Skip to main content

proxmox_api/clients/
reqwest.rs

1use reqwest::{Method, RequestBuilder, StatusCode};
2use serde::{Deserialize, Serialize, de::DeserializeOwned};
3
4use super::base_access::{AuthState, Ticket, TicketResponse};
5
6#[derive(Debug)]
7pub enum Error {
8    Reqwest(reqwest::Error),
9    EncounteredErrors(serde_json::Value),
10    ResponseWasNotString,
11    DecodingFailed(String, serde_json::Error),
12    UrlEncodingFailed(String),
13    UnknownFailure(StatusCode, Option<String>),
14    Other(&'static str),
15}
16
17impl std::fmt::Display for Error {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        match self {
20            Error::Reqwest(e) => write!(f, "{e}"),
21            Error::EncounteredErrors(v) => write!(f, "Proxmox returned errors: {v}"),
22            Error::ResponseWasNotString => write!(f, "response body was not valid UTF-8"),
23            Error::DecodingFailed(text, e) => {
24                write!(f, "failed to decode response: {e}; body: {text}")
25            }
26            Error::UrlEncodingFailed(msg) => write!(f, "failed to URL-encode request body: {msg}"),
27            Error::UnknownFailure(status, body) => {
28                write!(f, "HTTP {status}")?;
29                if let Some(body) = body {
30                    write!(f, ": {body}")?;
31                }
32                Ok(())
33            }
34            Error::Other(msg) => write!(f, "{msg}"),
35        }
36    }
37}
38
39impl std::error::Error for Error {
40    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
41        match self {
42            Error::Reqwest(e) => Some(e),
43            _ => None,
44        }
45    }
46}
47
48fn extract_message(body: &str) -> String {
49    serde_json::from_str::<serde_json::Value>(body)
50        .ok()
51        .and_then(|v| v.get("message").and_then(|m| m.as_str().map(String::from)))
52        .unwrap_or_else(|| body.to_string())
53}
54
55impl From<reqwest::Error> for Error {
56    fn from(value: reqwest::Error) -> Self {
57        Self::Reqwest(value)
58    }
59}
60
61#[derive(Debug, Clone)]
62pub struct Client {
63    client: reqwest::Client,
64    host: String,
65
66    user: String,
67    realm: String,
68
69    auth_state: AuthState,
70}
71
72impl Client {
73    fn client() -> reqwest::Client {
74        reqwest::ClientBuilder::new()
75            .danger_accept_invalid_certs(true)
76            .build()
77            .expect("failed to build HTTP client")
78    }
79
80    pub fn new(host: &str, user: &str, realm: &str, client: Option<reqwest::Client>) -> Self {
81        Self {
82            client: client.unwrap_or_else(Self::client),
83            host: host.to_string(),
84            user: user.into(),
85            realm: realm.into(),
86            auth_state: AuthState::new(),
87        }
88    }
89
90    pub fn with_api_token(self, token_id: &str, token: &str) -> Self {
91        // PVEAPIToken=USER@REALM!TOKENID=UUID
92        self.auth_state
93            .set_api_token(&self.user, &self.realm, token_id, token);
94        self
95    }
96
97    pub async fn with_login(self, password: &str) -> Result<Self, Error> {
98        self.login(password).await?;
99        Ok(self)
100    }
101
102    pub async fn with_ticket(self, ticket: &str, csrf: &str) -> Result<Self, Error> {
103        self.auth_state.set_csrf(ticket.into(), csrf.into());
104        self.refresh_auth_ticket(true).await?;
105        Ok(self)
106    }
107
108    fn route(&self, path: &str) -> String {
109        format!("{}/api2/json{}", self.host, path)
110    }
111
112    fn append_headers(&self, request: RequestBuilder) -> RequestBuilder {
113        let mut request = request;
114        for (k, v) in self.auth_state.headers() {
115            request = request.header(k, v);
116        }
117
118        request
119    }
120
121    async fn login(&self, password: &str) -> Result<(), Error> {
122        let user = self.user.to_string();
123        let realm = self.realm.to_string();
124        let request = Ticket::new(&user, &realm, password);
125
126        let csrf_details: TicketResponse =
127            crate::client::Client::post(self, "/access/ticket", &request).await?;
128
129        let ticket = csrf_details
130            .auth_ticket
131            .ok_or(Error::Other("Missing ticket from access response!"))?;
132        let csrf = csrf_details
133            .csrf_token
134            .ok_or(Error::Other("Missing CSRF token from access response!"))?;
135
136        self.auth_state.set_csrf(ticket, csrf);
137
138        Ok(())
139    }
140
141    /// Call this at least once every two hours.
142    ///
143    /// The ticket will automatically refresh if the last auth ticket was obtained more
144    /// than an hour ago, or if `force` is set to `true`.
145    pub async fn refresh_auth_ticket(&self, force: bool) -> Result<(), Error> {
146        log::trace!("Checking whether auth ticket should be refreshed (force: {force})");
147
148        let auth_ticket = if let Some(ticket) = self.auth_state.auth_ticket() {
149            ticket
150        } else {
151            if self.auth_state.api_token().is_none() {
152                log::warn!(
153                    "Tried to refresh auth ticket without existing auth ticket or API token."
154                );
155            }
156            return Ok(());
157        };
158
159        if force || self.auth_state.should_refresh() {
160            // TODO: lock auth state during entire login operation to avoid
161            // Time Of Check Time Of Use barriers
162            log::debug!("Refreshing auth ticket.");
163            self.login(&auth_ticket).await?;
164        }
165
166        Ok(())
167    }
168}
169
170impl crate::client::Client for Client {
171    type Error = Error;
172
173    async fn request_with_body_and_query<B, Q, R>(
174        &self,
175        method: crate::client::Method,
176        path: &str,
177        body: Option<&B>,
178        query: Option<&Q>,
179    ) -> Result<R, Error>
180    where
181        B: Serialize,
182        Q: Serialize,
183        R: DeserializeOwned,
184    {
185        let method = match method {
186            crate::client::Method::Post => Method::POST,
187            crate::client::Method::Get => Method::GET,
188            crate::client::Method::Put => Method::PUT,
189            crate::client::Method::Delete => Method::DELETE,
190        };
191
192        log::debug!("{} {}", method, path);
193
194        let request = self.client.request(method, self.route(path.as_ref()));
195
196        let request = if let Some(body) = body {
197            let body = serde_urlencoded::to_string(body)
198                .map_err(|e| Error::UrlEncodingFailed(e.to_string()))?;
199            request.body(body)
200        } else {
201            request
202        };
203
204        let request = if let Some(query) = query {
205            request.query(query)
206        } else {
207            request
208        };
209
210        let response = self.append_headers(request).send().await?;
211
212        let response_status = response.status();
213
214        let json_data = response.bytes().await?;
215        let json_str = std::str::from_utf8(&json_data).map_err(|_| Error::ResponseWasNotString)?;
216
217        log::debug!("JSON response: {json_str}");
218
219        if response_status != StatusCode::OK {
220            return Err(Error::UnknownFailure(
221                response_status,
222                Some(extract_message(json_str)),
223            ));
224        }
225
226        let result: Response<R> = serde_json::from_str(json_str)
227            .map_err(|e| Error::DecodingFailed(json_str.into(), e))?;
228
229        if let Some(data) = result.data {
230            Ok(data)
231        } else if let Some(errors) = result.errors {
232            Err(Error::EncounteredErrors(errors))
233        } else {
234            Err(Error::UnknownFailure(
235                response_status,
236                Some(extract_message(json_str)),
237            ))
238        }
239    }
240}
241
242#[derive(Debug, Deserialize)]
243pub struct Response<T> {
244    pub data: Option<T>,
245    pub errors: Option<serde_json::Value>,
246}