Skip to main content

ai_lib_rust/pipeline/
decode.rs

1//! Streaming decoders (Bytes -> JSON Value)
2//!
3//! This module intentionally keeps provider logic out of code: it decodes *formats*
4//! (SSE, NDJSON, etc.) based on manifest configuration.
5
6use crate::pipeline::{Decoder, PipelineError};
7use crate::protocol::DecoderConfig;
8use crate::{BoxStream, PipeResult};
9use bytes::Bytes;
10use futures::{stream, StreamExt};
11use serde_json::Value;
12
13/// A minimal, manifest-driven SSE decoder:
14/// - splits by delimiter (default "\n\n")
15/// - strips `prefix` (default "data: ")
16/// - stops on `done_signal` (default "[DONE]")
17pub struct SseDecoder {
18    delimiter: String,
19    prefix: String,
20    done_signal: String,
21}
22
23impl SseDecoder {
24    pub fn new(
25        delimiter: Option<String>,
26        prefix: Option<String>,
27        done_signal: Option<String>,
28    ) -> Self {
29        Self {
30            delimiter: delimiter.unwrap_or_else(|| "\n\n".to_string()),
31            prefix: prefix.unwrap_or_else(|| "data: ".to_string()),
32            done_signal: done_signal.unwrap_or_else(|| "[DONE]".to_string()),
33        }
34    }
35
36    pub fn from_config(cfg: &DecoderConfig) -> Result<Self, PipelineError> {
37        Ok(Self::new(
38            cfg.delimiter.clone(),
39            cfg.prefix.clone(),
40            cfg.done_signal.clone(),
41        ))
42    }
43
44    // NOTE: Parsing is implemented inside `decode_stream()` so we can construct streams that do not
45    // borrow `&self` (required to return `'static` streams for higher-level retry/fallback).
46}
47
48#[async_trait::async_trait]
49impl Decoder for SseDecoder {
50    async fn decode_stream(
51        &self,
52        input: BoxStream<'static, Bytes>,
53    ) -> PipeResult<BoxStream<'static, Value>> {
54        let delimiter = self.delimiter.clone();
55        let delimiter_len = delimiter.len();
56        let prefix = self.prefix.clone();
57        let done_signal = self.done_signal.clone();
58
59        // Incrementally buffer bytes and emit full frames split by delimiter.
60        let stream = stream::unfold((input, String::new()), move |(mut input, mut buf)| {
61            let delimiter = delimiter.clone();
62            let prefix = prefix.clone();
63            let done_signal = done_signal.clone();
64            async move {
65                let is_done = |s: &str| -> bool {
66                    let t = s.trim();
67                    t == done_signal
68                        || t == format!("data: {}", done_signal)
69                        || t == format!("data:{}", done_signal)
70                };
71
72                let parse_payload = |raw: &str| -> Option<Value> {
73                    let trimmed = raw.trim();
74                    if trimmed.is_empty() || is_done(trimmed) {
75                        return None;
76                    }
77
78                    // Ignore SSE comment lines
79                    if trimmed.starts_with(':') {
80                        return None;
81                    }
82
83                    // Strip prefix if present
84                    let payload = if trimmed.starts_with(&prefix) {
85                        &trimmed[prefix.len()..]
86                    } else if trimmed.starts_with("data:") {
87                        trimmed[5..].trim_start()
88                    } else {
89                        trimmed
90                    };
91
92                    serde_json::from_str(payload).ok()
93                };
94
95                loop {
96                    // If we have a full frame in buffer, emit it.
97                    if let Some(idx) = buf.find(&delimiter) {
98                        let frame = buf[..idx].to_string();
99                        let rest_start = idx + delimiter_len;
100                        buf = if rest_start <= buf.len() {
101                            buf[rest_start..].to_string()
102                        } else {
103                            String::new()
104                        };
105
106                        if is_done(&frame) {
107                            return None;
108                        }
109
110                        if let Some(v) = parse_payload(&frame) {
111                            return Some((Ok(v), (input, buf)));
112                        }
113
114                        // Skip non-json frames; keep looping.
115                        continue;
116                    }
117
118                    // Need more data.
119                    match input.next().await {
120                        Some(Ok(bytes)) => {
121                            let s = String::from_utf8_lossy(&bytes);
122                            buf.push_str(&s);
123                            continue;
124                        }
125                        Some(Err(e)) => {
126                            return Some((Err(e), (input, buf)));
127                        }
128                        None => {
129                            // EOF: try parse remaining buffer once
130                            if is_done(&buf) {
131                                return None;
132                            }
133                            if let Some(v) = parse_payload(&buf) {
134                                return Some((Ok(v), (input, String::new())));
135                            }
136                            return None;
137                        }
138                    }
139                }
140            }
141        });
142
143        Ok(Box::pin(stream))
144    }
145}
146
147/// NDJSON / JSONL decoder (one JSON object per line).
148pub struct NdjsonDecoder;
149
150#[async_trait::async_trait]
151impl Decoder for NdjsonDecoder {
152    async fn decode_stream(
153        &self,
154        input: BoxStream<'static, Bytes>,
155    ) -> PipeResult<BoxStream<'static, Value>> {
156        let stream = stream::unfold(
157            (input, String::new()),
158            move |(mut input, mut buf)| async move {
159                loop {
160                    if let Some(idx) = buf.find('\n') {
161                        let line = buf[..idx].trim().to_string();
162                        buf = buf[idx + 1..].to_string();
163                        if line.is_empty() {
164                            continue;
165                        }
166                        match serde_json::from_str::<Value>(&line) {
167                            Ok(v) => return Some((Ok(v), (input, buf))),
168                            Err(e) => {
169                                return Some((Err(crate::Error::Serialization(e)), (input, buf)))
170                            }
171                        }
172                    }
173
174                    match input.next().await {
175                        Some(Ok(bytes)) => {
176                            let s = String::from_utf8_lossy(&bytes);
177                            buf.push_str(&s);
178                            continue;
179                        }
180                        Some(Err(e)) => return Some((Err(e), (input, buf))),
181                        None => {
182                            let line = buf.trim();
183                            if line.is_empty() {
184                                return None;
185                            }
186                            match serde_json::from_str::<Value>(line) {
187                                Ok(v) => return Some((Ok(v), (input, String::new()))),
188                                Err(_) => return None,
189                            }
190                        }
191                    }
192                }
193            },
194        );
195
196        Ok(Box::pin(stream))
197    }
198}
199
200pub fn create_decoder(cfg: &DecoderConfig) -> Result<Box<dyn Decoder>, PipelineError> {
201    match cfg.format.as_str() {
202        "sse" => Ok(Box::new(SseDecoder::from_config(cfg)?)),
203        // Many providers (e.g. Anthropic) still speak SSE but differ in event semantics.
204        // We keep this manifest-driven and treat it as standard SSE framing.
205        "anthropic_sse" => Ok(Box::new(SseDecoder::from_config(cfg)?)),
206        "ndjson" | "jsonl" => Ok(Box::new(NdjsonDecoder)),
207        other => Err(PipelineError::Configuration(format!(
208            "Unsupported decoder format: {}. Supported formats: sse, jsonl, ndjson",
209            other
210        ))),
211    }
212}