use std::{fmt, sync::Arc, time::Duration};
use serde::{Deserialize, Serialize};
use tokio::time::Instant;
use crate::{
annotations::ToolAnnotations, core::config::Tool, tenant::TenantId, transform::ResponseFormat,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub enum ToolCategory {
#[default]
Static,
Alias,
Dynamic,
Custom,
Builtin,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct QualifiedToolName {
server_key: Arc<str>,
tool_name: Arc<str>,
}
impl QualifiedToolName {
pub fn new(server_key: impl AsRef<str>, tool_name: impl AsRef<str>) -> Self {
Self {
server_key: Arc::from(server_key.as_ref()),
tool_name: Arc::from(tool_name.as_ref()),
}
}
pub fn parse(s: &str) -> Option<Self> {
let (server, tool) = s.split_once(':')?;
Some(Self::new(server, tool))
}
pub fn server_key(&self) -> &str {
&self.server_key
}
pub fn tool_name(&self) -> &str {
&self.tool_name
}
}
impl fmt::Display for QualifiedToolName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}:{}", self.server_key, self.tool_name)
}
}
impl Serialize for QualifiedToolName {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(&self.to_string())
}
}
impl<'de> Deserialize<'de> for QualifiedToolName {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
Self::parse(&s).ok_or_else(|| serde::de::Error::custom("expected format: server:tool"))
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct AliasTarget {
pub target: QualifiedToolName,
pub arg_mapping: Option<ArgMapping>,
}
impl AliasTarget {
pub fn new(target: QualifiedToolName) -> Self {
Self {
target,
arg_mapping: None,
}
}
#[must_use]
pub fn with_arg_mapping(mut self, mapping: ArgMapping) -> Self {
self.arg_mapping = Some(mapping);
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct ArgMapping {
pub renames: Vec<(String, String)>,
pub defaults: Vec<(String, serde_json::Value)>,
pub overrides: Vec<(String, serde_json::Value)>,
}
impl ArgMapping {
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_rename(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
self.renames.push((from.into(), to.into()));
self
}
#[must_use]
pub fn with_default(mut self, name: impl Into<String>, value: serde_json::Value) -> Self {
self.defaults.push((name.into(), value));
self
}
#[must_use]
pub fn with_override(mut self, name: impl Into<String>, value: serde_json::Value) -> Self {
self.overrides.push((name.into(), value));
self
}
}
#[derive(Debug, Clone)]
pub struct ToolEntry {
pub qualified_name: QualifiedToolName,
pub tool: Tool,
pub category: ToolCategory,
pub annotations: ToolAnnotations,
pub tenant_id: Option<TenantId>,
pub alias_target: Option<AliasTarget>,
pub arg_mapping: Option<ArgMapping>,
pub cached_at: Instant,
pub ttl: Option<Duration>,
pub response_format: ResponseFormat,
}
impl ToolEntry {
pub fn new(qualified_name: QualifiedToolName, tool: Tool) -> Self {
Self {
qualified_name,
tool,
category: ToolCategory::default(),
annotations: ToolAnnotations::default(),
tenant_id: None,
alias_target: None,
arg_mapping: None,
cached_at: Instant::now(),
ttl: None,
response_format: ResponseFormat::default(),
}
}
pub fn from_server_tool(server_key: impl AsRef<str>, tool: Tool) -> Self {
let name = tool.name.to_string();
Self::new(QualifiedToolName::new(server_key, name), tool)
}
#[must_use]
pub fn with_category(mut self, category: ToolCategory) -> Self {
self.category = category;
self
}
#[must_use]
pub fn with_annotations(mut self, annotations: ToolAnnotations) -> Self {
self.annotations = annotations;
self
}
#[must_use]
pub fn with_tenant(mut self, tenant_id: TenantId) -> Self {
self.tenant_id = Some(tenant_id);
self
}
#[must_use]
pub fn with_alias(mut self, target: AliasTarget) -> Self {
self.alias_target = Some(target);
self.category = ToolCategory::Alias;
self
}
#[must_use]
pub fn with_arg_mapping(mut self, mapping: ArgMapping) -> Self {
self.arg_mapping = Some(mapping);
self
}
#[must_use]
pub fn with_ttl(mut self, ttl: Duration) -> Self {
self.ttl = Some(ttl);
self
}
#[must_use]
pub fn with_response_format(mut self, response_format: ResponseFormat) -> Self {
self.response_format = response_format;
self
}
pub fn is_expired(&self) -> bool {
self.ttl
.map(|ttl| self.cached_at.elapsed() > ttl)
.unwrap_or(false)
}
pub fn server_key(&self) -> &str {
&self.qualified_name.server_key
}
pub fn tool_name(&self) -> &str {
&self.qualified_name.tool_name
}
}
#[cfg(test)]
mod tests {
use std::{borrow::Cow, sync::Arc};
use super::*;
fn create_test_tool(name: &str) -> Tool {
let schema_obj = serde_json::json!({
"type": "object",
"properties": {}
});
let schema_map = if let serde_json::Value::Object(m) = schema_obj {
m
} else {
serde_json::Map::new()
};
Tool {
name: Cow::Owned(name.to_string()),
title: None,
description: Some(Cow::Owned(format!("Test tool: {name}"))),
input_schema: Arc::new(schema_map),
output_schema: None,
annotations: None,
icons: None,
}
}
#[test]
fn test_qualified_tool_name() {
let name = QualifiedToolName::new("server", "tool");
assert_eq!(name.server_key(), "server");
assert_eq!(name.tool_name(), "tool");
assert_eq!(format!("{name}"), "server:tool");
}
#[test]
fn test_qualified_tool_name_parse() {
let parsed = QualifiedToolName::parse("brave:web_search").unwrap();
assert_eq!(parsed.server_key(), "brave");
assert_eq!(parsed.tool_name(), "web_search");
assert!(QualifiedToolName::parse("no_colon").is_none());
}
#[test]
fn test_tool_entry_creation() {
let tool = create_test_tool("my_tool");
let entry = ToolEntry::from_server_tool("my_server", tool);
assert_eq!(entry.server_key(), "my_server");
assert_eq!(entry.tool_name(), "my_tool");
assert_eq!(entry.category, ToolCategory::Static);
assert!(!entry.is_expired());
}
#[test]
fn test_tool_entry_with_alias() {
let tool = create_test_tool("web_search");
let target = AliasTarget::new(QualifiedToolName::new("brave", "brave_web_search"));
let entry = ToolEntry::from_server_tool("aliases", tool).with_alias(target);
assert_eq!(entry.category, ToolCategory::Alias);
assert!(entry.alias_target.is_some());
assert_eq!(
entry.alias_target.unwrap().target.tool_name(),
"brave_web_search"
);
}
#[test]
fn test_tool_entry_expiration() {
let tool = create_test_tool("expiring_tool");
let entry = ToolEntry::from_server_tool("server", tool).with_ttl(Duration::from_millis(1));
assert!(!entry.is_expired());
std::thread::sleep(Duration::from_millis(5));
assert!(entry.is_expired());
}
#[test]
fn test_arg_mapping() {
let mapping = ArgMapping::new()
.with_rename("query", "search_query")
.with_default("limit", serde_json::json!(10));
assert_eq!(mapping.renames.len(), 1);
assert_eq!(mapping.defaults.len(), 1);
}
}