Skip to main content

kernex_core/
stream.rs

1//! Streaming types for SSE-based provider responses.
2
3/// An event emitted by a streaming provider.
4#[derive(Debug, Clone)]
5pub enum StreamEvent {
6    /// A text delta from the model.
7    TextDelta(String),
8    /// A JSON fragment from a tool input being streamed.
9    InputJsonDelta(String),
10    /// Streaming is complete.
11    Done,
12    /// A streaming-level error.
13    Error(String),
14}
15
16/// Accumulates [`StreamEvent`] deltas into a complete text response.
17#[derive(Debug, Default)]
18pub struct StreamAccumulator {
19    text: String,
20}
21
22impl StreamAccumulator {
23    /// Create a new accumulator.
24    pub fn new() -> Self {
25        Self::default()
26    }
27
28    /// Push a streaming event into the accumulator.
29    pub fn push(&mut self, event: &StreamEvent) {
30        if let StreamEvent::TextDelta(delta) = event {
31            self.text.push_str(delta);
32        }
33    }
34
35    /// Returns the accumulated text so far.
36    pub fn text(&self) -> &str {
37        &self.text
38    }
39
40    /// Consumes the accumulator and returns the final text.
41    pub fn into_text(self) -> String {
42        self.text
43    }
44}
45
46#[cfg(test)]
47mod tests {
48    use super::*;
49
50    #[test]
51    fn accumulates_text_deltas() {
52        let mut acc = StreamAccumulator::new();
53        acc.push(&StreamEvent::TextDelta("Hello".to_string()));
54        acc.push(&StreamEvent::TextDelta(", world".to_string()));
55        acc.push(&StreamEvent::Done);
56        assert_eq!(acc.text(), "Hello, world");
57    }
58
59    #[test]
60    fn ignores_non_text_events() {
61        let mut acc = StreamAccumulator::new();
62        acc.push(&StreamEvent::InputJsonDelta("{\"foo\":".to_string()));
63        acc.push(&StreamEvent::Done);
64        assert_eq!(acc.text(), "");
65    }
66
67    #[test]
68    fn into_text_consumes() {
69        let mut acc = StreamAccumulator::new();
70        acc.push(&StreamEvent::TextDelta("hi".to_string()));
71        assert_eq!(acc.into_text(), "hi");
72    }
73}