use std::collections::HashMap;
use std::sync::Arc;
use axum::{
Json, Router,
body::Bytes,
extract::{Path, Query, State},
http::{HeaderMap, Method, StatusCode},
response::IntoResponse,
routing::{get, post},
};
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use crate::channels::wasm::wrapper::WasmChannel;
#[derive(Debug, Clone)]
pub struct RegisteredEndpoint {
pub channel_name: String,
pub path: String,
pub methods: Vec<String>,
pub require_secret: bool,
}
pub struct WasmChannelRouter {
channels: RwLock<HashMap<String, Arc<WasmChannel>>>,
path_to_channel: RwLock<HashMap<String, String>>,
secrets: RwLock<HashMap<String, String>>,
secret_headers: RwLock<HashMap<String, String>>,
}
impl WasmChannelRouter {
pub fn new() -> Self {
Self {
channels: RwLock::new(HashMap::new()),
path_to_channel: RwLock::new(HashMap::new()),
secrets: RwLock::new(HashMap::new()),
secret_headers: RwLock::new(HashMap::new()),
}
}
pub async fn register(
&self,
channel: Arc<WasmChannel>,
endpoints: Vec<RegisteredEndpoint>,
secret: Option<String>,
secret_header: Option<String>,
) {
let name = channel.channel_name().to_string();
self.channels.write().await.insert(name.clone(), channel);
let mut path_map = self.path_to_channel.write().await;
for endpoint in endpoints {
path_map.insert(endpoint.path.clone(), name.clone());
tracing::info!(
channel = %name,
path = %endpoint.path,
methods = ?endpoint.methods,
"Registered WASM channel HTTP endpoint"
);
}
if let Some(s) = secret {
self.secrets.write().await.insert(name.clone(), s);
}
if let Some(h) = secret_header {
self.secret_headers.write().await.insert(name, h);
}
}
pub async fn get_secret_header(&self, channel_name: &str) -> String {
self.secret_headers
.read()
.await
.get(channel_name)
.cloned()
.unwrap_or_else(|| "X-Webhook-Secret".to_string())
}
pub async fn unregister(&self, channel_name: &str) {
self.channels.write().await.remove(channel_name);
self.secrets.write().await.remove(channel_name);
self.secret_headers.write().await.remove(channel_name);
self.path_to_channel
.write()
.await
.retain(|_, name| name != channel_name);
tracing::info!(
channel = %channel_name,
"Unregistered WASM channel"
);
}
pub async fn get_channel_for_path(&self, path: &str) -> Option<Arc<WasmChannel>> {
let path_map = self.path_to_channel.read().await;
let channel_name = path_map.get(path)?;
self.channels.read().await.get(channel_name).cloned()
}
pub async fn validate_secret(&self, channel_name: &str, provided: &str) -> bool {
let secrets = self.secrets.read().await;
match secrets.get(channel_name) {
Some(expected) => expected == provided,
None => true, }
}
pub async fn requires_secret(&self, channel_name: &str) -> bool {
self.secrets.read().await.contains_key(channel_name)
}
pub async fn list_channels(&self) -> Vec<String> {
self.channels.read().await.keys().cloned().collect()
}
pub async fn list_paths(&self) -> Vec<String> {
self.path_to_channel.read().await.keys().cloned().collect()
}
}
impl Default for WasmChannelRouter {
fn default() -> Self {
Self::new()
}
}
#[allow(dead_code)]
#[derive(Clone)]
pub struct RouterState {
router: Arc<WasmChannelRouter>,
extension_manager: Option<Arc<crate::extensions::ExtensionManager>>,
}
impl RouterState {
pub fn new(router: Arc<WasmChannelRouter>) -> Self {
Self {
router,
extension_manager: None,
}
}
pub fn with_extension_manager(
mut self,
manager: Arc<crate::extensions::ExtensionManager>,
) -> Self {
self.extension_manager = Some(manager);
self
}
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
pub struct WasmWebhookRequest {
#[serde(default)]
pub secret: Option<String>,
}
#[allow(dead_code)]
#[derive(Debug, Serialize)]
struct HealthResponse {
status: String,
channels: Vec<String>,
}
#[allow(dead_code)]
async fn health_handler(State(state): State<RouterState>) -> impl IntoResponse {
let channels = state.router.list_channels().await;
Json(HealthResponse {
status: "healthy".to_string(),
channels,
})
}
async fn webhook_handler(
State(state): State<RouterState>,
method: Method,
Path(path): Path<String>,
Query(query): Query<HashMap<String, String>>,
headers: HeaderMap,
body: Bytes,
) -> impl IntoResponse {
let full_path = format!("/webhook/{}", path);
tracing::info!(
method = %method,
path = %full_path,
body_len = body.len(),
"Webhook request received"
);
let channel = match state.router.get_channel_for_path(&full_path).await {
Some(c) => c,
None => {
tracing::warn!(
path = %full_path,
"No channel registered for webhook path"
);
return (
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": "Channel not found for path",
"path": full_path
})),
);
}
};
tracing::info!(
channel = %channel.channel_name(),
"Found channel for webhook"
);
let channel_name = channel.channel_name();
if state.router.requires_secret(channel_name).await {
let secret_header_name = state.router.get_secret_header(channel_name).await;
let provided_secret = query
.get("secret")
.cloned()
.or_else(|| {
headers
.get(&secret_header_name)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
})
.or_else(|| {
if secret_header_name != "X-Webhook-Secret" {
headers
.get("X-Webhook-Secret")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
} else {
None
}
});
tracing::debug!(
channel = %channel_name,
has_provided_secret = provided_secret.is_some(),
provided_secret_len = provided_secret.as_ref().map(|s| s.len()),
"Checking webhook secret"
);
match provided_secret {
Some(secret) => {
if !state.router.validate_secret(channel_name, &secret).await {
tracing::warn!(
channel = %channel_name,
"Webhook secret validation failed"
);
return (
StatusCode::UNAUTHORIZED,
Json(serde_json::json!({
"error": "Invalid webhook secret"
})),
);
}
tracing::debug!(channel = %channel_name, "Webhook secret validated");
}
None => {
tracing::warn!(
channel = %channel_name,
"Webhook secret required but not provided"
);
return (
StatusCode::UNAUTHORIZED,
Json(serde_json::json!({
"error": "Webhook secret required"
})),
);
}
}
}
let headers_map: HashMap<String, String> = headers
.iter()
.filter_map(|(k, v)| {
v.to_str()
.ok()
.map(|v| (k.as_str().to_string(), v.to_string()))
})
.collect();
let secret_validated = state.router.requires_secret(channel_name).await;
tracing::info!(
channel = %channel_name,
secret_validated = secret_validated,
"Calling WASM channel on_http_request"
);
match channel
.call_on_http_request(
method.as_str(),
&full_path,
&headers_map,
&query,
&body,
secret_validated,
)
.await
{
Ok(response) => {
let status =
StatusCode::from_u16(response.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
tracing::info!(
channel = %channel_name,
status = %status,
body_len = response.body.len(),
"WASM channel on_http_request completed successfully"
);
let body_json: serde_json::Value = serde_json::from_slice(&response.body)
.unwrap_or_else(|_| {
serde_json::json!({
"raw": String::from_utf8_lossy(&response.body).to_string()
})
});
(status, Json(body_json))
}
Err(e) => {
tracing::error!(
channel = %channel_name,
error = %e,
"WASM channel callback failed"
);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": "Channel callback failed",
"details": e.to_string()
})),
)
}
}
}
#[allow(dead_code)]
async fn oauth_callback_handler(
State(_state): State<RouterState>,
Query(params): Query<HashMap<String, String>>,
) -> impl IntoResponse {
let code = params.get("code").cloned().unwrap_or_default();
let _state = params.get("state").cloned().unwrap_or_default();
if code.is_empty() {
let error = params
.get("error")
.cloned()
.unwrap_or_else(|| "unknown".to_string());
return (
StatusCode::BAD_REQUEST,
axum::response::Html(format!(
"<!DOCTYPE html><html><body style=\"font-family: sans-serif; \
display: flex; justify-content: center; align-items: center; \
height: 100vh; margin: 0; background: #191919; color: white;\">\
<div style=\"text-align: center;\">\
<h1>Authorization Failed</h1>\
<p>Error: {}</p>\
</div></body></html>",
error
)),
);
}
(
StatusCode::OK,
axum::response::Html(
"<!DOCTYPE html><html><body style=\"font-family: sans-serif; \
display: flex; justify-content: center; align-items: center; \
height: 100vh; margin: 0; background: #191919; color: white;\">\
<div style=\"text-align: center;\">\
<h1>Connected!</h1>\
<p>You can close this window and return to IronClaw.</p>\
</div></body></html>"
.to_string(),
),
)
}
pub fn create_wasm_channel_router(
router: Arc<WasmChannelRouter>,
extension_manager: Option<Arc<crate::extensions::ExtensionManager>>,
) -> Router {
let mut state = RouterState::new(router);
if let Some(manager) = extension_manager {
state = state.with_extension_manager(manager);
}
Router::new()
.route("/wasm-channels/health", get(health_handler))
.route("/oauth/callback", get(oauth_callback_handler))
.route("/webhook/{*path}", get(webhook_handler))
.route("/webhook/{*path}", post(webhook_handler))
.with_state(state)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::channels::wasm::capabilities::ChannelCapabilities;
use crate::channels::wasm::router::{RegisteredEndpoint, WasmChannelRouter};
use crate::channels::wasm::runtime::{
PreparedChannelModule, WasmChannelRuntime, WasmChannelRuntimeConfig,
};
use crate::channels::wasm::wrapper::WasmChannel;
use crate::pairing::PairingStore;
use crate::tools::wasm::ResourceLimits;
fn create_test_channel(name: &str) -> Arc<WasmChannel> {
let config = WasmChannelRuntimeConfig::for_testing();
let runtime = Arc::new(WasmChannelRuntime::new(config).unwrap());
let prepared = Arc::new(PreparedChannelModule {
name: name.to_string(),
description: format!("Test channel: {}", name),
component_bytes: Vec::new(),
limits: ResourceLimits::default(),
});
let capabilities =
ChannelCapabilities::for_channel(name).with_path(format!("/webhook/{}", name));
Arc::new(WasmChannel::new(
runtime,
prepared,
capabilities,
"{}".to_string(),
Arc::new(PairingStore::new()),
))
}
#[tokio::test]
async fn test_router_register_and_lookup() {
let router = WasmChannelRouter::new();
let channel = create_test_channel("slack");
let endpoints = vec![RegisteredEndpoint {
channel_name: "slack".to_string(),
path: "/webhook/slack".to_string(),
methods: vec!["POST".to_string()],
require_secret: true,
}];
router
.register(channel, endpoints, Some("secret123".to_string()), None)
.await;
let found = router.get_channel_for_path("/webhook/slack").await;
assert!(found.is_some());
assert_eq!(found.unwrap().channel_name(), "slack");
let not_found = router.get_channel_for_path("/webhook/telegram").await;
assert!(not_found.is_none());
}
#[tokio::test]
async fn test_router_secret_validation() {
let router = WasmChannelRouter::new();
let channel = create_test_channel("slack");
router
.register(channel, vec![], Some("secret123".to_string()), None)
.await;
assert!(router.validate_secret("slack", "secret123").await);
assert!(!router.validate_secret("slack", "wrong").await);
let channel2 = create_test_channel("telegram");
router.register(channel2, vec![], None, None).await;
assert!(router.validate_secret("telegram", "anything").await);
}
#[tokio::test]
async fn test_router_unregister() {
let router = WasmChannelRouter::new();
let channel = create_test_channel("slack");
let endpoints = vec![RegisteredEndpoint {
channel_name: "slack".to_string(),
path: "/webhook/slack".to_string(),
methods: vec!["POST".to_string()],
require_secret: false,
}];
router.register(channel, endpoints, None, None).await;
assert!(
router
.get_channel_for_path("/webhook/slack")
.await
.is_some()
);
router.unregister("slack").await;
assert!(
router
.get_channel_for_path("/webhook/slack")
.await
.is_none()
);
}
#[tokio::test]
async fn test_router_list_channels() {
let router = WasmChannelRouter::new();
let channel1 = create_test_channel("slack");
let channel2 = create_test_channel("telegram");
router.register(channel1, vec![], None, None).await;
router.register(channel2, vec![], None, None).await;
let channels = router.list_channels().await;
assert_eq!(channels.len(), 2);
assert!(channels.contains(&"slack".to_string()));
assert!(channels.contains(&"telegram".to_string()));
}
#[tokio::test]
async fn test_router_secret_header() {
let router = WasmChannelRouter::new();
let channel = create_test_channel("telegram");
router
.register(
channel,
vec![],
Some("secret123".to_string()),
Some("X-Telegram-Bot-Api-Secret-Token".to_string()),
)
.await;
assert_eq!(
router.get_secret_header("telegram").await,
"X-Telegram-Bot-Api-Secret-Token"
);
let channel2 = create_test_channel("slack");
router
.register(channel2, vec![], Some("secret456".to_string()), None)
.await;
assert_eq!(router.get_secret_header("slack").await, "X-Webhook-Secret");
}
}