rig/agent/prompt_request/mod.rs
1pub(crate) mod streaming;
2
3use std::{future::IntoFuture, marker::PhantomData};
4
5use futures::{FutureExt, StreamExt, future::BoxFuture, stream};
6
7use crate::{
8 OneOrMany,
9 completion::{Completion, CompletionError, CompletionModel, Message, PromptError, Usage},
10 message::{AssistantContent, UserContent},
11 tool::ToolSetError,
12};
13
14use super::Agent;
15
16pub trait PromptType {}
17pub struct Standard;
18pub struct Extended;
19
20impl PromptType for Standard {}
21impl PromptType for Extended {}
22
23/// A builder for creating prompt requests with customizable options.
24/// Uses generics to track which options have been set during the build process.
25///
26/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
27/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
28/// attempting to await (which will send the prompt request) can potentially return
29/// [`crate::completion::request::PromptError::MaxDepthError`] if the agent decides to call tools
30/// back to back.
31pub struct PromptRequest<'a, S: PromptType, M: CompletionModel> {
32 /// The prompt message to send to the model
33 prompt: Message,
34 /// Optional chat history to include with the prompt
35 /// Note: chat history needs to outlive the agent as it might be used with other agents
36 chat_history: Option<&'a mut Vec<Message>>,
37 /// Maximum depth for multi-turn conversations (0 means no multi-turn)
38 max_depth: usize,
39 /// The agent to use for execution
40 agent: &'a Agent<M>,
41 /// Phantom data to track the type of the request
42 state: PhantomData<S>,
43 #[cfg(feature = "hooks")]
44 /// Optional per-request hook for events
45 hook: Option<&'a dyn crate::agent::PromptHook<M>>,
46}
47
48impl<'a, M: CompletionModel> PromptRequest<'a, Standard, M> {
49 /// Create a new PromptRequest with the given prompt and model
50 pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
51 Self {
52 prompt: prompt.into(),
53 chat_history: None,
54 max_depth: 0,
55 agent,
56 state: PhantomData,
57 #[cfg(feature = "hooks")]
58 hook: None,
59 }
60 }
61
62 /// Enable returning extended details for responses (includes aggregated token usage)
63 ///
64 /// Note: This changes the type of the response from `.send` to return a `PromptResponse` struct
65 /// instead of a simple `String`. This is useful for tracking token usage across multiple turns
66 /// of conversation.
67 pub fn extended_details(self) -> PromptRequest<'a, Extended, M> {
68 PromptRequest {
69 prompt: self.prompt,
70 chat_history: self.chat_history,
71 max_depth: self.max_depth,
72 agent: self.agent,
73 state: PhantomData,
74 #[cfg(feature = "hooks")]
75 hook: self.hook,
76 }
77 }
78}
79
80impl<'a, S: PromptType, M: CompletionModel> PromptRequest<'a, S, M> {
81 /// Set the maximum depth for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
82 /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxDepthError`].
83 pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, S, M> {
84 PromptRequest {
85 prompt: self.prompt,
86 chat_history: self.chat_history,
87 max_depth: depth,
88 agent: self.agent,
89 state: PhantomData,
90 #[cfg(feature = "hooks")]
91 hook: self.hook,
92 }
93 }
94
95 /// Add chat history to the prompt request
96 pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, S, M> {
97 PromptRequest {
98 prompt: self.prompt,
99 chat_history: Some(history),
100 max_depth: self.max_depth,
101 agent: self.agent,
102 state: PhantomData,
103 #[cfg(feature = "hooks")]
104 hook: self.hook,
105 }
106 }
107
108 #[cfg_attr(docsrs, doc(cfg(feature = "hooks")))]
109 #[cfg(feature = "hooks")]
110 /// Attach a per-request hook for tool call events
111 pub fn with_hook(self, hook: &'a dyn crate::agent::PromptHook<M>) -> PromptRequest<'a, S, M> {
112 PromptRequest {
113 prompt: self.prompt,
114 chat_history: self.chat_history,
115 max_depth: self.max_depth,
116 agent: self.agent,
117 state: PhantomData,
118 #[cfg(feature = "hooks")]
119 hook: Some(hook),
120 }
121 }
122}
123
124// dead code allowed because of functions being left empty to allow for users to not have to implement every single function
125/// Trait for per-request hooks to observe tool call events.
126/// Usage:
127/// ```rust
128///
129/// use std::env;
130///
131/// use rig::agent::PromptHook;
132/// use rig::client::CompletionClient;
133/// use rig::completion::{CompletionModel, CompletionResponse, Message, Prompt};
134/// use rig::message::{AssistantContent, UserContent};
135/// use rig::providers;
136///
137/// struct SessionIdHook<'a> {
138/// session_id: &'a str,
139/// }
140///
141/// #[async_trait::async_trait]
142/// impl<'a, M: CompletionModel> PromptHook<M> for SessionIdHook<'a> {
143/// async fn on_tool_call(&self, tool_name: &str, args: &str) {
144/// println!(
145/// "[Session {}] Calling tool: {} with args: {}",
146/// self.session_id, tool_name, args
147/// );
148/// }
149/// async fn on_tool_result(&self, tool_name: &str, args: &str, result: &str) {
150/// println!(
151/// "[Session {}] Tool result for {} (args: {}): {}",
152/// self.session_id, tool_name, args, result
153/// );
154/// }
155///
156/// async fn on_completion_call(&self, prompt: &Message, _history: &[Message]) {
157/// println!(
158/// "[Session {}] Sending prompt: {}",
159/// self.session_id,
160/// match prompt {
161/// Message::User { content } => content
162/// .iter()
163/// .filter_map(|c| {
164/// if let UserContent::Text(text_content) = c {
165/// Some(text_content.text.clone())
166/// } else {
167/// None
168/// }
169/// })
170/// .collect::<Vec<_>>()
171/// .join("\n"),
172/// Message::Assistant { content, .. } => content
173/// .iter()
174/// .filter_map(|c| if let AssistantContent::Text(text_content) = c {
175/// Some(text_content.text.clone())
176/// } else {
177/// None
178/// })
179/// .collect::<Vec<_>>()
180/// .join("\n"),
181/// }
182/// );
183/// }
184///
185/// async fn on_completion_response(
186/// &self,
187/// _prompt: &Message,
188/// response: &CompletionResponse<M::Response>,
189/// ) {
190/// if let Ok(resp) = serde_json::to_string(&response.raw_response) {
191/// println!("[Session {}] Received response: {}", self.session_id, resp);
192/// } else {
193/// println!(
194/// "[Session {}] Received response: <non-serializable>",
195/// self.session_id
196/// );
197/// }
198/// }
199/// }
200///
201/// // Example main function (pseudo-code, as actual Agent/CompletionModel setup is project-specific)
202/// #[tokio::main]
203/// async fn main() -> Result<(), anyhow::Error> {
204/// let client = providers::openai::Client::new(
205/// &env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"),
206/// );
207///
208/// // Create agent with a single context prompt
209/// let comedian_agent = client
210/// .agent("gpt-4o")
211/// .preamble("You are a comedian here to entertain the user using humour and jokes.")
212/// .build();
213///
214/// let session_id = "abc123";
215/// let hook = SessionIdHook { session_id };
216///
217/// // Prompt the agent and print the response
218/// comedian_agent
219/// .prompt("Entertain me!")
220/// .with_hook(&hook)
221/// .await?;
222///
223/// Ok(())
224/// }
225/// ```
226#[cfg_attr(docsrs, doc(cfg(feature = "hooks")))]
227#[cfg(feature = "hooks")]
228#[async_trait::async_trait]
229pub trait PromptHook<M: CompletionModel>: Send + Sync {
230 #[allow(unused_variables)]
231 /// Called before the prompt is sent to the model
232 async fn on_completion_call(&self, prompt: &Message, history: &[Message]) {}
233
234 #[allow(unused_variables)]
235 /// Called after the prompt is sent to the model and a response is received.
236 /// This function is for non-streamed responses. Please refer to `on_stream_completion_response_finish` for streamed responses.
237 async fn on_completion_response(
238 &self,
239 prompt: &Message,
240 response: &crate::completion::CompletionResponse<M::Response>,
241 ) {
242 }
243
244 #[allow(unused_variables)]
245 /// Called after the model provider has finished streaming a text response from their completion API to the client.
246 async fn on_stream_completion_response_finish(
247 &self,
248 prompt: &Message,
249 response: &<M as CompletionModel>::StreamingResponse,
250 ) {
251 }
252
253 #[allow(unused_variables)]
254 /// Called before a tool is invoked.
255 async fn on_tool_call(&self, tool_name: &str, args: &str) {}
256
257 #[allow(unused_variables)]
258 /// Called after a tool is invoked (and a result has been returned).
259 async fn on_tool_result(&self, tool_name: &str, args: &str, result: &str) {}
260}
261
262/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
263/// for the `IntoFuture` implementation. In the future, we should be able to use `impl Future<...>`
264/// directly via the associated type.
265impl<'a, M: CompletionModel> IntoFuture for PromptRequest<'a, Standard, M> {
266 type Output = Result<String, PromptError>;
267 type IntoFuture = BoxFuture<'a, Self::Output>; // This future should not outlive the agent
268
269 fn into_future(self) -> Self::IntoFuture {
270 self.send().boxed()
271 }
272}
273
274impl<'a, M: CompletionModel> IntoFuture for PromptRequest<'a, Extended, M> {
275 type Output = Result<PromptResponse, PromptError>;
276 type IntoFuture = BoxFuture<'a, Self::Output>; // This future should not outlive the agent
277
278 fn into_future(self) -> Self::IntoFuture {
279 self.send().boxed()
280 }
281}
282
283impl<M: CompletionModel> PromptRequest<'_, Standard, M> {
284 async fn send(self) -> Result<String, PromptError> {
285 self.extended_details().send().await.map(|resp| resp.output)
286 }
287}
288
289#[derive(Debug, Clone)]
290pub struct PromptResponse {
291 pub output: String,
292 pub total_usage: Usage,
293}
294
295impl PromptResponse {
296 pub fn new(output: impl Into<String>, total_usage: Usage) -> Self {
297 Self {
298 output: output.into(),
299 total_usage,
300 }
301 }
302}
303
304impl<M: CompletionModel> PromptRequest<'_, Extended, M> {
305 #[tracing::instrument(skip(self), fields(agent_name = self.agent.name()))]
306 async fn send(self) -> Result<PromptResponse, PromptError> {
307 let agent = self.agent;
308 let chat_history = if let Some(history) = self.chat_history {
309 history.push(self.prompt);
310 history
311 } else {
312 &mut vec![self.prompt]
313 };
314
315 let mut current_max_depth = 0;
316 let mut usage = Usage::new();
317
318 // We need to do at least 2 loops for 1 roundtrip (user expects normal message)
319 let last_prompt = loop {
320 let prompt = chat_history
321 .last()
322 .cloned()
323 .expect("there should always be at least one message in the chat history");
324
325 if current_max_depth > self.max_depth + 1 {
326 break prompt;
327 }
328
329 current_max_depth += 1;
330
331 if self.max_depth > 1 {
332 tracing::info!(
333 "Current conversation depth: {}/{}",
334 current_max_depth,
335 self.max_depth
336 );
337 }
338
339 #[cfg(feature = "hooks")]
340 if let Some(hook) = self.hook.as_ref() {
341 hook.on_completion_call(&prompt, &chat_history[..chat_history.len() - 1])
342 .await;
343 }
344
345 let resp = agent
346 .completion(
347 prompt.clone(),
348 chat_history[..chat_history.len() - 1].to_vec(),
349 )
350 .await?
351 .send()
352 .await?;
353
354 usage += resp.usage;
355
356 #[cfg(feature = "hooks")]
357 if let Some(hook) = self.hook.as_ref() {
358 hook.on_completion_response(&prompt, &resp).await;
359 }
360
361 let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
362 .choice
363 .iter()
364 .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
365
366 chat_history.push(Message::Assistant {
367 id: None,
368 content: resp.choice.clone(),
369 });
370
371 if tool_calls.is_empty() {
372 let merged_texts = texts
373 .into_iter()
374 .filter_map(|content| {
375 if let AssistantContent::Text(text) = content {
376 Some(text.text.clone())
377 } else {
378 None
379 }
380 })
381 .collect::<Vec<_>>()
382 .join("\n");
383
384 if self.max_depth > 1 {
385 tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
386 }
387
388 // If there are no tool calls, depth is not relevant, we can just return the merged text response.
389 return Ok(PromptResponse::new(merged_texts, usage));
390 }
391
392 let tool_content = stream::iter(tool_calls)
393 .then(|choice| async move {
394 if let AssistantContent::ToolCall(tool_call) = choice {
395 let tool_name = &tool_call.function.name;
396 let args = tool_call.function.arguments.to_string();
397 #[cfg(feature = "hooks")]
398 if let Some(hook) = self.hook.as_ref() {
399 hook.on_tool_call(tool_name, &args).await;
400 }
401 let output = agent.tools.call(tool_name, args.clone()).await?;
402 #[cfg(feature = "hooks")]
403 if let Some(hook) = self.hook.as_ref() {
404 hook.on_tool_result(tool_name, &args, &output.to_string())
405 .await;
406 }
407 if let Some(call_id) = tool_call.call_id.clone() {
408 Ok(UserContent::tool_result_with_call_id(
409 tool_call.id.clone(),
410 call_id,
411 OneOrMany::one(output.into()),
412 ))
413 } else {
414 Ok(UserContent::tool_result(
415 tool_call.id.clone(),
416 OneOrMany::one(output.into()),
417 ))
418 }
419 } else {
420 unreachable!(
421 "This should never happen as we already filtered for `ToolCall`"
422 )
423 }
424 })
425 .collect::<Vec<Result<UserContent, ToolSetError>>>()
426 .await
427 .into_iter()
428 .collect::<Result<Vec<_>, _>>()
429 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
430
431 chat_history.push(Message::User {
432 content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
433 });
434 };
435
436 // If we reach here, we never resolved the final tool call. We need to do ... something.
437 Err(PromptError::MaxDepthError {
438 max_depth: self.max_depth,
439 chat_history: chat_history.clone(),
440 prompt: last_prompt,
441 })
442 }
443}