use std::collections::HashMap;
use std::pin::Pin;
use serde_json::Value;
use crate::behaviors::tool::{ToolConnector, ToolResult};
use crate::connector::BaseConnector;
use crate::error::{ConnectorError, Result};
use crate::types::{ConnectorBehavior, TaskTypeSchema};
pub struct ToolAdapter<T: ToolConnector + 'static> {
inner: T,
connector_type: String,
version: String,
category: Option<String>,
icon: Option<String>,
output_schema_default: String,
extra_metadata: HashMap<String, String>,
}
impl<T: ToolConnector + 'static> ToolAdapter<T> {
pub fn new(connector_type: impl Into<String>, inner: T) -> Self {
Self {
inner,
connector_type: connector_type.into(),
version: "0.1.0".to_string(),
category: None,
icon: None,
output_schema_default: "{}".to_string(),
extra_metadata: HashMap::new(),
}
}
pub fn with_version(mut self, v: impl Into<String>) -> Self {
self.version = v.into();
self
}
pub fn with_category(mut self, c: impl Into<String>) -> Self {
self.category = Some(c.into());
self
}
pub fn with_icon(mut self, i: impl Into<String>) -> Self {
self.icon = Some(i.into());
self
}
pub fn with_output_schema_default(mut self, schema: impl Into<String>) -> Self {
self.output_schema_default = schema.into();
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.extra_metadata.insert(key.into(), value.into());
self
}
pub fn inner(&self) -> &T {
&self.inner
}
fn build_capabilities(&self) -> Vec<TaskTypeSchema> {
let category = self.category.clone().unwrap_or_default();
let icon = self.icon.clone().unwrap_or_default();
let output = self.output_schema_default.clone();
self.inner
.tools()
.into_iter()
.map(|t| {
let parameters_json =
serde_json::to_string(&t.parameters).unwrap_or_else(|_| "{}".to_string());
TaskTypeSchema {
task_type_id: t.name.clone(),
name: t.name,
description: t.description,
category: category.clone(),
icon: icon.clone(),
input_schema_json: parameters_json,
output_schema_json: output.clone(),
}
})
.collect()
}
}
impl<T: ToolConnector + 'static> BaseConnector for ToolAdapter<T> {
fn connector_type(&self) -> &str {
&self.connector_type
}
fn version(&self) -> &str {
&self.version
}
fn behavior(&self) -> ConnectorBehavior {
ConnectorBehavior::Tool
}
fn metadata(&self) -> HashMap<String, String> {
let mut m = self.extra_metadata.clone();
m.entry("tool_count".to_string())
.or_insert_with(|| self.inner.tools().len().to_string());
m.entry("timeout_ms".to_string())
.or_insert_with(|| self.inner.timeout_ms().to_string());
m
}
fn capabilities(&self) -> Vec<TaskTypeSchema> {
self.build_capabilities()
}
fn execute(
&self,
request: Value,
capability_id: Option<&str>,
) -> Pin<Box<dyn std::future::Future<Output = Result<Value>> + Send + '_>> {
let tool_name = capability_id.map(|s| s.to_string());
Box::pin(async move {
let tool = match tool_name.as_deref() {
Some(n) if !n.is_empty() => n.to_string(),
_ => {
return Err(ConnectorError::InvalidConfig(
"TOOL connector invocation is missing capability_id (tool name)"
.to_string(),
));
}
};
tool_result_to_value(self.inner.execute(&tool, request).await, &tool)
})
}
fn execute_with_context<'a>(
&'a self,
request: Value,
capability_id: Option<&'a str>,
context: &'a HashMap<String, String>,
) -> Pin<Box<dyn std::future::Future<Output = Result<Value>> + Send + 'a>> {
let tool_name = capability_id.map(|s| s.to_string());
Box::pin(async move {
let tool = match tool_name.as_deref() {
Some(n) if !n.is_empty() => n.to_string(),
_ => {
return Err(ConnectorError::InvalidConfig(
"TOOL connector invocation is missing capability_id (tool name)"
.to_string(),
));
}
};
tool_result_to_value(
self.inner
.execute_with_context(&tool, request, context)
.await,
&tool,
)
})
}
}
fn tool_result_to_value(result: ToolResult, tool: &str) -> Result<Value> {
if result.success {
Ok(result.result.unwrap_or_else(|| serde_json::json!({})))
} else {
let msg = result.error.unwrap_or_else(|| "tool error".to_string());
let formatted = match result.error_code {
Some(code) if !code.is_empty() => format!("[{code}] tool '{tool}': {msg}"),
_ => format!("tool '{tool}': {msg}"),
};
Err(ConnectorError::InvokeFailed(formatted))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::behaviors::tool::{ParamType, ToolSchema};
use async_trait::async_trait;
use serde_json::json;
struct DemoTools;
#[async_trait]
impl ToolConnector for DemoTools {
fn tools(&self) -> Vec<ToolSchema> {
vec![
ToolSchema::new("add", "Add two numbers")
.param("a", ParamType::Number, "First operand", true)
.param("b", ParamType::Number, "Second operand", true),
ToolSchema::new("greet", "Greet a person").param(
"name",
ParamType::String,
"Name",
true,
),
]
}
async fn execute(&self, tool: &str, params: Value) -> ToolResult {
match tool {
"add" => {
let a = params["a"].as_f64().unwrap_or(0.0);
let b = params["b"].as_f64().unwrap_or(0.0);
ToolResult::success(json!({ "result": a + b }))
}
"greet" => {
let name = params["name"].as_str().unwrap_or("World");
ToolResult::success(json!({ "message": format!("Hello, {name}!") }))
}
"fail_no_code" => ToolResult::error("plain failure"),
"fail_with_code" => ToolResult::error_with_code("validation broke", "BAD_INPUT"),
_ => ToolResult::error_with_code("unknown tool", "UNKNOWN_TOOL"),
}
}
fn timeout_ms(&self) -> u64 {
7000
}
}
fn adapter() -> ToolAdapter<DemoTools> {
ToolAdapter::new("demo_tools", DemoTools)
.with_version("1.2.3")
.with_category("demo")
.with_icon("flask")
.with_metadata("docs_url", "https://example.com/demo")
}
#[test]
fn adapter_advertises_tool_behavior() {
let a = adapter();
assert_eq!(a.behavior(), ConnectorBehavior::Tool);
assert_eq!(a.behaviors(), vec![ConnectorBehavior::Tool]);
}
#[test]
fn adapter_passes_through_identity() {
let a = adapter();
assert_eq!(a.connector_type(), "demo_tools");
assert_eq!(a.version(), "1.2.3");
}
#[test]
fn adapter_capabilities_match_tools() {
let a = adapter();
let caps = a.capabilities();
assert_eq!(caps.len(), 2);
let add = caps.iter().find(|c| c.task_type_id == "add").unwrap();
assert_eq!(add.name, "add");
assert_eq!(add.description, "Add two numbers");
assert_eq!(add.category, "demo");
assert_eq!(add.icon, "flask");
let parsed: Value = serde_json::from_str(&add.input_schema_json).unwrap();
assert_eq!(parsed["type"], "object");
assert!(parsed["properties"]["a"].is_object());
assert!(parsed["properties"]["b"].is_object());
let greet = caps.iter().find(|c| c.task_type_id == "greet").unwrap();
assert_eq!(greet.description, "Greet a person");
}
#[test]
fn adapter_metadata_includes_user_keys_and_derived_counts() {
let a = adapter();
let m = a.metadata();
assert_eq!(
m.get("docs_url").map(|s| s.as_str()),
Some("https://example.com/demo")
);
assert_eq!(m.get("tool_count").map(|s| s.as_str()), Some("2"));
assert_eq!(m.get("timeout_ms").map(|s| s.as_str()), Some("7000"));
assert!(!m.contains_key("tool_schemas"));
}
#[tokio::test]
async fn adapter_dispatches_to_named_tool_on_success() {
let a = adapter();
let v = a
.execute(json!({ "a": 2.5, "b": 3.0 }), Some("add"))
.await
.expect("add tool should succeed");
assert_eq!(v["result"], json!(5.5));
let g = a
.execute(json!({ "name": "Sreejith" }), Some("greet"))
.await
.expect("greet tool should succeed");
assert_eq!(g["message"], json!("Hello, Sreejith!"));
}
#[tokio::test]
async fn adapter_returns_error_when_capability_id_missing() {
let a = adapter();
let err = a.execute(json!({}), None).await.unwrap_err();
match err {
ConnectorError::InvalidConfig(msg) => assert!(msg.contains("capability_id")),
other => panic!("expected InvalidConfig, got {other:?}"),
}
}
#[tokio::test]
async fn adapter_returns_error_when_capability_id_empty() {
let a = adapter();
let err = a.execute(json!({}), Some("")).await.unwrap_err();
assert!(matches!(err, ConnectorError::InvalidConfig(_)));
}
#[tokio::test]
async fn adapter_propagates_tool_error_with_code() {
let a = adapter();
let err = a
.execute(json!({}), Some("fail_with_code"))
.await
.unwrap_err();
match err {
ConnectorError::InvokeFailed(msg) => {
assert!(msg.contains("BAD_INPUT"), "missing error code in: {msg}");
assert!(
msg.contains("validation broke"),
"missing message in: {msg}"
);
assert!(
msg.contains("fail_with_code"),
"missing tool name in: {msg}"
);
}
other => panic!("expected InvokeFailed, got {other:?}"),
}
}
#[tokio::test]
async fn adapter_propagates_tool_error_without_code() {
let a = adapter();
let err = a
.execute(json!({}), Some("fail_no_code"))
.await
.unwrap_err();
match err {
ConnectorError::InvokeFailed(msg) => {
assert!(msg.contains("plain failure"), "missing message in: {msg}");
assert!(!msg.contains("[]"), "stray empty code brackets in: {msg}");
}
other => panic!("expected InvokeFailed, got {other:?}"),
}
}
#[tokio::test]
async fn adapter_treats_unknown_tool_as_invoke_failed() {
let a = adapter();
let err = a
.execute(json!({}), Some("does_not_exist"))
.await
.unwrap_err();
assert!(matches!(err, ConnectorError::InvokeFailed(_)));
}
use std::sync::Mutex as StdMutex;
struct CtxCapturingTool {
seen: std::sync::Arc<StdMutex<Option<HashMap<String, String>>>>,
}
#[async_trait]
impl ToolConnector for CtxCapturingTool {
fn tools(&self) -> Vec<ToolSchema> {
vec![ToolSchema::new("noop", "Capture the calling context")]
}
async fn execute(&self, _tool: &str, _params: Value) -> ToolResult {
unreachable!("ToolAdapter must dispatch to execute_with_context, not bare execute")
}
async fn execute_with_context(
&self,
_tool: &str,
_params: Value,
context: &HashMap<String, String>,
) -> ToolResult {
*self.seen.lock().unwrap() = Some(context.clone());
ToolResult::ok()
}
}
struct LegacyTool {
called: std::sync::Arc<std::sync::atomic::AtomicBool>,
}
#[async_trait]
impl ToolConnector for LegacyTool {
fn tools(&self) -> Vec<ToolSchema> {
vec![ToolSchema::new(
"noop",
"Legacy tool with no context support",
)]
}
async fn execute(&self, _tool: &str, _params: Value) -> ToolResult {
self.called.store(true, std::sync::atomic::Ordering::SeqCst);
ToolResult::ok()
}
}
#[tokio::test]
async fn tool_adapter_forwards_context_to_execute_with_context() {
let seen = std::sync::Arc::new(StdMutex::new(None));
let adapter = ToolAdapter::new("ctx_tool", CtxCapturingTool { seen: seen.clone() });
let mut context = HashMap::new();
context.insert("tenant_id".into(), "tenant-acme".into());
context.insert("user_id".into(), "subject-99".into());
adapter
.execute_with_context(json!({}), Some("noop"), &context)
.await
.expect("execute_with_context must succeed");
let captured = seen
.lock()
.unwrap()
.clone()
.expect("ToolAdapter must invoke execute_with_context on the inner tool");
assert_eq!(
captured, context,
"context must round-trip through ToolAdapter to the inner ToolConnector"
);
}
#[tokio::test]
async fn tool_adapter_default_path_keeps_working_for_legacy_tool() {
let called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let adapter = ToolAdapter::new(
"legacy_tool",
LegacyTool {
called: called.clone(),
},
);
let mut context = HashMap::new();
context.insert("tenant_id".into(), "ignored-by-legacy".into());
adapter
.execute_with_context(json!({}), Some("noop"), &context)
.await
.expect("legacy tool must still execute through the default path");
assert!(
called.load(std::sync::atomic::Ordering::SeqCst),
"default ToolConnector::execute_with_context must delegate to execute"
);
}
#[tokio::test]
async fn tool_adapter_execute_with_context_requires_capability_id() {
let seen = std::sync::Arc::new(StdMutex::new(None));
let adapter = ToolAdapter::new("ctx_tool", CtxCapturingTool { seen: seen.clone() });
let context = HashMap::new();
let err = adapter
.execute_with_context(json!({}), None, &context)
.await
.unwrap_err();
assert!(matches!(err, ConnectorError::InvalidConfig(_)));
assert!(seen.lock().unwrap().is_none());
}
#[test]
fn build_registration_metadata_injects_tool_schemas_for_adapter() {
let a = adapter();
let merged = crate::connector::build_registration_metadata(&a);
let tool_schemas_json = merged
.get("tool_schemas")
.expect("tool_schemas must be auto-injected for TOOL behavior");
let parsed: Value = serde_json::from_str(tool_schemas_json).unwrap();
let arr = parsed.as_array().expect("tool_schemas must be an array");
assert_eq!(arr.len(), 2);
for entry in arr {
assert!(entry["name"].is_string());
assert!(entry["description"].is_string());
assert!(!entry["name"].as_str().unwrap().is_empty());
assert!(!entry["description"].as_str().unwrap().is_empty());
}
}
}