1pub(crate) mod streaming;
2
3use std::{
4 future::IntoFuture,
5 marker::PhantomData,
6 sync::{
7 Arc,
8 atomic::{AtomicBool, AtomicU64, Ordering},
9 },
10};
11use tracing::{Instrument, span::Id};
12
13use futures::{StreamExt, stream};
14use tracing::info_span;
15
16use crate::{
17 OneOrMany,
18 completion::{Completion, CompletionModel, Message, PromptError, Usage},
19 message::{AssistantContent, UserContent},
20 tool::ToolSetError,
21 wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
22};
23
24use super::Agent;
25
26pub trait PromptType {}
27pub struct Standard;
28pub struct Extended;
29
30impl PromptType for Standard {}
31impl PromptType for Extended {}
32
33pub struct PromptRequest<'a, S, M, P>
42where
43 S: PromptType,
44 M: CompletionModel,
45 P: PromptHook<M>,
46{
47 prompt: Message,
49 chat_history: Option<&'a mut Vec<Message>>,
52 max_depth: usize,
54 agent: &'a Agent<M>,
56 state: PhantomData<S>,
58 hook: Option<P>,
60}
61
62impl<'a, M> PromptRequest<'a, Standard, M, ()>
63where
64 M: CompletionModel,
65{
66 pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
68 Self {
69 prompt: prompt.into(),
70 chat_history: None,
71 max_depth: 0,
72 agent,
73 state: PhantomData,
74 hook: None,
75 }
76 }
77}
78
79impl<'a, S, M, P> PromptRequest<'a, S, M, P>
80where
81 S: PromptType,
82 M: CompletionModel,
83 P: PromptHook<M>,
84{
85 pub fn extended_details(self) -> PromptRequest<'a, Extended, M, P> {
91 PromptRequest {
92 prompt: self.prompt,
93 chat_history: self.chat_history,
94 max_depth: self.max_depth,
95 agent: self.agent,
96 state: PhantomData,
97 hook: self.hook,
98 }
99 }
100 pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, S, M, P> {
103 PromptRequest {
104 prompt: self.prompt,
105 chat_history: self.chat_history,
106 max_depth: depth,
107 agent: self.agent,
108 state: PhantomData,
109 hook: self.hook,
110 }
111 }
112
113 pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, S, M, P> {
115 PromptRequest {
116 prompt: self.prompt,
117 chat_history: Some(history),
118 max_depth: self.max_depth,
119 agent: self.agent,
120 state: PhantomData,
121 hook: self.hook,
122 }
123 }
124
125 pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<'a, S, M, P2>
127 where
128 P2: PromptHook<M>,
129 {
130 PromptRequest {
131 prompt: self.prompt,
132 chat_history: self.chat_history,
133 max_depth: self.max_depth,
134 agent: self.agent,
135 state: PhantomData,
136 hook: Some(hook),
137 }
138 }
139}
140
141pub struct CancelSignal(Arc<AtomicBool>);
142
143impl CancelSignal {
144 fn new() -> Self {
145 Self(Arc::new(AtomicBool::new(false)))
146 }
147
148 pub fn cancel(&self) {
149 self.0.store(true, Ordering::SeqCst);
150 }
151
152 fn is_cancelled(&self) -> bool {
153 self.0.load(Ordering::SeqCst)
154 }
155}
156
157impl Clone for CancelSignal {
158 fn clone(&self) -> Self {
159 Self(self.0.clone())
160 }
161}
162
163pub trait PromptHook<M>: Clone + WasmCompatSend + WasmCompatSync
166where
167 M: CompletionModel,
168{
169 #[allow(unused_variables)]
170 fn on_completion_call(
172 &self,
173 prompt: &Message,
174 history: &[Message],
175 cancel_sig: CancelSignal,
176 ) -> impl Future<Output = ()> + WasmCompatSend {
177 async {}
178 }
179
180 #[allow(unused_variables)]
181 fn on_completion_response(
183 &self,
184 prompt: &Message,
185 response: &crate::completion::CompletionResponse<M::Response>,
186 cancel_sig: CancelSignal,
187 ) -> impl Future<Output = ()> + WasmCompatSend {
188 async {}
189 }
190
191 #[allow(unused_variables)]
192 fn on_tool_call(
194 &self,
195 tool_name: &str,
196 args: &str,
197 cancel_sig: CancelSignal,
198 ) -> impl Future<Output = ()> + WasmCompatSend {
199 async {}
200 }
201
202 #[allow(unused_variables)]
203 fn on_tool_result(
205 &self,
206 tool_name: &str,
207 args: &str,
208 result: &str,
209 cancel_sig: CancelSignal,
210 ) -> impl Future<Output = ()> + WasmCompatSend {
211 async {}
212 }
213}
214
215impl<M> PromptHook<M> for () where M: CompletionModel {}
216
217impl<'a, M, P> IntoFuture for PromptRequest<'a, Standard, M, P>
221where
222 M: CompletionModel,
223 P: PromptHook<M> + 'static,
224{
225 type Output = Result<String, PromptError>;
226 type IntoFuture = WasmBoxedFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
229 Box::pin(self.send())
230 }
231}
232
233impl<'a, M, P> IntoFuture for PromptRequest<'a, Extended, M, P>
234where
235 M: CompletionModel,
236 P: PromptHook<M> + 'static,
237{
238 type Output = Result<PromptResponse, PromptError>;
239 type IntoFuture = WasmBoxedFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
242 Box::pin(self.send())
243 }
244}
245
246impl<M, P> PromptRequest<'_, Standard, M, P>
247where
248 M: CompletionModel,
249 P: PromptHook<M>,
250{
251 async fn send(self) -> Result<String, PromptError> {
252 self.extended_details().send().await.map(|resp| resp.output)
253 }
254}
255
256#[derive(Debug, Clone)]
257pub struct PromptResponse {
258 pub output: String,
259 pub total_usage: Usage,
260}
261
262impl PromptResponse {
263 pub fn new(output: impl Into<String>, total_usage: Usage) -> Self {
264 Self {
265 output: output.into(),
266 total_usage,
267 }
268 }
269}
270
271impl<M, P> PromptRequest<'_, Extended, M, P>
272where
273 M: CompletionModel,
274 P: PromptHook<M>,
275{
276 async fn send(self) -> Result<PromptResponse, PromptError> {
277 let agent_span = if tracing::Span::current().is_disabled() {
278 info_span!(
279 "invoke_agent",
280 gen_ai.operation.name = "invoke_agent",
281 gen_ai.agent.name = self.agent.name(),
282 gen_ai.system_instructions = self.agent.preamble,
283 gen_ai.prompt = tracing::field::Empty,
284 gen_ai.completion = tracing::field::Empty,
285 gen_ai.usage.input_tokens = tracing::field::Empty,
286 gen_ai.usage.output_tokens = tracing::field::Empty,
287 )
288 } else {
289 tracing::Span::current()
290 };
291
292 let agent = self.agent;
293 let chat_history = if let Some(history) = self.chat_history {
294 history.push(self.prompt.to_owned());
295 history
296 } else {
297 &mut vec![self.prompt.to_owned()]
298 };
299
300 if let Some(text) = self.prompt.rag_text() {
301 agent_span.record("gen_ai.prompt", text);
302 }
303
304 let cancel_sig = CancelSignal::new();
305
306 let mut current_max_depth = 0;
307 let mut usage = Usage::new();
308 let current_span_id: AtomicU64 = AtomicU64::new(0);
309
310 let last_prompt = loop {
312 let prompt = chat_history
313 .last()
314 .cloned()
315 .expect("there should always be at least one message in the chat history");
316
317 if current_max_depth > self.max_depth + 1 {
318 break prompt;
319 }
320
321 current_max_depth += 1;
322
323 if self.max_depth > 1 {
324 tracing::info!(
325 "Current conversation depth: {}/{}",
326 current_max_depth,
327 self.max_depth
328 );
329 }
330
331 if let Some(ref hook) = self.hook {
332 hook.on_completion_call(
333 &prompt,
334 &chat_history[..chat_history.len() - 1],
335 cancel_sig.clone(),
336 )
337 .await;
338 if cancel_sig.is_cancelled() {
339 return Err(PromptError::prompt_cancelled(chat_history.to_vec()));
340 }
341 }
342 let span = tracing::Span::current();
343 let chat_span = info_span!(
344 target: "rig::agent_chat",
345 parent: &span,
346 "chat",
347 gen_ai.operation.name = "chat",
348 gen_ai.system_instructions = self.agent.preamble,
349 gen_ai.provider.name = tracing::field::Empty,
350 gen_ai.request.model = tracing::field::Empty,
351 gen_ai.response.id = tracing::field::Empty,
352 gen_ai.response.model = tracing::field::Empty,
353 gen_ai.usage.output_tokens = tracing::field::Empty,
354 gen_ai.usage.input_tokens = tracing::field::Empty,
355 gen_ai.input.messages = tracing::field::Empty,
356 gen_ai.output.messages = tracing::field::Empty,
357 );
358
359 let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
360 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
361 chat_span.follows_from(id).to_owned()
362 } else {
363 chat_span
364 };
365
366 if let Some(id) = chat_span.id() {
367 current_span_id.store(id.into_u64(), Ordering::SeqCst);
368 };
369
370 let resp = agent
371 .completion(
372 prompt.clone(),
373 chat_history[..chat_history.len() - 1].to_vec(),
374 )
375 .await?
376 .send()
377 .instrument(chat_span.clone())
378 .await?;
379
380 usage += resp.usage;
381
382 if let Some(ref hook) = self.hook {
383 hook.on_completion_response(&prompt, &resp, cancel_sig.clone())
384 .await;
385 if cancel_sig.is_cancelled() {
386 return Err(PromptError::prompt_cancelled(chat_history.to_vec()));
387 }
388 }
389
390 let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
391 .choice
392 .iter()
393 .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
394
395 chat_history.push(Message::Assistant {
396 id: None,
397 content: resp.choice.clone(),
398 });
399
400 if tool_calls.is_empty() {
401 let merged_texts = texts
402 .into_iter()
403 .filter_map(|content| {
404 if let AssistantContent::Text(text) = content {
405 Some(text.text.clone())
406 } else {
407 None
408 }
409 })
410 .collect::<Vec<_>>()
411 .join("\n");
412
413 if self.max_depth > 1 {
414 tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
415 }
416
417 agent_span.record("gen_ai.completion", &merged_texts);
418 agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
419 agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
420
421 return Ok(PromptResponse::new(merged_texts, usage));
423 }
424
425 let hook = self.hook.clone();
426 let tool_content = stream::iter(tool_calls)
427 .then(|choice| {
428 let hook1 = hook.clone();
429 let hook2 = hook.clone();
430
431 let cancel_sig1 = cancel_sig.clone();
432 let cancel_sig2 = cancel_sig.clone();
433
434 let tool_span = info_span!(
435 "execute_tool",
436 gen_ai.operation.name = "execute_tool",
437 gen_ai.tool.type = "function",
438 gen_ai.tool.name = tracing::field::Empty,
439 gen_ai.tool.call.id = tracing::field::Empty,
440 gen_ai.tool.call.arguments = tracing::field::Empty,
441 gen_ai.tool.call.result = tracing::field::Empty
442 );
443
444 let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
445 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
446 tool_span.follows_from(id).to_owned()
447 } else {
448 tool_span
449 };
450
451 if let Some(id) = tool_span.id() {
452 current_span_id.store(id.into_u64(), Ordering::SeqCst);
453 };
454
455 async move {
456 if let AssistantContent::ToolCall(tool_call) = choice {
457 let tool_name = &tool_call.function.name;
458 let args = tool_call.function.arguments.to_string();
459 let tool_span = tracing::Span::current();
460 tool_span.record("gen_ai.tool.name", tool_name);
461 tool_span.record("gen_ai.tool.call.id", &tool_call.id);
462 tool_span.record("gen_ai.tool.call.arguments", &args);
463 if let Some(hook) = hook1 {
464 hook.on_tool_call(tool_name, &args, cancel_sig1.clone())
465 .await;
466 if cancel_sig1.is_cancelled() {
467 return Err(ToolSetError::Interrupted);
468 }
469 }
470 let output =
471 match agent.tool_server_handle.call_tool(tool_name, &args).await {
472 Ok(res) => res,
473 Err(e) => {
474 tracing::warn!("Error while executing tool: {e}");
475 e.to_string()
476 }
477 };
478 if let Some(hook) = hook2 {
479 hook.on_tool_result(
480 tool_name,
481 &args,
482 &output.to_string(),
483 cancel_sig2.clone(),
484 )
485 .await;
486
487 if cancel_sig2.is_cancelled() {
488 return Err(ToolSetError::Interrupted);
489 }
490 }
491 tool_span.record("gen_ai.tool.call.result", &output);
492 tracing::info!(
493 "executed tool {tool_name} with args {args}. result: {output}"
494 );
495 if let Some(call_id) = tool_call.call_id.clone() {
496 Ok(UserContent::tool_result_with_call_id(
497 tool_call.id.clone(),
498 call_id,
499 OneOrMany::one(output.into()),
500 ))
501 } else {
502 Ok(UserContent::tool_result(
503 tool_call.id.clone(),
504 OneOrMany::one(output.into()),
505 ))
506 }
507 } else {
508 unreachable!(
509 "This should never happen as we already filtered for `ToolCall`"
510 )
511 }
512 }
513 .instrument(tool_span)
514 })
515 .collect::<Vec<Result<UserContent, ToolSetError>>>()
516 .await
517 .into_iter()
518 .collect::<Result<Vec<_>, _>>()
519 .map_err(|e| {
520 if matches!(e, ToolSetError::Interrupted) {
521 PromptError::prompt_cancelled(chat_history.to_vec())
522 } else {
523 e.into()
524 }
525 })?;
526
527 chat_history.push(Message::User {
528 content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
529 });
530 };
531
532 Err(PromptError::MaxDepthError {
534 max_depth: self.max_depth,
535 chat_history: Box::new(chat_history.clone()),
536 prompt: last_prompt,
537 })
538 }
539}