1pub mod streaming;
2
3pub use streaming::StreamingPromptHook;
4
5use std::{
6 future::IntoFuture,
7 marker::PhantomData,
8 sync::{
9 Arc,
10 atomic::{AtomicBool, AtomicU64, Ordering},
11 },
12};
13use tracing::{Instrument, span::Id};
14
15use futures::{StreamExt, stream};
16use tracing::info_span;
17
18use crate::{
19 OneOrMany,
20 completion::{Completion, CompletionModel, Message, PromptError, Usage},
21 json_utils,
22 message::{AssistantContent, UserContent},
23 tool::ToolSetError,
24 wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
25};
26
27use super::Agent;
28
29pub trait PromptType {}
30pub struct Standard;
31pub struct Extended;
32
33impl PromptType for Standard {}
34impl PromptType for Extended {}
35
36pub struct PromptRequest<'a, S, M, P>
45where
46 S: PromptType,
47 M: CompletionModel,
48 P: PromptHook<M>,
49{
50 prompt: Message,
52 chat_history: Option<&'a mut Vec<Message>>,
55 max_depth: usize,
57 agent: &'a Agent<M>,
59 state: PhantomData<S>,
61 hook: Option<P>,
63 concurrency: usize,
65}
66
67impl<'a, M> PromptRequest<'a, Standard, M, ()>
68where
69 M: CompletionModel,
70{
71 pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
73 Self {
74 prompt: prompt.into(),
75 chat_history: None,
76 max_depth: 0,
77 agent,
78 state: PhantomData,
79 hook: None,
80 concurrency: 1,
81 }
82 }
83}
84
85impl<'a, S, M, P> PromptRequest<'a, S, M, P>
86where
87 S: PromptType,
88 M: CompletionModel,
89 P: PromptHook<M>,
90{
91 pub fn extended_details(self) -> PromptRequest<'a, Extended, M, P> {
97 PromptRequest {
98 prompt: self.prompt,
99 chat_history: self.chat_history,
100 max_depth: self.max_depth,
101 agent: self.agent,
102 state: PhantomData,
103 hook: self.hook,
104 concurrency: self.concurrency,
105 }
106 }
107 pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, S, M, P> {
110 PromptRequest {
111 prompt: self.prompt,
112 chat_history: self.chat_history,
113 max_depth: depth,
114 agent: self.agent,
115 state: PhantomData,
116 hook: self.hook,
117 concurrency: self.concurrency,
118 }
119 }
120
121 pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
124 self.concurrency = concurrency;
125 self
126 }
127
128 pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, S, M, P> {
130 PromptRequest {
131 prompt: self.prompt,
132 chat_history: Some(history),
133 max_depth: self.max_depth,
134 agent: self.agent,
135 state: PhantomData,
136 hook: self.hook,
137 concurrency: self.concurrency,
138 }
139 }
140
141 pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<'a, S, M, P2>
143 where
144 P2: PromptHook<M>,
145 {
146 PromptRequest {
147 prompt: self.prompt,
148 chat_history: self.chat_history,
149 max_depth: self.max_depth,
150 agent: self.agent,
151 state: PhantomData,
152 hook: Some(hook),
153 concurrency: self.concurrency,
154 }
155 }
156}
157
158pub struct CancelSignal(Arc<AtomicBool>);
159
160impl CancelSignal {
161 fn new() -> Self {
162 Self(Arc::new(AtomicBool::new(false)))
163 }
164
165 pub fn cancel(&self) {
166 self.0.store(true, Ordering::SeqCst);
167 }
168
169 fn is_cancelled(&self) -> bool {
170 self.0.load(Ordering::SeqCst)
171 }
172}
173
174impl Clone for CancelSignal {
175 fn clone(&self) -> Self {
176 Self(self.0.clone())
177 }
178}
179
180pub trait PromptHook<M>: Clone + WasmCompatSend + WasmCompatSync
183where
184 M: CompletionModel,
185{
186 #[allow(unused_variables)]
187 fn on_completion_call(
189 &self,
190 prompt: &Message,
191 history: &[Message],
192 cancel_sig: CancelSignal,
193 ) -> impl Future<Output = ()> + WasmCompatSend {
194 async {}
195 }
196
197 #[allow(unused_variables)]
198 fn on_completion_response(
200 &self,
201 prompt: &Message,
202 response: &crate::completion::CompletionResponse<M::Response>,
203 cancel_sig: CancelSignal,
204 ) -> impl Future<Output = ()> + WasmCompatSend {
205 async {}
206 }
207
208 #[allow(unused_variables)]
209 fn on_tool_call(
211 &self,
212 tool_name: &str,
213 tool_call_id: Option<String>,
214 args: &str,
215 cancel_sig: CancelSignal,
216 ) -> impl Future<Output = ()> + WasmCompatSend {
217 async {}
218 }
219
220 #[allow(unused_variables)]
221 fn on_tool_result(
223 &self,
224 tool_name: &str,
225 tool_call_id: Option<String>,
226 args: &str,
227 result: &str,
228 cancel_sig: CancelSignal,
229 ) -> impl Future<Output = ()> + WasmCompatSend {
230 async {}
231 }
232}
233
234impl<M> PromptHook<M> for () where M: CompletionModel {}
235
236impl<'a, M, P> IntoFuture for PromptRequest<'a, Standard, M, P>
240where
241 M: CompletionModel,
242 P: PromptHook<M> + 'static,
243{
244 type Output = Result<String, PromptError>;
245 type IntoFuture = WasmBoxedFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
248 Box::pin(self.send())
249 }
250}
251
252impl<'a, M, P> IntoFuture for PromptRequest<'a, Extended, M, P>
253where
254 M: CompletionModel,
255 P: PromptHook<M> + 'static,
256{
257 type Output = Result<PromptResponse, PromptError>;
258 type IntoFuture = WasmBoxedFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
261 Box::pin(self.send())
262 }
263}
264
265impl<M, P> PromptRequest<'_, Standard, M, P>
266where
267 M: CompletionModel,
268 P: PromptHook<M>,
269{
270 async fn send(self) -> Result<String, PromptError> {
271 self.extended_details().send().await.map(|resp| resp.output)
272 }
273}
274
275#[derive(Debug, Clone)]
276pub struct PromptResponse {
277 pub output: String,
278 pub total_usage: Usage,
279}
280
281impl PromptResponse {
282 pub fn new(output: impl Into<String>, total_usage: Usage) -> Self {
283 Self {
284 output: output.into(),
285 total_usage,
286 }
287 }
288}
289
290impl<M, P> PromptRequest<'_, Extended, M, P>
291where
292 M: CompletionModel,
293 P: PromptHook<M>,
294{
295 async fn send(self) -> Result<PromptResponse, PromptError> {
296 let agent_span = if tracing::Span::current().is_disabled() {
297 info_span!(
298 "invoke_agent",
299 gen_ai.operation.name = "invoke_agent",
300 gen_ai.agent.name = self.agent.name(),
301 gen_ai.system_instructions = self.agent.preamble,
302 gen_ai.prompt = tracing::field::Empty,
303 gen_ai.completion = tracing::field::Empty,
304 gen_ai.usage.input_tokens = tracing::field::Empty,
305 gen_ai.usage.output_tokens = tracing::field::Empty,
306 )
307 } else {
308 tracing::Span::current()
309 };
310
311 let agent = self.agent;
312 let chat_history = if let Some(history) = self.chat_history {
313 history.push(self.prompt.to_owned());
314 history
315 } else {
316 &mut vec![self.prompt.to_owned()]
317 };
318
319 if let Some(text) = self.prompt.rag_text() {
320 agent_span.record("gen_ai.prompt", text);
321 }
322
323 let cancel_sig = CancelSignal::new();
324
325 let mut current_max_depth = 0;
326 let mut usage = Usage::new();
327 let current_span_id: AtomicU64 = AtomicU64::new(0);
328
329 let last_prompt = loop {
331 let prompt = chat_history
332 .last()
333 .cloned()
334 .expect("there should always be at least one message in the chat history");
335
336 if current_max_depth > self.max_depth + 1 {
337 break prompt;
338 }
339
340 current_max_depth += 1;
341
342 if self.max_depth > 1 {
343 tracing::info!(
344 "Current conversation depth: {}/{}",
345 current_max_depth,
346 self.max_depth
347 );
348 }
349
350 if let Some(ref hook) = self.hook {
351 hook.on_completion_call(
352 &prompt,
353 &chat_history[..chat_history.len() - 1],
354 cancel_sig.clone(),
355 )
356 .await;
357 if cancel_sig.is_cancelled() {
358 return Err(PromptError::prompt_cancelled(chat_history.to_vec()));
359 }
360 }
361 let span = tracing::Span::current();
362 let chat_span = info_span!(
363 target: "rig::agent_chat",
364 parent: &span,
365 "chat",
366 gen_ai.operation.name = "chat",
367 gen_ai.system_instructions = self.agent.preamble,
368 gen_ai.provider.name = tracing::field::Empty,
369 gen_ai.request.model = tracing::field::Empty,
370 gen_ai.response.id = tracing::field::Empty,
371 gen_ai.response.model = tracing::field::Empty,
372 gen_ai.usage.output_tokens = tracing::field::Empty,
373 gen_ai.usage.input_tokens = tracing::field::Empty,
374 gen_ai.input.messages = tracing::field::Empty,
375 gen_ai.output.messages = tracing::field::Empty,
376 );
377
378 let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
379 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
380 chat_span.follows_from(id).to_owned()
381 } else {
382 chat_span
383 };
384
385 if let Some(id) = chat_span.id() {
386 current_span_id.store(id.into_u64(), Ordering::SeqCst);
387 };
388
389 let resp = agent
390 .completion(
391 prompt.clone(),
392 chat_history[..chat_history.len() - 1].to_vec(),
393 )
394 .await?
395 .send()
396 .instrument(chat_span.clone())
397 .await?;
398
399 usage += resp.usage;
400
401 if let Some(ref hook) = self.hook {
402 hook.on_completion_response(&prompt, &resp, cancel_sig.clone())
403 .await;
404 if cancel_sig.is_cancelled() {
405 return Err(PromptError::prompt_cancelled(chat_history.to_vec()));
406 }
407 }
408
409 let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
410 .choice
411 .iter()
412 .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
413
414 chat_history.push(Message::Assistant {
415 id: None,
416 content: resp.choice.clone(),
417 });
418
419 if tool_calls.is_empty() {
420 let merged_texts = texts
421 .into_iter()
422 .filter_map(|content| {
423 if let AssistantContent::Text(text) = content {
424 Some(text.text.clone())
425 } else {
426 None
427 }
428 })
429 .collect::<Vec<_>>()
430 .join("\n");
431
432 if self.max_depth > 1 {
433 tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
434 }
435
436 agent_span.record("gen_ai.completion", &merged_texts);
437 agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
438 agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
439
440 return Ok(PromptResponse::new(merged_texts, usage));
442 }
443
444 let hook = self.hook.clone();
445
446 let tool_calls: Vec<AssistantContent> = tool_calls.into_iter().cloned().collect();
447 let tool_content = stream::iter(tool_calls)
448 .map(|choice| {
449 let hook1 = hook.clone();
450 let hook2 = hook.clone();
451
452 let cancel_sig1 = cancel_sig.clone();
453 let cancel_sig2 = cancel_sig.clone();
454
455 let tool_span = info_span!(
456 "execute_tool",
457 gen_ai.operation.name = "execute_tool",
458 gen_ai.tool.type = "function",
459 gen_ai.tool.name = tracing::field::Empty,
460 gen_ai.tool.call.id = tracing::field::Empty,
461 gen_ai.tool.call.arguments = tracing::field::Empty,
462 gen_ai.tool.call.result = tracing::field::Empty
463 );
464
465 let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
466 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
467 tool_span.follows_from(id).to_owned()
468 } else {
469 tool_span
470 };
471
472 if let Some(id) = tool_span.id() {
473 current_span_id.store(id.into_u64(), Ordering::SeqCst);
474 };
475
476 async move {
477 if let AssistantContent::ToolCall(tool_call) = choice {
478 let tool_name = &tool_call.function.name;
479 let args =
480 json_utils::value_to_json_string(&tool_call.function.arguments);
481 let tool_span = tracing::Span::current();
482 tool_span.record("gen_ai.tool.name", tool_name);
483 tool_span.record("gen_ai.tool.call.id", &tool_call.id);
484 tool_span.record("gen_ai.tool.call.arguments", &args);
485 if let Some(hook) = hook1 {
486 hook.on_tool_call(
487 tool_name,
488 tool_call.call_id.clone(),
489 &args,
490 cancel_sig1.clone(),
491 )
492 .await;
493 if cancel_sig1.is_cancelled() {
494 return Err(ToolSetError::Interrupted);
495 }
496 }
497 let output =
498 match agent.tool_server_handle.call_tool(tool_name, &args).await {
499 Ok(res) => res,
500 Err(e) => {
501 tracing::warn!("Error while executing tool: {e}");
502 e.to_string()
503 }
504 };
505 if let Some(hook) = hook2 {
506 hook.on_tool_result(
507 tool_name,
508 tool_call.call_id.clone(),
509 &args,
510 &output.to_string(),
511 cancel_sig2.clone(),
512 )
513 .await;
514
515 if cancel_sig2.is_cancelled() {
516 return Err(ToolSetError::Interrupted);
517 }
518 }
519 tool_span.record("gen_ai.tool.call.result", &output);
520 tracing::info!(
521 "executed tool {tool_name} with args {args}. result: {output}"
522 );
523 if let Some(call_id) = tool_call.call_id.clone() {
524 Ok(UserContent::tool_result_with_call_id(
525 tool_call.id.clone(),
526 call_id,
527 OneOrMany::one(output.into()),
528 ))
529 } else {
530 Ok(UserContent::tool_result(
531 tool_call.id.clone(),
532 OneOrMany::one(output.into()),
533 ))
534 }
535 } else {
536 unreachable!(
537 "This should never happen as we already filtered for `ToolCall`"
538 )
539 }
540 }
541 .instrument(tool_span)
542 })
543 .buffer_unordered(self.concurrency)
544 .collect::<Vec<Result<UserContent, ToolSetError>>>()
545 .await
546 .into_iter()
547 .collect::<Result<Vec<_>, _>>()
548 .map_err(|e| {
549 if matches!(e, ToolSetError::Interrupted) {
550 PromptError::prompt_cancelled(chat_history.to_vec())
551 } else {
552 e.into()
553 }
554 })?;
555
556 chat_history.push(Message::User {
557 content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
558 });
559 };
560
561 Err(PromptError::MaxDepthError {
563 max_depth: self.max_depth,
564 chat_history: Box::new(chat_history.clone()),
565 prompt: Box::new(last_prompt),
566 })
567 }
568}