use crate::agent::Agent;
use crate::completion::{
CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message,
};
use futures::{Stream, StreamExt};
use std::boxed::Box;
use std::fmt::{Display, Formatter};
use std::future::Future;
use std::pin::Pin;
#[derive(Debug)]
pub enum StreamingChoice {
Message(String),
ToolCall(String, String, serde_json::Value),
}
impl Display for StreamingChoice {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
StreamingChoice::Message(text) => write!(f, "{}", text),
StreamingChoice::ToolCall(name, id, params) => {
write!(f, "Tool call: {} {} {:?}", name, id, params)
}
}
}
}
#[cfg(not(target_arch = "wasm32"))]
pub type StreamingResult =
Pin<Box<dyn Stream<Item = Result<StreamingChoice, CompletionError>> + Send>>;
#[cfg(target_arch = "wasm32")]
pub type StreamingResult = Pin<Box<dyn Stream<Item = Result<StreamingChoice, CompletionError>>>>;
pub trait StreamingPrompt: Send + Sync {
fn stream_prompt(
&self,
prompt: &str,
) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
}
pub trait StreamingChat: Send + Sync {
fn stream_chat(
&self,
prompt: &str,
chat_history: Vec<Message>,
) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
}
pub trait StreamingCompletion<M: StreamingCompletionModel> {
fn stream_completion(
&self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
}
pub trait StreamingCompletionModel: CompletionModel {
fn stream(
&self,
request: CompletionRequest,
) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
}
pub async fn stream_to_stdout<M: StreamingCompletionModel>(
agent: Agent<M>,
stream: &mut StreamingResult,
) -> Result<(), std::io::Error> {
print!("Response: ");
while let Some(chunk) = stream.next().await {
match chunk {
Ok(StreamingChoice::Message(text)) => {
print!("{}", text);
std::io::Write::flush(&mut std::io::stdout())?;
}
Ok(StreamingChoice::ToolCall(name, _, params)) => {
let res = agent
.tools
.call(&name, params.to_string())
.await
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
println!("\nResult: {}", res);
}
Err(e) => {
eprintln!("Error: {}", e);
break;
}
}
}
println!();
Ok(())
}