1use crate::{
2 OneOrMany,
3 agent::CancelSignal,
4 completion::GetTokenUsage,
5 json_utils,
6 message::{AssistantContent, Reasoning, ToolResult, ToolResultContent, UserContent},
7 streaming::{StreamedAssistantContent, StreamedUserContent, StreamingCompletion},
8 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
9};
10use futures::{Stream, StreamExt};
11use serde::{Deserialize, Serialize};
12use std::{pin::Pin, sync::Arc};
13use tokio::sync::RwLock;
14use tracing::info_span;
15use tracing_futures::Instrument;
16
17use crate::{
18 agent::Agent,
19 completion::{CompletionError, CompletionModel, PromptError},
20 message::{Message, Text},
21 tool::ToolSetError,
22};
23
24#[cfg(not(target_arch = "wasm32"))]
25pub type StreamingResult<R> =
26 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
27
28#[cfg(target_arch = "wasm32")]
29pub type StreamingResult<R> =
30 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
31
32#[derive(Deserialize, Serialize, Debug, Clone)]
33#[serde(tag = "type", rename_all = "camelCase")]
34#[non_exhaustive]
35pub enum MultiTurnStreamItem<R> {
36 StreamAssistantItem(StreamedAssistantContent<R>),
38 StreamUserItem(StreamedUserContent),
40 FinalResponse(FinalResponse),
42}
43
44#[derive(Deserialize, Serialize, Debug, Clone)]
45#[serde(rename_all = "camelCase")]
46pub struct FinalResponse {
47 response: String,
48 aggregated_usage: crate::completion::Usage,
49}
50
51impl FinalResponse {
52 pub fn empty() -> Self {
53 Self {
54 response: String::new(),
55 aggregated_usage: crate::completion::Usage::new(),
56 }
57 }
58
59 pub fn response(&self) -> &str {
60 &self.response
61 }
62
63 pub fn usage(&self) -> crate::completion::Usage {
64 self.aggregated_usage
65 }
66}
67
68impl<R> MultiTurnStreamItem<R> {
69 pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
70 Self::StreamAssistantItem(item)
71 }
72
73 pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
74 Self::FinalResponse(FinalResponse {
75 response: response.to_string(),
76 aggregated_usage,
77 })
78 }
79}
80
81#[derive(Debug, thiserror::Error)]
82pub enum StreamingError {
83 #[error("CompletionError: {0}")]
84 Completion(#[from] CompletionError),
85 #[error("PromptError: {0}")]
86 Prompt(#[from] Box<PromptError>),
87 #[error("ToolSetError: {0}")]
88 Tool(#[from] ToolSetError),
89}
90
91pub struct StreamingPromptRequest<M, P>
100where
101 M: CompletionModel,
102 P: StreamingPromptHook<M> + 'static,
103{
104 prompt: Message,
106 chat_history: Option<Vec<Message>>,
109 max_depth: usize,
111 agent: Arc<Agent<M>>,
113 hook: Option<P>,
115}
116
117impl<M, P> StreamingPromptRequest<M, P>
118where
119 M: CompletionModel + 'static,
120 <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
121 P: StreamingPromptHook<M>,
122{
123 pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> Self {
125 Self {
126 prompt: prompt.into(),
127 chat_history: None,
128 max_depth: 0,
129 agent,
130 hook: None,
131 }
132 }
133
134 pub fn multi_turn(mut self, depth: usize) -> Self {
137 self.max_depth = depth;
138 self
139 }
140
141 pub fn with_history(mut self, history: Vec<Message>) -> Self {
143 self.chat_history = Some(history);
144 self
145 }
146
147 pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
149 where
150 P2: StreamingPromptHook<M>,
151 {
152 StreamingPromptRequest {
153 prompt: self.prompt,
154 chat_history: self.chat_history,
155 max_depth: self.max_depth,
156 agent: self.agent,
157 hook: Some(hook),
158 }
159 }
160
161 #[cfg_attr(feature = "worker", worker::send)]
162 async fn send(self) -> StreamingResult<M::StreamingResponse> {
163 let agent_span = if tracing::Span::current().is_disabled() {
164 info_span!(
165 "invoke_agent",
166 gen_ai.operation.name = "invoke_agent",
167 gen_ai.agent.name = self.agent.name(),
168 gen_ai.system_instructions = self.agent.preamble,
169 gen_ai.prompt = tracing::field::Empty,
170 gen_ai.completion = tracing::field::Empty,
171 gen_ai.usage.input_tokens = tracing::field::Empty,
172 gen_ai.usage.output_tokens = tracing::field::Empty,
173 )
174 } else {
175 tracing::Span::current()
176 };
177
178 let prompt = self.prompt;
179 if let Some(text) = prompt.rag_text() {
180 agent_span.record("gen_ai.prompt", text);
181 }
182
183 let agent = self.agent;
184
185 let chat_history = if let Some(history) = self.chat_history {
186 Arc::new(RwLock::new(history))
187 } else {
188 Arc::new(RwLock::new(vec![]))
189 };
190
191 let mut current_max_depth = 0;
192 let mut last_prompt_error = String::new();
193
194 let mut last_text_response = String::new();
195 let mut is_text_response = false;
196 let mut max_depth_reached = false;
197
198 let mut aggregated_usage = crate::completion::Usage::new();
199
200 let cancel_signal = CancelSignal::new();
201
202 Box::pin(async_stream::stream! {
203 let _guard = agent_span.enter();
204 let mut current_prompt = prompt.clone();
205 let mut did_call_tool = false;
206
207 'outer: loop {
208 if current_max_depth > self.max_depth + 1 {
209 last_prompt_error = current_prompt.rag_text().unwrap_or_default();
210 max_depth_reached = true;
211 break;
212 }
213
214 current_max_depth += 1;
215
216 if self.max_depth > 1 {
217 tracing::info!(
218 "Current conversation depth: {}/{}",
219 current_max_depth,
220 self.max_depth
221 );
222 }
223
224 if let Some(ref hook) = self.hook {
225 let reader = chat_history.read().await;
226 let prompt = reader.last().cloned().expect("there should always be at least one message in the chat history");
227 let chat_history_except_last = reader[..reader.len() - 1].to_vec();
228
229 hook.on_completion_call(&prompt, &chat_history_except_last, cancel_signal.clone())
230 .await;
231
232 if cancel_signal.is_cancelled() {
233 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
234 }
235 }
236
237 let chat_stream_span = info_span!(
238 target: "rig::agent_chat",
239 parent: tracing::Span::current(),
240 "chat_streaming",
241 gen_ai.operation.name = "chat",
242 gen_ai.system_instructions = &agent.preamble,
243 gen_ai.provider.name = tracing::field::Empty,
244 gen_ai.request.model = tracing::field::Empty,
245 gen_ai.response.id = tracing::field::Empty,
246 gen_ai.response.model = tracing::field::Empty,
247 gen_ai.usage.output_tokens = tracing::field::Empty,
248 gen_ai.usage.input_tokens = tracing::field::Empty,
249 gen_ai.input.messages = tracing::field::Empty,
250 gen_ai.output.messages = tracing::field::Empty,
251 );
252
253 let mut stream = tracing::Instrument::instrument(
254 agent
255 .stream_completion(current_prompt.clone(), (*chat_history.read().await).clone())
256 .await?
257 .stream(), chat_stream_span
258 )
259
260 .await?;
261
262 chat_history.write().await.push(current_prompt.clone());
263
264 let mut tool_calls = vec![];
265 let mut tool_results = vec![];
266
267 while let Some(content) = stream.next().await {
268 match content {
269 Ok(StreamedAssistantContent::Text(text)) => {
270 if !is_text_response {
271 last_text_response = String::new();
272 is_text_response = true;
273 }
274 last_text_response.push_str(&text.text);
275 if let Some(ref hook) = self.hook {
276 hook.on_text_delta(&text.text, &last_text_response, cancel_signal.clone()).await;
277 if cancel_signal.is_cancelled() {
278 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
279 }
280 }
281 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
282 did_call_tool = false;
283 },
284 Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
285 let tool_span = info_span!(
286 parent: tracing::Span::current(),
287 "execute_tool",
288 gen_ai.operation.name = "execute_tool",
289 gen_ai.tool.type = "function",
290 gen_ai.tool.name = tracing::field::Empty,
291 gen_ai.tool.call.id = tracing::field::Empty,
292 gen_ai.tool.call.arguments = tracing::field::Empty,
293 gen_ai.tool.call.result = tracing::field::Empty
294 );
295
296 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall(tool_call.clone())));
297
298 let tc_result = async {
299 let tool_span = tracing::Span::current();
300 let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments);
301 if let Some(ref hook) = self.hook {
302 hook.on_tool_call(&tool_call.function.name, &tool_args, cancel_signal.clone()).await;
303 if cancel_signal.is_cancelled() {
304 return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
305 }
306 }
307
308 tool_span.record("gen_ai.tool.name", &tool_call.function.name);
309 tool_span.record("gen_ai.tool.call.arguments", &tool_args);
310
311 let tool_result = match
312 agent.tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
313 Ok(thing) => thing,
314 Err(e) => {
315 tracing::warn!("Error while calling tool: {e}");
316 e.to_string()
317 }
318 };
319
320 tool_span.record("gen_ai.tool.call.result", &tool_result);
321
322 if let Some(ref hook) = self.hook {
323 hook.on_tool_result(&tool_call.function.name, &tool_args, &tool_result.to_string(), cancel_signal.clone())
324 .await;
325
326 if cancel_signal.is_cancelled() {
327 return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
328 }
329 }
330
331 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
332
333 tool_calls.push(tool_call_msg);
334 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
335
336 did_call_tool = true;
337 Ok(tool_result)
338 }.instrument(tool_span).await;
339
340 match tc_result {
341 Ok(text) => {
342 let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: OneOrMany::one(ToolResultContent::Text(Text { text })) };
343 yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult(tr)));
344 }
345 Err(e) => {
346 yield Err(e);
347 }
348 }
349 },
350 Ok(StreamedAssistantContent::ToolCallDelta { id, delta }) => {
351 if let Some(ref hook) = self.hook {
352 hook.on_tool_call_delta(&id, &delta, cancel_signal.clone())
353 .await;
354
355 if cancel_signal.is_cancelled() {
356 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
357 }
358 }
359 }
360 Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })) => {
361 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })));
362 did_call_tool = false;
363 },
364 Ok(StreamedAssistantContent::Final(final_resp)) => {
365 if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
366 if is_text_response {
367 if let Some(ref hook) = self.hook {
368 hook.on_stream_completion_response_finish(&prompt, &final_resp, cancel_signal.clone()).await;
369
370 if cancel_signal.is_cancelled() {
371 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
372 }
373 }
374
375 tracing::Span::current().record("gen_ai.completion", &last_text_response);
376 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
377 is_text_response = false;
378 }
379 }
380 Err(e) => {
381 yield Err(e.into());
382 break 'outer;
383 }
384 }
385 }
386
387 if !tool_calls.is_empty() {
389 chat_history.write().await.push(Message::Assistant {
390 id: None,
391 content: OneOrMany::many(tool_calls.clone()).expect("Impossible EmptyListError"),
392 });
393 }
394
395 for (id, call_id, tool_result) in tool_results {
397 if let Some(call_id) = call_id {
398 chat_history.write().await.push(Message::User {
399 content: OneOrMany::one(UserContent::tool_result_with_call_id(
400 &id,
401 call_id.clone(),
402 OneOrMany::one(ToolResultContent::text(&tool_result)),
403 )),
404 });
405 } else {
406 chat_history.write().await.push(Message::User {
407 content: OneOrMany::one(UserContent::tool_result(
408 &id,
409 OneOrMany::one(ToolResultContent::text(&tool_result)),
410 )),
411 });
412 }
413 }
414
415 current_prompt = match chat_history.write().await.pop() {
417 Some(prompt) => prompt,
418 None => unreachable!("Chat history should never be empty at this point"),
419 };
420
421 if !did_call_tool {
422 let current_span = tracing::Span::current();
423 current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
424 current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
425 tracing::info!("Agent multi-turn stream finished");
426 yield Ok(MultiTurnStreamItem::final_response(&last_text_response, aggregated_usage));
427 break;
428 }
429 }
430
431 if max_depth_reached {
432 yield Err(Box::new(PromptError::MaxDepthError {
433 max_depth: self.max_depth,
434 chat_history: Box::new((*chat_history.read().await).clone()),
435 prompt: last_prompt_error.clone().into(),
436 }).into());
437 }
438
439 })
440 }
441}
442
443impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
444where
445 M: CompletionModel + 'static,
446 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
447 P: StreamingPromptHook<M> + 'static,
448{
449 type Output = StreamingResult<M::StreamingResponse>; type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
451
452 fn into_future(self) -> Self::IntoFuture {
453 Box::pin(async move { self.send().await })
455 }
456}
457
458pub async fn stream_to_stdout<R>(
460 stream: &mut StreamingResult<R>,
461) -> Result<FinalResponse, std::io::Error> {
462 let mut final_res = FinalResponse::empty();
463 print!("Response: ");
464 while let Some(content) = stream.next().await {
465 match content {
466 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
467 Text { text },
468 ))) => {
469 print!("{text}");
470 std::io::Write::flush(&mut std::io::stdout()).unwrap();
471 }
472 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
473 Reasoning { reasoning, .. },
474 ))) => {
475 let reasoning = reasoning.join("\n");
476 print!("{reasoning}");
477 std::io::Write::flush(&mut std::io::stdout()).unwrap();
478 }
479 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
480 final_res = res;
481 }
482 Err(err) => {
483 eprintln!("Error: {err}");
484 }
485 _ => {}
486 }
487 }
488
489 Ok(final_res)
490}
491
492pub trait StreamingPromptHook<M>: Clone + Send + Sync
495where
496 M: CompletionModel,
497{
498 #[allow(unused_variables)]
499 fn on_completion_call(
501 &self,
502 prompt: &Message,
503 history: &[Message],
504 cancel_sig: CancelSignal,
505 ) -> impl Future<Output = ()> + Send {
506 async {}
507 }
508
509 #[allow(unused_variables)]
510 fn on_text_delta(
512 &self,
513 text_delta: &str,
514 aggregated_text: &str,
515 cancel_sig: CancelSignal,
516 ) -> impl Future<Output = ()> + Send {
517 async {}
518 }
519
520 #[allow(unused_variables)]
521 fn on_tool_call_delta(
523 &self,
524 tool_call_id: &str,
525 tool_call_delta: &str,
526 cancel_sig: CancelSignal,
527 ) -> impl Future<Output = ()> + Send {
528 async {}
529 }
530
531 #[allow(unused_variables)]
532 fn on_stream_completion_response_finish(
534 &self,
535 prompt: &Message,
536 response: &<M as CompletionModel>::StreamingResponse,
537 cancel_sig: CancelSignal,
538 ) -> impl Future<Output = ()> + Send {
539 async {}
540 }
541
542 #[allow(unused_variables)]
543 fn on_tool_call(
545 &self,
546 tool_name: &str,
547 args: &str,
548 cancel_sig: CancelSignal,
549 ) -> impl Future<Output = ()> + Send {
550 async {}
551 }
552
553 #[allow(unused_variables)]
554 fn on_tool_result(
556 &self,
557 tool_name: &str,
558 args: &str,
559 result: &str,
560 cancel_sig: CancelSignal,
561 ) -> impl Future<Output = ()> + Send {
562 async {}
563 }
564}
565
566impl<M> StreamingPromptHook<M> for () where M: CompletionModel {}