use crate::event_bus::{CoreEvent, event_sender};
use crate::ssh::SshClient;
use crate::tools::ToolsError;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone)]
pub struct TcpdumpEvent {
pub capture_id: u64,
pub line: String,
pub is_stderr: bool,
}
type CaptureMap = Arc<Mutex<HashMap<u64, CancellationToken>>>;
pub struct TcpdumpRegistry {
next_id: AtomicU64,
captures: CaptureMap,
}
impl Default for TcpdumpRegistry {
fn default() -> Self {
Self::new()
}
}
impl TcpdumpRegistry {
pub fn new() -> Self {
Self {
next_id: AtomicU64::new(1),
captures: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn global() -> &'static Self {
static REGISTRY: OnceLock<TcpdumpRegistry> = OnceLock::new();
REGISTRY.get_or_init(TcpdumpRegistry::new)
}
pub async fn start(
&self,
client: &SshClient,
interface: &str,
filter: &str,
snaplen: Option<u32>,
) -> Result<u64, ToolsError> {
validate_interface(interface)?;
validate_filter(filter)?;
let snap = snaplen.unwrap_or(96);
let cmd = if filter.trim().is_empty() {
format!("sudo -n tcpdump -lnn -s {snap} -i {interface}")
} else {
format!("sudo -n tcpdump -lnn -s {snap} -i {interface} '{filter}'")
};
let (mut rx, cancel) = client
.execute_command_streaming(&cmd)
.await
.map_err(|e| ToolsError::SshExec(e.to_string()))?;
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
if let Ok(mut map) = self.captures.lock() {
map.insert(id, cancel.clone());
}
let captures = self.captures.clone();
tokio::spawn(async move {
let tx = event_sender();
while let Some(line) = rx.recv().await {
let (is_stderr, payload) = if let Some(rest) = line.strip_prefix('!') {
(true, rest.to_string())
} else {
(false, line)
};
if let Some(ref tx) = tx {
let _ = tx.send(CoreEvent::TcpdumpLine {
capture_id: id,
line: payload,
is_stderr,
});
}
}
if let Ok(mut map) = captures.lock() {
map.remove(&id);
}
});
Ok(id)
}
pub fn stop(&self, id: u64) -> Result<(), ToolsError> {
let token = {
let mut map = self
.captures
.lock()
.map_err(|e| ToolsError::Parse(format!("lock poisoned: {e}")))?;
map.remove(&id)
};
match token {
Some(t) => {
t.cancel();
Ok(())
}
None => Err(ToolsError::CaptureNotFound(id)),
}
}
}
fn validate_interface(iface: &str) -> Result<(), ToolsError> {
if iface.is_empty()
|| iface.len() > 32
|| !iface
.chars()
.all(|c| c.is_ascii_alphanumeric() || matches!(c, '.' | '-' | '_' | ':'))
{
return Err(ToolsError::Parse(format!("invalid interface: {iface}")));
}
Ok(())
}
fn validate_filter(filter: &str) -> Result<(), ToolsError> {
if filter.contains('\'') {
return Err(ToolsError::Parse(
"filter may not contain single quotes".into(),
));
}
if filter.len() > 4096 {
return Err(ToolsError::Parse("filter too long".into()));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rejects_bad_interface() {
assert!(validate_interface("eth0; rm -rf /").is_err());
assert!(validate_interface("").is_err());
assert!(validate_interface("eth0").is_ok());
assert!(validate_interface("any").is_ok());
assert!(validate_interface("en0:vlan100").is_ok());
}
#[test]
fn rejects_filter_with_quotes() {
assert!(validate_filter("port 80").is_ok());
assert!(validate_filter("port 80 and 'evil'").is_err());
assert!(validate_filter("").is_ok());
}
}