use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread::{self, JoinHandle};
use std::time::Duration;
use crossbeam_channel::{Receiver, TryRecvError};
use crate::afpacket::rx::CaptureBuilder;
use crate::error::Error;
use crate::packet::OwnedPacket;
pub struct ChannelCapture {
receiver: Receiver<OwnedPacket>,
handle: Option<JoinHandle<()>>,
stop: Arc<AtomicBool>,
}
impl ChannelCapture {
pub fn spawn(interface: &str, capacity: usize) -> Result<Self, Error> {
let rx = CaptureBuilder::default().interface(interface).build()?;
let (sender, receiver) = crossbeam_channel::bounded(capacity);
let stop = Arc::new(AtomicBool::new(false));
let stop_clone = Arc::clone(&stop);
let handle = thread::spawn(move || {
let mut rx = rx;
while !stop_clone.load(Ordering::Relaxed) {
match rx.next_batch_blocking(Duration::from_millis(10)) {
Ok(Some(batch)) => {
for pkt in &batch {
let owned = pkt.to_owned();
if sender.send(owned).is_err() {
return; }
}
}
Ok(None) => continue,
Err(_) => return,
}
}
});
Ok(Self {
receiver,
handle: Some(handle),
stop,
})
}
pub fn recv(&self) -> Result<OwnedPacket, crossbeam_channel::RecvError> {
self.receiver.recv()
}
pub fn try_recv(&self) -> Result<OwnedPacket, TryRecvError> {
self.receiver.try_recv()
}
pub fn stop_and_drain(mut self) -> Vec<OwnedPacket> {
self.stop.store(true, Ordering::Relaxed);
if let Some(handle) = self.handle.take() {
let _ = handle.join();
}
let mut drained = Vec::new();
while let Ok(pkt) = self.receiver.try_recv() {
drained.push(pkt);
}
drained
}
}
impl<'a> IntoIterator for &'a ChannelCapture {
type Item = OwnedPacket;
type IntoIter = ChannelIter<'a>;
fn into_iter(self) -> ChannelIter<'a> {
ChannelIter { cap: self }
}
}
pub struct ChannelIter<'a> {
cap: &'a ChannelCapture,
}
impl Iterator for ChannelIter<'_> {
type Item = OwnedPacket;
fn next(&mut self) -> Option<OwnedPacket> {
self.cap.receiver.recv().ok()
}
}
impl Drop for ChannelCapture {
fn drop(&mut self) {
self.stop.store(true, Ordering::Relaxed);
if let Some(handle) = self.handle.take() {
let _ = handle.join();
}
}
}