1pub mod streaming;
2
3pub use streaming::StreamingPromptHook;
4
5use std::{
6 future::IntoFuture,
7 marker::PhantomData,
8 sync::atomic::{AtomicU64, Ordering},
9};
10use tracing::{Instrument, span::Id};
11
12use futures::{StreamExt, stream};
13use tracing::info_span;
14
15use crate::{
16 OneOrMany,
17 completion::{Completion, CompletionModel, Message, PromptError, Usage},
18 json_utils,
19 message::{AssistantContent, ToolResultContent, UserContent},
20 wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
21};
22
23use super::Agent;
24
25pub trait PromptType {}
26pub struct Standard;
27pub struct Extended;
28
29impl PromptType for Standard {}
30impl PromptType for Extended {}
31
32#[derive(Debug, Clone, PartialEq, Eq)]
34pub enum ToolCallHookAction {
35 Continue,
37 Skip { reason: String },
39 Terminate { reason: String },
41}
42
43impl ToolCallHookAction {
44 pub fn cont() -> Self {
46 Self::Continue
47 }
48
49 pub fn skip(reason: impl Into<String>) -> Self {
51 Self::Skip {
52 reason: reason.into(),
53 }
54 }
55
56 pub fn terminate(reason: impl Into<String>) -> Self {
58 Self::Terminate {
59 reason: reason.into(),
60 }
61 }
62}
63
64#[derive(Debug, Clone, PartialEq, Eq)]
66pub enum HookAction {
67 Continue,
69 Terminate { reason: String },
71}
72
73impl HookAction {
74 pub fn cont() -> Self {
76 Self::Continue
77 }
78
79 pub fn terminate(reason: impl Into<String>) -> Self {
81 Self::Terminate {
82 reason: reason.into(),
83 }
84 }
85}
86
87pub struct PromptRequest<'a, S, M, P>
96where
97 S: PromptType,
98 M: CompletionModel,
99 P: PromptHook<M>,
100{
101 prompt: Message,
103 chat_history: Option<&'a mut Vec<Message>>,
106 max_turns: usize,
108 agent: &'a Agent<M>,
110 state: PhantomData<S>,
112 hook: Option<P>,
114 concurrency: usize,
116}
117
118impl<'a, M> PromptRequest<'a, Standard, M, ()>
119where
120 M: CompletionModel,
121{
122 pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
124 Self {
125 prompt: prompt.into(),
126 chat_history: None,
127 max_turns: agent.default_max_turns.unwrap_or_default(),
128 agent,
129 state: PhantomData,
130 hook: None,
131 concurrency: 1,
132 }
133 }
134}
135
136impl<'a, S, M, P> PromptRequest<'a, S, M, P>
137where
138 S: PromptType,
139 M: CompletionModel,
140 P: PromptHook<M>,
141{
142 pub fn extended_details(self) -> PromptRequest<'a, Extended, M, P> {
148 PromptRequest {
149 prompt: self.prompt,
150 chat_history: self.chat_history,
151 max_turns: self.max_turns,
152 agent: self.agent,
153 state: PhantomData,
154 hook: self.hook,
155 concurrency: self.concurrency,
156 }
157 }
158 pub fn max_turns(self, depth: usize) -> PromptRequest<'a, S, M, P> {
161 PromptRequest {
162 prompt: self.prompt,
163 chat_history: self.chat_history,
164 max_turns: depth,
165 agent: self.agent,
166 state: PhantomData,
167 hook: self.hook,
168 concurrency: self.concurrency,
169 }
170 }
171
172 pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
175 self.concurrency = concurrency;
176 self
177 }
178
179 pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, S, M, P> {
181 PromptRequest {
182 prompt: self.prompt,
183 chat_history: Some(history),
184 max_turns: self.max_turns,
185 agent: self.agent,
186 state: PhantomData,
187 hook: self.hook,
188 concurrency: self.concurrency,
189 }
190 }
191
192 pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<'a, S, M, P2>
194 where
195 P2: PromptHook<M>,
196 {
197 PromptRequest {
198 prompt: self.prompt,
199 chat_history: self.chat_history,
200 max_turns: self.max_turns,
201 agent: self.agent,
202 state: PhantomData,
203 hook: Some(hook),
204 concurrency: self.concurrency,
205 }
206 }
207}
208
209pub trait PromptHook<M>: Clone + WasmCompatSend + WasmCompatSync
212where
213 M: CompletionModel,
214{
215 fn on_completion_call(
217 &self,
218 _prompt: &Message,
219 _history: &[Message],
220 ) -> impl Future<Output = HookAction> + WasmCompatSend {
221 async { HookAction::cont() }
222 }
223
224 fn on_completion_response(
226 &self,
227 _prompt: &Message,
228 _response: &crate::completion::CompletionResponse<M::Response>,
229 ) -> impl Future<Output = HookAction> + WasmCompatSend {
230 async { HookAction::cont() }
231 }
232
233 fn on_tool_call(
239 &self,
240 _tool_name: &str,
241 _tool_call_id: Option<String>,
242 _internal_call_id: &str,
243 _args: &str,
244 ) -> impl Future<Output = ToolCallHookAction> + WasmCompatSend {
245 async { ToolCallHookAction::cont() }
246 }
247
248 fn on_tool_result(
250 &self,
251 _tool_name: &str,
252 _tool_call_id: Option<String>,
253 _internal_call_id: &str,
254 _args: &str,
255 _result: &str,
256 ) -> impl Future<Output = HookAction> + WasmCompatSend {
257 async { HookAction::cont() }
258 }
259}
260
261impl<M> PromptHook<M> for () where M: CompletionModel {}
262
263impl<'a, M, P> IntoFuture for PromptRequest<'a, Standard, M, P>
267where
268 M: CompletionModel,
269 P: PromptHook<M> + 'static,
270{
271 type Output = Result<String, PromptError>;
272 type IntoFuture = WasmBoxedFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
275 Box::pin(self.send())
276 }
277}
278
279impl<'a, M, P> IntoFuture for PromptRequest<'a, Extended, M, P>
280where
281 M: CompletionModel,
282 P: PromptHook<M> + 'static,
283{
284 type Output = Result<PromptResponse, PromptError>;
285 type IntoFuture = WasmBoxedFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
288 Box::pin(self.send())
289 }
290}
291
292impl<M, P> PromptRequest<'_, Standard, M, P>
293where
294 M: CompletionModel,
295 P: PromptHook<M>,
296{
297 async fn send(self) -> Result<String, PromptError> {
298 self.extended_details().send().await.map(|resp| resp.output)
299 }
300}
301
302#[derive(Debug, Clone)]
303pub struct PromptResponse {
304 pub output: String,
305 pub total_usage: Usage,
306}
307
308impl PromptResponse {
309 pub fn new(output: impl Into<String>, total_usage: Usage) -> Self {
310 Self {
311 output: output.into(),
312 total_usage,
313 }
314 }
315}
316
317impl<M, P> PromptRequest<'_, Extended, M, P>
318where
319 M: CompletionModel,
320 P: PromptHook<M>,
321{
322 async fn send(self) -> Result<PromptResponse, PromptError> {
323 let agent_span = if tracing::Span::current().is_disabled() {
324 info_span!(
325 "invoke_agent",
326 gen_ai.operation.name = "invoke_agent",
327 gen_ai.agent.name = self.agent.name(),
328 gen_ai.system_instructions = self.agent.preamble,
329 gen_ai.prompt = tracing::field::Empty,
330 gen_ai.completion = tracing::field::Empty,
331 gen_ai.usage.input_tokens = tracing::field::Empty,
332 gen_ai.usage.output_tokens = tracing::field::Empty,
333 )
334 } else {
335 tracing::Span::current()
336 };
337
338 let agent = self.agent;
339 let chat_history = if let Some(history) = self.chat_history {
340 history.push(self.prompt.to_owned());
341 history
342 } else {
343 &mut vec![self.prompt.to_owned()]
344 };
345
346 if let Some(text) = self.prompt.rag_text() {
347 agent_span.record("gen_ai.prompt", text);
348 }
349
350 let mut current_max_turns = 0;
351 let mut usage = Usage::new();
352 let current_span_id: AtomicU64 = AtomicU64::new(0);
353
354 let last_prompt = loop {
356 let prompt = chat_history
357 .last()
358 .cloned()
359 .expect("there should always be at least one message in the chat history");
360
361 if current_max_turns > self.max_turns + 1 {
362 break prompt;
363 }
364
365 current_max_turns += 1;
366
367 if self.max_turns > 1 {
368 tracing::info!(
369 "Current conversation depth: {}/{}",
370 current_max_turns,
371 self.max_turns
372 );
373 }
374
375 if let Some(ref hook) = self.hook
376 && let HookAction::Terminate { reason } = hook
377 .on_completion_call(&prompt, &chat_history[..chat_history.len() - 1])
378 .await
379 {
380 return Err(PromptError::prompt_cancelled(chat_history.to_vec(), reason));
381 }
382
383 let span = tracing::Span::current();
384 let chat_span = info_span!(
385 target: "rig::agent_chat",
386 parent: &span,
387 "chat",
388 gen_ai.operation.name = "chat",
389 gen_ai.agent.name = self.agent.name(),
390 gen_ai.system_instructions = self.agent.preamble,
391 gen_ai.provider.name = tracing::field::Empty,
392 gen_ai.request.model = tracing::field::Empty,
393 gen_ai.response.id = tracing::field::Empty,
394 gen_ai.response.model = tracing::field::Empty,
395 gen_ai.usage.output_tokens = tracing::field::Empty,
396 gen_ai.usage.input_tokens = tracing::field::Empty,
397 gen_ai.input.messages = tracing::field::Empty,
398 gen_ai.output.messages = tracing::field::Empty,
399 );
400
401 let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
402 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
403 chat_span.follows_from(id).to_owned()
404 } else {
405 chat_span
406 };
407
408 if let Some(id) = chat_span.id() {
409 current_span_id.store(id.into_u64(), Ordering::SeqCst);
410 };
411
412 let resp = agent
413 .completion(
414 prompt.clone(),
415 chat_history[..chat_history.len() - 1].to_vec(),
416 )
417 .await?
418 .send()
419 .instrument(chat_span.clone())
420 .await?;
421
422 usage += resp.usage;
423
424 if let Some(ref hook) = self.hook
425 && let HookAction::Terminate { reason } =
426 hook.on_completion_response(&prompt, &resp).await
427 {
428 return Err(PromptError::prompt_cancelled(chat_history.to_vec(), reason));
429 }
430
431 let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
432 .choice
433 .iter()
434 .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
435
436 chat_history.push(Message::Assistant {
437 id: None,
438 content: resp.choice.clone(),
439 });
440
441 if tool_calls.is_empty() {
442 let merged_texts = texts
443 .into_iter()
444 .filter_map(|content| {
445 if let AssistantContent::Text(text) = content {
446 Some(text.text.clone())
447 } else {
448 None
449 }
450 })
451 .collect::<Vec<_>>()
452 .join("\n");
453
454 if self.max_turns > 1 {
455 tracing::info!("Depth reached: {}/{}", current_max_turns, self.max_turns);
456 }
457
458 agent_span.record("gen_ai.completion", &merged_texts);
459 agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
460 agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
461
462 return Ok(PromptResponse::new(merged_texts, usage));
464 }
465
466 let hook = self.hook.clone();
467
468 let tool_calls: Vec<AssistantContent> = tool_calls.into_iter().cloned().collect();
469 let tool_content = stream::iter(tool_calls)
470 .map(|choice| {
471 let hook1 = hook.clone();
472 let hook2 = hook.clone();
473
474 let tool_span = info_span!(
475 "execute_tool",
476 gen_ai.operation.name = "execute_tool",
477 gen_ai.tool.type = "function",
478 gen_ai.tool.name = tracing::field::Empty,
479 gen_ai.tool.call.id = tracing::field::Empty,
480 gen_ai.tool.call.arguments = tracing::field::Empty,
481 gen_ai.tool.call.result = tracing::field::Empty
482 );
483
484 let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
485 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
486 tool_span.follows_from(id).to_owned()
487 } else {
488 tool_span
489 };
490
491 if let Some(id) = tool_span.id() {
492 current_span_id.store(id.into_u64(), Ordering::SeqCst);
493 };
494
495 let cloned_chat_history = chat_history.clone().to_vec();
496
497 async move {
498 if let AssistantContent::ToolCall(tool_call) = choice {
499 let tool_name = &tool_call.function.name;
500 let args =
501 json_utils::value_to_json_string(&tool_call.function.arguments);
502 let internal_call_id = nanoid::nanoid!();
503 let tool_span = tracing::Span::current();
504 tool_span.record("gen_ai.tool.name", tool_name);
505 tool_span.record("gen_ai.tool.call.id", &tool_call.id);
506 tool_span.record("gen_ai.tool.call.arguments", &args);
507 if let Some(hook) = hook1 {
508 let action = hook
509 .on_tool_call(
510 tool_name,
511 tool_call.call_id.clone(),
512 &internal_call_id,
513 &args,
514 )
515 .await;
516
517 if let ToolCallHookAction::Terminate { reason } = action {
518 return Err(PromptError::prompt_cancelled(
519 cloned_chat_history,
520 reason,
521 ));
522 }
523
524 if let ToolCallHookAction::Skip { reason } = action {
525 tracing::info!(
527 tool_name = tool_name,
528 reason = reason,
529 "Tool call rejected"
530 );
531 if let Some(call_id) = tool_call.call_id.clone() {
532 return Ok(UserContent::tool_result_with_call_id(
533 tool_call.id.clone(),
534 call_id,
535 OneOrMany::one(reason.into()),
536 ));
537 } else {
538 return Ok(UserContent::tool_result(
539 tool_call.id.clone(),
540 OneOrMany::one(reason.into()),
541 ));
542 }
543 }
544 }
545 let output =
546 match agent.tool_server_handle.call_tool(tool_name, &args).await {
547 Ok(res) => res,
548 Err(e) => {
549 tracing::warn!("Error while executing tool: {e}");
550 e.to_string()
551 }
552 };
553 if let Some(hook) = hook2
554 && let HookAction::Terminate { reason } = hook
555 .on_tool_result(
556 tool_name,
557 tool_call.call_id.clone(),
558 &internal_call_id,
559 &args,
560 &output.to_string(),
561 )
562 .await
563 {
564 return Err(PromptError::prompt_cancelled(
565 cloned_chat_history,
566 reason,
567 ));
568 }
569
570 tool_span.record("gen_ai.tool.call.result", &output);
571 tracing::info!(
572 "executed tool {tool_name} with args {args}. result: {output}"
573 );
574 if let Some(call_id) = tool_call.call_id.clone() {
575 Ok(UserContent::tool_result_with_call_id(
576 tool_call.id.clone(),
577 call_id,
578 ToolResultContent::from_tool_output(output),
579 ))
580 } else {
581 Ok(UserContent::tool_result(
582 tool_call.id.clone(),
583 ToolResultContent::from_tool_output(output),
584 ))
585 }
586 } else {
587 unreachable!(
588 "This should never happen as we already filtered for `ToolCall`"
589 )
590 }
591 }
592 .instrument(tool_span)
593 })
594 .buffer_unordered(self.concurrency)
595 .collect::<Vec<Result<UserContent, PromptError>>>()
596 .await
597 .into_iter()
598 .collect::<Result<Vec<_>, _>>()?;
599
600 chat_history.push(Message::User {
601 content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
602 });
603 };
604
605 Err(PromptError::MaxTurnsError {
607 max_turns: self.max_turns,
608 chat_history: Box::new(chat_history.clone()),
609 prompt: Box::new(last_prompt),
610 })
611 }
612}