couchbase_core/httpx/
client.rs1use crate::httpx::error::ErrorKind::Connect;
20use crate::httpx::error::{Error, Result as HttpxResult};
21use crate::httpx::request::{Auth, OboPasswordOrDomain, Request};
22use crate::httpx::response::Response;
23use crate::tls_config::TlsConfig;
24use arc_swap::ArcSwap;
25use async_trait::async_trait;
26use base64::prelude::BASE64_STANDARD;
27use base64::Engine;
28use http::header::{CONTENT_TYPE, USER_AGENT};
29use reqwest::redirect::Policy;
30use std::error::Error as StdError;
31use std::sync::Arc;
32use std::time::Duration;
33use tracing::{debug, trace};
34use uuid::Uuid;
35
36#[async_trait]
37pub trait Client: Send + Sync {
38 async fn execute(&self, req: Request) -> HttpxResult<Response>;
39}
40
41#[derive(Clone, Debug, Default)]
42#[non_exhaustive]
43pub struct ClientConfig {
44 pub tls_config: Option<TlsConfig>,
45 pub idle_connection_timeout: Duration,
46 pub max_idle_connections_per_host: Option<usize>,
47 pub tcp_keep_alive_time: Duration,
48}
49
50impl ClientConfig {
51 pub fn new() -> Self {
52 Default::default()
53 }
54
55 pub fn tls_config(mut self, tls_config: impl Into<Option<TlsConfig>>) -> Self {
56 self.tls_config = tls_config.into();
57 self
58 }
59
60 pub fn idle_connection_timeout(mut self, timeout: Duration) -> Self {
61 self.idle_connection_timeout = timeout;
62 self
63 }
64
65 pub fn max_idle_connections_per_host(mut self, max_idle_connections_per_host: usize) -> Self {
66 self.max_idle_connections_per_host = Some(max_idle_connections_per_host);
67 self
68 }
69}
70
71#[derive(Clone, Debug, Default)]
72#[non_exhaustive]
73pub struct UpdateTlsOptions {
74 pub tls_config: Option<TlsConfig>,
75}
76
77impl UpdateTlsOptions {
78 pub fn new() -> Self {
79 Default::default()
80 }
81
82 pub fn tls_config(mut self, tls_config: impl Into<Option<TlsConfig>>) -> Self {
83 self.tls_config = tls_config.into();
84 self
85 }
86}
87
88#[derive(Debug)]
89pub struct ReqwestClient {
90 inner: ArcSwap<reqwest::Client>,
91 client_id: String,
92
93 idle_connection_timeout: Duration,
94 max_idle_connections_per_host: Option<usize>,
95 tcp_keep_alive_time: Duration,
96}
97
98impl ReqwestClient {
99 pub fn new(cfg: ClientConfig) -> HttpxResult<Self> {
100 let idle_connection_timeout = cfg.idle_connection_timeout;
101 let max_idle_connections_per_host = cfg.max_idle_connections_per_host;
102 let tcp_keep_alive_time = cfg.tcp_keep_alive_time;
103
104 let inner = Self::new_client(cfg)?;
105
106 Ok(Self {
107 inner: ArcSwap::from_pointee(inner),
108 client_id: Uuid::new_v4().to_string(),
109 idle_connection_timeout,
110 max_idle_connections_per_host,
111 tcp_keep_alive_time,
112 })
113 }
114
115 pub fn update_tls(&self, opts: UpdateTlsOptions) -> HttpxResult<()> {
116 let cfg = ClientConfig {
117 tls_config: opts.tls_config,
118 idle_connection_timeout: self.idle_connection_timeout,
119 max_idle_connections_per_host: self.max_idle_connections_per_host,
120 tcp_keep_alive_time: self.tcp_keep_alive_time,
121 };
122
123 let new_client = Self::new_client(cfg)?;
124
125 self.inner.store(Arc::new(new_client));
126
127 debug!("Reconfigured HTTP Client {}", &self.client_id);
128
129 Ok(())
130 }
131
132 fn new_client(cfg: ClientConfig) -> HttpxResult<reqwest::Client> {
133 let mut builder = reqwest::Client::builder()
134 .redirect(Policy::limited(10))
135 .pool_idle_timeout(cfg.idle_connection_timeout)
136 .tcp_keepalive(cfg.tcp_keep_alive_time);
137
138 if let Some(max_idle) = cfg.max_idle_connections_per_host {
139 builder = builder.pool_max_idle_per_host(max_idle);
140 }
141
142 if let Some(config) = cfg.tls_config {
143 builder = Self::add_tls_config(builder, config);
144 }
145
146 let client = builder
147 .build()
148 .map_err(|e| Error::new_message_error(format!("failed to build http client {e}")))?;
149 Ok(client)
150 }
151
152 #[cfg(all(feature = "rustls-tls", not(feature = "native-tls")))]
153 fn add_tls_config(
154 builder: reqwest::ClientBuilder,
155 tls_config: TlsConfig,
156 ) -> reqwest::ClientBuilder {
157 builder.use_preconfigured_tls((*tls_config).clone())
159 }
160
161 #[cfg(feature = "native-tls")]
162 fn add_tls_config(
163 builder: reqwest::ClientBuilder,
164 tls_config: TlsConfig,
165 ) -> reqwest::ClientBuilder {
166 builder.use_preconfigured_tls(tls_config)
167 }
168}
169
170#[async_trait]
171impl Client for ReqwestClient {
172 async fn execute(&self, req: Request) -> HttpxResult<Response> {
173 let inner = self.inner.load();
174
175 let id = if let Some(unique_id) = req.unique_id {
176 unique_id
177 } else {
178 Uuid::new_v4().to_string()
179 };
180
181 trace!(
182 "Writing request on {} to {}. Method={}. Request id={}",
183 &self.client_id,
184 &req.uri,
185 &req.method,
186 &id
187 );
188
189 let mut builder = inner.request(req.method, req.uri);
190
191 if let Some(body) = req.body {
192 builder = builder.body(body);
193 }
194
195 if let Some(content_type) = req.content_type {
196 builder = builder.header(CONTENT_TYPE, content_type);
197 }
198
199 if let Some(user_agent) = req.user_agent {
200 builder = builder.header(USER_AGENT, user_agent);
201 }
202
203 if let Some(auth) = &req.auth {
204 match auth {
205 Auth::BasicAuth(basic) => {
206 builder = builder.basic_auth(&basic.username, Some(&basic.password))
207 }
208 Auth::BearerAuth(bearer) => builder = builder.bearer_auth(&bearer.token),
209 Auth::OnBehalfOf(obo) => {
210 match &obo.password_or_domain {
211 OboPasswordOrDomain::Password(password) => {
212 builder = builder.basic_auth(&obo.username, Some(password));
216 }
217 OboPasswordOrDomain::Domain(domain) => {
218 let obo_hdr_string =
220 BASE64_STANDARD.encode(format!("{}:{}", obo.username, domain));
221 builder = builder.header("cb-on-behalf-of", obo_hdr_string);
222 }
223 }
224 }
225 }
226 }
227
228 match builder.send().await {
229 Ok(response) => Ok({
230 trace!(
231 "Received response on {}. Request id={}. Status: {}",
232 &self.client_id,
233 &id,
234 response.status()
235 );
236 Response::from(response)
237 }),
238 Err(err) => {
239 let mut msg = format!(
240 "Received error on {}. Request id={}. Err: {}",
241 &self.client_id, &id, &err,
242 );
243
244 if let Some(source) = err.source() {
245 msg = format!("{msg}. Source: {source}");
246 }
247
248 trace!("{msg}");
249
250 if err.is_connect() {
251 Err(Error::new_connect_error(err.to_string()))
252 } else if err.is_request() {
253 Err(Error::new_request_error(err.to_string()))
254 } else {
255 Err(Error::new_message_error(err.to_string()))
256 }
257 }
258 }
259 }
260}
261
262impl Drop for ReqwestClient {
263 fn drop(&mut self) {
264 debug!("Dropping HTTP Client {}", &self.client_id);
265 }
266}