1pub mod hooks;
2pub mod streaming;
3
4use super::{
5 Agent,
6 completion::{DynamicContextStore, build_completion_request},
7};
8use crate::{
9 OneOrMany,
10 completion::{CompletionModel, Document, Message, PromptError, Usage},
11 json_utils,
12 message::{AssistantContent, ToolChoice, ToolResultContent, UserContent},
13 tool::server::ToolServerHandle,
14 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
15};
16use futures::{StreamExt, stream};
17use hooks::{HookAction, PromptHook, ToolCallHookAction};
18use std::{
19 future::IntoFuture,
20 marker::PhantomData,
21 sync::{
22 Arc,
23 atomic::{AtomicU64, Ordering},
24 },
25};
26use tracing::info_span;
27use tracing::{Instrument, span::Id};
28
29pub trait PromptType {}
30pub struct Standard;
31pub struct Extended;
32
33impl PromptType for Standard {}
34impl PromptType for Extended {}
35
36pub struct PromptRequest<S, M, P>
45where
46 S: PromptType,
47 M: CompletionModel,
48 P: PromptHook<M>,
49{
50 prompt: Message,
52 chat_history: Option<Vec<Message>>,
54 max_turns: usize,
56
57 model: Arc<M>,
60 agent_name: Option<String>,
62 preamble: Option<String>,
64 static_context: Vec<Document>,
66 temperature: Option<f64>,
68 max_tokens: Option<u64>,
70 additional_params: Option<serde_json::Value>,
72 tool_server_handle: ToolServerHandle,
74 dynamic_context: DynamicContextStore,
76 tool_choice: Option<ToolChoice>,
78
79 state: PhantomData<S>,
81 hook: Option<P>,
83 concurrency: usize,
85 output_schema: Option<schemars::Schema>,
87}
88
89impl<M, P> PromptRequest<Standard, M, P>
90where
91 M: CompletionModel,
92 P: PromptHook<M>,
93{
94 pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
96 PromptRequest {
97 prompt: prompt.into(),
98 chat_history: None,
99 max_turns: agent.default_max_turns.unwrap_or_default(),
100 model: agent.model.clone(),
101 agent_name: agent.name.clone(),
102 preamble: agent.preamble.clone(),
103 static_context: agent.static_context.clone(),
104 temperature: agent.temperature,
105 max_tokens: agent.max_tokens,
106 additional_params: agent.additional_params.clone(),
107 tool_server_handle: agent.tool_server_handle.clone(),
108 dynamic_context: agent.dynamic_context.clone(),
109 tool_choice: agent.tool_choice.clone(),
110 state: PhantomData,
111 hook: agent.hook.clone(),
112 concurrency: 1,
113 output_schema: agent.output_schema.clone(),
114 }
115 }
116}
117
118impl<S, M, P> PromptRequest<S, M, P>
119where
120 S: PromptType,
121 M: CompletionModel,
122 P: PromptHook<M>,
123{
124 pub fn extended_details(self) -> PromptRequest<Extended, M, P> {
131 PromptRequest {
132 prompt: self.prompt,
133 chat_history: self.chat_history,
134 max_turns: self.max_turns,
135 model: self.model,
136 agent_name: self.agent_name,
137 preamble: self.preamble,
138 static_context: self.static_context,
139 temperature: self.temperature,
140 max_tokens: self.max_tokens,
141 additional_params: self.additional_params,
142 tool_server_handle: self.tool_server_handle,
143 dynamic_context: self.dynamic_context,
144 tool_choice: self.tool_choice,
145 state: PhantomData,
146 hook: self.hook,
147 concurrency: self.concurrency,
148 output_schema: self.output_schema,
149 }
150 }
151
152 pub fn max_turns(mut self, depth: usize) -> Self {
155 self.max_turns = depth;
156 self
157 }
158
159 pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
162 self.concurrency = concurrency;
163 self
164 }
165
166 pub fn with_history<I, T>(mut self, history: I) -> Self
168 where
169 I: IntoIterator<Item = T>,
170 T: Into<Message>,
171 {
172 self.chat_history = Some(history.into_iter().map(Into::into).collect());
173 self
174 }
175
176 pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<S, M, P2>
179 where
180 P2: PromptHook<M>,
181 {
182 PromptRequest {
183 prompt: self.prompt,
184 chat_history: self.chat_history,
185 max_turns: self.max_turns,
186 model: self.model,
187 agent_name: self.agent_name,
188 preamble: self.preamble,
189 static_context: self.static_context,
190 temperature: self.temperature,
191 max_tokens: self.max_tokens,
192 additional_params: self.additional_params,
193 tool_server_handle: self.tool_server_handle,
194 dynamic_context: self.dynamic_context,
195 tool_choice: self.tool_choice,
196 state: PhantomData,
197 hook: Some(hook),
198 concurrency: self.concurrency,
199 output_schema: self.output_schema,
200 }
201 }
202}
203
204impl<M, P> IntoFuture for PromptRequest<Standard, M, P>
208where
209 M: CompletionModel + 'static,
210 P: PromptHook<M> + 'static,
211{
212 type Output = Result<String, PromptError>;
213 type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
214
215 fn into_future(self) -> Self::IntoFuture {
216 Box::pin(self.send())
217 }
218}
219
220impl<M, P> IntoFuture for PromptRequest<Extended, M, P>
221where
222 M: CompletionModel + 'static,
223 P: PromptHook<M> + 'static,
224{
225 type Output = Result<PromptResponse, PromptError>;
226 type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
227
228 fn into_future(self) -> Self::IntoFuture {
229 Box::pin(self.send())
230 }
231}
232
233impl<M, P> PromptRequest<Standard, M, P>
234where
235 M: CompletionModel,
236 P: PromptHook<M>,
237{
238 async fn send(self) -> Result<String, PromptError> {
239 self.extended_details().send().await.map(|resp| resp.output)
240 }
241}
242
243#[derive(Debug, Clone)]
244#[non_exhaustive]
245pub struct PromptResponse {
246 pub output: String,
247 pub usage: Usage,
248 pub messages: Option<Vec<Message>>,
249}
250
251impl std::fmt::Display for PromptResponse {
252 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253 self.output.fmt(f)
254 }
255}
256
257impl PromptResponse {
258 pub fn new(output: impl Into<String>, usage: Usage) -> Self {
259 Self {
260 output: output.into(),
261 usage,
262 messages: None,
263 }
264 }
265
266 pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
267 self.messages = Some(messages);
268 self
269 }
270}
271
272#[derive(Debug, Clone)]
273pub struct TypedPromptResponse<T> {
274 pub output: T,
275 pub usage: Usage,
276}
277
278impl<T> TypedPromptResponse<T> {
279 pub fn new(output: T, usage: Usage) -> Self {
280 Self { output, usage }
281 }
282}
283
284const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
285
286fn build_history_for_request(
288 chat_history: Option<&[Message]>,
289 new_messages: &[Message],
290) -> Vec<Message> {
291 let input = chat_history.unwrap_or(&[]);
292 input.iter().chain(new_messages.iter()).cloned().collect()
293}
294
295fn build_full_history(
297 chat_history: Option<&[Message]>,
298 new_messages: Vec<Message>,
299) -> Vec<Message> {
300 let input = chat_history.unwrap_or(&[]);
301 input.iter().cloned().chain(new_messages).collect()
302}
303
304impl<M, P> PromptRequest<Extended, M, P>
305where
306 M: CompletionModel,
307 P: PromptHook<M>,
308{
309 fn agent_name(&self) -> &str {
310 self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
311 }
312
313 async fn send(self) -> Result<PromptResponse, PromptError> {
314 let agent_span = if tracing::Span::current().is_disabled() {
315 info_span!(
316 "invoke_agent",
317 gen_ai.operation.name = "invoke_agent",
318 gen_ai.agent.name = self.agent_name(),
319 gen_ai.system_instructions = self.preamble,
320 gen_ai.prompt = tracing::field::Empty,
321 gen_ai.completion = tracing::field::Empty,
322 gen_ai.usage.input_tokens = tracing::field::Empty,
323 gen_ai.usage.output_tokens = tracing::field::Empty,
324 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
325 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
326 )
327 } else {
328 tracing::Span::current()
329 };
330
331 if let Some(text) = self.prompt.rag_text() {
332 agent_span.record("gen_ai.prompt", text);
333 }
334
335 let agent_name_for_span = self.agent_name.clone();
336 let chat_history = self.chat_history;
337 let mut new_messages: Vec<Message> = vec![self.prompt.clone()];
338
339 let mut current_max_turns = 0;
340 let mut usage = Usage::new();
341 let current_span_id: AtomicU64 = AtomicU64::new(0);
342
343 let last_prompt = loop {
345 let prompt = new_messages
347 .last()
348 .expect("there should always be at least one message")
349 .clone();
350
351 if current_max_turns > self.max_turns + 1 {
352 break prompt;
353 }
354
355 current_max_turns += 1;
356
357 if self.max_turns > 1 {
358 tracing::info!(
359 "Current conversation depth: {}/{}",
360 current_max_turns,
361 self.max_turns
362 );
363 }
364
365 let history_for_hook = build_history_for_request(
367 chat_history.as_deref(),
368 &new_messages[..new_messages.len().saturating_sub(1)],
369 );
370
371 if let Some(ref hook) = self.hook
372 && let HookAction::Terminate { reason } =
373 hook.on_completion_call(&prompt, &history_for_hook).await
374 {
375 return Err(PromptError::prompt_cancelled(
376 build_full_history(chat_history.as_deref(), new_messages),
377 reason,
378 ));
379 }
380
381 let span = tracing::Span::current();
382 let chat_span = info_span!(
383 target: "rig::agent_chat",
384 parent: &span,
385 "chat",
386 gen_ai.operation.name = "chat",
387 gen_ai.agent.name = agent_name_for_span.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
388 gen_ai.system_instructions = self.preamble,
389 gen_ai.provider.name = tracing::field::Empty,
390 gen_ai.request.model = tracing::field::Empty,
391 gen_ai.response.id = tracing::field::Empty,
392 gen_ai.response.model = tracing::field::Empty,
393 gen_ai.usage.output_tokens = tracing::field::Empty,
394 gen_ai.usage.input_tokens = tracing::field::Empty,
395 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
396 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
397 gen_ai.input.messages = tracing::field::Empty,
398 gen_ai.output.messages = tracing::field::Empty,
399 );
400
401 let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
402 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
403 chat_span.follows_from(id).to_owned()
404 } else {
405 chat_span
406 };
407
408 if let Some(id) = chat_span.id() {
409 current_span_id.store(id.into_u64(), Ordering::SeqCst);
410 };
411
412 let history_for_request = build_history_for_request(
414 chat_history.as_deref(),
415 &new_messages[..new_messages.len().saturating_sub(1)],
416 );
417
418 let resp = build_completion_request(
419 &self.model,
420 prompt.clone(),
421 &history_for_request,
422 self.preamble.as_deref(),
423 &self.static_context,
424 self.temperature,
425 self.max_tokens,
426 self.additional_params.as_ref(),
427 self.tool_choice.as_ref(),
428 &self.tool_server_handle,
429 &self.dynamic_context,
430 self.output_schema.as_ref(),
431 )
432 .await?
433 .send()
434 .instrument(chat_span.clone())
435 .await?;
436
437 usage += resp.usage;
438
439 if let Some(ref hook) = self.hook
440 && let HookAction::Terminate { reason } =
441 hook.on_completion_response(&prompt, &resp).await
442 {
443 return Err(PromptError::prompt_cancelled(
444 build_full_history(chat_history.as_deref(), new_messages),
445 reason,
446 ));
447 }
448
449 let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
450 .choice
451 .iter()
452 .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
453
454 new_messages.push(Message::Assistant {
455 id: resp.message_id.clone(),
456 content: resp.choice.clone(),
457 });
458
459 if tool_calls.is_empty() {
460 let merged_texts = texts
461 .into_iter()
462 .filter_map(|content| {
463 if let AssistantContent::Text(text) = content {
464 Some(text.text.clone())
465 } else {
466 None
467 }
468 })
469 .collect::<Vec<_>>()
470 .join("\n");
471
472 if self.max_turns > 1 {
473 tracing::info!("Depth reached: {}/{}", current_max_turns, self.max_turns);
474 }
475
476 agent_span.record("gen_ai.completion", &merged_texts);
477 agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
478 agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
479 agent_span.record(
480 "gen_ai.usage.cache_read.input_tokens",
481 usage.cached_input_tokens,
482 );
483 agent_span.record(
484 "gen_ai.usage.cache_creation.input_tokens",
485 usage.cache_creation_input_tokens,
486 );
487
488 return Ok(PromptResponse::new(merged_texts, usage).with_messages(new_messages));
489 }
490
491 let hook = self.hook.clone();
492 let tool_server_handle = self.tool_server_handle.clone();
493
494 let full_history_for_errors =
496 build_full_history(chat_history.as_deref(), new_messages.clone());
497
498 let tool_calls: Vec<AssistantContent> = tool_calls.into_iter().cloned().collect();
499 let tool_content = stream::iter(tool_calls)
500 .map(|choice| {
501 let hook1 = hook.clone();
502 let hook2 = hook.clone();
503 let tool_server_handle = tool_server_handle.clone();
504
505 let tool_span = info_span!(
506 "execute_tool",
507 gen_ai.operation.name = "execute_tool",
508 gen_ai.tool.type = "function",
509 gen_ai.tool.name = tracing::field::Empty,
510 gen_ai.tool.call.id = tracing::field::Empty,
511 gen_ai.tool.call.arguments = tracing::field::Empty,
512 gen_ai.tool.call.result = tracing::field::Empty
513 );
514
515 let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
516 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
517 tool_span.follows_from(id).to_owned()
518 } else {
519 tool_span
520 };
521
522 if let Some(id) = tool_span.id() {
523 current_span_id.store(id.into_u64(), Ordering::SeqCst);
524 };
525
526 let cloned_history_for_error = full_history_for_errors.clone();
528
529 async move {
530 if let AssistantContent::ToolCall(tool_call) = choice {
531 let tool_name = &tool_call.function.name;
532 let args =
533 json_utils::value_to_json_string(&tool_call.function.arguments);
534 let internal_call_id = nanoid::nanoid!();
535 let tool_span = tracing::Span::current();
536 tool_span.record("gen_ai.tool.name", tool_name);
537 tool_span.record("gen_ai.tool.call.id", &tool_call.id);
538 tool_span.record("gen_ai.tool.call.arguments", &args);
539 if let Some(hook) = hook1 {
540 let action = hook
541 .on_tool_call(
542 tool_name,
543 tool_call.call_id.clone(),
544 &internal_call_id,
545 &args,
546 )
547 .await;
548
549 if let ToolCallHookAction::Terminate { reason } = action {
550 return Err(PromptError::prompt_cancelled(
551 cloned_history_for_error,
552 reason,
553 ));
554 }
555
556 if let ToolCallHookAction::Skip { reason } = action {
557 tracing::info!(
559 tool_name = tool_name,
560 reason = reason,
561 "Tool call rejected"
562 );
563 if let Some(call_id) = tool_call.call_id.clone() {
564 return Ok(UserContent::tool_result_with_call_id(
565 tool_call.id.clone(),
566 call_id,
567 OneOrMany::one(reason.into()),
568 ));
569 } else {
570 return Ok(UserContent::tool_result(
571 tool_call.id.clone(),
572 OneOrMany::one(reason.into()),
573 ));
574 }
575 }
576 }
577 let output = match tool_server_handle.call_tool(tool_name, &args).await
578 {
579 Ok(res) => res,
580 Err(e) => {
581 tracing::warn!("Error while executing tool: {e}");
582 e.to_string()
583 }
584 };
585 if let Some(hook) = hook2
586 && let HookAction::Terminate { reason } = hook
587 .on_tool_result(
588 tool_name,
589 tool_call.call_id.clone(),
590 &internal_call_id,
591 &args,
592 &output.to_string(),
593 )
594 .await
595 {
596 return Err(PromptError::prompt_cancelled(
597 cloned_history_for_error,
598 reason,
599 ));
600 }
601
602 tool_span.record("gen_ai.tool.call.result", &output);
603 tracing::info!(
604 "executed tool {tool_name} with args {args}. result: {output}"
605 );
606 if let Some(call_id) = tool_call.call_id.clone() {
607 Ok(UserContent::tool_result_with_call_id(
608 tool_call.id.clone(),
609 call_id,
610 ToolResultContent::from_tool_output(output),
611 ))
612 } else {
613 Ok(UserContent::tool_result(
614 tool_call.id.clone(),
615 ToolResultContent::from_tool_output(output),
616 ))
617 }
618 } else {
619 unreachable!(
620 "This should never happen as we already filtered for `ToolCall`"
621 )
622 }
623 }
624 .instrument(tool_span)
625 })
626 .buffer_unordered(self.concurrency)
627 .collect::<Vec<Result<UserContent, PromptError>>>()
628 .await
629 .into_iter()
630 .collect::<Result<Vec<_>, _>>()?;
631
632 new_messages.push(Message::User {
633 content: OneOrMany::many(tool_content).expect("There is at least one tool call"),
634 });
635 };
636
637 Err(PromptError::MaxTurnsError {
639 max_turns: self.max_turns,
640 chat_history: build_full_history(chat_history.as_deref(), new_messages).into(),
641 prompt: last_prompt.into(),
642 })
643 }
644}
645
646use crate::completion::StructuredOutputError;
651use schemars::{JsonSchema, schema_for};
652use serde::de::DeserializeOwned;
653
654pub struct TypedPromptRequest<T, S, M, P>
671where
672 T: JsonSchema + DeserializeOwned + WasmCompatSend,
673 S: PromptType,
674 M: CompletionModel,
675 P: PromptHook<M>,
676{
677 inner: PromptRequest<S, M, P>,
678 _phantom: std::marker::PhantomData<T>,
679}
680
681impl<T, M, P> TypedPromptRequest<T, Standard, M, P>
682where
683 T: JsonSchema + DeserializeOwned + WasmCompatSend,
684 M: CompletionModel,
685 P: PromptHook<M>,
686{
687 pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
691 let mut inner = PromptRequest::from_agent(agent, prompt);
692 inner.output_schema = Some(schema_for!(T));
694 Self {
695 inner,
696 _phantom: std::marker::PhantomData,
697 }
698 }
699}
700
701impl<T, S, M, P> TypedPromptRequest<T, S, M, P>
702where
703 T: JsonSchema + DeserializeOwned + WasmCompatSend,
704 S: PromptType,
705 M: CompletionModel,
706 P: PromptHook<M>,
707{
708 pub fn extended_details(self) -> TypedPromptRequest<T, Extended, M, P> {
714 TypedPromptRequest {
715 inner: self.inner.extended_details(),
716 _phantom: std::marker::PhantomData,
717 }
718 }
719
720 pub fn max_turns(mut self, depth: usize) -> Self {
726 self.inner = self.inner.max_turns(depth);
727 self
728 }
729
730 pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
734 self.inner = self.inner.with_tool_concurrency(concurrency);
735 self
736 }
737
738 pub fn with_history<I, H>(mut self, history: I) -> Self
740 where
741 I: IntoIterator<Item = H>,
742 H: Into<Message>,
743 {
744 self.inner = self.inner.with_history(history);
745 self
746 }
747
748 pub fn with_hook<P2>(self, hook: P2) -> TypedPromptRequest<T, S, M, P2>
752 where
753 P2: PromptHook<M>,
754 {
755 TypedPromptRequest {
756 inner: self.inner.with_hook(hook),
757 _phantom: std::marker::PhantomData,
758 }
759 }
760}
761
762impl<T, M, P> TypedPromptRequest<T, Standard, M, P>
763where
764 T: JsonSchema + DeserializeOwned + WasmCompatSend,
765 M: CompletionModel,
766 P: PromptHook<M>,
767{
768 async fn send(self) -> Result<T, StructuredOutputError> {
770 let response = self.inner.send().await.map_err(Box::new)?;
771
772 if response.is_empty() {
773 return Err(StructuredOutputError::EmptyResponse);
774 }
775
776 let parsed: T = serde_json::from_str(&response)?;
777 Ok(parsed)
778 }
779}
780
781impl<T, M, P> TypedPromptRequest<T, Extended, M, P>
782where
783 T: JsonSchema + DeserializeOwned + WasmCompatSend,
784 M: CompletionModel,
785 P: PromptHook<M>,
786{
787 async fn send(self) -> Result<TypedPromptResponse<T>, StructuredOutputError> {
789 let response = self.inner.send().await.map_err(Box::new)?;
790
791 if response.output.is_empty() {
792 return Err(StructuredOutputError::EmptyResponse);
793 }
794
795 let parsed: T = serde_json::from_str(&response.output)?;
796 Ok(TypedPromptResponse::new(parsed, response.usage))
797 }
798}
799
800impl<T, M, P> IntoFuture for TypedPromptRequest<T, Standard, M, P>
801where
802 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static,
803 M: CompletionModel + 'static,
804 P: PromptHook<M> + 'static,
805{
806 type Output = Result<T, StructuredOutputError>;
807 type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
808
809 fn into_future(self) -> Self::IntoFuture {
810 Box::pin(self.send())
811 }
812}
813
814impl<T, M, P> IntoFuture for TypedPromptRequest<T, Extended, M, P>
815where
816 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static,
817 M: CompletionModel + 'static,
818 P: PromptHook<M> + 'static,
819{
820 type Output = Result<TypedPromptResponse<T>, StructuredOutputError>;
821 type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
822
823 fn into_future(self) -> Self::IntoFuture {
824 Box::pin(self.send())
825 }
826}