1pub mod streaming;
2
3pub use streaming::StreamingPromptHook;
4
5use std::{
6 future::IntoFuture,
7 marker::PhantomData,
8 sync::{
9 Arc,
10 atomic::{AtomicBool, AtomicU64, Ordering},
11 },
12};
13use tracing::{Instrument, span::Id};
14
15use futures::{StreamExt, stream};
16use tracing::info_span;
17
18use crate::{
19 OneOrMany,
20 completion::{Completion, CompletionModel, Message, PromptError, Usage},
21 message::{AssistantContent, UserContent},
22 tool::ToolSetError,
23 wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
24};
25
26use super::Agent;
27
28pub trait PromptType {}
29pub struct Standard;
30pub struct Extended;
31
32impl PromptType for Standard {}
33impl PromptType for Extended {}
34
35pub struct PromptRequest<'a, S, M, P>
44where
45 S: PromptType,
46 M: CompletionModel,
47 P: PromptHook<M>,
48{
49 prompt: Message,
51 chat_history: Option<&'a mut Vec<Message>>,
54 max_depth: usize,
56 agent: &'a Agent<M>,
58 state: PhantomData<S>,
60 hook: Option<P>,
62}
63
64impl<'a, M> PromptRequest<'a, Standard, M, ()>
65where
66 M: CompletionModel,
67{
68 pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
70 Self {
71 prompt: prompt.into(),
72 chat_history: None,
73 max_depth: 0,
74 agent,
75 state: PhantomData,
76 hook: None,
77 }
78 }
79}
80
81impl<'a, S, M, P> PromptRequest<'a, S, M, P>
82where
83 S: PromptType,
84 M: CompletionModel,
85 P: PromptHook<M>,
86{
87 pub fn extended_details(self) -> PromptRequest<'a, Extended, M, P> {
93 PromptRequest {
94 prompt: self.prompt,
95 chat_history: self.chat_history,
96 max_depth: self.max_depth,
97 agent: self.agent,
98 state: PhantomData,
99 hook: self.hook,
100 }
101 }
102 pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, S, M, P> {
105 PromptRequest {
106 prompt: self.prompt,
107 chat_history: self.chat_history,
108 max_depth: depth,
109 agent: self.agent,
110 state: PhantomData,
111 hook: self.hook,
112 }
113 }
114
115 pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, S, M, P> {
117 PromptRequest {
118 prompt: self.prompt,
119 chat_history: Some(history),
120 max_depth: self.max_depth,
121 agent: self.agent,
122 state: PhantomData,
123 hook: self.hook,
124 }
125 }
126
127 pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<'a, S, M, P2>
129 where
130 P2: PromptHook<M>,
131 {
132 PromptRequest {
133 prompt: self.prompt,
134 chat_history: self.chat_history,
135 max_depth: self.max_depth,
136 agent: self.agent,
137 state: PhantomData,
138 hook: Some(hook),
139 }
140 }
141}
142
143pub struct CancelSignal(Arc<AtomicBool>);
144
145impl CancelSignal {
146 fn new() -> Self {
147 Self(Arc::new(AtomicBool::new(false)))
148 }
149
150 pub fn cancel(&self) {
151 self.0.store(true, Ordering::SeqCst);
152 }
153
154 fn is_cancelled(&self) -> bool {
155 self.0.load(Ordering::SeqCst)
156 }
157}
158
159impl Clone for CancelSignal {
160 fn clone(&self) -> Self {
161 Self(self.0.clone())
162 }
163}
164
165pub trait PromptHook<M>: Clone + WasmCompatSend + WasmCompatSync
168where
169 M: CompletionModel,
170{
171 #[allow(unused_variables)]
172 fn on_completion_call(
174 &self,
175 prompt: &Message,
176 history: &[Message],
177 cancel_sig: CancelSignal,
178 ) -> impl Future<Output = ()> + WasmCompatSend {
179 async {}
180 }
181
182 #[allow(unused_variables)]
183 fn on_completion_response(
185 &self,
186 prompt: &Message,
187 response: &crate::completion::CompletionResponse<M::Response>,
188 cancel_sig: CancelSignal,
189 ) -> impl Future<Output = ()> + WasmCompatSend {
190 async {}
191 }
192
193 #[allow(unused_variables)]
194 fn on_tool_call(
196 &self,
197 tool_name: &str,
198 args: &str,
199 cancel_sig: CancelSignal,
200 ) -> impl Future<Output = ()> + WasmCompatSend {
201 async {}
202 }
203
204 #[allow(unused_variables)]
205 fn on_tool_result(
207 &self,
208 tool_name: &str,
209 args: &str,
210 result: &str,
211 cancel_sig: CancelSignal,
212 ) -> impl Future<Output = ()> + WasmCompatSend {
213 async {}
214 }
215}
216
217impl<M> PromptHook<M> for () where M: CompletionModel {}
218
219impl<'a, M, P> IntoFuture for PromptRequest<'a, Standard, M, P>
223where
224 M: CompletionModel,
225 P: PromptHook<M> + 'static,
226{
227 type Output = Result<String, PromptError>;
228 type IntoFuture = WasmBoxedFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
231 Box::pin(self.send())
232 }
233}
234
235impl<'a, M, P> IntoFuture for PromptRequest<'a, Extended, M, P>
236where
237 M: CompletionModel,
238 P: PromptHook<M> + 'static,
239{
240 type Output = Result<PromptResponse, PromptError>;
241 type IntoFuture = WasmBoxedFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
244 Box::pin(self.send())
245 }
246}
247
248impl<M, P> PromptRequest<'_, Standard, M, P>
249where
250 M: CompletionModel,
251 P: PromptHook<M>,
252{
253 async fn send(self) -> Result<String, PromptError> {
254 self.extended_details().send().await.map(|resp| resp.output)
255 }
256}
257
258#[derive(Debug, Clone)]
259pub struct PromptResponse {
260 pub output: String,
261 pub total_usage: Usage,
262}
263
264impl PromptResponse {
265 pub fn new(output: impl Into<String>, total_usage: Usage) -> Self {
266 Self {
267 output: output.into(),
268 total_usage,
269 }
270 }
271}
272
273impl<M, P> PromptRequest<'_, Extended, M, P>
274where
275 M: CompletionModel,
276 P: PromptHook<M>,
277{
278 async fn send(self) -> Result<PromptResponse, PromptError> {
279 let agent_span = if tracing::Span::current().is_disabled() {
280 info_span!(
281 "invoke_agent",
282 gen_ai.operation.name = "invoke_agent",
283 gen_ai.agent.name = self.agent.name(),
284 gen_ai.system_instructions = self.agent.preamble,
285 gen_ai.prompt = tracing::field::Empty,
286 gen_ai.completion = tracing::field::Empty,
287 gen_ai.usage.input_tokens = tracing::field::Empty,
288 gen_ai.usage.output_tokens = tracing::field::Empty,
289 )
290 } else {
291 tracing::Span::current()
292 };
293
294 let agent = self.agent;
295 let chat_history = if let Some(history) = self.chat_history {
296 history.push(self.prompt.to_owned());
297 history
298 } else {
299 &mut vec![self.prompt.to_owned()]
300 };
301
302 if let Some(text) = self.prompt.rag_text() {
303 agent_span.record("gen_ai.prompt", text);
304 }
305
306 let cancel_sig = CancelSignal::new();
307
308 let mut current_max_depth = 0;
309 let mut usage = Usage::new();
310 let current_span_id: AtomicU64 = AtomicU64::new(0);
311
312 let last_prompt = loop {
314 let prompt = chat_history
315 .last()
316 .cloned()
317 .expect("there should always be at least one message in the chat history");
318
319 if current_max_depth > self.max_depth + 1 {
320 break prompt;
321 }
322
323 current_max_depth += 1;
324
325 if self.max_depth > 1 {
326 tracing::info!(
327 "Current conversation depth: {}/{}",
328 current_max_depth,
329 self.max_depth
330 );
331 }
332
333 if let Some(ref hook) = self.hook {
334 hook.on_completion_call(
335 &prompt,
336 &chat_history[..chat_history.len() - 1],
337 cancel_sig.clone(),
338 )
339 .await;
340 if cancel_sig.is_cancelled() {
341 return Err(PromptError::prompt_cancelled(chat_history.to_vec()));
342 }
343 }
344 let span = tracing::Span::current();
345 let chat_span = info_span!(
346 target: "rig::agent_chat",
347 parent: &span,
348 "chat",
349 gen_ai.operation.name = "chat",
350 gen_ai.system_instructions = self.agent.preamble,
351 gen_ai.provider.name = tracing::field::Empty,
352 gen_ai.request.model = tracing::field::Empty,
353 gen_ai.response.id = tracing::field::Empty,
354 gen_ai.response.model = tracing::field::Empty,
355 gen_ai.usage.output_tokens = tracing::field::Empty,
356 gen_ai.usage.input_tokens = tracing::field::Empty,
357 gen_ai.input.messages = tracing::field::Empty,
358 gen_ai.output.messages = tracing::field::Empty,
359 );
360
361 let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
362 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
363 chat_span.follows_from(id).to_owned()
364 } else {
365 chat_span
366 };
367
368 if let Some(id) = chat_span.id() {
369 current_span_id.store(id.into_u64(), Ordering::SeqCst);
370 };
371
372 let resp = agent
373 .completion(
374 prompt.clone(),
375 chat_history[..chat_history.len() - 1].to_vec(),
376 )
377 .await?
378 .send()
379 .instrument(chat_span.clone())
380 .await?;
381
382 usage += resp.usage;
383
384 if let Some(ref hook) = self.hook {
385 hook.on_completion_response(&prompt, &resp, cancel_sig.clone())
386 .await;
387 if cancel_sig.is_cancelled() {
388 return Err(PromptError::prompt_cancelled(chat_history.to_vec()));
389 }
390 }
391
392 let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
393 .choice
394 .iter()
395 .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
396
397 chat_history.push(Message::Assistant {
398 id: None,
399 content: resp.choice.clone(),
400 });
401
402 if tool_calls.is_empty() {
403 let merged_texts = texts
404 .into_iter()
405 .filter_map(|content| {
406 if let AssistantContent::Text(text) = content {
407 Some(text.text.clone())
408 } else {
409 None
410 }
411 })
412 .collect::<Vec<_>>()
413 .join("\n");
414
415 if self.max_depth > 1 {
416 tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
417 }
418
419 agent_span.record("gen_ai.completion", &merged_texts);
420 agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
421 agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
422
423 return Ok(PromptResponse::new(merged_texts, usage));
425 }
426
427 let hook = self.hook.clone();
428 let tool_content = stream::iter(tool_calls)
429 .then(|choice| {
430 let hook1 = hook.clone();
431 let hook2 = hook.clone();
432
433 let cancel_sig1 = cancel_sig.clone();
434 let cancel_sig2 = cancel_sig.clone();
435
436 let tool_span = info_span!(
437 "execute_tool",
438 gen_ai.operation.name = "execute_tool",
439 gen_ai.tool.type = "function",
440 gen_ai.tool.name = tracing::field::Empty,
441 gen_ai.tool.call.id = tracing::field::Empty,
442 gen_ai.tool.call.arguments = tracing::field::Empty,
443 gen_ai.tool.call.result = tracing::field::Empty
444 );
445
446 let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
447 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
448 tool_span.follows_from(id).to_owned()
449 } else {
450 tool_span
451 };
452
453 if let Some(id) = tool_span.id() {
454 current_span_id.store(id.into_u64(), Ordering::SeqCst);
455 };
456
457 async move {
458 if let AssistantContent::ToolCall(tool_call) = choice {
459 let tool_name = &tool_call.function.name;
460 let args = tool_call.function.arguments.to_string();
461 let tool_span = tracing::Span::current();
462 tool_span.record("gen_ai.tool.name", tool_name);
463 tool_span.record("gen_ai.tool.call.id", &tool_call.id);
464 tool_span.record("gen_ai.tool.call.arguments", &args);
465 if let Some(hook) = hook1 {
466 hook.on_tool_call(tool_name, &args, cancel_sig1.clone())
467 .await;
468 if cancel_sig1.is_cancelled() {
469 return Err(ToolSetError::Interrupted);
470 }
471 }
472 let output =
473 match agent.tool_server_handle.call_tool(tool_name, &args).await {
474 Ok(res) => res,
475 Err(e) => {
476 tracing::warn!("Error while executing tool: {e}");
477 e.to_string()
478 }
479 };
480 if let Some(hook) = hook2 {
481 hook.on_tool_result(
482 tool_name,
483 &args,
484 &output.to_string(),
485 cancel_sig2.clone(),
486 )
487 .await;
488
489 if cancel_sig2.is_cancelled() {
490 return Err(ToolSetError::Interrupted);
491 }
492 }
493 tool_span.record("gen_ai.tool.call.result", &output);
494 tracing::info!(
495 "executed tool {tool_name} with args {args}. result: {output}"
496 );
497 if let Some(call_id) = tool_call.call_id.clone() {
498 Ok(UserContent::tool_result_with_call_id(
499 tool_call.id.clone(),
500 call_id,
501 OneOrMany::one(output.into()),
502 ))
503 } else {
504 Ok(UserContent::tool_result(
505 tool_call.id.clone(),
506 OneOrMany::one(output.into()),
507 ))
508 }
509 } else {
510 unreachable!(
511 "This should never happen as we already filtered for `ToolCall`"
512 )
513 }
514 }
515 .instrument(tool_span)
516 })
517 .collect::<Vec<Result<UserContent, ToolSetError>>>()
518 .await
519 .into_iter()
520 .collect::<Result<Vec<_>, _>>()
521 .map_err(|e| {
522 if matches!(e, ToolSetError::Interrupted) {
523 PromptError::prompt_cancelled(chat_history.to_vec())
524 } else {
525 e.into()
526 }
527 })?;
528
529 chat_history.push(Message::User {
530 content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
531 });
532 };
533
534 Err(PromptError::MaxDepthError {
536 max_depth: self.max_depth,
537 chat_history: Box::new(chat_history.clone()),
538 prompt: last_prompt,
539 })
540 }
541}