pub mod streaming;
pub use streaming::StreamingPromptHook;
use std::{
future::IntoFuture,
marker::PhantomData,
sync::{
Arc,
atomic::{AtomicBool, AtomicU64, Ordering},
},
};
use tracing::{Instrument, span::Id};
use futures::{StreamExt, stream};
use tracing::info_span;
use crate::{
OneOrMany,
completion::{Completion, CompletionModel, Message, PromptError, Usage},
json_utils,
message::{AssistantContent, UserContent},
tool::ToolSetError,
wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
};
use super::Agent;
pub trait PromptType {}
pub struct Standard;
pub struct Extended;
impl PromptType for Standard {}
impl PromptType for Extended {}
pub struct PromptRequest<'a, S, M, P>
where
S: PromptType,
M: CompletionModel,
P: PromptHook<M>,
{
prompt: Message,
chat_history: Option<&'a mut Vec<Message>>,
max_depth: usize,
agent: &'a Agent<M>,
state: PhantomData<S>,
hook: Option<P>,
concurrency: usize,
}
impl<'a, M> PromptRequest<'a, Standard, M, ()>
where
M: CompletionModel,
{
pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
Self {
prompt: prompt.into(),
chat_history: None,
max_depth: 0,
agent,
state: PhantomData,
hook: None,
concurrency: 1,
}
}
}
impl<'a, S, M, P> PromptRequest<'a, S, M, P>
where
S: PromptType,
M: CompletionModel,
P: PromptHook<M>,
{
pub fn extended_details(self) -> PromptRequest<'a, Extended, M, P> {
PromptRequest {
prompt: self.prompt,
chat_history: self.chat_history,
max_depth: self.max_depth,
agent: self.agent,
state: PhantomData,
hook: self.hook,
concurrency: self.concurrency,
}
}
pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, S, M, P> {
PromptRequest {
prompt: self.prompt,
chat_history: self.chat_history,
max_depth: depth,
agent: self.agent,
state: PhantomData,
hook: self.hook,
concurrency: self.concurrency,
}
}
pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
self.concurrency = concurrency;
self
}
pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, S, M, P> {
PromptRequest {
prompt: self.prompt,
chat_history: Some(history),
max_depth: self.max_depth,
agent: self.agent,
state: PhantomData,
hook: self.hook,
concurrency: self.concurrency,
}
}
pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<'a, S, M, P2>
where
P2: PromptHook<M>,
{
PromptRequest {
prompt: self.prompt,
chat_history: self.chat_history,
max_depth: self.max_depth,
agent: self.agent,
state: PhantomData,
hook: Some(hook),
concurrency: self.concurrency,
}
}
}
pub struct CancelSignal(Arc<AtomicBool>);
impl CancelSignal {
fn new() -> Self {
Self(Arc::new(AtomicBool::new(false)))
}
pub fn cancel(&self) {
self.0.store(true, Ordering::SeqCst);
}
fn is_cancelled(&self) -> bool {
self.0.load(Ordering::SeqCst)
}
}
impl Clone for CancelSignal {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
pub trait PromptHook<M>: Clone + WasmCompatSend + WasmCompatSync
where
M: CompletionModel,
{
#[allow(unused_variables)]
fn on_completion_call(
&self,
prompt: &Message,
history: &[Message],
cancel_sig: CancelSignal,
) -> impl Future<Output = ()> + WasmCompatSend {
async {}
}
#[allow(unused_variables)]
fn on_completion_response(
&self,
prompt: &Message,
response: &crate::completion::CompletionResponse<M::Response>,
cancel_sig: CancelSignal,
) -> impl Future<Output = ()> + WasmCompatSend {
async {}
}
#[allow(unused_variables)]
fn on_tool_call(
&self,
tool_name: &str,
tool_call_id: Option<String>,
args: &str,
cancel_sig: CancelSignal,
) -> impl Future<Output = ()> + WasmCompatSend {
async {}
}
#[allow(unused_variables)]
fn on_tool_result(
&self,
tool_name: &str,
tool_call_id: Option<String>,
args: &str,
result: &str,
cancel_sig: CancelSignal,
) -> impl Future<Output = ()> + WasmCompatSend {
async {}
}
}
impl<M> PromptHook<M> for () where M: CompletionModel {}
impl<'a, M, P> IntoFuture for PromptRequest<'a, Standard, M, P>
where
M: CompletionModel,
P: PromptHook<M> + 'static,
{
type Output = Result<String, PromptError>;
type IntoFuture = WasmBoxedFuture<'a, Self::Output>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(self.send())
}
}
impl<'a, M, P> IntoFuture for PromptRequest<'a, Extended, M, P>
where
M: CompletionModel,
P: PromptHook<M> + 'static,
{
type Output = Result<PromptResponse, PromptError>;
type IntoFuture = WasmBoxedFuture<'a, Self::Output>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(self.send())
}
}
impl<M, P> PromptRequest<'_, Standard, M, P>
where
M: CompletionModel,
P: PromptHook<M>,
{
async fn send(self) -> Result<String, PromptError> {
self.extended_details().send().await.map(|resp| resp.output)
}
}
#[derive(Debug, Clone)]
pub struct PromptResponse {
pub output: String,
pub total_usage: Usage,
}
impl PromptResponse {
pub fn new(output: impl Into<String>, total_usage: Usage) -> Self {
Self {
output: output.into(),
total_usage,
}
}
}
impl<M, P> PromptRequest<'_, Extended, M, P>
where
M: CompletionModel,
P: PromptHook<M>,
{
async fn send(self) -> Result<PromptResponse, PromptError> {
let agent_span = if tracing::Span::current().is_disabled() {
info_span!(
"invoke_agent",
gen_ai.operation.name = "invoke_agent",
gen_ai.agent.name = self.agent.name(),
gen_ai.system_instructions = self.agent.preamble,
gen_ai.prompt = tracing::field::Empty,
gen_ai.completion = tracing::field::Empty,
gen_ai.usage.input_tokens = tracing::field::Empty,
gen_ai.usage.output_tokens = tracing::field::Empty,
)
} else {
tracing::Span::current()
};
let agent = self.agent;
let chat_history = if let Some(history) = self.chat_history {
history.push(self.prompt.to_owned());
history
} else {
&mut vec![self.prompt.to_owned()]
};
if let Some(text) = self.prompt.rag_text() {
agent_span.record("gen_ai.prompt", text);
}
let cancel_sig = CancelSignal::new();
let mut current_max_depth = 0;
let mut usage = Usage::new();
let current_span_id: AtomicU64 = AtomicU64::new(0);
let last_prompt = loop {
let prompt = chat_history
.last()
.cloned()
.expect("there should always be at least one message in the chat history");
if current_max_depth > self.max_depth + 1 {
break prompt;
}
current_max_depth += 1;
if self.max_depth > 1 {
tracing::info!(
"Current conversation depth: {}/{}",
current_max_depth,
self.max_depth
);
}
if let Some(ref hook) = self.hook {
hook.on_completion_call(
&prompt,
&chat_history[..chat_history.len() - 1],
cancel_sig.clone(),
)
.await;
if cancel_sig.is_cancelled() {
return Err(PromptError::prompt_cancelled(chat_history.to_vec()));
}
}
let span = tracing::Span::current();
let chat_span = info_span!(
target: "rig::agent_chat",
parent: &span,
"chat",
gen_ai.operation.name = "chat",
gen_ai.system_instructions = self.agent.preamble,
gen_ai.provider.name = tracing::field::Empty,
gen_ai.request.model = tracing::field::Empty,
gen_ai.response.id = tracing::field::Empty,
gen_ai.response.model = tracing::field::Empty,
gen_ai.usage.output_tokens = tracing::field::Empty,
gen_ai.usage.input_tokens = tracing::field::Empty,
gen_ai.input.messages = tracing::field::Empty,
gen_ai.output.messages = tracing::field::Empty,
);
let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
chat_span.follows_from(id).to_owned()
} else {
chat_span
};
if let Some(id) = chat_span.id() {
current_span_id.store(id.into_u64(), Ordering::SeqCst);
};
let resp = agent
.completion(
prompt.clone(),
chat_history[..chat_history.len() - 1].to_vec(),
)
.await?
.send()
.instrument(chat_span.clone())
.await?;
usage += resp.usage;
if let Some(ref hook) = self.hook {
hook.on_completion_response(&prompt, &resp, cancel_sig.clone())
.await;
if cancel_sig.is_cancelled() {
return Err(PromptError::prompt_cancelled(chat_history.to_vec()));
}
}
let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
.choice
.iter()
.partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
chat_history.push(Message::Assistant {
id: None,
content: resp.choice.clone(),
});
if tool_calls.is_empty() {
let merged_texts = texts
.into_iter()
.filter_map(|content| {
if let AssistantContent::Text(text) = content {
Some(text.text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n");
if self.max_depth > 1 {
tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
}
agent_span.record("gen_ai.completion", &merged_texts);
agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
return Ok(PromptResponse::new(merged_texts, usage));
}
let hook = self.hook.clone();
let tool_calls: Vec<AssistantContent> = tool_calls.into_iter().cloned().collect();
let tool_content = stream::iter(tool_calls)
.map(|choice| {
let hook1 = hook.clone();
let hook2 = hook.clone();
let cancel_sig1 = cancel_sig.clone();
let cancel_sig2 = cancel_sig.clone();
let tool_span = info_span!(
"execute_tool",
gen_ai.operation.name = "execute_tool",
gen_ai.tool.type = "function",
gen_ai.tool.name = tracing::field::Empty,
gen_ai.tool.call.id = tracing::field::Empty,
gen_ai.tool.call.arguments = tracing::field::Empty,
gen_ai.tool.call.result = tracing::field::Empty
);
let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
tool_span.follows_from(id).to_owned()
} else {
tool_span
};
if let Some(id) = tool_span.id() {
current_span_id.store(id.into_u64(), Ordering::SeqCst);
};
async move {
if let AssistantContent::ToolCall(tool_call) = choice {
let tool_name = &tool_call.function.name;
let args =
json_utils::value_to_json_string(&tool_call.function.arguments);
let tool_span = tracing::Span::current();
tool_span.record("gen_ai.tool.name", tool_name);
tool_span.record("gen_ai.tool.call.id", &tool_call.id);
tool_span.record("gen_ai.tool.call.arguments", &args);
if let Some(hook) = hook1 {
hook.on_tool_call(
tool_name,
tool_call.call_id.clone(),
&args,
cancel_sig1.clone(),
)
.await;
if cancel_sig1.is_cancelled() {
return Err(ToolSetError::Interrupted);
}
}
let output =
match agent.tool_server_handle.call_tool(tool_name, &args).await {
Ok(res) => res,
Err(e) => {
tracing::warn!("Error while executing tool: {e}");
e.to_string()
}
};
if let Some(hook) = hook2 {
hook.on_tool_result(
tool_name,
tool_call.call_id.clone(),
&args,
&output.to_string(),
cancel_sig2.clone(),
)
.await;
if cancel_sig2.is_cancelled() {
return Err(ToolSetError::Interrupted);
}
}
tool_span.record("gen_ai.tool.call.result", &output);
tracing::info!(
"executed tool {tool_name} with args {args}. result: {output}"
);
if let Some(call_id) = tool_call.call_id.clone() {
Ok(UserContent::tool_result_with_call_id(
tool_call.id.clone(),
call_id,
OneOrMany::one(output.into()),
))
} else {
Ok(UserContent::tool_result(
tool_call.id.clone(),
OneOrMany::one(output.into()),
))
}
} else {
unreachable!(
"This should never happen as we already filtered for `ToolCall`"
)
}
}
.instrument(tool_span)
})
.buffer_unordered(self.concurrency)
.collect::<Vec<Result<UserContent, ToolSetError>>>()
.await
.into_iter()
.collect::<Result<Vec<_>, _>>()
.map_err(|e| {
if matches!(e, ToolSetError::Interrupted) {
PromptError::prompt_cancelled(chat_history.to_vec())
} else {
e.into()
}
})?;
chat_history.push(Message::User {
content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
});
};
Err(PromptError::MaxDepthError {
max_depth: self.max_depth,
chat_history: Box::new(chat_history.clone()),
prompt: Box::new(last_prompt),
})
}
}