use std::sync::Arc;
use futures::future::BoxFuture;
use rig::completion::ToolDefinition;
use rig::tool::{ToolDyn, ToolError};
use super::BuiltinTool;
use crate::event::{EventKind, EventLog};
pub struct NikaBuiltinToolAdapter {
tool: Arc<dyn BuiltinTool>,
full_name: String,
event_log: Option<Arc<EventLog>>,
task_id: Option<Arc<str>>,
}
impl NikaBuiltinToolAdapter {
pub fn new(tool: Arc<dyn BuiltinTool>) -> Self {
let full_name = format!("nika_{}", tool.name());
Self {
tool,
full_name,
event_log: None,
task_id: None,
}
}
pub fn with_event_log(mut self, event_log: Arc<EventLog>, task_id: Arc<str>) -> Self {
self.event_log = Some(event_log);
self.task_id = Some(task_id);
self
}
}
impl std::fmt::Debug for NikaBuiltinToolAdapter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NikaBuiltinToolAdapter")
.field("name", &self.full_name)
.finish()
}
}
impl ToolDyn for NikaBuiltinToolAdapter {
fn name(&self) -> String {
self.full_name.clone()
}
fn definition(&self, _prompt: String) -> BoxFuture<'_, ToolDefinition> {
let def = ToolDefinition {
name: self.full_name.clone(),
description: self.tool.description().to_string(),
parameters: self.tool.parameters_schema(),
};
Box::pin(async move { def })
}
fn call(&self, args: String) -> BoxFuture<'_, Result<String, ToolError>> {
let args_clone = args.clone();
Box::pin(async move {
let result =
self.tool.call(args_clone.clone()).await.map_err(|e| {
ToolError::ToolCallError(Box::new(BuiltinToolError(e.to_string())))
})?;
if let Some(ref event_log) = self.event_log {
match self.tool.name() {
"log" => {
if let Ok(response) = serde_json::from_str::<serde_json::Value>(&result) {
let level = response["level"].as_str().unwrap_or("info").to_string();
let message = response["message"].as_str().unwrap_or("").to_string();
event_log.emit(EventKind::Log {
level,
message,
task_id: self.task_id.clone(),
});
}
}
"emit" => {
if let Ok(response) = serde_json::from_str::<serde_json::Value>(&result) {
let name = response["name"].as_str().unwrap_or("unknown").to_string();
let payload = response["payload"].clone();
event_log.emit(EventKind::Custom {
name,
payload,
task_id: self.task_id.clone(),
});
}
}
_ => {}
}
}
Ok(result)
})
}
}
#[derive(Debug)]
struct BuiltinToolError(String);
impl std::fmt::Display for BuiltinToolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for BuiltinToolError {}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::NikaError;
struct TestTool;
impl BuiltinTool for TestTool {
fn name(&self) -> &'static str {
"test"
}
fn description(&self) -> &'static str {
"A test tool for unit tests"
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"value": {"type": "string"}
},
"required": ["value"]
})
}
fn call<'a>(
&'a self,
args: String,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<String, NikaError>> + Send + 'a>,
> {
Box::pin(async move {
let params: serde_json::Value =
serde_json::from_str(&args).map_err(|e| NikaError::BuiltinToolError {
tool: "test".into(),
reason: format!("Invalid JSON: {}", e),
})?;
let value = params["value"].as_str().unwrap_or("");
Ok(format!(r#"{{"received":"{}"}}"#, value))
})
}
}
#[test]
fn test_adapter_name() {
let tool = Arc::new(TestTool);
let adapter = NikaBuiltinToolAdapter::new(tool);
assert_eq!(adapter.name(), "nika_test");
}
#[tokio::test]
async fn test_adapter_definition() {
let tool = Arc::new(TestTool);
let adapter = NikaBuiltinToolAdapter::new(tool);
let def = adapter.definition("test".to_string()).await;
assert_eq!(def.name, "nika_test");
assert_eq!(def.description, "A test tool for unit tests");
assert_eq!(
def.parameters,
serde_json::json!({
"type": "object",
"properties": {
"value": {"type": "string"}
},
"required": ["value"]
})
);
}
#[tokio::test]
async fn test_adapter_call_success() {
let tool = Arc::new(TestTool);
let adapter = NikaBuiltinToolAdapter::new(tool);
let result = adapter.call(r#"{"value": "hello"}"#.to_string()).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), r#"{"received":"hello"}"#);
}
#[tokio::test]
async fn test_adapter_call_invalid_json() {
let tool = Arc::new(TestTool);
let adapter = NikaBuiltinToolAdapter::new(tool);
let result = adapter.call("not json".to_string()).await;
assert!(result.is_err());
}
#[test]
fn test_adapter_debug() {
let tool = Arc::new(TestTool);
let adapter = NikaBuiltinToolAdapter::new(tool);
let debug_str = format!("{:?}", adapter);
assert!(debug_str.contains("nika_test"));
}
#[test]
fn test_adapter_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<NikaBuiltinToolAdapter>();
}
#[tokio::test]
async fn test_log_tool_emits_event() {
use super::super::LogTool;
use rig::tool::ToolDyn;
let event_log = Arc::new(EventLog::new());
let task_id: Arc<str> = "test-task-1".into();
let adapter = NikaBuiltinToolAdapter::new(Arc::new(LogTool))
.with_event_log(Arc::clone(&event_log), Arc::clone(&task_id));
let result = adapter
.call(r#"{"level": "info", "message": "Test log message"}"#.to_string())
.await;
assert!(result.is_ok());
let events = event_log.events();
assert_eq!(events.len(), 1);
if let EventKind::Log {
level,
message,
task_id: tid,
} = &events[0].kind
{
assert_eq!(level, "info");
assert_eq!(message, "Test log message");
assert_eq!(tid.as_ref().map(|t| t.as_ref()), Some("test-task-1"));
} else {
panic!("Expected EventKind::Log, got {:?}", events[0].kind);
}
}
#[tokio::test]
async fn test_emit_tool_emits_custom_event() {
use super::super::EmitTool;
use rig::tool::ToolDyn;
let event_log = Arc::new(EventLog::new());
let task_id: Arc<str> = "test-task-2".into();
let adapter = NikaBuiltinToolAdapter::new(Arc::new(EmitTool))
.with_event_log(Arc::clone(&event_log), Arc::clone(&task_id));
let result = adapter
.call(r#"{"name": "user_action", "payload": {"action": "click"}}"#.to_string())
.await;
assert!(result.is_ok());
let events = event_log.events();
assert_eq!(events.len(), 1);
if let EventKind::Custom {
name,
payload,
task_id: tid,
} = &events[0].kind
{
assert_eq!(name, "user_action");
assert_eq!(payload["action"], "click");
assert_eq!(tid.as_ref().map(|t| t.as_ref()), Some("test-task-2"));
} else {
panic!("Expected EventKind::Custom, got {:?}", events[0].kind);
}
}
#[tokio::test]
async fn test_adapter_without_event_log_does_not_emit() {
use super::super::LogTool;
use rig::tool::ToolDyn;
let adapter = NikaBuiltinToolAdapter::new(Arc::new(LogTool));
let result = adapter
.call(r#"{"level": "info", "message": "No event expected"}"#.to_string())
.await;
assert!(result.is_ok());
}
#[test]
fn test_with_event_log_builder() {
let event_log = Arc::new(EventLog::new());
let task_id: Arc<str> = "test-task".into();
let adapter = NikaBuiltinToolAdapter::new(Arc::new(TestTool))
.with_event_log(Arc::clone(&event_log), Arc::clone(&task_id));
assert_eq!(adapter.name(), "nika_test");
}
}