pako_core/
threaded_analyzer.rs

1use std::{
2    cmp::min,
3    panic::AssertUnwindSafe,
4    sync::{Arc, Barrier},
5    thread,
6};
7
8use crossbeam_channel::{unbounded, Receiver, Sender};
9use log::{debug, info, trace, warn};
10use pako_tools::*;
11use pcap_parser::data::PacketData;
12use pnet_packet::ethernet::{EtherType, EtherTypes, EthernetPacket};
13
14use crate::{
15    analyzer::{handle_l3, run_plugins_v2_link, run_plugins_v2_physical, Analyzer},
16    layers::LinkLayerType,
17    plugin_registry::PluginRegistry,
18};
19
20pub enum Job<'a> {
21    Exit,
22    PrintDebug,
23    New(Packet<'a>, ParseContext, &'a [u8], EtherType),
24    Wait,
25}
26
27pub struct Worker {
28    pub(crate) _id: usize,
29    pub(crate) handler: thread::JoinHandle<()>,
30}
31
32/// Pcap/Pcap-ng Multi-threaded analyzer
33///
34pub struct ThreadedAnalyzer<'a> {
35    registry: Arc<PluginRegistry>,
36    /// create a local analyzer, so L2 packets can be handled without
37    /// dispatching them to threads
38    analyzer: Analyzer,
39
40    local_jobs: Vec<Sender<Job<'a>>>,
41    workers: Vec<Worker>,
42    barrier: Arc<Barrier>,
43}
44
45impl<'a> ThreadedAnalyzer<'a> {
46    pub fn new(registry: PluginRegistry, config: &Config) -> Self {
47        let n_workers = config
48            .get_usize("num_threads")
49            .map_or_else(num_cpus::get, |n| if n == 0 { num_cpus::get() } else { n });
50        let barrier = Arc::new(Barrier::new(n_workers + 1));
51        let registry = Arc::new(registry);
52        let analyzer = Analyzer::new(registry.clone(), config);
53
54        let mut workers = Vec::new();
55        let mut local_jobs = Vec::new();
56        for idx in 0..n_workers {
57            let n = format!("worker {idx}");
58            let a = Analyzer::new(registry.clone(), config);
59            let (sender, receiver) = unbounded();
60            // NOTE: remove job queue from lifetime management, it must be made 'static
61            // to be sent to threads
62            let r: Receiver<Job<'static>> = unsafe { ::std::mem::transmute(receiver) };
63            let barrier = barrier.clone();
64            let builder = thread::Builder::new();
65            let handler = builder
66                .name(n)
67                .spawn(move || {
68                    worker(a, idx, r, barrier);
69                })
70                .unwrap();
71            let worker = Worker { _id: idx, handler };
72            workers.push(worker);
73            local_jobs.push(sender);
74        }
75
76        ThreadedAnalyzer {
77            registry,
78            analyzer,
79            local_jobs,
80            workers,
81            barrier,
82        }
83    }
84
85    pub fn inner_analyzer(&self) -> &Analyzer {
86        &self.analyzer
87    }
88
89    fn wait_for_empty_jobs(&self) {
90        trace!("waiting for threads to finish processing");
91        for job in self.local_jobs.iter() {
92            job.send(Job::Wait).expect("Error while sending job");
93        }
94        self.barrier.wait();
95    }
96
97    fn dispatch(&mut self, packet: Packet<'static>, ctx: &ParseContext) -> Result<(), Error> {
98        match packet.data {
99            PacketData::L2(data) => self.handle_l2(packet, ctx, data),
100            PacketData::L3(ethertype, data) => {
101                extern_dispatch_l3(&self.local_jobs, packet, ctx, data, EtherType(ethertype))
102            }
103            PacketData::L4(_, _) => {
104                warn!("Unsupported packet data layer 4");
105                unimplemented!() // XXX
106            }
107            PacketData::Unsupported(_) => {
108                warn!("Unsupported linktype");
109                unimplemented!() // XXX
110            }
111        }
112    }
113
114    fn handle_l2(
115        &mut self,
116        packet: Packet<'static>,
117        ctx: &ParseContext,
118        data: &'static [u8],
119    ) -> Result<(), Error> {
120        trace!("handle_l2 (idx={})", ctx.pcap_index);
121        // resize slice to remove padding
122        let datalen = min(packet.caplen as usize, data.len());
123        let data = &data[..datalen];
124
125        // let start = ::std::time::Instant::now();
126        run_plugins_v2_physical(&packet, ctx, data, &mut self.analyzer)?;
127        // let elapsed = start.elapsed();
128        // debug!("Time to run l2 plugins: {}.{}", elapsed.as_secs(), elapsed.as_millis());
129
130        match EthernetPacket::new(data) {
131            Some(eth) => {
132                // debug!("    source: {}", eth.get_source());
133                // debug!("    dest  : {}", eth.get_destination());
134                match &data[..6] {
135                    [0x01, 0x00, 0x0c, 0xcc, 0xcc, 0xcc] => {
136                        info!("Cisco CDP/VTP/UDLD - ignoring");
137                        // the 'ethertype' field is used for length
138                        return Ok(());
139                    }
140                    [0x01, 0x00, 0x0c, 0xcd, 0xcd, 0xd0] => {
141                        info!("Cisco Multicast address - ignoring");
142                        return Ok(());
143                    }
144                    _ => {
145                        info!("Ethernet broadcast (unknown type) (idx={})", ctx.pcap_index);
146                    }
147                }
148                let ethertype = eth.get_ethertype();
149                let payload = &data[14..];
150                trace!("    ethertype: 0x{:x}", ethertype.0);
151                run_plugins_v2_link(
152                    &packet,
153                    ctx,
154                    LinkLayerType::Ethernet,
155                    payload,
156                    &mut self.analyzer,
157                )?;
158                extern_dispatch_l3(&self.local_jobs, packet, ctx, payload, ethertype)
159            }
160            None => {
161                // packet too small to be ethernet
162                Ok(())
163            }
164        }
165    }
166}
167
168impl<'a> PcapAnalyzer for ThreadedAnalyzer<'a> {
169    fn init(&mut self) -> Result<(), Error> {
170        self.registry.run_plugins(|_| true, |p| p.pre_process());
171
172        Ok(())
173    }
174
175    fn handle_packet(&mut self, packet: &Packet, ctx: &ParseContext) -> Result<(), Error> {
176        // NOTE: remove packet from lifetime management, it must be made 'static
177        // to be sent to threads
178        let packet: Packet<'static> = unsafe { ::std::mem::transmute(packet.clone()) };
179        self.dispatch(packet, ctx)?;
180        Ok(())
181    }
182
183    fn teardown(&mut self) {
184        debug!("main: exit");
185        self.wait_for_empty_jobs();
186        for job in self.local_jobs.iter() {
187            // XXX expire flows?
188            job.send(Job::PrintDebug).expect("Error while sending job");
189            job.send(Job::Exit).expect("Error while sending job");
190        }
191        while let Some(w) = self.workers.pop() {
192            w.handler.join().expect("panic occurred in a thread");
193        }
194        self.local_jobs.clear();
195        debug!("main: all workers ended");
196
197        self.registry.run_plugins(|_| true, |p| p.post_process());
198    }
199
200    fn before_refill(&mut self) {
201        self.wait_for_empty_jobs();
202        trace!("threads synchronized, refill");
203    }
204}
205
206pub(crate) fn extern_dispatch_l3<'a>(
207    jobs: &[Sender<Job<'a>>],
208    packet: Packet<'a>,
209    ctx: &ParseContext,
210    data: &'a [u8],
211    ethertype: EtherType,
212) -> Result<(), Error> {
213    let n_workers = jobs.len();
214    let i = fan_out(data, ethertype, n_workers);
215    debug_assert!(i < n_workers);
216    jobs[i]
217        .send(Job::New(packet, ctx.clone(), data, ethertype))
218        .or(Err(Error::Generic("Error while sending job")))
219}
220
221fn fan_out(data: &[u8], ethertype: EtherType, n_workers: usize) -> usize {
222    match ethertype {
223        EtherTypes::Ipv4 => {
224            if data.len() >= 20 {
225                // let src = &data[12..15];
226                // let dst = &data[16..19];
227                // let proto = data[9];
228                // (src[0] ^ dst[0] ^ proto) as usize % n_workers
229                let mut buf: [u8; 20] = [0; 20];
230                let sz = 4;
231                buf[0] = data[12] ^ data[16];
232                buf[1] = data[13] ^ data[17];
233                buf[2] = data[14] ^ data[18];
234                buf[3] = data[15] ^ data[19];
235                // we may append source and destination ports
236                // XXX breaks fragmentation
237                // if data[9] == crate::plugin::TRANSPORT_TCP || data[9] == crate::plugin::TRANSPORT_UDP {
238                //     if data.len() >= 24 {
239                //         // source port, in network-order
240                //         buf[8] = data[20];
241                //         buf[9] = data[21];
242                //         // destination port, in network-order
243                //         buf[10] = data[22];
244                //         buf[11] = data[23];
245                //         sz = 12;
246                //     }
247                // }
248                // let hash = crate::toeplitz::toeplitz_hash(crate::toeplitz::KEY, &buf[..sz]);
249                let hash = seahash::hash(&buf[..sz]);
250                // debug!("{:?} -- hash --> 0x{:x}", buf, hash);
251                // ((hash >> 24) ^ (hash & 0xff)) as usize % n_workers
252                hash as usize % n_workers
253            } else {
254                n_workers - 1
255            }
256        }
257        EtherTypes::Ipv6 => {
258            if data.len() >= 40 {
259                let mut buf: [u8; 40] = [0; 40];
260                // let sz = 32;
261                // source IP + destination IP, in network-order
262                // buf[0..32].copy_from_slice(&data[8..40]);
263                let sz = 16;
264                for i in 0..16 {
265                    buf[i] = data[8 + i] ^ data[24 + i];
266                }
267                // we may append source and destination ports
268                // XXX breaks fragmentation
269                // if data[6] == crate::plugin::TRANSPORT_TCP || data[6] == crate::plugin::TRANSPORT_UDP {
270                //     if data.len() >= 44 {
271                //         // source port, in network-order
272                //         buf[33] = data[40];
273                //         buf[34] = data[41];
274                //         // destination port, in network-order
275                //         buf[35] = data[42];
276                //         buf[36] = data[43];
277                //         sz += 4;
278                //     }
279                // }
280                // let hash = crate::toeplitz::toeplitz_hash(crate::toeplitz::KEY, &buf[..sz]);
281                let hash = seahash::hash(&buf[..sz]);
282                // debug!("{:?} -- hash --> 0x{:x}", buf, hash);
283                // ((hash >> 24) ^ (hash & 0xff)) as usize % n_workers
284                hash as usize % n_workers
285            } else {
286                n_workers - 1
287            }
288        }
289        _ => 0,
290    }
291}
292
293fn worker(mut a: Analyzer, idx: usize, r: Receiver<Job>, barrier: Arc<Barrier>) {
294    debug!("worker thread {} starting", idx);
295    let mut pcap_index = 0;
296    let res = ::std::panic::catch_unwind(AssertUnwindSafe(|| loop {
297        if let Ok(msg) = r.recv() {
298            match msg {
299                Job::Exit => break,
300                Job::PrintDebug => {
301                    {
302                        debug!("thread {}: hash table size: {}", idx, a.flows.len());
303                    };
304                }
305                Job::New(packet, ctx, data, ethertype) => {
306                    pcap_index = ctx.pcap_index;
307                    trace!("thread {}: got a job", idx);
308                    let h3_res = handle_l3(&packet, &ctx, data, ethertype, &mut a);
309                    if h3_res.is_err() {
310                        warn!("thread {}: handle_l3 failed", idx);
311                    }
312                }
313                Job::Wait => {
314                    trace!("Thread {}: waiting at barrier", idx);
315                    barrier.wait();
316                }
317            }
318        }
319    }));
320    if let Err(panic) = res {
321        warn!("thread {} panicked (idx={})\n{:?}", idx, pcap_index, panic);
322        // match panic.downcast::<String>() {
323        //     Ok(panic_msg) => {
324        //         println!("panic happened: {}", panic_msg);
325        //     }
326        //     Err(_) => {
327        //         println!("panic happened: unknown type.");
328        //     }
329        // }
330        // ::std::panic::resume_unwind(err);
331        ::std::process::exit(1);
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use std::mem;
338
339    use pako_tools::{Flow, Packet, ParseContext};
340
341    use super::Job;
342    #[test]
343    fn size_of_structs() {
344        println!("sizeof ParseContext: {}", mem::size_of::<ParseContext>());
345        println!("sizeof Packet: {}", mem::size_of::<Packet>());
346        println!("sizeof Flow: {}", mem::size_of::<Flow>());
347        println!("sizeof Job: {}", mem::size_of::<Job>());
348    }
349}