use std::collections::HashMap;
use std::process::Stdio;
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use serde_json::Value;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::Command;
use tokio::sync::Mutex;
use crate::filter::Filter;
pub struct Config {
pub target: Vec<String>,
pub max_chunks: usize,
pub threshold: f32,
pub ttl_days: u64,
pub max_db_mb: u64,
pub dry_run: bool,
pub verbose: bool,
pub stats_interval: u64,
pub cache_db: String,
}
pub async fn run(config: Config) -> Result<()> {
let (cmd, args) = config
.target
.split_first()
.ok_or_else(|| anyhow::anyhow!("No target command specified"))?;
{
let cache = crate::cache::Cache::new(&config.cache_db)?;
let n = cache.evict_expired(config.ttl_days)?;
if n > 0 {
eprintln!(
"[mcpkill] evicted {n} expired cache entries (ttl={}d)",
config.ttl_days
);
}
}
let mut child = Command::new(cmd)
.args(args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit())
.spawn()?;
let child_stdin = Arc::new(Mutex::new(child.stdin.take().unwrap()));
let child_stdout = child.stdout.take().unwrap();
let pending: Arc<Mutex<HashMap<String, String>>> = Arc::new(Mutex::new(HashMap::new()));
let filter = Arc::new(Filter::new(
&config.cache_db,
config.max_chunks,
config.threshold,
config.dry_run,
config.verbose,
config.max_db_mb,
)?);
let pending_w = Arc::clone(&pending);
let child_stdin_w = Arc::clone(&child_stdin);
let stdin_task = tokio::spawn(async move {
let reader = BufReader::new(tokio::io::stdin());
let mut lines = reader.lines();
while let Ok(Some(line)) = lines.next_line().await {
if let Ok(msg) = serde_json::from_str::<Value>(&line) {
if msg.get("method").and_then(|m| m.as_str()) == Some("tools/call") {
if let Some(id) = msg.get("id") {
let query = format!(
"{} {}",
msg["params"]["name"].as_str().unwrap_or(""),
msg["params"]["arguments"]
);
pending_w.lock().await.insert(id.to_string(), query);
}
}
}
let mut w = child_stdin_w.lock().await;
let _ = w.write_all(line.as_bytes()).await;
let _ = w.write_all(b"\n").await;
}
});
let pending_r = Arc::clone(&pending);
let filter_r = Arc::clone(&filter);
let stdout_task = tokio::spawn(async move {
let reader = BufReader::new(child_stdout);
let mut lines = reader.lines();
let mut stdout = tokio::io::stdout();
while let Ok(Some(line)) = lines.next_line().await {
let output = process_line(&line, &pending_r, &filter_r).await;
let _ = stdout.write_all(output.as_bytes()).await;
let _ = stdout.write_all(b"\n").await;
}
});
let db_path = config.cache_db.clone();
let stats_task = tokio::spawn(async move {
if config.stats_interval == 0 {
std::future::pending::<()>().await;
return;
}
let mut interval = tokio::time::interval(Duration::from_secs(config.stats_interval));
interval.tick().await; loop {
interval.tick().await;
let _ = crate::stats::print_stats(&db_path);
}
});
tokio::select! {
_ = stdin_task => {},
_ = stdout_task => {},
_ = stats_task => {},
_ = shutdown_signal() => {
let _ = child.kill().await;
}
status = child.wait() => {
std::process::exit(status?.code().unwrap_or(1));
}
}
Ok(())
}
async fn shutdown_signal() {
let ctrl_c = async { tokio::signal::ctrl_c().await.ok() };
#[cfg(unix)]
{
use tokio::signal::unix::{signal, SignalKind};
let mut sigterm = signal(SignalKind::terminate()).expect("failed to register SIGTERM");
tokio::select! {
_ = ctrl_c => eprintln!("[mcpkill] SIGINT — shutting down"),
_ = sigterm.recv() => eprintln!("[mcpkill] SIGTERM — shutting down"),
}
}
#[cfg(not(unix))]
{
ctrl_c.await;
eprintln!("[mcpkill] Ctrl-C — shutting down");
}
}
async fn process_line(
line: &str,
pending: &Arc<Mutex<HashMap<String, String>>>,
filter: &Arc<Filter>,
) -> String {
let Ok(msg) = serde_json::from_str::<Value>(line) else {
return line.to_string();
};
let Some(id) = msg.get("id") else {
return line.to_string();
};
let Some(query) = pending.lock().await.remove(&id.to_string()) else {
return line.to_string();
};
let filter_clone = Arc::clone(filter);
let msg_clone = msg.clone();
match tokio::task::spawn_blocking(move || filter_clone.process(&msg_clone, &query)).await {
Ok(Ok(filtered)) => serde_json::to_string(&filtered).unwrap_or_else(|_| line.to_string()),
Ok(Err(e)) => {
eprintln!("[mcpkill] filter error: {e}");
line.to_string()
}
Err(e) => {
eprintln!("[mcpkill] spawn_blocking panic: {e}");
line.to_string()
}
}
}