btlightning 0.2.8

QUIC transport layer for Bittensor
Documentation
use crate::types::{PeerAddr, QuicAxonInfo};
use quinn::Connection;
use std::collections::{HashMap, HashSet};
use std::time::Duration;
use tokio::time::Instant;

pub(crate) struct ReconnectState {
    pub attempts: u32,
    pub next_retry_at: Instant,
    pub in_progress: bool,
}

impl ReconnectState {
    pub fn new() -> Self {
        Self {
            attempts: 0,
            next_retry_at: Instant::now(),
            in_progress: false,
        }
    }
}

pub(crate) enum ReconnectRejection {
    InProgress,
    Exhausted { attempts: u32 },
    Backoff { next: Instant },
}

#[derive(Default)]
pub(crate) struct MinerRegistry {
    active_miners: HashMap<String, QuicAxonInfo>,
    established_connections: HashMap<PeerAddr, Connection>,
    reconnect_states: HashMap<PeerAddr, ReconnectState>,
    addr_to_hotkeys: HashMap<PeerAddr, HashSet<String>>,
}

impl MinerRegistry {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn register(&mut self, miner: QuicAxonInfo) {
        let new_addr_key = miner.addr_key();
        let hotkey = miner.hotkey.clone();
        if let Some(old) = self.active_miners.get(&hotkey) {
            let old_addr_key = old.addr_key();
            if old_addr_key != new_addr_key {
                if let Some(hotkeys) = self.addr_to_hotkeys.get_mut(&old_addr_key) {
                    hotkeys.remove(&hotkey);
                    if hotkeys.is_empty() {
                        self.addr_to_hotkeys.remove(&old_addr_key);
                        self.reconnect_states.remove(&old_addr_key);
                    }
                }
            }
        }
        self.active_miners.insert(hotkey.clone(), miner);
        self.addr_to_hotkeys
            .entry(new_addr_key)
            .or_default()
            .insert(hotkey);
    }

    pub fn deregister(&mut self, hotkey: &str) -> Option<QuicAxonInfo> {
        if let Some(miner) = self.active_miners.remove(hotkey) {
            let addr_key = miner.addr_key();
            if let Some(hotkeys) = self.addr_to_hotkeys.get_mut(&addr_key) {
                hotkeys.remove(hotkey);
                if hotkeys.is_empty() {
                    self.addr_to_hotkeys.remove(&addr_key);
                    self.reconnect_states.remove(&addr_key);
                }
            }
            Some(miner)
        } else {
            None
        }
    }

    pub fn addr_has_hotkeys(&self, addr_key: &PeerAddr) -> bool {
        self.addr_to_hotkeys.contains_key(addr_key)
    }

    pub fn hotkeys_at_addr(&self, addr_key: &PeerAddr) -> Vec<String> {
        self.addr_to_hotkeys
            .get(addr_key)
            .map(|hs| hs.iter().cloned().collect())
            .unwrap_or_default()
    }

    pub fn active_miner(&self, hotkey: &str) -> Option<&QuicAxonInfo> {
        self.active_miners.get(hotkey)
    }

    pub fn contains_active_miner(&self, hotkey: &str) -> bool {
        self.active_miners.contains_key(hotkey)
    }

    pub fn active_hotkeys(&self) -> Vec<String> {
        self.active_miners.keys().cloned().collect()
    }

    pub fn active_addrs(&self) -> HashSet<PeerAddr> {
        self.active_miners.values().map(|m| m.addr_key()).collect()
    }

    pub fn active_miner_count(&self) -> usize {
        self.active_miners.len()
    }

    pub fn get_connection(&self, addr: &PeerAddr) -> Option<Connection> {
        self.established_connections.get(addr).cloned()
    }

    pub fn set_connection(&mut self, addr: PeerAddr, conn: Connection) {
        self.established_connections.insert(addr, conn);
    }

    pub fn remove_connection(&mut self, addr: &PeerAddr) -> Option<Connection> {
        self.established_connections.remove(addr)
    }

