use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use tokio::sync::{RwLock, mpsc, oneshot};
use tokio_stream::wrappers::ReceiverStream;
use uuid::Uuid;
use wasmtime::Store;
use wasmtime::component::{Component, Linker};
use wasmtime_wasi::{ResourceTable, WasiCtx, WasiCtxBuilder, WasiView};
use crate::channels::wasm::capabilities::ChannelCapabilities;
use crate::channels::wasm::error::WasmChannelError;
use crate::channels::wasm::host::{ChannelEmitRateLimiter, ChannelHostState, EmittedMessage};
use crate::channels::wasm::router::RegisteredEndpoint;
use crate::channels::wasm::runtime::{PreparedChannelModule, WasmChannelRuntime};
use crate::channels::wasm::schema::ChannelConfig;
use crate::channels::{Channel, IncomingMessage, MessageStream, OutgoingResponse, StatusUpdate};
use crate::error::ChannelError;
use crate::pairing::PairingStore;
use crate::safety::LeakDetector;
use crate::tools::wasm::LogLevel;
use crate::tools::wasm::WasmResourceLimiter;
wasmtime::component::bindgen!({
path: "wit/channel.wit",
world: "sandboxed-channel",
async: false,
with: {
},
});
struct ChannelStoreData {
limiter: WasmResourceLimiter,
host_state: ChannelHostState,
wasi: WasiCtx,
table: ResourceTable,
credentials: HashMap<String, String>,
pairing_store: Arc<PairingStore>,
http_runtime: Option<tokio::runtime::Runtime>,
}
impl ChannelStoreData {
fn new(
memory_limit: u64,
channel_name: &str,
capabilities: ChannelCapabilities,
credentials: HashMap<String, String>,
pairing_store: Arc<PairingStore>,
) -> Self {
let wasi = WasiCtxBuilder::new().build();
Self {
limiter: WasmResourceLimiter::new(memory_limit),
host_state: ChannelHostState::new(channel_name, capabilities),
wasi,
table: ResourceTable::new(),
credentials,
pairing_store,
http_runtime: None,
}
}
fn inject_credentials(&self, input: &str, context: &str) -> String {
let mut result = input.to_string();
tracing::debug!(
input_preview = %input.chars().take(100).collect::<String>(),
context = %context,
credential_count = self.credentials.len(),
credential_names = ?self.credentials.keys().collect::<Vec<_>>(),
"Injecting credentials"
);
for (name, value) in &self.credentials {
let placeholder = format!("{{{}}}", name);
if result.contains(&placeholder) {
tracing::debug!(
placeholder = %placeholder,
context = %context,
"Found and replacing credential placeholder"
);
result = result.replace(&placeholder, value);
}
}
if result.contains('{') && result.contains('}') {
let brace_pattern = regex::Regex::new(r"\{[A-Z_]+\}").ok();
if let Some(re) = brace_pattern
&& re.is_match(&result)
{
tracing::warn!(
context = %context,
"String may contain unresolved credential placeholders"
);
}
}
result
}
fn redact_credentials(&self, text: &str) -> String {
let mut result = text.to_string();
for (name, value) in &self.credentials {
if !value.is_empty() {
result = result.replace(value, &format!("[REDACTED:{}]", name));
}
}
result
}
}
impl WasiView for ChannelStoreData {
fn ctx(&mut self) -> &mut WasiCtx {
&mut self.wasi
}
fn table(&mut self) -> &mut ResourceTable {
&mut self.table
}
}
impl near::agent::channel_host::Host for ChannelStoreData {
fn log(&mut self, level: near::agent::channel_host::LogLevel, message: String) {
let log_level = match level {
near::agent::channel_host::LogLevel::Trace => LogLevel::Trace,
near::agent::channel_host::LogLevel::Debug => LogLevel::Debug,
near::agent::channel_host::LogLevel::Info => LogLevel::Info,
near::agent::channel_host::LogLevel::Warn => LogLevel::Warn,
near::agent::channel_host::LogLevel::Error => LogLevel::Error,
};
let _ = self.host_state.log(log_level, message);
}
fn now_millis(&mut self) -> u64 {
self.host_state.now_millis()
}
fn workspace_read(&mut self, path: String) -> Option<String> {
self.host_state.workspace_read(&path).ok().flatten()
}
fn workspace_write(&mut self, path: String, content: String) -> Result<(), String> {
self.host_state
.workspace_write(&path, content)
.map_err(|e| e.to_string())
}
fn http_request(
&mut self,
method: String,
url: String,
headers_json: String,
body: Option<Vec<u8>>,
timeout_ms: Option<u32>,
) -> Result<near::agent::channel_host::HttpResponse, String> {
tracing::info!(
method = %method,
original_url = %url,
body_len = body.as_ref().map(|b| b.len()).unwrap_or(0),
"WASM http_request called"
);
let injected_url = self.inject_credentials(&url, "url");
let url_changed = injected_url != url;
tracing::info!(url_changed = url_changed, "URL after credential injection");
self.host_state
.check_http_allowed(&injected_url, &method)
.map_err(|e| {
tracing::error!(error = %e, "HTTP not allowed");
format!("HTTP not allowed: {}", e)
})?;
self.host_state.record_http_request().map_err(|e| {
tracing::error!(error = %e, "Rate limit exceeded");
format!("Rate limit exceeded: {}", e)
})?;
let raw_headers: std::collections::HashMap<String, String> =
serde_json::from_str(&headers_json).unwrap_or_default();
let headers: std::collections::HashMap<String, String> = raw_headers
.into_iter()
.map(|(k, v)| {
(
k.clone(),
self.inject_credentials(&v, &format!("header:{}", k)),
)
})
.collect();
let headers_changed = headers
.values()
.any(|v| v.contains("Bearer ") && !v.contains('{'));
tracing::debug!(
header_count = headers.len(),
headers_changed = headers_changed,
"Parsed and injected request headers"
);
let url = injected_url;
let leak_detector = LeakDetector::new();
let header_vec: Vec<(String, String)> = headers
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
leak_detector
.scan_http_request(&url, &header_vec, body.as_deref())
.map_err(|e| format!("Potential secret leak blocked: {}", e))?;
let max_response_bytes = self
.host_state
.capabilities()
.tool_capabilities
.http
.as_ref()
.map(|h| h.max_response_bytes)
.unwrap_or(10 * 1024 * 1024);
if self.http_runtime.is_none() {
self.http_runtime = Some(
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| format!("Failed to create HTTP runtime: {e}"))?,
);
}
let rt = self.http_runtime.as_ref().expect("just initialized");
let result = rt.block_on(async {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(10))
.build()
.map_err(|e| format!("Failed to build HTTP client: {e}"))?;
let mut request = match method.to_uppercase().as_str() {
"GET" => client.get(&url),
"POST" => client.post(&url),
"PUT" => client.put(&url),
"DELETE" => client.delete(&url),
"PATCH" => client.patch(&url),
"HEAD" => client.head(&url),
_ => return Err(format!("Unsupported HTTP method: {}", method)),
};
for (key, value) in headers {
request = request.header(&key, &value);
}
if let Some(body_bytes) = body {
request = request.body(body_bytes);
}
let timeout_ms = timeout_ms.unwrap_or(30_000).min(300_000) as u64;
let timeout = std::time::Duration::from_millis(timeout_ms);
let response = request.timeout(timeout).send().await.map_err(|e| {
let mut chain = format!("HTTP request failed: {}", e);
let mut source = std::error::Error::source(&e);
while let Some(cause) = source {
chain.push_str(&format!(" -> {}", cause));
source = cause.source();
}
chain
})?;
let status = response.status().as_u16();
let response_headers: std::collections::HashMap<String, String> = response
.headers()
.iter()
.filter_map(|(k, v)| {
v.to_str()
.ok()
.map(|v| (k.as_str().to_string(), v.to_string()))
})
.collect();
let headers_json = serde_json::to_string(&response_headers).unwrap_or_default();
let max_response = max_response_bytes;
if let Some(cl) = response.content_length()
&& cl as usize > max_response
{
return Err(format!(
"Response body too large: {} bytes exceeds limit of {} bytes",
cl, max_response
));
}
let body = response
.bytes()
.await
.map_err(|e| format!("Failed to read response body: {}", e))?;
if body.len() > max_response {
return Err(format!(
"Response body too large: {} bytes exceeds limit of {} bytes",
body.len(),
max_response
));
}
let body = body.to_vec();
tracing::info!(
status = status,
body_len = body.len(),
"HTTP response received"
);
if let Ok(body_str) = std::str::from_utf8(&body) {
let truncated = if body_str.chars().count() > 500 {
format!("{}...", body_str.chars().take(500).collect::<String>())
} else {
body_str.to_string()
};
tracing::debug!(body = %truncated, "Response body");
}
if let Ok(body_str) = std::str::from_utf8(&body) {
leak_detector
.scan_and_clean(body_str)
.map_err(|e| format!("Potential secret leak in response: {}", e))?;
}
Ok(near::agent::channel_host::HttpResponse {
status,
headers_json,
body,
})
});
let result = result.map_err(|e| self.redact_credentials(&e));
match &result {
Ok(resp) => {
tracing::info!(status = resp.status, "http_request completed successfully");
}
Err(e) => {
tracing::error!(error = %e, "http_request failed");
}
}
result
}
fn secret_exists(&mut self, name: String) -> bool {
self.host_state.secret_exists(&name)
}
fn emit_message(&mut self, msg: near::agent::channel_host::EmittedMessage) {
tracing::info!(
user_id = %msg.user_id,
user_name = ?msg.user_name,
content_len = msg.content.len(),
"WASM emit_message called"
);
let mut emitted = EmittedMessage::new(msg.user_id.clone(), msg.content.clone());
if let Some(name) = msg.user_name {
emitted = emitted.with_user_name(name);
}
if let Some(tid) = msg.thread_id {
emitted = emitted.with_thread_id(tid);
}
emitted = emitted.with_metadata(msg.metadata_json);
match self.host_state.emit_message(emitted) {
Ok(()) => {
tracing::info!("Message emitted to host state successfully");
}
Err(e) => {
tracing::error!(error = %e, "Failed to emit message to host state");
}
}
}
fn pairing_upsert_request(
&mut self,
channel: String,
id: String,
meta_json: String,
) -> Result<near::agent::channel_host::PairingUpsertResult, String> {
let meta = if meta_json.is_empty() {
None
} else {
serde_json::from_str(&meta_json).ok()
};
match self.pairing_store.upsert_request(&channel, &id, meta) {
Ok(r) => Ok(near::agent::channel_host::PairingUpsertResult {
code: r.code,
created: r.created,
}),
Err(e) => Err(e.to_string()),
}
}
fn pairing_is_allowed(
&mut self,
channel: String,
id: String,
username: Option<String>,
) -> Result<bool, String> {
self.pairing_store
.is_sender_allowed(&channel, &id, username.as_deref())
.map_err(|e| e.to_string())
}
fn pairing_read_allow_from(&mut self, channel: String) -> Result<Vec<String>, String> {
self.pairing_store
.read_allow_from(&channel)
.map_err(|e| e.to_string())
}
}
#[allow(dead_code)]
pub struct WasmChannel {
name: String,
runtime: Arc<WasmChannelRuntime>,
prepared: Arc<PreparedChannelModule>,
capabilities: ChannelCapabilities,
config_json: RwLock<String>,
channel_config: RwLock<Option<ChannelConfig>>,
message_tx: Arc<RwLock<Option<mpsc::Sender<IncomingMessage>>>>,
pending_responses: RwLock<HashMap<Uuid, oneshot::Sender<String>>>,
rate_limiter: Arc<RwLock<ChannelEmitRateLimiter>>,
shutdown_tx: RwLock<Option<oneshot::Sender<()>>>,
poll_shutdown_tx: RwLock<Option<oneshot::Sender<()>>>,
endpoints: RwLock<Vec<RegisteredEndpoint>>,
credentials: Arc<RwLock<HashMap<String, String>>>,
typing_task: RwLock<Option<tokio::task::JoinHandle<()>>>,
pairing_store: Arc<PairingStore>,
}
impl WasmChannel {
pub fn new(
runtime: Arc<WasmChannelRuntime>,
prepared: Arc<PreparedChannelModule>,
capabilities: ChannelCapabilities,
config_json: String,
pairing_store: Arc<PairingStore>,
) -> Self {
let name = prepared.name.clone();
let rate_limiter = ChannelEmitRateLimiter::new(capabilities.emit_rate_limit.clone());
Self {
name,
runtime,
prepared,
capabilities,
config_json: RwLock::new(config_json),
channel_config: RwLock::new(None),
message_tx: Arc::new(RwLock::new(None)),
pending_responses: RwLock::new(HashMap::new()),
rate_limiter: Arc::new(RwLock::new(rate_limiter)),
shutdown_tx: RwLock::new(None),
poll_shutdown_tx: RwLock::new(None),
endpoints: RwLock::new(Vec::new()),
credentials: Arc::new(RwLock::new(HashMap::new())),
typing_task: RwLock::new(None),
pairing_store,
}
}
pub async fn update_config(&self, updates: HashMap<String, serde_json::Value>) {
let mut config_guard = self.config_json.write().await;
let mut config: HashMap<String, serde_json::Value> =
serde_json::from_str(&config_guard).unwrap_or_default();
for (key, value) in updates {
config.insert(key, value);
}
*config_guard = serde_json::to_string(&config).unwrap_or_else(|_| "{}".to_string());
tracing::debug!(
channel = %self.name,
config = %*config_guard,
"Updated channel config"
);
}
pub async fn set_credential(&self, name: &str, value: String) {
self.credentials
.write()
.await
.insert(name.to_string(), value);
}
pub async fn get_credentials(&self) -> HashMap<String, String> {
self.credentials.read().await.clone()
}
pub fn channel_name(&self) -> &str {
&self.name
}
pub fn capabilities(&self) -> &ChannelCapabilities {
&self.capabilities
}
pub async fn endpoints(&self) -> Vec<RegisteredEndpoint> {
self.endpoints.read().await.clone()
}
fn add_host_functions(linker: &mut Linker<ChannelStoreData>) -> Result<(), WasmChannelError> {
wasmtime_wasi::add_to_linker_sync(linker).map_err(|e| {
WasmChannelError::Config(format!("Failed to add WASI functions: {}", e))
})?;
near::agent::channel_host::add_to_linker(linker, |state| state).map_err(|e| {
WasmChannelError::Config(format!("Failed to add host functions: {}", e))
})?;
Ok(())
}
fn create_store(
runtime: &WasmChannelRuntime,
prepared: &PreparedChannelModule,
capabilities: &ChannelCapabilities,
credentials: HashMap<String, String>,
pairing_store: Arc<PairingStore>,
) -> Result<Store<ChannelStoreData>, WasmChannelError> {
let engine = runtime.engine();
let limits = &prepared.limits;
let store_data = ChannelStoreData::new(
limits.memory_bytes,
&prepared.name,
capabilities.clone(),
credentials,
pairing_store,
);
let mut store = Store::new(engine, store_data);
if runtime.config().fuel_config.enabled {
store
.set_fuel(limits.fuel)
.map_err(|e| WasmChannelError::Config(format!("Failed to set fuel: {}", e)))?;
}
store.epoch_deadline_trap();
store.set_epoch_deadline(1);
store.limiter(|data| &mut data.limiter);
Ok(store)
}
fn instantiate_component(
runtime: &WasmChannelRuntime,
prepared: &PreparedChannelModule,
store: &mut Store<ChannelStoreData>,
) -> Result<SandboxedChannel, WasmChannelError> {
let engine = runtime.engine();
let component = Component::new(engine, prepared.component_bytes())
.map_err(|e| WasmChannelError::Compilation(e.to_string()))?;
let mut linker = Linker::new(engine);
Self::add_host_functions(&mut linker)?;
let instance = SandboxedChannel::instantiate(store, &component, &linker)
.map_err(|e| WasmChannelError::Instantiation(e.to_string()))?;
Ok(instance)
}
fn map_wasm_error(e: anyhow::Error, name: &str, fuel_limit: u64) -> WasmChannelError {
let error_str = e.to_string();
if error_str.contains("out of fuel") {
WasmChannelError::FuelExhausted {
name: name.to_string(),
limit: fuel_limit,
}
} else if error_str.contains("unreachable") {
WasmChannelError::Trapped {
name: name.to_string(),
reason: "unreachable code executed".to_string(),
}
} else {
WasmChannelError::Trapped {
name: name.to_string(),
reason: error_str,
}
}
}
fn extract_host_state(
store: &mut Store<ChannelStoreData>,
channel_name: &str,
capabilities: &ChannelCapabilities,
) -> ChannelHostState {
std::mem::replace(
&mut store.data_mut().host_state,
ChannelHostState::new(channel_name, capabilities.clone()),
)
}
async fn call_on_start(&self) -> Result<ChannelConfig, WasmChannelError> {
if self.prepared.component_bytes.is_empty() {
tracing::info!(
channel = %self.name,
"WASM channel on_start called (no WASM module, returning defaults)"
);
return Ok(ChannelConfig {
display_name: self.prepared.description.clone(),
http_endpoints: Vec::new(),
poll: None,
});
}
let runtime = Arc::clone(&self.runtime);
let prepared = Arc::clone(&self.prepared);
let capabilities = self.capabilities.clone();
let config_json = self.config_json.read().await.clone();
let timeout = self.runtime.config().callback_timeout;
let channel_name = self.name.clone();
let credentials = self.get_credentials().await;
let pairing_store = self.pairing_store.clone();
let result = tokio::time::timeout(timeout, async move {
tokio::task::spawn_blocking(move || {
let mut store = Self::create_store(
&runtime,
&prepared,
&capabilities,
credentials,
pairing_store,
)?;
let instance = Self::instantiate_component(&runtime, &prepared, &mut store)?;
let channel_iface = instance.near_agent_channel();
let wasm_result = channel_iface
.call_on_start(&mut store, &config_json)
.map_err(|e| Self::map_wasm_error(e, &prepared.name, prepared.limits.fuel))?;
let config = match wasm_result {
Ok(wit_config) => convert_channel_config(wit_config),
Err(err_msg) => {
return Err(WasmChannelError::CallbackFailed {
name: prepared.name.clone(),
reason: err_msg,
});
}
};
let host_state =
Self::extract_host_state(&mut store, &prepared.name, &capabilities);
Ok((config, host_state))
})
.await
.map_err(|e| WasmChannelError::ExecutionPanicked {
name: channel_name.clone(),
reason: e.to_string(),
})?
})
.await;
match result {
Ok(Ok((config, mut host_state))) => {
for entry in host_state.take_logs() {
match entry.level {
crate::tools::wasm::LogLevel::Error => {
tracing::error!(channel = %self.name, "{}", entry.message);
}
crate::tools::wasm::LogLevel::Warn => {
tracing::warn!(channel = %self.name, "{}", entry.message);
}
_ => {
tracing::debug!(channel = %self.name, "{}", entry.message);
}
}
}
tracing::info!(
channel = %self.name,
display_name = %config.display_name,
endpoints = config.http_endpoints.len(),
"WASM channel on_start completed"
);
Ok(config)
}
Ok(Err(e)) => Err(e),
Err(_) => Err(WasmChannelError::Timeout {
name: self.name.clone(),
callback: "on_start".to_string(),
}),
}
}
pub async fn call_on_http_request(
&self,
method: &str,
path: &str,
headers: &HashMap<String, String>,
query: &HashMap<String, String>,
body: &[u8],
secret_validated: bool,
) -> Result<HttpResponse, WasmChannelError> {
tracing::info!(
channel = %self.name,
method = method,
path = path,
body_len = body.len(),
secret_validated = secret_validated,
"call_on_http_request invoked (webhook received)"
);
if let Ok(body_str) = std::str::from_utf8(body) {
let truncated = if body_str.chars().count() > 1000 {
format!("{}...", body_str.chars().take(1000).collect::<String>())
} else {
body_str.to_string()
};
tracing::debug!(body = %truncated, "Webhook request body");
}
let creds = self.get_credentials().await;
tracing::info!(
credential_count = creds.len(),
credential_names = ?creds.keys().collect::<Vec<_>>(),
"Credentials available for on_http_request"
);
if self.prepared.component_bytes.is_empty() {
tracing::debug!(
channel = %self.name,
method = method,
path = path,
"WASM channel on_http_request called (no WASM module)"
);
return Ok(HttpResponse::ok());
}
let runtime = Arc::clone(&self.runtime);
let prepared = Arc::clone(&self.prepared);
let capabilities = self.capabilities.clone();
let timeout = self.runtime.config().callback_timeout;
let credentials = self.get_credentials().await;
let pairing_store = self.pairing_store.clone();
let method = method.to_string();
let path = path.to_string();
let headers_json = serde_json::to_string(&headers).unwrap_or_default();
let query_json = serde_json::to_string(&query).unwrap_or_default();
let body = body.to_vec();
let channel_name = self.name.clone();
let result = tokio::time::timeout(timeout, async move {
tokio::task::spawn_blocking(move || {
let mut store = Self::create_store(
&runtime,
&prepared,
&capabilities,
credentials,
pairing_store,
)?;
let instance = Self::instantiate_component(&runtime, &prepared, &mut store)?;
let wit_request = wit_channel::IncomingHttpRequest {
method,
path,
headers_json,
query_json,
body,
secret_validated,
};
let channel_iface = instance.near_agent_channel();
let wit_response = channel_iface
.call_on_http_request(&mut store, &wit_request)
.map_err(|e| Self::map_wasm_error(e, &prepared.name, prepared.limits.fuel))?;
let response = convert_http_response(wit_response);
let host_state =
Self::extract_host_state(&mut store, &prepared.name, &capabilities);
Ok((response, host_state))
})
.await
.map_err(|e| WasmChannelError::ExecutionPanicked {
name: channel_name.clone(),
reason: e.to_string(),
})?
})
.await;
let channel_name = self.name.clone();
match result {
Ok(Ok((response, mut host_state))) => {
let emitted = host_state.take_emitted_messages();
self.process_emitted_messages(emitted).await?;
tracing::debug!(
channel = %channel_name,
status = response.status,
"WASM channel on_http_request completed"
);
Ok(response)
}
Ok(Err(e)) => Err(e),
Err(_) => Err(WasmChannelError::Timeout {
name: channel_name,
callback: "on_http_request".to_string(),
}),
}
}
pub async fn call_on_poll(&self) -> Result<(), WasmChannelError> {
if self.prepared.component_bytes.is_empty() {
tracing::debug!(
channel = %self.name,
"WASM channel on_poll called (no WASM module)"
);
return Ok(());
}
let runtime = Arc::clone(&self.runtime);
let prepared = Arc::clone(&self.prepared);
let capabilities = self.capabilities.clone();
let timeout = self.runtime.config().callback_timeout;
let channel_name = self.name.clone();
let credentials = self.get_credentials().await;
let pairing_store = self.pairing_store.clone();
let result = tokio::time::timeout(timeout, async move {
tokio::task::spawn_blocking(move || {
let mut store = Self::create_store(
&runtime,
&prepared,
&capabilities,
credentials,
pairing_store,
)?;
let instance = Self::instantiate_component(&runtime, &prepared, &mut store)?;
let channel_iface = instance.near_agent_channel();
channel_iface
.call_on_poll(&mut store)
.map_err(|e| Self::map_wasm_error(e, &prepared.name, prepared.limits.fuel))?;
let host_state =
Self::extract_host_state(&mut store, &prepared.name, &capabilities);
Ok(((), host_state))
})
.await
.map_err(|e| WasmChannelError::ExecutionPanicked {
name: channel_name.clone(),
reason: e.to_string(),
})?
})
.await;
let channel_name = self.name.clone();
match result {
Ok(Ok(((), mut host_state))) => {
let emitted = host_state.take_emitted_messages();
self.process_emitted_messages(emitted).await?;
tracing::debug!(
channel = %channel_name,
"WASM channel on_poll completed"
);
Ok(())
}
Ok(Err(e)) => Err(e),
Err(_) => Err(WasmChannelError::Timeout {
name: channel_name,
callback: "on_poll".to_string(),
}),
}
}
pub async fn call_on_respond(
&self,
message_id: Uuid,
content: &str,
thread_id: Option<&str>,
metadata_json: &str,
) -> Result<(), WasmChannelError> {
tracing::info!(
channel = %self.name,
message_id = %message_id,
content_len = content.len(),
thread_id = ?thread_id,
"call_on_respond invoked"
);
let creds = self.get_credentials().await;
tracing::info!(
credential_count = creds.len(),
credential_names = ?creds.keys().collect::<Vec<_>>(),
"Credentials available for on_respond"
);
if self.prepared.component_bytes.is_empty() {
tracing::debug!(
channel = %self.name,
message_id = %message_id,
"WASM channel on_respond called (no WASM module)"
);
return Ok(());
}
let runtime = Arc::clone(&self.runtime);
let prepared = Arc::clone(&self.prepared);
let capabilities = self.capabilities.clone();
let timeout = self.runtime.config().callback_timeout;
let channel_name = self.name.clone();
let credentials = self.get_credentials().await;
let pairing_store = self.pairing_store.clone();
let message_id_str = message_id.to_string();
let content = content.to_string();
let thread_id = thread_id.map(|s| s.to_string());
let metadata_json = metadata_json.to_string();
tracing::info!(channel = %channel_name, "Starting on_respond WASM execution");
let result = tokio::time::timeout(timeout, async move {
tokio::task::spawn_blocking(move || {
tracing::info!("Creating WASM store for on_respond");
let mut store = Self::create_store(
&runtime,
&prepared,
&capabilities,
credentials,
pairing_store,
)?;
tracing::info!("Instantiating WASM component for on_respond");
let instance = Self::instantiate_component(&runtime, &prepared, &mut store)?;
let wit_response = wit_channel::AgentResponse {
message_id: message_id_str,
content: content.clone(),
thread_id,
metadata_json,
};
let content_preview: String = content.chars().take(50).collect();
tracing::info!(
content_preview = %content_preview,
"Calling WASM on_respond"
);
let channel_iface = instance.near_agent_channel();
let wasm_result = channel_iface
.call_on_respond(&mut store, &wit_response)
.map_err(|e| {
tracing::error!(error = %e, "WASM on_respond call failed");
Self::map_wasm_error(e, &prepared.name, prepared.limits.fuel)
})?;
tracing::info!(wasm_result = ?wasm_result, "WASM on_respond returned");
if let Err(ref err_msg) = wasm_result {
tracing::error!(error = %err_msg, "WASM on_respond returned error");
return Err(WasmChannelError::CallbackFailed {
name: prepared.name.clone(),
reason: err_msg.clone(),
});
}
let host_state =
Self::extract_host_state(&mut store, &prepared.name, &capabilities);
tracing::info!("on_respond WASM execution completed successfully");
Ok(((), host_state))
})
.await
.map_err(|e| {
tracing::error!(error = %e, "spawn_blocking panicked");
WasmChannelError::ExecutionPanicked {
name: channel_name.clone(),
reason: e.to_string(),
}
})?
})
.await;
let channel_name = self.name.clone();
match result {
Ok(Ok(((), _host_state))) => {
tracing::debug!(
channel = %channel_name,
message_id = %message_id,
"WASM channel on_respond completed"
);
Ok(())
}
Ok(Err(e)) => Err(e),
Err(_) => Err(WasmChannelError::Timeout {
name: channel_name,
callback: "on_respond".to_string(),
}),
}
}
pub async fn call_on_status(
&self,
status: &StatusUpdate,
metadata: &serde_json::Value,
) -> Result<(), WasmChannelError> {
if self.prepared.component_bytes.is_empty() {
return Ok(());
}
let runtime = Arc::clone(&self.runtime);
let prepared = Arc::clone(&self.prepared);
let capabilities = self.capabilities.clone();
let timeout = self.runtime.config().callback_timeout;
let channel_name = self.name.clone();
let credentials = self.get_credentials().await;
let pairing_store = self.pairing_store.clone();
let wit_update = status_to_wit(status, metadata);
let result = tokio::time::timeout(timeout, async move {
tokio::task::spawn_blocking(move || {
let mut store = Self::create_store(
&runtime,
&prepared,
&capabilities,
credentials,
pairing_store,
)?;
let instance = Self::instantiate_component(&runtime, &prepared, &mut store)?;
let channel_iface = instance.near_agent_channel();
channel_iface
.call_on_status(&mut store, &wit_update)
.map_err(|e| Self::map_wasm_error(e, &prepared.name, prepared.limits.fuel))?;
Ok(())
})
.await
.map_err(|e| WasmChannelError::ExecutionPanicked {
name: channel_name.clone(),
reason: e.to_string(),
})?
})
.await;
match result {
Ok(Ok(())) => {
tracing::debug!(
channel = %self.name,
"WASM channel on_status completed"
);
Ok(())
}
Ok(Err(e)) => Err(e),
Err(_) => Err(WasmChannelError::Timeout {
name: self.name.clone(),
callback: "on_status".to_string(),
}),
}
}
#[allow(clippy::too_many_arguments)]
async fn execute_status(
channel_name: &str,
runtime: &Arc<WasmChannelRuntime>,
prepared: &Arc<PreparedChannelModule>,
capabilities: &ChannelCapabilities,
credentials: &RwLock<HashMap<String, String>>,
pairing_store: Arc<PairingStore>,
timeout: Duration,
wit_update: wit_channel::StatusUpdate,
) -> Result<(), WasmChannelError> {
if prepared.component_bytes.is_empty() {
return Ok(());
}
let runtime = Arc::clone(runtime);
let prepared = Arc::clone(prepared);
let capabilities = capabilities.clone();
let credentials_snapshot = credentials.read().await.clone();
let channel_name_owned = channel_name.to_string();
let result = tokio::time::timeout(timeout, async move {
tokio::task::spawn_blocking(move || {
let mut store = Self::create_store(
&runtime,
&prepared,
&capabilities,
credentials_snapshot,
pairing_store,
)?;
let instance = Self::instantiate_component(&runtime, &prepared, &mut store)?;
let channel_iface = instance.near_agent_channel();
channel_iface
.call_on_status(&mut store, &wit_update)
.map_err(|e| Self::map_wasm_error(e, &prepared.name, prepared.limits.fuel))?;
Ok(())
})
.await
.map_err(|e| WasmChannelError::ExecutionPanicked {
name: channel_name_owned.clone(),
reason: e.to_string(),
})?
})
.await;
match result {
Ok(Ok(())) => Ok(()),
Ok(Err(e)) => Err(e),
Err(_) => Err(WasmChannelError::Timeout {
name: channel_name.to_string(),
callback: "on_status".to_string(),
}),
}
}
async fn cancel_typing_task(&self) {
if let Some(handle) = self.typing_task.write().await.take() {
handle.abort();
}
}
async fn handle_status_update(
&self,
status: StatusUpdate,
metadata: &serde_json::Value,
) -> Result<(), ChannelError> {
match &status {
StatusUpdate::Thinking(_) => {
self.cancel_typing_task().await;
if let Err(e) = self.call_on_status(&status, metadata).await {
tracing::debug!(
channel = %self.name,
error = %e,
"on_status(Thinking) failed (best-effort)"
);
}
let channel_name = self.name.clone();
let runtime = Arc::clone(&self.runtime);
let prepared = Arc::clone(&self.prepared);
let capabilities = self.capabilities.clone();
let credentials = self.credentials.clone();
let pairing_store = self.pairing_store.clone();
let callback_timeout = self.runtime.config().callback_timeout;
let wit_update = status_to_wit(&status, metadata);
let handle = tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(4));
interval.tick().await;
loop {
interval.tick().await;
let wit_update_clone = clone_wit_status_update(&wit_update);
if let Err(e) = Self::execute_status(
&channel_name,
&runtime,
&prepared,
&capabilities,
&credentials,
pairing_store.clone(),
callback_timeout,
wit_update_clone,
)
.await
{
tracing::debug!(
channel = %channel_name,
error = %e,
"Typing repeat on_status failed (best-effort)"
);
}
}
});
*self.typing_task.write().await = Some(handle);
}
StatusUpdate::StreamChunk(_) => {
}
_ => {
self.cancel_typing_task().await;
if let Err(e) = self.call_on_status(&status, metadata).await {
tracing::debug!(
channel = %self.name,
error = %e,
"on_status failed (best-effort)"
);
}
}
}
Ok(())
}
async fn process_emitted_messages(
&self,
messages: Vec<EmittedMessage>,
) -> Result<(), WasmChannelError> {
tracing::info!(
channel = %self.name,
message_count = messages.len(),
"Processing emitted messages from WASM callback"
);
if messages.is_empty() {
tracing::debug!(channel = %self.name, "No messages emitted");
return Ok(());
}
let tx_guard = self.message_tx.read().await;
let Some(tx) = tx_guard.as_ref() else {
tracing::error!(
channel = %self.name,
count = messages.len(),
"Messages emitted but no sender available - channel may not be started!"
);
return Ok(());
};
let mut rate_limiter = self.rate_limiter.write().await;
for emitted in messages {
if !rate_limiter.check_and_record() {
tracing::warn!(
channel = %self.name,
"Message emission rate limited"
);
return Err(WasmChannelError::EmitRateLimited {
name: self.name.clone(),
});
}
let mut msg = IncomingMessage::new(&self.name, &emitted.user_id, &emitted.content);
if let Some(name) = emitted.user_name {
msg = msg.with_user_name(name);
}
if let Some(thread_id) = emitted.thread_id {
msg = msg.with_thread(thread_id);
}
if let Ok(metadata) = serde_json::from_str(&emitted.metadata_json) {
msg = msg.with_metadata(metadata);
}
tracing::info!(
channel = %self.name,
user_id = %emitted.user_id,
content_len = emitted.content.len(),
"Sending emitted message to agent"
);
if tx.send(msg).await.is_err() {
tracing::error!(
channel = %self.name,
"Failed to send emitted message, channel closed"
);
break;
}
tracing::info!(
channel = %self.name,
"Message successfully sent to agent queue"
);
}
Ok(())
}
fn start_polling(&self, interval: Duration, shutdown_rx: oneshot::Receiver<()>) {
let channel_name = self.name.clone();
let runtime = Arc::clone(&self.runtime);
let prepared = Arc::clone(&self.prepared);
let capabilities = self.capabilities.clone();
let message_tx = self.message_tx.clone();
let rate_limiter = self.rate_limiter.clone();
let credentials = self.credentials.clone();
let pairing_store = self.pairing_store.clone();
let callback_timeout = self.runtime.config().callback_timeout;
tokio::spawn(async move {
let mut interval_timer = tokio::time::interval(interval);
let mut shutdown = std::pin::pin!(shutdown_rx);
loop {
tokio::select! {
_ = interval_timer.tick() => {
tracing::debug!(
channel = %channel_name,
"Polling tick - calling on_poll"
);
let result = Self::execute_poll(
&channel_name,
&runtime,
&prepared,
&capabilities,
&credentials,
pairing_store.clone(),
callback_timeout,
).await;
match result {
Ok(emitted_messages) => {
if !emitted_messages.is_empty()
&& let Err(e) = Self::dispatch_emitted_messages(
&channel_name,
emitted_messages,
&message_tx,
&rate_limiter,
).await {
tracing::warn!(
channel = %channel_name,
error = %e,
"Failed to dispatch emitted messages from poll"
);
}
}
Err(e) => {
tracing::warn!(
channel = %channel_name,
error = %e,
"Polling callback failed"
);
}
}
}
_ = &mut shutdown => {
tracing::info!(
channel = %channel_name,
"Polling stopped"
);
break;
}
}
}
});
}
async fn execute_poll(
channel_name: &str,
runtime: &Arc<WasmChannelRuntime>,
prepared: &Arc<PreparedChannelModule>,
capabilities: &ChannelCapabilities,
credentials: &RwLock<HashMap<String, String>>,
pairing_store: Arc<PairingStore>,
timeout: Duration,
) -> Result<Vec<EmittedMessage>, WasmChannelError> {
if prepared.component_bytes.is_empty() {
tracing::debug!(
channel = %channel_name,
"WASM channel on_poll called (no WASM module)"
);
return Ok(Vec::new());
}
let runtime = Arc::clone(runtime);
let prepared = Arc::clone(prepared);
let capabilities = capabilities.clone();
let credentials_snapshot = credentials.read().await.clone();
let channel_name_owned = channel_name.to_string();
let result = tokio::time::timeout(timeout, async move {
tokio::task::spawn_blocking(move || {
let mut store = Self::create_store(
&runtime,
&prepared,
&capabilities,
credentials_snapshot,
pairing_store,
)?;
let instance = Self::instantiate_component(&runtime, &prepared, &mut store)?;
let channel_iface = instance.near_agent_channel();
channel_iface
.call_on_poll(&mut store)
.map_err(|e| Self::map_wasm_error(e, &prepared.name, prepared.limits.fuel))?;
let host_state =
Self::extract_host_state(&mut store, &prepared.name, &capabilities);
Ok(host_state)
})
.await
.map_err(|e| WasmChannelError::ExecutionPanicked {
name: channel_name_owned.clone(),
reason: e.to_string(),
})?
})
.await;
match result {
Ok(Ok(mut host_state)) => {
let emitted = host_state.take_emitted_messages();
tracing::debug!(
channel = %channel_name,
emitted_count = emitted.len(),
"WASM channel on_poll completed"
);
Ok(emitted)
}
Ok(Err(e)) => Err(e),
Err(_) => Err(WasmChannelError::Timeout {
name: channel_name.to_string(),
callback: "on_poll".to_string(),
}),
}
}
async fn dispatch_emitted_messages(
channel_name: &str,
messages: Vec<EmittedMessage>,
message_tx: &RwLock<Option<mpsc::Sender<IncomingMessage>>>,
rate_limiter: &RwLock<ChannelEmitRateLimiter>,
) -> Result<(), WasmChannelError> {
tracing::info!(
channel = %channel_name,
message_count = messages.len(),
"Processing emitted messages from polling callback"
);
let tx_guard = message_tx.read().await;
let Some(tx) = tx_guard.as_ref() else {
tracing::error!(
channel = %channel_name,
count = messages.len(),
"Messages emitted but no sender available - channel may not be started!"
);
return Ok(());
};
let mut limiter = rate_limiter.write().await;
for emitted in messages {
if !limiter.check_and_record() {
tracing::warn!(
channel = %channel_name,
"Message emission rate limited"
);
return Err(WasmChannelError::EmitRateLimited {
name: channel_name.to_string(),
});
}
let mut msg = IncomingMessage::new(channel_name, &emitted.user_id, &emitted.content);
if let Some(name) = emitted.user_name {
msg = msg.with_user_name(name);
}
if let Some(thread_id) = emitted.thread_id {
msg = msg.with_thread(thread_id);
}
if let Ok(metadata) = serde_json::from_str(&emitted.metadata_json) {
msg = msg.with_metadata(metadata);
}
tracing::info!(
channel = %channel_name,
user_id = %emitted.user_id,
content_len = emitted.content.len(),
"Sending polled message to agent"
);
if tx.send(msg).await.is_err() {
tracing::error!(
channel = %channel_name,
"Failed to send polled message, channel closed"
);
break;
}
tracing::info!(
channel = %channel_name,
"Message successfully sent to agent queue"
);
}
Ok(())
}
}
#[async_trait]
impl Channel for WasmChannel {
fn name(&self) -> &str {
&self.name
}
async fn start(&self) -> Result<MessageStream, ChannelError> {
let (tx, rx) = mpsc::channel(256);
*self.message_tx.write().await = Some(tx);
let (shutdown_tx, _shutdown_rx) = oneshot::channel();
*self.shutdown_tx.write().await = Some(shutdown_tx);
let config = self
.call_on_start()
.await
.map_err(|e| ChannelError::StartupFailed {
name: self.name.clone(),
reason: e.to_string(),
})?;
*self.channel_config.write().await = Some(config.clone());
let mut endpoints = Vec::new();
for endpoint in &config.http_endpoints {
if !self.capabilities.is_path_allowed(&endpoint.path) {
tracing::warn!(
channel = %self.name,
path = %endpoint.path,
"HTTP endpoint path not allowed by capabilities"
);
continue;
}
endpoints.push(RegisteredEndpoint {
channel_name: self.name.clone(),
path: endpoint.path.clone(),
methods: endpoint.methods.clone(),
require_secret: endpoint.require_secret,
});
}
*self.endpoints.write().await = endpoints;
if let Some(poll_config) = &config.poll
&& poll_config.enabled
{
let interval = self
.capabilities
.validate_poll_interval(poll_config.interval_ms)
.map_err(|e| ChannelError::StartupFailed {
name: self.name.clone(),
reason: e,
})?;
let (poll_shutdown_tx, poll_shutdown_rx) = oneshot::channel();
*self.poll_shutdown_tx.write().await = Some(poll_shutdown_tx);
self.start_polling(Duration::from_millis(interval as u64), poll_shutdown_rx);
}
tracing::info!(
channel = %self.name,
display_name = %config.display_name,
endpoints = config.http_endpoints.len(),
"WASM channel started"
);
Ok(Box::pin(ReceiverStream::new(rx)))
}
async fn respond(
&self,
msg: &IncomingMessage,
response: OutgoingResponse,
) -> Result<(), ChannelError> {
self.cancel_typing_task().await;
if let Some(tx) = self.pending_responses.write().await.remove(&msg.id) {
let _ = tx.send(response.content.clone());
}
let metadata_json = serde_json::to_string(&msg.metadata).unwrap_or_default();
self.call_on_respond(
msg.id,
&response.content,
response.thread_id.as_deref(),
&metadata_json,
)
.await
.map_err(|e| ChannelError::SendFailed {
name: self.name.clone(),
reason: e.to_string(),
})?;
Ok(())
}
async fn send_status(
&self,
status: StatusUpdate,
metadata: &serde_json::Value,
) -> Result<(), ChannelError> {
self.handle_status_update(status, metadata).await
}
async fn health_check(&self) -> Result<(), ChannelError> {
if self.message_tx.read().await.is_some() {
Ok(())
} else {
Err(ChannelError::HealthCheckFailed {
name: self.name.clone(),
})
}
}
async fn shutdown(&self) -> Result<(), ChannelError> {
self.cancel_typing_task().await;
if let Some(tx) = self.shutdown_tx.write().await.take() {
let _ = tx.send(());
}
let _ = self.poll_shutdown_tx.write().await.take();
*self.message_tx.write().await = None;
tracing::info!(
channel = %self.name,
"WASM channel shut down"
);
Ok(())
}
}
impl std::fmt::Debug for WasmChannel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WasmChannel")
.field("name", &self.name)
.field("prepared", &self.prepared.name)
.finish()
}
}
pub struct SharedWasmChannel {
inner: Arc<WasmChannel>,
}
impl SharedWasmChannel {
pub fn new(channel: Arc<WasmChannel>) -> Self {
Self { inner: channel }
}
pub fn inner(&self) -> &Arc<WasmChannel> {
&self.inner
}
}
impl std::fmt::Debug for SharedWasmChannel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SharedWasmChannel")
.field("inner", &self.inner)
.finish()
}
}
#[async_trait]
impl Channel for SharedWasmChannel {
fn name(&self) -> &str {
self.inner.name()
}
async fn start(&self) -> Result<MessageStream, ChannelError> {
self.inner.start().await
}
async fn respond(
&self,
msg: &IncomingMessage,
response: OutgoingResponse,
) -> Result<(), ChannelError> {
self.inner.respond(msg, response).await
}
async fn send_status(
&self,
status: StatusUpdate,
metadata: &serde_json::Value,
) -> Result<(), ChannelError> {
self.inner.send_status(status, metadata).await
}
async fn health_check(&self) -> Result<(), ChannelError> {
self.inner.health_check().await
}
async fn shutdown(&self) -> Result<(), ChannelError> {
self.inner.shutdown().await
}
}
use exports::near::agent::channel as wit_channel;
fn convert_channel_config(wit: wit_channel::ChannelConfig) -> ChannelConfig {
ChannelConfig {
display_name: wit.display_name,
http_endpoints: wit
.http_endpoints
.into_iter()
.map(
|ep| crate::channels::wasm::schema::HttpEndpointConfigSchema {
path: ep.path,
methods: ep.methods,
require_secret: ep.require_secret,
},
)
.collect(),
poll: wit
.poll
.map(|p| crate::channels::wasm::schema::PollConfigSchema {
interval_ms: p.interval_ms,
enabled: p.enabled,
}),
}
}
fn convert_http_response(wit: wit_channel::OutgoingHttpResponse) -> HttpResponse {
let headers = serde_json::from_str(&wit.headers_json).unwrap_or_default();
HttpResponse {
status: wit.status,
headers,
body: wit.body,
}
}
fn status_to_wit(status: &StatusUpdate, metadata: &serde_json::Value) -> wit_channel::StatusUpdate {
let metadata_json = serde_json::to_string(metadata).unwrap_or_default();
match status {
StatusUpdate::Thinking(msg) => wit_channel::StatusUpdate {
status: wit_channel::StatusType::Thinking,
message: msg.clone(),
metadata_json,
},
StatusUpdate::ToolStarted { name } => wit_channel::StatusUpdate {
status: wit_channel::StatusType::ToolStarted,
message: name.clone(),
metadata_json,
},
StatusUpdate::ToolCompleted { name, success } => wit_channel::StatusUpdate {
status: wit_channel::StatusType::ToolCompleted,
message: format!("{}: {}", name, if *success { "ok" } else { "failed" }),
metadata_json,
},
StatusUpdate::ToolResult { name, preview } => wit_channel::StatusUpdate {
status: wit_channel::StatusType::ToolCompleted,
message: format!("{}: {}", name, preview),
metadata_json,
},
StatusUpdate::StreamChunk(chunk) => wit_channel::StatusUpdate {
status: wit_channel::StatusType::Thinking,
message: chunk.clone(),
metadata_json,
},
StatusUpdate::Status(msg) => {
let status_type = match msg.as_str() {
"Done" => wit_channel::StatusType::Done,
"Interrupted" => wit_channel::StatusType::Interrupted,
_ => wit_channel::StatusType::Thinking,
};
wit_channel::StatusUpdate {
status: status_type,
message: msg.clone(),
metadata_json,
}
}
StatusUpdate::ApprovalNeeded {
tool_name,
description,
..
} => wit_channel::StatusUpdate {
status: wit_channel::StatusType::Thinking,
message: format!("Approval needed: {} - {}", tool_name, description),
metadata_json,
},
StatusUpdate::JobStarted { job_id, title, .. } => wit_channel::StatusUpdate {
status: wit_channel::StatusType::Thinking,
message: format!("Job started: {} ({})", title, job_id),
metadata_json,
},
StatusUpdate::AuthRequired { extension_name, .. } => wit_channel::StatusUpdate {
status: wit_channel::StatusType::Thinking,
message: format!("Auth required: {}", extension_name),
metadata_json,
},
StatusUpdate::AuthCompleted {
extension_name,
success,
..
} => wit_channel::StatusUpdate {
status: wit_channel::StatusType::Thinking,
message: format!(
"Auth {}: {}",
if *success { "completed" } else { "failed" },
extension_name
),
metadata_json,
},
}
}
fn clone_wit_status_update(update: &wit_channel::StatusUpdate) -> wit_channel::StatusUpdate {
wit_channel::StatusUpdate {
status: match update.status {
wit_channel::StatusType::Thinking => wit_channel::StatusType::Thinking,
wit_channel::StatusType::Done => wit_channel::StatusType::Done,
wit_channel::StatusType::Interrupted => wit_channel::StatusType::Interrupted,
wit_channel::StatusType::ToolStarted => wit_channel::StatusType::ToolStarted,
wit_channel::StatusType::ToolCompleted => wit_channel::StatusType::ToolCompleted,
},
message: update.message.clone(),
metadata_json: update.metadata_json.clone(),
}
}
#[derive(Debug, Clone)]
pub struct HttpResponse {
pub status: u16,
pub headers: HashMap<String, String>,
pub body: Vec<u8>,
}
impl HttpResponse {
pub fn ok() -> Self {
Self {
status: 200,
headers: HashMap::new(),
body: Vec::new(),
}
}
pub fn json(value: serde_json::Value) -> Self {
let body = serde_json::to_vec(&value).unwrap_or_default();
let mut headers = HashMap::new();
headers.insert("Content-Type".to_string(), "application/json".to_string());
Self {
status: 200,
headers,
body,
}
}
pub fn error(status: u16, message: &str) -> Self {
Self {
status,
headers: HashMap::new(),
body: message.as_bytes().to_vec(),
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::channels::Channel;
use crate::channels::wasm::capabilities::ChannelCapabilities;
use crate::channels::wasm::runtime::{
PreparedChannelModule, WasmChannelRuntime, WasmChannelRuntimeConfig,
};
use crate::channels::wasm::wrapper::{HttpResponse, WasmChannel};
use crate::pairing::PairingStore;
use crate::tools::wasm::ResourceLimits;
fn create_test_channel() -> WasmChannel {
let config = WasmChannelRuntimeConfig::for_testing();
let runtime = Arc::new(WasmChannelRuntime::new(config).unwrap());
let prepared = Arc::new(PreparedChannelModule {
name: "test".to_string(),
description: "Test channel".to_string(),
component_bytes: Vec::new(),
limits: ResourceLimits::default(),
});
let capabilities = ChannelCapabilities::for_channel("test").with_path("/webhook/test");
WasmChannel::new(
runtime,
prepared,
capabilities,
"{}".to_string(),
Arc::new(PairingStore::new()),
)
}
#[test]
fn test_channel_name() {
let channel = create_test_channel();
assert_eq!(channel.name(), "test");
}
#[test]
fn test_http_response_ok() {
let response = HttpResponse::ok();
assert_eq!(response.status, 200);
assert!(response.body.is_empty());
}
#[test]
fn test_http_response_json() {
let response = HttpResponse::json(serde_json::json!({"key": "value"}));
assert_eq!(response.status, 200);
assert_eq!(
response.headers.get("Content-Type"),
Some(&"application/json".to_string())
);
}
#[test]
fn test_http_response_error() {
let response = HttpResponse::error(400, "Bad request");
assert_eq!(response.status, 400);
assert_eq!(response.body, b"Bad request");
}
#[tokio::test]
async fn test_channel_start_and_shutdown() {
let channel = create_test_channel();
let stream = channel.start().await;
assert!(stream.is_ok());
assert!(channel.health_check().await.is_ok());
assert!(channel.shutdown().await.is_ok());
assert!(channel.health_check().await.is_err());
}
#[tokio::test]
async fn test_execute_poll_no_wasm_returns_empty() {
let config = WasmChannelRuntimeConfig::for_testing();
let runtime = Arc::new(WasmChannelRuntime::new(config).unwrap());
let prepared = Arc::new(PreparedChannelModule {
name: "poll-test".to_string(),
description: "Test channel".to_string(),
component_bytes: Vec::new(), limits: ResourceLimits::default(),
});
let capabilities = ChannelCapabilities::for_channel("poll-test").with_polling(1000);
let credentials = Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new()));
let timeout = std::time::Duration::from_secs(5);
let result = WasmChannel::execute_poll(
"poll-test",
&runtime,
&prepared,
&capabilities,
&credentials,
Arc::new(PairingStore::new()),
timeout,
)
.await;
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
}
#[tokio::test]
async fn test_dispatch_emitted_messages_sends_to_channel() {
use crate::channels::wasm::host::EmittedMessage;
let (tx, mut rx) = tokio::sync::mpsc::channel(10);
let message_tx = Arc::new(tokio::sync::RwLock::new(Some(tx)));
let rate_limiter = Arc::new(tokio::sync::RwLock::new(
crate::channels::wasm::host::ChannelEmitRateLimiter::new(
crate::channels::wasm::capabilities::EmitRateLimitConfig::default(),
),
));
let messages = vec![
EmittedMessage::new("user1", "Hello from polling!"),
EmittedMessage::new("user2", "Another message"),
];
let result = WasmChannel::dispatch_emitted_messages(
"test-channel",
messages,
&message_tx,
&rate_limiter,
)
.await;
assert!(result.is_ok());
let msg1 = rx.try_recv().expect("Should receive first message");
assert_eq!(msg1.user_id, "user1");
assert_eq!(msg1.content, "Hello from polling!");
let msg2 = rx.try_recv().expect("Should receive second message");
assert_eq!(msg2.user_id, "user2");
assert_eq!(msg2.content, "Another message");
assert!(rx.try_recv().is_err());
}
#[tokio::test]
async fn test_dispatch_emitted_messages_no_sender_returns_ok() {
use crate::channels::wasm::host::EmittedMessage;
let message_tx = Arc::new(tokio::sync::RwLock::new(None));
let rate_limiter = Arc::new(tokio::sync::RwLock::new(
crate::channels::wasm::host::ChannelEmitRateLimiter::new(
crate::channels::wasm::capabilities::EmitRateLimitConfig::default(),
),
));
let messages = vec![EmittedMessage::new("user1", "Hello!")];
let result = WasmChannel::dispatch_emitted_messages(
"test-channel",
messages,
&message_tx,
&rate_limiter,
)
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_channel_with_polling_stores_shutdown_sender() {
let config = WasmChannelRuntimeConfig::for_testing();
let runtime = Arc::new(WasmChannelRuntime::new(config).unwrap());
let prepared = Arc::new(PreparedChannelModule {
name: "poll-channel".to_string(),
description: "Polling test channel".to_string(),
component_bytes: Vec::new(),
limits: ResourceLimits::default(),
});
let capabilities = ChannelCapabilities::for_channel("poll-channel")
.with_path("/webhook/poll")
.with_polling(1000);
let channel = WasmChannel::new(
runtime,
prepared,
capabilities,
"{}".to_string(),
Arc::new(PairingStore::new()),
);
let _stream = channel.start().await.expect("Channel should start");
assert!(channel.health_check().await.is_ok());
channel.shutdown().await.expect("Shutdown should succeed");
assert!(channel.health_check().await.is_err());
}
#[tokio::test]
async fn test_call_on_poll_no_wasm_succeeds() {
let channel = create_test_channel();
let _stream = channel.start().await.expect("Channel should start");
let result = channel.call_on_poll().await;
assert!(result.is_ok());
channel.shutdown().await.expect("Shutdown should succeed");
}
#[tokio::test]
async fn test_typing_task_starts_on_thinking() {
let channel = create_test_channel();
let _stream = channel.start().await.expect("Channel should start");
let metadata = serde_json::json!({"chat_id": 123});
let result = channel
.send_status(
crate::channels::StatusUpdate::Thinking("Processing...".into()),
&metadata,
)
.await;
assert!(result.is_ok());
assert!(channel.typing_task.read().await.is_some());
channel.shutdown().await.expect("Shutdown should succeed");
assert!(channel.typing_task.read().await.is_none());
}
#[tokio::test]
async fn test_typing_task_cancelled_on_done() {
let channel = create_test_channel();
let _stream = channel.start().await.expect("Channel should start");
let metadata = serde_json::json!({"chat_id": 123});
let _ = channel
.send_status(
crate::channels::StatusUpdate::Thinking("Processing...".into()),
&metadata,
)
.await;
assert!(channel.typing_task.read().await.is_some());
let _ = channel
.send_status(
crate::channels::StatusUpdate::Status("Done".into()),
&metadata,
)
.await;
assert!(channel.typing_task.read().await.is_none());
channel.shutdown().await.expect("Shutdown should succeed");
}
#[tokio::test]
async fn test_typing_task_replaced_on_new_thinking() {
let channel = create_test_channel();
let _stream = channel.start().await.expect("Channel should start");
let metadata = serde_json::json!({"chat_id": 123});
let _ = channel
.send_status(
crate::channels::StatusUpdate::Thinking("First...".into()),
&metadata,
)
.await;
let first_handle = {
let guard = channel.typing_task.read().await;
guard.as_ref().map(|h| h.id())
};
assert!(first_handle.is_some());
let _ = channel
.send_status(
crate::channels::StatusUpdate::Thinking("Second...".into()),
&metadata,
)
.await;
let second_handle = {
let guard = channel.typing_task.read().await;
guard.as_ref().map(|h| h.id())
};
assert!(second_handle.is_some());
assert_ne!(first_handle, second_handle);
channel.shutdown().await.expect("Shutdown should succeed");
}
#[tokio::test]
async fn test_respond_cancels_typing_task() {
use crate::channels::IncomingMessage;
let channel = create_test_channel();
let _stream = channel.start().await.expect("Channel should start");
let metadata = serde_json::json!({"chat_id": 123});
let _ = channel
.send_status(
crate::channels::StatusUpdate::Thinking("Processing...".into()),
&metadata,
)
.await;
assert!(channel.typing_task.read().await.is_some());
let msg = IncomingMessage::new("test", "user1", "hello").with_metadata(metadata);
let _ = channel
.respond(&msg, crate::channels::OutgoingResponse::text("response"))
.await;
assert!(channel.typing_task.read().await.is_none());
channel.shutdown().await.expect("Shutdown should succeed");
}
#[tokio::test]
async fn test_stream_chunk_is_noop() {
let channel = create_test_channel();
let _stream = channel.start().await.expect("Channel should start");
let metadata = serde_json::json!({"chat_id": 123});
let result = channel
.send_status(
crate::channels::StatusUpdate::StreamChunk("chunk".into()),
&metadata,
)
.await;
assert!(result.is_ok());
assert!(channel.typing_task.read().await.is_none());
channel.shutdown().await.expect("Shutdown should succeed");
}
#[test]
fn test_status_to_wit_thinking() {
use super::status_to_wit;
let metadata = serde_json::json!({"chat_id": 42});
let wit = status_to_wit(
&crate::channels::StatusUpdate::Thinking("Processing...".into()),
&metadata,
);
assert!(matches!(
wit.status,
super::wit_channel::StatusType::Thinking
));
assert_eq!(wit.message, "Processing...");
assert!(wit.metadata_json.contains("42"));
}
#[test]
fn test_status_to_wit_done() {
use super::status_to_wit;
let metadata = serde_json::json!(null);
let wit = status_to_wit(
&crate::channels::StatusUpdate::Status("Done".into()),
&metadata,
);
assert!(matches!(wit.status, super::wit_channel::StatusType::Done));
}
#[test]
fn test_status_to_wit_interrupted() {
use super::status_to_wit;
let metadata = serde_json::json!(null);
let wit = status_to_wit(
&crate::channels::StatusUpdate::Status("Interrupted".into()),
&metadata,
);
assert!(matches!(
wit.status,
super::wit_channel::StatusType::Interrupted
));
}
#[test]
fn test_clone_wit_status_update() {
use super::{clone_wit_status_update, wit_channel};
let original = wit_channel::StatusUpdate {
status: wit_channel::StatusType::Thinking,
message: "hello".to_string(),
metadata_json: "{\"a\":1}".to_string(),
};
let cloned = clone_wit_status_update(&original);
assert!(matches!(cloned.status, wit_channel::StatusType::Thinking));
assert_eq!(cloned.message, "hello");
assert_eq!(cloned.metadata_json, "{\"a\":1}");
}
#[test]
fn test_redact_credentials_replaces_values() {
use super::ChannelStoreData;
let mut creds = std::collections::HashMap::new();
creds.insert(
"TELEGRAM_BOT_TOKEN".to_string(),
"8218490433:AAEZeUxwqZ5OO3mOCXv7fKvpdhDgsmBBNis".to_string(),
);
creds.insert("OTHER_SECRET".to_string(), "s3cret".to_string());
let store = ChannelStoreData::new(
1024 * 1024,
"test",
ChannelCapabilities::default(),
creds,
Arc::new(PairingStore::new()),
);
let error = "HTTP request failed: error sending request for url \
(https://api.telegram.org/bot8218490433:AAEZeUxwqZ5OO3mOCXv7fKvpdhDgsmBBNis/getUpdates)";
let redacted = store.redact_credentials(error);
assert!(
!redacted.contains("8218490433:AAEZeUxwqZ5OO3mOCXv7fKvpdhDgsmBBNis"),
"credential value should be redacted"
);
assert!(
redacted.contains("[REDACTED:TELEGRAM_BOT_TOKEN]"),
"redacted text should contain placeholder name"
);
assert!(
!redacted.contains("s3cret"),
"other credentials should also be redacted"
);
}
#[test]
fn test_redact_credentials_no_op_without_credentials() {
use super::ChannelStoreData;
let store = ChannelStoreData::new(
1024 * 1024,
"test",
ChannelCapabilities::default(),
std::collections::HashMap::new(),
Arc::new(PairingStore::new()),
);
let input = "some error message";
assert_eq!(store.redact_credentials(input), input);
}
#[test]
fn test_redact_credentials_skips_empty_values() {
use super::ChannelStoreData;
let mut creds = std::collections::HashMap::new();
creds.insert("EMPTY_TOKEN".to_string(), String::new());
let store = ChannelStoreData::new(
1024 * 1024,
"test",
ChannelCapabilities::default(),
creds,
Arc::new(PairingStore::new()),
);
let input = "should not match anything";
assert_eq!(store.redact_credentials(input), input);
}
#[tokio::test]
async fn test_dedicated_runtime_inside_spawn_blocking() {
let result = tokio::task::spawn_blocking(|| {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("failed to build runtime");
rt.block_on(async { 42 })
})
.await
.expect("spawn_blocking panicked");
assert_eq!(result, 42);
}
#[tokio::test]
#[ignore] async fn test_dedicated_runtime_real_http() {
let result = tokio::task::spawn_blocking(|| {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("failed to build runtime");
rt.block_on(async {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(10))
.build()
.expect("failed to build client");
let resp = client
.get("https://api.telegram.org/bot000/getMe")
.timeout(std::time::Duration::from_secs(10))
.send()
.await;
match resp {
Ok(r) => r.status().as_u16(),
Err(e) if e.is_timeout() => panic!("request timed out: {e}"),
Err(e) => panic!("unexpected error: {e}"),
}
})
})
.await
.expect("spawn_blocking panicked");
assert_eq!(result, 404);
}
}