use super::prompt_request::{self, PromptRequest, hooks::PromptHook};
use crate::{
agent::prompt_request::streaming::StreamingPromptRequest,
completion::{
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
GetTokenUsage, Message, Prompt, PromptError, TypedPrompt,
},
message::ToolChoice,
streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
tool::server::ToolServerHandle,
vector_store::{VectorStoreError, request::VectorSearchRequest},
wasm_compat::WasmCompatSend,
};
use futures::{StreamExt, TryStreamExt, stream};
use std::{collections::HashMap, sync::Arc};
use tokio::sync::RwLock as TokioRwLock;
const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
pub type DynamicContextStore = Arc<
TokioRwLock<
Vec<(
usize,
Arc<dyn crate::vector_store::VectorStoreIndexDyn + Send + Sync>,
)>,
>,
>;
#[allow(clippy::too_many_arguments)]
pub(crate) async fn build_completion_request<M: CompletionModel>(
model: &Arc<M>,
prompt: Message,
chat_history: &[Message],
preamble: Option<&str>,
static_context: &[Document],
temperature: Option<f64>,
max_tokens: Option<u64>,
additional_params: Option<&serde_json::Value>,
tool_choice: Option<&ToolChoice>,
tool_server_handle: &ToolServerHandle,
dynamic_context: &DynamicContextStore,
output_schema: Option<&schemars::Schema>,
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
let rag_text = prompt.rag_text();
let rag_text = rag_text.or_else(|| {
chat_history
.iter()
.rev()
.find_map(|message| message.rag_text())
});
let chat_history: Vec<Message> = if let Some(preamble) = preamble {
std::iter::once(Message::system(preamble.to_owned()))
.chain(chat_history.iter().cloned())
.collect()
} else {
chat_history.to_vec()
};
let completion_request = model
.completion_request(prompt)
.messages(chat_history)
.temperature_opt(temperature)
.max_tokens_opt(max_tokens)
.additional_params_opt(additional_params.cloned())
.output_schema_opt(output_schema.cloned())
.documents(static_context.to_vec());
let completion_request = if let Some(tool_choice) = tool_choice {
completion_request.tool_choice(tool_choice.clone())
} else {
completion_request
};
let result = match &rag_text {
Some(text) => {
let fetched_context = stream::iter(dynamic_context.read().await.iter())
.then(|(num_sample, index)| async {
let req = VectorSearchRequest::builder()
.query(text)
.samples(*num_sample as u64)
.build()
.expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
Ok::<_, VectorStoreError>(
index
.top_n(req)
.await?
.into_iter()
.map(|(_, id, doc)| {
let text = serde_json::to_string_pretty(&doc)
.unwrap_or_else(|_| doc.to_string());
Document {
id,
text,
additional_props: HashMap::new(),
}
})
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
acc.extend(docs);
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
let tooldefs = tool_server_handle
.get_tool_defs(Some(text.to_string()))
.await
.map_err(|_| {
CompletionError::RequestError("Failed to get tool definitions".into())
})?;
completion_request
.documents(fetched_context)
.tools(tooldefs)
}
None => {
let tooldefs = tool_server_handle.get_tool_defs(None).await.map_err(|_| {
CompletionError::RequestError("Failed to get tool definitions".into())
})?;
completion_request.tools(tooldefs)
}
};
Ok(result)
}
#[derive(Clone)]
#[non_exhaustive]
pub struct Agent<M, P = ()>
where
M: CompletionModel,
P: PromptHook<M>,
{
pub name: Option<String>,
pub description: Option<String>,
pub model: Arc<M>,
pub preamble: Option<String>,
pub static_context: Vec<Document>,
pub temperature: Option<f64>,
pub max_tokens: Option<u64>,
pub additional_params: Option<serde_json::Value>,
pub tool_server_handle: ToolServerHandle,
pub dynamic_context: DynamicContextStore,
pub tool_choice: Option<ToolChoice>,
pub default_max_turns: Option<usize>,
pub hook: Option<P>,
pub output_schema: Option<schemars::Schema>,
}
impl<M, P> Agent<M, P>
where
M: CompletionModel,
P: PromptHook<M>,
{
pub(crate) fn name(&self) -> &str {
self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
}
}
impl<M, P> Completion<M> for Agent<M, P>
where
M: CompletionModel,
P: PromptHook<M>,
{
async fn completion<I, T>(
&self,
prompt: impl Into<Message> + WasmCompatSend,
chat_history: I,
) -> Result<CompletionRequestBuilder<M>, CompletionError>
where
I: IntoIterator<Item = T>,
T: Into<Message>,
{
let history: Vec<Message> = chat_history.into_iter().map(Into::into).collect();
build_completion_request(
&self.model,
prompt.into(),
&history,
self.preamble.as_deref(),
&self.static_context,
self.temperature,
self.max_tokens,
self.additional_params.as_ref(),
self.tool_choice.as_ref(),
&self.tool_server_handle,
&self.dynamic_context,
self.output_schema.as_ref(),
)
.await
}
}
#[allow(refining_impl_trait)]
impl<M, P> Prompt for Agent<M, P>
where
M: CompletionModel + 'static,
P: PromptHook<M> + 'static,
{
fn prompt(
&self,
prompt: impl Into<Message> + WasmCompatSend,
) -> PromptRequest<prompt_request::Standard, M, P> {
PromptRequest::from_agent(self, prompt)
}
}
#[allow(refining_impl_trait)]
impl<M, P> Prompt for &Agent<M, P>
where
M: CompletionModel + 'static,
P: PromptHook<M> + 'static,
{
#[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
fn prompt(
&self,
prompt: impl Into<Message> + WasmCompatSend,
) -> PromptRequest<prompt_request::Standard, M, P> {
PromptRequest::from_agent(*self, prompt)
}
}
#[allow(refining_impl_trait)]
impl<M, P> Chat for Agent<M, P>
where
M: CompletionModel + 'static,
P: PromptHook<M> + 'static,
{
#[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
async fn chat<I, T>(
&self,
prompt: impl Into<Message> + WasmCompatSend,
chat_history: I,
) -> Result<String, PromptError>
where
I: IntoIterator<Item = T>,
T: Into<Message>,
{
PromptRequest::from_agent(self, prompt)
.with_history(chat_history)
.await
}
}
impl<M, P> StreamingCompletion<M> for Agent<M, P>
where
M: CompletionModel,
P: PromptHook<M>,
{
async fn stream_completion<I, T>(
&self,
prompt: impl Into<Message> + WasmCompatSend,
chat_history: I,
) -> Result<CompletionRequestBuilder<M>, CompletionError>
where
I: IntoIterator<Item = T> + WasmCompatSend,
T: Into<Message>,
{
self.completion(prompt, chat_history).await
}
}
impl<M, P> StreamingPrompt<M, M::StreamingResponse> for Agent<M, P>
where
M: CompletionModel + 'static,
M::StreamingResponse: GetTokenUsage,
P: PromptHook<M> + 'static,
{
type Hook = P;
fn stream_prompt(
&self,
prompt: impl Into<Message> + WasmCompatSend,
) -> StreamingPromptRequest<M, P> {
StreamingPromptRequest::<M, P>::from_agent(self, prompt)
}
}
impl<M, P> StreamingChat<M, M::StreamingResponse> for Agent<M, P>
where
M: CompletionModel + 'static,
M::StreamingResponse: GetTokenUsage,
P: PromptHook<M> + 'static,
{
type Hook = P;
fn stream_chat<I, T>(
&self,
prompt: impl Into<Message> + WasmCompatSend,
chat_history: I,
) -> StreamingPromptRequest<M, P>
where
I: IntoIterator<Item = T>,
T: Into<Message>,
{
StreamingPromptRequest::<M, P>::from_agent(self, prompt).with_history(chat_history)
}
}
use crate::agent::prompt_request::TypedPromptRequest;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
#[allow(refining_impl_trait)]
impl<M, P> TypedPrompt for Agent<M, P>
where
M: CompletionModel + 'static,
P: PromptHook<M> + 'static,
{
type TypedRequest<T>
= TypedPromptRequest<T, prompt_request::Standard, M, P>
where
T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
fn prompt_typed<T>(
&self,
prompt: impl Into<Message> + WasmCompatSend,
) -> TypedPromptRequest<T, prompt_request::Standard, M, P>
where
T: JsonSchema + DeserializeOwned + WasmCompatSend,
{
TypedPromptRequest::from_agent(self, prompt)
}
}
#[allow(refining_impl_trait)]
impl<M, P> TypedPrompt for &Agent<M, P>
where
M: CompletionModel + 'static,
P: PromptHook<M> + 'static,
{
type TypedRequest<T>
= TypedPromptRequest<T, prompt_request::Standard, M, P>
where
T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
fn prompt_typed<T>(
&self,
prompt: impl Into<Message> + WasmCompatSend,
) -> TypedPromptRequest<T, prompt_request::Standard, M, P>
where
T: JsonSchema + DeserializeOwned + WasmCompatSend,
{
TypedPromptRequest::from_agent(*self, prompt)
}
}