1use crate::client::config::ClientConfig;
4#[cfg(not(target_arch = "wasm32"))]
5use crate::client::native_network::NativeNetwork;
6#[cfg(target_arch = "wasm32")]
7use crate::client::wasm_network::WasmNetwork;
8use crate::error::{BraidError, Result};
9use crate::traits::BraidNetwork;
10use crate::types::{BraidRequest, BraidResponse};
11use std::sync::Arc;
12
13#[derive(Clone)]
15pub struct BraidClient {
16 #[cfg(not(target_arch = "wasm32"))]
17 pub network: Arc<NativeNetwork>,
18 #[cfg(target_arch = "wasm32")]
19 pub network: Arc<WasmNetwork>,
20 pub config: Arc<ClientConfig>,
21 #[cfg(not(target_arch = "wasm32"))]
23 pub multiplexers: Arc<
24 tokio::sync::Mutex<
25 std::collections::HashMap<String, Arc<crate::client::multiplex::Multiplexer>>,
26 >,
27 >,
28}
29
30impl BraidClient {
31 #[cfg(not(target_arch = "wasm32"))]
32 pub fn network(&self) -> &Arc<NativeNetwork> {
33 &self.network
34 }
35
36 #[cfg(target_arch = "wasm32")]
37 pub fn network(&self) -> &Arc<WasmNetwork> {
38 &self.network
39 }
40
41 #[cfg(not(target_arch = "wasm32"))]
42 pub fn client(&self) -> &reqwest::Client {
43 self.network.client()
44 }
45
46 pub fn new() -> Result<Self> {
47 Self::with_config(ClientConfig::default())
48 }
49
50 pub fn with_config(config: ClientConfig) -> Result<Self> {
51 #[cfg(not(target_arch = "wasm32"))]
52 {
53 let mut builder = reqwest::Client::builder()
54 .http1_only()
55 .timeout(std::time::Duration::from_millis(config.request_timeout_ms))
56 .pool_idle_timeout(std::time::Duration::from_secs(90))
57 .pool_max_idle_per_host(config.max_total_connections as usize);
58
59 if !config.proxy_url.is_empty() {
60 if let Ok(proxy) = reqwest::Proxy::all(&config.proxy_url) {
61 builder = builder.proxy(proxy);
62 }
63 }
64
65 let client = builder
66 .user_agent("curl/7.81.0")
67 .build()
68 .map_err(|e| BraidError::Config(e.to_string()))?;
69 let network = Arc::new(NativeNetwork::new(client));
70
71 Ok(BraidClient {
72 network,
73 config: Arc::new(config),
74 multiplexers: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
75 })
76 }
77
78 #[cfg(target_arch = "wasm32")]
79 {
80 let network = Arc::new(WasmNetwork);
81 Ok(BraidClient {
82 network,
83 config: Arc::new(config),
84 })
85 }
86 }
87
88 #[cfg(not(target_arch = "wasm32"))]
89 pub fn with_client(client: reqwest::Client) -> Result<Self> {
90 Ok(BraidClient {
91 network: Arc::new(NativeNetwork::new(client)),
92 config: Arc::new(ClientConfig::default()),
93 multiplexers: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
94 })
95 }
96
97 pub async fn get(&self, url: &str) -> Result<BraidResponse> {
98 self.fetch(url, BraidRequest::new()).await
99 }
100
101 pub async fn put(
102 &self,
103 url: &str,
104 body: &str,
105 mut request: BraidRequest,
106 ) -> Result<BraidResponse> {
107 request = request.with_method("PUT").with_body(body.to_string());
108
109 if request.content_type.is_none() {
110 request = request.with_content_type("application/json");
111 }
112
113 if request.version.is_none() {
114 let random_version = uuid::Uuid::new_v4().to_string();
115 request.version = Some(vec![crate::types::Version::new(&random_version)]);
116 }
117
118 self.fetch(url, request).await
119 }
120
121 pub async fn post(
122 &self,
123 url: &str,
124 body: &str,
125 mut request: BraidRequest,
126 ) -> Result<BraidResponse> {
127 request = request.with_method("POST").with_body(body.to_string());
128 self.fetch(url, request).await
129 }
130
131 pub async fn poke(&self, recipient_endpoint: &str, post_url: &str) -> Result<BraidResponse> {
132 let request = BraidRequest::new()
133 .with_method("POST")
134 .with_body(post_url.to_string())
135 .with_content_type("text/plain");
136
137 self.fetch(recipient_endpoint, request).await
138 }
139
140 pub async fn fetch(&self, url: &str, request: BraidRequest) -> Result<BraidResponse> {
141 self.fetch_with_retries(url, request).await
142 }
143
144 pub async fn subscribe(
145 &self,
146 url: &str,
147 request: BraidRequest,
148 ) -> Result<crate::client::Subscription> {
149 self.log_request(url, &request);
150 let rx = self.network.subscribe(url, request).await?;
151 Ok(crate::client::Subscription::new(rx))
152 }
153
154 async fn fetch_with_retries(&self, url: &str, request: BraidRequest) -> Result<BraidResponse> {
155 let retry_config = request.retry.clone().unwrap_or_else(|| {
156 if self.config.max_retries == 0 {
157 crate::client::retry::RetryConfig::no_retry()
158 } else {
159 crate::client::retry::RetryConfig::default()
160 .with_max_retries(self.config.max_retries)
161 .with_initial_backoff(std::time::Duration::from_millis(
162 self.config.retry_delay_ms,
163 ))
164 }
165 });
166
167 let mut retry_state = crate::client::retry::RetryState::new(retry_config);
168
169 loop {
170 self.log_request(url, &request);
171
172 match self.fetch_internal(url, &request).await {
173 Ok(response) => {
174 self.log_response(url, &response);
175
176 let status = response.status;
177 if (400..600).contains(&status) {
178 let retry_after = response
179 .headers
180 .get("retry-after")
181 .and_then(|v| crate::client::retry::parse_retry_after(v));
182
183 match retry_state.should_retry_status(status, retry_after) {
184 crate::client::retry::RetryDecision::Retry(delay) => {
185 if self.config.enable_logging {
186 tracing::warn!(
187 "Request status {} (attempt {}), retrying in {:?}",
188 status,
189 retry_state.attempts,
190 delay
191 );
192 }
193 crate::client::utils::sleep(delay).await;
194 continue;
195 }
196 crate::client::retry::RetryDecision::DontRetry => {
197 return Ok(response);
198 }
199 }
200 }
201 retry_state.reset();
202 return Ok(response);
203 }
204 Err(e) => {
205 let is_abort = matches!(&e, BraidError::Aborted);
206
207 match retry_state.should_retry_error(is_abort) {
208 crate::client::retry::RetryDecision::Retry(delay) => {
209 if self.config.enable_logging {
210 tracing::warn!(
211 "Request failed (attempt {}), retrying in {:?}: {}",
212 retry_state.attempts,
213 delay,
214 e
215 );
216 }
217 crate::client::utils::sleep(delay).await;
218 continue;
219 }
220 crate::client::retry::RetryDecision::DontRetry => {
221 return Err(e);
222 }
223 }
224 }
225 }
226 }
227 }
228
229 async fn fetch_internal(&self, url: &str, request: &BraidRequest) -> Result<BraidResponse> {
230 self.network.fetch(url, request.clone()).await
231 }
232
233 #[cfg(not(target_arch = "wasm32"))]
234 pub async fn fetch_multiplexed(
235 &self,
236 url: &str,
237 mut request: BraidRequest,
238 ) -> Result<BraidResponse> {
239 let origin = self.origin_from_url(url)?;
240
241 let mut multiplexers = self.multiplexers.lock().await;
242 let multiplexer = if let Some(m) = multiplexers.get(&origin) {
243 m.clone()
244 } else {
245 let multiplex_url = format!("{}/.multiplex", origin);
246 let m_id = format!("{:x}", rand::random::<u64>());
247 let m = Arc::new(crate::client::multiplex::Multiplexer::new(
248 origin.clone(),
249 m_id,
250 ));
251
252 let client = self.clone();
253 let m_inner = m.clone();
254 let origin_task = origin.clone();
255 crate::client::utils::spawn_task(async move {
256 let run_multiplex = async {
257 let multiplex_method =
258 reqwest::Method::from_bytes(b"MULTIPLEX").map_err(|e| {
259 BraidError::Protocol(format!("Invalid multiplex method: {}", e))
260 })?;
261 let multiplex_header_name = reqwest::header::HeaderName::from_bytes(
262 crate::protocol::constants::headers::MULTIPLEX_VERSION
263 .as_str()
264 .as_bytes(),
265 )
266 .map_err(|e| {
267 BraidError::Protocol(format!("Invalid multiplex header: {}", e))
268 })?;
269
270 let resp = client
271 .network
272 .client()
273 .request(multiplex_method, &multiplex_url)
274 .header(multiplex_header_name, "1.0")
275 .send()
276 .await
277 .map_err(|e| {
278 BraidError::Http(format!(
279 "Failed to establish multiplexed connection to {}: {}",
280 multiplex_url, e
281 ))
282 })?;
283
284 m_inner.run_stream(resp).await
285 };
286
287 if let Err(e) = run_multiplex.await {
288 tracing::error!("Multiplexer task failed for {}: {}", origin_task, e);
289 }
290 });
291
292 multiplexers.insert(origin.clone(), m.clone());
293 m
294 };
295 drop(multiplexers);
296
297 let r_id = format!("{:x}", rand::random::<u32>());
298 let (tx, rx) = async_channel::bounded(100);
299 multiplexer.add_request(r_id.clone(), tx).await;
300
301 request.extra_headers.insert(
302 crate::protocol::constants::headers::MULTIPLEX_THROUGH.to_string(),
303 format!("/.well-known/multiplexer/{}/{}", multiplexer.id, r_id),
304 );
305
306 self.log_request(url, &request);
307 let initial_response = self.fetch_internal(url, &request).await?;
308 self.log_response(url, &initial_response);
309
310 if initial_response.status == 293 {
311 let mut response_buffer = Vec::new();
312 let mut headers_parsed = None;
313
314 while let Ok(chunk) = rx.recv().await {
315 response_buffer.extend_from_slice(&chunk);
316
317 if headers_parsed.is_none() {
318 if let Ok((status, headers, body_start)) =
319 crate::protocol::parse_tunneled_response(&response_buffer)
320 {
321 headers_parsed = Some((status, headers, body_start));
322 }
323 }
324 }
325
326 if let Some((status, headers, body_start)) = headers_parsed {
327 let body = bytes::Bytes::copy_from_slice(&response_buffer[body_start..]);
328 return Ok(BraidResponse {
329 status,
330 headers,
331 body,
332 is_subscription: false,
333 });
334 } else {
335 return Err(crate::error::BraidError::Protocol(
336 "Multiplexed response ended before headers received".to_string(),
337 ));
338 }
339 }
340
341 Ok(initial_response)
342 }
343
344 pub fn config(&self) -> &ClientConfig {
345 &self.config
346 }
347
348 fn log_request(&self, _url: &str, _request: &BraidRequest) {}
349
350 fn log_response(&self, _url: &str, _response: &BraidResponse) {}
351
352 fn origin_from_url(&self, url: &str) -> Result<String> {
353 let parsed_url = url::Url::parse(url).map_err(|e| BraidError::Config(e.to_string()))?;
354 Ok(format!(
355 "{}://{}",
356 parsed_url.scheme(),
357 parsed_url.host_str().unwrap_or("")
358 ))
359 }
360}
361
362impl Default for BraidClient {
363 fn default() -> Self {
364 Self::new().unwrap_or_else(|_| {
365 let network = Arc::new(NativeNetwork::new(reqwest::Client::new()));
366 BraidClient {
367 network,
368 config: Arc::new(ClientConfig::default()),
369 multiplexers: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
370 }
371 })
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378 use crate::types::BraidRequest;
379
380 #[test]
381 fn test_client_init() {
382 let client = BraidClient::new().unwrap();
383 assert_eq!(client.config().max_retries, 3);
384 }
385
386 #[test]
387 fn test_origin_extraction() {
388 let client = BraidClient::new().unwrap();
389 assert_eq!(
390 client.origin_from_url("http://example.com/foo").unwrap(),
391 "http://example.com"
392 );
393 }
394
395 #[test]
396 fn test_put_request_prep() {
397 let mut req = BraidRequest::new();
398 req = req.with_method("PUT").with_body("test".to_string());
399 if req.content_type.is_none() {
400 req = req.with_content_type("application/json");
401 }
402 if req.version.is_none() {
403 req.version = Some(vec![crate::types::Version::new("test-version")]);
404 }
405 assert_eq!(req.method, "PUT");
406 assert_eq!(req.version.unwrap()[0].to_string(), "test-version");
407 }
408
409 #[test]
410 fn test_poke_request_prep() {
411 let req = BraidRequest::new()
412 .with_method("POST")
413 .with_body("http://example.com/post")
414 .with_content_type("text/plain");
415 assert_eq!(req.method, "POST");
416 assert_eq!(
417 String::from_utf8_lossy(&req.body),
418 "http://example.com/post"
419 );
420 }
421}