pub mod schema_based;
pub mod simple;
pub mod types;
pub mod validation;
pub use schema_based::SchemaBasedTool;
pub use simple::__simple_async_trait;
pub use types::{ToolInput, ToolOutput};
pub use validation::{Format, ValidateArgs};
use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use cognis_core::{CognisError, Result};
#[async_trait]
pub trait Tool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn args_schema(&self) -> Option<serde_json::Value>;
fn return_direct(&self) -> bool {
false
}
async fn _run(&self, input: ToolInput) -> Result<ToolOutput>;
}
pub use Tool as BaseTool;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: Option<serde_json::Value>,
}
impl ToolDefinition {
pub fn from_tool(t: &dyn Tool) -> Self {
Self {
name: t.name().to_string(),
description: t.description().to_string(),
parameters: t.args_schema(),
}
}
}
#[derive(Default)]
struct ToolEntry {
tool: Option<Arc<dyn Tool>>,
enabled: bool,
calls: std::sync::atomic::AtomicUsize,
#[allow(clippy::type_complexity)]
permission: Option<Arc<dyn Fn(&str) -> bool + Send + Sync>>,
}
impl Clone for ToolEntry {
fn clone(&self) -> Self {
Self {
tool: self.tool.clone(),
enabled: self.enabled,
calls: std::sync::atomic::AtomicUsize::new(
self.calls.load(std::sync::atomic::Ordering::Relaxed),
),
permission: self.permission.clone(),
}
}
}
#[derive(Default, Clone)]
pub struct ToolRegistry {
entries: HashMap<String, ToolEntry>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, tool: Arc<dyn Tool>) {
let name = tool.name().to_string();
self.entries.insert(
name,
ToolEntry {
tool: Some(tool),
enabled: true,
calls: std::sync::atomic::AtomicUsize::new(0),
permission: None,
},
);
}
pub fn register_alias(&mut self, alias: impl Into<String>, name: &str) {
if let Some(t) = self.entries.get(name).and_then(|e| e.tool.clone()) {
self.entries.insert(
alias.into(),
ToolEntry {
tool: Some(t),
enabled: true,
calls: std::sync::atomic::AtomicUsize::new(0),
permission: None,
},
);
}
}
pub fn unregister(&mut self, name: &str) -> bool {
self.entries.remove(name).is_some()
}
pub fn retain<F>(&mut self, mut predicate: F) -> Vec<String>
where
F: FnMut(&str) -> bool,
{
let mut removed = Vec::new();
self.entries.retain(|k, _| {
let keep = predicate(k);
if !keep {
removed.push(k.clone());
}
keep
});
removed
}
pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
let e = self.entries.get(name)?;
if !e.enabled {
return None;
}
e.tool.as_ref()
}
pub fn contains(&self, name: &str) -> bool {
self.entries.contains_key(name)
}
pub fn is_enabled(&self, name: &str) -> bool {
self.entries.get(name).is_some_and(|e| e.enabled)
}
pub fn disable(&mut self, name: &str) -> bool {
match self.entries.get_mut(name) {
Some(e) => {
e.enabled = false;
true
}
None => false,
}
}
pub fn enable(&mut self, name: &str) -> bool {
match self.entries.get_mut(name) {
Some(e) => {
e.enabled = true;
true
}
None => false,
}
}
pub fn set_permission<F>(&mut self, name: &str, predicate: F) -> bool
where
F: Fn(&str) -> bool + Send + Sync + 'static,
{
match self.entries.get_mut(name) {
Some(e) => {
e.permission = Some(Arc::new(predicate));
true
}
None => false,
}
}
pub fn clear_permission(&mut self, name: &str) {
if let Some(e) = self.entries.get_mut(name) {
e.permission = None;
}
}
pub fn is_allowed(&self, name: &str, agent_id: &str) -> bool {
let Some(e) = self.entries.get(name) else {
return false;
};
if !e.enabled {
return false;
}
match &e.permission {
Some(p) => p(agent_id),
None => true,
}
}
pub fn call_count(&self, name: &str) -> usize {
self.entries
.get(name)
.map(|e| e.calls.load(std::sync::atomic::Ordering::Relaxed))
.unwrap_or(0)
}
pub fn tool_names(&self) -> Vec<&str> {
self.entries
.iter()
.filter(|(_, e)| e.enabled)
.map(|(k, _)| k.as_str())
.collect()
}
pub fn definitions(&self) -> Vec<ToolDefinition> {
self.entries
.values()
.filter(|e| e.enabled)
.filter_map(|e| e.tool.as_ref())
.map(|t| ToolDefinition::from_tool(t.as_ref()))
.collect()
}
pub async fn execute(&self, name: &str, input: ToolInput) -> Result<ToolOutput> {
let entry = self.entries.get(name).ok_or_else(|| CognisError::Tool {
name: name.to_string(),
reason: "not registered".into(),
})?;
if !entry.enabled {
return Err(CognisError::Tool {
name: name.to_string(),
reason: "disabled".into(),
});
}
let t = entry.tool.as_ref().ok_or_else(|| CognisError::Tool {
name: name.to_string(),
reason: "no implementation".into(),
})?;
entry
.calls
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
t._run(input).await
}
pub async fn execute_for(
&self,
name: &str,
agent_id: &str,
input: ToolInput,
) -> Result<ToolOutput> {
let entry = self.entries.get(name).ok_or_else(|| CognisError::Tool {
name: name.to_string(),
reason: "not registered".into(),
})?;
if !entry.enabled {
return Err(CognisError::Tool {
name: name.to_string(),
reason: "disabled".into(),
});
}
let allowed = entry
.permission
.as_ref()
.map(|p| p(agent_id))
.unwrap_or(true);
if !allowed {
return Err(CognisError::Tool {
name: name.to_string(),
reason: format!("not allowed for agent `{agent_id}`"),
});
}
self.execute(name, input).await
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
struct Echo;
#[async_trait]
impl Tool for Echo {
fn name(&self) -> &str {
"echo"
}
fn description(&self) -> &str {
"echoes input"
}
fn args_schema(&self) -> Option<serde_json::Value> {
Some(json!({"type": "object", "properties": {"text": {"type": "string"}}}))
}
async fn _run(&self, input: ToolInput) -> Result<ToolOutput> {
Ok(ToolOutput::Content(input.into_json()))
}
}
#[tokio::test]
async fn registry_register_get_execute() {
let mut reg = ToolRegistry::new();
assert!(reg.is_empty());
reg.register(Arc::new(Echo));
assert_eq!(reg.len(), 1);
assert!(reg.contains("echo"));
let mut m = HashMap::new();
m.insert("text".into(), json!("hi"));
let out = reg.execute("echo", ToolInput::Structured(m)).await.unwrap();
match out {
ToolOutput::Content(v) => assert_eq!(v["text"], "hi"),
_ => panic!("wrong variant"),
}
}
#[tokio::test]
async fn unknown_tool_errors() {
let reg = ToolRegistry::new();
let err = reg
.execute("missing", ToolInput::Text("x".into()))
.await
.unwrap_err();
assert_eq!(err.category(), "tool");
}
#[test]
fn definition_from_tool() {
let d = ToolDefinition::from_tool(&Echo);
assert_eq!(d.name, "echo");
assert_eq!(d.description, "echoes input");
assert!(d.parameters.is_some());
}
#[tokio::test]
async fn disable_hides_from_dispatch_and_listing() {
let mut reg = ToolRegistry::new();
reg.register(Arc::new(Echo));
assert!(reg.disable("echo"));
assert!(reg.contains("echo"), "still registered");
assert!(!reg.is_enabled("echo"));
assert!(reg.tool_names().is_empty());
assert!(reg.definitions().is_empty());
let err = reg
.execute("echo", ToolInput::Text("x".into()))
.await
.unwrap_err();
assert!(err.to_string().contains("disabled"), "got: {err}");
}
#[tokio::test]
async fn enable_restores() {
let mut reg = ToolRegistry::new();
reg.register(Arc::new(Echo));
reg.disable("echo");
reg.enable("echo");
assert!(reg.is_enabled("echo"));
assert!(reg
.execute("echo", ToolInput::Text("x".into()))
.await
.is_ok());
}
#[tokio::test]
async fn call_count_increments_on_execute() {
let mut reg = ToolRegistry::new();
reg.register(Arc::new(Echo));
assert_eq!(reg.call_count("echo"), 0);
for _ in 0..3 {
reg.execute("echo", ToolInput::Text("hi".into()))
.await
.unwrap();
}
assert_eq!(reg.call_count("echo"), 3);
assert_eq!(reg.call_count("missing"), 0);
}
#[tokio::test]
async fn permission_predicate_blocks_disallowed_agents() {
let mut reg = ToolRegistry::new();
reg.register(Arc::new(Echo));
reg.set_permission("echo", |agent_id: &str| agent_id == "writer");
assert!(reg.is_allowed("echo", "writer"));
assert!(!reg.is_allowed("echo", "intruder"));
let ok = reg
.execute_for("echo", "writer", ToolInput::Text("hi".into()))
.await;
assert!(ok.is_ok());
let denied = reg
.execute_for("echo", "intruder", ToolInput::Text("hi".into()))
.await
.unwrap_err();
assert!(denied.to_string().contains("not allowed"), "got: {denied}");
}
#[tokio::test]
async fn execute_for_reports_not_registered_before_permission() {
let reg = ToolRegistry::new();
let err = reg
.execute_for("ghost", "writer", ToolInput::Text("x".into()))
.await
.unwrap_err();
assert!(
err.to_string().contains("not registered"),
"wrong error: {err}"
);
}
#[tokio::test]
async fn execute_for_reports_disabled_before_permission() {
let mut reg = ToolRegistry::new();
reg.register(Arc::new(Echo));
reg.disable("echo");
reg.set_permission("echo", |_| false);
let err = reg
.execute_for("echo", "writer", ToolInput::Text("x".into()))
.await
.unwrap_err();
assert!(err.to_string().contains("disabled"), "wrong error: {err}");
}
#[tokio::test]
async fn clear_permission_reopens_dispatch() {
let mut reg = ToolRegistry::new();
reg.register(Arc::new(Echo));
reg.set_permission("echo", |_: &str| false);
assert!(!reg.is_allowed("echo", "anyone"));
reg.clear_permission("echo");
assert!(reg.is_allowed("echo", "anyone"));
}
}