    pub fn contains_connection(&self, addr: &PeerAddr) -> bool {
        self.established_connections.contains_key(addr)
    }

    pub fn connection_count(&self) -> usize {
        self.established_connections.len()
    }

    pub fn connection_addrs(&self) -> impl Iterator<Item = &PeerAddr> {
        self.established_connections.keys()
    }

    pub fn reconnect_state_or_insert(&mut self, addr: PeerAddr) -> &mut ReconnectState {
        self.reconnect_states
            .entry(addr)
            .or_insert_with(ReconnectState::new)
    }

    pub fn try_start_reconnect(
        &mut self,
        addr: PeerAddr,
        max_retries: u32,
        slow_probe_interval: Option<Duration>,
    ) -> std::result::Result<(), ReconnectRejection> {
        let rs = self
            .reconnect_states
            .entry(addr)
            .or_insert_with(ReconnectState::new);
        if rs.in_progress {
            return Err(ReconnectRejection::InProgress);
        }
        if rs.attempts >= max_retries {
            match slow_probe_interval {
                Some(_) if Instant::now() >= rs.next_retry_at => {
                    rs.in_progress = true;
                    return Ok(());
                }
                Some(_) => {
                    return Err(ReconnectRejection::Backoff {
                        next: rs.next_retry_at,
                    });
                }
                None => {
                    return Err(ReconnectRejection::Exhausted {
                        attempts: rs.attempts,
                    });
                }
            }
        }
        if Instant::now() < rs.next_retry_at {
            return Err(ReconnectRejection::Backoff {
                next: rs.next_retry_at,
            });
        }
        rs.in_progress = true;
        Ok(())
    }

    #[cfg(test)]
    pub fn reconnect_state_count(&self) -> usize {
        self.reconnect_states.len()
    }

    pub fn remove_reconnect_state(&mut self, addr: &PeerAddr) -> bool {
        self.reconnect_states.remove(addr).is_some()
    }

