use super::prompt_request::{self, PromptRequest};
use crate::{
agent::prompt_request::streaming::StreamingPromptRequest,
completion::{
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
GetTokenUsage, Message, Prompt, PromptError,
},
message::ToolChoice,
streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
tool::server::ToolServerHandle,
vector_store::{VectorStoreError, request::VectorSearchRequest},
wasm_compat::WasmCompatSend,
};
use futures::{StreamExt, TryStreamExt, stream};
use std::{collections::HashMap, sync::Arc};
use tokio::sync::RwLock;
const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
pub type DynamicContextStore = Arc<
RwLock<
Vec<(
usize,
Box<dyn crate::vector_store::VectorStoreIndexDyn + Send + Sync>,
)>,
>,
>;
#[derive(Clone)]
#[non_exhaustive]
pub struct Agent<M>
where
M: CompletionModel,
{
pub name: Option<String>,
pub description: Option<String>,
pub model: Arc<M>,
pub preamble: Option<String>,
pub static_context: Vec<Document>,
pub temperature: Option<f64>,
pub max_tokens: Option<u64>,
pub additional_params: Option<serde_json::Value>,
pub tool_server_handle: ToolServerHandle,
pub dynamic_context: DynamicContextStore,
pub tool_choice: Option<ToolChoice>,
}
impl<M> Agent<M>
where
M: CompletionModel,
{
pub(crate) fn name(&self) -> &str {
self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
}
}
impl<M> Completion<M> for Agent<M>
where
M: CompletionModel,
{
async fn completion(
&self,
prompt: impl Into<Message> + WasmCompatSend,
chat_history: Vec<Message>,
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
let prompt = prompt.into();
let rag_text = prompt.rag_text();
let rag_text = rag_text.or_else(|| {
chat_history
.iter()
.rev()
.find_map(|message| message.rag_text())
});
let completion_request = self
.model
.completion_request(prompt)
.messages(chat_history)
.temperature_opt(self.temperature)
.max_tokens_opt(self.max_tokens)
.additional_params_opt(self.additional_params.clone())
.documents(self.static_context.clone());
let completion_request = if let Some(preamble) = &self.preamble {
completion_request.preamble(preamble.to_owned())
} else {
completion_request
};
let completion_request = if let Some(tool_choice) = &self.tool_choice {
completion_request.tool_choice(tool_choice.clone())
} else {
completion_request
};
let agent = match &rag_text {
Some(text) => {
let dynamic_context = stream::iter(self.dynamic_context.read().await.iter())
.then(|(num_sample, index)| async {
let req = VectorSearchRequest::builder().query(text).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
Ok::<_, VectorStoreError>(
index
.top_n(req)
.await?
.into_iter()
.map(|(_, id, doc)| {
let text = serde_json::to_string_pretty(&doc)
.unwrap_or_else(|_| doc.to_string());
Document {
id,
text,
additional_props: HashMap::new(),
}
})
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
acc.extend(docs);
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
let tooldefs = self
.tool_server_handle
.get_tool_defs(Some(text.to_string()))
.await
.map_err(|_| {
CompletionError::RequestError("Failed to get tool definitions".into())
})?;
completion_request
.documents(dynamic_context)
.tools(tooldefs)
}
None => {
let tooldefs = self
.tool_server_handle
.get_tool_defs(None)
.await
.map_err(|_| {
CompletionError::RequestError("Failed to get tool definitions".into())
})?;
completion_request.tools(tooldefs)
}
};
Ok(agent)
}
}
#[allow(refining_impl_trait)]
impl<M> Prompt for Agent<M>
where
M: CompletionModel,
{
fn prompt(
&self,
prompt: impl Into<Message> + WasmCompatSend,
) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
PromptRequest::new(self, prompt)
}
}
#[allow(refining_impl_trait)]
impl<M> Prompt for &Agent<M>
where
M: CompletionModel,
{
#[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
fn prompt(
&self,
prompt: impl Into<Message> + WasmCompatSend,
) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
PromptRequest::new(*self, prompt)
}
}
#[allow(refining_impl_trait)]
impl<M> Chat for Agent<M>
where
M: CompletionModel,
{
#[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
async fn chat(
&self,
prompt: impl Into<Message> + WasmCompatSend,
mut chat_history: Vec<Message>,
) -> Result<String, PromptError> {
PromptRequest::new(self, prompt)
.with_history(&mut chat_history)
.await
}
}
impl<M> StreamingCompletion<M> for Agent<M>
where
M: CompletionModel,
{
async fn stream_completion(
&self,
prompt: impl Into<Message> + WasmCompatSend,
chat_history: Vec<Message>,
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
self.completion(prompt, chat_history).await
}
}
impl<M> StreamingPrompt<M, M::StreamingResponse> for Agent<M>
where
M: CompletionModel + 'static,
M::StreamingResponse: GetTokenUsage,
{
fn stream_prompt(
&self,
prompt: impl Into<Message> + WasmCompatSend,
) -> StreamingPromptRequest<M, ()> {
let arc = Arc::new(self.clone());
StreamingPromptRequest::new(arc, prompt)
}
}
impl<M> StreamingChat<M, M::StreamingResponse> for Agent<M>
where
M: CompletionModel + 'static,
M::StreamingResponse: GetTokenUsage,
{
fn stream_chat(
&self,
prompt: impl Into<Message> + WasmCompatSend,
chat_history: Vec<Message>,
) -> StreamingPromptRequest<M, ()> {
let arc = Arc::new(self.clone());
StreamingPromptRequest::new(arc, prompt).with_history(chat_history)
}
}