use std::sync::Arc;
use async_trait::async_trait;
use serde_json::{json, Value};
use crate::bus::{MessageBus, OutboundMessage};
use crate::error::{Result, ZeptoError};
use super::{Tool, ToolContext};
pub struct MessageTool {
bus: Arc<MessageBus>,
}
impl MessageTool {
pub fn new(bus: Arc<MessageBus>) -> Self {
Self { bus }
}
}
#[async_trait]
impl Tool for MessageTool {
fn name(&self) -> &str {
"message"
}
fn description(&self) -> &str {
"Send a proactive message to a chat."
}
fn parameters(&self) -> Value {
json!({
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "Message text to send"
},
"channel": {
"type": "string",
"description": "Destination channel. Optional when context already has channel."
},
"chat_id": {
"type": "string",
"description": "Destination chat ID. Optional when context already has chat_id."
}
},
"required": ["content"]
})
}
async fn execute(&self, args: Value, ctx: &ToolContext) -> Result<String> {
let content = args
.get("content")
.and_then(|v| v.as_str())
.map(str::trim)
.filter(|s| !s.is_empty())
.ok_or_else(|| ZeptoError::Tool("Missing 'content' parameter".to_string()))?;
let channel = args
.get("channel")
.and_then(|v| v.as_str())
.map(str::to_string)
.or_else(|| ctx.channel.clone())
.ok_or_else(|| ZeptoError::Tool("No target channel specified".to_string()))?;
let chat_id = args
.get("chat_id")
.and_then(|v| v.as_str())
.map(str::to_string)
.or_else(|| ctx.chat_id.clone())
.ok_or_else(|| ZeptoError::Tool("No target chat_id specified".to_string()))?;
const ALLOWED_CHANNELS: &[&str] = &["telegram", "slack", "discord", "webhook"];
if !ALLOWED_CHANNELS
.iter()
.any(|c| c.eq_ignore_ascii_case(&channel))
{
return Err(ZeptoError::Tool(format!(
"Unknown channel '{}'. Allowed: {}",
channel,
ALLOWED_CHANNELS.join(", ")
)));
}
self.bus
.publish_outbound(OutboundMessage::new(&channel, &chat_id, content))
.await
.map_err(|e| ZeptoError::Tool(format!("Failed to publish message: {}", e)))?;
Ok(format!("Message sent to {}:{}", channel, chat_id))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_tool_properties() {
let bus = Arc::new(MessageBus::new());
let tool = MessageTool::new(bus);
assert_eq!(tool.name(), "message");
assert!(tool.description().contains("proactive"));
}
#[tokio::test]
async fn test_message_tool_with_context_target() {
let bus = Arc::new(MessageBus::new());
let tool = MessageTool::new(bus.clone());
let ctx = ToolContext::new().with_channel("telegram", "12345");
let result = tool.execute(json!({"content": "Hello"}), &ctx).await;
assert!(result.is_ok());
let outbound = bus.consume_outbound().await.expect("outbound message");
assert_eq!(outbound.channel, "telegram");
assert_eq!(outbound.chat_id, "12345");
assert_eq!(outbound.content, "Hello");
}
#[tokio::test]
async fn test_message_tool_missing_content() {
let bus = Arc::new(MessageBus::new());
let tool = MessageTool::new(bus);
let result = tool
.execute(
json!({"channel": "telegram", "chat_id": "12345"}),
&ToolContext::new(),
)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_message_tool_missing_target() {
let bus = Arc::new(MessageBus::new());
let tool = MessageTool::new(bus);
let result = tool
.execute(json!({"content": "Hello"}), &ToolContext::new())
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_message_tool_rejects_unknown_channel() {
let bus = Arc::new(MessageBus::new());
let tool = MessageTool::new(bus);
let result = tool
.execute(
json!({"content": "Hello", "channel": "evil-channel", "chat_id": "123"}),
&ToolContext::new(),
)
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("Unknown channel"));
}
#[tokio::test]
async fn test_message_tool_allows_known_channels() {
for channel in &["telegram", "slack", "discord", "webhook"] {
let bus = Arc::new(MessageBus::new());
let tool = MessageTool::new(bus.clone());
let result = tool
.execute(
json!({"content": "Hi", "channel": channel, "chat_id": "123"}),
&ToolContext::new(),
)
.await;
assert!(result.is_ok(), "Channel '{}' should be allowed", channel);
}
}
}