use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::secrets::CredentialMapping;
#[derive(Debug, Clone, Default)]
pub struct Capabilities {
pub workspace_read: Option<WorkspaceCapability>,
pub http: Option<HttpCapability>,
pub tool_invoke: Option<ToolInvokeCapability>,
pub secrets: Option<SecretsCapability>,
pub webhook: Option<WebhookCapability>,
pub websocket: Option<serde_json::Value>,
}
impl Capabilities {
pub fn none() -> Self {
Self::default()
}
pub fn with_workspace_read(mut self, prefixes: Vec<String>) -> Self {
self.workspace_read = Some(WorkspaceCapability {
allowed_prefixes: prefixes,
reader: None,
});
self
}
pub fn with_http(mut self, http: HttpCapability) -> Self {
self.http = Some(http);
self
}
pub fn with_tool_invoke(mut self, aliases: HashMap<String, String>) -> Self {
self.tool_invoke = Some(ToolInvokeCapability {
aliases,
rate_limit: RateLimitConfig::default(),
});
self
}
pub fn with_secrets(mut self, allowed: Vec<String>) -> Self {
self.secrets = Some(SecretsCapability {
allowed_names: allowed,
});
self
}
}
#[derive(Clone, Default)]
pub struct WorkspaceCapability {
pub allowed_prefixes: Vec<String>,
pub reader: Option<Arc<dyn WorkspaceReader>>,
}
impl std::fmt::Debug for WorkspaceCapability {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WorkspaceCapability")
.field("allowed_prefixes", &self.allowed_prefixes)
.field("reader", &self.reader.is_some())
.finish()
}
}
pub trait WorkspaceReader: Send + Sync {
fn read(&self, path: &str) -> Option<String>;
}
#[derive(Debug, Clone)]
pub struct HttpCapability {
pub allowlist: Vec<EndpointPattern>,
pub credentials: HashMap<String, CredentialMapping>,
pub rate_limit: RateLimitConfig,
pub max_request_bytes: usize,
pub max_response_bytes: usize,
pub timeout: Duration,
}
impl Default for HttpCapability {
fn default() -> Self {
Self {
allowlist: Vec::new(),
credentials: HashMap::new(),
rate_limit: RateLimitConfig::default(),
max_request_bytes: 1024 * 1024, max_response_bytes: 10 * 1024 * 1024, timeout: Duration::from_secs(30),
}
}
}
impl HttpCapability {
pub fn new(allowlist: Vec<EndpointPattern>) -> Self {
Self {
allowlist,
..Default::default()
}
}
pub fn with_credential(mut self, name: impl Into<String>, mapping: CredentialMapping) -> Self {
self.credentials.insert(name.into(), mapping);
self
}
pub fn with_rate_limit(mut self, rate_limit: RateLimitConfig) -> Self {
self.rate_limit = rate_limit;
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn with_max_request_bytes(mut self, bytes: usize) -> Self {
self.max_request_bytes = bytes;
self
}
pub fn with_max_response_bytes(mut self, bytes: usize) -> Self {
self.max_response_bytes = bytes;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EndpointPattern {
pub host: String,
pub path_prefix: Option<String>,
pub methods: Vec<String>,
}
impl EndpointPattern {
pub fn host(host: impl Into<String>) -> Self {
Self {
host: host.into(),
path_prefix: None,
methods: Vec::new(),
}
}
pub fn with_path_prefix(mut self, prefix: impl Into<String>) -> Self {
self.path_prefix = Some(prefix.into());
self
}
pub fn with_methods(mut self, methods: Vec<String>) -> Self {
self.methods = methods;
self
}
pub fn matches(&self, url_host: &str, url_path: &str, method: &str) -> bool {
if !self.host_matches(url_host) {
return false;
}
if let Some(ref prefix) = self.path_prefix
&& !url_path.starts_with(prefix)
{
return false;
}
if !self.methods.is_empty() {
let method_upper = method.to_uppercase();
if !self
.methods
.iter()
.any(|m| m.to_uppercase() == method_upper)
{
return false;
}
}
true
}
pub fn host_matches(&self, url_host: &str) -> bool {
if self.host == url_host {
return true;
}
if let Some(suffix) = self.host.strip_prefix("*.")
&& url_host.ends_with(suffix)
&& url_host.len() > suffix.len()
{
let prefix = &url_host[..url_host.len() - suffix.len()];
if prefix.ends_with('.') || prefix.is_empty() {
return true;
}
}
false
}
}
#[derive(Debug, Clone, Default)]
pub struct ToolInvokeCapability {
pub aliases: HashMap<String, String>,
pub rate_limit: RateLimitConfig,
}
impl ToolInvokeCapability {
pub fn new(aliases: HashMap<String, String>) -> Self {
Self {
aliases,
rate_limit: RateLimitConfig::default(),
}
}
pub fn resolve_alias(&self, alias: &str) -> Option<&str> {
self.aliases.get(alias).map(|s| s.as_str())
}
}
#[derive(Debug, Clone, Default)]
pub struct SecretsCapability {
pub allowed_names: Vec<String>,
}
impl SecretsCapability {
pub fn is_allowed(&self, name: &str) -> bool {
for pattern in &self.allowed_names {
if pattern == name {
return true;
}
if let Some(prefix) = pattern.strip_suffix('*')
&& name.starts_with(prefix)
{
return true;
}
}
false
}
}
pub use crate::tools::tool::ToolRateLimitConfig as RateLimitConfig;
#[derive(Debug, Clone, Default)]
pub struct WebhookCapability {
pub secret_header: Option<String>,
pub secret_name: Option<String>,
pub signature_key_secret_name: Option<String>,
pub hmac_secret_name: Option<String>,
pub hmac_signature_header: Option<String>,
pub hmac_timestamp_header: Option<String>,
pub hmac_prefix: Option<String>,
}
#[cfg(test)]
mod tests {
use crate::tools::wasm::capabilities::{Capabilities, EndpointPattern, SecretsCapability};
#[test]
fn test_capabilities_default_is_none() {
let caps = Capabilities::default();
assert!(caps.workspace_read.is_none());
assert!(caps.http.is_none());
assert!(caps.tool_invoke.is_none());
assert!(caps.secrets.is_none());
assert!(caps.webhook.is_none());
assert!(caps.websocket.is_none());
}
#[test]
fn test_endpoint_pattern_exact_host() {
let pattern = EndpointPattern::host("api.example.com");
assert!(pattern.matches("api.example.com", "/", "GET"));
assert!(!pattern.matches("other.example.com", "/", "GET"));
}
#[test]
fn test_endpoint_pattern_wildcard_host() {
let pattern = EndpointPattern::host("*.example.com");
assert!(pattern.matches("api.example.com", "/", "GET"));
assert!(pattern.matches("sub.api.example.com", "/", "GET"));
assert!(!pattern.matches("example.com", "/", "GET"));
assert!(!pattern.matches("notexample.com", "/", "GET"));
}
#[test]
fn test_endpoint_pattern_path_prefix() {
let pattern = EndpointPattern::host("api.example.com").with_path_prefix("/v1/");
assert!(pattern.matches("api.example.com", "/v1/users", "GET"));
assert!(pattern.matches("api.example.com", "/v1/", "GET"));
assert!(!pattern.matches("api.example.com", "/v2/users", "GET"));
assert!(!pattern.matches("api.example.com", "/", "GET"));
}
#[test]
fn test_endpoint_pattern_methods() {
let pattern = EndpointPattern::host("api.example.com")
.with_methods(vec!["GET".to_string(), "POST".to_string()]);
assert!(pattern.matches("api.example.com", "/", "GET"));
assert!(pattern.matches("api.example.com", "/", "get")); assert!(pattern.matches("api.example.com", "/", "POST"));
assert!(!pattern.matches("api.example.com", "/", "DELETE"));
}
#[test]
fn test_secrets_capability_exact_match() {
let cap = SecretsCapability {
allowed_names: vec!["openai_key".to_string()],
};
assert!(cap.is_allowed("openai_key"));
assert!(!cap.is_allowed("anthropic_key"));
}
#[test]
fn test_secrets_capability_glob() {
let cap = SecretsCapability {
allowed_names: vec!["openai_*".to_string()],
};
assert!(cap.is_allowed("openai_key"));
assert!(cap.is_allowed("openai_org"));
assert!(!cap.is_allowed("anthropic_key"));
}
#[test]
fn test_capabilities_builder() {
let caps = Capabilities::none()
.with_workspace_read(vec!["context/".to_string()])
.with_secrets(vec!["test_*".to_string()]);
assert!(caps.workspace_read.is_some());
assert!(caps.secrets.is_some());
assert!(caps.http.is_none());
}
}