Skip to main content

nv_redfish_bmc_http/
reqwest.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::BmcCredentials;
17use crate::CacheableError;
18use crate::HttpClient;
19use futures_util::StreamExt;
20use http::header;
21use http::HeaderMap;
22use nv_redfish_core::AsyncTask;
23use nv_redfish_core::BoxTryStream;
24use nv_redfish_core::ModificationResponse;
25use nv_redfish_core::ODataETag;
26use nv_redfish_core::ODataId;
27use serde::de::DeserializeOwned;
28use serde::Serialize;
29use std::time::Duration;
30use url::Url;
31
32#[derive(Debug)]
33pub enum BmcError {
34    ReqwestError(reqwest::Error),
35    JsonError(serde_path_to_error::Error<serde_json::Error>),
36    InvalidResponse {
37        url: url::Url,
38        status: reqwest::StatusCode,
39        text: String,
40    },
41    SseStreamError(sse_stream::Error),
42    CacheMiss,
43    CacheError(String),
44    DecodeError(serde_json::Error),
45}
46
47impl From<reqwest::Error> for BmcError {
48    fn from(value: reqwest::Error) -> Self {
49        Self::ReqwestError(value)
50    }
51}
52
53impl CacheableError for BmcError {
54    fn is_cached(&self) -> bool {
55        match self {
56            Self::InvalidResponse { status, .. } => status == &reqwest::StatusCode::NOT_MODIFIED,
57            _ => false,
58        }
59    }
60
61    fn cache_miss() -> Self {
62        Self::CacheMiss
63    }
64
65    fn cache_error(reason: String) -> Self {
66        Self::CacheError(reason)
67    }
68}
69
70#[allow(clippy::absolute_paths)]
71impl std::fmt::Display for BmcError {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        match self {
74            Self::ReqwestError(e) => write!(f, "HTTP client error: {e:?}"),
75            Self::InvalidResponse { url, status, text } => {
76                write!(
77                    f,
78                    "Invalid HTTP response - url: {url} status: {status} text: {text}"
79                )
80            }
81            Self::CacheMiss => write!(f, "Resource not found in cache"),
82            Self::CacheError(r) => write!(f, "Error occurred in cache {r:?}"),
83            Self::JsonError(e) => write!(
84                f,
85                "JSON deserialization error at line {} column {} path {}: {e}",
86                e.inner().line(),
87                e.inner().column(),
88                e.path(),
89            ),
90            Self::SseStreamError(e) => write!(f, "SSE stream decode error: {e}"),
91            Self::DecodeError(e) => write!(f, "JSON Decode error: {e}"),
92        }
93    }
94}
95
96#[allow(clippy::absolute_paths)]
97impl std::error::Error for BmcError {
98    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
99        match self {
100            Self::ReqwestError(e) => Some(e),
101            Self::JsonError(e) => Some(e.inner()),
102            Self::SseStreamError(e) => Some(e),
103            Self::DecodeError(e) => Some(e),
104            _ => None,
105        }
106    }
107}
108
109/// Configuration parameters for the reqwest HTTP client.
110///
111/// This struct allows customizing various aspects of the reqwest client behavior,
112/// including timeouts, TLS settings, and connection pooling.
113///
114/// # Examples
115///
116/// ```rust
117/// use nv_redfish_bmc_http::reqwest::ClientParams;
118/// use std::time::Duration;
119///
120/// let params = ClientParams::new()
121///     .timeout(Duration::from_secs(30))
122///     .connect_timeout(Duration::from_secs(10))
123///     .user_agent("MyApp/1.0")
124///     .accept_invalid_certs(true);
125/// ```
126#[derive(Debug, Clone)]
127pub struct ClientParams {
128    /// HTTP request timeout
129    pub timeout: Option<Duration>,
130    /// TCP connection timeout
131    pub connect_timeout: Option<Duration>,
132    /// User-Agent header value
133    pub user_agent: Option<String>,
134    /// Whether to accept invalid TLS certificates
135    pub accept_invalid_certs: bool,
136    /// Maximum number of HTTP redirects to follow
137    pub max_redirects: Option<usize>,
138    /// TCP keep-alive timeout
139    pub tcp_keepalive: Option<Duration>,
140    /// Connection pool idle timeout
141    pub pool_idle_timeout: Option<Duration>,
142    /// Maximum idle connections per host
143    pub pool_max_idle_per_host: Option<usize>,
144    /// List of default headers, added to every request
145    pub default_headers: Option<HeaderMap>,
146    /// Forces use of rust TLS, enabled by default
147    pub use_rust_tls: bool,
148}
149
150impl Default for ClientParams {
151    fn default() -> Self {
152        Self {
153            timeout: Some(Duration::from_secs(120)),
154            connect_timeout: Some(Duration::from_secs(5)),
155            user_agent: Some("nv-redfish/v1".to_string()),
156            accept_invalid_certs: false,
157            max_redirects: Some(10),
158            tcp_keepalive: Some(Duration::from_secs(60)),
159            pool_idle_timeout: Some(Duration::from_secs(90)),
160            pool_max_idle_per_host: Some(1),
161            default_headers: None,
162            use_rust_tls: true,
163        }
164    }
165}
166
167impl ClientParams {
168    #[must_use]
169    pub fn new() -> Self {
170        Self::default()
171    }
172
173    #[must_use]
174    pub const fn timeout(mut self, timeout: Duration) -> Self {
175        self.timeout = Some(timeout);
176        self
177    }
178
179    #[must_use]
180    pub const fn connect_timeout(mut self, timeout: Duration) -> Self {
181        self.connect_timeout = Some(timeout);
182        self
183    }
184
185    #[must_use]
186    pub fn user_agent<S: Into<String>>(mut self, user_agent: S) -> Self {
187        self.user_agent = Some(user_agent.into());
188        self
189    }
190
191    #[must_use]
192    pub const fn accept_invalid_certs(mut self, accept: bool) -> Self {
193        self.accept_invalid_certs = accept;
194        self
195    }
196
197    #[must_use]
198    pub const fn max_redirects(mut self, max: usize) -> Self {
199        self.max_redirects = Some(max);
200        self
201    }
202
203    #[must_use]
204    pub const fn tcp_keepalive(mut self, keepalive: Duration) -> Self {
205        self.tcp_keepalive = Some(keepalive);
206        self
207    }
208
209    #[must_use]
210    pub const fn pool_max_idle_per_host(mut self, pool_max_idle_per_host: usize) -> Self {
211        self.pool_max_idle_per_host = Some(pool_max_idle_per_host);
212        self
213    }
214
215    #[must_use]
216    pub const fn idle_timeout(mut self, pool_idle_timeout: Duration) -> Self {
217        self.pool_idle_timeout = Some(pool_idle_timeout);
218        self
219    }
220
221    #[must_use]
222    pub const fn no_timeout(mut self) -> Self {
223        self.timeout = None;
224        self
225    }
226
227    #[must_use]
228    pub fn default_headers(mut self, default_headers: HeaderMap) -> Self {
229        self.default_headers = Some(default_headers);
230        self
231    }
232}
233
234/// HTTP client implementation using the reqwest library.
235///
236/// This provides a concrete implementation of [`HttpClient`] using the popular
237/// reqwest HTTP client library. It supports all standard HTTP features including
238/// TLS, authentication, and connection pooling.
239///
240#[derive(Clone)]
241pub struct Client {
242    client: reqwest::Client,
243}
244
245#[allow(clippy::missing_errors_doc)]
246#[allow(clippy::absolute_paths)]
247impl Client {
248    pub fn new() -> Result<Self, reqwest::Error> {
249        Self::with_params(ClientParams::default())
250    }
251
252    pub fn with_params(params: ClientParams) -> Result<Self, reqwest::Error> {
253        let mut builder = reqwest::Client::builder();
254
255        if params.use_rust_tls {
256            builder = builder.use_rustls_tls();
257        }
258
259        if let Some(timeout) = params.timeout {
260            builder = builder.timeout(timeout);
261        }
262
263        if let Some(connect_timeout) = params.connect_timeout {
264            builder = builder.connect_timeout(connect_timeout);
265        }
266
267        if let Some(user_agent) = params.user_agent {
268            builder = builder.user_agent(user_agent);
269        }
270
271        if params.accept_invalid_certs {
272            builder = builder.danger_accept_invalid_certs(true);
273        }
274
275        if let Some(max_redirects) = params.max_redirects {
276            builder = builder.redirect(reqwest::redirect::Policy::limited(max_redirects));
277        }
278
279        if let Some(keepalive) = params.tcp_keepalive {
280            builder = builder.tcp_keepalive(keepalive);
281        }
282
283        if let Some(idle_timeout) = params.pool_idle_timeout {
284            builder = builder.pool_idle_timeout(idle_timeout);
285        }
286
287        if let Some(max_idle) = params.pool_max_idle_per_host {
288            builder = builder.pool_max_idle_per_host(max_idle);
289        }
290
291        if let Some(default_headers) = params.default_headers {
292            builder = builder.default_headers(default_headers);
293        }
294
295        Ok(Self {
296            client: builder.build()?,
297        })
298    }
299
300    #[must_use]
301    pub const fn with_client(client: reqwest::Client) -> Self {
302        Self { client }
303    }
304}
305
306impl Client {
307    async fn handle_response<T>(&self, response: reqwest::Response) -> Result<T, BmcError>
308    where
309        T: DeserializeOwned,
310    {
311        if !response.status().is_success() {
312            return Err(BmcError::InvalidResponse {
313                url: response.url().clone(),
314                status: response.status(),
315                text: response.text().await.unwrap_or_else(|_| "<no data>".into()),
316            });
317        }
318
319        let headers = response.headers().clone();
320
321        let etag_header = etag_from_headers(&headers);
322
323        let mut value: serde_json::Value = response.json().await.map_err(BmcError::ReqwestError)?;
324
325        if let Some(etag) = etag_header {
326            inject_etag(etag, &mut value);
327        }
328
329        serde_path_to_error::deserialize(value).map_err(BmcError::JsonError)
330    }
331
332    async fn handle_modification_response<T>(
333        &self,
334        response: reqwest::Response,
335    ) -> Result<ModificationResponse<T>, BmcError>
336    where
337        T: DeserializeOwned + Send + Sync,
338    {
339        let status = response.status();
340        let url = response.url().clone();
341        let headers = response.headers().clone();
342        if !status.is_success() {
343            return Err(BmcError::InvalidResponse {
344                url,
345                status,
346                text: response.text().await.unwrap_or_else(|_| "<no data>".into()),
347            });
348        }
349
350        let etag = etag_from_headers(&headers);
351        let location = location_from_headers(&headers);
352
353        match status {
354            reqwest::StatusCode::NO_CONTENT => Ok(ModificationResponse::Empty),
355            reqwest::StatusCode::ACCEPTED => {
356                let Some(task_monitor_id) = location else {
357                    return Err(BmcError::InvalidResponse {
358                        url,
359                        status,
360                        text: String::from("202 Accepted without Location header"),
361                    });
362                };
363                Ok(ModificationResponse::Task(AsyncTask {
364                    id: task_monitor_id,
365                    retry_after_secs: retry_after_from_headers(&headers),
366                }))
367            }
368            reqwest::StatusCode::OK | reqwest::StatusCode::CREATED => {
369                let bytes = response.bytes().await.map_err(BmcError::ReqwestError)?;
370                if !bytes.is_empty() {
371                    let value: serde_json::Value =
372                        serde_json::from_slice(&bytes).map_err(BmcError::DecodeError)?;
373                    let mut value = value;
374
375                    if value.get("@odata.id").is_some() {
376                        if let Some(etag) = etag {
377                            inject_etag(etag, &mut value);
378                        }
379                        return serde_path_to_error::deserialize(value)
380                            .map(ModificationResponse::Entity)
381                            .map_err(BmcError::JsonError);
382                    }
383                }
384
385                if let Some(location) = location {
386                    let value = serde_json::json!({ "@odata.id": location });
387                    return serde_path_to_error::deserialize(value)
388                        .map(ModificationResponse::Entity)
389                        .map_err(BmcError::JsonError);
390                }
391
392                Ok(ModificationResponse::Empty)
393            }
394            _ => Err(BmcError::InvalidResponse {
395                url,
396                status,
397                text: format!("Unexpected successful status code: {status}"),
398            }),
399        }
400    }
401}
402
403fn location_from_headers(headers: &HeaderMap) -> Option<ODataId> {
404    headers
405        .get(header::LOCATION)
406        .and_then(|value| value.to_str().ok())
407        .map(|raw| {
408            if let Ok(url) = Url::parse(raw) {
409                let mut path = url.path().to_string();
410                if let Some(query) = url.query() {
411                    path.push('?');
412                    path.push_str(query);
413                }
414                path.into()
415            } else {
416                raw.to_string().into()
417            }
418        })
419}
420
421fn etag_from_headers(headers: &HeaderMap) -> Option<ODataETag> {
422    headers
423        .get(header::ETAG)
424        .and_then(|value| value.to_str().ok())
425        .map(|v| v.to_string().into())
426}
427
428fn retry_after_from_headers(headers: &HeaderMap) -> Option<u64> {
429    headers
430        .get(header::RETRY_AFTER)
431        .and_then(|value| value.to_str().ok())
432        .and_then(|v| v.trim().parse::<u64>().ok())
433}
434
435fn inject_etag(etag: ODataETag, body: &mut serde_json::Value) {
436    if let Some(obj) = body.as_object_mut() {
437        let etag_value = serde_json::Value::String(etag.to_string());
438
439        // Handles both absent and null values
440        obj.entry("@odata.etag")
441            .and_modify(|v| *v = etag_value.clone())
442            .or_insert(etag_value);
443    }
444}
445
446fn auth_headers(
447    request: reqwest::RequestBuilder,
448    credentials: &BmcCredentials,
449) -> reqwest::RequestBuilder {
450    match credentials {
451        BmcCredentials::UsernamePassword { username, password } => {
452            request.basic_auth(username, password.as_ref())
453        }
454        BmcCredentials::Token { token } => request.header("X-Auth-Token", token),
455    }
456}
457
458impl HttpClient for Client {
459    type Error = BmcError;
460
461    async fn get<T>(
462        &self,
463        url: Url,
464        credentials: &BmcCredentials,
465        etag: Option<ODataETag>,
466        custom_headers: &HeaderMap,
467    ) -> Result<T, Self::Error>
468    where
469        T: DeserializeOwned,
470    {
471        let mut request = auth_headers(self.client.get(url), credentials).headers(custom_headers.clone());
472
473        if let Some(etag) = etag {
474            request = request.header(header::IF_NONE_MATCH, etag.to_string());
475        }
476
477        let response = request.send().await?;
478        self.handle_response(response).await
479    }
480
481    async fn post<B, T>(
482        &self,
483        url: Url,
484        body: &B,
485        credentials: &BmcCredentials,
486        custom_headers: &HeaderMap,
487    ) -> Result<ModificationResponse<T>, Self::Error>
488    where
489        B: Serialize + Send + Sync,
490        T: DeserializeOwned + Send + Sync,
491    {
492        let response = auth_headers(self.client.post(url), credentials)
493            .headers(custom_headers.clone())
494            .json(body)
495            .send()
496            .await?;
497
498        self.handle_modification_response(response).await
499    }
500
501    async fn patch<B, T>(
502        &self,
503        url: Url,
504        etag: ODataETag,
505        body: &B,
506        credentials: &BmcCredentials,
507        custom_headers: &HeaderMap,
508    ) -> Result<ModificationResponse<T>, Self::Error>
509    where
510        B: Serialize + Send + Sync,
511        T: DeserializeOwned + Send + Sync,
512    {
513        let mut request =
514            auth_headers(self.client.patch(url), credentials).headers(custom_headers.clone());
515
516        request = request.header(header::IF_MATCH, etag.to_string());
517
518        let response = request.json(body).send().await?;
519        self.handle_modification_response(response).await
520    }
521
522    async fn delete<T>(
523        &self,
524        url: Url,
525        credentials: &BmcCredentials,
526        custom_headers: &HeaderMap,
527    ) -> Result<ModificationResponse<T>, Self::Error>
528    where
529        T: DeserializeOwned + Send + Sync,
530    {
531        let response = auth_headers(self.client.delete(url), credentials)
532            .headers(custom_headers.clone())
533            .send()
534            .await?;
535
536        self.handle_modification_response(response).await
537    }
538
539    async fn sse<T: Sized + for<'a> serde::Deserialize<'a> + Send + 'static>(
540        &self,
541        url: Url,
542        credentials: &BmcCredentials,
543        custom_headers: &HeaderMap,
544    ) -> Result<BoxTryStream<T, Self::Error>, Self::Error> {
545        let response = auth_headers(self.client.get(url), credentials)
546            .headers(custom_headers.clone())
547            .header(header::ACCEPT, "text/event-stream")
548            .send()
549            .await?;
550
551        if !response.status().is_success() {
552            return Err(BmcError::InvalidResponse {
553                url: response.url().clone(),
554                status: response.status(),
555                text: response.text().await.unwrap_or_else(|_| "<no data>".into()),
556            });
557        }
558
559        let stream = sse_stream::SseStream::from_byte_stream(response.bytes_stream()).filter_map(
560            |event| async move {
561                match event {
562                    Err(err) => Some(Err(BmcError::SseStreamError(err))),
563                    Ok(sse) => sse.data.map(|data| {
564                        serde_path_to_error::deserialize(&mut serde_json::Deserializer::from_str(
565                            &data,
566                        ))
567                        .map_err(BmcError::JsonError)
568                    }),
569                }
570            },
571        );
572
573        Ok(Box::pin(stream))
574    }
575}
576
577#[cfg(test)]
578mod tests {
579    use super::*;
580    #[test]
581    fn test_cacheable_error_trait() {
582        let mock_response = reqwest::Response::from(
583            http::Response::builder()
584                .status(304)
585                .body("")
586                .expect("Valid empty body"),
587        );
588        let error = BmcError::InvalidResponse {
589            url: "http://example.com/redfish/v1".parse().unwrap(),
590            status: mock_response.status(),
591            text: "".into(),
592        };
593        assert!(error.is_cached());
594
595        let cache_miss = BmcError::CacheMiss;
596        assert!(!cache_miss.is_cached());
597
598        let created_miss = BmcError::cache_miss();
599        assert!(matches!(created_miss, BmcError::CacheMiss));
600    }
601}