use std::collections::HashMap;
use futures::{stream, StreamExt, TryStreamExt};
use crate::{
completion::{
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
Message, Prompt, PromptError,
},
message::AssistantContent,
streaming::{
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
StreamingResult,
},
tool::{Tool, ToolSet},
vector_store::{VectorStoreError, VectorStoreIndexDyn},
};
pub struct Agent<M: CompletionModel> {
model: M,
preamble: String,
static_context: Vec<Document>,
static_tools: Vec<String>,
temperature: Option<f64>,
max_tokens: Option<u64>,
additional_params: Option<serde_json::Value>,
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
pub tools: ToolSet,
}
impl<M: CompletionModel> Completion<M> for Agent<M> {
async fn completion(
&self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
let prompt = prompt.into();
let rag_text = prompt.rag_text().clone();
let completion_request = self
.model
.completion_request(prompt)
.preamble(self.preamble.clone())
.messages(chat_history)
.temperature_opt(self.temperature)
.max_tokens_opt(self.max_tokens)
.additional_params_opt(self.additional_params.clone())
.documents(self.static_context.clone());
let agent = match &rag_text {
Some(text) => {
let dynamic_context = stream::iter(self.dynamic_context.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n(text, *num_sample)
.await?
.into_iter()
.map(|(_, id, doc)| {
let text = serde_json::to_string_pretty(&doc)
.unwrap_or_else(|_| doc.to_string());
Document {
id,
text,
additional_props: HashMap::new(),
}
})
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
acc.extend(docs);
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
let dynamic_tools = stream::iter(self.dynamic_tools.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n_ids(text, *num_sample)
.await?
.into_iter()
.map(|(_, id)| id)
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
for doc in docs {
if let Some(tool) = self.tools.get(&doc) {
acc.push(tool.definition(text.into()).await)
} else {
tracing::warn!("Tool implementation not found in toolset: {}", doc);
}
}
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
let static_tools = stream::iter(self.static_tools.iter())
.filter_map(|toolname| async move {
if let Some(tool) = self.tools.get(toolname) {
Some(tool.definition(text.into()).await)
} else {
tracing::warn!(
"Tool implementation not found in toolset: {}",
toolname
);
None
}
})
.collect::<Vec<_>>()
.await;
completion_request
.documents(dynamic_context)
.tools([static_tools.clone(), dynamic_tools].concat())
}
None => {
let static_tools = stream::iter(self.static_tools.iter())
.filter_map(|toolname| async move {
if let Some(tool) = self.tools.get(toolname) {
Some(tool.definition("".into()).await)
} else {
tracing::warn!(
"Tool implementation not found in toolset: {}",
toolname
);
None
}
})
.collect::<Vec<_>>()
.await;
completion_request.tools(static_tools)
}
};
Ok(agent)
}
}
impl<M: CompletionModel> Prompt for Agent<M> {
async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
self.chat(prompt, vec![]).await
}
}
impl<M: CompletionModel> Prompt for &Agent<M> {
async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
self.chat(prompt, vec![]).await
}
}
impl<M: CompletionModel> Chat for Agent<M> {
async fn chat(
&self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> Result<String, PromptError> {
let resp = self.completion(prompt, chat_history).await?.send().await?;
match resp.choice.first() {
AssistantContent::Text(text) => Ok(text.text.clone()),
AssistantContent::ToolCall(tool_call) => Ok(self
.tools
.call(
&tool_call.function.name,
tool_call.function.arguments.to_string(),
)
.await?),
}
}
}
pub struct AgentBuilder<M: CompletionModel> {
model: M,
preamble: Option<String>,
static_context: Vec<Document>,
static_tools: Vec<String>,
additional_params: Option<serde_json::Value>,
max_tokens: Option<u64>,
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
temperature: Option<f64>,
tools: ToolSet,
}
impl<M: CompletionModel> AgentBuilder<M> {
pub fn new(model: M) -> Self {
Self {
model,
preamble: None,
static_context: vec![],
static_tools: vec![],
temperature: None,
max_tokens: None,
additional_params: None,
dynamic_context: vec![],
dynamic_tools: vec![],
tools: ToolSet::default(),
}
}
pub fn preamble(mut self, preamble: &str) -> Self {
self.preamble = Some(preamble.into());
self
}
pub fn append_preamble(mut self, doc: &str) -> Self {
self.preamble = Some(format!(
"{}\n{}",
self.preamble.unwrap_or_else(|| "".into()),
doc
));
self
}
pub fn context(mut self, doc: &str) -> Self {
self.static_context.push(Document {
id: format!("static_doc_{}", self.static_context.len()),
text: doc.into(),
additional_props: HashMap::new(),
});
self
}
pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
let toolname = tool.name();
self.tools.add_tool(tool);
self.static_tools.push(toolname);
self
}
pub fn dynamic_context(
mut self,
sample: usize,
dynamic_context: impl VectorStoreIndexDyn + 'static,
) -> Self {
self.dynamic_context
.push((sample, Box::new(dynamic_context)));
self
}
pub fn dynamic_tools(
mut self,
sample: usize,
dynamic_tools: impl VectorStoreIndexDyn + 'static,
toolset: ToolSet,
) -> Self {
self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
self.tools.add_tools(toolset);
self
}
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn additional_params(mut self, params: serde_json::Value) -> Self {
self.additional_params = Some(params);
self
}
pub fn build(self) -> Agent<M> {
Agent {
model: self.model,
preamble: self.preamble.unwrap_or_default(),
static_context: self.static_context,
static_tools: self.static_tools,
temperature: self.temperature,
max_tokens: self.max_tokens,
additional_params: self.additional_params,
dynamic_context: self.dynamic_context,
dynamic_tools: self.dynamic_tools,
tools: self.tools,
}
}
}
impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
async fn stream_completion(
&self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
self.completion(prompt, chat_history).await
}
}
impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> {
async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> {
self.stream_chat(prompt, vec![]).await
}
}
impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
async fn stream_chat(
&self,
prompt: &str,
chat_history: Vec<Message>,
) -> Result<StreamingResult, CompletionError> {
self.stream_completion(prompt, chat_history)
.await?
.stream()
.await
}
}