use std::time::{SystemTime, UNIX_EPOCH};
use crate::tools::wasm::capabilities::Capabilities;
use crate::tools::wasm::error::WasmError;
const MAX_LOG_ENTRIES: usize = 1000;
const MAX_LOG_MESSAGE_BYTES: usize = 4096;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LogLevel {
Trace,
Debug,
Info,
Warn,
Error,
}
impl std::fmt::Display for LogLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LogLevel::Trace => write!(f, "TRACE"),
LogLevel::Debug => write!(f, "DEBUG"),
LogLevel::Info => write!(f, "INFO"),
LogLevel::Warn => write!(f, "WARN"),
LogLevel::Error => write!(f, "ERROR"),
}
}
}
#[derive(Debug, Clone)]
pub struct LogEntry {
pub level: LogLevel,
pub message: String,
pub timestamp_millis: u64,
}
pub struct HostState {
logs: Vec<LogEntry>,
logging_enabled: bool,
capabilities: Capabilities,
logs_dropped: usize,
user_id: Option<String>,
http_request_count: u32,
tool_invoke_count: u32,
}
impl std::fmt::Debug for HostState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HostState")
.field("logs_count", &self.logs.len())
.field("logging_enabled", &self.logging_enabled)
.field("logs_dropped", &self.logs_dropped)
.field("user_id", &self.user_id)
.field("http_request_count", &self.http_request_count)
.field("tool_invoke_count", &self.tool_invoke_count)
.finish()
}
}
impl HostState {
pub fn new(capabilities: Capabilities) -> Self {
Self {
logs: Vec::new(),
logging_enabled: true,
capabilities,
logs_dropped: 0,
user_id: None,
http_request_count: 0,
tool_invoke_count: 0,
}
}
pub fn new_with_user(capabilities: Capabilities, user_id: impl Into<String>) -> Self {
Self {
logs: Vec::new(),
logging_enabled: true,
capabilities,
logs_dropped: 0,
user_id: Some(user_id.into()),
http_request_count: 0,
tool_invoke_count: 0,
}
}
pub fn minimal() -> Self {
Self::new(Capabilities::default())
}
pub fn user_id(&self) -> Option<&str> {
self.user_id.as_deref()
}
pub fn capabilities(&self) -> &Capabilities {
&self.capabilities
}
pub fn log(&mut self, level: LogLevel, message: String) -> Result<(), WasmError> {
if !self.logging_enabled {
self.logs_dropped += 1;
return Ok(()); }
if self.logs.len() >= MAX_LOG_ENTRIES {
self.logging_enabled = false;
self.logs_dropped += 1;
tracing::warn!(
"WASM log limit reached ({} entries), further logs dropped",
MAX_LOG_ENTRIES
);
return Ok(());
}
let message = if message.len() > MAX_LOG_MESSAGE_BYTES {
let mut truncated = message[..MAX_LOG_MESSAGE_BYTES].to_string();
truncated.push_str("... (truncated)");
truncated
} else {
message
};
let timestamp_millis = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
self.logs.push(LogEntry {
level,
message,
timestamp_millis,
});
Ok(())
}
pub fn now_millis(&self) -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
pub fn workspace_read(&self, path: &str) -> Result<Option<String>, WasmError> {
let capability = match &self.capabilities.workspace_read {
Some(cap) => cap,
None => return Ok(None), };
validate_workspace_path(path)?;
if !capability.allowed_prefixes.is_empty() {
let allowed = capability
.allowed_prefixes
.iter()
.any(|prefix| path.starts_with(prefix));
if !allowed {
tracing::debug!(
path = path,
allowed = ?capability.allowed_prefixes,
"WASM workspace read denied: path not in allowed prefixes"
);
return Ok(None);
}
}
match &capability.reader {
Some(reader) => Ok(reader.read(path)),
None => Ok(None), }
}
pub fn take_logs(&mut self) -> Vec<LogEntry> {
std::mem::take(&mut self.logs)
}
pub fn logs_dropped(&self) -> usize {
self.logs_dropped
}
pub fn secret_exists(&self, name: &str) -> bool {
let capability = match &self.capabilities.secrets {
Some(cap) => cap,
None => return false,
};
capability.is_allowed(name)
}
pub fn check_http_allowed(&self, url: &str, method: &str) -> Result<(), String> {
let capability = self
.capabilities
.http
.as_ref()
.ok_or_else(|| "HTTP capability not granted".to_string())?;
use crate::tools::wasm::allowlist::AllowlistValidator;
let validator = AllowlistValidator::new(capability.allowlist.clone());
let result = validator.validate(url, method);
if result.is_allowed() {
Ok(())
} else {
Err(format!("HTTP request not allowed: {:?}", result))
}
}
pub fn check_tool_invoke_allowed(&self, alias: &str) -> Result<String, String> {
let capability = self
.capabilities
.tool_invoke
.as_ref()
.ok_or_else(|| "Tool invocation capability not granted".to_string())?;
capability
.resolve_alias(alias)
.map(|s| s.to_string())
.ok_or_else(|| format!("Unknown tool alias: {}", alias))
}
pub fn record_http_request(&mut self) -> Result<(), String> {
let _capability = self
.capabilities
.http
.as_ref()
.ok_or_else(|| "HTTP capability not granted".to_string())?;
self.http_request_count += 1;
const MAX_REQUESTS_PER_EXECUTION: u32 = 50;
if self.http_request_count > MAX_REQUESTS_PER_EXECUTION {
return Err(format!(
"Too many HTTP requests in single execution (max {})",
MAX_REQUESTS_PER_EXECUTION
));
}
Ok(())
}
pub fn record_tool_invoke(&mut self) -> Result<(), String> {
self.tool_invoke_count += 1;
const MAX_INVOKES_PER_EXECUTION: u32 = 20;
if self.tool_invoke_count > MAX_INVOKES_PER_EXECUTION {
return Err(format!(
"Too many tool invocations in single execution (max {})",
MAX_INVOKES_PER_EXECUTION
));
}
Ok(())
}
pub fn http_request_count(&self) -> u32 {
self.http_request_count
}
pub fn tool_invoke_count(&self) -> u32 {
self.tool_invoke_count
}
}
fn validate_workspace_path(path: &str) -> Result<(), WasmError> {
if path.starts_with('/') {
return Err(WasmError::PathTraversalBlocked(
"absolute paths not allowed".to_string(),
));
}
if path.contains("..") {
return Err(WasmError::PathTraversalBlocked(
"parent directory references not allowed".to_string(),
));
}
if path.contains('\0') {
return Err(WasmError::PathTraversalBlocked(
"null bytes not allowed".to_string(),
));
}
if path.len() >= 2 && path.chars().nth(1) == Some(':') {
return Err(WasmError::PathTraversalBlocked(
"Windows-style paths not allowed".to_string(),
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::tools::wasm::capabilities::{
Capabilities, SecretsCapability, WorkspaceCapability, WorkspaceReader,
};
use crate::tools::wasm::host::{
HostState, LogLevel, MAX_LOG_ENTRIES, MAX_LOG_MESSAGE_BYTES, validate_workspace_path,
};
struct MockReader {
content: String,
}
impl WorkspaceReader for MockReader {
fn read(&self, _path: &str) -> Option<String> {
Some(self.content.clone())
}
}
#[test]
fn test_logging_basic() {
let mut state = HostState::minimal();
state
.log(LogLevel::Info, "test message".to_string())
.unwrap();
let logs = state.take_logs();
assert_eq!(logs.len(), 1);
assert_eq!(logs[0].level, LogLevel::Info);
assert_eq!(logs[0].message, "test message");
}
#[test]
fn test_logging_rate_limit() {
let mut state = HostState::minimal();
for i in 0..MAX_LOG_ENTRIES {
state
.log(LogLevel::Debug, format!("message {}", i))
.unwrap();
}
state
.log(LogLevel::Info, "should be dropped".to_string())
.unwrap();
assert_eq!(state.take_logs().len(), MAX_LOG_ENTRIES);
assert_eq!(state.logs_dropped(), 1);
}
#[test]
fn test_logging_truncation() {
let mut state = HostState::minimal();
let long_message = "x".repeat(MAX_LOG_MESSAGE_BYTES + 1000);
state.log(LogLevel::Info, long_message).unwrap();
let logs = state.take_logs();
assert!(logs[0].message.len() <= MAX_LOG_MESSAGE_BYTES + 20); assert!(logs[0].message.ends_with("... (truncated)"));
}
#[test]
fn test_now_millis() {
let state = HostState::minimal();
let now = state.now_millis();
assert!(now > 1577836800000); }
#[test]
fn test_workspace_read_no_capability() {
let state = HostState::minimal();
let result = state.workspace_read("context/test.md").unwrap();
assert!(result.is_none());
}
#[test]
fn test_workspace_read_with_capability() {
let reader = Arc::new(MockReader {
content: "test content".to_string(),
});
let capabilities = Capabilities {
workspace_read: Some(WorkspaceCapability {
allowed_prefixes: vec![],
reader: Some(reader),
}),
..Default::default()
};
let state = HostState::new(capabilities);
let result = state.workspace_read("context/test.md").unwrap();
assert_eq!(result, Some("test content".to_string()));
}
#[test]
fn test_workspace_read_prefix_restriction() {
let reader = Arc::new(MockReader {
content: "test content".to_string(),
});
let capabilities = Capabilities {
workspace_read: Some(WorkspaceCapability {
allowed_prefixes: vec!["context/".to_string()],
reader: Some(reader),
}),
..Default::default()
};
let state = HostState::new(capabilities);
let result = state.workspace_read("context/test.md").unwrap();
assert!(result.is_some());
let result = state.workspace_read("secrets/api_key.txt").unwrap();
assert!(result.is_none());
}
#[test]
fn test_path_validation_blocks_traversal() {
assert!(validate_workspace_path("../etc/passwd").is_err());
assert!(validate_workspace_path("context/../secrets").is_err());
assert!(validate_workspace_path("context/test/../../secrets").is_err());
}
#[test]
fn test_path_validation_blocks_absolute() {
assert!(validate_workspace_path("/etc/passwd").is_err());
assert!(validate_workspace_path("/context/test.md").is_err());
}
#[test]
fn test_path_validation_blocks_null_bytes() {
assert!(validate_workspace_path("context/test\0.md").is_err());
}
#[test]
fn test_path_validation_blocks_windows_paths() {
assert!(validate_workspace_path("C:\\Windows\\System32").is_err());
assert!(validate_workspace_path("D:secrets").is_err());
}
#[test]
fn test_path_validation_allows_valid_paths() {
assert!(validate_workspace_path("context/test.md").is_ok());
assert!(validate_workspace_path("daily/2024-01-15.md").is_ok());
assert!(validate_workspace_path("projects/alpha/notes.md").is_ok());
assert!(validate_workspace_path("MEMORY.md").is_ok());
}
#[test]
fn test_secret_exists_no_capability() {
let state = HostState::minimal();
assert!(!state.secret_exists("any_secret"));
}
#[test]
fn test_secret_exists_with_capability() {
let capabilities = Capabilities {
secrets: Some(SecretsCapability {
allowed_names: vec!["openai_*".to_string(), "exact_name".to_string()],
}),
..Default::default()
};
let state = HostState::new(capabilities);
assert!(state.secret_exists("openai_key"));
assert!(state.secret_exists("openai_org"));
assert!(state.secret_exists("exact_name"));
assert!(!state.secret_exists("stripe_key"));
}
#[test]
fn test_http_request_rate_limit() {
let capabilities = Capabilities {
http: Some(crate::tools::wasm::capabilities::HttpCapability::default()),
..Default::default()
};
let mut state = HostState::new(capabilities);
for _ in 0..50 {
assert!(state.record_http_request().is_ok());
}
assert!(state.record_http_request().is_err());
}
#[test]
fn test_tool_invoke_rate_limit() {
let capabilities = Capabilities {
tool_invoke: Some(crate::tools::wasm::capabilities::ToolInvokeCapability::default()),
..Default::default()
};
let mut state = HostState::new(capabilities);
for _ in 0..20 {
assert!(state.record_tool_invoke().is_ok());
}
assert!(state.record_tool_invoke().is_err());
}
#[test]
fn test_new_with_user() {
let state = HostState::new_with_user(Capabilities::default(), "user123");
assert_eq!(state.user_id(), Some("user123"));
}
}