genai/chat/
printer.rs

1//! Printer utility to help print a chat stream
2//! > Note: This is primarily for quick testing and temporary debugging
3
4use crate::chat::{ChatStreamEvent, ChatStreamResponse, StreamChunk};
5use futures::StreamExt;
6use serde::{Deserialize, Serialize};
7use tokio::io::{AsyncWriteExt as _, Stdout};
8
9// Note: This module has its own Error type (see end of file)
10type Result<T> = core::result::Result<T, Error>;
11
12// region:    --- PrintChatStreamOptions
13
14/// Options for printing a chat stream with `printer::print_chat_stream`.
15#[derive(Debug, Default, Serialize, Deserialize)]
16pub struct PrintChatStreamOptions {
17	/// When true, also print event markers and tool-call metadata.
18	print_events: Option<bool>,
19}
20
21/// Constructors
22impl PrintChatStreamOptions {
23	/// Build options with `print_events` set.
24	pub fn from_print_events(print_events: bool) -> Self {
25		PrintChatStreamOptions {
26			print_events: Some(print_events),
27		}
28	}
29}
30
31// endregion: --- PrintChatStreamOptions
32
33/// Write the streamed chat response to stdout and return the concatenated content.
34///
35/// Stdout is flushed before returning, even on error.
36pub async fn print_chat_stream(
37	chat_res: ChatStreamResponse,
38	options: Option<&PrintChatStreamOptions>,
39) -> Result<String> {
40	let mut stdout = tokio::io::stdout();
41	let res = print_chat_stream_inner(&mut stdout, chat_res, options).await;
42
43	// Ensure tokio stdout flush is called, regardless of success or failure.
44	let flush_res = stdout.flush().await;
45
46	match (res, flush_res) {
47		// Prefer returning the inner processing error when both fail.
48		(Err(e), Err(_flush_err)) => Err(e),
49
50		// Inner succeeded but flush failed.
51		(Ok(_), Err(flush_err)) => Err(flush_err.into()),
52
53		// Flush succeeded (or not applicable); return inner result.
54		(inner, _) => inner,
55	}
56}
57
58async fn print_chat_stream_inner(
59	stdout: &mut Stdout,
60	chat_res: ChatStreamResponse,
61	options: Option<&PrintChatStreamOptions>,
62) -> Result<String> {
63	let mut stream = chat_res.stream;
64
65	let mut content_capture = String::new();
66
67	let print_events = options.and_then(|o| o.print_events).unwrap_or_default();
68
69	let mut first_chunk = true;
70	let mut first_reasoning_chunk = true;
71	let mut first_tool_chunk = true;
72
73	while let Some(next) = stream.next().await {
74		let (event_info, print_content, capture_content_flag) = match next {
75			Ok(stream_event) => {
76				match stream_event {
77					ChatStreamEvent::Start => {
78						if print_events {
79							// TODO: Might implement pretty JSON formatting
80							(Some("\n-- ChatStreamEvent::Start\n".to_string()), None, false)
81						} else {
82							(None, None, false)
83						}
84					}
85
86					ChatStreamEvent::Chunk(StreamChunk { content }) => {
87						if print_events && first_chunk {
88							first_chunk = false;
89							(
90								Some("\n-- ChatStreamEvent::Chunk (concatenated):\n".to_string()),
91								Some(content),
92								true,
93							)
94						} else {
95							(None, Some(content), true)
96						}
97					}
98
99					ChatStreamEvent::ReasoningChunk(StreamChunk { content }) => {
100						if print_events && first_reasoning_chunk {
101							first_reasoning_chunk = false;
102							(
103								Some("\n-- ChatStreamEvent::ReasoningChunk (concatenated):\n".to_string()),
104								Some(content),
105								false, // print but do not capture
106							)
107						} else {
108							(None, Some(content), false) // print but do not capture
109						}
110					}
111
112					ChatStreamEvent::ToolCallChunk(tool_chunk) => {
113						if print_events && first_tool_chunk {
114							first_tool_chunk = false;
115							(
116								Some(format!(
117									"\n-- ChatStreamEvent::ToolCallChunk: fn: {}, args: {}\n",
118									tool_chunk.tool_call.fn_name, tool_chunk.tool_call.fn_arguments
119								)),
120								None,
121								false,
122							)
123						} else {
124							(None, None, false)
125						}
126					}
127
128					ChatStreamEvent::End(end_event) => {
129						if print_events {
130							// TODO: Might implement pretty JSON formatting
131							(
132								Some(format!("\n\n-- ChatStreamEvent::End {end_event:?}\n")),
133								None,
134								false,
135							)
136						} else {
137							(None, None, false)
138						}
139					}
140				}
141			}
142			Err(e) => return Err(e.into()),
143		};
144
145		if let Some(event_info) = event_info {
146			stdout.write_all(event_info.as_bytes()).await?;
147		}
148
149		if let Some(content) = print_content {
150			if capture_content_flag {
151				content_capture.push_str(&content);
152			}
153			stdout.write_all(content.as_bytes()).await?;
154		};
155
156		stdout.flush().await?;
157	}
158
159	stdout.write_all(b"\n").await?;
160
161	Ok(content_capture)
162}
163
164// region:    --- Error
165
166// Note 1: The printer has its own error type because it is more of a utility, and therefore
167//         making the main crate error aware of the different error types would be unnecessary.
168//
169// Note 2: This Printer Error is not wrapped in the main crate error because the printer
170//         functions are not used by any other crate functions (they are more of a debug utility)
171
172use derive_more::From;
173
174/// The Printer error.
175#[derive(Debug, From)]
176pub enum Error {
177	/// The `tokio::io::Error` when using `tokio::io::stdout`
178	#[from]
179	TokioIo(tokio::io::Error),
180
181	/// The stream returned an error from the main crate.
182	#[from]
183	Stream(crate::Error),
184}
185
186// region:    --- Error Boilerplate
187
188impl core::fmt::Display for Error {
189	fn fmt(&self, fmt: &mut core::fmt::Formatter) -> core::result::Result<(), core::fmt::Error> {
190		write!(fmt, "{self:?}")
191	}
192}
193
194impl std::error::Error for Error {}
195
196// endregion: --- Error Boilerplate