couchbase_core/httpx/
client.rs

1/*
2 *
3 *  * Copyright (c) 2025 Couchbase, Inc.
4 *  *
5 *  * Licensed under the Apache License, Version 2.0 (the "License");
6 *  * you may not use this file except in compliance with the License.
7 *  * You may obtain a copy of the License at
8 *  *
9 *  *    http://www.apache.org/licenses/LICENSE-2.0
10 *  *
11 *  * Unless required by applicable law or agreed to in writing, software
12 *  * distributed under the License is distributed on an "AS IS" BASIS,
13 *  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 *  * See the License for the specific language governing permissions and
15 *  * limitations under the License.
16 *
17 */
18
19use crate::httpx::error::ErrorKind::Connection;
20use crate::httpx::error::{Error, Result as HttpxResult};
21use crate::httpx::request::{Auth, OboPasswordOrDomain, Request};
22use crate::httpx::response::Response;
23use crate::tls_config::TlsConfig;
24use arc_swap::ArcSwap;
25use async_trait::async_trait;
26use base64::prelude::BASE64_STANDARD;
27use base64::Engine;
28use http::header::{CONTENT_TYPE, USER_AGENT};
29use log::{debug, trace};
30use reqwest::redirect::Policy;
31use std::error::Error as StdError;
32use std::sync::Arc;
33use std::time::Duration;
34use uuid::Uuid;
35
36#[async_trait]
37pub trait Client: Send + Sync {
38    async fn execute(&self, req: Request) -> HttpxResult<Response>;
39}
40
41#[derive(Clone, Debug, Default)]
42#[non_exhaustive]
43pub struct ClientConfig {
44    pub tls_config: Option<TlsConfig>,
45    pub idle_connection_timeout: Duration,
46    pub max_idle_connections_per_host: Option<usize>,
47    pub tcp_keep_alive_time: Duration,
48}
49
50impl ClientConfig {
51    pub fn new() -> Self {
52        Default::default()
53    }
54
55    pub fn tls_config(mut self, tls_config: impl Into<Option<TlsConfig>>) -> Self {
56        self.tls_config = tls_config.into();
57        self
58    }
59
60    pub fn idle_connection_timeout(mut self, timeout: Duration) -> Self {
61        self.idle_connection_timeout = timeout;
62        self
63    }
64
65    pub fn max_idle_connections_per_host(mut self, max_idle_connections_per_host: usize) -> Self {
66        self.max_idle_connections_per_host = Some(max_idle_connections_per_host);
67        self
68    }
69}
70
71#[derive(Debug)]
72pub struct ReqwestClient {
73    inner: ArcSwap<reqwest::Client>,
74    client_id: String,
75}
76
77impl ReqwestClient {
78    pub fn new(cfg: ClientConfig) -> HttpxResult<Self> {
79        let inner = Self::new_client(cfg)?;
80        Ok(Self {
81            inner: ArcSwap::from_pointee(inner),
82            client_id: Uuid::new_v4().to_string(),
83        })
84    }
85
86    // TODO: once options are supported we need to check if they've changed before creating
87    // a new client provider.
88    pub fn reconfigure(&self, cfg: ClientConfig) -> HttpxResult<()> {
89        let new_inner = Self::new_client(cfg)?;
90        let old_inner = self.inner.swap(Arc::new(new_inner));
91
92        // TODO: This will close any in flight requests, do we actually need to do this or will
93        // it get dropped once requests complete anyway?
94        drop(old_inner);
95
96        Ok(())
97    }
98
99    fn new_client(cfg: ClientConfig) -> HttpxResult<reqwest::Client> {
100        let mut builder = reqwest::Client::builder()
101            .redirect(Policy::limited(10))
102            .pool_idle_timeout(cfg.idle_connection_timeout)
103            .tcp_keepalive(cfg.tcp_keep_alive_time);
104
105        if let Some(max_idle) = cfg.max_idle_connections_per_host {
106            builder = builder.pool_max_idle_per_host(max_idle);
107        }
108
109        if let Some(config) = cfg.tls_config {
110            builder = Self::add_tls_config(builder, config);
111        }
112
113        let client = builder
114            .build()
115            .map_err(|e| Error::new_message_error(format!("failed to build http client {e}")))?;
116        Ok(client)
117    }
118
119    #[cfg(all(feature = "rustls-tls", not(feature = "native-tls")))]
120    fn add_tls_config(
121        builder: reqwest::ClientBuilder,
122        tls_config: TlsConfig,
123    ) -> reqwest::ClientBuilder {
124        // We have to deref the Arc, otherwise we'll get a runtime error from reqwest.
125        builder.use_preconfigured_tls((*tls_config).clone())
126    }
127
128    #[cfg(feature = "native-tls")]
129    fn add_tls_config(
130        builder: reqwest::ClientBuilder,
131        tls_config: TlsConfig,
132    ) -> reqwest::ClientBuilder {
133        builder.use_preconfigured_tls(tls_config)
134    }
135}
136
137#[async_trait]
138impl Client for ReqwestClient {
139    async fn execute(&self, req: Request) -> HttpxResult<Response> {
140        let inner = self.inner.load();
141
142        let id = if let Some(unique_id) = req.unique_id {
143            unique_id
144        } else {
145            Uuid::new_v4().to_string()
146        };
147
148        trace!(
149            "Writing request on {} to {}. Method={}. Request id={}",
150            &self.client_id,
151            &req.uri,
152            &req.method,
153            &id
154        );
155
156        let mut builder = inner.request(req.method, req.uri);
157
158        if let Some(body) = req.body {
159            builder = builder.body(body);
160        }
161
162        if let Some(content_type) = req.content_type {
163            builder = builder.header(CONTENT_TYPE, content_type);
164        }
165
166        if let Some(user_agent) = req.user_agent {
167            builder = builder.header(USER_AGENT, user_agent);
168        }
169
170        if let Some(auth) = &req.auth {
171            match auth {
172                Auth::BasicAuth(basic) => {
173                    builder = builder.basic_auth(&basic.username, Some(&basic.password))
174                }
175                Auth::OnBehalfOf(obo) => {
176                    match &obo.password_or_domain {
177                        OboPasswordOrDomain::Password(password) => {
178                            // If we have the OBO users password, we just directly set the basic auth
179                            // on the request with those credentials rather than using an on-behalf-of
180                            // header.  This enables support for older server versions.
181                            builder = builder.basic_auth(&obo.username, Some(password));
182                        }
183                        OboPasswordOrDomain::Domain(domain) => {
184                            // Otherwise we send the user/domain using an OBO header.
185                            let obo_hdr_string =
186                                BASE64_STANDARD.encode(format!("{}:{}", obo.username, domain));
187                            builder = builder.header("cb-on-behalf-of", obo_hdr_string);
188                        }
189                    }
190                }
191            }
192        }
193
194        match builder.send().await {
195            Ok(response) => Ok({
196                trace!(
197                    "Received response on {}. Request id={}. Status: {}",
198                    &self.client_id,
199                    &id,
200                    response.status()
201                );
202                Response::from(response)
203            }),
204            Err(err) => {
205                let mut msg = format!(
206                    "Received error on {}. Request id={}. Err: {}",
207                    &self.client_id, &id, &err,
208                );
209
210                if let Some(source) = err.source() {
211                    msg = format!("{msg}. Source: {source}");
212                }
213
214                trace!("{msg}");
215
216                if err.is_connect() {
217                    Err(Error::new_connection_error(err.to_string()))
218                } else {
219                    Err(Error::new_message_error(err.to_string()))
220                }
221            }
222        }
223    }
224}
225
226impl Drop for ReqwestClient {
227    fn drop(&mut self) {
228        debug!("Dropping HTTP Client {}", &self.client_id);
229    }
230}