use std::collections::HashMap;
use std::fmt::Debug;
use serde::ser::SerializeMap;
use serde::{Deserialize, Serialize, Serializer};
use toolbox::Toolbox;
use tracing::warn;
use crate::tools::invocation::{ExtractedInvocations, InvocationError};
pub mod invocation;
pub mod toolbox;
#[derive(Debug, Clone)]
pub struct FieldFormat {
pub name: String,
pub r#type: String,
pub optional: bool,
pub description: String,
}
pub trait Describe {
fn describe() -> Format;
}
#[derive(Debug, Clone, Default)]
pub struct Format {
pub fields: Vec<FieldFormat>,
}
impl Serialize for Format {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let n = self.fields.len();
let mut map = serializer.serialize_map(Some(n))?;
for field in &self.fields {
let description = if field.optional {
format!("<{}> {} (optional)", field.r#type, field.description)
} else {
format!("<{}> {}", field.r#type, field.description)
};
map.serialize_entry(&field.name, &description)?;
}
map.end()
}
}
impl From<Vec<FieldFormat>> for Format {
fn from(fields: Vec<FieldFormat>) -> Self {
Format { fields }
}
}
#[derive(Debug, Serialize, Clone)]
pub struct ToolDescription {
pub name: String,
pub description: String,
pub parameters: Format,
pub responses_content: Format,
}
impl ToolDescription {
pub fn new(
name: &str,
description: &str,
parameters: Format,
responses_content: Format,
) -> Self {
ToolDescription {
name: name.to_string(),
description: description.to_string(),
parameters,
responses_content,
}
}
}
#[derive(Debug, thiserror::Error, Clone, Serialize, Deserialize)]
pub enum ToolUseError {
#[error("Tool not found: {0}")]
ToolNotFound(String),
#[error("Tool invocation failed: {0}")]
InvocationFailed(String),
#[error("Failed to serialize the output: {0}")]
InvalidOutput(String),
#[error("Failed to deserialize the parameters: {0}")]
InvalidInput(String),
}
#[derive(Serialize, Deserialize, Debug)]
pub(crate) struct ToolInvocationInput {
tool_name: String,
parameters: serde_yaml::Value,
#[serde(skip_serializing_if = "HashMap::is_empty", flatten)]
junk: HashMap<String, serde_yaml::Value>,
}
pub trait ProtoToolDescribe {
fn description(&self) -> ToolDescription;
}
#[async_trait::async_trait]
pub trait ProtoToolInvoke {
async fn invoke(&self, input: serde_yaml::Value) -> Result<serde_yaml::Value, ToolUseError>;
}
#[async_trait::async_trait]
pub trait Tool: Sync + Send {
fn description(&self) -> ToolDescription;
async fn invoke(&self, input: serde_yaml::Value) -> Result<serde_yaml::Value, ToolUseError>;
}
#[async_trait::async_trait]
impl<T: Sync + Send> Tool for T
where
T: ProtoToolDescribe + ProtoToolInvoke,
{
fn description(&self) -> ToolDescription {
self.description()
}
async fn invoke(&self, input: serde_yaml::Value) -> Result<serde_yaml::Value, ToolUseError> {
self.invoke(input).await
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TerminationMessage {
pub conclusion: String,
pub original_question: String,
}
#[async_trait::async_trait]
pub trait TerminalTool: Tool + Sync + Send {
async fn is_done(&self) -> bool {
false
}
async fn take_done(&self) -> Option<TerminationMessage> {
None
}
}
#[async_trait::async_trait]
pub trait AdvancedTool: Tool {
async fn invoke_with_toolbox(
&self,
toolbox: Toolbox,
input: serde_yaml::Value,
) -> Result<serde_yaml::Value, ToolUseError>;
}
async fn choose_invocation(
tool_invocations: ExtractedInvocations,
) -> Result<ToolInvocationInput, InvocationError> {
if tool_invocations.yaml_block_count > 1 {
return Err(InvocationError::TooManyYamlBlocks(
tool_invocations.yaml_block_count,
));
}
if tool_invocations.invocations.is_empty() {
return Err(InvocationError::NoInvocationFound);
}
let mut invocation = tool_invocations.invocations.into_iter().next().unwrap();
if !invocation.junk.is_empty() {
let junk_keys = invocation
.junk
.keys()
.cloned()
.collect::<Vec<String>>()
.join(", ");
warn!(
?junk_keys,
"The Action should not have fields: {}.", junk_keys
);
invocation.junk.clear();
}
Ok(invocation)
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use insta::assert_display_snapshot;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
struct FakeToolInput {
q: String,
excluded_terms: Option<String>,
num_results: Option<u32>,
}
#[derive(Debug, Serialize, Deserialize)]
struct FakeToolOutput {
items: Vec<String>,
}
#[tokio::test]
async fn test_serializing_tool_invocation() {
let input = FakeToolInput {
q: "Marcel Deneuve".to_string(),
excluded_terms: Some("Resident Evil".to_string()),
num_results: Some(10),
};
let output = FakeToolOutput {
items: vec![
"Marcel Deneuve is a character in the Resident Evil film series,".to_string(),
"playing a minor role in Resident Evil: Apocalypse and a much larger".to_string(),
" role in Resident Evil: Extinction. Explore historical records and ".to_string(),
"family tree profiles about Marcel Deneuve on MyHeritage, the world's largest family network.".to_string()
]
};
let junk = vec![("output".to_string(), serde_yaml::to_value(output).unwrap())];
let invocation = super::ToolInvocationInput {
tool_name: "Search".to_string(),
parameters: serde_yaml::to_value(input).unwrap(),
junk: HashMap::from_iter(junk.into_iter()),
};
let serialized = serde_yaml::to_string(&invocation).unwrap();
assert_display_snapshot!(serialized);
}
}