sesame_cli 0.2.0

P2P encrypted chat with deniable authentication, panic mode, and multi-peer mesh
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::{Duration, Instant};

use tokio::sync::{mpsc, watch, Notify};

use crate::crypto::LockedBytes;
use crate::types::{ChatMessage, PeerAddr, PeerId, SessionInfo, FLAG_SYSTEM_ALONE, FLAG_SYSTEM_INFO};

pub struct SessionHandle {
    pub peer_id: PeerId,
    pub peer_addr: PeerAddr,
    pub sender: mpsc::Sender<Vec<u8>>,
    pub connected_since: Instant,
    pub last_message: Instant,
    pub cancel_notify: Arc<Notify>,
}

pub struct SessionManager {
    sessions: Mutex<HashMap<PeerId, SessionHandle>>,
    phrase: LockedBytes,
    pub max_sessions: usize,
    pub same_ip_limit: usize,
    pub message_tx: mpsc::Sender<(PeerId, ChatMessage)>,
    pub inactivity_timeout: Duration,
    known_peers: Mutex<HashMap<PeerId, PeerAddr>>,
    pub my_listen_addr: PeerAddr,
    my_peer_id: Mutex<PeerId>,
    discovery_tx: Mutex<Option<mpsc::Sender<PeerAddr>>>,
    cancel_tx: watch::Sender<bool>,
    display_names: Mutex<HashMap<PeerId, String>>,
    my_display_name: Mutex<Option<String>>,
}

impl SessionManager {
    pub fn new(
        phrase: LockedBytes,
        message_tx: mpsc::Sender<(PeerId, ChatMessage)>,
        inactivity_timeout: Duration,
        my_listen_addr: PeerAddr,
        my_peer_id: PeerId,
        my_display_name: Option<String>,
    ) -> Self {
        Self {
            sessions: Mutex::new(HashMap::new()),
            phrase,
            max_sessions: 10,
            same_ip_limit: 1,
            message_tx,
            inactivity_timeout,
            known_peers: Mutex::new(HashMap::new()),
            my_listen_addr,
            my_peer_id: Mutex::new(my_peer_id),
            discovery_tx: Mutex::new(None),
            cancel_tx: watch::channel(false).0,
            display_names: Mutex::new(HashMap::new()),
            my_display_name: Mutex::new(my_display_name),
        }
    }

    pub fn cancel_rx(&self) -> watch::Receiver<bool> {
        self.cancel_tx.subscribe()
    }

    pub fn my_peer_id(&self) -> PeerId {
        *self.my_peer_id.lock().expect("my_peer_id poisoned")
    }

    pub fn set_my_peer_id(&self, peer_id: PeerId) {
        *self.my_peer_id.lock().expect("my_peer_id poisoned") = peer_id;
    }

    pub fn my_display_name(&self) -> Option<String> {
        self.my_display_name.lock().expect("my_display_name poisoned").clone()
    }

    pub fn set_display_name(&self, peer_id: PeerId, name: String) {
        self.display_names.lock().expect("display_names poisoned").insert(peer_id, name);
    }

    pub fn get_display_name(&self, peer_id: &PeerId) -> Option<String> {
        self.display_names.lock().expect("display_names poisoned").get(peer_id).cloned()
    }

    pub fn register_session(&self, handle: SessionHandle) -> Result<(), &'static str> {
        if handle.peer_id == self.my_peer_id() {
            return Err("cannot connect to self");
        }

        let mut sessions = self.sessions.lock().expect("sessions poisoned");

        if sessions.contains_key(&handle.peer_id) {
            return Err("duplicate session");
        }

        if sessions.len() >= self.max_sessions {
            return Err("max sessions reached");
        }

        let ip_count = sessions
            .values()
            .filter(|s| s.peer_addr.ip == handle.peer_addr.ip)
            .count();
        if ip_count >= self.same_ip_limit {
            return Err("same-ip limit reached");
        }

