kaccy_ai/llm/
streaming.rs

1//! Streaming support for LLM providers
2//!
3//! Provides real-time token-by-token streaming for chat completions.
4
5use async_trait::async_trait;
6use futures::Stream;
7use serde::{Deserialize, Serialize};
8use std::pin::Pin;
9
10use super::types::{ChatMessage, ChatRequest, ChatRole};
11use crate::error::Result;
12
13/// A chunk of streamed response
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct StreamChunk {
16    /// The text delta for this chunk
17    pub delta: String,
18    /// Whether this is the final chunk
19    pub is_final: bool,
20    /// Stop reason (only set on final chunk)
21    pub stop_reason: Option<String>,
22    /// Index of the choice (for multi-response)
23    pub index: u32,
24}
25
26/// Final streaming response with usage info
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct StreamingChatResponse {
29    /// Complete accumulated text
30    pub text: String,
31    /// Prompt tokens used
32    pub prompt_tokens: Option<u32>,
33    /// Completion tokens generated
34    pub completion_tokens: Option<u32>,
35    /// Stop reason
36    pub stop_reason: Option<String>,
37}
38
39/// Streaming request configuration
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct StreamingChatRequest {
42    /// Base chat request
43    pub request: ChatRequest,
44    /// Include usage information in final message
45    pub include_usage: bool,
46}
47
48impl StreamingChatRequest {
49    /// Create a new streaming request from a chat request
50    #[must_use]
51    pub fn new(request: ChatRequest) -> Self {
52        Self {
53            request,
54            include_usage: true,
55        }
56    }
57
58    /// Create with system prompt and user message
59    pub fn with_system(system: impl Into<String>, user: impl Into<String>) -> Self {
60        Self::new(ChatRequest::with_system(system, user))
61    }
62
63    /// Set whether to include usage info
64    #[must_use]
65    pub fn include_usage(mut self, include: bool) -> Self {
66        self.include_usage = include;
67        self
68    }
69}
70
71/// Type alias for streaming response
72pub type StreamResponse = Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send>>;
73
74/// Trait for LLM providers that support streaming
75#[async_trait]
76pub trait StreamingLlmProvider: Send + Sync {
77    /// Stream a chat completion response
78    async fn chat_stream(&self, request: StreamingChatRequest) -> Result<StreamResponse>;
79}
80
81/// Accumulator for building complete response from stream
82pub struct StreamAccumulator {
83    text: String,
84    prompt_tokens: Option<u32>,
85    completion_tokens: Option<u32>,
86    stop_reason: Option<String>,
87}
88
89impl Default for StreamAccumulator {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95impl StreamAccumulator {
96    /// Create a new accumulator
97    #[must_use]
98    pub fn new() -> Self {
99        Self {
100            text: String::new(),
101            prompt_tokens: None,
102            completion_tokens: None,
103            stop_reason: None,
104        }
105    }
106
107    /// Add a chunk to the accumulator
108    pub fn add_chunk(&mut self, chunk: &StreamChunk) {
109        self.text.push_str(&chunk.delta);
110        if chunk.is_final {
111            self.stop_reason = chunk.stop_reason.clone();
112        }
113    }
114
115    /// Set usage information
116    pub fn set_usage(&mut self, prompt_tokens: u32, completion_tokens: u32) {
117        self.prompt_tokens = Some(prompt_tokens);
118        self.completion_tokens = Some(completion_tokens);
119    }
120
121    /// Build the final response
122    #[must_use]
123    pub fn build(self) -> StreamingChatResponse {
124        StreamingChatResponse {
125            text: self.text,
126            prompt_tokens: self.prompt_tokens,
127            completion_tokens: self.completion_tokens,
128            stop_reason: self.stop_reason,
129        }
130    }
131
132    /// Get the accumulated text so far
133    #[must_use]
134    pub fn text(&self) -> &str {
135        &self.text
136    }
137
138    /// Get current length
139    #[must_use]
140    pub fn len(&self) -> usize {
141        self.text.len()
142    }
143
144    /// Check if empty
145    #[must_use]
146    pub fn is_empty(&self) -> bool {
147        self.text.is_empty()
148    }
149}
150
151/// Helper to convert a streaming response to a complete message
152pub async fn collect_stream(mut stream: StreamResponse) -> Result<StreamingChatResponse> {
153    use futures::StreamExt;
154
155    let mut accumulator = StreamAccumulator::new();
156
157    while let Some(chunk_result) = stream.next().await {
158        let chunk = chunk_result?;
159        accumulator.add_chunk(&chunk);
160    }
161
162    Ok(accumulator.build())
163}
164
165/// Convert streaming response to `ChatMessage`
166impl From<StreamingChatResponse> for ChatMessage {
167    fn from(response: StreamingChatResponse) -> Self {
168        ChatMessage {
169            role: ChatRole::Assistant,
170            content: response.text,
171        }
172    }
173}
174
175/// Callback type for handling stream chunks
176pub type StreamCallback = Box<dyn Fn(&StreamChunk) + Send + Sync>;
177
178/// Stream handler that calls a callback for each chunk
179pub struct StreamHandler {
180    callback: StreamCallback,
181    accumulator: StreamAccumulator,
182}
183
184impl StreamHandler {
185    /// Create a new stream handler with a callback
186    pub fn new(callback: impl Fn(&StreamChunk) + Send + Sync + 'static) -> Self {
187        Self {
188            callback: Box::new(callback),
189            accumulator: StreamAccumulator::new(),
190        }
191    }
192
193    /// Process a chunk
194    pub fn handle_chunk(&mut self, chunk: &StreamChunk) {
195        (self.callback)(chunk);
196        self.accumulator.add_chunk(chunk);
197    }
198
199    /// Get the accumulated response
200    #[must_use]
201    pub fn finish(self) -> StreamingChatResponse {
202        self.accumulator.build()
203    }
204}
205
206/// Simple print handler for debugging
207#[must_use]
208pub fn print_handler() -> StreamHandler {
209    StreamHandler::new(|chunk| {
210        print!("{}", chunk.delta);
211        if chunk.is_final {
212            println!();
213        }
214    })
215}