Skip to main content

braid_http/client/
native_network.rs

1use crate::client::parser::MessageParser;
2use crate::error::{BraidError, Result};
3use crate::protocol;
4use crate::traits::BraidNetwork;
5use crate::types::{BraidRequest, BraidResponse, Update};
6use async_trait::async_trait;
7use futures::StreamExt;
8use reqwest::Client;
9
10pub struct NativeNetwork {
11    client: Client,
12}
13
14impl NativeNetwork {
15    pub fn new(client: Client) -> Self {
16        Self { client }
17    }
18
19    pub fn client(&self) -> &Client {
20        &self.client
21    }
22}
23
24#[async_trait]
25impl BraidNetwork for NativeNetwork {
26    async fn fetch(&self, url: &str, request: BraidRequest) -> Result<BraidResponse> {
27        let method = match request.method.to_uppercase().as_str() {
28            "POST" => reqwest::Method::POST,
29            "PUT" => reqwest::Method::PUT,
30            "DELETE" => reqwest::Method::DELETE,
31            "PATCH" => reqwest::Method::PATCH,
32            _ => reqwest::Method::GET,
33        };
34
35        let mut req_builder = self.client.request(method.clone(), url);
36
37        for (k, v) in &request.extra_headers {
38            req_builder = req_builder.header(k, v);
39        }
40
41        if !request.body.is_empty() {
42            let ct = request
43                .content_type
44                .as_deref()
45                .unwrap_or("application/json");
46            req_builder = req_builder.header(reqwest::header::CONTENT_TYPE, ct);
47            req_builder = req_builder.body(request.body.clone());
48        }
49
50        if let Some(versions) = &request.version {
51            req_builder = req_builder.header("Version", protocol::format_version_header(versions));
52        }
53        if let Some(parents) = &request.parents {
54            req_builder = req_builder.header("Parents", protocol::format_version_header(parents));
55        }
56        if request.subscribe {
57            req_builder = req_builder.header("subscribe", "true");
58        }
59        if let Some(peer) = &request.peer {
60            req_builder = req_builder.header("Peer", format!("\"{}\"", peer));
61        }
62        if let Some(merge_type) = &request.merge_type {
63            req_builder = req_builder.header("merge-type", merge_type);
64        }
65
66        tracing::info!(
67            "[BraidHTTP-Out] {} {} headers: {:?}",
68            method,
69            url,
70            request.extra_headers
71        );
72
73        let response = req_builder
74            .send()
75            .await
76            .map_err(|e| BraidError::Http(e.to_string()))?;
77
78        let status = response.status().as_u16();
79        let mut headers = std::collections::BTreeMap::new();
80        for (k, v) in response.headers() {
81            if let Ok(val) = v.to_str() {
82                headers.insert(k.as_str().to_string(), val.to_string());
83            }
84        }
85
86        let body = response
87            .bytes()
88            .await
89            .map_err(|e| BraidError::Http(e.to_string()))?;
90
91        Ok(BraidResponse {
92            status,
93            headers,
94            body,
95            is_subscription: status == 209,
96        })
97    }
98
99    async fn subscribe(
100        &self,
101        url: &str,
102        mut request: BraidRequest,
103    ) -> Result<async_channel::Receiver<Result<Update>>> {
104        request.subscribe = true;
105        let mut req_builder = self.client.get(url).header("subscribe", "true");
106
107        for (k, v) in &request.extra_headers {
108            req_builder = req_builder.header(k, v);
109        }
110
111        if let Some(versions) = &request.version {
112            req_builder = req_builder.header("Version", protocol::format_version_header(versions));
113        }
114
115        if let Some(parents) = &request.parents {
116            req_builder = req_builder.header("Parents", protocol::format_version_header(parents));
117        }
118
119        if let Some(peer) = &request.peer {
120            req_builder = req_builder.header("Peer", format!("\"{}\"", peer));
121        }
122
123        if let Some(merge_type) = &request.merge_type {
124            req_builder = req_builder.header("merge-type", merge_type);
125        }
126
127        tracing::info!(
128            "[BraidHTTP-Sub-Out] GET {} headers: {:?}",
129            url,
130            request.extra_headers
131        );
132
133        let response = req_builder
134            .send()
135            .await
136            .map_err(|e| BraidError::Http(e.to_string()))?;
137
138        let mut headers = std::collections::BTreeMap::new();
139        for (k, v) in response.headers() {
140            if let Ok(val) = v.to_str() {
141                headers.insert(k.as_str().to_lowercase(), val.to_string());
142            }
143        }
144
145        tracing::debug!(
146            "[BraidRequest] Response headers (normalized): {:?}",
147            headers
148        );
149
150        let mut content_length = response.content_length().unwrap_or(0) as usize;
151
152        if content_length == 0 {
153            if let Some(range) = headers.get("content-range") {
154                // Parse Content-Range: unit start-end/total
155                // e.g. "text 0-4455/4455"
156                let parts: Vec<&str> = range.split_whitespace().collect();
157                if parts.len() >= 2 {
158                    if let Some(range_part) = parts.get(1) {
159                        if let Some((start, end)) = range_part.split_once('-') {
160                            if let (Ok(s), Ok(e)) = (
161                                start.parse::<usize>(),
162                                end.split('/').next().unwrap_or("").parse::<usize>(),
163                            ) {
164                                // content_length = e - s; // Redundant assignment fixed below
165                                // Wait, HTTP Content-Range is inclusive: "0-499" means 500 bytes.
166                                // "0-4455/4455"? If total is 4455, bytes are 0-4454.
167                                // If string is "0-4455", it might be start-seq?
168                                // Let's re-read the curl output: "content-range: text 0-4455/4455"
169                                // If total is 4455.
170                                // Usually Content-Range is bytes start-end/total.
171                                // If it is 0-4455... that's 4456 bytes?
172                                // But let's look at `parser.rs` logic for Content-Range.
173                                // It just grabs the unit.
174                                // Wait, Braid `Content-Range` might be different for text?
175                                // Let's assume it works like HTTP.
176                                // Safe bet: if total is there, use total?
177                                // No, valid is end - start.
178                                // Actually, let's just use the length from the part after / if present?
179                                // Or better, let's look at the `dt.js` or `parser.rs`?
180                                // `parser.rs` doesn't parse Content-Range for body length, only for patches.
181                                // It uses `expected_body_length`.
182
183                                // Let's trust "content-length" header if present.
184                                // If not, use the diff.
185                                // HTTP Range: start-end. Length = end - start + 1.
186                                content_length = e - s;
187                            }
188                        }
189                    }
190                }
191            }
192        }
193
194        let (tx, rx) = async_channel::bounded(100);
195        let mut stream = response.bytes_stream();
196
197        tokio::spawn(async move {
198            // Initialize parser with the HTTP headers and content-length
199            // so it can parse the first message (snapshot) correctly
200            let mut parser = MessageParser::new_with_state(headers, content_length);
201
202            while let Some(chunk_res) = stream.next().await {
203                match chunk_res {
204                    Ok(chunk) => {
205                        if let Ok(messages) = parser.feed(&chunk) {
206                            for msg in messages {
207                                let update = crate::client::utils::message_to_update(msg);
208                                let _ = tx.send(Ok(update)).await;
209                            }
210                        }
211                    }
212                    Err(e) => {
213                        let _ = tx.send(Err(BraidError::Http(e.to_string()))).await;
214                        break;
215                    }
216                }
217            }
218        });
219
220        Ok(rx)
221    }
222}