forge_orchestration/inference/
streaming.rs

1//! Streaming response support for AI/ML inference
2//!
3//! Provides Server-Sent Events (SSE) and token streaming for LLM inference.
4
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use futures::Stream;
8use serde::{Deserialize, Serialize};
9use tokio::sync::mpsc;
10
11/// Configuration for streaming responses
12#[derive(Debug, Clone)]
13pub struct StreamingConfig {
14    /// Buffer size for the stream channel
15    pub buffer_size: usize,
16    /// Include timing information in events
17    pub include_timing: bool,
18    /// Heartbeat interval in milliseconds (0 to disable)
19    pub heartbeat_ms: u64,
20}
21
22impl Default for StreamingConfig {
23    fn default() -> Self {
24        Self {
25            buffer_size: 32,
26            include_timing: true,
27            heartbeat_ms: 15000, // 15 seconds
28        }
29    }
30}
31
32/// A streaming event
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct StreamEvent {
35    /// Event type (e.g., "token", "done", "error")
36    pub event: String,
37    /// Event data
38    pub data: String,
39    /// Event ID (optional)
40    pub id: Option<String>,
41    /// Timestamp in milliseconds
42    pub timestamp_ms: Option<u64>,
43}
44
45impl StreamEvent {
46    /// Create a new stream event
47    pub fn new(event: impl Into<String>, data: impl Into<String>) -> Self {
48        Self {
49            event: event.into(),
50            data: data.into(),
51            id: None,
52            timestamp_ms: None,
53        }
54    }
55
56    /// Create a token event
57    pub fn token(token: impl Into<String>) -> Self {
58        Self::new("token", token)
59    }
60
61    /// Create a done event
62    pub fn done() -> Self {
63        Self::new("done", "[DONE]")
64    }
65
66    /// Create an error event
67    pub fn error(msg: impl Into<String>) -> Self {
68        Self::new("error", msg)
69    }
70
71    /// Create a heartbeat event
72    pub fn heartbeat() -> Self {
73        Self::new("heartbeat", "")
74    }
75
76    /// Set event ID
77    pub fn with_id(mut self, id: impl Into<String>) -> Self {
78        self.id = Some(id.into());
79        self
80    }
81
82    /// Set timestamp
83    pub fn with_timestamp(mut self, timestamp_ms: u64) -> Self {
84        self.timestamp_ms = Some(timestamp_ms);
85        self
86    }
87
88    /// Format as SSE
89    pub fn to_sse(&self) -> String {
90        let mut result = String::new();
91        
92        if let Some(id) = &self.id {
93            result.push_str(&format!("id: {}\n", id));
94        }
95        
96        result.push_str(&format!("event: {}\n", self.event));
97        
98        // Handle multi-line data
99        for line in self.data.lines() {
100            result.push_str(&format!("data: {}\n", line));
101        }
102        if self.data.is_empty() {
103            result.push_str("data: \n");
104        }
105        
106        result.push('\n');
107        result
108    }
109}
110
111/// Streaming response for inference
112pub struct StreamingResponse {
113    rx: mpsc::Receiver<StreamEvent>,
114    config: StreamingConfig,
115}
116
117impl StreamingResponse {
118    /// Create a new streaming response with a sender
119    pub fn new(config: StreamingConfig) -> (Self, StreamSender) {
120        let (tx, rx) = mpsc::channel(config.buffer_size);
121        let response = Self { rx, config };
122        let sender = StreamSender { tx };
123        (response, sender)
124    }
125
126    /// Create with default config
127    pub fn default_config() -> (Self, StreamSender) {
128        Self::new(StreamingConfig::default())
129    }
130}
131
132impl Stream for StreamingResponse {
133    type Item = StreamEvent;
134
135    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
136        Pin::new(&mut self.rx).poll_recv(cx)
137    }
138}
139
140/// Sender for streaming events
141#[derive(Clone)]
142pub struct StreamSender {
143    tx: mpsc::Sender<StreamEvent>,
144}
145
146impl StreamSender {
147    /// Send an event
148    pub async fn send(&self, event: StreamEvent) -> Result<(), StreamError> {
149        self.tx.send(event).await.map_err(|_| StreamError::Closed)
150    }
151
152    /// Send a token
153    pub async fn send_token(&self, token: impl Into<String>) -> Result<(), StreamError> {
154        self.send(StreamEvent::token(token)).await
155    }
156
157    /// Send done signal
158    pub async fn send_done(&self) -> Result<(), StreamError> {
159        self.send(StreamEvent::done()).await
160    }
161
162    /// Send error
163    pub async fn send_error(&self, msg: impl Into<String>) -> Result<(), StreamError> {
164        self.send(StreamEvent::error(msg)).await
165    }
166
167    /// Check if receiver is still connected
168    pub fn is_closed(&self) -> bool {
169        self.tx.is_closed()
170    }
171}
172
173/// Streaming error
174#[derive(Debug, thiserror::Error)]
175pub enum StreamError {
176    /// Stream was closed
177    #[error("Stream closed")]
178    Closed,
179    /// Send failed
180    #[error("Send failed: {0}")]
181    SendFailed(String),
182}
183
184/// Token stream for LLM inference
185pub struct TokenStream {
186    tokens: Vec<String>,
187    index: usize,
188    delay_ms: u64,
189}
190
191impl TokenStream {
192    /// Create a new token stream from a list of tokens
193    pub fn new(tokens: Vec<String>) -> Self {
194        Self {
195            tokens,
196            index: 0,
197            delay_ms: 0,
198        }
199    }
200
201    /// Set delay between tokens (for simulation)
202    pub fn with_delay(mut self, delay_ms: u64) -> Self {
203        self.delay_ms = delay_ms;
204        self
205    }
206
207    /// Stream tokens to a sender
208    pub async fn stream_to(&mut self, sender: &StreamSender) -> Result<(), StreamError> {
209        while let Some(token) = self.next_token() {
210            sender.send_token(token).await?;
211            
212            if self.delay_ms > 0 {
213                tokio::time::sleep(std::time::Duration::from_millis(self.delay_ms)).await;
214            }
215        }
216        
217        sender.send_done().await
218    }
219
220    /// Get next token
221    pub fn next_token(&mut self) -> Option<String> {
222        if self.index < self.tokens.len() {
223            let token = self.tokens[self.index].clone();
224            self.index += 1;
225            Some(token)
226        } else {
227            None
228        }
229    }
230
231    /// Reset stream
232    pub fn reset(&mut self) {
233        self.index = 0;
234    }
235
236    /// Get remaining token count
237    pub fn remaining(&self) -> usize {
238        self.tokens.len().saturating_sub(self.index)
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    #[test]
247    fn test_stream_event_sse() {
248        let event = StreamEvent::token("Hello")
249            .with_id("1")
250            .with_timestamp(12345);
251        
252        let sse = event.to_sse();
253        assert!(sse.contains("id: 1"));
254        assert!(sse.contains("event: token"));
255        assert!(sse.contains("data: Hello"));
256    }
257
258    #[test]
259    fn test_stream_event_multiline() {
260        let event = StreamEvent::new("message", "line1\nline2\nline3");
261        let sse = event.to_sse();
262        
263        assert!(sse.contains("data: line1"));
264        assert!(sse.contains("data: line2"));
265        assert!(sse.contains("data: line3"));
266    }
267
268    #[tokio::test]
269    async fn test_streaming_response() {
270        let (mut response, sender) = StreamingResponse::default_config();
271        
272        tokio::spawn(async move {
273            sender.send_token("Hello").await.unwrap();
274            sender.send_token(" World").await.unwrap();
275            sender.send_done().await.unwrap();
276        });
277
278        use futures::StreamExt;
279        let events: Vec<_> = response.collect().await;
280        assert_eq!(events.len(), 3);
281        assert_eq!(events[0].event, "token");
282        assert_eq!(events[2].event, "done");
283    }
284
285    #[tokio::test]
286    async fn test_token_stream() {
287        let tokens = vec!["Hello".to_string(), " ".to_string(), "World".to_string()];
288        let mut stream = TokenStream::new(tokens);
289        
290        assert_eq!(stream.remaining(), 3);
291        assert_eq!(stream.next_token(), Some("Hello".to_string()));
292        assert_eq!(stream.remaining(), 2);
293        
294        stream.reset();
295        assert_eq!(stream.remaining(), 3);
296    }
297}