        sessions.insert(handle.peer_id, handle);
        Ok(())
    }

    pub fn remove_session(&self, peer_id: &PeerId) {
        let removed = {
            let mut sessions = self.sessions.lock().expect("sessions poisoned");
            sessions.remove(peer_id)
        };
        if let Some(handle) = removed {
            handle.cancel_notify.notify_one();
            let addr = handle.peer_addr;
            self.known_peers.lock().expect("known_peers poisoned").insert(*peer_id, addr);
            let msg = ChatMessage {
                peer_id: *peer_id,
                text: format!("connection lost, reconnecting to {peer_id}..."),
                timestamp: 0,
                flags: FLAG_SYSTEM_INFO,
            };
            let _ = self.message_tx.try_send((*peer_id, msg));
        }
    }

    pub fn disconnect_peer(&self, peer_id: &PeerId) {
        let removed = {
            let mut sessions = self.sessions.lock().expect("sessions poisoned");
            sessions.remove(peer_id)
        };
        if let Some(handle) = removed {
            handle.cancel_notify.notify_one();
            self.known_peers.lock().expect("known_peers poisoned").remove(peer_id);
        }
    }

    pub fn clear_sessions(&self) {
        let mut handles: Vec<SessionHandle> = {
            let mut sessions = self.sessions.lock().expect("sessions poisoned");
            sessions.drain().map(|(_, h)| h).collect()
        };
        let had_sessions = !handles.is_empty();
        for h in &handles {
            h.cancel_notify.notify_one();
        }
        handles.clear();
        self.known_peers.lock().expect("known_peers poisoned").clear();
        if had_sessions {
            let msg = ChatMessage {
                peer_id: self.my_peer_id(),
                text: String::new(),
                timestamp: 0,
                flags: FLAG_SYSTEM_ALONE,
            };
            let _ = self.message_tx.try_send((self.my_peer_id(), msg));
        }
    }

    pub fn panic_shutdown(&self) {
        let _ = self.cancel_tx.send(true);
        let handles: Vec<SessionHandle> = {
            let mut sessions = self.sessions.lock().expect("sessions poisoned");
            sessions.drain().map(|(_, h)| h).collect()
        };
        for h in &handles {
            h.cancel_notify.notify_one();
        }
        drop(handles);
        self.known_peers.lock().expect("known_peers poisoned").clear();
    }

    #[allow(dead_code)]
    pub fn get_sender(&self, peer_id: &PeerId) -> Option<mpsc::Sender<Vec<u8>>> {
        self.sessions
            .lock()
            .expect("sessions poisoned")
            .get(peer_id)
            .map(|s| s.sender.clone())
    }

    pub fn broadcast(&self, data: &[u8]) {
        let sessions = self.sessions.lock().expect("sessions poisoned");
        for handle in sessions.values() {
            let _ = handle.sender.try_send(data.to_vec());
        }
    }

    pub fn broadcast_except(&self, data: &[u8], exclude: &PeerId) {
        let sessions = self.sessions.lock().expect("sessions poisoned");
        for handle in sessions.values() {
            if handle.peer_id != *exclude {
                let _ = handle.sender.try_send(data.to_vec());
            }
        }
    }

    #[allow(dead_code)]
    pub fn get_session_info(&self, peer_id: &PeerId) -> Option<SessionInfo> {
        let sessions = self.sessions.lock().expect("sessions poisoned");
        sessions.get(peer_id).map(|s| SessionInfo {
            peer_id: s.peer_id,
            peer_addr: s.peer_addr.clone(),
            connected_since: s.connected_since,
            last_message: s.last_message,
        })
    }

    pub fn list_sessions(&self) -> Vec<SessionInfo> {
        let sessions = self.sessions.lock().expect("sessions poisoned");
        sessions
            .values()
            .map(|s| SessionInfo {
                peer_id: s.peer_id,
                peer_addr: s.peer_addr.clone(),
                connected_since: s.connected_since,
                last_message: s.last_message,
            })
            .collect()
    }

    pub fn list_peer_addresses(&self, exclude: &PeerId) -> Vec<PeerAddr> {
        let sessions = self.sessions.lock().expect("sessions poisoned");
        sessions
            .values()
            .filter(|s| s.peer_id != *exclude)
            .map(|s| s.peer_addr.clone())
            .collect()
    }

    pub fn is_connected_to_addr(&self, addr: &PeerAddr) -> bool {
        let sessions = self.sessions.lock().expect("sessions poisoned");
        sessions.values().any(|s| s.peer_addr == *addr)
    }

    pub fn peer_count(&self) -> usize {
        self.sessions.lock().expect("sessions poisoned").len()
    }

    pub fn phrase(&self) -> &[u8] {
        self.phrase.as_bytes()
    }

    pub fn known_peers_list(&self) -> Vec<PeerAddr> {
        self.known_peers
            .lock()
            .expect("known_peers poisoned")
            .values()
            .cloned()
            .collect()
    }

    pub fn clear_known_peers(&self) {
        self.known_peers.lock().expect("known_peers poisoned").clear();
    }

    pub fn set_discovery_tx(&self, tx: mpsc::Sender<PeerAddr>) {
        *self.discovery_tx.lock().expect("discovery_tx poisoned") = Some(tx);
    }

    pub fn send_discovered(&self, addr: &PeerAddr) {
        if let Some(ref tx) = *self.discovery_tx.lock().expect("discovery_tx poisoned") {
            let _ = tx.try_send(addr.clone());
        }
    }

    pub fn system_msg(&self, text: &str) {
        let msg = ChatMessage {
            peer_id: PeerId([0u8; 32]),
            text: text.to_string(),
            timestamp: 0,
            flags: FLAG_SYSTEM_INFO,
        };
        let _ = self.message_tx.try_send((PeerId([0u8; 32]), msg));
    }

    pub fn spawn_timeout_checker(self: &Arc<Self>) {
        let this = self.clone();
        tokio::spawn(async move {
            loop {
                tokio::time::sleep(Duration::from_secs(30)).await;
                let now = Instant::now();
                let stale: Vec<PeerId> = {
                    let sessions = this.sessions.lock().expect("sessions poisoned");
                    sessions
                        .iter()
                        .filter(|(_, h)| now.duration_since(h.last_message) > this.inactivity_timeout)
                        .map(|(id, _)| *id)
                        .collect()
                };
                for id in stale {
                    this.remove_session(&id);
                }
            }
        });
    }

    pub fn update_last_message(&self, peer_id: &PeerId) {
        if let Some(handle) = self.sessions.lock().expect("sessions poisoned").get_mut(peer_id) {
            handle.last_message = Instant::now();
        }
    }
}