use crate::tool::ToolBox;
use anyhow::{anyhow, Result};
use genai::adapter::AdapterKind;
use genai::chat::{ChatMessage, ChatOptions, ChatRequest, JsonSpec, MessageContent, ToolResponse};
use genai::resolver::{AuthData, Endpoint, ServiceTargetResolver};
use genai::{Client, ClientBuilder, ModelIden, ServiceTarget};
use log::{debug, trace};
use schemars::{schema_for, JsonSchema};
use serde::de::DeserializeOwned;
use serde_json::{from_str, json, Value};
use std::any::TypeId;
use std::sync::Arc;
#[derive(Clone)]
pub struct Agent {
client: Client,
history: Vec<ChatMessage>,
}
impl Agent {
pub fn new(system: &str) -> Self {
let client = Client::default();
Self::new_with_client(client, system)
}
pub fn new_with_client(client: Client, system: &str) -> Self {
Self {
client,
history: vec![ChatMessage::system(system.trim())],
}
}
pub fn new_with_url(base_url: &str, api_key: &str, system: &str) -> Self {
let endpoint = Endpoint::from_owned(Arc::from(base_url));
let auth = AuthData::from_single(api_key);
let target_resolver = ServiceTargetResolver::from_resolver_fn(
|service_target: ServiceTarget| -> Result<ServiceTarget, genai::resolver::Error> {
let ServiceTarget { model, .. } = service_target;
let model = ModelIden::new(AdapterKind::OpenAI, model.model_name);
Ok(ServiceTarget {
endpoint,
auth,
model,
})
},
);
let client = ClientBuilder::default()
.with_service_target_resolver(target_resolver)
.build();
Self::new_with_client(client, system)
}
pub async fn run<D>(
&mut self,
model: &str,
prompt: &str,
toolbox: Option<&dyn ToolBox>,
) -> Result<D>
where
D: DeserializeOwned + JsonSchema + 'static,
{
debug!("Agent Question: {prompt}");
self.history.push(ChatMessage::user(prompt));
let mut chat_opts = ChatOptions::default().with_temperature(0.2);
let is_answer_string = TypeId::of::<String>() == TypeId::of::<D>();
if !is_answer_string {
let mut response_schema = serde_json::to_value(schema_for!(D))?;
let obj = response_schema.as_object_mut().unwrap();
obj.remove("$schema");
obj.remove("title");
chat_opts = chat_opts.with_response_format(JsonSpec::new("ResponseFormat", json!(obj)));
}
let max_iterations = 5;
for iteration in 0..max_iterations {
debug!("Agent iteration: {iteration}");
let mut chat_req = ChatRequest::new(self.history.clone());
if let Some(toolbox) = toolbox {
chat_req = chat_req.with_tools(toolbox.tools_definitions()?);
}
let chat_resp = self
.client
.exec_chat(model, chat_req, Some(&chat_opts))
.await?;
match chat_resp.content {
Some(MessageContent::Text(text)) => {
let mut resp = text;
debug!("Agent Answer: {resp}");
self.history.push(ChatMessage::assistant(resp.clone()));
if is_answer_string {
resp = Value::String(resp).to_string();
}
let resp = from_str(&resp)?;
return Ok(resp);
}
Some(MessageContent::ToolCalls(tools_call)) => {
self.history.push(ChatMessage::from(tools_call.clone()));
for tool_request in tools_call {
trace!(
"Tool request: {} with arguments: {}",
tool_request.fn_name,
tool_request.fn_arguments
);
if let Some(tool) = toolbox {
match tool
.call_tool(tool_request.fn_name, tool_request.fn_arguments)
.await
{
Ok(result) => {
trace!("Tool result: {result}");
self.history.push(ChatMessage::from(ToolResponse::new(
tool_request.call_id.clone(),
result,
)));
}
Err(err) => {
trace!("Error: {err}");
self.history.push(ChatMessage::from(ToolResponse::new(
tool_request.call_id.clone(),
err.to_string(),
)));
}
};
} else {
todo!("No tool found for {}", tool_request.fn_name);
}
}
}
Some(msg_content) => {
return Err(anyhow!(format!(
"Unsupported message content {:?}",
msg_content
)));
}
None => {}
};
}
Err(anyhow!(format!(
"Unable to get response in {max_iterations} tries"
)))
}
}