Skip to main content

braid_http/client/
parser.rs

1//! Message parser for Braid protocol streaming.
2
3use crate::error::{BraidError, Result};
4use crate::types::Patch;
5use bytes::{Buf, Bytes, BytesMut};
6use once_cell::sync::Lazy;
7use regex::Regex;
8use std::collections::BTreeMap;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ParseState {
12    WaitingForHeaders,
13    ParsingHeaders,
14    WaitingForBody,
15    WaitingForPatchHeaders,
16    WaitingForPatchBody,
17    SkippingSeparator,
18    Complete,
19    Error,
20}
21
22#[derive(Debug)]
23pub struct MessageParser {
24    buffer: BytesMut,
25    state: ParseState,
26    headers: BTreeMap<String, String>,
27    body_buffer: BytesMut,
28    expected_body_length: usize,
29    read_body_length: usize,
30    patches: Vec<Patch>,
31    expected_patches: usize,
32    patches_read: usize,
33    patch_headers: BTreeMap<String, String>,
34    expected_patch_length: usize,
35    read_patch_length: usize,
36    is_encoding_block: bool,
37}
38
39static HTTP_STATUS_REGEX: Lazy<Regex> =
40    Lazy::new(|| Regex::new(r"^HTTP/?\d*\.?\d* (\d{3})").unwrap());
41
42static ENCODING_BLOCK_REGEX: Lazy<Regex> =
43    Lazy::new(|| Regex::new(r"(?i)Encoding:\s*(\w+)\r?\nLength:\s*(\d+)\r?\n").unwrap());
44
45impl MessageParser {
46    pub fn new() -> Self {
47        MessageParser {
48            buffer: BytesMut::with_capacity(8192),
49            state: ParseState::WaitingForHeaders,
50            headers: BTreeMap::new(),
51            body_buffer: BytesMut::new(),
52            expected_body_length: 0,
53            read_body_length: 0,
54            patches: Vec::new(),
55            expected_patches: 0,
56            patches_read: 0,
57            patch_headers: BTreeMap::new(),
58            expected_patch_length: 0,
59            read_patch_length: 0,
60            is_encoding_block: false,
61        }
62    }
63
64    pub fn new_with_state(headers: BTreeMap<String, String>, content_length: usize) -> Self {
65        let mut parser = MessageParser::new();
66        parser.headers = headers;
67        parser.expected_body_length = content_length;
68        if content_length > 0 {
69            parser.state = ParseState::WaitingForBody;
70        } else {
71            // If explicit 0 length, we might have a message ready effectively?
72            // But usually we wait for body. If 0, try_parse_body handles it.
73            parser.state = ParseState::WaitingForBody;
74        }
75        parser
76    }
77
78    pub fn feed(&mut self, data: &[u8]) -> Result<Vec<Message>> {
79        self.buffer.extend_from_slice(data);
80        let mut messages = Vec::new();
81
82        loop {
83            match self.state {
84                ParseState::WaitingForHeaders => {
85                    while !self.buffer.is_empty()
86                        && (self.buffer[0] == b'\r' || self.buffer[0] == b'\n')
87                    {
88                        self.buffer.advance(1);
89                    }
90
91                    if self.buffer.is_empty() {
92                        break;
93                    }
94
95                    if self.check_encoding_block()? {
96                        self.state = ParseState::WaitingForBody;
97                        continue;
98                    }
99
100                    if let Some(pos) = self.find_header_end() {
101                        self.parse_headers(pos)?;
102                        self.state = ParseState::WaitingForBody;
103                    } else {
104                        break;
105                    }
106                }
107                ParseState::WaitingForBody => {
108                    if self.expected_patches > 0 {
109                        self.state = ParseState::WaitingForPatchHeaders;
110                    } else if self.try_parse_body()? {
111                        if let Some(msg) = self.finalize_message()? {
112                            messages.push(msg);
113                        }
114                        self.reset();
115                        self.state = ParseState::WaitingForHeaders;
116                    } else {
117                        break;
118                    }
119                }
120                ParseState::WaitingForPatchHeaders => {
121                    if let Some(pos) = self.find_header_end() {
122                        self.parse_patch_headers(pos)?;
123                        self.state = ParseState::WaitingForPatchBody;
124                    } else {
125                        break;
126                    }
127                }
128                ParseState::WaitingForPatchBody => {
129                    if self.try_parse_patch_body()? {
130                        self.patches_read += 1;
131                        if self.patches_read < self.expected_patches {
132                            self.state = ParseState::SkippingSeparator;
133                        } else {
134                            if let Some(msg) = self.finalize_message()? {
135                                messages.push(msg);
136                            }
137                            self.reset();
138                            self.state = ParseState::WaitingForHeaders;
139                        }
140                    } else {
141                        break;
142                    }
143                }
144                ParseState::SkippingSeparator => {
145                    if self.buffer.len() >= 2 {
146                        if &self.buffer[..2] == b"\r\n" {
147                            self.buffer.advance(2);
148                        } else if self.buffer[0] == b'\n' {
149                            self.buffer.advance(1);
150                        }
151                        self.state = ParseState::WaitingForPatchHeaders;
152                    } else if self.buffer.len() == 1 && self.buffer[0] == b'\n' {
153                        self.buffer.advance(1);
154                        self.state = ParseState::WaitingForPatchHeaders;
155                    } else {
156                        break;
157                    }
158                }
159                _ => break,
160            }
161        }
162        Ok(messages)
163    }
164
165    fn check_encoding_block(&mut self) -> Result<bool> {
166        if self.buffer.is_empty() || (self.buffer[0] != b'E' && self.buffer[0] != b'e') {
167            return Ok(false);
168        }
169
170        if let Some(end) = self.find_double_newline() {
171            let header_bytes = &self.buffer[..end];
172            let header_str = std::str::from_utf8(header_bytes).map_err(|e| {
173                BraidError::Protocol(format!("Invalid encoding block UTF-8: {}", e))
174            })?;
175
176            if let Some(caps) = ENCODING_BLOCK_REGEX.captures(header_str) {
177                let encoding = caps.get(1).unwrap().as_str().to_string();
178                let length: usize = caps.get(2).unwrap().as_str().parse().map_err(|_| {
179                    BraidError::Protocol("Invalid length in encoding block".to_string())
180                })?;
181
182                let _ = self.buffer.split_to(end);
183                self.headers.insert("encoding".to_string(), encoding);
184                self.headers
185                    .insert("length".to_string(), length.to_string());
186                self.expected_body_length = length;
187                self.is_encoding_block = true;
188                return Ok(true);
189            }
190        }
191        Ok(false)
192    }
193
194    fn find_double_newline(&self) -> Option<usize> {
195        if let Some(pos) = self.buffer.windows(4).position(|w| w == b"\r\n\r\n") {
196            return Some(pos + 4);
197        }
198        if let Some(pos) = self.buffer.windows(2).position(|w| w == b"\n\n") {
199            return Some(pos + 2);
200        }
201        None
202    }
203
204    fn find_header_end(&self) -> Option<usize> {
205        self.buffer
206            .windows(4)
207            .position(|w| w == b"\r\n\r\n")
208            .map(|p| p + 4)
209    }
210
211    fn parse_headers(&mut self, end: usize) -> Result<()> {
212        let header_bytes = self.buffer.split_to(end);
213        let mut header_str = String::from_utf8(header_bytes[..header_bytes.len() - 4].to_vec())?;
214
215        if let Some(caps) = HTTP_STATUS_REGEX.captures(&header_str) {
216            if let Some(status_match) = caps.get(1) {
217                let status = status_match.as_str();
218                if let Some(first_newline) = header_str.find('\n') {
219                    let replacement = format!(":status: {}\r", status);
220                    header_str = replacement + &header_str[first_newline..];
221                }
222            }
223        }
224
225        for line in header_str.lines() {
226            if let Some(colon_pos) = line.find(':') {
227                let key = line[..colon_pos].trim().to_lowercase();
228                let value = line[colon_pos + 1..].trim().to_string();
229                self.headers.insert(key, value);
230            }
231        }
232
233        if let Some(patches_str) = self.headers.get("patches") {
234            self.expected_patches = patches_str.parse().unwrap_or(0);
235        }
236
237        if let Some(len_str) = self
238            .headers
239            .get("content-length")
240            .or_else(|| self.headers.get("length"))
241        {
242            self.expected_body_length = len_str.parse().map_err(|_| {
243                BraidError::HeaderParse(format!("Invalid content-length: {}", len_str))
244            })?;
245        }
246        Ok(())
247    }
248
249    fn parse_patch_headers(&mut self, end: usize) -> Result<()> {
250        let header_bytes = self.buffer.split_to(end);
251        let header_str = String::from_utf8(header_bytes[..header_bytes.len() - 4].to_vec())?;
252
253        self.patch_headers.clear();
254        for line in header_str.lines() {
255            if let Some(colon_pos) = line.find(':') {
256                let key = line[..colon_pos].trim().to_lowercase();
257                let value = line[colon_pos + 1..].trim().to_string();
258                self.patch_headers.insert(key, value);
259            }
260        }
261
262        if let Some(len_str) = self.patch_headers.get("content-length") {
263            self.expected_patch_length = len_str.parse().map_err(|_| {
264                BraidError::HeaderParse(format!("Invalid patch content-length: {}", len_str))
265            })?;
266        } else {
267            return Err(BraidError::Protocol(
268                "Every patch MUST include Content-Length".to_string(),
269            ));
270        }
271
272        self.read_patch_length = 0;
273        Ok(())
274    }
275
276    fn try_parse_patch_body(&mut self) -> Result<bool> {
277        let remaining = self.expected_patch_length - self.read_patch_length;
278        if self.buffer.len() >= remaining {
279            let body_chunk = self.buffer.split_to(remaining);
280            let unit = self
281                .patch_headers
282                .get("content-range")
283                .and_then(|cr| cr.split_whitespace().next())
284                .unwrap_or("bytes")
285                .to_string();
286            let range = self
287                .patch_headers
288                .get("content-range")
289                .and_then(|cr| cr.split_whitespace().nth(1))
290                .unwrap_or("")
291                .to_string();
292            let patch = Patch::with_length(unit, range, body_chunk, self.expected_patch_length);
293            self.patches.push(patch);
294            self.read_patch_length += remaining;
295            Ok(true)
296        } else {
297            Ok(false)
298        }
299    }
300
301    fn try_parse_body(&mut self) -> Result<bool> {
302        if self.expected_body_length == 0 {
303            return Ok(true);
304        }
305        let remaining = self.expected_body_length - self.read_body_length;
306        if self.buffer.len() >= remaining {
307            let body_chunk = self.buffer.split_to(remaining);
308            self.body_buffer.extend_from_slice(&body_chunk);
309            self.read_body_length += body_chunk.len();
310            Ok(true)
311        } else {
312            let chunk_len = self.buffer.len();
313            self.body_buffer
314                .extend_from_slice(&self.buffer.split_to(chunk_len));
315            self.read_body_length += chunk_len;
316            Ok(false)
317        }
318    }
319
320    fn finalize_message(&mut self) -> Result<Option<Message>> {
321        let body = self.body_buffer.split().freeze();
322        let headers = std::mem::take(&mut self.headers);
323        let url = headers.get("content-location").cloned();
324        let encoding = headers.get("encoding").cloned();
325
326        Ok(Some(Message {
327            headers,
328            body,
329            patches: std::mem::take(&mut self.patches),
330            status_code: None,
331            encoding,
332            url,
333        }))
334    }
335
336    fn reset(&mut self) {
337        self.headers.clear();
338        self.body_buffer.clear();
339        self.expected_body_length = 0;
340        self.read_body_length = 0;
341        self.patches.clear();
342        self.expected_patches = 0;
343        self.patches_read = 0;
344        self.patch_headers.clear();
345        self.expected_patch_length = 0;
346        self.read_patch_length = 0;
347        self.is_encoding_block = false;
348    }
349
350    pub fn state(&self) -> ParseState {
351        self.state
352    }
353    pub fn headers(&self) -> &BTreeMap<String, String> {
354        &self.headers
355    }
356    pub fn body(&self) -> &[u8] {
357        &self.body_buffer
358    }
359}
360
361impl Default for MessageParser {
362    fn default() -> Self {
363        Self::new()
364    }
365}
366
367#[derive(Debug, Clone)]
368pub struct Message {
369    pub headers: BTreeMap<String, String>,
370    pub body: Bytes,
371    pub patches: Vec<Patch>,
372    pub status_code: Option<u16>,
373    pub encoding: Option<String>,
374    pub url: Option<String>,
375}
376
377impl Message {
378    pub fn status(&self) -> Option<u16> {
379        self.status_code
380            .or_else(|| self.headers.get(":status").and_then(|v| v.parse().ok()))
381    }
382
383    pub fn version(&self) -> Option<&str> {
384        self.headers.get("version").map(|s| s.as_str())
385    }
386    pub fn current_version(&self) -> Option<&str> {
387        self.headers.get("current-version").map(|s| s.as_str())
388    }
389    pub fn parents(&self) -> Option<&str> {
390        self.headers.get("parents").map(|s| s.as_str())
391    }
392
393    pub fn decode_body(&self) -> Result<Bytes> {
394        match self.encoding.as_deref() {
395            Some("dt") => Ok(self.body.clone()),
396            Some(enc) => Err(BraidError::Protocol(format!("Unknown encoding: {}", enc))),
397            None => Ok(self.body.clone()),
398        }
399    }
400
401    pub fn extra_headers(&self) -> BTreeMap<String, String> {
402        const KNOWN_HEADERS: &[&str] = &[
403            "version",
404            "parents",
405            "current-version",
406            "patches",
407            "content-length",
408            "content-range",
409            ":status",
410        ];
411        self.headers
412            .iter()
413            .filter(|(k, _)| !KNOWN_HEADERS.contains(&k.as_str()))
414            .map(|(k, v)| (k.clone(), v.clone()))
415            .collect()
416    }
417
418    pub fn body_text(&self) -> Option<String> {
419        std::str::from_utf8(&self.body).ok().map(|s| s.to_string())
420    }
421}
422
423pub fn parse_status_line(line: &str) -> Option<u16> {
424    let parts: Vec<&str> = line.split_whitespace().collect();
425    if parts.len() >= 2 && parts[0].to_uppercase().starts_with("HTTP") {
426        parts[1].parse().ok()
427    } else {
428        None
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435
436    #[test]
437    fn test_parser_creation() {
438        let parser = MessageParser::new();
439        assert_eq!(parser.state(), ParseState::WaitingForHeaders);
440    }
441
442    #[test]
443    fn test_simple_message_parsing() {
444        let mut parser = MessageParser::new();
445        let data = b"Content-Length: 5\r\n\r\nHello";
446        let messages = parser.feed(data).unwrap();
447        assert!(!messages.is_empty());
448        assert_eq!(messages[0].body, Bytes::from_static(b"Hello"));
449    }
450
451    #[test]
452    fn test_parse_status_line() {
453        assert_eq!(parse_status_line("HTTP/1.1 200 OK"), Some(200));
454        assert_eq!(parse_status_line("HTTP 209 Subscription"), Some(209));
455        assert_eq!(parse_status_line("HTTP/2 404"), Some(404));
456    }
457
458    #[test]
459    fn test_message_extra_headers() {
460        let mut headers = BTreeMap::new();
461        headers.insert("version".to_string(), "\"v1\"".to_string());
462        headers.insert("x-custom-header".to_string(), "value".to_string());
463
464        let msg = Message {
465            headers,
466            body: Bytes::new(),
467            patches: vec![],
468            status_code: None,
469            encoding: None,
470            url: None,
471        };
472
473        let extra = msg.extra_headers();
474        assert_eq!(extra.len(), 1);
475        assert!(extra.contains_key("x-custom-header"));
476        assert!(!extra.contains_key("version"));
477    }
478
479    #[test]
480    fn test_multi_patch_parsing() {
481        let mut parser = MessageParser::new();
482        let data = b"Patches: 2\r\n\r\n\
483                     Content-Length: 5\r\n\
484                     Content-Range: json .a\r\n\r\n\
485                     hello\r\n\
486                     Content-Length: 5\r\n\
487                     Content-Range: json .b\r\n\r\n\
488                     world\r\n";
489
490        let messages = parser.feed(data).unwrap();
491        assert!(!messages.is_empty());
492        let msg = &messages[0];
493        assert_eq!(msg.patches.len(), 2);
494    }
495}