Skip to main content

couchbase_core/httpx/
client.rs

1/*
2 *
3 *  * Copyright (c) 2025 Couchbase, Inc.
4 *  *
5 *  * Licensed under the Apache License, Version 2.0 (the "License");
6 *  * you may not use this file except in compliance with the License.
7 *  * You may obtain a copy of the License at
8 *  *
9 *  *    http://www.apache.org/licenses/LICENSE-2.0
10 *  *
11 *  * Unless required by applicable law or agreed to in writing, software
12 *  * distributed under the License is distributed on an "AS IS" BASIS,
13 *  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 *  * See the License for the specific language governing permissions and
15 *  * limitations under the License.
16 *
17 */
18
19use crate::httpx::error::ErrorKind::Connect;
20use crate::httpx::error::{Error, Result as HttpxResult};
21use crate::httpx::request::{Auth, OboPasswordOrDomain, Request};
22use crate::httpx::response::Response;
23use crate::tls_config::TlsConfig;
24use arc_swap::ArcSwap;
25use async_trait::async_trait;
26use base64::prelude::BASE64_STANDARD;
27use base64::Engine;
28use http::header::{CONTENT_TYPE, USER_AGENT};
29use reqwest::redirect::Policy;
30use std::error::Error as StdError;
31use std::sync::Arc;
32use std::time::Duration;
33use tracing::{debug, trace};
34use uuid::Uuid;
35
36#[async_trait]
37pub trait Client: Send + Sync {
38    async fn execute(&self, req: Request) -> HttpxResult<Response>;
39}
40
41#[derive(Clone, Debug, Default)]
42#[non_exhaustive]
43pub struct ClientConfig {
44    pub tls_config: Option<TlsConfig>,
45    pub idle_connection_timeout: Duration,
46    pub max_idle_connections_per_host: Option<usize>,
47    pub tcp_keep_alive_time: Duration,
48}
49
50impl ClientConfig {
51    pub fn new() -> Self {
52        Default::default()
53    }
54
55    pub fn tls_config(mut self, tls_config: impl Into<Option<TlsConfig>>) -> Self {
56        self.tls_config = tls_config.into();
57        self
58    }
59
60    pub fn idle_connection_timeout(mut self, timeout: Duration) -> Self {
61        self.idle_connection_timeout = timeout;
62        self
63    }
64
65    pub fn max_idle_connections_per_host(mut self, max_idle_connections_per_host: usize) -> Self {
66        self.max_idle_connections_per_host = Some(max_idle_connections_per_host);
67        self
68    }
69}
70
71#[derive(Clone, Debug, Default)]
72#[non_exhaustive]
73pub struct UpdateTlsOptions {
74    pub tls_config: Option<TlsConfig>,
75}
76
77impl UpdateTlsOptions {
78    pub fn new() -> Self {
79        Default::default()
80    }
81
82    pub fn tls_config(mut self, tls_config: impl Into<Option<TlsConfig>>) -> Self {
83        self.tls_config = tls_config.into();
84        self
85    }
86}
87
88#[derive(Debug)]
89pub struct ReqwestClient {
90    inner: ArcSwap<reqwest::Client>,
91    client_id: String,
92
93    idle_connection_timeout: Duration,
94    max_idle_connections_per_host: Option<usize>,
95    tcp_keep_alive_time: Duration,
96}
97
98impl ReqwestClient {
99    pub fn new(cfg: ClientConfig) -> HttpxResult<Self> {
100        let idle_connection_timeout = cfg.idle_connection_timeout;
101        let max_idle_connections_per_host = cfg.max_idle_connections_per_host;
102        let tcp_keep_alive_time = cfg.tcp_keep_alive_time;
103
104        let inner = Self::new_client(cfg)?;
105
106        Ok(Self {
107            inner: ArcSwap::from_pointee(inner),
108            client_id: Uuid::new_v4().to_string(),
109            idle_connection_timeout,
110            max_idle_connections_per_host,
111            tcp_keep_alive_time,
112        })
113    }
114
115    pub fn update_tls(&self, opts: UpdateTlsOptions) -> HttpxResult<()> {
116        let cfg = ClientConfig {
117            tls_config: opts.tls_config,
118            idle_connection_timeout: self.idle_connection_timeout,
119            max_idle_connections_per_host: self.max_idle_connections_per_host,
120            tcp_keep_alive_time: self.tcp_keep_alive_time,
121        };
122
123        let new_client = Self::new_client(cfg)?;
124
125        self.inner.store(Arc::new(new_client));
126
127        debug!("Reconfigured HTTP Client {}", &self.client_id);
128
129        Ok(())
130    }
131
132    fn new_client(cfg: ClientConfig) -> HttpxResult<reqwest::Client> {
133        let mut builder = reqwest::Client::builder()
134            .redirect(Policy::limited(10))
135            .pool_idle_timeout(cfg.idle_connection_timeout)
136            .tcp_keepalive(cfg.tcp_keep_alive_time);
137
138        if let Some(max_idle) = cfg.max_idle_connections_per_host {
139            builder = builder.pool_max_idle_per_host(max_idle);
140        }
141
142        if let Some(config) = cfg.tls_config {
143            builder = Self::add_tls_config(builder, config);
144        }
145
146        let client = builder
147            .build()
148            .map_err(|e| Error::new_message_error(format!("failed to build http client {e}")))?;
149        Ok(client)
150    }
151
152    #[cfg(all(feature = "rustls-tls", not(feature = "native-tls")))]
153    fn add_tls_config(
154        builder: reqwest::ClientBuilder,
155        tls_config: TlsConfig,
156    ) -> reqwest::ClientBuilder {
157        // We have to deref the Arc, otherwise we'll get a runtime error from reqwest.
158        builder.use_preconfigured_tls((*tls_config).clone())
159    }
160
161    #[cfg(feature = "native-tls")]
162    fn add_tls_config(
163        builder: reqwest::ClientBuilder,
164        tls_config: TlsConfig,
165    ) -> reqwest::ClientBuilder {
166        builder.use_preconfigured_tls(tls_config)
167    }
168}
169
170#[async_trait]
171impl Client for ReqwestClient {
172    async fn execute(&self, req: Request) -> HttpxResult<Response> {
173        let inner = self.inner.load();
174
175        let id = if let Some(unique_id) = req.unique_id {
176            unique_id
177        } else {
178            Uuid::new_v4().to_string()
179        };
180
181        trace!(
182            "Writing request on {} to {}. Method={}. Request id={}",
183            &self.client_id,
184            &req.uri,
185            &req.method,
186            &id
187        );
188
189        let mut builder = inner.request(req.method, req.uri);
190
191        if let Some(body) = req.body {
192            builder = builder.body(body);
193        }
194
195        if let Some(content_type) = req.content_type {
196            builder = builder.header(CONTENT_TYPE, content_type);
197        }
198
199        if let Some(user_agent) = req.user_agent {
200            builder = builder.header(USER_AGENT, user_agent);
201        }
202
203        if let Some(auth) = &req.auth {
204            match auth {
205                Auth::BasicAuth(basic) => {
206                    builder = builder.basic_auth(&basic.username, Some(&basic.password))
207                }
208                Auth::BearerAuth(bearer) => builder = builder.bearer_auth(&bearer.token),
209                Auth::OnBehalfOf(obo) => {
210                    match &obo.password_or_domain {
211                        OboPasswordOrDomain::Password(password) => {
212                            // If we have the OBO users password, we just directly set the basic auth
213                            // on the request with those credentials rather than using an on-behalf-of
214                            // header.  This enables support for older server versions.
215                            builder = builder.basic_auth(&obo.username, Some(password));
216                        }
217                        OboPasswordOrDomain::Domain(domain) => {
218                            // Otherwise we send the user/domain using an OBO header.
219                            let obo_hdr_string =
220                                BASE64_STANDARD.encode(format!("{}:{}", obo.username, domain));
221                            builder = builder.header("cb-on-behalf-of", obo_hdr_string);
222                        }
223                    }
224                }
225            }
226        }
227
228        match builder.send().await {
229            Ok(response) => Ok({
230                trace!(
231                    "Received response on {}. Request id={}. Status: {}",
232                    &self.client_id,
233                    &id,
234                    response.status()
235                );
236                Response::from(response)
237            }),
238            Err(err) => {
239                let mut msg = format!(
240                    "Received error on {}. Request id={}. Err: {}",
241                    &self.client_id, &id, &err,
242                );
243
244                if let Some(source) = err.source() {
245                    msg = format!("{msg}. Source: {source}");
246                }
247
248                trace!("{msg}");
249
250                if err.is_connect() {
251                    Err(Error::new_connect_error(err.to_string()))
252                } else if err.is_request() {
253                    Err(Error::new_request_error(err.to_string()))
254                } else {
255                    Err(Error::new_message_error(err.to_string()))
256                }
257            }
258        }
259    }
260}
261
262impl Drop for ReqwestClient {
263    fn drop(&mut self) {
264        debug!("Dropping HTTP Client {}", &self.client_id);
265    }
266}