Skip to main content

antenna_protocol/mesh/
mod.rs

1use crate::{
2    HandshakeFSM, HandshakeInput, HandshakeMode, HandshakeState, HandshakeStrategy, Identity,
3    Input, MAX_RECONNECT_ATTEMPTS, MsgPayload, Output, PeerID, RECONNECT_INTERVAL_MS, RelayPayload,
4    Scheduled, SignalingPayload, UserMsgPayload,
5};
6use anyhow::{Result, anyhow};
7use std::collections::{HashMap, HashSet, VecDeque};
8
9#[derive(Debug, Default, Clone)]
10struct DroppedPeerState {
11    attempts: u32,
12}
13
14pub struct HandshakeContext {
15    pub fsm: HandshakeFSM,
16    pub mode: HandshakeMode,
17}
18
19/// Coarse mesh-membership state for the local node.
20///
21/// Recomputed after every [`MeshNodeFSM::process`] call. Once `Left`, stays `Left`.
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum FSMState {
24    /// No peer is `Connected` yet — the node has not joined any mesh.
25    Init,
26    /// At least one peer is `Connected`, but some relay handshake is still in progress.
27    Connected,
28    /// All peer handshakes have settled — the node is fully meshed and ready to send.
29    Available,
30    /// The node has issued [`Input::Leave`]; terminal.
31    Left,
32}
33
34/// Top-level state machine for a mesh node.
35///
36/// Owns the local [`Identity`], every per-peer [`HandshakeFSM`], and the
37/// coarse [`FSMState`]. [`MeshNodeFSM::process`] is the single entry point:
38/// feed it an [`Input`], get back a `Vec<Output>`. No I/O happens here —
39/// outputs are commands the driver must execute.
40pub struct MeshNodeFSM {
41    /// current peer ID
42    id: PeerID,
43
44    /// identity of current peer: id and key pair
45    identity: Identity,
46
47    /// Map of handshake automati, contains state of current handshakes with other sessions
48    connections: HashMap<PeerID, HandshakeContext>,
49
50    /// Pool of open-offer handshakes before the joiner's peer ID is known
51    pending_handshakes: VecDeque<HandshakeContext>,
52
53    /// Peers we lost abruptly and are trying to reconnect to. Cleared on successful
54    /// reconnect (`HandshakeState::Connected`) or on `Input::Leave`.
55    lost_peers: HashMap<PeerID, DroppedPeerState>,
56
57    /// Coarse mesh-membership state; recomputed after each `process` call.
58    /// Once `Left`, stays `Left`.
59    state: FSMState,
60}
61
62impl Default for MeshNodeFSM {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68impl MeshNodeFSM {
69    pub fn new() -> Self {
70        Self::with_identity(Identity::new())
71    }
72
73    pub fn with_identity(identity: Identity) -> Self {
74        Self {
75            id: identity.peer_id(),
76            identity,
77            connections: HashMap::new(),
78            pending_handshakes: VecDeque::new(),
79            lost_peers: HashMap::new(),
80            state: FSMState::Init,
81        }
82    }
83
84    pub fn state(&self) -> FSMState {
85        self.state
86    }
87
88    fn compute_state(&self) -> FSMState {
89        if self.state == FSMState::Left {
90            return FSMState::Left;
91        }
92        if self.connected_peers().is_empty() {
93            return FSMState::Init;
94        }
95        let in_progress_relays = self.connections.values().any(|ctx| {
96            matches!(ctx.mode, HandshakeMode::Relay(_))
97                && *ctx.fsm.state() != HandshakeState::Connected
98        });
99        if in_progress_relays {
100            FSMState::Connected
101        } else {
102            FSMState::Available
103        }
104    }
105
106    /// Outputs emitted for a state edge. Multi-step jumps (e.g. Init →
107    fn state_transition_outputs<Msg: UserMsgPayload>(
108        prev: FSMState,
109        new: FSMState,
110    ) -> Vec<Output<Msg>> {
111        match (prev, new) {
112            (FSMState::Init, FSMState::Connected) => vec![Output::Connected],
113            (FSMState::Init, FSMState::Available) => {
114                vec![Output::Connected, Output::Available]
115            }
116            (FSMState::Connected, FSMState::Available) => vec![Output::Available],
117            (FSMState::Available, FSMState::Connected) => vec![Output::Unavailable],
118            (FSMState::Connected | FSMState::Available, FSMState::Init) => {
119                vec![Output::Unavailable]
120            }
121            (_, FSMState::Left) => vec![Output::Disconnecting],
122            _ => vec![],
123        }
124    }
125
126    pub fn id(&self) -> &PeerID {
127        &self.id
128    }
129
130    pub fn is_connected(&self, peer: &PeerID) -> bool {
131        self.connections.contains_key(peer)
132            && *self.connections.get(peer).unwrap().fsm.state() == HandshakeState::Connected
133    }
134
135    /// Returns true if `peer` is allowed to deliver `msg` to us right now.
136    pub fn channel_open_for_msg<Msg: UserMsgPayload>(
137        &self,
138        peer: &PeerID,
139        msg: &MsgPayload<Msg>,
140    ) -> bool {
141        match msg {
142            MsgPayload::RelaySignalingTo { .. } | MsgPayload::RelaySignalingFrom { .. } => {
143                matches!(
144                    self.connections.get(peer).map(|c| c.fsm.state()),
145                    Some(HandshakeState::Connected | HandshakeState::WaitingForDataChannel)
146                )
147            }
148            MsgPayload::User(_) | MsgPayload::Disconnect => self.is_connected(peer),
149        }
150    }
151
152    pub fn connected_peers(&self) -> HashSet<PeerID> {
153        self.connections
154            .iter()
155            .filter(|x| *x.1.fsm.state() == HandshakeState::Connected)
156            .map(|x| x.0.clone())
157            .collect()
158    }
159
160    pub fn connected_number(&self) -> usize {
161        self.connections.iter().fold(0, |a, x| {
162            if *x.1.fsm.state() == HandshakeState::Connected {
163                a + 1
164            } else {
165                a
166            }
167        })
168    }
169
170    pub fn handle_init_handshake<Msg: UserMsgPayload>(
171        &mut self,
172        with: PeerID,
173        mode: HandshakeMode,
174        strategy: HandshakeStrategy,
175    ) -> Result<Vec<Output<Msg>>> {
176        self.connections.insert(
177            with,
178            HandshakeContext {
179                fsm: HandshakeFSM::new(strategy),
180                mode,
181            },
182        );
183        Ok(vec![])
184    }
185
186    pub fn handle_init_open_offer<Msg: UserMsgPayload>(&mut self) -> Result<Vec<Output<Msg>>> {
187        let mut ctx = HandshakeContext {
188            fsm: HandshakeFSM::new(HandshakeStrategy::Host),
189            mode: HandshakeMode::Bootstrap,
190        };
191        ctx.fsm.process(HandshakeInput::Init)?;
192        self.pending_handshakes.push_back(ctx);
193        Ok(vec![Output::InitOpenOffer])
194    }
195
196    pub fn handle_open_offer_created<Msg: UserMsgPayload>(
197        &mut self,
198        sdp: String,
199    ) -> Result<Vec<Output<Msg>>> {
200        let offer = SignalingPayload {
201            token: self.identity.create_token(&sdp)?,
202            pubkey: self.identity.pubkey(),
203        };
204        self.pending_handshakes
205            .back_mut()
206            .ok_or_else(|| anyhow!("No pending open offer"))?
207            .fsm
208            .process(HandshakeInput::OfferCreated(sdp))?;
209        Ok(vec![Output::OfferReady(offer)])
210    }
211
212    pub fn handle_send<Msg: UserMsgPayload>(
213        &mut self,
214        peer_to: PeerID,
215        data: MsgPayload<Msg>,
216    ) -> Result<Vec<Output<Msg>>> {
217        if self.is_connected(&peer_to) {
218            Ok(vec![Output::SendMessage { peer_to, data }])
219        } else {
220            Ok(vec![])
221        }
222    }
223
224    pub fn handle_broadcast<Msg: UserMsgPayload>(
225        &mut self,
226
227        data: MsgPayload<Msg>,
228    ) -> Result<Vec<Output<Msg>>> {
229        let mut out = vec![];
230        for peer in self.connections.keys() {
231            if !self.is_connected(peer) {
232                continue;
233            }
234            out.push(Output::SendMessage {
235                peer_to: peer.clone(),
236                data: data.clone(),
237            })
238        }
239        Ok(out)
240    }
241
242    pub fn process<Msg: UserMsgPayload>(&mut self, input: Input<Msg>) -> Result<Vec<Output<Msg>>> {
243        let prev_state = self.state;
244        let mut outputs = self.dispatch(input)?;
245        let new_state = self.compute_state();
246        if prev_state != new_state {
247            self.state = new_state;
248            outputs.extend(Self::state_transition_outputs::<Msg>(prev_state, new_state));
249        }
250        Ok(outputs)
251    }
252
253    fn dispatch<Msg: UserMsgPayload>(&mut self, input: Input<Msg>) -> Result<Vec<Output<Msg>>> {
254        match input {
255            Input::InitHandshake {
256                with,
257                mode,
258                strategy,
259            } => self.handle_init_handshake(with, mode, strategy),
260            Input::InitOpenOffer => self.handle_init_open_offer(),
261            Input::OpenOfferCreated(sdp) => self.handle_open_offer_created(sdp),
262            Input::Handshake { from, event } => self.handle_handshake(from, event),
263            Input::PeerLeaving { peer } => self.handle_peer_leaving(peer),
264            Input::MessageReceived { peer_from, data } => self.handle_message(peer_from, data),
265            Input::Send { peer_to, data } => self.handle_send(peer_to, data),
266            Input::Broadcast { data } => self.handle_broadcast(data),
267            Input::Leave => self.handle_leave(),
268            Input::TimerFired { kind } => self.handle_timer_fired(kind),
269        }
270    }
271
272    pub fn identity(&self) -> &Identity {
273        &self.identity
274    }
275
276    /// Debug: snapshot of every connection's `(peer, state, mode)`.
277    pub fn connections_snapshot(&self) -> Vec<(PeerID, HandshakeState, HandshakeMode)> {
278        self.connections
279            .iter()
280            .map(|(peer, ctx)| (peer.clone(), ctx.fsm.state().clone(), ctx.mode.clone()))
281            .collect()
282    }
283
284    /// Debug: number of contexts in `pending_handshakes`.
285    pub fn pending_handshakes_len(&self) -> usize {
286        self.pending_handshakes.len()
287    }
288
289    fn handle_leave<Msg: UserMsgPayload>(&mut self) -> Result<Vec<Output<Msg>>> {
290        if self.state == FSMState::Left {
291            return Ok(vec![]);
292        }
293        let mut out = vec![];
294        for peer in self.connections.keys() {
295            if self.is_connected(peer) {
296                out.push(Output::SendMessage {
297                    peer_to: peer.clone(),
298                    data: MsgPayload::Disconnect,
299                });
300            }
301        }
302        self.connections.clear();
303        self.pending_handshakes.clear();
304        self.lost_peers.clear();
305        self.state = FSMState::Left;
306        Ok(out)
307    }
308
309    fn handle_peer_leaving<Msg: UserMsgPayload>(
310        &mut self,
311        peer: PeerID,
312    ) -> Result<Vec<Output<Msg>>> {
313        let was_connected = self.connections.remove(&peer);
314
315        let mut out = Vec::new();
316        if was_connected.is_some() {
317            out.push(Output::PeerDisconnected { peer });
318        }
319        Ok(out)
320    }
321
322    pub(crate) fn handle_handshake<Msg: UserMsgPayload>(
323        &mut self,
324        peer: PeerID,
325        event: HandshakeInput,
326    ) -> Result<Vec<Output<Msg>>> {
327        let mut outputs: Vec<Output<Msg>> = vec![];
328
329        if !self.connections.contains_key(&peer) {
330            match &event {
331                HandshakeInput::Answer(_) => {
332                    let ctx = self
333                        .pending_handshakes
334                        .pop_front()
335                        .ok_or_else(|| anyhow!("Pending handshake not found"))?;
336                    self.connections.insert(peer.clone(), ctx);
337                }
338                HandshakeInput::ConnectionDropped => {
339                    return Ok(outputs);
340                }
341                _ => return Err(anyhow!("Handshake instance with peer not found")),
342            }
343        }
344
345        let side_effects_outs = self.handle_side_effects(&peer, &event)?;
346        outputs.extend(side_effects_outs);
347
348        let handshake_out = {
349            let ctx = self.connections.get_mut(&peer);
350            if let Some(ctx) = ctx {
351                ctx.fsm.process(event.clone())?
352            } else {
353                None
354            }
355        };
356
357        if let Some(event) = handshake_out {
358            outputs.push(Output::Handshake {
359                peer: peer.clone(),
360                event,
361            });
362        }
363
364        let ctx = self.connections.get(&peer);
365        if let Some(ctx) = ctx {
366            match ctx.fsm.state() {
367                HandshakeState::Connected => {
368                    self.lost_peers.remove(&peer);
369                    outputs.push(Output::PeerConnected { peer: peer.clone() });
370                    for existing in self.connections.keys() {
371                        if !self.is_connected(existing) || *existing == peer {
372                            continue;
373                        }
374                        outputs.push(Output::SendMessage {
375                            peer_to: existing.clone(),
376                            data: MsgPayload::RelaySignalingFrom {
377                                src: peer.clone(),
378                                data: RelayPayload::InitConnect(peer.clone()),
379                            },
380                        });
381                        outputs.push(Output::SendMessage {
382                            peer_to: peer.clone(),
383                            data: MsgPayload::RelaySignalingFrom {
384                                src: existing.clone(),
385                                data: RelayPayload::InitConnect(existing.clone()),
386                            },
387                        });
388                    }
389                }
390                HandshakeState::Closed => {
391                    self.connections.remove(&peer);
392                    self.lost_peers.entry(peer.clone()).or_default();
393                    outputs.push(Output::PeerLost { peer: peer.clone() });
394                    outputs.push(Output::ScheduleTimer {
395                        kind: Scheduled::ReconnectAttempt { peer: peer.clone() },
396                        after_ms: RECONNECT_INTERVAL_MS,
397                    });
398
399                    let orphans: Vec<PeerID> = self
400                        .connections
401                        .iter()
402                        .filter(|(_, c)| {
403                            matches!(&c.mode, HandshakeMode::Relay(via) if via == &peer)
404                                && *c.fsm.state() != HandshakeState::Connected
405                        })
406                        .map(|(id, _)| id.clone())
407                        .collect();
408                    for orphan in orphans {
409                        self.connections.remove(&orphan);
410                        self.lost_peers.entry(orphan.clone()).or_default();
411                        outputs.push(Output::ScheduleTimer {
412                            kind: Scheduled::ReconnectAttempt { peer: orphan },
413                            after_ms: RECONNECT_INTERVAL_MS,
414                        });
415                    }
416                }
417                _ => {}
418            }
419        }
420
421        Ok(outputs)
422    }
423
424    fn handle_side_effects<Msg: UserMsgPayload>(
425        &mut self,
426        peer: &PeerID,
427        event: &HandshakeInput,
428    ) -> Result<Vec<Output<Msg>>> {
429        let ctx = self.connections.get(peer).unwrap();
430        let mut outputs: Vec<Output<Msg>> = vec![];
431        match &event {
432            HandshakeInput::Offer(payload) | HandshakeInput::Answer(payload) => {
433                payload.get_sdp_verified(peer)?;
434            }
435            HandshakeInput::AnswerCreated(answer) => {
436                let answer = SignalingPayload {
437                    token: self.identity.create_token(answer)?,
438                    pubkey: self.identity.pubkey(),
439                };
440                match &ctx.mode {
441                    HandshakeMode::Bootstrap => outputs.push(Output::AnswerReady(answer)),
442                    HandshakeMode::Relay(via) => {
443                        outputs.push(Output::SendMessage {
444                            peer_to: via.clone(),
445                            data: MsgPayload::RelaySignalingTo {
446                                dst: peer.clone(),
447                                data: RelayPayload::Answer(answer),
448                            },
449                        });
450                    }
451                }
452            }
453            HandshakeInput::OfferCreated(offer) => {
454                let offer = SignalingPayload {
455                    token: self.identity.create_token(offer)?,
456                    pubkey: self.identity.pubkey(),
457                };
458                match &ctx.mode {
459                    HandshakeMode::Bootstrap => outputs.push(Output::OfferReady(offer)),
460                    HandshakeMode::Relay(via) => {
461                        outputs.push(Output::SendMessage {
462                            peer_to: via.clone(),
463                            data: MsgPayload::RelaySignalingTo {
464                                dst: peer.clone(),
465                                data: RelayPayload::Offer(offer),
466                            },
467                        });
468                    }
469                }
470            }
471            _ => {}
472        }
473        Ok(outputs)
474    }
475
476    pub(crate) fn handle_message<Msg: UserMsgPayload>(
477        &mut self,
478        peer: PeerID,
479        msg: MsgPayload<Msg>,
480    ) -> Result<Vec<Output<Msg>>> {
481        if !self.channel_open_for_msg(&peer, &msg) {
482            return Ok(vec![]);
483        }
484
485        match msg {
486            MsgPayload::RelaySignalingTo { dst, data } => {
487                self.handle_relay_signaling_to(peer, dst, data)
488            }
489            MsgPayload::RelaySignalingFrom { src, data } => {
490                self.handle_relay_signaling_from(peer, src, data)
491            }
492            MsgPayload::User(_) => Ok(vec![Output::ReceiveMessage {
493                peer_from: peer,
494                data: msg,
495            }]),
496            MsgPayload::Disconnect => self.handle_peer_leaving(peer),
497        }
498    }
499
500    fn handle_relay_signaling_to<Msg: UserMsgPayload>(
501        &mut self,
502        src: PeerID,
503        dst: PeerID,
504        data: RelayPayload,
505    ) -> Result<Vec<Output<Msg>>> {
506        Ok(vec![Output::SendMessage {
507            peer_to: dst,
508            data: MsgPayload::RelaySignalingFrom { src, data },
509        }])
510    }
511
512    fn handle_timer_fired<Msg: UserMsgPayload>(
513        &mut self,
514        kind: Scheduled,
515    ) -> Result<Vec<Output<Msg>>> {
516        match kind {
517            Scheduled::ReconnectAttempt { peer } => self.handle_reconnect_attempt(peer),
518        }
519    }
520
521    fn handle_reconnect_attempt<Msg: UserMsgPayload>(
522        &mut self,
523        peer: PeerID,
524    ) -> Result<Vec<Output<Msg>>> {
525        if !self.lost_peers.contains_key(&peer) {
526            return Ok(vec![]);
527        }
528        if self.is_connected(&peer) {
529            self.lost_peers.remove(&peer);
530            return Ok(vec![]);
531        }
532        if self.connections.contains_key(&peer) {
533            return Ok(vec![]);
534        }
535
536        let attempts = {
537            let state = self.lost_peers.get_mut(&peer).unwrap();
538            state.attempts += 1;
539            state.attempts
540        };
541
542        let mut outputs: Vec<Output<Msg>> = vec![];
543
544        if attempts > MAX_RECONNECT_ATTEMPTS {
545            self.lost_peers.remove(&peer);
546            return Ok(outputs);
547        }
548
549        let relay_peer = match self.connected_peers().into_iter().min() {
550            Some(peer) => peer,
551            None => {
552                self.lost_peers.remove(&peer);
553                return Ok(outputs);
554            }
555        };
556
557        let i_am_host = self.id < peer;
558        outputs.push(Output::SendMessage {
559            peer_to: relay_peer.clone(),
560            data: MsgPayload::RelaySignalingTo {
561                dst: peer.clone(),
562                data: RelayPayload::InitConnect(self.id.clone()),
563            },
564        });
565
566        if i_am_host {
567            let init_outs = self.process::<Msg>(Input::InitHandshake {
568                with: peer.clone(),
569                mode: HandshakeMode::Relay(relay_peer),
570                strategy: HandshakeStrategy::Host,
571            })?;
572            outputs.extend(init_outs);
573            let step_outs = self.process::<Msg>(Input::Handshake {
574                from: peer.clone(),
575                event: HandshakeInput::Init,
576            })?;
577            outputs.extend(step_outs);
578        }
579
580        outputs.push(Output::ScheduleTimer {
581            kind: Scheduled::ReconnectAttempt { peer },
582            after_ms: RECONNECT_INTERVAL_MS,
583        });
584
585        Ok(outputs)
586    }
587
588    fn handle_relay_signaling_from<Msg: UserMsgPayload>(
589        &mut self,
590        via: PeerID,
591        src: PeerID,
592        data: RelayPayload,
593    ) -> Result<Vec<Output<Msg>>> {
594        match data {
595            RelayPayload::InitConnect(_) => {
596                if self.connections.contains_key(&src) {
597                    return Ok(vec![]);
598                }
599                let strategy = if self.id < src {
600                    HandshakeStrategy::Host
601                } else {
602                    HandshakeStrategy::Joiner
603                };
604                self.process::<Msg>(Input::InitHandshake {
605                    with: src.clone(),
606                    mode: HandshakeMode::Relay(via),
607                    strategy: strategy.clone(),
608                })?;
609                match strategy {
610                    HandshakeStrategy::Host => self.process::<Msg>(Input::Handshake {
611                        from: src,
612                        event: HandshakeInput::Init,
613                    }),
614                    HandshakeStrategy::Joiner => Ok(vec![]),
615                }
616            }
617            RelayPayload::Offer(offer) => self.process::<Msg>(Input::Handshake {
618                from: src,
619                event: HandshakeInput::Offer(offer),
620            }),
621            RelayPayload::Answer(answer) => self.process::<Msg>(Input::Handshake {
622                from: src,
623                event: HandshakeInput::Answer(answer),
624            }),
625        }
626    }
627}