pub mod hooks;
pub mod streaming;
use super::{
Agent,
completion::{DynamicContextStore, build_completion_request},
};
use crate::{
OneOrMany,
completion::{CompletionModel, Document, Message, PromptError, Usage},
json_utils,
message::{AssistantContent, ToolChoice, ToolResultContent, UserContent},
tool::server::ToolServerHandle,
wasm_compat::{WasmBoxedFuture, WasmCompatSend},
};
use futures::{StreamExt, stream};
use hooks::{HookAction, PromptHook, ToolCallHookAction};
use std::{
future::IntoFuture,
marker::PhantomData,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
};
use tracing::info_span;
use tracing::{Instrument, span::Id};
pub trait PromptType {}
pub struct Standard;
pub struct Extended;
impl PromptType for Standard {}
impl PromptType for Extended {}
pub struct PromptRequest<S, M, P>
where
S: PromptType,
M: CompletionModel,
P: PromptHook<M>,
{
prompt: Message,
chat_history: Option<Vec<Message>>,
max_turns: usize,
model: Arc<M>,
agent_name: Option<String>,
preamble: Option<String>,
static_context: Vec<Document>,
temperature: Option<f64>,
max_tokens: Option<u64>,
additional_params: Option<serde_json::Value>,
tool_server_handle: ToolServerHandle,
dynamic_context: DynamicContextStore,
tool_choice: Option<ToolChoice>,
state: PhantomData<S>,
hook: Option<P>,
concurrency: usize,
output_schema: Option<schemars::Schema>,
}
impl<M, P> PromptRequest<Standard, M, P>
where
M: CompletionModel,
P: PromptHook<M>,
{
pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
PromptRequest {
prompt: prompt.into(),
chat_history: None,
max_turns: agent.default_max_turns.unwrap_or_default(),
model: agent.model.clone(),
agent_name: agent.name.clone(),
preamble: agent.preamble.clone(),
static_context: agent.static_context.clone(),
temperature: agent.temperature,
max_tokens: agent.max_tokens,
additional_params: agent.additional_params.clone(),
tool_server_handle: agent.tool_server_handle.clone(),
dynamic_context: agent.dynamic_context.clone(),
tool_choice: agent.tool_choice.clone(),
state: PhantomData,
hook: agent.hook.clone(),
concurrency: 1,
output_schema: agent.output_schema.clone(),
}
}
}
impl<S, M, P> PromptRequest<S, M, P>
where
S: PromptType,
M: CompletionModel,
P: PromptHook<M>,
{
pub fn extended_details(self) -> PromptRequest<Extended, M, P> {
PromptRequest {
prompt: self.prompt,
chat_history: self.chat_history,
max_turns: self.max_turns,
model: self.model,
agent_name: self.agent_name,
preamble: self.preamble,
static_context: self.static_context,
temperature: self.temperature,
max_tokens: self.max_tokens,
additional_params: self.additional_params,
tool_server_handle: self.tool_server_handle,
dynamic_context: self.dynamic_context,
tool_choice: self.tool_choice,
state: PhantomData,
hook: self.hook,
concurrency: self.concurrency,
output_schema: self.output_schema,
}
}
pub fn max_turns(mut self, depth: usize) -> Self {
self.max_turns = depth;
self
}
pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
self.concurrency = concurrency;
self
}
pub fn with_history<I, T>(mut self, history: I) -> Self
where
I: IntoIterator<Item = T>,
T: Into<Message>,
{
self.chat_history = Some(history.into_iter().map(Into::into).collect());
self
}
pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<S, M, P2>
where
P2: PromptHook<M>,
{
PromptRequest {
prompt: self.prompt,
chat_history: self.chat_history,
max_turns: self.max_turns,
model: self.model,
agent_name: self.agent_name,
preamble: self.preamble,
static_context: self.static_context,
temperature: self.temperature,
max_tokens: self.max_tokens,
additional_params: self.additional_params,
tool_server_handle: self.tool_server_handle,
dynamic_context: self.dynamic_context,
tool_choice: self.tool_choice,
state: PhantomData,
hook: Some(hook),
concurrency: self.concurrency,
output_schema: self.output_schema,
}
}
}
impl<M, P> IntoFuture for PromptRequest<Standard, M, P>
where
M: CompletionModel + 'static,
P: PromptHook<M> + 'static,
{
type Output = Result<String, PromptError>;
type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(self.send())
}
}
impl<M, P> IntoFuture for PromptRequest<Extended, M, P>
where
M: CompletionModel + 'static,
P: PromptHook<M> + 'static,
{
type Output = Result<PromptResponse, PromptError>;
type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(self.send())
}
}
impl<M, P> PromptRequest<Standard, M, P>
where
M: CompletionModel,
P: PromptHook<M>,
{
async fn send(self) -> Result<String, PromptError> {
self.extended_details().send().await.map(|resp| resp.output)
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct PromptResponse {
pub output: String,
pub usage: Usage,
pub messages: Option<Vec<Message>>,
}
impl std::fmt::Display for PromptResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.output.fmt(f)
}
}
impl PromptResponse {
pub fn new(output: impl Into<String>, usage: Usage) -> Self {
Self {
output: output.into(),
usage,
messages: None,
}
}
pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
self.messages = Some(messages);
self
}
}
#[derive(Debug, Clone)]
pub struct TypedPromptResponse<T> {
pub output: T,
pub usage: Usage,
}
impl<T> TypedPromptResponse<T> {
pub fn new(output: T, usage: Usage) -> Self {
Self { output, usage }
}
}
const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
fn build_history_for_request(
chat_history: Option<&[Message]>,
new_messages: &[Message],
) -> Vec<Message> {
let input = chat_history.unwrap_or(&[]);
input.iter().chain(new_messages.iter()).cloned().collect()
}
fn build_full_history(
chat_history: Option<&[Message]>,
new_messages: Vec<Message>,
) -> Vec<Message> {
let input = chat_history.unwrap_or(&[]);
input.iter().cloned().chain(new_messages).collect()
}
impl<M, P> PromptRequest<Extended, M, P>
where
M: CompletionModel,
P: PromptHook<M>,
{
fn agent_name(&self) -> &str {
self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
}
async fn send(self) -> Result<PromptResponse, PromptError> {
let agent_span = if tracing::Span::current().is_disabled() {
info_span!(
"invoke_agent",
gen_ai.operation.name = "invoke_agent",
gen_ai.agent.name = self.agent_name(),
gen_ai.system_instructions = self.preamble,
gen_ai.prompt = tracing::field::Empty,
gen_ai.completion = tracing::field::Empty,
gen_ai.usage.input_tokens = tracing::field::Empty,
gen_ai.usage.output_tokens = tracing::field::Empty,
gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
)
} else {
tracing::Span::current()
};
if let Some(text) = self.prompt.rag_text() {
agent_span.record("gen_ai.prompt", text);
}
let agent_name_for_span = self.agent_name.clone();
let chat_history = self.chat_history;
let mut new_messages: Vec<Message> = vec![self.prompt.clone()];
let mut current_max_turns = 0;
let mut usage = Usage::new();
let current_span_id: AtomicU64 = AtomicU64::new(0);
let last_prompt = loop {
let prompt = new_messages
.last()
.expect("there should always be at least one message")
.clone();
if current_max_turns > self.max_turns + 1 {
break prompt;
}
current_max_turns += 1;
if self.max_turns > 1 {
tracing::info!(
"Current conversation depth: {}/{}",
current_max_turns,
self.max_turns
);
}
let history_for_hook = build_history_for_request(
chat_history.as_deref(),
&new_messages[..new_messages.len().saturating_sub(1)],
);
if let Some(ref hook) = self.hook
&& let HookAction::Terminate { reason } =
hook.on_completion_call(&prompt, &history_for_hook).await
{
return Err(PromptError::prompt_cancelled(
build_full_history(chat_history.as_deref(), new_messages),
reason,
));
}
let span = tracing::Span::current();
let chat_span = info_span!(
target: "rig::agent_chat",
parent: &span,
"chat",
gen_ai.operation.name = "chat",
gen_ai.agent.name = agent_name_for_span.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
gen_ai.system_instructions = self.preamble,
gen_ai.provider.name = tracing::field::Empty,
gen_ai.request.model = tracing::field::Empty,
gen_ai.response.id = tracing::field::Empty,
gen_ai.response.model = tracing::field::Empty,
gen_ai.usage.output_tokens = tracing::field::Empty,
gen_ai.usage.input_tokens = tracing::field::Empty,
gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
gen_ai.input.messages = tracing::field::Empty,
gen_ai.output.messages = tracing::field::Empty,
);
let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
chat_span.follows_from(id).to_owned()
} else {
chat_span
};
if let Some(id) = chat_span.id() {
current_span_id.store(id.into_u64(), Ordering::SeqCst);
};
let history_for_request = build_history_for_request(
chat_history.as_deref(),
&new_messages[..new_messages.len().saturating_sub(1)],
);
let resp = build_completion_request(
&self.model,
prompt.clone(),
&history_for_request,
self.preamble.as_deref(),
&self.static_context,
self.temperature,
self.max_tokens,
self.additional_params.as_ref(),
self.tool_choice.as_ref(),
&self.tool_server_handle,
&self.dynamic_context,
self.output_schema.as_ref(),
)
.await?
.send()
.instrument(chat_span.clone())
.await?;
usage += resp.usage;
if let Some(ref hook) = self.hook
&& let HookAction::Terminate { reason } =
hook.on_completion_response(&prompt, &resp).await
{
return Err(PromptError::prompt_cancelled(
build_full_history(chat_history.as_deref(), new_messages),
reason,
));
}
let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
.choice
.iter()
.partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
new_messages.push(Message::Assistant {
id: resp.message_id.clone(),
content: resp.choice.clone(),
});
if tool_calls.is_empty() {
let merged_texts = texts
.into_iter()
.filter_map(|content| {
if let AssistantContent::Text(text) = content {
Some(text.text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n");
if self.max_turns > 1 {
tracing::info!("Depth reached: {}/{}", current_max_turns, self.max_turns);
}
agent_span.record("gen_ai.completion", &merged_texts);
agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
agent_span.record(
"gen_ai.usage.cache_read.input_tokens",
usage.cached_input_tokens,
);
agent_span.record(
"gen_ai.usage.cache_creation.input_tokens",
usage.cache_creation_input_tokens,
);
return Ok(PromptResponse::new(merged_texts, usage).with_messages(new_messages));
}
let hook = self.hook.clone();
let tool_server_handle = self.tool_server_handle.clone();
let full_history_for_errors =
build_full_history(chat_history.as_deref(), new_messages.clone());
let tool_calls: Vec<AssistantContent> = tool_calls.into_iter().cloned().collect();
let tool_content = stream::iter(tool_calls)
.map(|choice| {
let hook1 = hook.clone();
let hook2 = hook.clone();
let tool_server_handle = tool_server_handle.clone();
let tool_span = info_span!(
"execute_tool",
gen_ai.operation.name = "execute_tool",
gen_ai.tool.type = "function",
gen_ai.tool.name = tracing::field::Empty,
gen_ai.tool.call.id = tracing::field::Empty,
gen_ai.tool.call.arguments = tracing::field::Empty,
gen_ai.tool.call.result = tracing::field::Empty
);
let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
tool_span.follows_from(id).to_owned()
} else {
tool_span
};
if let Some(id) = tool_span.id() {
current_span_id.store(id.into_u64(), Ordering::SeqCst);
};
let cloned_history_for_error = full_history_for_errors.clone();
async move {
if let AssistantContent::ToolCall(tool_call) = choice {
let tool_name = &tool_call.function.name;
let args =
json_utils::value_to_json_string(&tool_call.function.arguments);
let internal_call_id = nanoid::nanoid!();
let tool_span = tracing::Span::current();
tool_span.record("gen_ai.tool.name", tool_name);
tool_span.record("gen_ai.tool.call.id", &tool_call.id);
tool_span.record("gen_ai.tool.call.arguments", &args);
if let Some(hook) = hook1 {
let action = hook
.on_tool_call(
tool_name,
tool_call.call_id.clone(),
&internal_call_id,
&args,
)
.await;
if let ToolCallHookAction::Terminate { reason } = action {
return Err(PromptError::prompt_cancelled(
cloned_history_for_error,
reason,
));
}
if let ToolCallHookAction::Skip { reason } = action {
tracing::info!(
tool_name = tool_name,
reason = reason,
"Tool call rejected"
);
if let Some(call_id) = tool_call.call_id.clone() {
return Ok(UserContent::tool_result_with_call_id(
tool_call.id.clone(),
call_id,
OneOrMany::one(reason.into()),
));
} else {
return Ok(UserContent::tool_result(
tool_call.id.clone(),
OneOrMany::one(reason.into()),
));
}
}
}
let output = match tool_server_handle.call_tool(tool_name, &args).await
{
Ok(res) => res,
Err(e) => {
tracing::warn!("Error while executing tool: {e}");
e.to_string()
}
};
if let Some(hook) = hook2
&& let HookAction::Terminate { reason } = hook
.on_tool_result(
tool_name,
tool_call.call_id.clone(),
&internal_call_id,
&args,
&output.to_string(),
)
.await
{
return Err(PromptError::prompt_cancelled(
cloned_history_for_error,
reason,
));
}
tool_span.record("gen_ai.tool.call.result", &output);
tracing::info!(
"executed tool {tool_name} with args {args}. result: {output}"
);
if let Some(call_id) = tool_call.call_id.clone() {
Ok(UserContent::tool_result_with_call_id(
tool_call.id.clone(),
call_id,
ToolResultContent::from_tool_output(output),
))
} else {
Ok(UserContent::tool_result(
tool_call.id.clone(),
ToolResultContent::from_tool_output(output),
))
}
} else {
unreachable!(
"This should never happen as we already filtered for `ToolCall`"
)
}
}
.instrument(tool_span)
})
.buffer_unordered(self.concurrency)
.collect::<Vec<Result<UserContent, PromptError>>>()
.await
.into_iter()
.collect::<Result<Vec<_>, _>>()?;
new_messages.push(Message::User {
content: OneOrMany::many(tool_content).expect("There is at least one tool call"),
});
};
Err(PromptError::MaxTurnsError {
max_turns: self.max_turns,
chat_history: build_full_history(chat_history.as_deref(), new_messages).into(),
prompt: last_prompt.into(),
})
}
}
use crate::completion::StructuredOutputError;
use schemars::{JsonSchema, schema_for};
use serde::de::DeserializeOwned;
pub struct TypedPromptRequest<T, S, M, P>
where
T: JsonSchema + DeserializeOwned + WasmCompatSend,
S: PromptType,
M: CompletionModel,
P: PromptHook<M>,
{
inner: PromptRequest<S, M, P>,
_phantom: std::marker::PhantomData<T>,
}
impl<T, M, P> TypedPromptRequest<T, Standard, M, P>
where
T: JsonSchema + DeserializeOwned + WasmCompatSend,
M: CompletionModel,
P: PromptHook<M>,
{
pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
let mut inner = PromptRequest::from_agent(agent, prompt);
inner.output_schema = Some(schema_for!(T));
Self {
inner,
_phantom: std::marker::PhantomData,
}
}
}
impl<T, S, M, P> TypedPromptRequest<T, S, M, P>
where
T: JsonSchema + DeserializeOwned + WasmCompatSend,
S: PromptType,
M: CompletionModel,
P: PromptHook<M>,
{
pub fn extended_details(self) -> TypedPromptRequest<T, Extended, M, P> {
TypedPromptRequest {
inner: self.inner.extended_details(),
_phantom: std::marker::PhantomData,
}
}
pub fn max_turns(mut self, depth: usize) -> Self {
self.inner = self.inner.max_turns(depth);
self
}
pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
self.inner = self.inner.with_tool_concurrency(concurrency);
self
}
pub fn with_history<I, H>(mut self, history: I) -> Self
where
I: IntoIterator<Item = H>,
H: Into<Message>,
{
self.inner = self.inner.with_history(history);
self
}
pub fn with_hook<P2>(self, hook: P2) -> TypedPromptRequest<T, S, M, P2>
where
P2: PromptHook<M>,
{
TypedPromptRequest {
inner: self.inner.with_hook(hook),
_phantom: std::marker::PhantomData,
}
}
}
impl<T, M, P> TypedPromptRequest<T, Standard, M, P>
where
T: JsonSchema + DeserializeOwned + WasmCompatSend,
M: CompletionModel,
P: PromptHook<M>,
{
async fn send(self) -> Result<T, StructuredOutputError> {
let response = self.inner.send().await.map_err(Box::new)?;
if response.is_empty() {
return Err(StructuredOutputError::EmptyResponse);
}
let parsed: T = serde_json::from_str(&response)?;
Ok(parsed)
}
}
impl<T, M, P> TypedPromptRequest<T, Extended, M, P>
where
T: JsonSchema + DeserializeOwned + WasmCompatSend,
M: CompletionModel,
P: PromptHook<M>,
{
async fn send(self) -> Result<TypedPromptResponse<T>, StructuredOutputError> {
let response = self.inner.send().await.map_err(Box::new)?;
if response.output.is_empty() {
return Err(StructuredOutputError::EmptyResponse);
}
let parsed: T = serde_json::from_str(&response.output)?;
Ok(TypedPromptResponse::new(parsed, response.usage))
}
}
impl<T, M, P> IntoFuture for TypedPromptRequest<T, Standard, M, P>
where
T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static,
M: CompletionModel + 'static,
P: PromptHook<M> + 'static,
{
type Output = Result<T, StructuredOutputError>;
type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(self.send())
}
}
impl<T, M, P> IntoFuture for TypedPromptRequest<T, Extended, M, P>
where
T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static,
M: CompletionModel + 'static,
P: PromptHook<M> + 'static,
{
type Output = Result<TypedPromptResponse<T>, StructuredOutputError>;
type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(self.send())
}
}