1use crate::chat::{ChatStreamEvent, ChatStreamResponse, StreamChunk};
5use futures::StreamExt;
6use serde::{Deserialize, Serialize};
7use tokio::io::{AsyncWriteExt as _, Stdout};
8
9type Result<T> = core::result::Result<T, Error>;
11
12#[derive(Debug, Default, Serialize, Deserialize)]
16pub struct PrintChatStreamOptions {
17 print_events: Option<bool>,
19}
20
21impl PrintChatStreamOptions {
23 pub fn from_print_events(print_events: bool) -> Self {
25 PrintChatStreamOptions {
26 print_events: Some(print_events),
27 }
28 }
29}
30
31pub async fn print_chat_stream(
37 chat_res: ChatStreamResponse,
38 options: Option<&PrintChatStreamOptions>,
39) -> Result<String> {
40 let mut stdout = tokio::io::stdout();
41 let res = print_chat_stream_inner(&mut stdout, chat_res, options).await;
42
43 let flush_res = stdout.flush().await;
45
46 match (res, flush_res) {
47 (Err(e), Err(_flush_err)) => Err(e),
49
50 (Ok(_), Err(flush_err)) => Err(flush_err.into()),
52
53 (inner, _) => inner,
55 }
56}
57
58async fn print_chat_stream_inner(
59 stdout: &mut Stdout,
60 chat_res: ChatStreamResponse,
61 options: Option<&PrintChatStreamOptions>,
62) -> Result<String> {
63 let mut stream = chat_res.stream;
64
65 let mut content_capture = String::new();
66
67 let print_events = options.and_then(|o| o.print_events).unwrap_or_default();
68
69 let mut first_chunk = true;
70 let mut first_reasoning_chunk = true;
71 let mut first_tool_chunk = true;
72
73 while let Some(next) = stream.next().await {
74 let (event_info, print_content, capture_content_flag) = match next {
75 Ok(stream_event) => {
76 match stream_event {
77 ChatStreamEvent::Start => {
78 if print_events {
79 (Some("\n-- ChatStreamEvent::Start\n".to_string()), None, false)
81 } else {
82 (None, None, false)
83 }
84 }
85
86 ChatStreamEvent::Chunk(StreamChunk { content }) => {
87 if print_events && first_chunk {
88 first_chunk = false;
89 (
90 Some("\n-- ChatStreamEvent::Chunk (concatenated):\n".to_string()),
91 Some(content),
92 true,
93 )
94 } else {
95 (None, Some(content), true)
96 }
97 }
98
99 ChatStreamEvent::ReasoningChunk(StreamChunk { content }) => {
100 if print_events && first_reasoning_chunk {
101 first_reasoning_chunk = false;
102 (
103 Some("\n-- ChatStreamEvent::ReasoningChunk (concatenated):\n".to_string()),
104 Some(content),
105 false, )
107 } else {
108 (None, Some(content), false) }
110 }
111
112 ChatStreamEvent::ToolCallChunk(tool_chunk) => {
113 if print_events && first_tool_chunk {
114 first_tool_chunk = false;
115 (
116 Some(format!(
117 "\n-- ChatStreamEvent::ToolCallChunk: fn: {}, args: {}\n",
118 tool_chunk.tool_call.fn_name, tool_chunk.tool_call.fn_arguments
119 )),
120 None,
121 false,
122 )
123 } else {
124 (None, None, false)
125 }
126 }
127
128 ChatStreamEvent::End(end_event) => {
129 if print_events {
130 (
132 Some(format!("\n\n-- ChatStreamEvent::End {end_event:?}\n")),
133 None,
134 false,
135 )
136 } else {
137 (None, None, false)
138 }
139 }
140 }
141 }
142 Err(e) => return Err(e.into()),
143 };
144
145 if let Some(event_info) = event_info {
146 stdout.write_all(event_info.as_bytes()).await?;
147 }
148
149 if let Some(content) = print_content {
150 if capture_content_flag {
151 content_capture.push_str(&content);
152 }
153 stdout.write_all(content.as_bytes()).await?;
154 };
155
156 stdout.flush().await?;
157 }
158
159 stdout.write_all(b"\n").await?;
160
161 Ok(content_capture)
162}
163
164use derive_more::From;
173
174#[derive(Debug, From)]
176pub enum Error {
177 #[from]
179 TokioIo(tokio::io::Error),
180
181 #[from]
183 Stream(crate::Error),
184}
185
186impl core::fmt::Display for Error {
189 fn fmt(&self, fmt: &mut core::fmt::Formatter) -> core::result::Result<(), core::fmt::Error> {
190 write!(fmt, "{self:?}")
191 }
192}
193
194impl std::error::Error for Error {}
195
196