use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::error::Result;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolCategory {
FilesystemRead,
FilesystemWrite,
NetworkRead,
NetworkWrite,
Shell,
Hardware,
Memory,
Messaging,
Destructive,
}
impl ToolCategory {
pub fn all() -> [ToolCategory; 9] {
[
ToolCategory::FilesystemRead,
ToolCategory::FilesystemWrite,
ToolCategory::NetworkRead,
ToolCategory::NetworkWrite,
ToolCategory::Shell,
ToolCategory::Hardware,
ToolCategory::Memory,
ToolCategory::Messaging,
ToolCategory::Destructive,
]
}
}
impl std::fmt::Display for ToolCategory {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::FilesystemRead => write!(f, "filesystem_read"),
Self::FilesystemWrite => write!(f, "filesystem_write"),
Self::NetworkRead => write!(f, "network_read"),
Self::NetworkWrite => write!(f, "network_write"),
Self::Shell => write!(f, "shell"),
Self::Hardware => write!(f, "hardware"),
Self::Memory => write!(f, "memory"),
Self::Messaging => write!(f, "messaging"),
Self::Destructive => write!(f, "destructive"),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ToolOutput {
pub for_llm: String,
pub for_user: Option<String>,
pub is_error: bool,
pub is_async: bool,
pub pause_for_input: bool,
}
impl ToolOutput {
pub fn llm_only(content: impl Into<String>) -> Self {
Self {
for_llm: content.into(),
for_user: None,
is_error: false,
is_async: false,
pause_for_input: false,
}
}
pub fn user_visible(content: impl Into<String>) -> Self {
let s = content.into();
Self {
for_llm: s.clone(),
for_user: Some(s),
is_error: false,
is_async: false,
pause_for_input: false,
}
}
pub fn error(content: impl Into<String>) -> Self {
Self {
for_llm: content.into(),
for_user: None,
is_error: true,
is_async: false,
pause_for_input: false,
}
}
pub fn async_task(content: impl Into<String>) -> Self {
Self {
for_llm: content.into(),
for_user: None,
is_error: false,
is_async: true,
pause_for_input: false,
}
}
pub fn split(for_llm: impl Into<String>, for_user: impl Into<String>) -> Self {
Self {
for_llm: for_llm.into(),
for_user: Some(for_user.into()),
is_error: false,
is_async: false,
pause_for_input: false,
}
}
pub fn with_pause(mut self) -> Self {
self.pause_for_input = true;
self
}
}
#[async_trait]
pub trait Tool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters(&self) -> Value;
async fn execute(&self, args: Value, ctx: &ToolContext) -> Result<ToolOutput>;
fn compact_description(&self) -> &str {
self.description()
}
fn category(&self) -> ToolCategory {
ToolCategory::Shell
}
}
#[derive(Debug, Clone, Default)]
pub struct ToolContext {
pub channel: Option<String>,
pub chat_id: Option<String>,
pub workspace: Option<String>,
pub is_batch: bool,
}
impl ToolContext {
pub fn new() -> Self {
Self::default()
}
pub fn with_channel(mut self, channel: &str, chat_id: &str) -> Self {
self.channel = Some(channel.to_string());
self.chat_id = Some(chat_id.to_string());
self
}
pub fn with_workspace(mut self, workspace: &str) -> Self {
self.workspace = Some(workspace.to_string());
self
}
pub fn with_batch(mut self, is_batch: bool) -> Self {
self.is_batch = is_batch;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_context_new() {
let ctx = ToolContext::new();
assert!(ctx.channel.is_none());
assert!(ctx.chat_id.is_none());
assert!(ctx.workspace.is_none());
}
#[test]
fn test_tool_context_default() {
let ctx = ToolContext::default();
assert!(ctx.channel.is_none());
assert!(ctx.chat_id.is_none());
assert!(ctx.workspace.is_none());
}
#[test]
fn test_tool_context_is_batch_default() {
let ctx = ToolContext::new();
assert!(!ctx.is_batch);
}
#[test]
fn test_tool_context_with_batch() {
let ctx = ToolContext::new().with_batch(true);
assert!(ctx.is_batch);
}
#[test]
fn test_tool_context_with_channel() {
let ctx = ToolContext::new().with_channel("telegram", "123456");
assert_eq!(ctx.channel.as_deref(), Some("telegram"));
assert_eq!(ctx.chat_id.as_deref(), Some("123456"));
assert!(ctx.workspace.is_none());
}
#[test]
fn test_tool_context_with_workspace() {
let ctx = ToolContext::new().with_workspace("/home/user/project");
assert!(ctx.channel.is_none());
assert!(ctx.chat_id.is_none());
assert_eq!(ctx.workspace.as_deref(), Some("/home/user/project"));
}
#[test]
fn test_tool_context_builder_chain() {
let ctx = ToolContext::new()
.with_channel("discord", "abc123")
.with_workspace("/tmp/workspace");
assert_eq!(ctx.channel.as_deref(), Some("discord"));
assert_eq!(ctx.chat_id.as_deref(), Some("abc123"));
assert_eq!(ctx.workspace.as_deref(), Some("/tmp/workspace"));
}
#[test]
fn test_tool_context_debug() {
let ctx = ToolContext::new().with_channel("cli", "test");
let debug_str = format!("{:?}", ctx);
assert!(debug_str.contains("ToolContext"));
assert!(debug_str.contains("cli"));
}
#[test]
fn test_tool_context_clone() {
let ctx1 = ToolContext::new()
.with_channel("telegram", "123")
.with_workspace("/test");
let ctx2 = ctx1.clone();
assert_eq!(ctx1.channel, ctx2.channel);
assert_eq!(ctx1.chat_id, ctx2.chat_id);
assert_eq!(ctx1.workspace, ctx2.workspace);
}
#[test]
fn test_tool_category_display() {
assert_eq!(ToolCategory::FilesystemRead.to_string(), "filesystem_read");
assert_eq!(ToolCategory::Shell.to_string(), "shell");
assert_eq!(ToolCategory::Hardware.to_string(), "hardware");
assert_eq!(ToolCategory::Destructive.to_string(), "destructive");
}
#[test]
fn test_tool_category_serde_roundtrip() {
let cat = ToolCategory::NetworkWrite;
let json = serde_json::to_string(&cat).unwrap();
assert_eq!(json, "\"network_write\"");
let back: ToolCategory = serde_json::from_str(&json).unwrap();
assert_eq!(back, cat);
}
#[test]
fn test_tool_category_all_variants() {
use std::collections::HashSet;
let all = vec![
ToolCategory::FilesystemRead,
ToolCategory::FilesystemWrite,
ToolCategory::NetworkRead,
ToolCategory::NetworkWrite,
ToolCategory::Shell,
ToolCategory::Hardware,
ToolCategory::Memory,
ToolCategory::Messaging,
ToolCategory::Destructive,
];
let set: HashSet<_> = all.iter().collect();
assert_eq!(set.len(), 9);
}
#[test]
fn test_tool_default_category() {
let tool = super::super::EchoTool;
assert_eq!(tool.category(), ToolCategory::Shell);
}
#[test]
fn test_tool_output_llm_only() {
let out = ToolOutput::llm_only("internal");
assert_eq!(out.for_llm, "internal");
assert!(out.for_user.is_none());
assert!(!out.is_error);
assert!(!out.is_async);
}
#[test]
fn test_tool_output_user_visible() {
let out = ToolOutput::user_visible("hello");
assert_eq!(out.for_llm, "hello");
assert_eq!(out.for_user.as_deref(), Some("hello"));
}
#[test]
fn test_tool_output_error() {
let out = ToolOutput::error("something broke");
assert!(out.is_error);
assert!(out.for_user.is_none());
}
#[test]
fn test_tool_output_async_task() {
let out = ToolOutput::async_task("running in background");
assert!(out.is_async);
assert!(!out.is_error);
}
#[test]
fn test_tool_output_split() {
let out = ToolOutput::split("llm sees this", "user sees that");
assert_eq!(out.for_llm, "llm sees this");
assert_eq!(out.for_user.as_deref(), Some("user sees that"));
}
#[test]
fn test_tool_output_default_pause_false() {
let out = ToolOutput::llm_only("test");
assert!(!out.pause_for_input);
let out2 = ToolOutput::user_visible("test");
assert!(!out2.pause_for_input);
let out3 = ToolOutput::error("test");
assert!(!out3.pause_for_input);
let out4 = ToolOutput::split("a", "b");
assert!(!out4.pause_for_input);
let out5 = ToolOutput::async_task("test");
assert!(!out5.pause_for_input);
}
#[test]
fn test_tool_output_with_pause() {
let out = ToolOutput::llm_only("test").with_pause();
assert!(out.pause_for_input);
assert_eq!(out.for_llm, "test");
}
#[test]
fn test_tool_output_split_with_pause() {
let out = ToolOutput::split("llm", "user").with_pause();
assert!(out.pause_for_input);
assert_eq!(out.for_user.as_deref(), Some("user"));
}
}