    pub fn drain_connections(&mut self) -> impl Iterator<Item = (PeerAddr, Connection)> + '_ {
        self.established_connections.drain()
    }

    pub fn clear(&mut self) {
        self.established_connections.clear();
        self.active_miners.clear();
        self.reconnect_states.clear();
        self.addr_to_hotkeys.clear();
    }

    #[cfg(test)]
    fn assert_invariants(&self) {
        for (hotkey, miner) in &self.active_miners {
            let addr = miner.addr_key();
            let hotkeys_at = self.addr_to_hotkeys.get(&addr);
            assert!(
                hotkeys_at.is_some_and(|hs| hs.contains(hotkey)),
                "active miner {} at {} missing from addr_to_hotkeys",
                hotkey,
                addr
            );
        }

        for (addr, hotkeys) in &self.addr_to_hotkeys {
            assert!(!hotkeys.is_empty(), "empty hotkey set at addr {}", addr);
            for hk in hotkeys {
                assert!(
                    self.active_miners.contains_key(hk),
                    "addr_to_hotkeys references {} at {} but not in active_miners",
                    hk,
                    addr
                );
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use proptest::prelude::*;

    fn arb_miner() -> impl Strategy<Value = QuicAxonInfo> {
        (
            "[a-z]{4,8}",
            (1u8..=254, 0u8..=255, 0u8..=255, 1u8..=254),
            1024u16..=65535,
        )
            .prop_map(|(hotkey, (a, b, c, d), port)| {
                let ip = format!("{}.{}.{}.{}", a, b, c, d);
                QuicAxonInfo::new(hotkey, ip, port, 4)
            })
    }

    #[derive(Debug, Clone)]
    enum Op {
        Register(QuicAxonInfo),
        Deregister(String),
    }

    fn arb_op() -> impl Strategy<Value = Op> {
        prop_oneof![
            arb_miner().prop_map(Op::Register),
            "[a-z]{4,8}".prop_map(Op::Deregister),
        ]
    }

    proptest! {
        #[test]
        fn registry_invariants_hold_after_random_ops(ops in proptest::collection::vec(arb_op(), 1..200)) {
            let mut reg = MinerRegistry::new();
            for op in ops {
                match op {
                    Op::Register(miner) => reg.register(miner),
                    Op::Deregister(hk) => { reg.deregister(&hk); }
                }
                reg.assert_invariants();
            }
        }

        #[test]
        fn register_same_hotkey_different_addr_updates_correctly(
            hotkey in "[a-z]{4,8}",
            ip1 in "1\\.0\\.0\\.[1-9]",
            ip2 in "2\\.0\\.0\\.[1-9]",
            port in 1024u16..=65535,
        ) {
            let mut reg = MinerRegistry::new();
            let m1 = QuicAxonInfo::new(hotkey.clone(), ip1, port, 4);
            let m2 = QuicAxonInfo::new(hotkey.clone(), ip2, port, 4);

            reg.register(m1);
            reg.assert_invariants();
            prop_assert_eq!(reg.active_miner_count(), 1);

            reg.register(m2);
            reg.assert_invariants();
            prop_assert_eq!(reg.active_miner_count(), 1);
        }

        #[test]
        fn deregister_nonexistent_is_noop(hotkey in "[a-z]{4,8}") {
            let mut reg = MinerRegistry::new();
            prop_assert!(reg.deregister(&hotkey).is_none());
            reg.assert_invariants();
        }

        #[test]
        fn addr_key_roundtrips((a, b, c, d) in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255), port in 1024u16..=65535) {
            let ip = format!("{}.{}.{}.{}", a, b, c, d);
            let addr = PeerAddr::new(&ip, port);
            let s: &str = addr.as_ref();
            prop_assert!(s.contains(&port.to_string()));
            prop_assert!(s.contains(&ip));
            prop_assert!(s.parse::<std::net::SocketAddr>().is_ok());
        }

        #[test]
        fn backoff_never_exceeds_max(
            initial_ms in 100u64..5000,
            max_ms in 5000u64..120_000,
            attempts in 0u32..63,
        ) {
            let initial = std::time::Duration::from_millis(initial_ms);
            let max = std::time::Duration::from_millis(max_ms);
            let shift = attempts.min(20);
            let backoff = initial
                .checked_mul(2u32.pow(shift))
                .map(|d| d.min(max))
                .unwrap_or(max);
            prop_assert!(backoff <= max);
        }
    }

    #[test]
    fn register_address_change_purges_stale_reconnect_state() {
        let mut reg = MinerRegistry::new();
        let old_addr = PeerAddr::new("1.2.3.4", 8080);
        reg.register(QuicAxonInfo::new("hk1".into(), "1.2.3.4".into(), 8080, 4));
        let rs = reg.reconnect_state_or_insert(old_addr.clone());
        rs.attempts = 3;
        assert!(reg.reconnect_states.contains_key(&old_addr));

        reg.register(QuicAxonInfo::new("hk1".into(), "5.6.7.8".into(), 9090, 4));
        assert!(!reg.reconnect_states.contains_key(&old_addr));
        reg.assert_invariants();
    }

    #[test]
    fn clear_resets_all_maps() {
        let mut reg = MinerRegistry::new();
        reg.register(QuicAxonInfo::new("hk1".into(), "1.2.3.4".into(), 8080, 4));
        reg.register(QuicAxonInfo::new("hk2".into(), "5.6.7.8".into(), 9090, 4));
        let addr = PeerAddr::new("1.2.3.4", 8080);
        reg.reconnect_state_or_insert(addr);
        assert_eq!(reg.active_miner_count(), 2);
        assert!(reg.reconnect_state_count() > 0);
        reg.clear();
        assert_eq!(reg.active_miner_count(), 0);
        assert_eq!(reg.reconnect_state_count(), 0);
        reg.assert_invariants();
    }
}