1pub mod streaming;
2
3pub use streaming::StreamingPromptHook;
4
5use std::{
6 future::IntoFuture,
7 marker::PhantomData,
8 sync::{
9 Arc, OnceLock,
10 atomic::{AtomicBool, AtomicU64, Ordering},
11 },
12};
13use tracing::{Instrument, span::Id};
14
15use futures::{StreamExt, stream};
16use tracing::info_span;
17
18use crate::{
19 OneOrMany,
20 completion::{Completion, CompletionModel, Message, PromptError, Usage},
21 json_utils,
22 message::{AssistantContent, UserContent},
23 tool::ToolSetError,
24 wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
25};
26
27use super::Agent;
28
29pub trait PromptType {}
30pub struct Standard;
31pub struct Extended;
32
33impl PromptType for Standard {}
34impl PromptType for Extended {}
35
36pub struct PromptRequest<'a, S, M, P>
45where
46 S: PromptType,
47 M: CompletionModel,
48 P: PromptHook<M>,
49{
50 prompt: Message,
52 chat_history: Option<&'a mut Vec<Message>>,
55 max_depth: usize,
57 agent: &'a Agent<M>,
59 state: PhantomData<S>,
61 hook: Option<P>,
63 concurrency: usize,
65}
66
67impl<'a, M> PromptRequest<'a, Standard, M, ()>
68where
69 M: CompletionModel,
70{
71 pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
73 Self {
74 prompt: prompt.into(),
75 chat_history: None,
76 max_depth: agent.default_max_depth.unwrap_or_default(),
77 agent,
78 state: PhantomData,
79 hook: None,
80 concurrency: 1,
81 }
82 }
83}
84
85impl<'a, S, M, P> PromptRequest<'a, S, M, P>
86where
87 S: PromptType,
88 M: CompletionModel,
89 P: PromptHook<M>,
90{
91 pub fn extended_details(self) -> PromptRequest<'a, Extended, M, P> {
97 PromptRequest {
98 prompt: self.prompt,
99 chat_history: self.chat_history,
100 max_depth: self.max_depth,
101 agent: self.agent,
102 state: PhantomData,
103 hook: self.hook,
104 concurrency: self.concurrency,
105 }
106 }
107 pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, S, M, P> {
110 PromptRequest {
111 prompt: self.prompt,
112 chat_history: self.chat_history,
113 max_depth: depth,
114 agent: self.agent,
115 state: PhantomData,
116 hook: self.hook,
117 concurrency: self.concurrency,
118 }
119 }
120
121 pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
124 self.concurrency = concurrency;
125 self
126 }
127
128 pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, S, M, P> {
130 PromptRequest {
131 prompt: self.prompt,
132 chat_history: Some(history),
133 max_depth: self.max_depth,
134 agent: self.agent,
135 state: PhantomData,
136 hook: self.hook,
137 concurrency: self.concurrency,
138 }
139 }
140
141 pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<'a, S, M, P2>
143 where
144 P2: PromptHook<M>,
145 {
146 PromptRequest {
147 prompt: self.prompt,
148 chat_history: self.chat_history,
149 max_depth: self.max_depth,
150 agent: self.agent,
151 state: PhantomData,
152 hook: Some(hook),
153 concurrency: self.concurrency,
154 }
155 }
156}
157
158pub struct CancelSignal {
162 sig: Arc<AtomicBool>,
163 reason: Arc<OnceLock<String>>,
164}
165
166impl CancelSignal {
167 fn new() -> Self {
168 Self {
169 sig: Arc::new(AtomicBool::new(false)),
170 reason: Arc::new(OnceLock::new()),
171 }
172 }
173
174 pub fn cancel(&self) {
175 self.sig.store(true, Ordering::SeqCst);
176 }
177
178 pub fn cancel_with_reason(&self, reason: &str) {
179 let _ = self.reason.set(reason.to_string());
182 self.cancel();
183 }
184
185 fn is_cancelled(&self) -> bool {
186 self.sig.load(Ordering::SeqCst)
187 }
188
189 fn cancel_reason(&self) -> Option<&str> {
190 self.reason.get().map(|x| x.as_str())
191 }
192}
193
194impl Clone for CancelSignal {
195 fn clone(&self) -> Self {
196 Self {
197 sig: self.sig.clone(),
198 reason: self.reason.clone(),
199 }
200 }
201}
202
203pub trait PromptHook<M>: Clone + WasmCompatSend + WasmCompatSync
206where
207 M: CompletionModel,
208{
209 #[allow(unused_variables)]
210 fn on_completion_call(
212 &self,
213 prompt: &Message,
214 history: &[Message],
215 cancel_sig: CancelSignal,
216 ) -> impl Future<Output = ()> + WasmCompatSend {
217 async {}
218 }
219
220 #[allow(unused_variables)]
221 fn on_completion_response(
223 &self,
224 prompt: &Message,
225 response: &crate::completion::CompletionResponse<M::Response>,
226 cancel_sig: CancelSignal,
227 ) -> impl Future<Output = ()> + WasmCompatSend {
228 async {}
229 }
230
231 #[allow(unused_variables)]
232 fn on_tool_call(
234 &self,
235 tool_name: &str,
236 tool_call_id: Option<String>,
237 args: &str,
238 cancel_sig: CancelSignal,
239 ) -> impl Future<Output = ()> + WasmCompatSend {
240 async {}
241 }
242
243 #[allow(unused_variables)]
244 fn on_tool_result(
246 &self,
247 tool_name: &str,
248 tool_call_id: Option<String>,
249 args: &str,
250 result: &str,
251 cancel_sig: CancelSignal,
252 ) -> impl Future<Output = ()> + WasmCompatSend {
253 async {}
254 }
255}
256
257impl<M> PromptHook<M> for () where M: CompletionModel {}
258
259impl<'a, M, P> IntoFuture for PromptRequest<'a, Standard, M, P>
263where
264 M: CompletionModel,
265 P: PromptHook<M> + 'static,
266{
267 type Output = Result<String, PromptError>;
268 type IntoFuture = WasmBoxedFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
271 Box::pin(self.send())
272 }
273}
274
275impl<'a, M, P> IntoFuture for PromptRequest<'a, Extended, M, P>
276where
277 M: CompletionModel,
278 P: PromptHook<M> + 'static,
279{
280 type Output = Result<PromptResponse, PromptError>;
281 type IntoFuture = WasmBoxedFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
284 Box::pin(self.send())
285 }
286}
287
288impl<M, P> PromptRequest<'_, Standard, M, P>
289where
290 M: CompletionModel,
291 P: PromptHook<M>,
292{
293 async fn send(self) -> Result<String, PromptError> {
294 self.extended_details().send().await.map(|resp| resp.output)
295 }
296}
297
298#[derive(Debug, Clone)]
299pub struct PromptResponse {
300 pub output: String,
301 pub total_usage: Usage,
302}
303
304impl PromptResponse {
305 pub fn new(output: impl Into<String>, total_usage: Usage) -> Self {
306 Self {
307 output: output.into(),
308 total_usage,
309 }
310 }
311}
312
313impl<M, P> PromptRequest<'_, Extended, M, P>
314where
315 M: CompletionModel,
316 P: PromptHook<M>,
317{
318 async fn send(self) -> Result<PromptResponse, PromptError> {
319 let agent_span = if tracing::Span::current().is_disabled() {
320 info_span!(
321 "invoke_agent",
322 gen_ai.operation.name = "invoke_agent",
323 gen_ai.agent.name = self.agent.name(),
324 gen_ai.system_instructions = self.agent.preamble,
325 gen_ai.prompt = tracing::field::Empty,
326 gen_ai.completion = tracing::field::Empty,
327 gen_ai.usage.input_tokens = tracing::field::Empty,
328 gen_ai.usage.output_tokens = tracing::field::Empty,
329 )
330 } else {
331 tracing::Span::current()
332 };
333
334 let agent = self.agent;
335 let chat_history = if let Some(history) = self.chat_history {
336 history.push(self.prompt.to_owned());
337 history
338 } else {
339 &mut vec![self.prompt.to_owned()]
340 };
341
342 if let Some(text) = self.prompt.rag_text() {
343 agent_span.record("gen_ai.prompt", text);
344 }
345
346 let cancel_sig = CancelSignal::new();
347
348 let mut current_max_depth = 0;
349 let mut usage = Usage::new();
350 let current_span_id: AtomicU64 = AtomicU64::new(0);
351
352 let last_prompt = loop {
354 let prompt = chat_history
355 .last()
356 .cloned()
357 .expect("there should always be at least one message in the chat history");
358
359 if current_max_depth > self.max_depth + 1 {
360 break prompt;
361 }
362
363 current_max_depth += 1;
364
365 if self.max_depth > 1 {
366 tracing::info!(
367 "Current conversation depth: {}/{}",
368 current_max_depth,
369 self.max_depth
370 );
371 }
372
373 if let Some(ref hook) = self.hook {
374 hook.on_completion_call(
375 &prompt,
376 &chat_history[..chat_history.len() - 1],
377 cancel_sig.clone(),
378 )
379 .await;
380 if cancel_sig.is_cancelled() {
381 return Err(PromptError::prompt_cancelled(
382 chat_history.to_vec(),
383 cancel_sig.cancel_reason().unwrap_or("<no reason given>"),
384 ));
385 }
386 }
387 let span = tracing::Span::current();
388 let chat_span = info_span!(
389 target: "rig::agent_chat",
390 parent: &span,
391 "chat",
392 gen_ai.operation.name = "chat",
393 gen_ai.agent.name = self.agent.name(),
394 gen_ai.system_instructions = self.agent.preamble,
395 gen_ai.provider.name = tracing::field::Empty,
396 gen_ai.request.model = tracing::field::Empty,
397 gen_ai.response.id = tracing::field::Empty,
398 gen_ai.response.model = tracing::field::Empty,
399 gen_ai.usage.output_tokens = tracing::field::Empty,
400 gen_ai.usage.input_tokens = tracing::field::Empty,
401 gen_ai.input.messages = tracing::field::Empty,
402 gen_ai.output.messages = tracing::field::Empty,
403 );
404
405 let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
406 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
407 chat_span.follows_from(id).to_owned()
408 } else {
409 chat_span
410 };
411
412 if let Some(id) = chat_span.id() {
413 current_span_id.store(id.into_u64(), Ordering::SeqCst);
414 };
415
416 let resp = agent
417 .completion(
418 prompt.clone(),
419 chat_history[..chat_history.len() - 1].to_vec(),
420 )
421 .await?
422 .send()
423 .instrument(chat_span.clone())
424 .await?;
425
426 usage += resp.usage;
427
428 if let Some(ref hook) = self.hook {
429 hook.on_completion_response(&prompt, &resp, cancel_sig.clone())
430 .await;
431 if cancel_sig.is_cancelled() {
432 return Err(PromptError::prompt_cancelled(
433 chat_history.to_vec(),
434 cancel_sig.cancel_reason().unwrap_or("<no reason given>"),
435 ));
436 }
437 }
438
439 let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
440 .choice
441 .iter()
442 .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
443
444 chat_history.push(Message::Assistant {
445 id: None,
446 content: resp.choice.clone(),
447 });
448
449 if tool_calls.is_empty() {
450 let merged_texts = texts
451 .into_iter()
452 .filter_map(|content| {
453 if let AssistantContent::Text(text) = content {
454 Some(text.text.clone())
455 } else {
456 None
457 }
458 })
459 .collect::<Vec<_>>()
460 .join("\n");
461
462 if self.max_depth > 1 {
463 tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
464 }
465
466 agent_span.record("gen_ai.completion", &merged_texts);
467 agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
468 agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
469
470 return Ok(PromptResponse::new(merged_texts, usage));
472 }
473
474 let hook = self.hook.clone();
475
476 let tool_calls: Vec<AssistantContent> = tool_calls.into_iter().cloned().collect();
477 let tool_content = stream::iter(tool_calls)
478 .map(|choice| {
479 let hook1 = hook.clone();
480 let hook2 = hook.clone();
481
482 let cancel_sig1 = cancel_sig.clone();
483 let cancel_sig2 = cancel_sig.clone();
484
485 let tool_span = info_span!(
486 "execute_tool",
487 gen_ai.operation.name = "execute_tool",
488 gen_ai.tool.type = "function",
489 gen_ai.tool.name = tracing::field::Empty,
490 gen_ai.tool.call.id = tracing::field::Empty,
491 gen_ai.tool.call.arguments = tracing::field::Empty,
492 gen_ai.tool.call.result = tracing::field::Empty
493 );
494
495 let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
496 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
497 tool_span.follows_from(id).to_owned()
498 } else {
499 tool_span
500 };
501
502 if let Some(id) = tool_span.id() {
503 current_span_id.store(id.into_u64(), Ordering::SeqCst);
504 };
505
506 async move {
507 if let AssistantContent::ToolCall(tool_call) = choice {
508 let tool_name = &tool_call.function.name;
509 let args =
510 json_utils::value_to_json_string(&tool_call.function.arguments);
511 let tool_span = tracing::Span::current();
512 tool_span.record("gen_ai.tool.name", tool_name);
513 tool_span.record("gen_ai.tool.call.id", &tool_call.id);
514 tool_span.record("gen_ai.tool.call.arguments", &args);
515 if let Some(hook) = hook1 {
516 hook.on_tool_call(
517 tool_name,
518 tool_call.call_id.clone(),
519 &args,
520 cancel_sig1.clone(),
521 )
522 .await;
523 if cancel_sig1.is_cancelled() {
524 return Err(ToolSetError::Interrupted);
525 }
526 }
527 let output =
528 match agent.tool_server_handle.call_tool(tool_name, &args).await {
529 Ok(res) => res,
530 Err(e) => {
531 tracing::warn!("Error while executing tool: {e}");
532 e.to_string()
533 }
534 };
535 if let Some(hook) = hook2 {
536 hook.on_tool_result(
537 tool_name,
538 tool_call.call_id.clone(),
539 &args,
540 &output.to_string(),
541 cancel_sig2.clone(),
542 )
543 .await;
544
545 if cancel_sig2.is_cancelled() {
546 return Err(ToolSetError::Interrupted);
547 }
548 }
549 tool_span.record("gen_ai.tool.call.result", &output);
550 tracing::info!(
551 "executed tool {tool_name} with args {args}. result: {output}"
552 );
553 if let Some(call_id) = tool_call.call_id.clone() {
554 Ok(UserContent::tool_result_with_call_id(
555 tool_call.id.clone(),
556 call_id,
557 OneOrMany::one(output.into()),
558 ))
559 } else {
560 Ok(UserContent::tool_result(
561 tool_call.id.clone(),
562 OneOrMany::one(output.into()),
563 ))
564 }
565 } else {
566 unreachable!(
567 "This should never happen as we already filtered for `ToolCall`"
568 )
569 }
570 }
571 .instrument(tool_span)
572 })
573 .buffer_unordered(self.concurrency)
574 .collect::<Vec<Result<UserContent, ToolSetError>>>()
575 .await
576 .into_iter()
577 .collect::<Result<Vec<_>, _>>()
578 .map_err(|e| {
579 if matches!(e, ToolSetError::Interrupted) {
580 PromptError::prompt_cancelled(
581 chat_history.to_vec(),
582 cancel_sig.cancel_reason().unwrap_or("<no reason given>"),
583 )
584 } else {
585 e.into()
586 }
587 })?;
588
589 chat_history.push(Message::User {
590 content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
591 });
592 };
593
594 Err(PromptError::MaxDepthError {
596 max_depth: self.max_depth,
597 chat_history: Box::new(chat_history.clone()),
598 prompt: Box::new(last_prompt),
599 })
600 }
601}