1pub mod hooks;
2pub mod streaming;
3
4use hooks::{HookAction, PromptHook, ToolCallHookAction};
5use std::{
6 future::IntoFuture,
7 marker::PhantomData,
8 sync::{
9 Arc,
10 atomic::{AtomicU64, Ordering},
11 },
12};
13use tracing::{Instrument, span::Id};
14
15use futures::{StreamExt, stream};
16use tracing::info_span;
17
18use crate::{
19 OneOrMany,
20 completion::{CompletionModel, Document, Message, PromptError, Usage},
21 json_utils,
22 message::{AssistantContent, ToolChoice, ToolResultContent, UserContent},
23 tool::server::ToolServerHandle,
24 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
25};
26
27use super::{
28 Agent,
29 completion::{DynamicContextStore, build_completion_request},
30};
31
32pub trait PromptType {}
33pub struct Standard;
34pub struct Extended;
35
36impl PromptType for Standard {}
37impl PromptType for Extended {}
38
39pub struct PromptRequest<'a, S, M, P>
48where
49 S: PromptType,
50 M: CompletionModel,
51 P: PromptHook<M>,
52{
53 prompt: Message,
55 chat_history: Option<&'a mut Vec<Message>>,
58 max_turns: usize,
60
61 model: Arc<M>,
64 agent_name: Option<String>,
66 preamble: Option<String>,
68 static_context: Vec<Document>,
70 temperature: Option<f64>,
72 max_tokens: Option<u64>,
74 additional_params: Option<serde_json::Value>,
76 tool_server_handle: ToolServerHandle,
78 dynamic_context: DynamicContextStore,
80 tool_choice: Option<ToolChoice>,
82
83 state: PhantomData<S>,
85 hook: Option<P>,
87 concurrency: usize,
89 output_schema: Option<schemars::Schema>,
91}
92
93impl<'a, M, P> PromptRequest<'a, Standard, M, P>
94where
95 M: CompletionModel,
96 P: PromptHook<M>,
97{
98 pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
100 PromptRequest {
101 prompt: prompt.into(),
102 chat_history: None,
103 max_turns: agent.default_max_turns.unwrap_or_default(),
104 model: agent.model.clone(),
105 agent_name: agent.name.clone(),
106 preamble: agent.preamble.clone(),
107 static_context: agent.static_context.clone(),
108 temperature: agent.temperature,
109 max_tokens: agent.max_tokens,
110 additional_params: agent.additional_params.clone(),
111 tool_server_handle: agent.tool_server_handle.clone(),
112 dynamic_context: agent.dynamic_context.clone(),
113 tool_choice: agent.tool_choice.clone(),
114 state: PhantomData,
115 hook: agent.hook.clone(),
116 concurrency: 1,
117 output_schema: agent.output_schema.clone(),
118 }
119 }
120}
121
122impl<'a, S, M, P> PromptRequest<'a, S, M, P>
123where
124 S: PromptType,
125 M: CompletionModel,
126 P: PromptHook<M>,
127{
128 pub fn extended_details(self) -> PromptRequest<'a, Extended, M, P> {
134 PromptRequest {
135 prompt: self.prompt,
136 chat_history: self.chat_history,
137 max_turns: self.max_turns,
138 model: self.model,
139 agent_name: self.agent_name,
140 preamble: self.preamble,
141 static_context: self.static_context,
142 temperature: self.temperature,
143 max_tokens: self.max_tokens,
144 additional_params: self.additional_params,
145 tool_server_handle: self.tool_server_handle,
146 dynamic_context: self.dynamic_context,
147 tool_choice: self.tool_choice,
148 state: PhantomData,
149 hook: self.hook,
150 concurrency: self.concurrency,
151 output_schema: self.output_schema,
152 }
153 }
154
155 pub fn max_turns(mut self, depth: usize) -> Self {
158 self.max_turns = depth;
159 self
160 }
161
162 pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
165 self.concurrency = concurrency;
166 self
167 }
168
169 pub fn with_history(mut self, history: &'a mut Vec<Message>) -> Self {
171 self.chat_history = Some(history);
172 self
173 }
174
175 pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<'a, S, M, P2>
178 where
179 P2: PromptHook<M>,
180 {
181 PromptRequest {
182 prompt: self.prompt,
183 chat_history: self.chat_history,
184 max_turns: self.max_turns,
185 model: self.model,
186 agent_name: self.agent_name,
187 preamble: self.preamble,
188 static_context: self.static_context,
189 temperature: self.temperature,
190 max_tokens: self.max_tokens,
191 additional_params: self.additional_params,
192 tool_server_handle: self.tool_server_handle,
193 dynamic_context: self.dynamic_context,
194 tool_choice: self.tool_choice,
195 state: PhantomData,
196 hook: Some(hook),
197 concurrency: self.concurrency,
198 output_schema: self.output_schema,
199 }
200 }
201}
202
203impl<'a, M, P> IntoFuture for PromptRequest<'a, Standard, M, P>
207where
208 M: CompletionModel + 'a,
209 P: PromptHook<M> + 'static,
210{
211 type Output = Result<String, PromptError>;
212 type IntoFuture = WasmBoxedFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
215 Box::pin(self.send())
216 }
217}
218
219impl<'a, M, P> IntoFuture for PromptRequest<'a, Extended, M, P>
220where
221 M: CompletionModel + 'a,
222 P: PromptHook<M> + 'static,
223{
224 type Output = Result<PromptResponse, PromptError>;
225 type IntoFuture = WasmBoxedFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
228 Box::pin(self.send())
229 }
230}
231
232impl<M, P> PromptRequest<'_, Standard, M, P>
233where
234 M: CompletionModel,
235 P: PromptHook<M>,
236{
237 async fn send(self) -> Result<String, PromptError> {
238 self.extended_details().send().await.map(|resp| resp.output)
239 }
240}
241
242#[derive(Debug, Clone)]
243pub struct PromptResponse {
244 pub output: String,
245 pub total_usage: Usage,
246}
247
248impl PromptResponse {
249 pub fn new(output: impl Into<String>, total_usage: Usage) -> Self {
250 Self {
251 output: output.into(),
252 total_usage,
253 }
254 }
255}
256
257const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
258
259impl<M, P> PromptRequest<'_, Extended, M, P>
260where
261 M: CompletionModel,
262 P: PromptHook<M>,
263{
264 fn agent_name(&self) -> &str {
265 self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
266 }
267
268 async fn send(mut self) -> Result<PromptResponse, PromptError> {
269 let agent_span = if tracing::Span::current().is_disabled() {
270 info_span!(
271 "invoke_agent",
272 gen_ai.operation.name = "invoke_agent",
273 gen_ai.agent.name = self.agent_name(),
274 gen_ai.system_instructions = self.preamble,
275 gen_ai.prompt = tracing::field::Empty,
276 gen_ai.completion = tracing::field::Empty,
277 gen_ai.usage.input_tokens = tracing::field::Empty,
278 gen_ai.usage.output_tokens = tracing::field::Empty,
279 )
280 } else {
281 tracing::Span::current()
282 };
283
284 if let Some(text) = self.prompt.rag_text() {
285 agent_span.record("gen_ai.prompt", text);
286 }
287
288 let agent_name_for_span = self.agent_name.clone();
290
291 let chat_history = if let Some(history) = self.chat_history.as_mut() {
292 history.push(self.prompt.to_owned());
293 history
294 } else {
295 &mut vec![self.prompt.to_owned()]
296 };
297
298 let mut current_max_turns = 0;
299 let mut usage = Usage::new();
300 let current_span_id: AtomicU64 = AtomicU64::new(0);
301
302 let last_prompt = loop {
304 let prompt = chat_history
305 .last()
306 .cloned()
307 .expect("there should always be at least one message in the chat history");
308
309 if current_max_turns > self.max_turns + 1 {
310 break prompt;
311 }
312
313 current_max_turns += 1;
314
315 if self.max_turns > 1 {
316 tracing::info!(
317 "Current conversation depth: {}/{}",
318 current_max_turns,
319 self.max_turns
320 );
321 }
322
323 if let Some(ref hook) = self.hook
324 && let HookAction::Terminate { reason } = hook
325 .on_completion_call(&prompt, &chat_history[..chat_history.len() - 1])
326 .await
327 {
328 return Err(PromptError::prompt_cancelled(chat_history.to_vec(), reason));
329 }
330
331 let span = tracing::Span::current();
332 let chat_span = info_span!(
333 target: "rig::agent_chat",
334 parent: &span,
335 "chat",
336 gen_ai.operation.name = "chat",
337 gen_ai.agent.name = agent_name_for_span.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
338 gen_ai.system_instructions = self.preamble,
339 gen_ai.provider.name = tracing::field::Empty,
340 gen_ai.request.model = tracing::field::Empty,
341 gen_ai.response.id = tracing::field::Empty,
342 gen_ai.response.model = tracing::field::Empty,
343 gen_ai.usage.output_tokens = tracing::field::Empty,
344 gen_ai.usage.input_tokens = tracing::field::Empty,
345 gen_ai.input.messages = tracing::field::Empty,
346 gen_ai.output.messages = tracing::field::Empty,
347 );
348
349 let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
350 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
351 chat_span.follows_from(id).to_owned()
352 } else {
353 chat_span
354 };
355
356 if let Some(id) = chat_span.id() {
357 current_span_id.store(id.into_u64(), Ordering::SeqCst);
358 };
359
360 let resp = build_completion_request(
361 &self.model,
362 prompt.clone(),
363 chat_history[..chat_history.len() - 1].to_vec(),
364 self.preamble.as_deref(),
365 &self.static_context,
366 self.temperature,
367 self.max_tokens,
368 self.additional_params.as_ref(),
369 self.tool_choice.as_ref(),
370 &self.tool_server_handle,
371 &self.dynamic_context,
372 self.output_schema.as_ref(),
373 )
374 .await?
375 .send()
376 .instrument(chat_span.clone())
377 .await?;
378
379 usage += resp.usage;
380
381 if let Some(ref hook) = self.hook
382 && let HookAction::Terminate { reason } =
383 hook.on_completion_response(&prompt, &resp).await
384 {
385 return Err(PromptError::prompt_cancelled(chat_history.to_vec(), reason));
386 }
387
388 let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
389 .choice
390 .iter()
391 .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
392
393 chat_history.push(Message::Assistant {
394 id: resp.message_id.clone(),
395 content: resp.choice.clone(),
396 });
397
398 if tool_calls.is_empty() {
399 let merged_texts = texts
400 .into_iter()
401 .filter_map(|content| {
402 if let AssistantContent::Text(text) = content {
403 Some(text.text.clone())
404 } else {
405 None
406 }
407 })
408 .collect::<Vec<_>>()
409 .join("\n");
410
411 if self.max_turns > 1 {
412 tracing::info!("Depth reached: {}/{}", current_max_turns, self.max_turns);
413 }
414
415 agent_span.record("gen_ai.completion", &merged_texts);
416 agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
417 agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
418
419 return Ok(PromptResponse::new(merged_texts, usage));
421 }
422
423 let hook = self.hook.clone();
424 let tool_server_handle = self.tool_server_handle.clone();
425
426 let tool_calls: Vec<AssistantContent> = tool_calls.into_iter().cloned().collect();
427 let tool_content = stream::iter(tool_calls)
428 .map(|choice| {
429 let hook1 = hook.clone();
430 let hook2 = hook.clone();
431 let tool_server_handle = tool_server_handle.clone();
432
433 let tool_span = info_span!(
434 "execute_tool",
435 gen_ai.operation.name = "execute_tool",
436 gen_ai.tool.type = "function",
437 gen_ai.tool.name = tracing::field::Empty,
438 gen_ai.tool.call.id = tracing::field::Empty,
439 gen_ai.tool.call.arguments = tracing::field::Empty,
440 gen_ai.tool.call.result = tracing::field::Empty
441 );
442
443 let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
444 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
445 tool_span.follows_from(id).to_owned()
446 } else {
447 tool_span
448 };
449
450 if let Some(id) = tool_span.id() {
451 current_span_id.store(id.into_u64(), Ordering::SeqCst);
452 };
453
454 let cloned_chat_history = chat_history.clone().to_vec();
455
456 async move {
457 if let AssistantContent::ToolCall(tool_call) = choice {
458 let tool_name = &tool_call.function.name;
459 let args =
460 json_utils::value_to_json_string(&tool_call.function.arguments);
461 let internal_call_id = nanoid::nanoid!();
462 let tool_span = tracing::Span::current();
463 tool_span.record("gen_ai.tool.name", tool_name);
464 tool_span.record("gen_ai.tool.call.id", &tool_call.id);
465 tool_span.record("gen_ai.tool.call.arguments", &args);
466 if let Some(hook) = hook1 {
467 let action = hook
468 .on_tool_call(
469 tool_name,
470 tool_call.call_id.clone(),
471 &internal_call_id,
472 &args,
473 )
474 .await;
475
476 if let ToolCallHookAction::Terminate { reason } = action {
477 return Err(PromptError::prompt_cancelled(
478 cloned_chat_history,
479 reason,
480 ));
481 }
482
483 if let ToolCallHookAction::Skip { reason } = action {
484 tracing::info!(
486 tool_name = tool_name,
487 reason = reason,
488 "Tool call rejected"
489 );
490 if let Some(call_id) = tool_call.call_id.clone() {
491 return Ok(UserContent::tool_result_with_call_id(
492 tool_call.id.clone(),
493 call_id,
494 OneOrMany::one(reason.into()),
495 ));
496 } else {
497 return Ok(UserContent::tool_result(
498 tool_call.id.clone(),
499 OneOrMany::one(reason.into()),
500 ));
501 }
502 }
503 }
504 let output = match tool_server_handle.call_tool(tool_name, &args).await
505 {
506 Ok(res) => res,
507 Err(e) => {
508 tracing::warn!("Error while executing tool: {e}");
509 e.to_string()
510 }
511 };
512 if let Some(hook) = hook2
513 && let HookAction::Terminate { reason } = hook
514 .on_tool_result(
515 tool_name,
516 tool_call.call_id.clone(),
517 &internal_call_id,
518 &args,
519 &output.to_string(),
520 )
521 .await
522 {
523 return Err(PromptError::prompt_cancelled(
524 cloned_chat_history,
525 reason,
526 ));
527 }
528
529 tool_span.record("gen_ai.tool.call.result", &output);
530 tracing::info!(
531 "executed tool {tool_name} with args {args}. result: {output}"
532 );
533 if let Some(call_id) = tool_call.call_id.clone() {
534 Ok(UserContent::tool_result_with_call_id(
535 tool_call.id.clone(),
536 call_id,
537 ToolResultContent::from_tool_output(output),
538 ))
539 } else {
540 Ok(UserContent::tool_result(
541 tool_call.id.clone(),
542 ToolResultContent::from_tool_output(output),
543 ))
544 }
545 } else {
546 unreachable!(
547 "This should never happen as we already filtered for `ToolCall`"
548 )
549 }
550 }
551 .instrument(tool_span)
552 })
553 .buffer_unordered(self.concurrency)
554 .collect::<Vec<Result<UserContent, PromptError>>>()
555 .await
556 .into_iter()
557 .collect::<Result<Vec<_>, _>>()?;
558
559 chat_history.push(Message::User {
560 content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
561 });
562 };
563
564 Err(PromptError::MaxTurnsError {
566 max_turns: self.max_turns,
567 chat_history: Box::new(chat_history.clone()),
568 prompt: Box::new(last_prompt),
569 })
570 }
571}
572
573use crate::completion::StructuredOutputError;
578use schemars::{JsonSchema, schema_for};
579use serde::de::DeserializeOwned;
580
581pub struct TypedPromptRequest<'a, T, M, P>
595where
596 T: JsonSchema + DeserializeOwned + WasmCompatSend,
597 M: CompletionModel,
598 P: PromptHook<M>,
599{
600 inner: PromptRequest<'a, Standard, M, P>,
601 _phantom: std::marker::PhantomData<T>,
602}
603
604impl<'a, T, M, P> TypedPromptRequest<'a, T, M, P>
605where
606 T: JsonSchema + DeserializeOwned + WasmCompatSend,
607 M: CompletionModel,
608 P: PromptHook<M>,
609{
610 pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
614 let mut inner = PromptRequest::from_agent(agent, prompt);
615 inner.output_schema = Some(schema_for!(T));
617 Self {
618 inner,
619 _phantom: std::marker::PhantomData,
620 }
621 }
622
623 pub fn max_turns(mut self, depth: usize) -> Self {
629 self.inner = self.inner.max_turns(depth);
630 self
631 }
632
633 pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
637 self.inner = self.inner.with_tool_concurrency(concurrency);
638 self
639 }
640
641 pub fn with_history(mut self, history: &'a mut Vec<Message>) -> Self {
643 self.inner = self.inner.with_history(history);
644 self
645 }
646
647 pub fn with_hook<P2>(self, hook: P2) -> TypedPromptRequest<'a, T, M, P2>
651 where
652 P2: PromptHook<M>,
653 {
654 TypedPromptRequest {
655 inner: self.inner.with_hook(hook),
656 _phantom: std::marker::PhantomData,
657 }
658 }
659
660 async fn send(self) -> Result<T, StructuredOutputError> {
662 let response = self.inner.send().await?;
663
664 if response.is_empty() {
665 return Err(StructuredOutputError::EmptyResponse);
666 }
667
668 let parsed: T = serde_json::from_str(&response)?;
669 Ok(parsed)
670 }
671}
672
673impl<'a, T, M, P> IntoFuture for TypedPromptRequest<'a, T, M, P>
674where
675 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'a,
676 M: CompletionModel + 'a,
677 P: PromptHook<M> + 'static,
678{
679 type Output = Result<T, StructuredOutputError>;
680 type IntoFuture = WasmBoxedFuture<'a, Self::Output>;
681
682 fn into_future(self) -> Self::IntoFuture {
683 Box::pin(self.send())
684 }
685}