use std::net::IpAddr;
use std::time::Duration;
use anyhow::{Context, Result};
use pcap::{Capture, Device};
use tokio::sync::mpsc;
use crate::capture::parser::{self, ParsedPacket};
const MAX_BACKOFF: Duration = Duration::from_secs(30);
pub struct PacketEvent {
pub parsed: ParsedPacket,
}
pub fn start_capture(
interface: Option<String>,
filter: Option<String>,
promiscuous: bool,
local_net: Option<(IpAddr, u8)>,
tx: mpsc::UnboundedSender<PacketEvent>,
) -> Result<CaptureHandle> {
let _ = resolve_device(&interface)?;
let handle = std::thread::Builder::new()
.name("packet-capture".into())
.spawn(move || {
let mut backoff = Duration::from_millis(250);
loop {
let cap = match open_capture(&interface, &filter, promiscuous) {
Ok(c) => {
backoff = Duration::from_millis(250); c
}
Err(_) => {
if tx.is_closed() {
return;
}
std::thread::sleep(backoff);
backoff = (backoff * 2).min(MAX_BACKOFF);
continue;
}
};
let datalink = cap.datalink;
match run_capture_loop(cap.handle, datalink, local_net, &tx) {
LoopExit::ReceiverDropped => return,
LoopExit::TransientError => {
if tx.is_closed() {
return;
}
std::thread::sleep(backoff);
backoff = (backoff * 2).min(MAX_BACKOFF);
}
}
}
})
.context("Failed to spawn capture thread")?;
Ok(CaptureHandle { _thread: handle })
}
enum LoopExit {
ReceiverDropped,
TransientError,
}
struct OpenedCapture {
handle: Capture<pcap::Active>,
datalink: pcap::Linktype,
}
fn resolve_device(interface: &Option<String>) -> Result<Device> {
if let Some(name) = interface {
Device::list()
.context("Failed to list devices")?
.into_iter()
.find(|d| d.name == *name)
.with_context(|| format!("Interface '{}' not found", name))
} else {
Device::lookup()
.context("Failed to lookup default device")?
.context("No default device found")
}
}
fn open_capture(
interface: &Option<String>,
filter: &Option<String>,
promiscuous: bool,
) -> Result<OpenedCapture> {
let device = resolve_device(interface)?;
let mut cap = Capture::from_device(device)
.context("Failed to open device")?
.promisc(promiscuous)
.snaplen(256)
.timeout(100)
.open()
.context("Failed to activate capture")?;
if let Some(f) = filter {
cap.filter(f, true)
.with_context(|| format!("Failed to set BPF filter: {}", f))?;
}
let datalink = cap.get_datalink();
Ok(OpenedCapture {
handle: cap,
datalink,
})
}
fn run_capture_loop(
mut cap: Capture<pcap::Active>,
datalink: pcap::Linktype,
local_net: Option<(IpAddr, u8)>,
tx: &mpsc::UnboundedSender<PacketEvent>,
) -> LoopExit {
loop {
match cap.next_packet() {
Ok(packet) => {
let parsed = match datalink {
pcap::Linktype::ETHERNET => parser::parse_ethernet(packet.data, local_net),
pcap::Linktype(0) => {
parser::parse_loopback(packet.data, local_net)
}
pcap::Linktype(113) => {
parser::parse_sll(packet.data, local_net)
}
pcap::Linktype(101) => {
parser::parse_raw(packet.data, local_net)
}
_ => {
parser::parse_raw(packet.data, local_net)
}
};
if let Some(p) = parsed
&& tx.send(PacketEvent { parsed: p }).is_err()
{
return LoopExit::ReceiverDropped;
}
}
Err(pcap::Error::TimeoutExpired) => continue,
Err(_) => return LoopExit::TransientError,
}
}
}
pub struct CaptureHandle {
_thread: std::thread::JoinHandle<()>,
}
pub fn list_interfaces() -> Result<Vec<String>> {
Ok(Device::list()
.context("Failed to list devices")?
.into_iter()
.map(|d| d.name)
.collect())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn list_interfaces_returns_non_empty_on_typical_host() {
match list_interfaces() {
Ok(ifs) => assert!(!ifs.is_empty(), "expected at least one interface"),
Err(_) => { }
}
}
#[test]
fn resolve_device_unknown_name_is_err() {
let result = resolve_device(&Some(
"definitely_not_a_real_interface_xyz_12345".to_string(),
));
assert!(result.is_err(), "expected error for unknown interface");
}
#[test]
fn resolve_device_default_when_none() {
let _ = resolve_device(&None);
}
#[test]
fn max_backoff_is_30_seconds() {
assert_eq!(MAX_BACKOFF, Duration::from_secs(30));
}
#[test]
fn packet_event_holds_parsed_ref() {
fn _shape_check(p: ParsedPacket) -> PacketEvent {
PacketEvent { parsed: p }
}
}
#[test]
fn loop_exit_is_two_variants() {
let _ = LoopExit::ReceiverDropped;
let _ = LoopExit::TransientError;
}
}