nettop 0.1.1

CLI network usage monitor by application — like NetLimiter for the terminal
use std::collections::HashMap;
use std::time::{Instant, SystemTime, UNIX_EPOCH};

use sysinfo::{ProcessRefreshKind, RefreshKind, System};

use crate::procnet::read_proc_net_bytes;
use crate::types::{ProcessNet, Snapshot};

const HISTORY_LEN: usize = 60;

#[derive(Debug, Clone)]
struct PidState {
    raw_sent: u64,
    raw_recv: u64,
    sent_total: u64,
    recv_total: u64,
    has_net: bool,
    history: Vec<u64>,
}

pub struct Collector {
    sys: System,
    prev: HashMap<u32, PidState>,
    last_tick: Instant,
}

impl Collector {
    pub fn new() -> Self {
        let sys = System::new_with_specifics(
            RefreshKind::new().with_processes(
                ProcessRefreshKind::new()
                    .with_disk_usage()
                    .with_cpu()
                    .with_memory(),
            ),
        );
        Self {
            sys,
            prev: HashMap::new(),
            last_tick: Instant::now(),
        }
    }

    pub fn collect(&mut self) -> Snapshot {
        let now = Instant::now();
        let elapsed_ms = now.duration_since(self.last_tick).as_millis() as u64;
        self.last_tick = now;

        self.sys.refresh_processes_specifics(
            ProcessRefreshKind::new()
                .with_disk_usage()
                .with_cpu()
                .with_memory(),
        );

        let scale = if elapsed_ms > 0 {
            1000.0 / elapsed_ms as f64
        } else {
            1.0
        };
        let mut entries: Vec<ProcessNet> = Vec::new();

        for (pid, proc) in self.sys.processes() {
            let pid_u32 = pid.as_u32();

            let (cur_sent, cur_recv, has_net) = if let Some((s, r)) = read_proc_net_bytes(pid_u32) {
                (s, r, true)
            } else {
                let d = proc.disk_usage();
                (d.written_bytes, d.read_bytes, false)
            };

            let (sent_delta, recv_delta, sent_total, recv_total, mut history) =
                if let Some(prev) = self.prev.get(&pid_u32) {
                    if prev.has_net == has_net {
                        let sd = cur_sent.saturating_sub(prev.raw_sent);
                        let rd = cur_recv.saturating_sub(prev.raw_recv);
                        let mut h = prev.history.clone();
                        let total_rate = ((sd + rd) as f64 * scale) as u64;
                        h.push(total_rate);
                        if h.len() > HISTORY_LEN {
                            h.remove(0);
                        }
                        (
                            sd,
                            rd,
                            prev.sent_total.saturating_add(sd),
                            prev.recv_total.saturating_add(rd),
                            h,
                        )
                    } else {
                        (0, 0, prev.sent_total, prev.recv_total, prev.history.clone())
                    }
                } else {
                    (0, 0, 0, 0, Vec::new())
                };

            let sent_rate = (sent_delta as f64 * scale) as u64;
            let recv_rate = (recv_delta as f64 * scale) as u64;

            // pad history on first appearance
            if history.is_empty() {
                history.push(sent_rate + recv_rate);
            }

            self.prev.insert(
                pid_u32,
                PidState {
                    raw_sent: cur_sent,
                    raw_recv: cur_recv,
                    sent_total,
                    recv_total,
                    has_net,
                    history: history.clone(),
                },
            );

            if sent_delta == 0 && recv_delta == 0 && sent_total == 0 && recv_total == 0 {
                continue;
            }

            let name = proc.name().to_string();
            let exe = proc
                .exe()
                .map(|p| p.to_string_lossy().into_owned())
                .unwrap_or_default();

            entries.push(ProcessNet {
                pid: pid_u32,
                name,
                exe,
                sent_rate,
                recv_rate,
                sent_total,
                recv_total,
                history,
            });
        }

        // Prune dead PIDs
        let live_pids: std::collections::HashSet<u32> =
            self.sys.processes().keys().map(|p| p.as_u32()).collect();
        self.prev.retain(|pid, _| live_pids.contains(pid));

        let total_sent = entries.iter().map(|e| e.sent_rate).sum();
        let total_recv = entries.iter().map(|e| e.recv_rate).sum();
        let timestamp = SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .map(|d| d.as_secs())
            .unwrap_or(0);

        Snapshot {
            entries,
            elapsed_ms,
            total_sent,
            total_recv,
            timestamp,
        }
    }
}

pub fn apply_filter(entries: &mut Vec<ProcessNet>, filter: &Option<String>) {
    if let Some(f) = filter {
        let f_low = f.to_lowercase();
        entries.retain(|e| e.name.to_lowercase().contains(&f_low));
    }
}

// pub fn apply_sort(entries: &mut Vec<ProcessNet>, sort: &crate::args::SortBy, cumulative: bool) {
pub fn apply_sort(entries: &mut [ProcessNet], sort: &crate::args::SortBy, cumulative: bool) {
    use crate::args::SortBy::*;
    entries.sort_by(|a, b| match sort {
        Pid => a.pid.cmp(&b.pid),
        Name => a.name.cmp(&b.name),
        Sent => {
            if cumulative {
                b.sent_total.cmp(&a.sent_total)
            } else {
                b.sent_rate.cmp(&a.sent_rate)
            }
        }
        Recv => {
            if cumulative {
                b.recv_total.cmp(&a.recv_total)
            } else {
                b.recv_rate.cmp(&a.recv_rate)
            }
        }
        TotalRate => {
            if cumulative {
                b.total_cumulative().cmp(&a.total_cumulative())
            } else {
                b.total_rate().cmp(&a.total_rate())
            }
        }
        SentTotal => b.sent_total.cmp(&a.sent_total),
        RecvTotal => b.recv_total.cmp(&a.recv_total),
    });
}