1use crate::{
2 HandshakeFSM, HandshakeInput, HandshakeMode, HandshakeState, HandshakeStrategy, Identity,
3 Input, MAX_RECONNECT_ATTEMPTS, MsgPayload, Output, PeerID, RECONNECT_INTERVAL_MS, RelayPayload,
4 Scheduled, SignalingPayload, UserMsgPayload,
5};
6use anyhow::{Result, anyhow};
7use std::collections::{HashMap, HashSet, VecDeque};
8
9#[derive(Debug, Default, Clone)]
10struct DroppedPeerState {
11 attempts: u32,
12}
13
14pub struct HandshakeContext {
15 pub fsm: HandshakeFSM,
16 pub mode: HandshakeMode,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum FSMState {
24 Init,
26 Connected,
28 Available,
30 Left,
32}
33
34pub struct MeshNodeFSM {
41 id: PeerID,
43
44 identity: Identity,
46
47 connections: HashMap<PeerID, HandshakeContext>,
49
50 pending_handshakes: VecDeque<HandshakeContext>,
52
53 lost_peers: HashMap<PeerID, DroppedPeerState>,
56
57 state: FSMState,
60}
61
62impl Default for MeshNodeFSM {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68impl MeshNodeFSM {
69 pub fn new() -> Self {
70 Self::with_identity(Identity::new())
71 }
72
73 pub fn with_identity(identity: Identity) -> Self {
74 Self {
75 id: identity.peer_id(),
76 identity,
77 connections: HashMap::new(),
78 pending_handshakes: VecDeque::new(),
79 lost_peers: HashMap::new(),
80 state: FSMState::Init,
81 }
82 }
83
84 pub fn state(&self) -> FSMState {
85 self.state
86 }
87
88 fn compute_state(&self) -> FSMState {
89 if self.state == FSMState::Left {
90 return FSMState::Left;
91 }
92 if self.connected_peers().is_empty() {
93 return FSMState::Init;
94 }
95 let in_progress_relays = self.connections.values().any(|ctx| {
96 matches!(ctx.mode, HandshakeMode::Relay(_))
97 && *ctx.fsm.state() != HandshakeState::Connected
98 });
99 if in_progress_relays {
100 FSMState::Connected
101 } else {
102 FSMState::Available
103 }
104 }
105
106 fn state_transition_outputs<Msg: UserMsgPayload>(
108 prev: FSMState,
109 new: FSMState,
110 ) -> Vec<Output<Msg>> {
111 match (prev, new) {
112 (FSMState::Init, FSMState::Connected) => vec![Output::Connected],
113 (FSMState::Init, FSMState::Available) => {
114 vec![Output::Connected, Output::Available]
115 }
116 (FSMState::Connected, FSMState::Available) => vec![Output::Available],
117 (FSMState::Available, FSMState::Connected) => vec![Output::Unavailable],
118 (FSMState::Connected | FSMState::Available, FSMState::Init) => {
119 vec![Output::Unavailable]
120 }
121 (_, FSMState::Left) => vec![Output::Disconnecting],
122 _ => vec![],
123 }
124 }
125
126 pub fn id(&self) -> &PeerID {
127 &self.id
128 }
129
130 pub fn is_connected(&self, peer: &PeerID) -> bool {
131 self.connections.contains_key(peer)
132 && *self.connections.get(peer).unwrap().fsm.state() == HandshakeState::Connected
133 }
134
135 pub fn channel_open_for_msg<Msg: UserMsgPayload>(
137 &self,
138 peer: &PeerID,
139 msg: &MsgPayload<Msg>,
140 ) -> bool {
141 match msg {
142 MsgPayload::RelaySignalingTo { .. } | MsgPayload::RelaySignalingFrom { .. } => {
143 matches!(
144 self.connections.get(peer).map(|c| c.fsm.state()),
145 Some(HandshakeState::Connected | HandshakeState::WaitingForDataChannel)
146 )
147 }
148 MsgPayload::User(_) | MsgPayload::Disconnect => self.is_connected(peer),
149 }
150 }
151
152 pub fn connected_peers(&self) -> HashSet<PeerID> {
153 self.connections
154 .iter()
155 .filter(|x| *x.1.fsm.state() == HandshakeState::Connected)
156 .map(|x| x.0.clone())
157 .collect()
158 }
159
160 pub fn connected_number(&self) -> usize {
161 self.connections.iter().fold(0, |a, x| {
162 if *x.1.fsm.state() == HandshakeState::Connected {
163 a + 1
164 } else {
165 a
166 }
167 })
168 }
169
170 pub fn handle_init_handshake<Msg: UserMsgPayload>(
171 &mut self,
172 with: PeerID,
173 mode: HandshakeMode,
174 strategy: HandshakeStrategy,
175 ) -> Result<Vec<Output<Msg>>> {
176 self.connections.insert(
177 with,
178 HandshakeContext {
179 fsm: HandshakeFSM::new(strategy),
180 mode,
181 },
182 );
183 Ok(vec![])
184 }
185
186 pub fn handle_init_open_offer<Msg: UserMsgPayload>(&mut self) -> Result<Vec<Output<Msg>>> {
187 let mut ctx = HandshakeContext {
188 fsm: HandshakeFSM::new(HandshakeStrategy::Host),
189 mode: HandshakeMode::Bootstrap,
190 };
191 ctx.fsm.process(HandshakeInput::Init)?;
192 self.pending_handshakes.push_back(ctx);
193 Ok(vec![Output::InitOpenOffer])
194 }
195
196 pub fn handle_open_offer_created<Msg: UserMsgPayload>(
197 &mut self,
198 sdp: String,
199 ) -> Result<Vec<Output<Msg>>> {
200 let offer = SignalingPayload {
201 token: self.identity.create_token(&sdp)?,
202 pubkey: self.identity.pubkey(),
203 };
204 self.pending_handshakes
205 .back_mut()
206 .ok_or_else(|| anyhow!("No pending open offer"))?
207 .fsm
208 .process(HandshakeInput::OfferCreated(sdp))?;
209 Ok(vec![Output::OfferReady(offer)])
210 }
211
212 pub fn handle_send<Msg: UserMsgPayload>(
213 &mut self,
214 peer_to: PeerID,
215 data: MsgPayload<Msg>,
216 ) -> Result<Vec<Output<Msg>>> {
217 if self.is_connected(&peer_to) {
218 Ok(vec![Output::SendMessage { peer_to, data }])
219 } else {
220 Ok(vec![])
221 }
222 }
223
224 pub fn handle_broadcast<Msg: UserMsgPayload>(
225 &mut self,
226
227 data: MsgPayload<Msg>,
228 ) -> Result<Vec<Output<Msg>>> {
229 let mut out = vec![];
230 for peer in self.connections.keys() {
231 if !self.is_connected(peer) {
232 continue;
233 }
234 out.push(Output::SendMessage {
235 peer_to: peer.clone(),
236 data: data.clone(),
237 })
238 }
239 Ok(out)
240 }
241
242 pub fn process<Msg: UserMsgPayload>(&mut self, input: Input<Msg>) -> Result<Vec<Output<Msg>>> {
243 let prev_state = self.state;
244 let mut outputs = self.dispatch(input)?;
245 let new_state = self.compute_state();
246 if prev_state != new_state {
247 self.state = new_state;
248 outputs.extend(Self::state_transition_outputs::<Msg>(prev_state, new_state));
249 }
250 Ok(outputs)
251 }
252
253 fn dispatch<Msg: UserMsgPayload>(&mut self, input: Input<Msg>) -> Result<Vec<Output<Msg>>> {
254 match input {
255 Input::InitHandshake {
256 with,
257 mode,
258 strategy,
259 } => self.handle_init_handshake(with, mode, strategy),
260 Input::InitOpenOffer => self.handle_init_open_offer(),
261 Input::OpenOfferCreated(sdp) => self.handle_open_offer_created(sdp),
262 Input::Handshake { from, event } => self.handle_handshake(from, event),
263 Input::PeerLeaving { peer } => self.handle_peer_leaving(peer),
264 Input::MessageReceived { peer_from, data } => self.handle_message(peer_from, data),
265 Input::Send { peer_to, data } => self.handle_send(peer_to, data),
266 Input::Broadcast { data } => self.handle_broadcast(data),
267 Input::Leave => self.handle_leave(),
268 Input::TimerFired { kind } => self.handle_timer_fired(kind),
269 }
270 }
271
272 pub fn identity(&self) -> &Identity {
273 &self.identity
274 }
275
276 pub fn connections_snapshot(&self) -> Vec<(PeerID, HandshakeState, HandshakeMode)> {
278 self.connections
279 .iter()
280 .map(|(peer, ctx)| (peer.clone(), ctx.fsm.state().clone(), ctx.mode.clone()))
281 .collect()
282 }
283
284 pub fn pending_handshakes_len(&self) -> usize {
286 self.pending_handshakes.len()
287 }
288
289 fn handle_leave<Msg: UserMsgPayload>(&mut self) -> Result<Vec<Output<Msg>>> {
290 if self.state == FSMState::Left {
291 return Ok(vec![]);
292 }
293 let mut out = vec![];
294 for peer in self.connections.keys() {
295 if self.is_connected(peer) {
296 out.push(Output::SendMessage {
297 peer_to: peer.clone(),
298 data: MsgPayload::Disconnect,
299 });
300 }
301 }
302 self.connections.clear();
303 self.pending_handshakes.clear();
304 self.lost_peers.clear();
305 self.state = FSMState::Left;
306 Ok(out)
307 }
308
309 fn handle_peer_leaving<Msg: UserMsgPayload>(
310 &mut self,
311 peer: PeerID,
312 ) -> Result<Vec<Output<Msg>>> {
313 let was_connected = self.connections.remove(&peer);
314
315 let mut out = Vec::new();
316 if was_connected.is_some() {
317 out.push(Output::PeerDisconnected { peer });
318 }
319 Ok(out)
320 }
321
322 pub(crate) fn handle_handshake<Msg: UserMsgPayload>(
323 &mut self,
324 peer: PeerID,
325 event: HandshakeInput,
326 ) -> Result<Vec<Output<Msg>>> {
327 let mut outputs: Vec<Output<Msg>> = vec![];
328
329 if !self.connections.contains_key(&peer) {
330 match &event {
331 HandshakeInput::Answer(_) => {
332 let ctx = self
333 .pending_handshakes
334 .pop_front()
335 .ok_or_else(|| anyhow!("Pending handshake not found"))?;
336 self.connections.insert(peer.clone(), ctx);
337 }
338 HandshakeInput::ConnectionDropped => {
339 return Ok(outputs);
340 }
341 _ => return Err(anyhow!("Handshake instance with peer not found")),
342 }
343 }
344
345 let side_effects_outs = self.handle_side_effects(&peer, &event)?;
346 outputs.extend(side_effects_outs);
347
348 let handshake_out = {
349 let ctx = self.connections.get_mut(&peer);
350 if let Some(ctx) = ctx {
351 ctx.fsm.process(event.clone())?
352 } else {
353 None
354 }
355 };
356
357 if let Some(event) = handshake_out {
358 outputs.push(Output::Handshake {
359 peer: peer.clone(),
360 event,
361 });
362 }
363
364 let ctx = self.connections.get(&peer);
365 if let Some(ctx) = ctx {
366 match ctx.fsm.state() {
367 HandshakeState::Connected => {
368 self.lost_peers.remove(&peer);
369 outputs.push(Output::PeerConnected { peer: peer.clone() });
370 for existing in self.connections.keys() {
371 if !self.is_connected(existing) || *existing == peer {
372 continue;
373 }
374 outputs.push(Output::SendMessage {
375 peer_to: existing.clone(),
376 data: MsgPayload::RelaySignalingFrom {
377 src: peer.clone(),
378 data: RelayPayload::InitConnect(peer.clone()),
379 },
380 });
381 outputs.push(Output::SendMessage {
382 peer_to: peer.clone(),
383 data: MsgPayload::RelaySignalingFrom {
384 src: existing.clone(),
385 data: RelayPayload::InitConnect(existing.clone()),
386 },
387 });
388 }
389 }
390 HandshakeState::Closed => {
391 self.connections.remove(&peer);
392 self.lost_peers.entry(peer.clone()).or_default();
393 outputs.push(Output::PeerLost { peer: peer.clone() });
394 outputs.push(Output::ScheduleTimer {
395 kind: Scheduled::ReconnectAttempt { peer: peer.clone() },
396 after_ms: RECONNECT_INTERVAL_MS,
397 });
398
399 let orphans: Vec<PeerID> = self
400 .connections
401 .iter()
402 .filter(|(_, c)| {
403 matches!(&c.mode, HandshakeMode::Relay(via) if via == &peer)
404 && *c.fsm.state() != HandshakeState::Connected
405 })
406 .map(|(id, _)| id.clone())
407 .collect();
408 for orphan in orphans {
409 self.connections.remove(&orphan);
410 self.lost_peers.entry(orphan.clone()).or_default();
411 outputs.push(Output::ScheduleTimer {
412 kind: Scheduled::ReconnectAttempt { peer: orphan },
413 after_ms: RECONNECT_INTERVAL_MS,
414 });
415 }
416 }
417 _ => {}
418 }
419 }
420
421 Ok(outputs)
422 }
423
424 fn handle_side_effects<Msg: UserMsgPayload>(
425 &mut self,
426 peer: &PeerID,
427 event: &HandshakeInput,
428 ) -> Result<Vec<Output<Msg>>> {
429 let ctx = self.connections.get(peer).unwrap();
430 let mut outputs: Vec<Output<Msg>> = vec![];
431 match &event {
432 HandshakeInput::Offer(payload) | HandshakeInput::Answer(payload) => {
433 payload.get_sdp_verified(peer)?;
434 }
435 HandshakeInput::AnswerCreated(answer) => {
436 let answer = SignalingPayload {
437 token: self.identity.create_token(answer)?,
438 pubkey: self.identity.pubkey(),
439 };
440 match &ctx.mode {
441 HandshakeMode::Bootstrap => outputs.push(Output::AnswerReady(answer)),
442 HandshakeMode::Relay(via) => {
443 outputs.push(Output::SendMessage {
444 peer_to: via.clone(),
445 data: MsgPayload::RelaySignalingTo {
446 dst: peer.clone(),
447 data: RelayPayload::Answer(answer),
448 },
449 });
450 }
451 }
452 }
453 HandshakeInput::OfferCreated(offer) => {
454 let offer = SignalingPayload {
455 token: self.identity.create_token(offer)?,
456 pubkey: self.identity.pubkey(),
457 };
458 match &ctx.mode {
459 HandshakeMode::Bootstrap => outputs.push(Output::OfferReady(offer)),
460 HandshakeMode::Relay(via) => {
461 outputs.push(Output::SendMessage {
462 peer_to: via.clone(),
463 data: MsgPayload::RelaySignalingTo {
464 dst: peer.clone(),
465 data: RelayPayload::Offer(offer),
466 },
467 });
468 }
469 }
470 }
471 _ => {}
472 }
473 Ok(outputs)
474 }
475
476 pub(crate) fn handle_message<Msg: UserMsgPayload>(
477 &mut self,
478 peer: PeerID,
479 msg: MsgPayload<Msg>,
480 ) -> Result<Vec<Output<Msg>>> {
481 if !self.channel_open_for_msg(&peer, &msg) {
482 return Ok(vec![]);
483 }
484
485 match msg {
486 MsgPayload::RelaySignalingTo { dst, data } => {
487 self.handle_relay_signaling_to(peer, dst, data)
488 }
489 MsgPayload::RelaySignalingFrom { src, data } => {
490 self.handle_relay_signaling_from(peer, src, data)
491 }
492 MsgPayload::User(_) => Ok(vec![Output::ReceiveMessage {
493 peer_from: peer,
494 data: msg,
495 }]),
496 MsgPayload::Disconnect => self.handle_peer_leaving(peer),
497 }
498 }
499
500 fn handle_relay_signaling_to<Msg: UserMsgPayload>(
501 &mut self,
502 src: PeerID,
503 dst: PeerID,
504 data: RelayPayload,
505 ) -> Result<Vec<Output<Msg>>> {
506 Ok(vec![Output::SendMessage {
507 peer_to: dst,
508 data: MsgPayload::RelaySignalingFrom { src, data },
509 }])
510 }
511
512 fn handle_timer_fired<Msg: UserMsgPayload>(
513 &mut self,
514 kind: Scheduled,
515 ) -> Result<Vec<Output<Msg>>> {
516 match kind {
517 Scheduled::ReconnectAttempt { peer } => self.handle_reconnect_attempt(peer),
518 }
519 }
520
521 fn handle_reconnect_attempt<Msg: UserMsgPayload>(
522 &mut self,
523 peer: PeerID,
524 ) -> Result<Vec<Output<Msg>>> {
525 if !self.lost_peers.contains_key(&peer) {
526 return Ok(vec![]);
527 }
528 if self.is_connected(&peer) {
529 self.lost_peers.remove(&peer);
530 return Ok(vec![]);
531 }
532 if self.connections.contains_key(&peer) {
533 return Ok(vec![]);
534 }
535
536 let attempts = {
537 let state = self.lost_peers.get_mut(&peer).unwrap();
538 state.attempts += 1;
539 state.attempts
540 };
541
542 let mut outputs: Vec<Output<Msg>> = vec![];
543
544 if attempts > MAX_RECONNECT_ATTEMPTS {
545 self.lost_peers.remove(&peer);
546 return Ok(outputs);
547 }
548
549 let relay_peer = match self.connected_peers().into_iter().min() {
550 Some(peer) => peer,
551 None => {
552 self.lost_peers.remove(&peer);
553 return Ok(outputs);
554 }
555 };
556
557 let i_am_host = self.id < peer;
558 outputs.push(Output::SendMessage {
559 peer_to: relay_peer.clone(),
560 data: MsgPayload::RelaySignalingTo {
561 dst: peer.clone(),
562 data: RelayPayload::InitConnect(self.id.clone()),
563 },
564 });
565
566 if i_am_host {
567 let init_outs = self.process::<Msg>(Input::InitHandshake {
568 with: peer.clone(),
569 mode: HandshakeMode::Relay(relay_peer),
570 strategy: HandshakeStrategy::Host,
571 })?;
572 outputs.extend(init_outs);
573 let step_outs = self.process::<Msg>(Input::Handshake {
574 from: peer.clone(),
575 event: HandshakeInput::Init,
576 })?;
577 outputs.extend(step_outs);
578 }
579
580 outputs.push(Output::ScheduleTimer {
581 kind: Scheduled::ReconnectAttempt { peer },
582 after_ms: RECONNECT_INTERVAL_MS,
583 });
584
585 Ok(outputs)
586 }
587
588 fn handle_relay_signaling_from<Msg: UserMsgPayload>(
589 &mut self,
590 via: PeerID,
591 src: PeerID,
592 data: RelayPayload,
593 ) -> Result<Vec<Output<Msg>>> {
594 match data {
595 RelayPayload::InitConnect(_) => {
596 if self.connections.contains_key(&src) {
597 return Ok(vec![]);
598 }
599 let strategy = if self.id < src {
600 HandshakeStrategy::Host
601 } else {
602 HandshakeStrategy::Joiner
603 };
604 self.process::<Msg>(Input::InitHandshake {
605 with: src.clone(),
606 mode: HandshakeMode::Relay(via),
607 strategy: strategy.clone(),
608 })?;
609 match strategy {
610 HandshakeStrategy::Host => self.process::<Msg>(Input::Handshake {
611 from: src,
612 event: HandshakeInput::Init,
613 }),
614 HandshakeStrategy::Joiner => Ok(vec![]),
615 }
616 }
617 RelayPayload::Offer(offer) => self.process::<Msg>(Input::Handshake {
618 from: src,
619 event: HandshakeInput::Offer(offer),
620 }),
621 RelayPayload::Answer(answer) => self.process::<Msg>(Input::Handshake {
622 from: src,
623 event: HandshakeInput::Answer(answer),
624 }),
625 }
626 }
627}