use std::io::Write;
use std::sync::Arc;
use futures_util::StreamExt;
use genai::Client;
use genai::chat::{ChatMessage, ChatRequest, ChatStreamEvent};
use tokio::sync::mpsc;
#[derive(Debug)]
pub enum StreamMsg {
Token(String),
Done,
Error(String),
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum ChatTurn {
User(String),
Assistant(String),
}
pub fn spawn_chat_stream(
client: Arc<Client>,
model: String,
system_prompt: Option<String>,
history: Vec<ChatTurn>,
user_prompt: String,
) -> mpsc::UnboundedReceiver<StreamMsg> {
let (tx, rx) = mpsc::unbounded_channel();
tokio::spawn(async move {
let mut messages: Vec<ChatMessage> = Vec::new();
if let Some(s) = system_prompt {
if !s.trim().is_empty() {
messages.push(ChatMessage::system(s));
}
}
for turn in history {
match turn {
ChatTurn::User(t) => messages.push(ChatMessage::user(t)),
ChatTurn::Assistant(t) => messages.push(ChatMessage::assistant(t)),
}
}
messages.push(ChatMessage::user(user_prompt));
let req = ChatRequest::new(messages);
let response = match client.exec_chat_stream(model.as_str(), req, None).await {
Ok(r) => r,
Err(e) => {
let _ = tx.send(StreamMsg::Error(format!("exec_chat_stream: {e}")));
return;
}
};
let mut stream = response.stream;
while let Some(event) = stream.next().await {
match event {
Ok(ChatStreamEvent::Chunk(chunk)) => {
if tx.send(StreamMsg::Token(chunk.content)).is_err() {
return;
}
}
Ok(ChatStreamEvent::ReasoningChunk(_))
| Ok(ChatStreamEvent::ThoughtSignatureChunk(_))
| Ok(ChatStreamEvent::ToolCallChunk(_))
| Ok(ChatStreamEvent::Start)
| Ok(ChatStreamEvent::End(_)) => {}
Err(e) => {
let _ = tx.send(StreamMsg::Error(format!("stream event: {e}")));
return;
}
}
}
let _ = tx.send(StreamMsg::Done);
});
rx
}
pub fn collect_blocking(
client: Arc<Client>,
model: String,
system_prompt: Option<String>,
prompt: String,
) -> Result<String, String> {
let mut rx = spawn_chat_stream(client, model, system_prompt, Vec::new(), prompt);
let mut raw = String::new();
while let Some(msg) = rx.blocking_recv() {
match msg {
StreamMsg::Token(t) => {
raw.push_str(&t);
let _ = std::io::stderr().write_all(b".");
let _ = std::io::stderr().flush();
}
StreamMsg::Done => break,
StreamMsg::Error(e) => return Err(e),
}
}
Ok(raw)
}