Skip to main content

braid_http/client/
fetch.rs

1//! Main Braid HTTP client implementation.
2
3use crate::client::config::ClientConfig;
4#[cfg(not(target_arch = "wasm32"))]
5use crate::client::native_network::NativeNetwork;
6#[cfg(target_arch = "wasm32")]
7use crate::client::wasm_network::WasmNetwork;
8use crate::error::{BraidError, Result};
9use crate::traits::BraidNetwork;
10use crate::types::{BraidRequest, BraidResponse};
11use std::sync::Arc;
12
13/// The main Braid HTTP client
14#[derive(Clone)]
15pub struct BraidClient {
16    #[cfg(not(target_arch = "wasm32"))]
17    pub network: Arc<NativeNetwork>,
18    #[cfg(target_arch = "wasm32")]
19    pub network: Arc<WasmNetwork>,
20    pub config: Arc<ClientConfig>,
21    /// Active multiplexers by origin.
22    #[cfg(not(target_arch = "wasm32"))]
23    pub multiplexers: Arc<
24        tokio::sync::Mutex<
25            std::collections::HashMap<String, Arc<crate::client::multiplex::Multiplexer>>,
26        >,
27    >,
28}
29
30impl BraidClient {
31    #[cfg(not(target_arch = "wasm32"))]
32    pub fn network(&self) -> &Arc<NativeNetwork> {
33        &self.network
34    }
35
36    #[cfg(target_arch = "wasm32")]
37    pub fn network(&self) -> &Arc<WasmNetwork> {
38        &self.network
39    }
40
41    #[cfg(not(target_arch = "wasm32"))]
42    pub fn client(&self) -> &reqwest::Client {
43        self.network.client()
44    }
45
46    pub fn new() -> Result<Self> {
47        Self::with_config(ClientConfig::default())
48    }
49
50    pub fn with_config(config: ClientConfig) -> Result<Self> {
51        #[cfg(not(target_arch = "wasm32"))]
52        {
53            let mut builder = reqwest::Client::builder()
54                .http1_only()
55                .timeout(std::time::Duration::from_millis(config.request_timeout_ms))
56                .pool_idle_timeout(std::time::Duration::from_secs(90))
57                .pool_max_idle_per_host(config.max_total_connections as usize);
58
59            if !config.proxy_url.is_empty() {
60                if let Ok(proxy) = reqwest::Proxy::all(&config.proxy_url) {
61                    builder = builder.proxy(proxy);
62                }
63            }
64
65            let client = builder
66                .user_agent("curl/7.81.0")
67                .build()
68                .map_err(|e| BraidError::Config(e.to_string()))?;
69            let network = Arc::new(NativeNetwork::new(client));
70
71            Ok(BraidClient {
72                network,
73                config: Arc::new(config),
74                multiplexers: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
75            })
76        }
77
78        #[cfg(target_arch = "wasm32")]
79        {
80            let network = Arc::new(WasmNetwork);
81            Ok(BraidClient {
82                network,
83                config: Arc::new(config),
84            })
85        }
86    }
87
88    #[cfg(not(target_arch = "wasm32"))]
89    pub fn with_client(client: reqwest::Client) -> Result<Self> {
90        Ok(BraidClient {
91            network: Arc::new(NativeNetwork::new(client)),
92            config: Arc::new(ClientConfig::default()),
93            multiplexers: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
94        })
95    }
96
97    pub async fn get(&self, url: &str) -> Result<BraidResponse> {
98        self.fetch(url, BraidRequest::new()).await
99    }
100
101    pub async fn put(
102        &self,
103        url: &str,
104        body: &str,
105        mut request: BraidRequest,
106    ) -> Result<BraidResponse> {
107        request = request.with_method("PUT").with_body(body.to_string());
108
109        if request.content_type.is_none() {
110            request = request.with_content_type("application/json");
111        }
112
113        if request.version.is_none() {
114            let random_version = uuid::Uuid::new_v4().to_string();
115            request.version = Some(vec![crate::types::Version::new(&random_version)]);
116        }
117
118        self.fetch(url, request).await
119    }
120
121    pub async fn post(
122        &self,
123        url: &str,
124        body: &str,
125        mut request: BraidRequest,
126    ) -> Result<BraidResponse> {
127        request = request.with_method("POST").with_body(body.to_string());
128        self.fetch(url, request).await
129    }
130
131    pub async fn poke(&self, recipient_endpoint: &str, post_url: &str) -> Result<BraidResponse> {
132        let request = BraidRequest::new()
133            .with_method("POST")
134            .with_body(post_url.to_string())
135            .with_content_type("text/plain");
136
137        self.fetch(recipient_endpoint, request).await
138    }
139
140    pub async fn fetch(&self, url: &str, request: BraidRequest) -> Result<BraidResponse> {
141        self.fetch_with_retries(url, request).await
142    }
143
144    pub async fn subscribe(
145        &self,
146        url: &str,
147        request: BraidRequest,
148    ) -> Result<crate::client::Subscription> {
149        self.log_request(url, &request);
150        let rx = self.network.subscribe(url, request).await?;
151        Ok(crate::client::Subscription::new(rx))
152    }
153
154    async fn fetch_with_retries(&self, url: &str, request: BraidRequest) -> Result<BraidResponse> {
155        let retry_config = request.retry.clone().unwrap_or_else(|| {
156            if self.config.max_retries == 0 {
157                crate::client::retry::RetryConfig::no_retry()
158            } else {
159                crate::client::retry::RetryConfig::default()
160                    .with_max_retries(self.config.max_retries)
161                    .with_initial_backoff(std::time::Duration::from_millis(
162                        self.config.retry_delay_ms,
163                    ))
164            }
165        });
166
167        let mut retry_state = crate::client::retry::RetryState::new(retry_config);
168
169        loop {
170            self.log_request(url, &request);
171
172            match self.fetch_internal(url, &request).await {
173                Ok(response) => {
174                    self.log_response(url, &response);
175
176                    let status = response.status;
177                    if (400..600).contains(&status) {
178                        let retry_after = response
179                            .headers
180                            .get("retry-after")
181                            .and_then(|v| crate::client::retry::parse_retry_after(v));
182
183                        match retry_state.should_retry_status(status, retry_after) {
184                            crate::client::retry::RetryDecision::Retry(delay) => {
185                                if self.config.enable_logging {
186                                    tracing::warn!(
187                                        "Request status {} (attempt {}), retrying in {:?}",
188                                        status,
189                                        retry_state.attempts,
190                                        delay
191                                    );
192                                }
193                                crate::client::utils::sleep(delay).await;
194                                continue;
195                            }
196                            crate::client::retry::RetryDecision::DontRetry => {
197                                return Ok(response);
198                            }
199                        }
200                    }
201                    retry_state.reset();
202                    return Ok(response);
203                }
204                Err(e) => {
205                    let is_abort = matches!(&e, BraidError::Aborted);
206
207                    match retry_state.should_retry_error(is_abort) {
208                        crate::client::retry::RetryDecision::Retry(delay) => {
209                            if self.config.enable_logging {
210                                tracing::warn!(
211                                    "Request failed (attempt {}), retrying in {:?}: {}",
212                                    retry_state.attempts,
213                                    delay,
214                                    e
215                                );
216                            }
217                            crate::client::utils::sleep(delay).await;
218                            continue;
219                        }
220                        crate::client::retry::RetryDecision::DontRetry => {
221                            return Err(e);
222                        }
223                    }
224                }
225            }
226        }
227    }
228
229    async fn fetch_internal(&self, url: &str, request: &BraidRequest) -> Result<BraidResponse> {
230        self.network.fetch(url, request.clone()).await
231    }
232
233    #[cfg(not(target_arch = "wasm32"))]
234    pub async fn fetch_multiplexed(
235        &self,
236        url: &str,
237        mut request: BraidRequest,
238    ) -> Result<BraidResponse> {
239        let origin = self.origin_from_url(url)?;
240
241        let mut multiplexers = self.multiplexers.lock().await;
242        let multiplexer = if let Some(m) = multiplexers.get(&origin) {
243            m.clone()
244        } else {
245            let multiplex_url = format!("{}/.multiplex", origin);
246            let m_id = format!("{:x}", rand::random::<u64>());
247            let m = Arc::new(crate::client::multiplex::Multiplexer::new(
248                origin.clone(),
249                m_id,
250            ));
251
252            let client = self.clone();
253            let m_inner = m.clone();
254            let origin_task = origin.clone();
255            crate::client::utils::spawn_task(async move {
256                let run_multiplex = async {
257                    let multiplex_method =
258                        reqwest::Method::from_bytes(b"MULTIPLEX").map_err(|e| {
259                            BraidError::Protocol(format!("Invalid multiplex method: {}", e))
260                        })?;
261                    let multiplex_header_name = reqwest::header::HeaderName::from_bytes(
262                        crate::protocol::constants::headers::MULTIPLEX_VERSION
263                            .as_str()
264                            .as_bytes(),
265                    )
266                    .map_err(|e| {
267                        BraidError::Protocol(format!("Invalid multiplex header: {}", e))
268                    })?;
269
270                    let resp = client
271                        .network
272                        .client()
273                        .request(multiplex_method, &multiplex_url)
274                        .header(multiplex_header_name, "1.0")
275                        .send()
276                        .await
277                        .map_err(|e| {
278                            BraidError::Http(format!(
279                                "Failed to establish multiplexed connection to {}: {}",
280                                multiplex_url, e
281                            ))
282                        })?;
283
284                    m_inner.run_stream(resp).await
285                };
286
287                if let Err(e) = run_multiplex.await {
288                    tracing::error!("Multiplexer task failed for {}: {}", origin_task, e);
289                }
290            });
291
292            multiplexers.insert(origin.clone(), m.clone());
293            m
294        };
295        drop(multiplexers);
296
297        let r_id = format!("{:x}", rand::random::<u32>());
298        let (tx, rx) = async_channel::bounded(100);
299        multiplexer.add_request(r_id.clone(), tx).await;
300
301        request.extra_headers.insert(
302            crate::protocol::constants::headers::MULTIPLEX_THROUGH.to_string(),
303            format!("/.well-known/multiplexer/{}/{}", multiplexer.id, r_id),
304        );
305
306        self.log_request(url, &request);
307        let initial_response = self.fetch_internal(url, &request).await?;
308        self.log_response(url, &initial_response);
309
310        if initial_response.status == 293 {
311            let mut response_buffer = Vec::new();
312            let mut headers_parsed = None;
313
314            while let Ok(chunk) = rx.recv().await {
315                response_buffer.extend_from_slice(&chunk);
316
317                if headers_parsed.is_none() {
318                    if let Ok((status, headers, body_start)) =
319                        crate::protocol::parse_tunneled_response(&response_buffer)
320                    {
321                        headers_parsed = Some((status, headers, body_start));
322                    }
323                }
324            }
325
326            if let Some((status, headers, body_start)) = headers_parsed {
327                let body = bytes::Bytes::copy_from_slice(&response_buffer[body_start..]);
328                return Ok(BraidResponse {
329                    status,
330                    headers,
331                    body,
332                    is_subscription: false,
333                });
334            } else {
335                return Err(crate::error::BraidError::Protocol(
336                    "Multiplexed response ended before headers received".to_string(),
337                ));
338            }
339        }
340
341        Ok(initial_response)
342    }
343
344    pub fn config(&self) -> &ClientConfig {
345        &self.config
346    }
347
348    fn log_request(&self, _url: &str, _request: &BraidRequest) {}
349
350    fn log_response(&self, _url: &str, _response: &BraidResponse) {}
351
352    fn origin_from_url(&self, url: &str) -> Result<String> {
353        let parsed_url = url::Url::parse(url).map_err(|e| BraidError::Config(e.to_string()))?;
354        Ok(format!(
355            "{}://{}",
356            parsed_url.scheme(),
357            parsed_url.host_str().unwrap_or("")
358        ))
359    }
360}
361
362impl Default for BraidClient {
363    fn default() -> Self {
364        Self::new().unwrap_or_else(|_| {
365            let network = Arc::new(NativeNetwork::new(reqwest::Client::new()));
366            BraidClient {
367                network,
368                config: Arc::new(ClientConfig::default()),
369                multiplexers: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
370            }
371        })
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378    use crate::types::BraidRequest;
379
380    #[test]
381    fn test_client_init() {
382        let client = BraidClient::new().unwrap();
383        assert_eq!(client.config().max_retries, 3);
384    }
385
386    #[test]
387    fn test_origin_extraction() {
388        let client = BraidClient::new().unwrap();
389        assert_eq!(
390            client.origin_from_url("http://example.com/foo").unwrap(),
391            "http://example.com"
392        );
393    }
394
395    #[test]
396    fn test_put_request_prep() {
397        let mut req = BraidRequest::new();
398        req = req.with_method("PUT").with_body("test".to_string());
399        if req.content_type.is_none() {
400            req = req.with_content_type("application/json");
401        }
402        if req.version.is_none() {
403            req.version = Some(vec![crate::types::Version::new("test-version")]);
404        }
405        assert_eq!(req.method, "PUT");
406        assert_eq!(req.version.unwrap()[0].to_string(), "test-version");
407    }
408
409    #[test]
410    fn test_poke_request_prep() {
411        let req = BraidRequest::new()
412            .with_method("POST")
413            .with_body("http://example.com/post")
414            .with_content_type("text/plain");
415        assert_eq!(req.method, "POST");
416        assert_eq!(
417            String::from_utf8_lossy(&req.body),
418            "http://example.com/post"
419        );
420    }
421}