forge_orchestration/inference/
streaming.rs1use std::pin::Pin;
6use std::task::{Context, Poll};
7use futures::Stream;
8use serde::{Deserialize, Serialize};
9use tokio::sync::mpsc;
10
11#[derive(Debug, Clone)]
13pub struct StreamingConfig {
14 pub buffer_size: usize,
16 pub include_timing: bool,
18 pub heartbeat_ms: u64,
20}
21
22impl Default for StreamingConfig {
23 fn default() -> Self {
24 Self {
25 buffer_size: 32,
26 include_timing: true,
27 heartbeat_ms: 15000, }
29 }
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct StreamEvent {
35 pub event: String,
37 pub data: String,
39 pub id: Option<String>,
41 pub timestamp_ms: Option<u64>,
43}
44
45impl StreamEvent {
46 pub fn new(event: impl Into<String>, data: impl Into<String>) -> Self {
48 Self {
49 event: event.into(),
50 data: data.into(),
51 id: None,
52 timestamp_ms: None,
53 }
54 }
55
56 pub fn token(token: impl Into<String>) -> Self {
58 Self::new("token", token)
59 }
60
61 pub fn done() -> Self {
63 Self::new("done", "[DONE]")
64 }
65
66 pub fn error(msg: impl Into<String>) -> Self {
68 Self::new("error", msg)
69 }
70
71 pub fn heartbeat() -> Self {
73 Self::new("heartbeat", "")
74 }
75
76 pub fn with_id(mut self, id: impl Into<String>) -> Self {
78 self.id = Some(id.into());
79 self
80 }
81
82 pub fn with_timestamp(mut self, timestamp_ms: u64) -> Self {
84 self.timestamp_ms = Some(timestamp_ms);
85 self
86 }
87
88 pub fn to_sse(&self) -> String {
90 let mut result = String::new();
91
92 if let Some(id) = &self.id {
93 result.push_str(&format!("id: {}\n", id));
94 }
95
96 result.push_str(&format!("event: {}\n", self.event));
97
98 for line in self.data.lines() {
100 result.push_str(&format!("data: {}\n", line));
101 }
102 if self.data.is_empty() {
103 result.push_str("data: \n");
104 }
105
106 result.push('\n');
107 result
108 }
109}
110
111pub struct StreamingResponse {
113 rx: mpsc::Receiver<StreamEvent>,
114 config: StreamingConfig,
115}
116
117impl StreamingResponse {
118 pub fn new(config: StreamingConfig) -> (Self, StreamSender) {
120 let (tx, rx) = mpsc::channel(config.buffer_size);
121 let response = Self { rx, config };
122 let sender = StreamSender { tx };
123 (response, sender)
124 }
125
126 pub fn default_config() -> (Self, StreamSender) {
128 Self::new(StreamingConfig::default())
129 }
130}
131
132impl Stream for StreamingResponse {
133 type Item = StreamEvent;
134
135 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
136 Pin::new(&mut self.rx).poll_recv(cx)
137 }
138}
139
140#[derive(Clone)]
142pub struct StreamSender {
143 tx: mpsc::Sender<StreamEvent>,
144}
145
146impl StreamSender {
147 pub async fn send(&self, event: StreamEvent) -> Result<(), StreamError> {
149 self.tx.send(event).await.map_err(|_| StreamError::Closed)
150 }
151
152 pub async fn send_token(&self, token: impl Into<String>) -> Result<(), StreamError> {
154 self.send(StreamEvent::token(token)).await
155 }
156
157 pub async fn send_done(&self) -> Result<(), StreamError> {
159 self.send(StreamEvent::done()).await
160 }
161
162 pub async fn send_error(&self, msg: impl Into<String>) -> Result<(), StreamError> {
164 self.send(StreamEvent::error(msg)).await
165 }
166
167 pub fn is_closed(&self) -> bool {
169 self.tx.is_closed()
170 }
171}
172
173#[derive(Debug, thiserror::Error)]
175pub enum StreamError {
176 #[error("Stream closed")]
178 Closed,
179 #[error("Send failed: {0}")]
181 SendFailed(String),
182}
183
184pub struct TokenStream {
186 tokens: Vec<String>,
187 index: usize,
188 delay_ms: u64,
189}
190
191impl TokenStream {
192 pub fn new(tokens: Vec<String>) -> Self {
194 Self {
195 tokens,
196 index: 0,
197 delay_ms: 0,
198 }
199 }
200
201 pub fn with_delay(mut self, delay_ms: u64) -> Self {
203 self.delay_ms = delay_ms;
204 self
205 }
206
207 pub async fn stream_to(&mut self, sender: &StreamSender) -> Result<(), StreamError> {
209 while let Some(token) = self.next_token() {
210 sender.send_token(token).await?;
211
212 if self.delay_ms > 0 {
213 tokio::time::sleep(std::time::Duration::from_millis(self.delay_ms)).await;
214 }
215 }
216
217 sender.send_done().await
218 }
219
220 pub fn next_token(&mut self) -> Option<String> {
222 if self.index < self.tokens.len() {
223 let token = self.tokens[self.index].clone();
224 self.index += 1;
225 Some(token)
226 } else {
227 None
228 }
229 }
230
231 pub fn reset(&mut self) {
233 self.index = 0;
234 }
235
236 pub fn remaining(&self) -> usize {
238 self.tokens.len().saturating_sub(self.index)
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 #[test]
247 fn test_stream_event_sse() {
248 let event = StreamEvent::token("Hello")
249 .with_id("1")
250 .with_timestamp(12345);
251
252 let sse = event.to_sse();
253 assert!(sse.contains("id: 1"));
254 assert!(sse.contains("event: token"));
255 assert!(sse.contains("data: Hello"));
256 }
257
258 #[test]
259 fn test_stream_event_multiline() {
260 let event = StreamEvent::new("message", "line1\nline2\nline3");
261 let sse = event.to_sse();
262
263 assert!(sse.contains("data: line1"));
264 assert!(sse.contains("data: line2"));
265 assert!(sse.contains("data: line3"));
266 }
267
268 #[tokio::test]
269 async fn test_streaming_response() {
270 let (mut response, sender) = StreamingResponse::default_config();
271
272 tokio::spawn(async move {
273 sender.send_token("Hello").await.unwrap();
274 sender.send_token(" World").await.unwrap();
275 sender.send_done().await.unwrap();
276 });
277
278 use futures::StreamExt;
279 let events: Vec<_> = response.collect().await;
280 assert_eq!(events.len(), 3);
281 assert_eq!(events[0].event, "token");
282 assert_eq!(events[2].event, "done");
283 }
284
285 #[tokio::test]
286 async fn test_token_stream() {
287 let tokens = vec!["Hello".to_string(), " ".to_string(), "World".to_string()];
288 let mut stream = TokenStream::new(tokens);
289
290 assert_eq!(stream.remaining(), 3);
291 assert_eq!(stream.next_token(), Some("Hello".to_string()));
292 assert_eq!(stream.remaining(), 2);
293
294 stream.reset();
295 assert_eq!(stream.remaining(), 3);
296 }
297}