use std::collections::HashMap;
use futures::{stream, StreamExt, TryStreamExt};
use crate::{
completion::{
Completion, CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse,
Document, Message, ModelChoice, Prompt, PromptError,
},
tool::{Tool, ToolSet, ToolSetError},
vector_store::{NoIndex, VectorStoreError, VectorStoreIndex},
};
pub struct RagAgent<M: CompletionModel, C: VectorStoreIndex, T: VectorStoreIndex> {
model: M,
preamble: String,
static_context: Vec<Document>,
static_tools: Vec<String>,
temperature: Option<f64>,
additional_params: Option<serde_json::Value>,
dynamic_context: Vec<(usize, C)>,
dynamic_tools: Vec<(usize, T)>,
pub tools: ToolSet,
}
pub type ToolRagAgent<M, T> = RagAgent<M, NoIndex, T>;
pub type ContextRagAgent<M, C> = RagAgent<M, C, NoIndex>;
impl<M: CompletionModel, C: VectorStoreIndex, T: VectorStoreIndex> Completion<M>
for RagAgent<M, C, T>
{
async fn completion(
&self,
prompt: &str,
chat_history: Vec<Message>,
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
let dynamic_context = stream::iter(self.dynamic_context.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n_from_query(prompt, *num_sample)
.await?
.into_iter()
.map(|(_, doc)| {
let doc_text = serde_json::to_string_pretty(&doc.document)
.unwrap_or_else(|_| doc.document.to_string());
Document {
id: doc.id,
text: doc_text,
additional_props: HashMap::new(),
}
})
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
acc.extend(docs);
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
let dynamic_tools = stream::iter(self.dynamic_tools.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n_ids_from_query(prompt, *num_sample)
.await?
.into_iter()
.map(|(_, doc)| doc)
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
for doc in docs {
if let Some(tool) = self.tools.get(&doc) {
acc.push(tool.definition(prompt.into()).await)
} else {
tracing::warn!("Tool implementation not found in toolset: {}", doc);
}
}
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;
let static_tools = stream::iter(self.static_tools.iter())
.filter_map(|toolname| async move {
if let Some(tool) = self.tools.get(toolname) {
Some(tool.definition(prompt.into()).await)
} else {
tracing::warn!("Tool implementation not found in toolset: {}", toolname);
None
}
})
.collect::<Vec<_>>()
.await;
Ok(self
.model
.completion_request(prompt)
.preamble(self.preamble.clone())
.messages(chat_history)
.documents([self.static_context.clone(), dynamic_context].concat())
.tools([static_tools.clone(), dynamic_tools].concat())
.temperature_opt(self.temperature)
.additional_params_opt(self.additional_params.clone()))
}
}
impl<M: CompletionModel, C: VectorStoreIndex, T: VectorStoreIndex> Prompt for RagAgent<M, C, T> {
async fn chat(&self, prompt: &str, chat_history: Vec<Message>) -> Result<String, PromptError> {
match self.completion(prompt, chat_history).await?.send().await? {
CompletionResponse {
choice: ModelChoice::Message(msg),
..
} => Ok(msg),
CompletionResponse {
choice: ModelChoice::ToolCall(toolname, args),
..
} => Ok(self.tools.call(&toolname, args.to_string()).await?),
}
}
}
impl<M: CompletionModel, C: VectorStoreIndex, T: VectorStoreIndex> RagAgent<M, C, T> {
pub async fn call_tool(&self, toolname: &str, args: &str) -> Result<String, ToolSetError> {
self.tools.call(toolname, args.to_string()).await
}
}
pub struct RagAgentBuilder<M: CompletionModel, C: VectorStoreIndex, T: VectorStoreIndex> {
model: M,
preamble: Option<String>,
static_context: Vec<Document>,
static_tools: Vec<String>,
additional_params: Option<serde_json::Value>,
dynamic_context: Vec<(usize, C)>,
dynamic_tools: Vec<(usize, T)>,
temperature: Option<f64>,
tools: ToolSet,
}
impl<M: CompletionModel, C: VectorStoreIndex, T: VectorStoreIndex> RagAgentBuilder<M, C, T> {
pub fn new(model: M) -> Self {
Self {
model,
preamble: None,
static_context: vec![],
static_tools: vec![],
temperature: None,
additional_params: None,
dynamic_context: vec![],
dynamic_tools: vec![],
tools: ToolSet::default(),
}
}
pub fn preamble(mut self, preamble: &str) -> Self {
self.preamble = Some(preamble.into());
self
}
pub fn static_context(mut self, doc: &str) -> Self {
self.static_context.push(Document {
id: format!("static_doc_{}", self.static_context.len()),
text: doc.into(),
additional_props: HashMap::new(),
});
self
}
pub fn static_tool(mut self, tool: impl Tool + 'static) -> Self {
let toolname = tool.name();
self.tools.add_tool(tool);
self.static_tools.push(toolname);
self
}
pub fn dynamic_context(mut self, sample: usize, dynamic_context: C) -> Self {
self.dynamic_context.push((sample, dynamic_context));
self
}
pub fn dynamic_tools(mut self, sample: usize, dynamic_tools: T, toolset: ToolSet) -> Self {
self.dynamic_tools.push((sample, dynamic_tools));
self.tools.add_tools(toolset);
self
}
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
pub fn build(self) -> RagAgent<M, C, T> {
RagAgent {
model: self.model,
preamble: self.preamble.unwrap_or_default(),
static_context: self.static_context,
static_tools: self.static_tools,
temperature: self.temperature,
additional_params: self.additional_params,
dynamic_context: self.dynamic_context,
dynamic_tools: self.dynamic_tools,
tools: self.tools,
}
}
}