use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::tools::wasm::{Capabilities as ToolCapabilities, RateLimitConfig};
pub const MIN_POLL_INTERVAL_MS: u32 = 30_000;
pub const DEFAULT_EMIT_RATE_PER_MINUTE: u32 = 100;
pub const DEFAULT_EMIT_RATE_PER_HOUR: u32 = 5000;
#[derive(Debug, Clone)]
pub struct ChannelCapabilities {
pub tool_capabilities: ToolCapabilities,
pub allowed_paths: Vec<String>,
pub allow_polling: bool,
pub min_poll_interval_ms: u32,
pub workspace_prefix: String,
pub emit_rate_limit: EmitRateLimitConfig,
pub max_message_size: usize,
pub callback_timeout: Duration,
}
impl Default for ChannelCapabilities {
fn default() -> Self {
Self {
tool_capabilities: ToolCapabilities::default(),
allowed_paths: Vec::new(),
allow_polling: false,
min_poll_interval_ms: MIN_POLL_INTERVAL_MS,
workspace_prefix: String::new(),
emit_rate_limit: EmitRateLimitConfig::default(),
max_message_size: 64 * 1024, callback_timeout: Duration::from_secs(30),
}
}
}
impl ChannelCapabilities {
pub fn for_channel(name: &str) -> Self {
Self {
workspace_prefix: format!("channels/{}/", name),
..Default::default()
}
}
pub fn with_path(mut self, path: impl Into<String>) -> Self {
self.allowed_paths.push(path.into());
self
}
pub fn with_polling(mut self, min_interval_ms: u32) -> Self {
self.allow_polling = true;
self.min_poll_interval_ms = min_interval_ms.max(MIN_POLL_INTERVAL_MS);
self
}
pub fn with_emit_rate_limit(mut self, rate_limit: EmitRateLimitConfig) -> Self {
self.emit_rate_limit = rate_limit;
self
}
pub fn with_callback_timeout(mut self, timeout: Duration) -> Self {
self.callback_timeout = timeout;
self
}
pub fn with_tool_capabilities(mut self, capabilities: ToolCapabilities) -> Self {
self.tool_capabilities = capabilities;
self
}
pub fn is_path_allowed(&self, path: &str) -> bool {
self.allowed_paths.iter().any(|p| p == path)
}
pub fn validate_poll_interval(&self, interval_ms: u32) -> Result<u32, String> {
if !self.allow_polling {
return Err("Polling not allowed for this channel".to_string());
}
Ok(interval_ms.max(self.min_poll_interval_ms))
}
pub fn prefix_workspace_path(&self, path: &str) -> String {
if self.workspace_prefix.is_empty() {
path.to_string()
} else {
format!("{}{}", self.workspace_prefix, path)
}
}
pub fn validate_workspace_path(&self, path: &str) -> Result<String, String> {
if path.starts_with('/') {
return Err("Absolute paths not allowed".to_string());
}
if path.contains("..") {
return Err("Parent directory references not allowed".to_string());
}
if path.contains('\0') {
return Err("Null bytes not allowed".to_string());
}
Ok(self.prefix_workspace_path(path))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HttpEndpointConfig {
pub path: String,
pub methods: Vec<String>,
pub require_secret: bool,
}
impl HttpEndpointConfig {
pub fn post_webhook(path: impl Into<String>) -> Self {
Self {
path: path.into(),
methods: vec!["POST".to_string()],
require_secret: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PollConfig {
pub interval_ms: u32,
pub enabled: bool,
}
impl Default for PollConfig {
fn default() -> Self {
Self {
interval_ms: MIN_POLL_INTERVAL_MS,
enabled: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmitRateLimitConfig {
pub messages_per_minute: u32,
pub messages_per_hour: u32,
}
impl Default for EmitRateLimitConfig {
fn default() -> Self {
Self {
messages_per_minute: DEFAULT_EMIT_RATE_PER_MINUTE,
messages_per_hour: DEFAULT_EMIT_RATE_PER_HOUR,
}
}
}
impl From<RateLimitConfig> for EmitRateLimitConfig {
fn from(config: RateLimitConfig) -> Self {
Self {
messages_per_minute: config.requests_per_minute,
messages_per_hour: config.requests_per_hour,
}
}
}
#[cfg(test)]
mod tests {
use crate::channels::wasm::capabilities::{
ChannelCapabilities, EmitRateLimitConfig, HttpEndpointConfig, MIN_POLL_INTERVAL_MS,
};
#[test]
fn test_default_capabilities() {
let caps = ChannelCapabilities::default();
assert!(caps.allowed_paths.is_empty());
assert!(!caps.allow_polling);
assert_eq!(caps.min_poll_interval_ms, MIN_POLL_INTERVAL_MS);
}
#[test]
fn test_for_channel() {
let caps = ChannelCapabilities::for_channel("slack");
assert_eq!(caps.workspace_prefix, "channels/slack/");
}
#[test]
fn test_path_allowed() {
let caps = ChannelCapabilities::default()
.with_path("/webhook/slack")
.with_path("/webhook/slack/events");
assert!(caps.is_path_allowed("/webhook/slack"));
assert!(caps.is_path_allowed("/webhook/slack/events"));
assert!(!caps.is_path_allowed("/webhook/telegram"));
}
#[test]
fn test_poll_interval_validation() {
let caps = ChannelCapabilities::default().with_polling(60_000);
assert_eq!(caps.validate_poll_interval(90_000).unwrap(), 90_000);
assert_eq!(caps.validate_poll_interval(1000).unwrap(), 60_000);
let no_poll_caps = ChannelCapabilities::default();
assert!(no_poll_caps.validate_poll_interval(60_000).is_err());
}
#[test]
fn test_workspace_path_validation() {
let caps = ChannelCapabilities::for_channel("slack");
let result = caps.validate_workspace_path("state.json");
assert_eq!(result.unwrap(), "channels/slack/state.json");
let result = caps.validate_workspace_path("data/users.json");
assert_eq!(result.unwrap(), "channels/slack/data/users.json");
let result = caps.validate_workspace_path("/etc/passwd");
assert!(result.is_err());
let result = caps.validate_workspace_path("../secrets/key.txt");
assert!(result.is_err());
let result = caps.validate_workspace_path("file\0.txt");
assert!(result.is_err());
}
#[test]
fn test_http_endpoint_config() {
let endpoint = HttpEndpointConfig::post_webhook("/webhook/slack");
assert_eq!(endpoint.path, "/webhook/slack");
assert_eq!(endpoint.methods, vec!["POST"]);
assert!(endpoint.require_secret);
}
#[test]
fn test_emit_rate_limit_default() {
let limit = EmitRateLimitConfig::default();
assert_eq!(limit.messages_per_minute, 100);
assert_eq!(limit.messages_per_hour, 5000);
}
}