use crate::{
AppConfig, CertificateAuthority, WitmProxy,
config::{confique_app_config_layer::AppConfigLayer, expand_home_in_path},
db::Db,
plugins::registry::PluginRegistry,
wasm::Runtime,
};
use plugin::PluginCommands;
use proxy::ProxyCommands;
use trust::TrustCommands;
use anyhow::Result;
use clap::{Parser, Subcommand};
use confique::Config;
use notify::{Event as NotifyEvent, RecommendedWatcher, RecursiveMode, Watcher, event::ModifyKind};
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, path::PathBuf, sync::Arc};
use tokio::sync::{RwLock, mpsc};
use tracing::{error, info, warn};
mod plugin;
mod proxy;
mod trust;
#[cfg(test)]
mod tests;
#[derive(Parser)]
#[command(name = "witmproxy")]
#[command(about = "A WASM-in-the-middle proxy")]
pub struct Cli {
#[command(subcommand)]
command: Option<Commands>,
#[arg(short, long, default_value = "$HOME/.witmproxy/config.toml")]
config_path: PathBuf,
#[command(flatten)]
config: AppConfigLayer,
#[arg(short, long)]
verbose: bool,
#[arg(long)]
plugin_dir: Option<PathBuf>,
#[arg(long)]
auto: bool,
}
pub struct ResolvedCli {
command: Option<Commands>,
config: AppConfig,
verbose: bool,
plugin_dir: Option<PathBuf>,
auto: bool,
}
#[derive(Subcommand)]
enum Commands {
Plugin {
#[command(subcommand)]
command: PluginCommands,
},
Trust {
#[command(subcommand)]
command: TrustCommands,
},
Proxy {
#[command(subcommand)]
command: ProxyCommands,
},
}
#[derive(Serialize, Deserialize)]
struct Services {
proxy: String,
web: String,
}
impl Cli {
pub async fn run(self) -> Result<()> {
let log_level = if self.verbose { "debug" } else { "info" };
tracing_subscriber::fmt()
.with_env_filter(format!("witmproxy={},{}", log_level, log_level))
.init();
let resolved_cli = self.resolve_config().await?;
if let Some(ref command) = resolved_cli.command {
return resolved_cli.handle_command(command).await;
}
resolved_cli.run_proxy().await
}
async fn resolve_config(self) -> Result<ResolvedCli> {
let config_path = expand_home_in_path(&self.config_path)?;
let config = AppConfig::builder()
.preloaded(self.config)
.env()
.file(&config_path)
.load()?
.with_resolved_paths()?;
let plugin_dir = if let Some(ref dir) = self.plugin_dir {
Some(expand_home_in_path(dir)?)
} else {
None
};
Ok(ResolvedCli {
command: self.command,
config,
verbose: self.verbose,
plugin_dir,
auto: self.auto,
})
}
}
impl ResolvedCli {
async fn handle_command(&self, command: &Commands) -> Result<()> {
match command {
Commands::Plugin { command } => {
let plugin_handler = plugin::PluginHandler::new(self.config.clone(), self.verbose);
plugin_handler.handle(command).await
}
Commands::Trust { command } => {
let trust_handler = trust::TrustHandler::new(self.config.clone());
trust_handler.handle(command).await
}
Commands::Proxy { command } => {
let proxy_handler = proxy::ProxyHandler::new(self.config.clone());
proxy_handler.handle(command).await
}
}
}
async fn run_proxy(&self) -> Result<()> {
let app_dir = self
.config
.tls
.cert_dir
.parent()
.unwrap_or(&PathBuf::from("."))
.to_path_buf();
std::fs::create_dir_all(&app_dir)?;
info!("Loaded proxy configuration");
std::fs::create_dir_all(&self.config.tls.cert_dir)?;
let ca = CertificateAuthority::new(self.config.tls.cert_dir.clone()).await?;
info!("Certificate Authority initialized");
if self.auto {
info!("Auto mode enabled: checking CA trust status");
ca.install_root_certificate(true, false).await?;
}
if let Some(parent) = self.config.db.db_path.parent() {
std::fs::create_dir_all(parent)?;
}
let db = Db::from_path(self.config.db.db_path.clone(), &self.config.db.db_password).await?;
db.migrate().await?;
info!(
"Database initialized and migrated at: {}",
self.config.db.db_path.display()
);
let plugin_registry = if self.config.plugins.enabled {
let runtime = Runtime::try_default()?;
let mut registry = PluginRegistry::new(db, runtime)?;
registry.load_plugins().await?;
info!("Number of plugins loaded: {}", registry.plugins().len());
Some(Arc::new(RwLock::new(registry)))
} else {
None
};
let plugin_registry = plugin_registry;
if let Some(ref plugin_dir) = self.plugin_dir {
if let Some(ref registry) = plugin_registry {
info!("Loading plugins from directory: {:?}", plugin_dir);
std::fs::create_dir_all(plugin_dir)?;
load_plugins_from_directory(plugin_dir, registry.clone()).await?;
} else {
warn!("--plugin-dir specified but plugins are disabled in configuration");
}
}
let ca_for_proxy = CertificateAuthority::new(self.config.tls.cert_dir.clone()).await?;
let mut proxy = WitmProxy::new(ca_for_proxy, plugin_registry.clone(), self.config.clone());
proxy.start().await?;
let proxy_addr = proxy
.proxy_listen_addr()
.ok_or_else(|| anyhow::anyhow!("Failed to get proxy listen address"))?;
let web_addr = proxy
.web_listen_addr()
.ok_or_else(|| anyhow::anyhow!("Failed to get web listen address"))?;
let services = Services {
proxy: proxy_addr.to_string(),
web: web_addr.to_string(),
};
let services_path = app_dir.join("services.json");
let services_json = serde_json::to_string_pretty(&services)?;
std::fs::write(&services_path, services_json)?;
info!("Services information written to: {:?}", services_path);
if self.auto {
info!("Auto mode: enabling system proxy");
let proxy_handler = proxy::ProxyHandler::new(self.config.clone());
proxy_handler.enable_proxy_internal(false).await?;
}
let _watcher = if let Some(ref plugin_dir) = self.plugin_dir {
if let Some(ref registry) = plugin_registry {
Some(setup_plugin_dir_watcher(
plugin_dir.clone(),
registry.clone(),
)?)
} else {
None
}
} else {
None
};
proxy.join().await?;
if self.auto {
info!("Auto mode: disabling system proxy on shutdown");
let proxy_handler = proxy::ProxyHandler::new(self.config.clone());
proxy_handler.disable_proxy_internal(false).await?;
}
proxy.shutdown().await;
Ok(())
}
}
pub async fn load_plugins_from_directory(
dir: &PathBuf,
registry: Arc<RwLock<PluginRegistry>>,
) -> Result<()> {
let entries = std::fs::read_dir(dir)?;
for entry in entries {
let entry = entry?;
let path = entry.path();
if path.is_file() && path.extension().is_some_and(|ext| ext == "wasm") {
match load_plugin_from_file(&path, ®istry).await {
Ok(plugin_id) => {
info!("Loaded plugin from file: {:?} ({})", path, plugin_id);
}
Err(e) => {
warn!("Failed to load plugin from {:?}: {}", path, e);
}
}
}
}
Ok(())
}
async fn load_plugin_from_file(
path: &PathBuf,
registry: &Arc<RwLock<PluginRegistry>>,
) -> Result<String> {
let component_bytes = std::fs::read(path)?;
let mut registry = registry.write().await;
let plugin = registry.plugin_from_component(component_bytes).await?;
let plugin_id = plugin.id();
registry.register_plugin(plugin).await?;
Ok(plugin_id)
}
fn setup_plugin_dir_watcher(
plugin_dir: PathBuf,
registry: Arc<RwLock<PluginRegistry>>,
) -> Result<RecommendedWatcher> {
let (tx, mut rx) = mpsc::channel::<notify::Result<NotifyEvent>>(100);
let mut watcher = notify::recommended_watcher(move |res| {
let _ = tx.blocking_send(res);
})?;
watcher.watch(&plugin_dir, RecursiveMode::NonRecursive)?;
info!("Watching plugin directory for changes: {:?}", plugin_dir);
let file_plugin_map: Arc<RwLock<HashMap<PathBuf, String>>> =
Arc::new(RwLock::new(HashMap::new()));
let registry_clone = registry.clone();
let plugin_dir_clone = plugin_dir.clone();
let file_plugin_map_clone = file_plugin_map.clone();
tokio::spawn(async move {
if let Ok(entries) = std::fs::read_dir(&plugin_dir_clone) {
for entry in entries.flatten() {
let path = entry.path();
if path.is_file() && path.extension().is_some_and(|ext| ext == "wasm") {
if let Ok(component_bytes) = std::fs::read(&path) {
let reg = registry_clone.read().await;
if let Ok(plugin) = reg.plugin_from_component(component_bytes).await {
let mut map = file_plugin_map_clone.write().await;
map.insert(path, plugin.id());
}
}
}
}
}
});
let registry_for_handler = registry.clone();
let file_plugin_map_for_handler = file_plugin_map;
tokio::spawn(async move {
while let Some(res) = rx.recv().await {
match res {
Ok(event) => {
handle_plugin_file_event(
event,
®istry_for_handler,
&file_plugin_map_for_handler,
)
.await;
}
Err(e) => {
error!("File watcher error: {}", e);
}
}
}
});
Ok(watcher)
}
async fn handle_plugin_file_event(
event: NotifyEvent,
registry: &Arc<RwLock<PluginRegistry>>,
file_plugin_map: &Arc<RwLock<HashMap<PathBuf, String>>>,
) {
use notify::EventKind;
for path in event.paths {
if !path.extension().is_some_and(|ext| ext == "wasm") {
continue;
}
match event.kind {
EventKind::Create(_) | EventKind::Modify(ModifyKind::Data(_)) => {
info!("Plugin file created/modified: {:?}", path);
{
let map = file_plugin_map.read().await;
if let Some(old_plugin_id) = map.get(&path) {
let parts: Vec<&str> = old_plugin_id.split('/').collect();
if parts.len() == 2 {
let mut reg = registry.write().await;
match reg.remove_plugin(parts[1], Some(parts[0])).await {
Ok(removed) => {
if !removed.is_empty() {
info!("Removed old plugin version: {}", old_plugin_id);
}
}
Err(e) => {
warn!("Failed to remove old plugin {}: {}", old_plugin_id, e);
}
}
}
}
}
match load_plugin_from_file(&path, registry).await {
Ok(plugin_id) => {
info!("Loaded/updated plugin: {} from {:?}", plugin_id, path);
let mut map = file_plugin_map.write().await;
map.insert(path.clone(), plugin_id);
}
Err(e) => {
warn!("Failed to load plugin from {:?}: {}", path, e);
}
}
}
EventKind::Remove(_) => {
info!("Plugin file removed: {:?}", path);
let plugin_id = {
let mut map = file_plugin_map.write().await;
map.remove(&path)
};
if let Some(plugin_id) = plugin_id {
let parts: Vec<&str> = plugin_id.split('/').collect();
if parts.len() == 2 {
let mut reg = registry.write().await;
match reg.remove_plugin(parts[1], Some(parts[0])).await {
Ok(removed) => {
if !removed.is_empty() {
info!("Removed plugin: {}", plugin_id);
}
}
Err(e) => {
warn!("Failed to remove plugin {}: {}", plugin_id, e);
}
}
}
}
}
_ => {}
}
}
}