use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
#[cfg(feature = "native")]
pub use tokio_util::sync::CancellationToken;
#[cfg(not(feature = "native"))]
mod cancellation {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
#[derive(Clone)]
pub struct CancellationToken {
cancelled: Arc<AtomicBool>,
}
impl CancellationToken {
pub fn new() -> Self {
Self {
cancelled: Arc::new(AtomicBool::new(false)),
}
}
pub fn cancel(&self) {
self.cancelled.store(true, Ordering::SeqCst);
}
pub fn is_cancelled(&self) -> bool {
self.cancelled.load(Ordering::SeqCst)
}
}
impl Default for CancellationToken {
fn default() -> Self {
Self::new()
}
}
}
#[cfg(not(feature = "native"))]
pub use cancellation::CancellationToken;
use crate::error::PluginError;
use crate::message::MessagePayload;
#[async_trait]
pub trait Tool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters_schema(&self) -> serde_json::Value;
async fn execute(
&self,
params: serde_json::Value,
ctx: &dyn ToolContext,
) -> Result<serde_json::Value, PluginError>;
}
#[async_trait]
pub trait ChannelAdapter: Send + Sync {
fn name(&self) -> &str;
fn display_name(&self) -> &str;
fn supports_threads(&self) -> bool;
fn supports_media(&self) -> bool;
async fn start(
&self,
host: Arc<dyn ChannelAdapterHost>,
cancel: CancellationToken,
) -> Result<(), PluginError>;
async fn send(
&self,
target: &str,
payload: &MessagePayload,
) -> Result<String, PluginError>;
}
#[async_trait]
pub trait ChannelAdapterHost: Send + Sync {
async fn deliver_inbound(
&self,
channel: &str,
sender_id: &str,
chat_id: &str,
payload: MessagePayload,
metadata: HashMap<String, serde_json::Value>,
) -> Result<(), PluginError>;
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PipelineStageType {
PreProcess,
Process,
PostProcess,
Observer,
}
#[async_trait]
pub trait PipelineStage: Send + Sync {
fn name(&self) -> &str;
fn stage_type(&self) -> PipelineStageType;
async fn process(
&self,
input: serde_json::Value,
) -> Result<serde_json::Value, PluginError>;
}
#[async_trait]
pub trait Skill: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn version(&self) -> &str;
fn variables(&self) -> HashMap<String, String>;
fn allowed_tools(&self) -> Vec<String>;
fn instructions(&self) -> &str;
fn is_user_invocable(&self) -> bool;
async fn execute_tool(
&self,
tool_name: &str,
params: serde_json::Value,
ctx: &dyn ToolContext,
) -> Result<serde_json::Value, PluginError>;
}
#[async_trait]
pub trait MemoryBackend: Send + Sync {
async fn store(
&self,
key: &str,
value: &str,
namespace: Option<&str>,
ttl_seconds: Option<u64>,
tags: Option<Vec<String>>,
) -> Result<(), PluginError>;
async fn retrieve(
&self,
key: &str,
namespace: Option<&str>,
) -> Result<Option<String>, PluginError>;
async fn search(
&self,
query: &str,
namespace: Option<&str>,
limit: Option<usize>,
) -> Result<Vec<(String, String, f64)>, PluginError>;
async fn delete(
&self,
key: &str,
namespace: Option<&str>,
) -> Result<bool, PluginError>;
}
#[async_trait]
pub trait VoiceHandler: Send + Sync {
async fn process_audio(
&self,
audio_data: &[u8],
mime_type: &str,
) -> Result<String, PluginError>;
async fn synthesize(
&self,
text: &str,
) -> Result<(Vec<u8>, String), PluginError>;
}
#[async_trait]
pub trait KeyValueStore: Send + Sync {
async fn get(&self, key: &str) -> Result<Option<String>, PluginError>;
async fn set(&self, key: &str, value: &str) -> Result<(), PluginError>;
async fn delete(&self, key: &str) -> Result<bool, PluginError>;
async fn list_keys(&self, prefix: Option<&str>) -> Result<Vec<String>, PluginError>;
}
pub trait ToolContext: Send + Sync {
fn key_value_store(&self) -> &dyn KeyValueStore;
fn plugin_id(&self) -> &str;
fn agent_id(&self) -> &str;
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_send_sync<T: Send + Sync + ?Sized>() {}
#[test]
fn test_traits_are_send_sync() {
assert_send_sync::<dyn Tool>();
assert_send_sync::<dyn ChannelAdapter>();
assert_send_sync::<dyn PipelineStage>();
assert_send_sync::<dyn Skill>();
assert_send_sync::<dyn MemoryBackend>();
assert_send_sync::<dyn VoiceHandler>();
assert_send_sync::<dyn KeyValueStore>();
assert_send_sync::<dyn ToolContext>();
assert_send_sync::<dyn ChannelAdapterHost>();
}
#[test]
fn test_pipeline_stage_type_serde_roundtrip() {
let types = vec![
PipelineStageType::PreProcess,
PipelineStageType::Process,
PipelineStageType::PostProcess,
PipelineStageType::Observer,
];
for t in &types {
let json = serde_json::to_string(t).unwrap();
let restored: PipelineStageType = serde_json::from_str(&json).unwrap();
assert_eq!(&restored, t);
}
}
#[test]
fn test_pipeline_stage_type_json_values() {
assert_eq!(
serde_json::to_string(&PipelineStageType::PreProcess).unwrap(),
"\"pre_process\""
);
assert_eq!(
serde_json::to_string(&PipelineStageType::Process).unwrap(),
"\"process\""
);
assert_eq!(
serde_json::to_string(&PipelineStageType::PostProcess).unwrap(),
"\"post_process\""
);
assert_eq!(
serde_json::to_string(&PipelineStageType::Observer).unwrap(),
"\"observer\""
);
}
struct MockKvStore;
#[async_trait]
impl KeyValueStore for MockKvStore {
async fn get(&self, _key: &str) -> Result<Option<String>, PluginError> {
Ok(None)
}
async fn set(&self, _key: &str, _value: &str) -> Result<(), PluginError> {
Ok(())
}
async fn delete(&self, _key: &str) -> Result<bool, PluginError> {
Ok(false)
}
async fn list_keys(&self, _prefix: Option<&str>) -> Result<Vec<String>, PluginError> {
Ok(vec![])
}
}
struct MockToolContext;
impl ToolContext for MockToolContext {
fn key_value_store(&self) -> &dyn KeyValueStore {
&MockKvStore
}
fn plugin_id(&self) -> &str {
"mock-plugin"
}
fn agent_id(&self) -> &str {
"mock-agent"
}
}
struct MockTool;
#[async_trait]
impl Tool for MockTool {
fn name(&self) -> &str {
"mock_tool"
}
fn description(&self) -> &str {
"A mock tool for testing"
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"input": { "type": "string" }
}
})
}
async fn execute(
&self,
params: serde_json::Value,
_ctx: &dyn ToolContext,
) -> Result<serde_json::Value, PluginError> {
Ok(serde_json::json!({
"result": format!("processed: {}", params)
}))
}
}
struct MockChannelAdapter;
#[async_trait]
impl ChannelAdapter for MockChannelAdapter {
fn name(&self) -> &str {
"mock"
}
fn display_name(&self) -> &str {
"Mock Channel"
}
fn supports_threads(&self) -> bool {
false
}
fn supports_media(&self) -> bool {
true
}
async fn start(
&self,
_host: Arc<dyn ChannelAdapterHost>,
cancel: CancellationToken,
) -> Result<(), PluginError> {
cancel.cancelled().await;
Ok(())
}
async fn send(
&self,
_target: &str,
_payload: &MessagePayload,
) -> Result<String, PluginError> {
Ok("msg-001".into())
}
}
struct MockPipelineStage;
#[async_trait]
impl PipelineStage for MockPipelineStage {
fn name(&self) -> &str {
"mock_stage"
}
fn stage_type(&self) -> PipelineStageType {
PipelineStageType::PreProcess
}
async fn process(
&self,
input: serde_json::Value,
) -> Result<serde_json::Value, PluginError> {
Ok(input)
}
}
struct MockSkill;
#[async_trait]
impl Skill for MockSkill {
fn name(&self) -> &str {
"mock-skill"
}
fn description(&self) -> &str {
"A mock skill"
}
fn version(&self) -> &str {
"1.0.0"
}
fn variables(&self) -> HashMap<String, String> {
HashMap::new()
}
fn allowed_tools(&self) -> Vec<String> {
vec!["mock_tool".into()]
}
fn instructions(&self) -> &str {
"Do mock things."
}
fn is_user_invocable(&self) -> bool {
true
}
async fn execute_tool(
&self,
tool_name: &str,
_params: serde_json::Value,
_ctx: &dyn ToolContext,
) -> Result<serde_json::Value, PluginError> {
Ok(serde_json::json!({ "tool": tool_name, "status": "ok" }))
}
}
struct MockMemoryBackend;
#[async_trait]
impl MemoryBackend for MockMemoryBackend {
async fn store(
&self,
_key: &str,
_value: &str,
_namespace: Option<&str>,
_ttl_seconds: Option<u64>,
_tags: Option<Vec<String>>,
) -> Result<(), PluginError> {
Ok(())
}
async fn retrieve(
&self,
_key: &str,
_namespace: Option<&str>,
) -> Result<Option<String>, PluginError> {
Ok(Some("stored-value".into()))
}
async fn search(
&self,
_query: &str,
_namespace: Option<&str>,
_limit: Option<usize>,
) -> Result<Vec<(String, String, f64)>, PluginError> {
Ok(vec![("key".into(), "value".into(), 0.95)])
}
async fn delete(
&self,
_key: &str,
_namespace: Option<&str>,
) -> Result<bool, PluginError> {
Ok(true)
}
}
struct MockVoiceHandler;
#[async_trait]
impl VoiceHandler for MockVoiceHandler {
async fn process_audio(
&self,
_audio_data: &[u8],
_mime_type: &str,
) -> Result<String, PluginError> {
Ok("transcribed text".into())
}
async fn synthesize(
&self,
_text: &str,
) -> Result<(Vec<u8>, String), PluginError> {
Ok((vec![0u8; 100], "audio/wav".into()))
}
}
#[tokio::test]
async fn test_tool_trait_implementation() {
let tool = MockTool;
let ctx = MockToolContext;
assert_eq!(tool.name(), "mock_tool");
assert_eq!(tool.description(), "A mock tool for testing");
assert!(tool.parameters_schema().is_object());
let result = tool
.execute(serde_json::json!({"input": "test"}), &ctx)
.await
.unwrap();
assert!(result["result"].as_str().unwrap().contains("test"));
}
#[tokio::test]
async fn test_channel_adapter_trait_implementation() {
let adapter = MockChannelAdapter;
assert_eq!(adapter.name(), "mock");
assert_eq!(adapter.display_name(), "Mock Channel");
assert!(!adapter.supports_threads());
assert!(adapter.supports_media());
let payload = MessagePayload::text("hello");
let msg_id = adapter.send("target", &payload).await.unwrap();
assert_eq!(msg_id, "msg-001");
}
#[tokio::test]
async fn test_pipeline_stage_trait_implementation() {
let stage = MockPipelineStage;
assert_eq!(stage.name(), "mock_stage");
assert_eq!(stage.stage_type(), PipelineStageType::PreProcess);
let input = serde_json::json!({"data": "test"});
let output = stage.process(input.clone()).await.unwrap();
assert_eq!(output, input);
}
#[tokio::test]
async fn test_skill_trait_implementation() {
let skill = MockSkill;
let ctx = MockToolContext;
assert_eq!(skill.name(), "mock-skill");
assert_eq!(skill.description(), "A mock skill");
assert_eq!(skill.version(), "1.0.0");
assert!(skill.variables().is_empty());
assert_eq!(skill.allowed_tools(), vec!["mock_tool"]);
assert_eq!(skill.instructions(), "Do mock things.");
assert!(skill.is_user_invocable());
let result = skill
.execute_tool("mock_tool", serde_json::json!({}), &ctx)
.await
.unwrap();
assert_eq!(result["tool"], "mock_tool");
assert_eq!(result["status"], "ok");
}
#[tokio::test]
async fn test_memory_backend_trait_implementation() {
let backend = MockMemoryBackend;
backend
.store("key", "value", None, None, None)
.await
.unwrap();
let val = backend.retrieve("key", None).await.unwrap();
assert_eq!(val, Some("stored-value".into()));
let results = backend.search("query", None, Some(10)).await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, "key");
let deleted = backend.delete("key", None).await.unwrap();
assert!(deleted);
}
#[tokio::test]
async fn test_voice_handler_trait_implementation() {
let handler = MockVoiceHandler;
let text = handler
.process_audio(&[0u8; 100], "audio/wav")
.await
.unwrap();
assert_eq!(text, "transcribed text");
let (audio, mime) = handler.synthesize("hello").await.unwrap();
assert!(!audio.is_empty());
assert_eq!(mime, "audio/wav");
}
#[tokio::test]
async fn test_key_value_store_trait_implementation() {
let store = MockKvStore;
let val = store.get("missing").await.unwrap();
assert!(val.is_none());
store.set("key", "value").await.unwrap();
let deleted = store.delete("key").await.unwrap();
assert!(!deleted); let keys = store.list_keys(None).await.unwrap();
assert!(keys.is_empty());
}
#[test]
fn test_tool_context_trait_implementation() {
let ctx = MockToolContext;
assert_eq!(ctx.plugin_id(), "mock-plugin");
assert_eq!(ctx.agent_id(), "mock-agent");
let _kv = ctx.key_value_store();
}
#[test]
fn test_trait_objects_can_be_boxed() {
let _tool: Box<dyn Tool> = Box::new(MockTool);
let _channel: Box<dyn ChannelAdapter> = Box::new(MockChannelAdapter);
let _stage: Box<dyn PipelineStage> = Box::new(MockPipelineStage);
let _skill: Box<dyn Skill> = Box::new(MockSkill);
let _memory: Box<dyn MemoryBackend> = Box::new(MockMemoryBackend);
let _voice: Box<dyn VoiceHandler> = Box::new(MockVoiceHandler);
let _kv: Box<dyn KeyValueStore> = Box::new(MockKvStore);
let _ctx: Box<dyn ToolContext> = Box::new(MockToolContext);
}
#[test]
fn test_trait_objects_can_be_arced() {
let _tool: Arc<dyn Tool> = Arc::new(MockTool);
let _channel: Arc<dyn ChannelAdapter> = Arc::new(MockChannelAdapter);
let _stage: Arc<dyn PipelineStage> = Arc::new(MockPipelineStage);
let _skill: Arc<dyn Skill> = Arc::new(MockSkill);
let _memory: Arc<dyn MemoryBackend> = Arc::new(MockMemoryBackend);
let _voice: Arc<dyn VoiceHandler> = Arc::new(MockVoiceHandler);
let _kv: Arc<dyn KeyValueStore> = Arc::new(MockKvStore);
}
}