use std::collections::HashMap;
use futures::{stream, StreamExt, TryStreamExt};
use crate::{
completion::{
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder,
CompletionResponse, Document, Message, ModelChoice, Prompt, PromptError,
},
tool::{Tool, ToolSet},
vector_store::{VectorStoreError, VectorStoreIndexDyn},
};
pub struct Agent<M: CompletionModel> {
model: M,
preamble: String,
static_context: Vec<Document>,
static_tools: Vec<String>,
temperature: Option<f64>,
max_tokens: Option<u64>,
additional_params: Option<serde_json::Value>,
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
pub tools: ToolSet,
}
impl<M: CompletionModel> Completion<M> for Agent<M> {
async fn completion(
&self,
prompt: &str,
chat_history: Vec<Message>,
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
let dynamic_context = stream::iter(self.dynamic_context.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n(prompt, *num_sample)
.await?
.into_iter()
.map(|(_, id, doc)| {
let text = serde_json::to_string_pretty(&doc)
.unwrap_or_else(|_| doc.to_string());
Document {
id,
text,
additional_props: HashMap::new(),
}
})
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
acc.extend(docs);
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
let dynamic_tools = stream::iter(self.dynamic_tools.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n_ids(prompt, *num_sample)
.await?
.into_iter()
.map(|(_, id)| id)
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
for doc in docs {
if let Some(tool) = self.tools.get(&doc) {
acc.push(tool.definition(prompt.into()).await)
} else {
tracing::warn!("Tool implementation not found in toolset: {}", doc);
}
}
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
let static_tools = stream::iter(self.static_tools.iter())
.filter_map(|toolname| async move {
if let Some(tool) = self.tools.get(toolname) {
Some(tool.definition(prompt.into()).await)
} else {
tracing::warn!("Tool implementation not found in toolset: {}", toolname);
None
}
})
.collect::<Vec<_>>()
.await;
Ok(self
.model
.completion_request(prompt)
.preamble(self.preamble.clone())
.messages(chat_history)
.documents([self.static_context.clone(), dynamic_context].concat())
.tools([static_tools.clone(), dynamic_tools].concat())
.temperature_opt(self.temperature)
.max_tokens_opt(self.max_tokens)
.additional_params_opt(self.additional_params.clone()))
}
}
impl<M: CompletionModel> Prompt for Agent<M> {
async fn prompt(&self, prompt: &str) -> Result<String, PromptError> {
self.chat(prompt, vec![]).await
}
}
impl<M: CompletionModel> Prompt for &Agent<M> {
async fn prompt(&self, prompt: &str) -> Result<String, PromptError> {
self.chat(prompt, vec![]).await
}
}
impl<M: CompletionModel> Chat for Agent<M> {
async fn chat(&self, prompt: &str, chat_history: Vec<Message>) -> Result<String, PromptError> {
match self.completion(prompt, chat_history).await?.send().await? {
CompletionResponse {
choice: ModelChoice::Message(msg),
..
} => Ok(msg),
CompletionResponse {
choice: ModelChoice::ToolCall(toolname, args),
..
} => Ok(self.tools.call(&toolname, args.to_string()).await?),
}
}
}
pub struct AgentBuilder<M: CompletionModel> {
model: M,
preamble: Option<String>,
static_context: Vec<Document>,
static_tools: Vec<String>,
additional_params: Option<serde_json::Value>,
max_tokens: Option<u64>,
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
temperature: Option<f64>,
tools: ToolSet,
}
impl<M: CompletionModel> AgentBuilder<M> {
pub fn new(model: M) -> Self {
Self {
model,
preamble: None,
static_context: vec![],
static_tools: vec![],
temperature: None,
max_tokens: None,
additional_params: None,
dynamic_context: vec![],
dynamic_tools: vec![],
tools: ToolSet::default(),
}
}
pub fn preamble(mut self, preamble: &str) -> Self {
self.preamble = Some(preamble.into());
self
}
pub fn append_preamble(mut self, doc: &str) -> Self {
self.preamble = Some(format!(
"{}\n{}",
self.preamble.unwrap_or_else(|| "".into()),
doc
));
self
}
pub fn context(mut self, doc: &str) -> Self {
self.static_context.push(Document {
id: format!("static_doc_{}", self.static_context.len()),
text: doc.into(),
additional_props: HashMap::new(),
});
self
}
pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
let toolname = tool.name();
self.tools.add_tool(tool);
self.static_tools.push(toolname);
self
}
pub fn dynamic_context(
mut self,
sample: usize,
dynamic_context: impl VectorStoreIndexDyn + 'static,
) -> Self {
self.dynamic_context
.push((sample, Box::new(dynamic_context)));
self
}
pub fn dynamic_tools(
mut self,
sample: usize,
dynamic_tools: impl VectorStoreIndexDyn + 'static,
toolset: ToolSet,
) -> Self {
self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
self.tools.add_tools(toolset);
self
}
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn additional_params(mut self, params: serde_json::Value) -> Self {
self.additional_params = Some(params);
self
}
pub fn build(self) -> Agent<M> {
Agent {
model: self.model,
preamble: self.preamble.unwrap_or_default(),
static_context: self.static_context,
static_tools: self.static_tools,
temperature: self.temperature,
max_tokens: self.max_tokens,
additional_params: self.additional_params,
dynamic_context: self.dynamic_context,
dynamic_tools: self.dynamic_tools,
tools: self.tools,
}
}
}