1use crate::{
2 HandshakeFSM, HandshakeInput, HandshakeMode, HandshakeState, HandshakeStrategy, Identity,
3 Input, MAX_RECONNECT_ATTEMPTS, MsgPayload, Output, PeerID, RECONNECT_INTERVAL_MS, RelayPayload,
4 Scheduled, SignalingPayload, UserMsgPayload,
5};
6use anyhow::{Result, anyhow};
7use std::collections::{HashMap, HashSet, VecDeque};
8
9#[derive(Debug, Default, Clone)]
10struct DroppedPeerState {
11 attempts: u32,
12}
13
14pub struct HandshakeContext {
15 pub fsm: HandshakeFSM,
16 pub mode: HandshakeMode,
17}
18
19pub struct MeshNodeFSM {
22 id: PeerID,
24
25 identity: Identity,
27
28 connections: HashMap<PeerID, HandshakeContext>,
30
31 pending_handshakes: VecDeque<HandshakeContext>,
33
34 lost_peers: HashMap<PeerID, DroppedPeerState>,
37}
38
39impl Default for MeshNodeFSM {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl MeshNodeFSM {
46 pub fn new() -> Self {
47 Self::with_identity(Identity::new())
48 }
49
50 pub fn with_identity(identity: Identity) -> Self {
51 Self {
52 id: identity.peer_id(),
53 identity,
54 connections: HashMap::new(),
55 pending_handshakes: VecDeque::new(),
56 lost_peers: HashMap::new(),
57 }
58 }
59
60 pub fn id(&self) -> &PeerID {
61 &self.id
62 }
63
64 pub fn is_connected(&self, peer: &PeerID) -> bool {
65 self.connections.contains_key(peer)
66 && *self.connections.get(peer).unwrap().fsm.state() == HandshakeState::Connected
67 }
68
69 pub fn channel_open_for_msg<Msg: UserMsgPayload>(
71 &self,
72 peer: &PeerID,
73 msg: &MsgPayload<Msg>,
74 ) -> bool {
75 match msg {
76 MsgPayload::RelaySignalingTo { .. } | MsgPayload::RelaySignalingFrom { .. } => {
77 matches!(
78 self.connections.get(peer).map(|c| c.fsm.state()),
79 Some(HandshakeState::Connected | HandshakeState::WaitingForDataChannel)
80 )
81 }
82 MsgPayload::User(_) | MsgPayload::Disconnect => self.is_connected(peer),
83 }
84 }
85
86 pub fn connected_peers(&self) -> HashSet<PeerID> {
87 self.connections
88 .iter()
89 .filter(|x| *x.1.fsm.state() == HandshakeState::Connected)
90 .map(|x| x.0.clone())
91 .collect()
92 }
93
94 pub fn connected_number(&self) -> usize {
95 self.connections.iter().fold(0, |a, x| {
96 if *x.1.fsm.state() == HandshakeState::Connected {
97 a + 1
98 } else {
99 a
100 }
101 })
102 }
103
104 pub fn handle_init_handshake<Msg: UserMsgPayload>(
105 &mut self,
106 with: PeerID,
107 mode: HandshakeMode,
108 strategy: HandshakeStrategy,
109 ) -> Result<Vec<Output<Msg>>> {
110 self.connections.insert(
111 with,
112 HandshakeContext {
113 fsm: HandshakeFSM::new(strategy),
114 mode,
115 },
116 );
117 Ok(vec![])
118 }
119
120 pub fn handle_init_open_offer<Msg: UserMsgPayload>(&mut self) -> Result<Vec<Output<Msg>>> {
121 let mut ctx = HandshakeContext {
122 fsm: HandshakeFSM::new(HandshakeStrategy::Host),
123 mode: HandshakeMode::Bootstrap,
124 };
125 ctx.fsm.process(HandshakeInput::Init)?;
126 self.pending_handshakes.push_back(ctx);
127 Ok(vec![Output::InitOpenOffer])
128 }
129
130 pub fn handle_open_offer_created<Msg: UserMsgPayload>(
131 &mut self,
132 sdp: String,
133 ) -> Result<Vec<Output<Msg>>> {
134 let offer = SignalingPayload {
135 token: self.identity.create_token(&sdp)?,
136 pubkey: self.identity.pubkey(),
137 };
138 self.pending_handshakes
139 .back_mut()
140 .ok_or_else(|| anyhow!("No pending open offer"))?
141 .fsm
142 .process(HandshakeInput::OfferCreated(sdp))?;
143 Ok(vec![Output::OfferReady(offer)])
144 }
145
146 pub fn handle_send<Msg: UserMsgPayload>(
147 &mut self,
148 peer_to: PeerID,
149 data: MsgPayload<Msg>,
150 ) -> Result<Vec<Output<Msg>>> {
151 if self.is_connected(&peer_to) {
152 Ok(vec![Output::SendMessage { peer_to, data }])
153 } else {
154 Ok(vec![])
155 }
156 }
157
158 pub fn handle_broadcast<Msg: UserMsgPayload>(
159 &mut self,
160
161 data: MsgPayload<Msg>,
162 ) -> Result<Vec<Output<Msg>>> {
163 let mut out = vec![];
164 for peer in self.connections.keys() {
165 if !self.is_connected(peer) {
166 continue;
167 }
168 out.push(Output::SendMessage {
169 peer_to: peer.clone(),
170 data: data.clone(),
171 })
172 }
173 Ok(out)
174 }
175
176 pub fn process<Msg: UserMsgPayload>(&mut self, input: Input<Msg>) -> Result<Vec<Output<Msg>>> {
177 match input {
178 Input::InitHandshake {
179 with,
180 mode,
181 strategy,
182 } => self.handle_init_handshake(with, mode, strategy),
183 Input::InitOpenOffer => self.handle_init_open_offer(),
184 Input::OpenOfferCreated(sdp) => self.handle_open_offer_created(sdp),
185 Input::Handshake { from, event } => self.handle_handshake(from, event),
186 Input::PeerLeaving { peer } => self.handle_peer_leaving(peer),
187 Input::MessageReceived { peer_from, data } => self.handle_message(peer_from, data),
188 Input::Send { peer_to, data } => self.handle_send(peer_to, data),
189 Input::Broadcast { data } => self.handle_broadcast(data),
190 Input::Leave => self.handle_leave(),
191 Input::TimerFired { kind } => self.handle_timer_fired(kind),
192 }
193 }
194
195 pub fn identity(&self) -> &Identity {
196 &self.identity
197 }
198
199 pub fn connections_snapshot(&self) -> Vec<(PeerID, HandshakeState, HandshakeMode)> {
201 self.connections
202 .iter()
203 .map(|(peer, ctx)| (peer.clone(), ctx.fsm.state().clone(), ctx.mode.clone()))
204 .collect()
205 }
206
207 pub fn pending_handshakes_len(&self) -> usize {
209 self.pending_handshakes.len()
210 }
211
212 fn handle_leave<Msg: UserMsgPayload>(&mut self) -> Result<Vec<Output<Msg>>> {
213 let mut out = vec![Output::Disconnecting];
214 for peer in self.connections.keys() {
215 if self.is_connected(peer) {
216 out.push(Output::SendMessage {
217 peer_to: peer.clone(),
218 data: MsgPayload::Disconnect,
219 });
220 }
221 }
222 self.connections.clear();
223 self.pending_handshakes.clear();
224 self.lost_peers.clear();
225 Ok(out)
226 }
227
228 fn handle_peer_leaving<Msg: UserMsgPayload>(
229 &mut self,
230 peer: PeerID,
231 ) -> Result<Vec<Output<Msg>>> {
232 let was_connected = self.connections.remove(&peer);
233
234 let mut out = Vec::new();
235 if was_connected.is_some() {
236 out.push(Output::PeerDisconnected { peer });
237 if self
238 .connections
239 .values()
240 .any(|ctx| *ctx.fsm.state() != HandshakeState::Connected)
241 {
242 out.push(Output::Unavailable);
243 }
244 }
245 Ok(out)
246 }
247
248 pub(crate) fn handle_handshake<Msg: UserMsgPayload>(
249 &mut self,
250 peer: PeerID,
251 event: HandshakeInput,
252 ) -> Result<Vec<Output<Msg>>> {
253 let mut outputs: Vec<Output<Msg>> = vec![];
254
255 if !self.connections.contains_key(&peer) {
256 match &event {
257 HandshakeInput::Answer(_) => {
258 let ctx = self
259 .pending_handshakes
260 .pop_front()
261 .ok_or_else(|| anyhow!("Pending handshake not found"))?;
262 self.connections.insert(peer.clone(), ctx);
263 }
264 HandshakeInput::ConnectionDropped => {
265 return Ok(outputs);
266 }
267 _ => return Err(anyhow!("Handshake instance with peer not found")),
268 }
269 }
270
271 let side_effects_outs = self.handle_side_effects(&peer, &event)?;
272 outputs.extend(side_effects_outs);
273
274 let handshake_out = {
275 let ctx = self.connections.get_mut(&peer);
276 if let Some(ctx) = ctx {
277 ctx.fsm.process(event.clone())?
278 } else {
279 None
280 }
281 };
282
283 if let Some(event) = handshake_out {
284 outputs.push(Output::Handshake {
285 peer: peer.clone(),
286 event,
287 });
288 }
289
290 let ctx = self.connections.get(&peer);
291 if let Some(ctx) = ctx {
292 match ctx.fsm.state() {
293 HandshakeState::Connected => {
294 self.lost_peers.remove(&peer);
295 outputs.push(Output::PeerConnected { peer: peer.clone() });
296 for existing in self.connections.keys() {
297 if !self.is_connected(existing) || *existing == peer {
298 continue;
299 }
300 outputs.push(Output::SendMessage {
301 peer_to: existing.clone(),
302 data: MsgPayload::RelaySignalingFrom {
303 src: peer.clone(),
304 data: RelayPayload::InitConnect(peer.clone()),
305 },
306 });
307 outputs.push(Output::SendMessage {
308 peer_to: peer.clone(),
309 data: MsgPayload::RelaySignalingFrom {
310 src: existing.clone(),
311 data: RelayPayload::InitConnect(existing.clone()),
312 },
313 });
314 }
315 }
316 HandshakeState::Closed => {
317 self.connections.remove(&peer);
318 self.lost_peers.entry(peer.clone()).or_default();
319 outputs.push(Output::PeerLost { peer: peer.clone() });
320 outputs.push(Output::ScheduleTimer {
321 kind: Scheduled::ReconnectAttempt { peer: peer.clone() },
322 after_ms: RECONNECT_INTERVAL_MS,
323 });
324
325 let orphans: Vec<PeerID> = self
326 .connections
327 .iter()
328 .filter(|(_, c)| {
329 matches!(&c.mode, HandshakeMode::Relay(via) if via == &peer)
330 && *c.fsm.state() != HandshakeState::Connected
331 })
332 .map(|(id, _)| id.clone())
333 .collect();
334 for orphan in orphans {
335 self.connections.remove(&orphan);
336 self.lost_peers.entry(orphan.clone()).or_default();
337 outputs.push(Output::ScheduleTimer {
338 kind: Scheduled::ReconnectAttempt { peer: orphan },
339 after_ms: RECONNECT_INTERVAL_MS,
340 });
341 }
342 }
343 _ => {}
344 }
345 }
346
347 let in_progress_relays = self
348 .connections
349 .values()
350 .filter(|ctx| {
351 matches!(ctx.mode, HandshakeMode::Relay(_))
352 && *ctx.fsm.state() != HandshakeState::Connected
353 })
354 .count();
355 if in_progress_relays == 0 && !self.connected_peers().is_empty() {
356 outputs.push(Output::Available);
357 }
358
359 Ok(outputs)
360 }
361
362 fn handle_side_effects<Msg: UserMsgPayload>(
363 &mut self,
364 peer: &PeerID,
365 event: &HandshakeInput,
366 ) -> Result<Vec<Output<Msg>>> {
367 let ctx = self.connections.get(peer).unwrap();
368 let mut outputs: Vec<Output<Msg>> = vec![];
369 match &event {
370 HandshakeInput::Offer(payload) | HandshakeInput::Answer(payload) => {
371 payload.get_sdp_verified(peer)?;
372 }
373 HandshakeInput::AnswerCreated(answer) => {
374 let answer = SignalingPayload {
375 token: self.identity.create_token(answer)?,
376 pubkey: self.identity.pubkey(),
377 };
378 match &ctx.mode {
379 HandshakeMode::Bootstrap => outputs.push(Output::AnswerReady(answer)),
380 HandshakeMode::Relay(via) => {
381 outputs.push(Output::SendMessage {
382 peer_to: via.clone(),
383 data: MsgPayload::RelaySignalingTo {
384 dst: peer.clone(),
385 data: RelayPayload::Answer(answer),
386 },
387 });
388 }
389 }
390 }
391 HandshakeInput::OfferCreated(offer) => {
392 let offer = SignalingPayload {
393 token: self.identity.create_token(offer)?,
394 pubkey: self.identity.pubkey(),
395 };
396 match &ctx.mode {
397 HandshakeMode::Bootstrap => outputs.push(Output::OfferReady(offer)),
398 HandshakeMode::Relay(via) => {
399 outputs.push(Output::SendMessage {
400 peer_to: via.clone(),
401 data: MsgPayload::RelaySignalingTo {
402 dst: peer.clone(),
403 data: RelayPayload::Offer(offer),
404 },
405 });
406 }
407 }
408 }
409 _ => {}
410 }
411 Ok(outputs)
412 }
413
414 pub(crate) fn handle_message<Msg: UserMsgPayload>(
415 &mut self,
416 peer: PeerID,
417 msg: MsgPayload<Msg>,
418 ) -> Result<Vec<Output<Msg>>> {
419 if !self.channel_open_for_msg(&peer, &msg) {
420 return Ok(vec![]);
421 }
422
423 match msg {
424 MsgPayload::RelaySignalingTo { dst, data } => {
425 self.handle_relay_signaling_to(peer, dst, data)
426 }
427 MsgPayload::RelaySignalingFrom { src, data } => {
428 self.handle_relay_signaling_from(peer, src, data)
429 }
430 MsgPayload::User(_) => Ok(vec![Output::ReceiveMessage {
431 peer_from: peer,
432 data: msg,
433 }]),
434 MsgPayload::Disconnect => self.handle_peer_leaving(peer),
435 }
436 }
437
438 fn handle_relay_signaling_to<Msg: UserMsgPayload>(
439 &mut self,
440 src: PeerID,
441 dst: PeerID,
442 data: RelayPayload,
443 ) -> Result<Vec<Output<Msg>>> {
444 Ok(vec![Output::SendMessage {
445 peer_to: dst,
446 data: MsgPayload::RelaySignalingFrom { src, data },
447 }])
448 }
449
450 fn handle_timer_fired<Msg: UserMsgPayload>(
451 &mut self,
452 kind: Scheduled,
453 ) -> Result<Vec<Output<Msg>>> {
454 match kind {
455 Scheduled::ReconnectAttempt { peer } => self.handle_reconnect_attempt(peer),
456 }
457 }
458
459 fn handle_reconnect_attempt<Msg: UserMsgPayload>(
460 &mut self,
461 peer: PeerID,
462 ) -> Result<Vec<Output<Msg>>> {
463 if !self.lost_peers.contains_key(&peer) {
464 return Ok(vec![]);
465 }
466 if self.is_connected(&peer) {
467 self.lost_peers.remove(&peer);
468 return Ok(vec![]);
469 }
470 if self.connections.contains_key(&peer) {
471 return Ok(vec![]);
472 }
473
474 let attempts = {
475 let state = self.lost_peers.get_mut(&peer).unwrap();
476 state.attempts += 1;
477 state.attempts
478 };
479
480 let mut outputs: Vec<Output<Msg>> = vec![];
481
482 if attempts > MAX_RECONNECT_ATTEMPTS {
483 self.lost_peers.remove(&peer);
484 return Ok(outputs);
485 }
486
487 let relay_peer = match self.connected_peers().into_iter().min() {
488 Some(peer) => peer,
489 None => {
490 self.lost_peers.remove(&peer);
491 return Ok(outputs);
492 }
493 };
494
495 let i_am_host = self.id < peer;
496 outputs.push(Output::SendMessage {
497 peer_to: relay_peer.clone(),
498 data: MsgPayload::RelaySignalingTo {
499 dst: peer.clone(),
500 data: RelayPayload::InitConnect(self.id.clone()),
501 },
502 });
503
504 if i_am_host {
505 let init_outs = self.process::<Msg>(Input::InitHandshake {
506 with: peer.clone(),
507 mode: HandshakeMode::Relay(relay_peer),
508 strategy: HandshakeStrategy::Host,
509 })?;
510 outputs.extend(init_outs);
511 let step_outs = self.process::<Msg>(Input::Handshake {
512 from: peer.clone(),
513 event: HandshakeInput::Init,
514 })?;
515 outputs.extend(step_outs);
516 }
517
518 outputs.push(Output::ScheduleTimer {
519 kind: Scheduled::ReconnectAttempt { peer },
520 after_ms: RECONNECT_INTERVAL_MS,
521 });
522
523 Ok(outputs)
524 }
525
526 fn handle_relay_signaling_from<Msg: UserMsgPayload>(
527 &mut self,
528 via: PeerID,
529 src: PeerID,
530 data: RelayPayload,
531 ) -> Result<Vec<Output<Msg>>> {
532 match data {
533 RelayPayload::InitConnect(_) => {
534 if self.connections.contains_key(&src) {
535 return Ok(vec![]);
536 }
537 let strategy = if self.id < src {
538 HandshakeStrategy::Host
539 } else {
540 HandshakeStrategy::Joiner
541 };
542 self.process::<Msg>(Input::InitHandshake {
543 with: src.clone(),
544 mode: HandshakeMode::Relay(via),
545 strategy: strategy.clone(),
546 })?;
547 match strategy {
548 HandshakeStrategy::Host => self.process::<Msg>(Input::Handshake {
549 from: src,
550 event: HandshakeInput::Init,
551 }),
552 HandshakeStrategy::Joiner => Ok(vec![Output::Unavailable]),
553 }
554 }
555 RelayPayload::Offer(offer) => self.process::<Msg>(Input::Handshake {
556 from: src,
557 event: HandshakeInput::Offer(offer),
558 }),
559 RelayPayload::Answer(answer) => self.process::<Msg>(Input::Handshake {
560 from: src,
561 event: HandshakeInput::Answer(answer),
562 }),
563 }
564 }
565}