1use crate::agent::Agent;
13use crate::completion::{
14 CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message,
15};
16use futures::{Stream, StreamExt};
17use std::boxed::Box;
18use std::fmt::{Display, Formatter};
19use std::future::Future;
20use std::pin::Pin;
21
22#[derive(Debug)]
24pub enum StreamingChoice {
25 Message(String),
27
28 ToolCall(String, String, serde_json::Value),
30}
31
32impl Display for StreamingChoice {
33 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34 match self {
35 StreamingChoice::Message(text) => write!(f, "{}", text),
36 StreamingChoice::ToolCall(name, id, params) => {
37 write!(f, "Tool call: {} {} {:?}", name, id, params)
38 }
39 }
40 }
41}
42
43#[cfg(not(target_arch = "wasm32"))]
44pub type StreamingResult =
45 Pin<Box<dyn Stream<Item = Result<StreamingChoice, CompletionError>> + Send>>;
46
47#[cfg(target_arch = "wasm32")]
48pub type StreamingResult = Pin<Box<dyn Stream<Item = Result<StreamingChoice, CompletionError>>>>;
49
50pub trait StreamingPrompt: Send + Sync {
52 fn stream_prompt(
54 &self,
55 prompt: &str,
56 ) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
57}
58
59pub trait StreamingChat: Send + Sync {
61 fn stream_chat(
63 &self,
64 prompt: &str,
65 chat_history: Vec<Message>,
66 ) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
67}
68
69pub trait StreamingCompletion<M: StreamingCompletionModel> {
71 fn stream_completion(
73 &self,
74 prompt: &str,
75 chat_history: Vec<Message>,
76 ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
77}
78
79pub trait StreamingCompletionModel: CompletionModel {
81 fn stream(
83 &self,
84 request: CompletionRequest,
85 ) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
86}
87
88pub async fn stream_to_stdout<M: StreamingCompletionModel>(
90 agent: Agent<M>,
91 stream: &mut StreamingResult,
92) -> Result<(), std::io::Error> {
93 print!("Response: ");
94 while let Some(chunk) = stream.next().await {
95 match chunk {
96 Ok(StreamingChoice::Message(text)) => {
97 print!("{}", text);
98 std::io::Write::flush(&mut std::io::stdout())?;
99 }
100 Ok(StreamingChoice::ToolCall(name, _, params)) => {
101 let res = agent
102 .tools
103 .call(&name, params.to_string())
104 .await
105 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
106 println!("\nResult: {}", res);
107 }
108 Err(e) => {
109 eprintln!("Error: {}", e);
110 break;
111 }
112 }
113 }
114 println!(); Ok(())
117}