1use crate::{
2 OneOrMany,
3 agent::CancelSignal,
4 completion::GetTokenUsage,
5 message::{AssistantContent, Reasoning, ToolResultContent, UserContent},
6 streaming::{StreamedAssistantContent, StreamingCompletion},
7 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
8};
9use futures::{Stream, StreamExt};
10use serde::{Deserialize, Serialize};
11use std::{pin::Pin, sync::Arc};
12use tokio::sync::RwLock;
13use tracing::info_span;
14use tracing_futures::Instrument;
15
16use crate::{
17 agent::Agent,
18 completion::{CompletionError, CompletionModel, PromptError},
19 message::{Message, Text},
20 tool::ToolSetError,
21};
22
23#[cfg(not(target_arch = "wasm32"))]
24pub type StreamingResult<R> =
25 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
26
27#[cfg(target_arch = "wasm32")]
28pub type StreamingResult<R> =
29 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
30
31#[derive(Deserialize, Serialize, Debug, Clone)]
32#[serde(tag = "type", rename_all = "camelCase")]
33#[non_exhaustive]
34pub enum MultiTurnStreamItem<R> {
35 StreamItem(StreamedAssistantContent<R>),
36 FinalResponse(FinalResponse),
37}
38
39#[derive(Deserialize, Serialize, Debug, Clone)]
40#[serde(rename_all = "camelCase")]
41pub struct FinalResponse {
42 response: String,
43 aggregated_usage: crate::completion::Usage,
44}
45
46impl FinalResponse {
47 pub fn empty() -> Self {
48 Self {
49 response: String::new(),
50 aggregated_usage: crate::completion::Usage::new(),
51 }
52 }
53
54 pub fn response(&self) -> &str {
55 &self.response
56 }
57
58 pub fn usage(&self) -> crate::completion::Usage {
59 self.aggregated_usage
60 }
61}
62
63impl<R> MultiTurnStreamItem<R> {
64 pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
65 Self::StreamItem(item)
66 }
67
68 pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
69 Self::FinalResponse(FinalResponse {
70 response: response.to_string(),
71 aggregated_usage,
72 })
73 }
74}
75
76#[derive(Debug, thiserror::Error)]
77pub enum StreamingError {
78 #[error("CompletionError: {0}")]
79 Completion(#[from] CompletionError),
80 #[error("PromptError: {0}")]
81 Prompt(#[from] Box<PromptError>),
82 #[error("ToolSetError: {0}")]
83 Tool(#[from] ToolSetError),
84}
85
86pub struct StreamingPromptRequest<M, P>
95where
96 M: CompletionModel,
97 P: StreamingPromptHook<M> + 'static,
98{
99 prompt: Message,
101 chat_history: Option<Vec<Message>>,
104 max_depth: usize,
106 agent: Arc<Agent<M>>,
108 hook: Option<P>,
110}
111
112impl<M, P> StreamingPromptRequest<M, P>
113where
114 M: CompletionModel + 'static,
115 <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
116 P: StreamingPromptHook<M>,
117{
118 pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> Self {
120 Self {
121 prompt: prompt.into(),
122 chat_history: None,
123 max_depth: 0,
124 agent,
125 hook: None,
126 }
127 }
128
129 pub fn multi_turn(mut self, depth: usize) -> Self {
132 self.max_depth = depth;
133 self
134 }
135
136 pub fn with_history(mut self, history: Vec<Message>) -> Self {
138 self.chat_history = Some(history);
139 self
140 }
141
142 pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
144 where
145 P2: StreamingPromptHook<M>,
146 {
147 StreamingPromptRequest {
148 prompt: self.prompt,
149 chat_history: self.chat_history,
150 max_depth: self.max_depth,
151 agent: self.agent,
152 hook: Some(hook),
153 }
154 }
155
156 #[cfg_attr(feature = "worker", worker::send)]
157 async fn send(self) -> StreamingResult<M::StreamingResponse> {
158 let agent_span = if tracing::Span::current().is_disabled() {
159 info_span!(
160 "invoke_agent",
161 gen_ai.operation.name = "invoke_agent",
162 gen_ai.agent.name = self.agent.name(),
163 gen_ai.system_instructions = self.agent.preamble,
164 gen_ai.prompt = tracing::field::Empty,
165 gen_ai.completion = tracing::field::Empty,
166 gen_ai.usage.input_tokens = tracing::field::Empty,
167 gen_ai.usage.output_tokens = tracing::field::Empty,
168 )
169 } else {
170 tracing::Span::current()
171 };
172
173 let prompt = self.prompt;
174 if let Some(text) = prompt.rag_text() {
175 agent_span.record("gen_ai.prompt", text);
176 }
177
178 let agent = self.agent;
179
180 let chat_history = if let Some(history) = self.chat_history {
181 Arc::new(RwLock::new(history))
182 } else {
183 Arc::new(RwLock::new(vec![]))
184 };
185
186 let mut current_max_depth = 0;
187 let mut last_prompt_error = String::new();
188
189 let mut last_text_response = String::new();
190 let mut is_text_response = false;
191 let mut max_depth_reached = false;
192
193 let mut aggregated_usage = crate::completion::Usage::new();
194
195 let cancel_signal = CancelSignal::new();
196
197 Box::pin(async_stream::stream! {
198 let _guard = agent_span.enter();
199 let mut current_prompt = prompt.clone();
200 let mut did_call_tool = false;
201
202 'outer: loop {
203 if current_max_depth > self.max_depth + 1 {
204 last_prompt_error = current_prompt.rag_text().unwrap_or_default();
205 max_depth_reached = true;
206 break;
207 }
208
209 current_max_depth += 1;
210
211 if self.max_depth > 1 {
212 tracing::info!(
213 "Current conversation depth: {}/{}",
214 current_max_depth,
215 self.max_depth
216 );
217 }
218
219 if let Some(ref hook) = self.hook {
220 let reader = chat_history.read().await;
221 let prompt = reader.last().cloned().expect("there should always be at least one message in the chat history");
222 let chat_history_except_last = reader[..reader.len() - 1].to_vec();
223
224 hook.on_completion_call(&prompt, &chat_history_except_last, cancel_signal.clone())
225 .await;
226
227 if cancel_signal.is_cancelled() {
228 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
229 }
230 }
231
232 let chat_stream_span = info_span!(
233 target: "rig::agent_chat",
234 parent: tracing::Span::current(),
235 "chat_streaming",
236 gen_ai.operation.name = "chat",
237 gen_ai.system_instructions = &agent.preamble,
238 gen_ai.provider.name = tracing::field::Empty,
239 gen_ai.request.model = tracing::field::Empty,
240 gen_ai.response.id = tracing::field::Empty,
241 gen_ai.response.model = tracing::field::Empty,
242 gen_ai.usage.output_tokens = tracing::field::Empty,
243 gen_ai.usage.input_tokens = tracing::field::Empty,
244 gen_ai.input.messages = tracing::field::Empty,
245 gen_ai.output.messages = tracing::field::Empty,
246 );
247
248 let mut stream = tracing::Instrument::instrument(
249 agent
250 .stream_completion(current_prompt.clone(), (*chat_history.read().await).clone())
251 .await?
252 .stream(), chat_stream_span
253 )
254
255 .await?;
256
257 chat_history.write().await.push(current_prompt.clone());
258
259 let mut tool_calls = vec![];
260 let mut tool_results = vec![];
261
262 while let Some(content) = stream.next().await {
263 match content {
264 Ok(StreamedAssistantContent::Text(text)) => {
265 if !is_text_response {
266 last_text_response = String::new();
267 is_text_response = true;
268 }
269 last_text_response.push_str(&text.text);
270 if let Some(ref hook) = self.hook {
271 hook.on_text_delta(&text.text, &last_text_response, cancel_signal.clone()).await;
272 if cancel_signal.is_cancelled() {
273 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
274 }
275 }
276 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
277 did_call_tool = false;
278 },
279 Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
280 let tool_span = info_span!(
281 parent: tracing::Span::current(),
282 "execute_tool",
283 gen_ai.operation.name = "execute_tool",
284 gen_ai.tool.type = "function",
285 gen_ai.tool.name = tracing::field::Empty,
286 gen_ai.tool.call.id = tracing::field::Empty,
287 gen_ai.tool.call.arguments = tracing::field::Empty,
288 gen_ai.tool.call.result = tracing::field::Empty
289 );
290
291 let res = async {
292 let tool_span = tracing::Span::current();
293 if let Some(ref hook) = self.hook {
294 hook.on_tool_call(&tool_call.function.name, &tool_call.function.arguments.to_string(), cancel_signal.clone()).await;
295 if cancel_signal.is_cancelled() {
296 return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
297 }
298 }
299
300 tool_span.record("gen_ai.tool.name", &tool_call.function.name);
301 tool_span.record("gen_ai.tool.call.arguments", tool_call.function.arguments.to_string());
302
303 let tool_result = match
304 agent.tool_server_handle.call_tool(&tool_call.function.name, &tool_call.function.arguments.to_string()).await {
305 Ok(thing) => thing,
306 Err(e) => {
307 tracing::warn!("Error while calling tool: {e}");
308 e.to_string()
309 }
310 };
311
312 tool_span.record("gen_ai.tool.call.result", &tool_result);
313
314 if let Some(ref hook) = self.hook {
315 hook.on_tool_result(&tool_call.function.name, &tool_call.function.arguments.to_string(), &tool_result.to_string(), cancel_signal.clone())
316 .await;
317
318 if cancel_signal.is_cancelled() {
319 return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
320 }
321 }
322
323 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
324
325 tool_calls.push(tool_call_msg);
326 tool_results.push((tool_call.id, tool_call.call_id, tool_result));
327
328 did_call_tool = true;
329 Ok(())
330 }.instrument(tool_span).await;
332
333 if let Err(e) = res {
334 yield Err(e);
335 }
336 },
337 Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id })) => {
338 chat_history.write().await.push(rig::message::Message::Assistant {
339 id: None,
340 content: OneOrMany::one(AssistantContent::Reasoning(Reasoning {
341 reasoning: reasoning.clone(), id: id.clone()
342 }))
343 });
344 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id })));
345 did_call_tool = false;
346 },
347 Ok(StreamedAssistantContent::Final(final_resp)) => {
348 if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
349 if is_text_response {
350 if let Some(ref hook) = self.hook {
351 hook.on_stream_completion_response_finish(&prompt, &final_resp, cancel_signal.clone()).await;
352
353 if cancel_signal.is_cancelled() {
354 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
355 }
356 }
357
358 tracing::Span::current().record("gen_ai.completion", &last_text_response);
359 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
360 is_text_response = false;
361 }
362 }
363 Err(e) => {
364 yield Err(e.into());
365 break 'outer;
366 }
367 }
368 }
369
370 if !tool_calls.is_empty() {
372 chat_history.write().await.push(Message::Assistant {
373 id: None,
374 content: OneOrMany::many(tool_calls.clone()).expect("Impossible EmptyListError"),
375 });
376 }
377
378 for (id, call_id, tool_result) in tool_results {
380 if let Some(call_id) = call_id {
381 chat_history.write().await.push(Message::User {
382 content: OneOrMany::one(UserContent::tool_result_with_call_id(
383 &id,
384 call_id.clone(),
385 OneOrMany::one(ToolResultContent::text(&tool_result)),
386 )),
387 });
388 } else {
389 chat_history.write().await.push(Message::User {
390 content: OneOrMany::one(UserContent::tool_result(
391 &id,
392 OneOrMany::one(ToolResultContent::text(&tool_result)),
393 )),
394 });
395 }
396 }
397
398 current_prompt = match chat_history.write().await.pop() {
400 Some(prompt) => prompt,
401 None => unreachable!("Chat history should never be empty at this point"),
402 };
403
404 if !did_call_tool {
405 let current_span = tracing::Span::current();
406 current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
407 current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
408 tracing::info!("Agent multi-turn stream finished");
409 yield Ok(MultiTurnStreamItem::final_response(&last_text_response, aggregated_usage));
410 break;
411 }
412 }
413
414 if max_depth_reached {
415 yield Err(Box::new(PromptError::MaxDepthError {
416 max_depth: self.max_depth,
417 chat_history: Box::new((*chat_history.read().await).clone()),
418 prompt: last_prompt_error.clone().into(),
419 }).into());
420 }
421
422 })
423 }
424}
425
426impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
427where
428 M: CompletionModel + 'static,
429 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
430 P: StreamingPromptHook<M> + 'static,
431{
432 type Output = StreamingResult<M::StreamingResponse>; type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
434
435 fn into_future(self) -> Self::IntoFuture {
436 Box::pin(async move { self.send().await })
438 }
439}
440
441pub async fn stream_to_stdout<R>(
443 stream: &mut StreamingResult<R>,
444) -> Result<FinalResponse, std::io::Error> {
445 let mut final_res = FinalResponse::empty();
446 print!("Response: ");
447 while let Some(content) = stream.next().await {
448 match content {
449 Ok(MultiTurnStreamItem::StreamItem(StreamedAssistantContent::Text(Text { text }))) => {
450 print!("{text}");
451 std::io::Write::flush(&mut std::io::stdout()).unwrap();
452 }
453 Ok(MultiTurnStreamItem::StreamItem(StreamedAssistantContent::Reasoning(
454 Reasoning { reasoning, .. },
455 ))) => {
456 let reasoning = reasoning.join("\n");
457 print!("{reasoning}");
458 std::io::Write::flush(&mut std::io::stdout()).unwrap();
459 }
460 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
461 final_res = res;
462 }
463 Err(err) => {
464 eprintln!("Error: {err}");
465 }
466 _ => {}
467 }
468 }
469
470 Ok(final_res)
471}
472
473pub trait StreamingPromptHook<M>: Clone + Send + Sync
476where
477 M: CompletionModel,
478{
479 #[allow(unused_variables)]
480 fn on_completion_call(
482 &self,
483 prompt: &Message,
484 history: &[Message],
485 cancel_sig: CancelSignal,
486 ) -> impl Future<Output = ()> + Send {
487 async {}
488 }
489
490 #[allow(unused_variables)]
491 fn on_text_delta(
493 &self,
494 text_delta: &str,
495 aggregated_text: &str,
496 cancel_sig: CancelSignal,
497 ) -> impl Future<Output = ()> + Send {
498 async {}
499 }
500
501 #[allow(unused_variables)]
502 fn on_stream_completion_response_finish(
504 &self,
505 prompt: &Message,
506 response: &<M as CompletionModel>::StreamingResponse,
507 cancel_sig: CancelSignal,
508 ) -> impl Future<Output = ()> + Send {
509 async {}
510 }
511
512 #[allow(unused_variables)]
513 fn on_tool_call(
515 &self,
516 tool_name: &str,
517 args: &str,
518 cancel_sig: CancelSignal,
519 ) -> impl Future<Output = ()> + Send {
520 async {}
521 }
522
523 #[allow(unused_variables)]
524 fn on_tool_result(
526 &self,
527 tool_name: &str,
528 args: &str,
529 result: &str,
530 cancel_sig: CancelSignal,
531 ) -> impl Future<Output = ()> + Send {
532 async {}
533 }
534}
535
536impl<M> StreamingPromptHook<M> for () where M: CompletionModel {}