kaccy_ai/llm/
streaming.rs1use async_trait::async_trait;
6use futures::Stream;
7use serde::{Deserialize, Serialize};
8use std::pin::Pin;
9
10use super::types::{ChatMessage, ChatRequest, ChatRole};
11use crate::error::Result;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct StreamChunk {
16 pub delta: String,
18 pub is_final: bool,
20 pub stop_reason: Option<String>,
22 pub index: u32,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct StreamingChatResponse {
29 pub text: String,
31 pub prompt_tokens: Option<u32>,
33 pub completion_tokens: Option<u32>,
35 pub stop_reason: Option<String>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct StreamingChatRequest {
42 pub request: ChatRequest,
44 pub include_usage: bool,
46}
47
48impl StreamingChatRequest {
49 #[must_use]
51 pub fn new(request: ChatRequest) -> Self {
52 Self {
53 request,
54 include_usage: true,
55 }
56 }
57
58 pub fn with_system(system: impl Into<String>, user: impl Into<String>) -> Self {
60 Self::new(ChatRequest::with_system(system, user))
61 }
62
63 #[must_use]
65 pub fn include_usage(mut self, include: bool) -> Self {
66 self.include_usage = include;
67 self
68 }
69}
70
71pub type StreamResponse = Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send>>;
73
74#[async_trait]
76pub trait StreamingLlmProvider: Send + Sync {
77 async fn chat_stream(&self, request: StreamingChatRequest) -> Result<StreamResponse>;
79}
80
81pub struct StreamAccumulator {
83 text: String,
84 prompt_tokens: Option<u32>,
85 completion_tokens: Option<u32>,
86 stop_reason: Option<String>,
87}
88
89impl Default for StreamAccumulator {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95impl StreamAccumulator {
96 #[must_use]
98 pub fn new() -> Self {
99 Self {
100 text: String::new(),
101 prompt_tokens: None,
102 completion_tokens: None,
103 stop_reason: None,
104 }
105 }
106
107 pub fn add_chunk(&mut self, chunk: &StreamChunk) {
109 self.text.push_str(&chunk.delta);
110 if chunk.is_final {
111 self.stop_reason = chunk.stop_reason.clone();
112 }
113 }
114
115 pub fn set_usage(&mut self, prompt_tokens: u32, completion_tokens: u32) {
117 self.prompt_tokens = Some(prompt_tokens);
118 self.completion_tokens = Some(completion_tokens);
119 }
120
121 #[must_use]
123 pub fn build(self) -> StreamingChatResponse {
124 StreamingChatResponse {
125 text: self.text,
126 prompt_tokens: self.prompt_tokens,
127 completion_tokens: self.completion_tokens,
128 stop_reason: self.stop_reason,
129 }
130 }
131
132 #[must_use]
134 pub fn text(&self) -> &str {
135 &self.text
136 }
137
138 #[must_use]
140 pub fn len(&self) -> usize {
141 self.text.len()
142 }
143
144 #[must_use]
146 pub fn is_empty(&self) -> bool {
147 self.text.is_empty()
148 }
149}
150
151pub async fn collect_stream(mut stream: StreamResponse) -> Result<StreamingChatResponse> {
153 use futures::StreamExt;
154
155 let mut accumulator = StreamAccumulator::new();
156
157 while let Some(chunk_result) = stream.next().await {
158 let chunk = chunk_result?;
159 accumulator.add_chunk(&chunk);
160 }
161
162 Ok(accumulator.build())
163}
164
165impl From<StreamingChatResponse> for ChatMessage {
167 fn from(response: StreamingChatResponse) -> Self {
168 ChatMessage {
169 role: ChatRole::Assistant,
170 content: response.text,
171 }
172 }
173}
174
175pub type StreamCallback = Box<dyn Fn(&StreamChunk) + Send + Sync>;
177
178pub struct StreamHandler {
180 callback: StreamCallback,
181 accumulator: StreamAccumulator,
182}
183
184impl StreamHandler {
185 pub fn new(callback: impl Fn(&StreamChunk) + Send + Sync + 'static) -> Self {
187 Self {
188 callback: Box::new(callback),
189 accumulator: StreamAccumulator::new(),
190 }
191 }
192
193 pub fn handle_chunk(&mut self, chunk: &StreamChunk) {
195 (self.callback)(chunk);
196 self.accumulator.add_chunk(chunk);
197 }
198
199 #[must_use]
201 pub fn finish(self) -> StreamingChatResponse {
202 self.accumulator.build()
203 }
204}
205
206#[must_use]
208pub fn print_handler() -> StreamHandler {
209 StreamHandler::new(|chunk| {
210 print!("{}", chunk.delta);
211 if chunk.is_final {
212 println!();
213 }
214 })
215}