use std::collections::BTreeMap;
use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolKind {
Read,
Edit,
Delete,
Move,
Search,
Execute,
Think,
Fetch,
#[default]
Other,
}
impl ToolKind {
pub fn is_read_only(&self) -> bool {
matches!(self, Self::Read | Self::Search | Self::Think | Self::Fetch)
}
pub fn mutation_class(&self) -> &'static str {
match self {
Self::Read | Self::Search | Self::Think | Self::Fetch => "read_only",
Self::Edit => "workspace_write",
Self::Delete | Self::Move => "destructive",
Self::Execute => "ambient_side_effect",
Self::Other => "other",
}
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SideEffectLevel {
#[default]
None,
ReadOnly,
WorkspaceWrite,
ProcessExec,
Network,
}
impl SideEffectLevel {
pub fn rank(&self) -> usize {
match self {
Self::None => 0,
Self::ReadOnly => 1,
Self::WorkspaceWrite => 2,
Self::ProcessExec => 3,
Self::Network => 4,
}
}
pub fn as_str(&self) -> &'static str {
match self {
Self::None => "none",
Self::ReadOnly => "read_only",
Self::WorkspaceWrite => "workspace_write",
Self::ProcessExec => "process_exec",
Self::Network => "network",
}
}
pub fn parse(value: &str) -> Self {
match value {
"none" => Self::None,
"read_only" => Self::ReadOnly,
"workspace_write" => Self::WorkspaceWrite,
"process_exec" => Self::ProcessExec,
"network" => Self::Network,
_ => Self::None,
}
}
}
#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
#[serde(default)]
pub struct ToolArgSchema {
pub path_params: Vec<String>,
pub arg_aliases: BTreeMap<String, String>,
pub required: Vec<String>,
}
#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
#[serde(default)]
pub struct ToolAnnotations {
pub kind: ToolKind,
pub side_effect_level: SideEffectLevel,
pub arg_schema: ToolArgSchema,
pub capabilities: BTreeMap<String, Vec<String>>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tool_kind_serde_roundtrip() {
for (kind, expected) in [
(ToolKind::Read, "\"read\""),
(ToolKind::Edit, "\"edit\""),
(ToolKind::Delete, "\"delete\""),
(ToolKind::Move, "\"move\""),
(ToolKind::Search, "\"search\""),
(ToolKind::Execute, "\"execute\""),
(ToolKind::Think, "\"think\""),
(ToolKind::Fetch, "\"fetch\""),
(ToolKind::Other, "\"other\""),
] {
let encoded = serde_json::to_string(&kind).unwrap();
assert_eq!(encoded, expected);
let decoded: ToolKind = serde_json::from_str(expected).unwrap();
assert_eq!(decoded, kind);
}
}
#[test]
fn only_read_search_think_fetch_are_read_only() {
assert!(ToolKind::Read.is_read_only());
assert!(ToolKind::Search.is_read_only());
assert!(ToolKind::Think.is_read_only());
assert!(ToolKind::Fetch.is_read_only());
assert!(!ToolKind::Other.is_read_only());
assert!(!ToolKind::Edit.is_read_only());
assert!(!ToolKind::Delete.is_read_only());
assert!(!ToolKind::Move.is_read_only());
assert!(!ToolKind::Execute.is_read_only());
}
#[test]
fn mutation_class_derived_from_kind() {
assert_eq!(ToolKind::Read.mutation_class(), "read_only");
assert_eq!(ToolKind::Search.mutation_class(), "read_only");
assert_eq!(ToolKind::Edit.mutation_class(), "workspace_write");
assert_eq!(ToolKind::Delete.mutation_class(), "destructive");
assert_eq!(ToolKind::Move.mutation_class(), "destructive");
assert_eq!(ToolKind::Execute.mutation_class(), "ambient_side_effect");
assert_eq!(ToolKind::Other.mutation_class(), "other");
}
#[test]
fn side_effect_level_round_trip() {
for level in [
SideEffectLevel::None,
SideEffectLevel::ReadOnly,
SideEffectLevel::WorkspaceWrite,
SideEffectLevel::ProcessExec,
SideEffectLevel::Network,
] {
assert_eq!(SideEffectLevel::parse(level.as_str()), level);
let encoded = serde_json::to_string(&level).unwrap();
let decoded: SideEffectLevel = serde_json::from_str(&encoded).unwrap();
assert_eq!(decoded, level);
}
}
#[test]
fn side_effect_level_rank_orders() {
assert!(SideEffectLevel::None.rank() < SideEffectLevel::ReadOnly.rank());
assert!(SideEffectLevel::ReadOnly.rank() < SideEffectLevel::WorkspaceWrite.rank());
assert!(SideEffectLevel::WorkspaceWrite.rank() < SideEffectLevel::ProcessExec.rank());
assert!(SideEffectLevel::ProcessExec.rank() < SideEffectLevel::Network.rank());
}
#[test]
fn arg_schema_defaults_empty() {
let schema = ToolArgSchema::default();
assert!(schema.path_params.is_empty());
assert!(schema.arg_aliases.is_empty());
assert!(schema.required.is_empty());
}
}