sfo_cmd_server/server/
peer_manager.rs1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3use crate::{CmdTunnelRead, CmdTunnelWrite};
4use crate::peer_connection::PeerConnection;
5use crate::peer_id::PeerId;
6use crate::server::CmdServerEventListener;
7use crate::tunnel_id::{TunnelId, TunnelIdGenerator};
8
9#[derive(Clone)]
10pub struct CachedPeerInfo {
11    pub conn_list: Vec<TunnelId>,
12}
13
14pub struct PeerManager<R: CmdTunnelRead, W: CmdTunnelWrite> {
15    conn_cache: Mutex<HashMap<TunnelId, (PeerId, Arc<tokio::sync::Mutex<PeerConnection<R, W>>>)>>,
16    device_conn_map: Mutex<HashMap<PeerId, CachedPeerInfo>>,
17    conn_id_generator: TunnelIdGenerator,
18    listener: Arc<dyn CmdServerEventListener>,
19}
20pub type PeerManagerRef<R, W> = Arc<PeerManager<R, W>>;
21
22
23impl<R: CmdTunnelRead, W: CmdTunnelWrite> PeerManager<R, W> {
24    pub fn new(listener: Arc<dyn CmdServerEventListener>) -> PeerManagerRef<R, W> {
25        Arc::new(PeerManager {
26            conn_cache: Mutex::new(HashMap::new()),
27            device_conn_map: Mutex::new(HashMap::new()),
28            conn_id_generator: TunnelIdGenerator::new(),
29            listener,
30        })
31    }
32
33    pub fn generate_conn_id(&self) -> TunnelId {
34        self.conn_id_generator.generate()
35    }
36
37    pub async fn add_peer_connection(self: &Arc<Self>, mut conn: PeerConnection<R, W>) {
38        let recv_handle = conn.handle.take().unwrap();
39        let peer_id = conn.peer_id.clone();
40        let conn_id = conn.conn_id;
41        let conn_count = {
42            self.conn_cache.lock().unwrap().insert(conn_id, (peer_id.clone(), Arc::new(tokio::sync::Mutex::new(conn))));
43            let mut device_conn_map = self.device_conn_map.lock().unwrap();
44            let peer_info = device_conn_map.entry(peer_id.clone()).or_insert(CachedPeerInfo { conn_list: Vec::new() });
45            peer_info.conn_list.push(conn_id);
46            peer_info.conn_list.len()
47        };
48
49        let this = self.clone();
50        tokio::spawn(async move {
51            let _ = recv_handle.await;
52            this.remove_peer_connection(conn_id).await;
53        });
54        if conn_count == 1 {
55            let _ = self.listener.on_peer_connected(&peer_id).await;
56        }
57    }
58
59    pub async fn remove_peer_connection(&self, conn_id: TunnelId) {
60        let mut peer_id = None;
61        {
62            let mut conn_cache = self.conn_cache.lock().unwrap();
63            if let Some(conn) = conn_cache.remove(&conn_id) {
64                let mut device_conn_map = self.device_conn_map.lock().unwrap();
65                if let Some(peer_info) = device_conn_map.get_mut(&conn.0) {
66                    peer_info.conn_list.retain(|&id| id != conn_id);
67                    if peer_info.conn_list.is_empty() {
68                        device_conn_map.remove(&conn.0);
69                        peer_id = Some(conn.0.clone());
70                    }
71                }
72            }
73        }
74        if peer_id.is_some() {
75            let _ = self.listener.on_peer_disconnected(peer_id.as_ref().unwrap()).await;
76        }
77    }
78
79    pub fn find_connection(&self, conn_id: TunnelId) -> Option<Arc<tokio::sync::Mutex<PeerConnection<R, W>>>> {
80        let conn_cache = self.conn_cache.lock().unwrap();
81        conn_cache.get(&conn_id).map(|c| c.1.clone())
82    }
83
84    pub fn find_connections(&self, device_id: &PeerId) -> Vec<Arc<tokio::sync::Mutex<PeerConnection<R, W>>>> {
85        let conn_cache = self.conn_cache.lock().unwrap();
86        let device_conn_map = self.device_conn_map.lock().unwrap();
87        device_conn_map.get(device_id).map(|conns| {
88            conns.conn_list.iter().filter_map(|c| conn_cache.get(c).map(|c| c.1.clone())).collect()
89        }).unwrap_or_default()
90    }
91
92}