1pub mod hooks;
2pub mod streaming;
3
4use hooks::{HookAction, PromptHook, ToolCallHookAction};
5use std::{
6 future::IntoFuture,
7 marker::PhantomData,
8 sync::{
9 Arc,
10 atomic::{AtomicU64, Ordering},
11 },
12};
13use tracing::{Instrument, span::Id};
14
15use futures::{StreamExt, stream};
16use tracing::info_span;
17
18use crate::{
19 OneOrMany,
20 completion::{CompletionModel, Document, Message, PromptError, Usage},
21 json_utils,
22 message::{AssistantContent, ToolChoice, ToolResultContent, UserContent},
23 tool::server::ToolServerHandle,
24 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
25};
26
27use super::{
28 Agent,
29 completion::{DynamicContextStore, build_completion_request},
30};
31
32pub trait PromptType {}
33pub struct Standard;
34pub struct Extended;
35
36impl PromptType for Standard {}
37impl PromptType for Extended {}
38
39pub struct PromptRequest<'a, S, M, P>
48where
49 S: PromptType,
50 M: CompletionModel,
51 P: PromptHook<M>,
52{
53 prompt: Message,
55 chat_history: Option<&'a mut Vec<Message>>,
58 max_turns: usize,
60
61 model: Arc<M>,
64 agent_name: Option<String>,
66 preamble: Option<String>,
68 static_context: Vec<Document>,
70 temperature: Option<f64>,
72 max_tokens: Option<u64>,
74 additional_params: Option<serde_json::Value>,
76 tool_server_handle: ToolServerHandle,
78 dynamic_context: DynamicContextStore,
80 tool_choice: Option<ToolChoice>,
82
83 state: PhantomData<S>,
85 hook: Option<P>,
87 concurrency: usize,
89 output_schema: Option<schemars::Schema>,
91}
92
93impl<'a, M, P> PromptRequest<'a, Standard, M, P>
94where
95 M: CompletionModel,
96 P: PromptHook<M>,
97{
98 pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
100 PromptRequest {
101 prompt: prompt.into(),
102 chat_history: None,
103 max_turns: agent.default_max_turns.unwrap_or_default(),
104 model: agent.model.clone(),
105 agent_name: agent.name.clone(),
106 preamble: agent.preamble.clone(),
107 static_context: agent.static_context.clone(),
108 temperature: agent.temperature,
109 max_tokens: agent.max_tokens,
110 additional_params: agent.additional_params.clone(),
111 tool_server_handle: agent.tool_server_handle.clone(),
112 dynamic_context: agent.dynamic_context.clone(),
113 tool_choice: agent.tool_choice.clone(),
114 state: PhantomData,
115 hook: agent.hook.clone(),
116 concurrency: 1,
117 output_schema: agent.output_schema.clone(),
118 }
119 }
120}
121
122impl<'a, S, M, P> PromptRequest<'a, S, M, P>
123where
124 S: PromptType,
125 M: CompletionModel,
126 P: PromptHook<M>,
127{
128 pub fn extended_details(self) -> PromptRequest<'a, Extended, M, P> {
135 PromptRequest {
136 prompt: self.prompt,
137 chat_history: self.chat_history,
138 max_turns: self.max_turns,
139 model: self.model,
140 agent_name: self.agent_name,
141 preamble: self.preamble,
142 static_context: self.static_context,
143 temperature: self.temperature,
144 max_tokens: self.max_tokens,
145 additional_params: self.additional_params,
146 tool_server_handle: self.tool_server_handle,
147 dynamic_context: self.dynamic_context,
148 tool_choice: self.tool_choice,
149 state: PhantomData,
150 hook: self.hook,
151 concurrency: self.concurrency,
152 output_schema: self.output_schema,
153 }
154 }
155
156 pub fn max_turns(mut self, depth: usize) -> Self {
159 self.max_turns = depth;
160 self
161 }
162
163 pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
166 self.concurrency = concurrency;
167 self
168 }
169
170 pub fn with_history(mut self, history: &'a mut Vec<Message>) -> Self {
172 self.chat_history = Some(history);
173 self
174 }
175
176 pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<'a, S, M, P2>
179 where
180 P2: PromptHook<M>,
181 {
182 PromptRequest {
183 prompt: self.prompt,
184 chat_history: self.chat_history,
185 max_turns: self.max_turns,
186 model: self.model,
187 agent_name: self.agent_name,
188 preamble: self.preamble,
189 static_context: self.static_context,
190 temperature: self.temperature,
191 max_tokens: self.max_tokens,
192 additional_params: self.additional_params,
193 tool_server_handle: self.tool_server_handle,
194 dynamic_context: self.dynamic_context,
195 tool_choice: self.tool_choice,
196 state: PhantomData,
197 hook: Some(hook),
198 concurrency: self.concurrency,
199 output_schema: self.output_schema,
200 }
201 }
202}
203
204impl<'a, M, P> IntoFuture for PromptRequest<'a, Standard, M, P>
208where
209 M: CompletionModel + 'a,
210 P: PromptHook<M> + 'static,
211{
212 type Output = Result<String, PromptError>;
213 type IntoFuture = WasmBoxedFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
216 Box::pin(self.send())
217 }
218}
219
220impl<'a, M, P> IntoFuture for PromptRequest<'a, Extended, M, P>
221where
222 M: CompletionModel + 'a,
223 P: PromptHook<M> + 'static,
224{
225 type Output = Result<PromptResponse, PromptError>;
226 type IntoFuture = WasmBoxedFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
229 Box::pin(self.send())
230 }
231}
232
233impl<M, P> PromptRequest<'_, Standard, M, P>
234where
235 M: CompletionModel,
236 P: PromptHook<M>,
237{
238 async fn send(self) -> Result<String, PromptError> {
239 self.extended_details().send().await.map(|resp| resp.output)
240 }
241}
242
243#[derive(Debug, Clone)]
244#[non_exhaustive]
245pub struct PromptResponse {
246 pub output: String,
247 pub usage: Usage,
248 pub messages: Option<Vec<Message>>,
251}
252
253impl std::fmt::Display for PromptResponse {
254 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255 self.output.fmt(f)
256 }
257}
258
259impl PromptResponse {
260 pub fn new(output: impl Into<String>, usage: Usage) -> Self {
261 Self {
262 output: output.into(),
263 usage,
264 messages: None,
265 }
266 }
267
268 pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
269 self.messages = Some(messages);
270 self
271 }
272}
273
274#[derive(Debug, Clone)]
275pub struct TypedPromptResponse<T> {
276 pub output: T,
277 pub usage: Usage,
278}
279
280impl<T> TypedPromptResponse<T> {
281 pub fn new(output: T, usage: Usage) -> Self {
282 Self { output, usage }
283 }
284}
285
286const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
287
288impl<M, P> PromptRequest<'_, Extended, M, P>
289where
290 M: CompletionModel,
291 P: PromptHook<M>,
292{
293 fn agent_name(&self) -> &str {
294 self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
295 }
296
297 async fn send(mut self) -> Result<PromptResponse, PromptError> {
298 let agent_span = if tracing::Span::current().is_disabled() {
299 info_span!(
300 "invoke_agent",
301 gen_ai.operation.name = "invoke_agent",
302 gen_ai.agent.name = self.agent_name(),
303 gen_ai.system_instructions = self.preamble,
304 gen_ai.prompt = tracing::field::Empty,
305 gen_ai.completion = tracing::field::Empty,
306 gen_ai.usage.input_tokens = tracing::field::Empty,
307 gen_ai.usage.output_tokens = tracing::field::Empty,
308 gen_ai.usage.cached_tokens = tracing::field::Empty,
309 )
310 } else {
311 tracing::Span::current()
312 };
313
314 if let Some(text) = self.prompt.rag_text() {
315 agent_span.record("gen_ai.prompt", text);
316 }
317
318 let agent_name_for_span = self.agent_name.clone();
320
321 let chat_history = if let Some(history) = self.chat_history.as_mut() {
322 history.push(self.prompt.to_owned());
323 history
324 } else {
325 &mut vec![self.prompt.to_owned()]
326 };
327
328 let mut current_max_turns = 0;
329 let mut usage = Usage::new();
330 let current_span_id: AtomicU64 = AtomicU64::new(0);
331
332 let last_prompt = loop {
334 let prompt = chat_history
335 .last()
336 .cloned()
337 .expect("there should always be at least one message in the chat history");
338
339 if current_max_turns > self.max_turns + 1 {
340 break prompt;
341 }
342
343 current_max_turns += 1;
344
345 if self.max_turns > 1 {
346 tracing::info!(
347 "Current conversation depth: {}/{}",
348 current_max_turns,
349 self.max_turns
350 );
351 }
352
353 if let Some(ref hook) = self.hook
354 && let HookAction::Terminate { reason } = hook
355 .on_completion_call(&prompt, &chat_history[..chat_history.len() - 1])
356 .await
357 {
358 return Err(PromptError::prompt_cancelled(chat_history.to_vec(), reason));
359 }
360
361 let span = tracing::Span::current();
362 let chat_span = info_span!(
363 target: "rig::agent_chat",
364 parent: &span,
365 "chat",
366 gen_ai.operation.name = "chat",
367 gen_ai.agent.name = agent_name_for_span.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
368 gen_ai.system_instructions = self.preamble,
369 gen_ai.provider.name = tracing::field::Empty,
370 gen_ai.request.model = tracing::field::Empty,
371 gen_ai.response.id = tracing::field::Empty,
372 gen_ai.response.model = tracing::field::Empty,
373 gen_ai.usage.output_tokens = tracing::field::Empty,
374 gen_ai.usage.input_tokens = tracing::field::Empty,
375 gen_ai.usage.cached_tokens = tracing::field::Empty,
376 gen_ai.input.messages = tracing::field::Empty,
377 gen_ai.output.messages = tracing::field::Empty,
378 );
379
380 let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
381 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
382 chat_span.follows_from(id).to_owned()
383 } else {
384 chat_span
385 };
386
387 if let Some(id) = chat_span.id() {
388 current_span_id.store(id.into_u64(), Ordering::SeqCst);
389 };
390
391 let resp = build_completion_request(
392 &self.model,
393 prompt.clone(),
394 chat_history[..chat_history.len() - 1].to_vec(),
395 self.preamble.as_deref(),
396 &self.static_context,
397 self.temperature,
398 self.max_tokens,
399 self.additional_params.as_ref(),
400 self.tool_choice.as_ref(),
401 &self.tool_server_handle,
402 &self.dynamic_context,
403 self.output_schema.as_ref(),
404 )
405 .await?
406 .send()
407 .instrument(chat_span.clone())
408 .await?;
409
410 usage += resp.usage;
411
412 if let Some(ref hook) = self.hook
413 && let HookAction::Terminate { reason } =
414 hook.on_completion_response(&prompt, &resp).await
415 {
416 return Err(PromptError::prompt_cancelled(chat_history.to_vec(), reason));
417 }
418
419 let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
420 .choice
421 .iter()
422 .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
423
424 chat_history.push(Message::Assistant {
425 id: resp.message_id.clone(),
426 content: resp.choice.clone(),
427 });
428
429 if tool_calls.is_empty() {
430 let merged_texts = texts
431 .into_iter()
432 .filter_map(|content| {
433 if let AssistantContent::Text(text) = content {
434 Some(text.text.clone())
435 } else {
436 None
437 }
438 })
439 .collect::<Vec<_>>()
440 .join("\n");
441
442 if self.max_turns > 1 {
443 tracing::info!("Depth reached: {}/{}", current_max_turns, self.max_turns);
444 }
445
446 agent_span.record("gen_ai.completion", &merged_texts);
447 agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
448 agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
449 agent_span.record("gen_ai.usage.cached_tokens", usage.cached_input_tokens);
450
451 return Ok(
453 PromptResponse::new(merged_texts, usage).with_messages(chat_history.to_vec())
454 );
455 }
456
457 let hook = self.hook.clone();
458 let tool_server_handle = self.tool_server_handle.clone();
459
460 let tool_calls: Vec<AssistantContent> = tool_calls.into_iter().cloned().collect();
461 let tool_content = stream::iter(tool_calls)
462 .map(|choice| {
463 let hook1 = hook.clone();
464 let hook2 = hook.clone();
465 let tool_server_handle = tool_server_handle.clone();
466
467 let tool_span = info_span!(
468 "execute_tool",
469 gen_ai.operation.name = "execute_tool",
470 gen_ai.tool.type = "function",
471 gen_ai.tool.name = tracing::field::Empty,
472 gen_ai.tool.call.id = tracing::field::Empty,
473 gen_ai.tool.call.arguments = tracing::field::Empty,
474 gen_ai.tool.call.result = tracing::field::Empty
475 );
476
477 let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
478 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
479 tool_span.follows_from(id).to_owned()
480 } else {
481 tool_span
482 };
483
484 if let Some(id) = tool_span.id() {
485 current_span_id.store(id.into_u64(), Ordering::SeqCst);
486 };
487
488 let cloned_chat_history = chat_history.clone().to_vec();
489
490 async move {
491 if let AssistantContent::ToolCall(tool_call) = choice {
492 let tool_name = &tool_call.function.name;
493 let args =
494 json_utils::value_to_json_string(&tool_call.function.arguments);
495 let internal_call_id = nanoid::nanoid!();
496 let tool_span = tracing::Span::current();
497 tool_span.record("gen_ai.tool.name", tool_name);
498 tool_span.record("gen_ai.tool.call.id", &tool_call.id);
499 tool_span.record("gen_ai.tool.call.arguments", &args);
500 if let Some(hook) = hook1 {
501 let action = hook
502 .on_tool_call(
503 tool_name,
504 tool_call.call_id.clone(),
505 &internal_call_id,
506 &args,
507 )
508 .await;
509
510 if let ToolCallHookAction::Terminate { reason } = action {
511 return Err(PromptError::prompt_cancelled(
512 cloned_chat_history,
513 reason,
514 ));
515 }
516
517 if let ToolCallHookAction::Skip { reason } = action {
518 tracing::info!(
520 tool_name = tool_name,
521 reason = reason,
522 "Tool call rejected"
523 );
524 if let Some(call_id) = tool_call.call_id.clone() {
525 return Ok(UserContent::tool_result_with_call_id(
526 tool_call.id.clone(),
527 call_id,
528 OneOrMany::one(reason.into()),
529 ));
530 } else {
531 return Ok(UserContent::tool_result(
532 tool_call.id.clone(),
533 OneOrMany::one(reason.into()),
534 ));
535 }
536 }
537 }
538 let output = match tool_server_handle.call_tool(tool_name, &args).await
539 {
540 Ok(res) => res,
541 Err(e) => {
542 tracing::warn!("Error while executing tool: {e}");
543 e.to_string()
544 }
545 };
546 if let Some(hook) = hook2
547 && let HookAction::Terminate { reason } = hook
548 .on_tool_result(
549 tool_name,
550 tool_call.call_id.clone(),
551 &internal_call_id,
552 &args,
553 &output.to_string(),
554 )
555 .await
556 {
557 return Err(PromptError::prompt_cancelled(
558 cloned_chat_history,
559 reason,
560 ));
561 }
562
563 tool_span.record("gen_ai.tool.call.result", &output);
564 tracing::info!(
565 "executed tool {tool_name} with args {args}. result: {output}"
566 );
567 if let Some(call_id) = tool_call.call_id.clone() {
568 Ok(UserContent::tool_result_with_call_id(
569 tool_call.id.clone(),
570 call_id,
571 ToolResultContent::from_tool_output(output),
572 ))
573 } else {
574 Ok(UserContent::tool_result(
575 tool_call.id.clone(),
576 ToolResultContent::from_tool_output(output),
577 ))
578 }
579 } else {
580 unreachable!(
581 "This should never happen as we already filtered for `ToolCall`"
582 )
583 }
584 }
585 .instrument(tool_span)
586 })
587 .buffer_unordered(self.concurrency)
588 .collect::<Vec<Result<UserContent, PromptError>>>()
589 .await
590 .into_iter()
591 .collect::<Result<Vec<_>, _>>()?;
592
593 chat_history.push(Message::User {
594 content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
595 });
596 };
597
598 Err(PromptError::MaxTurnsError {
600 max_turns: self.max_turns,
601 chat_history: Box::new(chat_history.clone()),
602 prompt: Box::new(last_prompt),
603 })
604 }
605}
606
607use crate::completion::StructuredOutputError;
612use schemars::{JsonSchema, schema_for};
613use serde::de::DeserializeOwned;
614
615pub struct TypedPromptRequest<'a, T, S, M, P>
632where
633 T: JsonSchema + DeserializeOwned + WasmCompatSend,
634 S: PromptType,
635 M: CompletionModel,
636 P: PromptHook<M>,
637{
638 inner: PromptRequest<'a, S, M, P>,
639 _phantom: std::marker::PhantomData<T>,
640}
641
642impl<'a, T, M, P> TypedPromptRequest<'a, T, Standard, M, P>
643where
644 T: JsonSchema + DeserializeOwned + WasmCompatSend,
645 M: CompletionModel,
646 P: PromptHook<M>,
647{
648 pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
652 let mut inner = PromptRequest::from_agent(agent, prompt);
653 inner.output_schema = Some(schema_for!(T));
655 Self {
656 inner,
657 _phantom: std::marker::PhantomData,
658 }
659 }
660}
661
662impl<'a, T, S, M, P> TypedPromptRequest<'a, T, S, M, P>
663where
664 T: JsonSchema + DeserializeOwned + WasmCompatSend,
665 S: PromptType,
666 M: CompletionModel,
667 P: PromptHook<M>,
668{
669 pub fn extended_details(self) -> TypedPromptRequest<'a, T, Extended, M, P> {
675 TypedPromptRequest {
676 inner: self.inner.extended_details(),
677 _phantom: std::marker::PhantomData,
678 }
679 }
680
681 pub fn max_turns(mut self, depth: usize) -> Self {
687 self.inner = self.inner.max_turns(depth);
688 self
689 }
690
691 pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
695 self.inner = self.inner.with_tool_concurrency(concurrency);
696 self
697 }
698
699 pub fn with_history(mut self, history: &'a mut Vec<Message>) -> Self {
701 self.inner = self.inner.with_history(history);
702 self
703 }
704
705 pub fn with_hook<P2>(self, hook: P2) -> TypedPromptRequest<'a, T, S, M, P2>
709 where
710 P2: PromptHook<M>,
711 {
712 TypedPromptRequest {
713 inner: self.inner.with_hook(hook),
714 _phantom: std::marker::PhantomData,
715 }
716 }
717}
718
719impl<'a, T, M, P> TypedPromptRequest<'a, T, Standard, M, P>
720where
721 T: JsonSchema + DeserializeOwned + WasmCompatSend,
722 M: CompletionModel,
723 P: PromptHook<M>,
724{
725 async fn send(self) -> Result<T, StructuredOutputError> {
727 let response = self.inner.send().await?;
728
729 if response.is_empty() {
730 return Err(StructuredOutputError::EmptyResponse);
731 }
732
733 let parsed: T = serde_json::from_str(&response)?;
734 Ok(parsed)
735 }
736}
737
738impl<'a, T, M, P> TypedPromptRequest<'a, T, Extended, M, P>
739where
740 T: JsonSchema + DeserializeOwned + WasmCompatSend,
741 M: CompletionModel,
742 P: PromptHook<M>,
743{
744 async fn send(self) -> Result<TypedPromptResponse<T>, StructuredOutputError> {
746 let response = self.inner.send().await?;
747
748 if response.output.is_empty() {
749 return Err(StructuredOutputError::EmptyResponse);
750 }
751
752 let parsed: T = serde_json::from_str(&response.output)?;
753 Ok(TypedPromptResponse::new(parsed, response.usage))
754 }
755}
756
757impl<'a, T, M, P> IntoFuture for TypedPromptRequest<'a, T, Standard, M, P>
758where
759 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'a,
760 M: CompletionModel + 'a,
761 P: PromptHook<M> + 'static,
762{
763 type Output = Result<T, StructuredOutputError>;
764 type IntoFuture = WasmBoxedFuture<'a, Self::Output>;
765
766 fn into_future(self) -> Self::IntoFuture {
767 Box::pin(self.send())
768 }
769}
770
771impl<'a, T, M, P> IntoFuture for TypedPromptRequest<'a, T, Extended, M, P>
772where
773 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'a,
774 M: CompletionModel + 'a,
775 P: PromptHook<M> + 'static,
776{
777 type Output = Result<TypedPromptResponse<T>, StructuredOutputError>;
778 type IntoFuture = WasmBoxedFuture<'a, Self::Output>;
779
780 fn into_future(self) -> Self::IntoFuture {
781 Box::pin(self.send())
782 }
783}