1pub mod streaming;
2
3pub use streaming::StreamingPromptHook;
4
5use std::{
6 future::IntoFuture,
7 marker::PhantomData,
8 sync::{
9 Arc,
10 atomic::{AtomicBool, AtomicU64, Ordering},
11 },
12};
13use tracing::{Instrument, span::Id};
14
15use futures::{StreamExt, stream};
16use tracing::info_span;
17
18use crate::{
19 OneOrMany,
20 completion::{Completion, CompletionModel, Message, PromptError, Usage},
21 json_utils,
22 message::{AssistantContent, UserContent},
23 tool::ToolSetError,
24 wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
25};
26
27use super::Agent;
28
29pub trait PromptType {}
30pub struct Standard;
31pub struct Extended;
32
33impl PromptType for Standard {}
34impl PromptType for Extended {}
35
36pub struct PromptRequest<'a, S, M, P>
45where
46 S: PromptType,
47 M: CompletionModel,
48 P: PromptHook<M>,
49{
50 prompt: Message,
52 chat_history: Option<&'a mut Vec<Message>>,
55 max_depth: usize,
57 agent: &'a Agent<M>,
59 state: PhantomData<S>,
61 hook: Option<P>,
63}
64
65impl<'a, M> PromptRequest<'a, Standard, M, ()>
66where
67 M: CompletionModel,
68{
69 pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
71 Self {
72 prompt: prompt.into(),
73 chat_history: None,
74 max_depth: 0,
75 agent,
76 state: PhantomData,
77 hook: None,
78 }
79 }
80}
81
82impl<'a, S, M, P> PromptRequest<'a, S, M, P>
83where
84 S: PromptType,
85 M: CompletionModel,
86 P: PromptHook<M>,
87{
88 pub fn extended_details(self) -> PromptRequest<'a, Extended, M, P> {
94 PromptRequest {
95 prompt: self.prompt,
96 chat_history: self.chat_history,
97 max_depth: self.max_depth,
98 agent: self.agent,
99 state: PhantomData,
100 hook: self.hook,
101 }
102 }
103 pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, S, M, P> {
106 PromptRequest {
107 prompt: self.prompt,
108 chat_history: self.chat_history,
109 max_depth: depth,
110 agent: self.agent,
111 state: PhantomData,
112 hook: self.hook,
113 }
114 }
115
116 pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, S, M, P> {
118 PromptRequest {
119 prompt: self.prompt,
120 chat_history: Some(history),
121 max_depth: self.max_depth,
122 agent: self.agent,
123 state: PhantomData,
124 hook: self.hook,
125 }
126 }
127
128 pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<'a, S, M, P2>
130 where
131 P2: PromptHook<M>,
132 {
133 PromptRequest {
134 prompt: self.prompt,
135 chat_history: self.chat_history,
136 max_depth: self.max_depth,
137 agent: self.agent,
138 state: PhantomData,
139 hook: Some(hook),
140 }
141 }
142}
143
144pub struct CancelSignal(Arc<AtomicBool>);
145
146impl CancelSignal {
147 fn new() -> Self {
148 Self(Arc::new(AtomicBool::new(false)))
149 }
150
151 pub fn cancel(&self) {
152 self.0.store(true, Ordering::SeqCst);
153 }
154
155 fn is_cancelled(&self) -> bool {
156 self.0.load(Ordering::SeqCst)
157 }
158}
159
160impl Clone for CancelSignal {
161 fn clone(&self) -> Self {
162 Self(self.0.clone())
163 }
164}
165
166pub trait PromptHook<M>: Clone + WasmCompatSend + WasmCompatSync
169where
170 M: CompletionModel,
171{
172 #[allow(unused_variables)]
173 fn on_completion_call(
175 &self,
176 prompt: &Message,
177 history: &[Message],
178 cancel_sig: CancelSignal,
179 ) -> impl Future<Output = ()> + WasmCompatSend {
180 async {}
181 }
182
183 #[allow(unused_variables)]
184 fn on_completion_response(
186 &self,
187 prompt: &Message,
188 response: &crate::completion::CompletionResponse<M::Response>,
189 cancel_sig: CancelSignal,
190 ) -> impl Future<Output = ()> + WasmCompatSend {
191 async {}
192 }
193
194 #[allow(unused_variables)]
195 fn on_tool_call(
197 &self,
198 tool_name: &str,
199 args: &str,
200 cancel_sig: CancelSignal,
201 ) -> impl Future<Output = ()> + WasmCompatSend {
202 async {}
203 }
204
205 #[allow(unused_variables)]
206 fn on_tool_result(
208 &self,
209 tool_name: &str,
210 args: &str,
211 result: &str,
212 cancel_sig: CancelSignal,
213 ) -> impl Future<Output = ()> + WasmCompatSend {
214 async {}
215 }
216}
217
218impl<M> PromptHook<M> for () where M: CompletionModel {}
219
220impl<'a, M, P> IntoFuture for PromptRequest<'a, Standard, M, P>
224where
225 M: CompletionModel,
226 P: PromptHook<M> + 'static,
227{
228 type Output = Result<String, PromptError>;
229 type IntoFuture = WasmBoxedFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
232 Box::pin(self.send())
233 }
234}
235
236impl<'a, M, P> IntoFuture for PromptRequest<'a, Extended, M, P>
237where
238 M: CompletionModel,
239 P: PromptHook<M> + 'static,
240{
241 type Output = Result<PromptResponse, PromptError>;
242 type IntoFuture = WasmBoxedFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
245 Box::pin(self.send())
246 }
247}
248
249impl<M, P> PromptRequest<'_, Standard, M, P>
250where
251 M: CompletionModel,
252 P: PromptHook<M>,
253{
254 async fn send(self) -> Result<String, PromptError> {
255 self.extended_details().send().await.map(|resp| resp.output)
256 }
257}
258
259#[derive(Debug, Clone)]
260pub struct PromptResponse {
261 pub output: String,
262 pub total_usage: Usage,
263}
264
265impl PromptResponse {
266 pub fn new(output: impl Into<String>, total_usage: Usage) -> Self {
267 Self {
268 output: output.into(),
269 total_usage,
270 }
271 }
272}
273
274impl<M, P> PromptRequest<'_, Extended, M, P>
275where
276 M: CompletionModel,
277 P: PromptHook<M>,
278{
279 async fn send(self) -> Result<PromptResponse, PromptError> {
280 let agent_span = if tracing::Span::current().is_disabled() {
281 info_span!(
282 "invoke_agent",
283 gen_ai.operation.name = "invoke_agent",
284 gen_ai.agent.name = self.agent.name(),
285 gen_ai.system_instructions = self.agent.preamble,
286 gen_ai.prompt = tracing::field::Empty,
287 gen_ai.completion = tracing::field::Empty,
288 gen_ai.usage.input_tokens = tracing::field::Empty,
289 gen_ai.usage.output_tokens = tracing::field::Empty,
290 )
291 } else {
292 tracing::Span::current()
293 };
294
295 let agent = self.agent;
296 let chat_history = if let Some(history) = self.chat_history {
297 history.push(self.prompt.to_owned());
298 history
299 } else {
300 &mut vec![self.prompt.to_owned()]
301 };
302
303 if let Some(text) = self.prompt.rag_text() {
304 agent_span.record("gen_ai.prompt", text);
305 }
306
307 let cancel_sig = CancelSignal::new();
308
309 let mut current_max_depth = 0;
310 let mut usage = Usage::new();
311 let current_span_id: AtomicU64 = AtomicU64::new(0);
312
313 let last_prompt = loop {
315 let prompt = chat_history
316 .last()
317 .cloned()
318 .expect("there should always be at least one message in the chat history");
319
320 if current_max_depth > self.max_depth + 1 {
321 break prompt;
322 }
323
324 current_max_depth += 1;
325
326 if self.max_depth > 1 {
327 tracing::info!(
328 "Current conversation depth: {}/{}",
329 current_max_depth,
330 self.max_depth
331 );
332 }
333
334 if let Some(ref hook) = self.hook {
335 hook.on_completion_call(
336 &prompt,
337 &chat_history[..chat_history.len() - 1],
338 cancel_sig.clone(),
339 )
340 .await;
341 if cancel_sig.is_cancelled() {
342 return Err(PromptError::prompt_cancelled(chat_history.to_vec()));
343 }
344 }
345 let span = tracing::Span::current();
346 let chat_span = info_span!(
347 target: "rig::agent_chat",
348 parent: &span,
349 "chat",
350 gen_ai.operation.name = "chat",
351 gen_ai.system_instructions = self.agent.preamble,
352 gen_ai.provider.name = tracing::field::Empty,
353 gen_ai.request.model = tracing::field::Empty,
354 gen_ai.response.id = tracing::field::Empty,
355 gen_ai.response.model = tracing::field::Empty,
356 gen_ai.usage.output_tokens = tracing::field::Empty,
357 gen_ai.usage.input_tokens = tracing::field::Empty,
358 gen_ai.input.messages = tracing::field::Empty,
359 gen_ai.output.messages = tracing::field::Empty,
360 );
361
362 let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
363 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
364 chat_span.follows_from(id).to_owned()
365 } else {
366 chat_span
367 };
368
369 if let Some(id) = chat_span.id() {
370 current_span_id.store(id.into_u64(), Ordering::SeqCst);
371 };
372
373 let resp = agent
374 .completion(
375 prompt.clone(),
376 chat_history[..chat_history.len() - 1].to_vec(),
377 )
378 .await?
379 .send()
380 .instrument(chat_span.clone())
381 .await?;
382
383 usage += resp.usage;
384
385 if let Some(ref hook) = self.hook {
386 hook.on_completion_response(&prompt, &resp, cancel_sig.clone())
387 .await;
388 if cancel_sig.is_cancelled() {
389 return Err(PromptError::prompt_cancelled(chat_history.to_vec()));
390 }
391 }
392
393 let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
394 .choice
395 .iter()
396 .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
397
398 chat_history.push(Message::Assistant {
399 id: None,
400 content: resp.choice.clone(),
401 });
402
403 if tool_calls.is_empty() {
404 let merged_texts = texts
405 .into_iter()
406 .filter_map(|content| {
407 if let AssistantContent::Text(text) = content {
408 Some(text.text.clone())
409 } else {
410 None
411 }
412 })
413 .collect::<Vec<_>>()
414 .join("\n");
415
416 if self.max_depth > 1 {
417 tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
418 }
419
420 agent_span.record("gen_ai.completion", &merged_texts);
421 agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
422 agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
423
424 return Ok(PromptResponse::new(merged_texts, usage));
426 }
427
428 let hook = self.hook.clone();
429 let tool_content = stream::iter(tool_calls)
430 .then(|choice| {
431 let hook1 = hook.clone();
432 let hook2 = hook.clone();
433
434 let cancel_sig1 = cancel_sig.clone();
435 let cancel_sig2 = cancel_sig.clone();
436
437 let tool_span = info_span!(
438 "execute_tool",
439 gen_ai.operation.name = "execute_tool",
440 gen_ai.tool.type = "function",
441 gen_ai.tool.name = tracing::field::Empty,
442 gen_ai.tool.call.id = tracing::field::Empty,
443 gen_ai.tool.call.arguments = tracing::field::Empty,
444 gen_ai.tool.call.result = tracing::field::Empty
445 );
446
447 let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
448 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
449 tool_span.follows_from(id).to_owned()
450 } else {
451 tool_span
452 };
453
454 if let Some(id) = tool_span.id() {
455 current_span_id.store(id.into_u64(), Ordering::SeqCst);
456 };
457
458 async move {
459 if let AssistantContent::ToolCall(tool_call) = choice {
460 let tool_name = &tool_call.function.name;
461 let args =
462 json_utils::value_to_json_string(&tool_call.function.arguments);
463 let tool_span = tracing::Span::current();
464 tool_span.record("gen_ai.tool.name", tool_name);
465 tool_span.record("gen_ai.tool.call.id", &tool_call.id);
466 tool_span.record("gen_ai.tool.call.arguments", &args);
467 if let Some(hook) = hook1 {
468 hook.on_tool_call(tool_name, &args, cancel_sig1.clone())
469 .await;
470 if cancel_sig1.is_cancelled() {
471 return Err(ToolSetError::Interrupted);
472 }
473 }
474 let output =
475 match agent.tool_server_handle.call_tool(tool_name, &args).await {
476 Ok(res) => res,
477 Err(e) => {
478 tracing::warn!("Error while executing tool: {e}");
479 e.to_string()
480 }
481 };
482 if let Some(hook) = hook2 {
483 hook.on_tool_result(
484 tool_name,
485 &args,
486 &output.to_string(),
487 cancel_sig2.clone(),
488 )
489 .await;
490
491 if cancel_sig2.is_cancelled() {
492 return Err(ToolSetError::Interrupted);
493 }
494 }
495 tool_span.record("gen_ai.tool.call.result", &output);
496 tracing::info!(
497 "executed tool {tool_name} with args {args}. result: {output}"
498 );
499 if let Some(call_id) = tool_call.call_id.clone() {
500 Ok(UserContent::tool_result_with_call_id(
501 tool_call.id.clone(),
502 call_id,
503 OneOrMany::one(output.into()),
504 ))
505 } else {
506 Ok(UserContent::tool_result(
507 tool_call.id.clone(),
508 OneOrMany::one(output.into()),
509 ))
510 }
511 } else {
512 unreachable!(
513 "This should never happen as we already filtered for `ToolCall`"
514 )
515 }
516 }
517 .instrument(tool_span)
518 })
519 .collect::<Vec<Result<UserContent, ToolSetError>>>()
520 .await
521 .into_iter()
522 .collect::<Result<Vec<_>, _>>()
523 .map_err(|e| {
524 if matches!(e, ToolSetError::Interrupted) {
525 PromptError::prompt_cancelled(chat_history.to_vec())
526 } else {
527 e.into()
528 }
529 })?;
530
531 chat_history.push(Message::User {
532 content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
533 });
534 };
535
536 Err(PromptError::MaxDepthError {
538 max_depth: self.max_depth,
539 chat_history: Box::new(chat_history.clone()),
540 prompt: last_prompt,
541 })
542 }
543}