Skip to main content

highflame_shield/
stream.rs

1//! SSE streaming support for `POST /v1/guard/stream`.
2
3use crate::ShieldError;
4use async_stream::stream;
5use futures_util::StreamExt as _;
6use std::collections::HashMap;
7
8/// A single server-sent event from `POST /v1/guard/stream`.
9#[derive(Debug, Clone)]
10pub struct ShieldStreamEvent {
11    /// Event type: `"detection"`, `"decision"`, `"error"`, or `"done"`.
12    pub r#type: String,
13    /// Parsed JSON payload.
14    pub data: HashMap<String, serde_json::Value>,
15}
16
17/// Parse a raw SSE response body into a stream of [`ShieldStreamEvent`]s.
18///
19/// Implements the SSE wire protocol (RFC 8895) manually so the SDK stays
20/// dependency-free beyond `reqwest` and `async-stream`.  Handles:
21/// - Multi-chunk delivery (lines that span chunk boundaries)
22/// - Multiple events per chunk
23/// - `event:` / `data:` field pairs
24/// - `id:`, `retry:`, and comment lines are silently ignored
25pub(crate) fn parse_sse_response(
26    response: reqwest::Response,
27) -> impl futures_util::Stream<Item = Result<ShieldStreamEvent, ShieldError>> {
28    stream! {
29        let mut byte_stream = response.bytes_stream();
30        // Accumulates bytes across chunk boundaries until a complete line is seen.
31        let mut line_buf = String::new();
32        // SSE parser state for the current event block.
33        let mut current_event: Option<String> = None;
34        let mut current_data: Option<String> = None;
35
36        while let Some(chunk) = byte_stream.next().await {
37            let chunk = match chunk {
38                Ok(c) => c,
39                Err(e) => {
40                    yield Err(ShieldError::Connection(e.to_string()));
41                    return;
42                }
43            };
44
45            line_buf.push_str(&String::from_utf8_lossy(&chunk));
46
47            // Drain complete lines from the buffer.
48            loop {
49                match line_buf.find('\n') {
50                    None => break,
51                    Some(pos) => {
52                        let raw = line_buf[..pos].trim_end_matches('\r').to_string();
53                        line_buf.drain(..=pos);
54
55                        if raw.is_empty() {
56                            // Blank line — dispatch the accumulated event.
57                            if let Some(data) = current_data.take() {
58                                let event_type = current_event
59                                    .take()
60                                    .unwrap_or_else(|| "detection".to_string());
61                                match serde_json::from_str::<HashMap<String, serde_json::Value>>(
62                                    &data,
63                                ) {
64                                    Ok(parsed) => yield Ok(ShieldStreamEvent {
65                                        r#type: event_type,
66                                        data: parsed,
67                                    }),
68                                    Err(e) => {
69                                        yield Err(ShieldError::Deserialisation(e));
70                                        return;
71                                    }
72                                }
73                            } else {
74                                current_event = None;
75                            }
76                        } else if let Some(rest) = raw.strip_prefix("data:") {
77                            current_data = Some(rest.trim().to_string());
78                        } else if let Some(rest) = raw.strip_prefix("event:") {
79                            current_event = Some(rest.trim().to_string());
80                        }
81                        // Ignore id:, retry:, and comment lines (:...).
82                    }
83                }
84            }
85        }
86
87        // Flush any unterminated trailing message (server closed without blank line).
88        if let Some(data) = current_data.take() {
89            let event_type = current_event
90                .take()
91                .unwrap_or_else(|| "detection".to_string());
92            match serde_json::from_str::<HashMap<String, serde_json::Value>>(&data) {
93                Ok(parsed) => yield Ok(ShieldStreamEvent { r#type: event_type, data: parsed }),
94                Err(e) => yield Err(ShieldError::Deserialisation(e)),
95            }
96        }
97    }
98}