1use crate::{
2 OneOrMany,
3 agent::CancelSignal,
4 completion::GetTokenUsage,
5 message::{AssistantContent, Reasoning, ToolResult, ToolResultContent, UserContent},
6 streaming::{StreamedAssistantContent, StreamedUserContent, StreamingCompletion},
7 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
8};
9use futures::{Stream, StreamExt};
10use serde::{Deserialize, Serialize};
11use std::{pin::Pin, sync::Arc};
12use tokio::sync::RwLock;
13use tracing::info_span;
14use tracing_futures::Instrument;
15
16use crate::{
17 agent::Agent,
18 completion::{CompletionError, CompletionModel, PromptError},
19 message::{Message, Text},
20 tool::ToolSetError,
21};
22
23#[cfg(not(target_arch = "wasm32"))]
24pub type StreamingResult<R> =
25 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
26
27#[cfg(target_arch = "wasm32")]
28pub type StreamingResult<R> =
29 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
30
31#[derive(Deserialize, Serialize, Debug, Clone)]
32#[serde(tag = "type", rename_all = "camelCase")]
33#[non_exhaustive]
34pub enum MultiTurnStreamItem<R> {
35 StreamAssistantItem(StreamedAssistantContent<R>),
37 StreamUserItem(StreamedUserContent),
39 FinalResponse(FinalResponse),
41}
42
43#[derive(Deserialize, Serialize, Debug, Clone)]
44#[serde(rename_all = "camelCase")]
45pub struct FinalResponse {
46 response: String,
47 aggregated_usage: crate::completion::Usage,
48}
49
50impl FinalResponse {
51 pub fn empty() -> Self {
52 Self {
53 response: String::new(),
54 aggregated_usage: crate::completion::Usage::new(),
55 }
56 }
57
58 pub fn response(&self) -> &str {
59 &self.response
60 }
61
62 pub fn usage(&self) -> crate::completion::Usage {
63 self.aggregated_usage
64 }
65}
66
67impl<R> MultiTurnStreamItem<R> {
68 pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
69 Self::StreamAssistantItem(item)
70 }
71
72 pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
73 Self::FinalResponse(FinalResponse {
74 response: response.to_string(),
75 aggregated_usage,
76 })
77 }
78}
79
80#[derive(Debug, thiserror::Error)]
81pub enum StreamingError {
82 #[error("CompletionError: {0}")]
83 Completion(#[from] CompletionError),
84 #[error("PromptError: {0}")]
85 Prompt(#[from] Box<PromptError>),
86 #[error("ToolSetError: {0}")]
87 Tool(#[from] ToolSetError),
88}
89
90pub struct StreamingPromptRequest<M, P>
99where
100 M: CompletionModel,
101 P: StreamingPromptHook<M> + 'static,
102{
103 prompt: Message,
105 chat_history: Option<Vec<Message>>,
108 max_depth: usize,
110 agent: Arc<Agent<M>>,
112 hook: Option<P>,
114}
115
116impl<M, P> StreamingPromptRequest<M, P>
117where
118 M: CompletionModel + 'static,
119 <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
120 P: StreamingPromptHook<M>,
121{
122 pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> Self {
124 Self {
125 prompt: prompt.into(),
126 chat_history: None,
127 max_depth: 0,
128 agent,
129 hook: None,
130 }
131 }
132
133 pub fn multi_turn(mut self, depth: usize) -> Self {
136 self.max_depth = depth;
137 self
138 }
139
140 pub fn with_history(mut self, history: Vec<Message>) -> Self {
142 self.chat_history = Some(history);
143 self
144 }
145
146 pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
148 where
149 P2: StreamingPromptHook<M>,
150 {
151 StreamingPromptRequest {
152 prompt: self.prompt,
153 chat_history: self.chat_history,
154 max_depth: self.max_depth,
155 agent: self.agent,
156 hook: Some(hook),
157 }
158 }
159
160 #[cfg_attr(feature = "worker", worker::send)]
161 async fn send(self) -> StreamingResult<M::StreamingResponse> {
162 let agent_span = if tracing::Span::current().is_disabled() {
163 info_span!(
164 "invoke_agent",
165 gen_ai.operation.name = "invoke_agent",
166 gen_ai.agent.name = self.agent.name(),
167 gen_ai.system_instructions = self.agent.preamble,
168 gen_ai.prompt = tracing::field::Empty,
169 gen_ai.completion = tracing::field::Empty,
170 gen_ai.usage.input_tokens = tracing::field::Empty,
171 gen_ai.usage.output_tokens = tracing::field::Empty,
172 )
173 } else {
174 tracing::Span::current()
175 };
176
177 let prompt = self.prompt;
178 if let Some(text) = prompt.rag_text() {
179 agent_span.record("gen_ai.prompt", text);
180 }
181
182 let agent = self.agent;
183
184 let chat_history = if let Some(history) = self.chat_history {
185 Arc::new(RwLock::new(history))
186 } else {
187 Arc::new(RwLock::new(vec![]))
188 };
189
190 let mut current_max_depth = 0;
191 let mut last_prompt_error = String::new();
192
193 let mut last_text_response = String::new();
194 let mut is_text_response = false;
195 let mut max_depth_reached = false;
196
197 let mut aggregated_usage = crate::completion::Usage::new();
198
199 let cancel_signal = CancelSignal::new();
200
201 Box::pin(async_stream::stream! {
202 let _guard = agent_span.enter();
203 let mut current_prompt = prompt.clone();
204 let mut did_call_tool = false;
205
206 'outer: loop {
207 if current_max_depth > self.max_depth + 1 {
208 last_prompt_error = current_prompt.rag_text().unwrap_or_default();
209 max_depth_reached = true;
210 break;
211 }
212
213 current_max_depth += 1;
214
215 if self.max_depth > 1 {
216 tracing::info!(
217 "Current conversation depth: {}/{}",
218 current_max_depth,
219 self.max_depth
220 );
221 }
222
223 if let Some(ref hook) = self.hook {
224 let reader = chat_history.read().await;
225 let prompt = reader.last().cloned().expect("there should always be at least one message in the chat history");
226 let chat_history_except_last = reader[..reader.len() - 1].to_vec();
227
228 hook.on_completion_call(&prompt, &chat_history_except_last, cancel_signal.clone())
229 .await;
230
231 if cancel_signal.is_cancelled() {
232 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
233 }
234 }
235
236 let chat_stream_span = info_span!(
237 target: "rig::agent_chat",
238 parent: tracing::Span::current(),
239 "chat_streaming",
240 gen_ai.operation.name = "chat",
241 gen_ai.system_instructions = &agent.preamble,
242 gen_ai.provider.name = tracing::field::Empty,
243 gen_ai.request.model = tracing::field::Empty,
244 gen_ai.response.id = tracing::field::Empty,
245 gen_ai.response.model = tracing::field::Empty,
246 gen_ai.usage.output_tokens = tracing::field::Empty,
247 gen_ai.usage.input_tokens = tracing::field::Empty,
248 gen_ai.input.messages = tracing::field::Empty,
249 gen_ai.output.messages = tracing::field::Empty,
250 );
251
252 let mut stream = tracing::Instrument::instrument(
253 agent
254 .stream_completion(current_prompt.clone(), (*chat_history.read().await).clone())
255 .await?
256 .stream(), chat_stream_span
257 )
258
259 .await?;
260
261 chat_history.write().await.push(current_prompt.clone());
262
263 let mut tool_calls = vec![];
264 let mut tool_results = vec![];
265
266 while let Some(content) = stream.next().await {
267 match content {
268 Ok(StreamedAssistantContent::Text(text)) => {
269 if !is_text_response {
270 last_text_response = String::new();
271 is_text_response = true;
272 }
273 last_text_response.push_str(&text.text);
274 if let Some(ref hook) = self.hook {
275 hook.on_text_delta(&text.text, &last_text_response, cancel_signal.clone()).await;
276 if cancel_signal.is_cancelled() {
277 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
278 }
279 }
280 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
281 did_call_tool = false;
282 },
283 Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
284 let tool_span = info_span!(
285 parent: tracing::Span::current(),
286 "execute_tool",
287 gen_ai.operation.name = "execute_tool",
288 gen_ai.tool.type = "function",
289 gen_ai.tool.name = tracing::field::Empty,
290 gen_ai.tool.call.id = tracing::field::Empty,
291 gen_ai.tool.call.arguments = tracing::field::Empty,
292 gen_ai.tool.call.result = tracing::field::Empty
293 );
294
295 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall(tool_call.clone())));
296
297 let tc_result = async {
298 let tool_span = tracing::Span::current();
299 if let Some(ref hook) = self.hook {
300 hook.on_tool_call(&tool_call.function.name, &tool_call.function.arguments.to_string(), cancel_signal.clone()).await;
301 if cancel_signal.is_cancelled() {
302 return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
303 }
304 }
305
306 tool_span.record("gen_ai.tool.name", &tool_call.function.name);
307 tool_span.record("gen_ai.tool.call.arguments", tool_call.function.arguments.to_string());
308
309 let tool_result = match
310 agent.tool_server_handle.call_tool(&tool_call.function.name, &tool_call.function.arguments.to_string()).await {
311 Ok(thing) => thing,
312 Err(e) => {
313 tracing::warn!("Error while calling tool: {e}");
314 e.to_string()
315 }
316 };
317
318 tool_span.record("gen_ai.tool.call.result", &tool_result);
319
320 if let Some(ref hook) = self.hook {
321 hook.on_tool_result(&tool_call.function.name, &tool_call.function.arguments.to_string(), &tool_result.to_string(), cancel_signal.clone())
322 .await;
323
324 if cancel_signal.is_cancelled() {
325 return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
326 }
327 }
328
329 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
330
331 tool_calls.push(tool_call_msg);
332 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
333
334 did_call_tool = true;
335 Ok(tool_result)
336 }.instrument(tool_span).await;
337
338 match tc_result {
339 Ok(text) => {
340 let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: OneOrMany::one(ToolResultContent::Text(Text { text })) };
341 yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult(tr)));
342 }
343 Err(e) => {
344 yield Err(e);
345 }
346 }
347 },
348 Ok(StreamedAssistantContent::ToolCallDelta { id, delta }) => {
349 if let Some(ref hook) = self.hook {
350 hook.on_tool_call_delta(&id, &delta, cancel_signal.clone())
351 .await;
352
353 if cancel_signal.is_cancelled() {
354 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
355 }
356 }
357 }
358 Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })) => {
359 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })));
360 did_call_tool = false;
361 },
362 Ok(StreamedAssistantContent::Final(final_resp)) => {
363 if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
364 if is_text_response {
365 if let Some(ref hook) = self.hook {
366 hook.on_stream_completion_response_finish(&prompt, &final_resp, cancel_signal.clone()).await;
367
368 if cancel_signal.is_cancelled() {
369 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
370 }
371 }
372
373 tracing::Span::current().record("gen_ai.completion", &last_text_response);
374 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
375 is_text_response = false;
376 }
377 }
378 Err(e) => {
379 yield Err(e.into());
380 break 'outer;
381 }
382 }
383 }
384
385 if !tool_calls.is_empty() {
387 chat_history.write().await.push(Message::Assistant {
388 id: None,
389 content: OneOrMany::many(tool_calls.clone()).expect("Impossible EmptyListError"),
390 });
391 }
392
393 for (id, call_id, tool_result) in tool_results {
395 if let Some(call_id) = call_id {
396 chat_history.write().await.push(Message::User {
397 content: OneOrMany::one(UserContent::tool_result_with_call_id(
398 &id,
399 call_id.clone(),
400 OneOrMany::one(ToolResultContent::text(&tool_result)),
401 )),
402 });
403 } else {
404 chat_history.write().await.push(Message::User {
405 content: OneOrMany::one(UserContent::tool_result(
406 &id,
407 OneOrMany::one(ToolResultContent::text(&tool_result)),
408 )),
409 });
410 }
411 }
412
413 current_prompt = match chat_history.write().await.pop() {
415 Some(prompt) => prompt,
416 None => unreachable!("Chat history should never be empty at this point"),
417 };
418
419 if !did_call_tool {
420 let current_span = tracing::Span::current();
421 current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
422 current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
423 tracing::info!("Agent multi-turn stream finished");
424 yield Ok(MultiTurnStreamItem::final_response(&last_text_response, aggregated_usage));
425 break;
426 }
427 }
428
429 if max_depth_reached {
430 yield Err(Box::new(PromptError::MaxDepthError {
431 max_depth: self.max_depth,
432 chat_history: Box::new((*chat_history.read().await).clone()),
433 prompt: last_prompt_error.clone().into(),
434 }).into());
435 }
436
437 })
438 }
439}
440
441impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
442where
443 M: CompletionModel + 'static,
444 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
445 P: StreamingPromptHook<M> + 'static,
446{
447 type Output = StreamingResult<M::StreamingResponse>; type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
449
450 fn into_future(self) -> Self::IntoFuture {
451 Box::pin(async move { self.send().await })
453 }
454}
455
456pub async fn stream_to_stdout<R>(
458 stream: &mut StreamingResult<R>,
459) -> Result<FinalResponse, std::io::Error> {
460 let mut final_res = FinalResponse::empty();
461 print!("Response: ");
462 while let Some(content) = stream.next().await {
463 match content {
464 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
465 Text { text },
466 ))) => {
467 print!("{text}");
468 std::io::Write::flush(&mut std::io::stdout()).unwrap();
469 }
470 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
471 Reasoning { reasoning, .. },
472 ))) => {
473 let reasoning = reasoning.join("\n");
474 print!("{reasoning}");
475 std::io::Write::flush(&mut std::io::stdout()).unwrap();
476 }
477 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
478 final_res = res;
479 }
480 Err(err) => {
481 eprintln!("Error: {err}");
482 }
483 _ => {}
484 }
485 }
486
487 Ok(final_res)
488}
489
490pub trait StreamingPromptHook<M>: Clone + Send + Sync
493where
494 M: CompletionModel,
495{
496 #[allow(unused_variables)]
497 fn on_completion_call(
499 &self,
500 prompt: &Message,
501 history: &[Message],
502 cancel_sig: CancelSignal,
503 ) -> impl Future<Output = ()> + Send {
504 async {}
505 }
506
507 #[allow(unused_variables)]
508 fn on_text_delta(
510 &self,
511 text_delta: &str,
512 aggregated_text: &str,
513 cancel_sig: CancelSignal,
514 ) -> impl Future<Output = ()> + Send {
515 async {}
516 }
517
518 #[allow(unused_variables)]
519 fn on_tool_call_delta(
521 &self,
522 tool_call_id: &str,
523 tool_call_delta: &str,
524 cancel_sig: CancelSignal,
525 ) -> impl Future<Output = ()> + Send {
526 async {}
527 }
528
529 #[allow(unused_variables)]
530 fn on_stream_completion_response_finish(
532 &self,
533 prompt: &Message,
534 response: &<M as CompletionModel>::StreamingResponse,
535 cancel_sig: CancelSignal,
536 ) -> impl Future<Output = ()> + Send {
537 async {}
538 }
539
540 #[allow(unused_variables)]
541 fn on_tool_call(
543 &self,
544 tool_name: &str,
545 args: &str,
546 cancel_sig: CancelSignal,
547 ) -> impl Future<Output = ()> + Send {
548 async {}
549 }
550
551 #[allow(unused_variables)]
552 fn on_tool_result(
554 &self,
555 tool_name: &str,
556 args: &str,
557 result: &str,
558 cancel_sig: CancelSignal,
559 ) -> impl Future<Output = ()> + Send {
560 async {}
561 }
562}
563
564impl<M> StreamingPromptHook<M> for () where M: CompletionModel {}