1use crate::{
2 OneOrMany,
3 agent::CancelSignal,
4 completion::GetTokenUsage,
5 message::{AssistantContent, Reasoning, ToolResultContent, UserContent},
6 streaming::{StreamedAssistantContent, StreamingCompletion},
7 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
8};
9use futures::{Stream, StreamExt};
10use serde::{Deserialize, Serialize};
11use std::{pin::Pin, sync::Arc};
12use tokio::sync::RwLock;
13use tracing::info_span;
14use tracing_futures::Instrument;
15
16use crate::{
17 agent::Agent,
18 completion::{CompletionError, CompletionModel, PromptError},
19 message::{Message, Text},
20 tool::ToolSetError,
21};
22
23#[cfg(not(target_arch = "wasm32"))]
24pub type StreamingResult<R> =
25 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
26
27#[cfg(target_arch = "wasm32")]
28pub type StreamingResult<R> =
29 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
30
31#[derive(Deserialize, Serialize, Debug, Clone)]
32#[serde(tag = "type", rename_all = "camelCase")]
33#[non_exhaustive]
34pub enum MultiTurnStreamItem<R> {
35 StreamItem(StreamedAssistantContent<R>),
36 FinalResponse(FinalResponse),
37}
38
39#[derive(Deserialize, Serialize, Debug, Clone)]
40#[serde(rename_all = "camelCase")]
41pub struct FinalResponse {
42 response: String,
43 aggregated_usage: crate::completion::Usage,
44}
45
46impl FinalResponse {
47 pub fn empty() -> Self {
48 Self {
49 response: String::new(),
50 aggregated_usage: crate::completion::Usage::new(),
51 }
52 }
53
54 pub fn response(&self) -> &str {
55 &self.response
56 }
57
58 pub fn usage(&self) -> crate::completion::Usage {
59 self.aggregated_usage
60 }
61}
62
63impl<R> MultiTurnStreamItem<R> {
64 pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
65 Self::StreamItem(item)
66 }
67
68 pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
69 Self::FinalResponse(FinalResponse {
70 response: response.to_string(),
71 aggregated_usage,
72 })
73 }
74}
75
76#[derive(Debug, thiserror::Error)]
77pub enum StreamingError {
78 #[error("CompletionError: {0}")]
79 Completion(#[from] CompletionError),
80 #[error("PromptError: {0}")]
81 Prompt(#[from] Box<PromptError>),
82 #[error("ToolSetError: {0}")]
83 Tool(#[from] ToolSetError),
84}
85
86pub struct StreamingPromptRequest<M, P>
95where
96 M: CompletionModel,
97 P: StreamingPromptHook<M> + 'static,
98{
99 prompt: Message,
101 chat_history: Option<Vec<Message>>,
104 max_depth: usize,
106 agent: Arc<Agent<M>>,
108 hook: Option<P>,
110}
111
112impl<M, P> StreamingPromptRequest<M, P>
113where
114 M: CompletionModel + 'static,
115 <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
116 P: StreamingPromptHook<M>,
117{
118 pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> Self {
120 Self {
121 prompt: prompt.into(),
122 chat_history: None,
123 max_depth: 0,
124 agent,
125 hook: None,
126 }
127 }
128
129 pub fn multi_turn(mut self, depth: usize) -> Self {
132 self.max_depth = depth;
133 self
134 }
135
136 pub fn with_history(mut self, history: Vec<Message>) -> Self {
138 self.chat_history = Some(history);
139 self
140 }
141
142 pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
144 where
145 P2: StreamingPromptHook<M>,
146 {
147 StreamingPromptRequest {
148 prompt: self.prompt,
149 chat_history: self.chat_history,
150 max_depth: self.max_depth,
151 agent: self.agent,
152 hook: Some(hook),
153 }
154 }
155
156 #[cfg_attr(feature = "worker", worker::send)]
157 async fn send(self) -> StreamingResult<M::StreamingResponse> {
158 let agent_span = if tracing::Span::current().is_disabled() {
159 info_span!(
160 "invoke_agent",
161 gen_ai.operation.name = "invoke_agent",
162 gen_ai.agent.name = self.agent.name(),
163 gen_ai.system_instructions = self.agent.preamble,
164 gen_ai.prompt = tracing::field::Empty,
165 gen_ai.completion = tracing::field::Empty,
166 gen_ai.usage.input_tokens = tracing::field::Empty,
167 gen_ai.usage.output_tokens = tracing::field::Empty,
168 )
169 } else {
170 tracing::Span::current()
171 };
172
173 let prompt = self.prompt;
174 if let Some(text) = prompt.rag_text() {
175 agent_span.record("gen_ai.prompt", text);
176 }
177
178 let agent = self.agent;
179
180 let chat_history = if let Some(history) = self.chat_history {
181 Arc::new(RwLock::new(history))
182 } else {
183 Arc::new(RwLock::new(vec![]))
184 };
185
186 let mut current_max_depth = 0;
187 let mut last_prompt_error = String::new();
188
189 let mut last_text_response = String::new();
190 let mut is_text_response = false;
191 let mut max_depth_reached = false;
192
193 let mut aggregated_usage = crate::completion::Usage::new();
194
195 let cancel_signal = CancelSignal::new();
196
197 Box::pin(async_stream::stream! {
198 let _guard = agent_span.enter();
199 let mut current_prompt = prompt.clone();
200 let mut did_call_tool = false;
201
202 'outer: loop {
203 if current_max_depth > self.max_depth + 1 {
204 last_prompt_error = current_prompt.rag_text().unwrap_or_default();
205 max_depth_reached = true;
206 break;
207 }
208
209 current_max_depth += 1;
210
211 if self.max_depth > 1 {
212 tracing::info!(
213 "Current conversation depth: {}/{}",
214 current_max_depth,
215 self.max_depth
216 );
217 }
218
219 if let Some(ref hook) = self.hook {
220 let reader = chat_history.read().await;
221 let prompt = reader.last().cloned().expect("there should always be at least one message in the chat history");
222 let chat_history_except_last = reader[..reader.len() - 1].to_vec();
223
224 hook.on_completion_call(&prompt, &chat_history_except_last, cancel_signal.clone())
225 .await;
226
227 if cancel_signal.is_cancelled() {
228 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
229 }
230 }
231
232 let chat_stream_span = info_span!(
233 target: "rig::agent_chat",
234 parent: tracing::Span::current(),
235 "chat_streaming",
236 gen_ai.operation.name = "chat",
237 gen_ai.system_instructions = &agent.preamble,
238 gen_ai.provider.name = tracing::field::Empty,
239 gen_ai.request.model = tracing::field::Empty,
240 gen_ai.response.id = tracing::field::Empty,
241 gen_ai.response.model = tracing::field::Empty,
242 gen_ai.usage.output_tokens = tracing::field::Empty,
243 gen_ai.usage.input_tokens = tracing::field::Empty,
244 gen_ai.input.messages = tracing::field::Empty,
245 gen_ai.output.messages = tracing::field::Empty,
246 );
247
248 let mut stream = tracing::Instrument::instrument(
249 agent
250 .stream_completion(current_prompt.clone(), (*chat_history.read().await).clone())
251 .await?
252 .stream(), chat_stream_span
253 )
254
255 .await?;
256
257 chat_history.write().await.push(current_prompt.clone());
258
259 let mut tool_calls = vec![];
260 let mut tool_results = vec![];
261
262 while let Some(content) = stream.next().await {
263 match content {
264 Ok(StreamedAssistantContent::Text(text)) => {
265 if !is_text_response {
266 last_text_response = String::new();
267 is_text_response = true;
268 }
269 last_text_response.push_str(&text.text);
270 if let Some(ref hook) = self.hook {
271 hook.on_text_delta(&text.text, &last_text_response, cancel_signal.clone()).await;
272 if cancel_signal.is_cancelled() {
273 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
274 }
275 }
276 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
277 did_call_tool = false;
278 },
279 Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
280 let tool_span = info_span!(
281 parent: tracing::Span::current(),
282 "execute_tool",
283 gen_ai.operation.name = "execute_tool",
284 gen_ai.tool.type = "function",
285 gen_ai.tool.name = tracing::field::Empty,
286 gen_ai.tool.call.id = tracing::field::Empty,
287 gen_ai.tool.call.arguments = tracing::field::Empty,
288 gen_ai.tool.call.result = tracing::field::Empty
289 );
290
291 let res = async {
292 let tool_span = tracing::Span::current();
293 if let Some(ref hook) = self.hook {
294 hook.on_tool_call(&tool_call.function.name, &tool_call.function.arguments.to_string(), cancel_signal.clone()).await;
295 if cancel_signal.is_cancelled() {
296 return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
297 }
298 }
299
300 tool_span.record("gen_ai.tool.name", &tool_call.function.name);
301 tool_span.record("gen_ai.tool.call.arguments", tool_call.function.arguments.to_string());
302
303 let tool_result = match
304 agent.tool_server_handle.call_tool(&tool_call.function.name, &tool_call.function.arguments.to_string()).await {
305 Ok(thing) => thing,
306 Err(e) => {
307 tracing::warn!("Error while calling tool: {e}");
308 e.to_string()
309 }
310 };
311
312 tool_span.record("gen_ai.tool.call.result", &tool_result);
313
314 if let Some(ref hook) = self.hook {
315 hook.on_tool_result(&tool_call.function.name, &tool_call.function.arguments.to_string(), &tool_result.to_string(), cancel_signal.clone())
316 .await;
317
318 if cancel_signal.is_cancelled() {
319 return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
320 }
321 }
322
323 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
324
325 tool_calls.push(tool_call_msg);
326 tool_results.push((tool_call.id, tool_call.call_id, tool_result));
327
328 did_call_tool = true;
329 Ok(())
330 }.instrument(tool_span).await;
332
333 if let Err(e) = res {
334 yield Err(e);
335 }
336 },
337 Ok(StreamedAssistantContent::ToolCallDelta { id, delta }) => {
338 if let Some(ref hook) = self.hook {
339 hook.on_tool_call_delta(&id, &delta, cancel_signal.clone())
340 .await;
341
342 if cancel_signal.is_cancelled() {
343 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
344 }
345 }
346 }
347 Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })) => {
348 chat_history.write().await.push(rig::message::Message::Assistant {
349 id: None,
350 content: OneOrMany::one(AssistantContent::Reasoning(Reasoning {
351 reasoning: reasoning.clone(), id: id.clone(), signature: signature.clone()
352 }))
353 });
354 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })));
355 did_call_tool = false;
356 },
357 Ok(StreamedAssistantContent::Final(final_resp)) => {
358 if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
359 if is_text_response {
360 if let Some(ref hook) = self.hook {
361 hook.on_stream_completion_response_finish(&prompt, &final_resp, cancel_signal.clone()).await;
362
363 if cancel_signal.is_cancelled() {
364 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
365 }
366 }
367
368 tracing::Span::current().record("gen_ai.completion", &last_text_response);
369 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
370 is_text_response = false;
371 }
372 }
373 Err(e) => {
374 yield Err(e.into());
375 break 'outer;
376 }
377 }
378 }
379
380 if !tool_calls.is_empty() {
382 chat_history.write().await.push(Message::Assistant {
383 id: None,
384 content: OneOrMany::many(tool_calls.clone()).expect("Impossible EmptyListError"),
385 });
386 }
387
388 for (id, call_id, tool_result) in tool_results {
390 if let Some(call_id) = call_id {
391 chat_history.write().await.push(Message::User {
392 content: OneOrMany::one(UserContent::tool_result_with_call_id(
393 &id,
394 call_id.clone(),
395 OneOrMany::one(ToolResultContent::text(&tool_result)),
396 )),
397 });
398 } else {
399 chat_history.write().await.push(Message::User {
400 content: OneOrMany::one(UserContent::tool_result(
401 &id,
402 OneOrMany::one(ToolResultContent::text(&tool_result)),
403 )),
404 });
405 }
406 }
407
408 current_prompt = match chat_history.write().await.pop() {
410 Some(prompt) => prompt,
411 None => unreachable!("Chat history should never be empty at this point"),
412 };
413
414 if !did_call_tool {
415 let current_span = tracing::Span::current();
416 current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
417 current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
418 tracing::info!("Agent multi-turn stream finished");
419 yield Ok(MultiTurnStreamItem::final_response(&last_text_response, aggregated_usage));
420 break;
421 }
422 }
423
424 if max_depth_reached {
425 yield Err(Box::new(PromptError::MaxDepthError {
426 max_depth: self.max_depth,
427 chat_history: Box::new((*chat_history.read().await).clone()),
428 prompt: last_prompt_error.clone().into(),
429 }).into());
430 }
431
432 })
433 }
434}
435
436impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
437where
438 M: CompletionModel + 'static,
439 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
440 P: StreamingPromptHook<M> + 'static,
441{
442 type Output = StreamingResult<M::StreamingResponse>; type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
444
445 fn into_future(self) -> Self::IntoFuture {
446 Box::pin(async move { self.send().await })
448 }
449}
450
451pub async fn stream_to_stdout<R>(
453 stream: &mut StreamingResult<R>,
454) -> Result<FinalResponse, std::io::Error> {
455 let mut final_res = FinalResponse::empty();
456 print!("Response: ");
457 while let Some(content) = stream.next().await {
458 match content {
459 Ok(MultiTurnStreamItem::StreamItem(StreamedAssistantContent::Text(Text { text }))) => {
460 print!("{text}");
461 std::io::Write::flush(&mut std::io::stdout()).unwrap();
462 }
463 Ok(MultiTurnStreamItem::StreamItem(StreamedAssistantContent::Reasoning(
464 Reasoning { reasoning, .. },
465 ))) => {
466 let reasoning = reasoning.join("\n");
467 print!("{reasoning}");
468 std::io::Write::flush(&mut std::io::stdout()).unwrap();
469 }
470 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
471 final_res = res;
472 }
473 Err(err) => {
474 eprintln!("Error: {err}");
475 }
476 _ => {}
477 }
478 }
479
480 Ok(final_res)
481}
482
483pub trait StreamingPromptHook<M>: Clone + Send + Sync
486where
487 M: CompletionModel,
488{
489 #[allow(unused_variables)]
490 fn on_completion_call(
492 &self,
493 prompt: &Message,
494 history: &[Message],
495 cancel_sig: CancelSignal,
496 ) -> impl Future<Output = ()> + Send {
497 async {}
498 }
499
500 #[allow(unused_variables)]
501 fn on_text_delta(
503 &self,
504 text_delta: &str,
505 aggregated_text: &str,
506 cancel_sig: CancelSignal,
507 ) -> impl Future<Output = ()> + Send {
508 async {}
509 }
510
511 #[allow(unused_variables)]
512 fn on_tool_call_delta(
514 &self,
515 tool_call_id: &str,
516 tool_call_delta: &str,
517 cancel_sig: CancelSignal,
518 ) -> impl Future<Output = ()> + Send {
519 async {}
520 }
521
522 #[allow(unused_variables)]
523 fn on_stream_completion_response_finish(
525 &self,
526 prompt: &Message,
527 response: &<M as CompletionModel>::StreamingResponse,
528 cancel_sig: CancelSignal,
529 ) -> impl Future<Output = ()> + Send {
530 async {}
531 }
532
533 #[allow(unused_variables)]
534 fn on_tool_call(
536 &self,
537 tool_name: &str,
538 args: &str,
539 cancel_sig: CancelSignal,
540 ) -> impl Future<Output = ()> + Send {
541 async {}
542 }
543
544 #[allow(unused_variables)]
545 fn on_tool_result(
547 &self,
548 tool_name: &str,
549 args: &str,
550 result: &str,
551 cancel_sig: CancelSignal,
552 ) -> impl Future<Output = ()> + Send {
553 async {}
554 }
555}
556
557impl<M> StreamingPromptHook<M> for () where M: CompletionModel {}