1pub mod hooks;
2pub mod streaming;
3
4use hooks::{HookAction, PromptHook, ToolCallHookAction};
5use std::{
6 future::IntoFuture,
7 marker::PhantomData,
8 sync::{
9 Arc,
10 atomic::{AtomicU64, Ordering},
11 },
12};
13use tracing::{Instrument, span::Id};
14
15use futures::{StreamExt, stream};
16use tracing::info_span;
17
18use crate::{
19 OneOrMany,
20 completion::{CompletionModel, Document, Message, PromptError, Usage},
21 json_utils,
22 message::{AssistantContent, ToolChoice, ToolResultContent, UserContent},
23 tool::server::ToolServerHandle,
24 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
25};
26
27use super::{
28 Agent,
29 completion::{DynamicContextStore, build_completion_request},
30};
31
32pub trait PromptType {}
33pub struct Standard;
34pub struct Extended;
35
36impl PromptType for Standard {}
37impl PromptType for Extended {}
38
39pub struct PromptRequest<'a, S, M, P>
48where
49 S: PromptType,
50 M: CompletionModel,
51 P: PromptHook<M>,
52{
53 prompt: Message,
55 chat_history: Option<&'a mut Vec<Message>>,
58 max_turns: usize,
60
61 model: Arc<M>,
64 agent_name: Option<String>,
66 preamble: Option<String>,
68 static_context: Vec<Document>,
70 temperature: Option<f64>,
72 max_tokens: Option<u64>,
74 additional_params: Option<serde_json::Value>,
76 tool_server_handle: ToolServerHandle,
78 dynamic_context: DynamicContextStore,
80 tool_choice: Option<ToolChoice>,
82
83 state: PhantomData<S>,
85 hook: Option<P>,
87 concurrency: usize,
89 output_schema: Option<schemars::Schema>,
91}
92
93impl<'a, M, P> PromptRequest<'a, Standard, M, P>
94where
95 M: CompletionModel,
96 P: PromptHook<M>,
97{
98 pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
100 PromptRequest {
101 prompt: prompt.into(),
102 chat_history: None,
103 max_turns: agent.default_max_turns.unwrap_or_default(),
104 model: agent.model.clone(),
105 agent_name: agent.name.clone(),
106 preamble: agent.preamble.clone(),
107 static_context: agent.static_context.clone(),
108 temperature: agent.temperature,
109 max_tokens: agent.max_tokens,
110 additional_params: agent.additional_params.clone(),
111 tool_server_handle: agent.tool_server_handle.clone(),
112 dynamic_context: agent.dynamic_context.clone(),
113 tool_choice: agent.tool_choice.clone(),
114 state: PhantomData,
115 hook: agent.hook.clone(),
116 concurrency: 1,
117 output_schema: agent.output_schema.clone(),
118 }
119 }
120}
121
122impl<'a, S, M, P> PromptRequest<'a, S, M, P>
123where
124 S: PromptType,
125 M: CompletionModel,
126 P: PromptHook<M>,
127{
128 pub fn extended_details(self) -> PromptRequest<'a, Extended, M, P> {
135 PromptRequest {
136 prompt: self.prompt,
137 chat_history: self.chat_history,
138 max_turns: self.max_turns,
139 model: self.model,
140 agent_name: self.agent_name,
141 preamble: self.preamble,
142 static_context: self.static_context,
143 temperature: self.temperature,
144 max_tokens: self.max_tokens,
145 additional_params: self.additional_params,
146 tool_server_handle: self.tool_server_handle,
147 dynamic_context: self.dynamic_context,
148 tool_choice: self.tool_choice,
149 state: PhantomData,
150 hook: self.hook,
151 concurrency: self.concurrency,
152 output_schema: self.output_schema,
153 }
154 }
155
156 pub fn max_turns(mut self, depth: usize) -> Self {
159 self.max_turns = depth;
160 self
161 }
162
163 pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
166 self.concurrency = concurrency;
167 self
168 }
169
170 pub fn with_history(mut self, history: &'a mut Vec<Message>) -> Self {
172 self.chat_history = Some(history);
173 self
174 }
175
176 pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<'a, S, M, P2>
179 where
180 P2: PromptHook<M>,
181 {
182 PromptRequest {
183 prompt: self.prompt,
184 chat_history: self.chat_history,
185 max_turns: self.max_turns,
186 model: self.model,
187 agent_name: self.agent_name,
188 preamble: self.preamble,
189 static_context: self.static_context,
190 temperature: self.temperature,
191 max_tokens: self.max_tokens,
192 additional_params: self.additional_params,
193 tool_server_handle: self.tool_server_handle,
194 dynamic_context: self.dynamic_context,
195 tool_choice: self.tool_choice,
196 state: PhantomData,
197 hook: Some(hook),
198 concurrency: self.concurrency,
199 output_schema: self.output_schema,
200 }
201 }
202}
203
204impl<'a, M, P> IntoFuture for PromptRequest<'a, Standard, M, P>
208where
209 M: CompletionModel + 'a,
210 P: PromptHook<M> + 'static,
211{
212 type Output = Result<String, PromptError>;
213 type IntoFuture = WasmBoxedFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
216 Box::pin(self.send())
217 }
218}
219
220impl<'a, M, P> IntoFuture for PromptRequest<'a, Extended, M, P>
221where
222 M: CompletionModel + 'a,
223 P: PromptHook<M> + 'static,
224{
225 type Output = Result<PromptResponse, PromptError>;
226 type IntoFuture = WasmBoxedFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
229 Box::pin(self.send())
230 }
231}
232
233impl<M, P> PromptRequest<'_, Standard, M, P>
234where
235 M: CompletionModel,
236 P: PromptHook<M>,
237{
238 async fn send(self) -> Result<String, PromptError> {
239 self.extended_details().send().await.map(|resp| resp.output)
240 }
241}
242
243#[derive(Debug, Clone)]
244#[non_exhaustive]
245pub struct PromptResponse {
246 pub output: String,
247 pub usage: Usage,
248 pub messages: Option<Vec<Message>>,
251}
252
253impl std::fmt::Display for PromptResponse {
254 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255 self.output.fmt(f)
256 }
257}
258
259impl PromptResponse {
260 pub fn new(output: impl Into<String>, usage: Usage) -> Self {
261 Self {
262 output: output.into(),
263 usage,
264 messages: None,
265 }
266 }
267
268 pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
269 self.messages = Some(messages);
270 self
271 }
272}
273
274#[derive(Debug, Clone)]
275pub struct TypedPromptResponse<T> {
276 pub output: T,
277 pub usage: Usage,
278}
279
280impl<T> TypedPromptResponse<T> {
281 pub fn new(output: T, usage: Usage) -> Self {
282 Self { output, usage }
283 }
284}
285
286const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
287
288impl<M, P> PromptRequest<'_, Extended, M, P>
289where
290 M: CompletionModel,
291 P: PromptHook<M>,
292{
293 fn agent_name(&self) -> &str {
294 self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
295 }
296
297 async fn send(mut self) -> Result<PromptResponse, PromptError> {
298 let agent_span = if tracing::Span::current().is_disabled() {
299 info_span!(
300 "invoke_agent",
301 gen_ai.operation.name = "invoke_agent",
302 gen_ai.agent.name = self.agent_name(),
303 gen_ai.system_instructions = self.preamble,
304 gen_ai.prompt = tracing::field::Empty,
305 gen_ai.completion = tracing::field::Empty,
306 gen_ai.usage.input_tokens = tracing::field::Empty,
307 gen_ai.usage.output_tokens = tracing::field::Empty,
308 )
309 } else {
310 tracing::Span::current()
311 };
312
313 if let Some(text) = self.prompt.rag_text() {
314 agent_span.record("gen_ai.prompt", text);
315 }
316
317 let agent_name_for_span = self.agent_name.clone();
319
320 let chat_history = if let Some(history) = self.chat_history.as_mut() {
321 history.push(self.prompt.to_owned());
322 history
323 } else {
324 &mut vec![self.prompt.to_owned()]
325 };
326
327 let mut current_max_turns = 0;
328 let mut usage = Usage::new();
329 let current_span_id: AtomicU64 = AtomicU64::new(0);
330
331 let last_prompt = loop {
333 let prompt = chat_history
334 .last()
335 .cloned()
336 .expect("there should always be at least one message in the chat history");
337
338 if current_max_turns > self.max_turns + 1 {
339 break prompt;
340 }
341
342 current_max_turns += 1;
343
344 if self.max_turns > 1 {
345 tracing::info!(
346 "Current conversation depth: {}/{}",
347 current_max_turns,
348 self.max_turns
349 );
350 }
351
352 if let Some(ref hook) = self.hook
353 && let HookAction::Terminate { reason } = hook
354 .on_completion_call(&prompt, &chat_history[..chat_history.len() - 1])
355 .await
356 {
357 return Err(PromptError::prompt_cancelled(chat_history.to_vec(), reason));
358 }
359
360 let span = tracing::Span::current();
361 let chat_span = info_span!(
362 target: "rig::agent_chat",
363 parent: &span,
364 "chat",
365 gen_ai.operation.name = "chat",
366 gen_ai.agent.name = agent_name_for_span.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
367 gen_ai.system_instructions = self.preamble,
368 gen_ai.provider.name = tracing::field::Empty,
369 gen_ai.request.model = tracing::field::Empty,
370 gen_ai.response.id = tracing::field::Empty,
371 gen_ai.response.model = tracing::field::Empty,
372 gen_ai.usage.output_tokens = tracing::field::Empty,
373 gen_ai.usage.input_tokens = tracing::field::Empty,
374 gen_ai.input.messages = tracing::field::Empty,
375 gen_ai.output.messages = tracing::field::Empty,
376 );
377
378 let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
379 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
380 chat_span.follows_from(id).to_owned()
381 } else {
382 chat_span
383 };
384
385 if let Some(id) = chat_span.id() {
386 current_span_id.store(id.into_u64(), Ordering::SeqCst);
387 };
388
389 let resp = build_completion_request(
390 &self.model,
391 prompt.clone(),
392 chat_history[..chat_history.len() - 1].to_vec(),
393 self.preamble.as_deref(),
394 &self.static_context,
395 self.temperature,
396 self.max_tokens,
397 self.additional_params.as_ref(),
398 self.tool_choice.as_ref(),
399 &self.tool_server_handle,
400 &self.dynamic_context,
401 self.output_schema.as_ref(),
402 )
403 .await?
404 .send()
405 .instrument(chat_span.clone())
406 .await?;
407
408 usage += resp.usage;
409
410 if let Some(ref hook) = self.hook
411 && let HookAction::Terminate { reason } =
412 hook.on_completion_response(&prompt, &resp).await
413 {
414 return Err(PromptError::prompt_cancelled(chat_history.to_vec(), reason));
415 }
416
417 let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
418 .choice
419 .iter()
420 .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
421
422 chat_history.push(Message::Assistant {
423 id: resp.message_id.clone(),
424 content: resp.choice.clone(),
425 });
426
427 if tool_calls.is_empty() {
428 let merged_texts = texts
429 .into_iter()
430 .filter_map(|content| {
431 if let AssistantContent::Text(text) = content {
432 Some(text.text.clone())
433 } else {
434 None
435 }
436 })
437 .collect::<Vec<_>>()
438 .join("\n");
439
440 if self.max_turns > 1 {
441 tracing::info!("Depth reached: {}/{}", current_max_turns, self.max_turns);
442 }
443
444 agent_span.record("gen_ai.completion", &merged_texts);
445 agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
446 agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
447
448 return Ok(
450 PromptResponse::new(merged_texts, usage).with_messages(chat_history.to_vec())
451 );
452 }
453
454 let hook = self.hook.clone();
455 let tool_server_handle = self.tool_server_handle.clone();
456
457 let tool_calls: Vec<AssistantContent> = tool_calls.into_iter().cloned().collect();
458 let tool_content = stream::iter(tool_calls)
459 .map(|choice| {
460 let hook1 = hook.clone();
461 let hook2 = hook.clone();
462 let tool_server_handle = tool_server_handle.clone();
463
464 let tool_span = info_span!(
465 "execute_tool",
466 gen_ai.operation.name = "execute_tool",
467 gen_ai.tool.type = "function",
468 gen_ai.tool.name = tracing::field::Empty,
469 gen_ai.tool.call.id = tracing::field::Empty,
470 gen_ai.tool.call.arguments = tracing::field::Empty,
471 gen_ai.tool.call.result = tracing::field::Empty
472 );
473
474 let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
475 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
476 tool_span.follows_from(id).to_owned()
477 } else {
478 tool_span
479 };
480
481 if let Some(id) = tool_span.id() {
482 current_span_id.store(id.into_u64(), Ordering::SeqCst);
483 };
484
485 let cloned_chat_history = chat_history.clone().to_vec();
486
487 async move {
488 if let AssistantContent::ToolCall(tool_call) = choice {
489 let tool_name = &tool_call.function.name;
490 let args =
491 json_utils::value_to_json_string(&tool_call.function.arguments);
492 let internal_call_id = nanoid::nanoid!();
493 let tool_span = tracing::Span::current();
494 tool_span.record("gen_ai.tool.name", tool_name);
495 tool_span.record("gen_ai.tool.call.id", &tool_call.id);
496 tool_span.record("gen_ai.tool.call.arguments", &args);
497 if let Some(hook) = hook1 {
498 let action = hook
499 .on_tool_call(
500 tool_name,
501 tool_call.call_id.clone(),
502 &internal_call_id,
503 &args,
504 )
505 .await;
506
507 if let ToolCallHookAction::Terminate { reason } = action {
508 return Err(PromptError::prompt_cancelled(
509 cloned_chat_history,
510 reason,
511 ));
512 }
513
514 if let ToolCallHookAction::Skip { reason } = action {
515 tracing::info!(
517 tool_name = tool_name,
518 reason = reason,
519 "Tool call rejected"
520 );
521 if let Some(call_id) = tool_call.call_id.clone() {
522 return Ok(UserContent::tool_result_with_call_id(
523 tool_call.id.clone(),
524 call_id,
525 OneOrMany::one(reason.into()),
526 ));
527 } else {
528 return Ok(UserContent::tool_result(
529 tool_call.id.clone(),
530 OneOrMany::one(reason.into()),
531 ));
532 }
533 }
534 }
535 let output = match tool_server_handle.call_tool(tool_name, &args).await
536 {
537 Ok(res) => res,
538 Err(e) => {
539 tracing::warn!("Error while executing tool: {e}");
540 e.to_string()
541 }
542 };
543 if let Some(hook) = hook2
544 && let HookAction::Terminate { reason } = hook
545 .on_tool_result(
546 tool_name,
547 tool_call.call_id.clone(),
548 &internal_call_id,
549 &args,
550 &output.to_string(),
551 )
552 .await
553 {
554 return Err(PromptError::prompt_cancelled(
555 cloned_chat_history,
556 reason,
557 ));
558 }
559
560 tool_span.record("gen_ai.tool.call.result", &output);
561 tracing::info!(
562 "executed tool {tool_name} with args {args}. result: {output}"
563 );
564 if let Some(call_id) = tool_call.call_id.clone() {
565 Ok(UserContent::tool_result_with_call_id(
566 tool_call.id.clone(),
567 call_id,
568 ToolResultContent::from_tool_output(output),
569 ))
570 } else {
571 Ok(UserContent::tool_result(
572 tool_call.id.clone(),
573 ToolResultContent::from_tool_output(output),
574 ))
575 }
576 } else {
577 unreachable!(
578 "This should never happen as we already filtered for `ToolCall`"
579 )
580 }
581 }
582 .instrument(tool_span)
583 })
584 .buffer_unordered(self.concurrency)
585 .collect::<Vec<Result<UserContent, PromptError>>>()
586 .await
587 .into_iter()
588 .collect::<Result<Vec<_>, _>>()?;
589
590 chat_history.push(Message::User {
591 content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
592 });
593 };
594
595 Err(PromptError::MaxTurnsError {
597 max_turns: self.max_turns,
598 chat_history: Box::new(chat_history.clone()),
599 prompt: Box::new(last_prompt),
600 })
601 }
602}
603
604use crate::completion::StructuredOutputError;
609use schemars::{JsonSchema, schema_for};
610use serde::de::DeserializeOwned;
611
612pub struct TypedPromptRequest<'a, T, S, M, P>
629where
630 T: JsonSchema + DeserializeOwned + WasmCompatSend,
631 S: PromptType,
632 M: CompletionModel,
633 P: PromptHook<M>,
634{
635 inner: PromptRequest<'a, S, M, P>,
636 _phantom: std::marker::PhantomData<T>,
637}
638
639impl<'a, T, M, P> TypedPromptRequest<'a, T, Standard, M, P>
640where
641 T: JsonSchema + DeserializeOwned + WasmCompatSend,
642 M: CompletionModel,
643 P: PromptHook<M>,
644{
645 pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
649 let mut inner = PromptRequest::from_agent(agent, prompt);
650 inner.output_schema = Some(schema_for!(T));
652 Self {
653 inner,
654 _phantom: std::marker::PhantomData,
655 }
656 }
657}
658
659impl<'a, T, S, M, P> TypedPromptRequest<'a, T, S, M, P>
660where
661 T: JsonSchema + DeserializeOwned + WasmCompatSend,
662 S: PromptType,
663 M: CompletionModel,
664 P: PromptHook<M>,
665{
666 pub fn extended_details(self) -> TypedPromptRequest<'a, T, Extended, M, P> {
672 TypedPromptRequest {
673 inner: self.inner.extended_details(),
674 _phantom: std::marker::PhantomData,
675 }
676 }
677
678 pub fn max_turns(mut self, depth: usize) -> Self {
684 self.inner = self.inner.max_turns(depth);
685 self
686 }
687
688 pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
692 self.inner = self.inner.with_tool_concurrency(concurrency);
693 self
694 }
695
696 pub fn with_history(mut self, history: &'a mut Vec<Message>) -> Self {
698 self.inner = self.inner.with_history(history);
699 self
700 }
701
702 pub fn with_hook<P2>(self, hook: P2) -> TypedPromptRequest<'a, T, S, M, P2>
706 where
707 P2: PromptHook<M>,
708 {
709 TypedPromptRequest {
710 inner: self.inner.with_hook(hook),
711 _phantom: std::marker::PhantomData,
712 }
713 }
714}
715
716impl<'a, T, M, P> TypedPromptRequest<'a, T, Standard, M, P>
717where
718 T: JsonSchema + DeserializeOwned + WasmCompatSend,
719 M: CompletionModel,
720 P: PromptHook<M>,
721{
722 async fn send(self) -> Result<T, StructuredOutputError> {
724 let response = self.inner.send().await?;
725
726 if response.is_empty() {
727 return Err(StructuredOutputError::EmptyResponse);
728 }
729
730 let parsed: T = serde_json::from_str(&response)?;
731 Ok(parsed)
732 }
733}
734
735impl<'a, T, M, P> TypedPromptRequest<'a, T, Extended, M, P>
736where
737 T: JsonSchema + DeserializeOwned + WasmCompatSend,
738 M: CompletionModel,
739 P: PromptHook<M>,
740{
741 async fn send(self) -> Result<TypedPromptResponse<T>, StructuredOutputError> {
743 let response = self.inner.send().await?;
744
745 if response.output.is_empty() {
746 return Err(StructuredOutputError::EmptyResponse);
747 }
748
749 let parsed: T = serde_json::from_str(&response.output)?;
750 Ok(TypedPromptResponse::new(parsed, response.usage))
751 }
752}
753
754impl<'a, T, M, P> IntoFuture for TypedPromptRequest<'a, T, Standard, M, P>
755where
756 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'a,
757 M: CompletionModel + 'a,
758 P: PromptHook<M> + 'static,
759{
760 type Output = Result<T, StructuredOutputError>;
761 type IntoFuture = WasmBoxedFuture<'a, Self::Output>;
762
763 fn into_future(self) -> Self::IntoFuture {
764 Box::pin(self.send())
765 }
766}
767
768impl<'a, T, M, P> IntoFuture for TypedPromptRequest<'a, T, Extended, M, P>
769where
770 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'a,
771 M: CompletionModel + 'a,
772 P: PromptHook<M> + 'static,
773{
774 type Output = Result<TypedPromptResponse<T>, StructuredOutputError>;
775 type IntoFuture = WasmBoxedFuture<'a, Self::Output>;
776
777 fn into_future(self) -> Self::IntoFuture {
778 Box::pin(self.send())
779 }
780}