gwyh/
handshake.rs

1use std::collections::{HashMap, HashSet};
2use std::future::Future;
3use std::marker::PhantomData;
4use std::net::SocketAddr;
5use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use dryoc::dryocbox::PublicKey;
9use dryoc::generichash::GenericHash;
10use dryoc::rng::randombytes_buf;
11use dryoc::types::StackByteArray;
12use genserver::GenServer;
13use rand::rngs::StdRng;
14use rand::{Rng, SeedableRng};
15use serde::{Deserialize, Serialize};
16use tokio::time::sleep;
17use tracing::error;
18use uuid::Uuid;
19
20use crate::and_then::AndThen;
21use crate::delayed::Delayed;
22use crate::message::{Body, Message};
23use crate::packet::Payload;
24use crate::packet_handler::PacketHandlerMessage;
25use crate::peer::Peer;
26use crate::peer_manager::{PeerManagerRequest, PeerRequest, PeerState, PeerStatus};
27use crate::registry::Registry;
28
29type Hash = StackByteArray<64>;
30type Key = StackByteArray<64>;
31
32#[derive(Serialize, Deserialize, Debug, Clone)]
33pub enum HandshakeStep {
34    Hello {
35        pk: PublicKey,
36        zone: Option<String>,
37        hmacs: Vec<Hmac>,
38        timestamp: DateTime<Utc>,
39    },
40    Ohai {
41        pk: PublicKey,
42        zone: Option<String>,
43        hmacs: Vec<Hmac>,
44        challenge: Vec<u8>,
45    },
46    OkBoss {
47        response: Vec<Vec<u8>>,
48    },
49    GoTeam,
50    Dead,
51}
52
53#[derive(Serialize, Deserialize, Debug, Clone)]
54pub struct Handshake {
55    pub(crate) id: Uuid,
56    pub(crate) step: HandshakeStep,
57}
58
59#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
60pub struct Hmac {
61    hmac: Hash,
62    message: Vec<u8>,
63}
64
65impl Hmac {
66    pub fn new(hashed_key: &Hash, timestamp: &DateTime<Utc>, len: usize) -> Self {
67        let message = randombytes_buf(len);
68        let hmac = Self::calculate_hmac(hashed_key, &message, timestamp);
69        Self { hmac, message }
70    }
71
72    pub fn is_valid_timestamp_for(timestamp: &DateTime<Utc>, now: &DateTime<Utc>) -> bool {
73        // timestamp must be within handshake window
74        let difference = now.signed_duration_since(*timestamp);
75        if let Some(nanos) = difference.num_nanoseconds() {
76            nanos.abs() < 1000 * HANDSHAKE_TIMEOUT_MILLIS as i64
77        } else {
78            false
79        }
80    }
81
82    pub fn valid_for_key(&self, timestamp: &DateTime<Utc>, hashed_key: &Hash) -> bool {
83        let hmac = Self::calculate_hmac(hashed_key, &self.message, timestamp);
84        hmac == self.hmac
85    }
86
87    fn calculate_hmac(hashed_key: &Hash, message: &Vec<u8>, timestamp: &DateTime<Utc>) -> Hash {
88        let mut hasher = GenericHash::new(Some(hashed_key)).unwrap();
89        hasher.update(message.as_slice());
90        hasher.update(
91            timestamp
92                .timestamp_nanos_opt()
93                .expect("timestamp_nanos_opt failed")
94                .to_string()
95                .as_bytes(),
96        );
97        hasher.finalize().expect("HMAC failed")
98    }
99}
100
101struct HandshakeState<State> {
102    id: Uuid,
103    peer: Peer,
104    challenge: Option<Vec<u8>>,
105    response: Option<Vec<Vec<u8>>>,
106    _timer: Option<Delayed>,
107    and_then: Option<AndThen<()>>,
108    timestamp: DateTime<Utc>,
109    phantom: PhantomData<State>,
110}
111
112#[derive(Debug)]
113struct Hello;
114#[derive(Debug)]
115struct Ohai;
116#[derive(Debug)]
117struct OkBoss;
118
119const HANDSHAKE_TIMEOUT_MILLIS: u64 = 60_000;
120const HANDSHAKE_BOOTSTRAP_DELAY_MILLIS_MIN: u64 = 10;
121const HANDSHAKE_BOOTSTRAP_DELAY_MILLIS_MAX: u64 = 150;
122
123impl<S> HandshakeState<S> {
124    fn start_timer(registry: Registry, id: Uuid, sockaddr: SocketAddr) -> Delayed {
125        Delayed::new(
126            Duration::from_millis(HANDSHAKE_TIMEOUT_MILLIS),
127            async move {
128                registry
129                    .cast_handshaker(HandshakerMessage::Packet(
130                        sockaddr,
131                        Handshake {
132                            id,
133                            step: HandshakeStep::Dead,
134                        },
135                        Uuid::nil(),
136                    ))
137                    .await
138                    .ok();
139            },
140        )
141    }
142}
143
144impl HandshakeState<Hello> {
145    #[tracing::instrument(skip(registry))]
146    async fn new(registry: Registry, sockaddr: SocketAddr, and_then: Option<AndThen<()>>) -> Self {
147        let id = Uuid::new_v4();
148        let timestamp = Utc::now();
149        Self {
150            id,
151            peer: Peer::new(None, sockaddr, Uuid::nil(), Some(timestamp), None),
152            challenge: None,
153            response: None,
154            _timer: Some(Self::start_timer(registry, id, sockaddr)),
155            and_then,
156            timestamp,
157            phantom: PhantomData,
158        }
159    }
160}
161
162impl HandshakeState<Ohai> {
163    #[tracing::instrument(skip(registry))]
164    async fn from(
165        registry: Registry,
166        sockaddr: SocketAddr,
167        node_id: Uuid,
168        id: Uuid,
169        their_public_key: PublicKey,
170        timestamp: DateTime<Utc>,
171        zone: Option<String>,
172    ) -> Self {
173        Self {
174            id,
175            peer: Peer::new(
176                Some(their_public_key),
177                sockaddr,
178                node_id,
179                Some(timestamp),
180                zone,
181            ),
182            challenge: Some(randombytes_buf(69)),
183            response: None,
184            _timer: Some(Self::start_timer(registry, id, sockaddr)),
185            and_then: None,
186            timestamp,
187            phantom: PhantomData,
188        }
189    }
190}
191
192fn compute_response(
193    id: &Uuid,
194    public_key: &PublicKey,
195    challenge: &Vec<u8>,
196    keys: &[String],
197) -> Vec<Vec<u8>> {
198    keys.iter()
199        .map(|key| {
200            let mut hasher: GenericHash<64, 64> =
201                GenericHash::new::<Key>(None).expect("new failed");
202            hasher.update(key.as_bytes());
203            hasher.update(challenge);
204            hasher.update(public_key);
205            hasher.update(id.as_bytes());
206            hasher.finalize_to_vec().expect("finalize failed")
207        })
208        .collect()
209}
210
211impl HandshakeState<Hello> {
212    fn got_ohai(
213        self,
214        node_id: Uuid,
215        public_key: &PublicKey,
216        challenge: &Vec<u8>,
217        keys: &[String],
218        zone: Option<String>,
219    ) -> HandshakeState<OkBoss> {
220        let mut peer = self.peer;
221        peer.set_their_public_key(public_key.clone());
222        peer.set_id(node_id);
223        peer.set_zone(zone);
224        let response = compute_response(&self.id, public_key, challenge, keys);
225        HandshakeState {
226            id: self.id,
227            peer,
228            response: Some(response),
229            challenge: self.challenge,
230            _timer: None,
231            and_then: self.and_then,
232            timestamp: self.timestamp,
233            phantom: PhantomData,
234        }
235    }
236}
237
238impl HandshakeState<Ohai> {
239    fn got_okboss(self, response: &[Vec<u8>], keys: &[String]) -> Result<Self, Self> {
240        // check response matches what we expected
241        let pk = &self.peer.our_keypair().unwrap().public_key;
242        let expected_response =
243            compute_response(&self.id, pk, self.challenge.as_ref().unwrap(), keys);
244
245        let expected: HashSet<Vec<u8>> = HashSet::from_iter(expected_response.iter().cloned());
246        let received: HashSet<Vec<u8>> = HashSet::from_iter(response.iter().cloned());
247
248        let intersection: HashSet<_> = expected.intersection(&received).collect();
249
250        if intersection.is_empty() {
251            Err(self)
252        } else {
253            Ok(self)
254        }
255    }
256}
257
258pub struct Handshaker {
259    registry: Registry,
260    hellos: HashMap<(Uuid, SocketAddr), HandshakeState<Hello>>,
261    ohais: HashMap<(Uuid, SocketAddr), HandshakeState<Ohai>>,
262    okboss: HashMap<(Uuid, SocketAddr), HandshakeState<OkBoss>>,
263    inflight: HashSet<SocketAddr>,
264    rng: StdRng,
265    hashed_keys: Option<Vec<Hash>>,
266    bootstrap: Option<Delayed>,
267}
268
269#[derive(Debug)]
270pub enum HandshakerMessage {
271    Bootstrap,
272    Packet(SocketAddr, Handshake, Uuid),
273    SendHello(SocketAddr, Option<AndThen<()>>),
274}
275
276impl GenServer for Handshaker {
277    type Message = HandshakerMessage;
278    type Registry = Registry;
279    type Response = ();
280
281    type CallResponse<'a> = impl Future<Output = Self::Response> + 'a;
282    type CastResponse<'a> = impl Future<Output = ()> + 'a;
283
284    fn new(registry: Self::Registry) -> Self {
285        Self {
286            registry,
287            hellos: HashMap::new(),
288            ohais: HashMap::new(),
289            okboss: HashMap::new(),
290            inflight: HashSet::new(),
291            rng: SeedableRng::from_entropy(),
292            hashed_keys: None,
293            bootstrap: None,
294        }
295    }
296
297    fn handle_call(&mut self, message: Self::Message) -> Self::CallResponse<'_> {
298        async { self.handle_message(message).await }
299    }
300
301    fn handle_cast(&mut self, message: Self::Message) -> Self::CastResponse<'_> {
302        async { self.handle_message(message).await }
303    }
304}
305
306impl Handshaker {
307    async fn handle_message(&mut self, message: HandshakerMessage) {
308        match message {
309            HandshakerMessage::Bootstrap => self.bootstrap().await,
310            HandshakerMessage::Packet(sockaddr, handshake, node_id) => {
311                self.handle_handshake(sockaddr, handshake, node_id).await;
312            }
313            HandshakerMessage::SendHello(sockaddr, and_then) => {
314                self.send_hello(sockaddr, and_then).await
315            }
316        }
317    }
318
319    fn has_valid_hmac(&self, timestamp: &DateTime<Utc>, hmacs: &[Hmac]) -> bool {
320        let now = Utc::now();
321        if let Some(hashed_keys) = &self.hashed_keys {
322            Hmac::is_valid_timestamp_for(timestamp, &now)
323                && hmacs.iter().any(|hmac| {
324                    hashed_keys
325                        .iter()
326                        .any(|hashed_key| hmac.valid_for_key(timestamp, hashed_key))
327                })
328        } else {
329            false
330        }
331    }
332
333    fn compute_hmacs(&mut self, timestamp: &DateTime<Utc>) -> Vec<Hmac> {
334        if let Some(hashed_keys) = &self.hashed_keys {
335            hashed_keys
336                .iter()
337                .map(|hashed_key| Hmac::new(hashed_key, timestamp, self.rng.gen_range(69..420)))
338                .collect()
339        } else {
340            vec![]
341        }
342    }
343
344    #[tracing::instrument(skip(self))]
345    async fn handle_handshake(
346        &mut self,
347        sockaddr: SocketAddr,
348        handshake: Handshake,
349        node_id: Uuid,
350    ) {
351        match &handshake.step {
352            HandshakeStep::Hello {
353                pk,
354                zone,
355                hmacs,
356                timestamp,
357            } => {
358                // ignore messages if we're trying to handshake with ourselves.
359                // first, validate the HMACs to see if any are valid. at least
360                // one needs to match or else we ignore this hello.
361                if &node_id != self.registry.nodeinfo().id()
362                    && self.has_valid_hmac(timestamp, hmacs)
363                {
364                    let handshake = HandshakeState::from(
365                        self.registry.clone(),
366                        sockaddr,
367                        node_id,
368                        handshake.id,
369                        pk.clone(),
370                        *timestamp,
371                        zone.clone(),
372                    )
373                    .await;
374
375                    self.send_ohai(
376                        handshake.id,
377                        sockaddr,
378                        handshake.peer.our_keypair().unwrap().public_key.clone(),
379                        handshake.challenge.as_ref().unwrap().clone(),
380                        &handshake.timestamp,
381                        self.inflight.contains(&sockaddr),
382                    )
383                    .await;
384                    self.inflight.insert(sockaddr);
385                    self.ohais.insert((handshake.id, sockaddr), handshake);
386                }
387            }
388            HandshakeStep::Ohai {
389                pk,
390                zone,
391                hmacs,
392                challenge,
393            } => {
394                if let Some(handshake) = self.hellos.remove(&(handshake.id, sockaddr)) {
395                    if self.has_valid_hmac(&handshake.timestamp, hmacs) {
396                        let handshake = handshake.got_ohai(
397                            node_id,
398                            pk,
399                            challenge,
400                            self.registry.nodeinfo().keys(),
401                            zone.clone(),
402                        );
403                        self.send_okboss(
404                            handshake.id,
405                            sockaddr,
406                            handshake.response.as_ref().unwrap().clone(),
407                        )
408                        .await;
409
410                        assert!(handshake.peer.is_handshaken());
411                        self.okboss.insert((handshake.id, sockaddr), handshake);
412                    }
413                }
414            }
415            HandshakeStep::OkBoss { response } => {
416                if let Some(handshake) = self.ohais.remove(&(handshake.id, sockaddr)) {
417                    if let Ok(handshake) =
418                        handshake.got_okboss(response, self.registry.nodeinfo().keys())
419                    {
420                        self.send_peer_up(handshake.peer).await;
421                        self.send_goteam(handshake.id, sockaddr).await;
422                        if let Some(and_then) = handshake.and_then {
423                            and_then.call().await;
424                        }
425                    } else {
426                        error!("bad okboss from {sockaddr}");
427                    }
428                }
429            }
430            HandshakeStep::GoTeam => {
431                if let Some(handshake) = self.okboss.remove(&(handshake.id, sockaddr)) {
432                    self.send_peer_up(handshake.peer.clone()).await;
433                    if let Some(and_then) = handshake.and_then {
434                        and_then.call().await;
435                    }
436                }
437            }
438            HandshakeStep::Dead => {
439                self.hellos.remove(&(handshake.id, sockaddr));
440                self.ohais.remove(&(handshake.id, sockaddr));
441                self.okboss.remove(&(handshake.id, sockaddr));
442                self.inflight.remove(&sockaddr);
443            }
444        }
445    }
446
447    #[tracing::instrument(skip(self))]
448    pub async fn send_hello(&mut self, sockaddr: SocketAddr, and_then: Option<AndThen<()>>) {
449        let handshake = HandshakeState::new(self.registry.clone(), sockaddr, and_then).await;
450        let payload = Payload::H(Handshake {
451            id: handshake.id,
452            step: HandshakeStep::Hello {
453                pk: handshake.peer.our_keypair().unwrap().public_key.clone(),
454                zone: self.registry.nodeinfo().zone().clone(),
455                hmacs: self.compute_hmacs(&handshake.timestamp),
456                timestamp: handshake.timestamp,
457            },
458        });
459        self.inflight.insert(sockaddr);
460        self.hellos.insert((handshake.id, sockaddr), handshake);
461        self.registry
462            .cast_packet_handler(PacketHandlerMessage::Send(sockaddr, payload))
463            .await
464            .ok();
465    }
466
467    #[tracing::instrument(skip(self))]
468    pub async fn send_ohai(
469        &mut self,
470        id: Uuid,
471        sockaddr: SocketAddr,
472        public_key: PublicKey,
473        challenge: Vec<u8>,
474        timestamp: &DateTime<Utc>,
475        inflight: bool,
476    ) {
477        let payload = Payload::H(Handshake {
478            id,
479            step: HandshakeStep::Ohai {
480                pk: public_key,
481                zone: self.registry.nodeinfo().zone().clone(),
482                hmacs: self.compute_hmacs(timestamp),
483                challenge,
484            },
485        });
486
487        let registry = self.registry.clone();
488        let millis = if inflight {
489            self.rng.gen_range(
490                HANDSHAKE_BOOTSTRAP_DELAY_MILLIS_MIN..HANDSHAKE_BOOTSTRAP_DELAY_MILLIS_MAX,
491            )
492        } else {
493            0
494        };
495        tokio::spawn(async move {
496            sleep(Duration::from_millis(millis)).await;
497            registry
498                .call_packet_handler(PacketHandlerMessage::Send(sockaddr, payload))
499                .await
500                .ok();
501        });
502    }
503
504    #[tracing::instrument(skip(self))]
505    pub async fn send_okboss(&mut self, id: Uuid, sockaddr: SocketAddr, response: Vec<Vec<u8>>) {
506        let payload = Payload::H(Handshake {
507            id,
508            step: HandshakeStep::OkBoss { response },
509        });
510        self.registry
511            .call_packet_handler(PacketHandlerMessage::Send(sockaddr, payload))
512            .await
513            .unwrap();
514    }
515
516    #[tracing::instrument(skip(self))]
517    pub async fn send_goteam(&mut self, id: Uuid, sockaddr: SocketAddr) {
518        let payload = Payload::H(Handshake {
519            id,
520            step: HandshakeStep::GoTeam,
521        });
522        self.registry
523            .call_packet_handler(PacketHandlerMessage::Send(sockaddr, payload))
524            .await
525            .unwrap();
526    }
527
528    #[tracing::instrument(skip(self))]
529    pub async fn request_peerlist(&mut self, id: Uuid, sockaddr: SocketAddr) {
530        self.registry
531            .call_packet_handler(PacketHandlerMessage::SendMessage(
532                id,
533                Message {
534                    id: Uuid::new_v4(),
535                    body: Body::PeerRequest(Box::new(PeerRequest::PeerList)),
536                },
537            ))
538            .await
539            .ok();
540    }
541
542    #[tracing::instrument(skip(self))]
543    pub async fn send_peer_up(&self, peer: Peer) {
544        self.registry
545            .call_peer_manager(PeerManagerRequest::Request {
546                from_id: *self.registry.nodeinfo().id(),
547                request: Box::new(PeerRequest::PeerStatus(PeerStatus::new(
548                    peer,
549                    PeerState::Up,
550                ))),
551            })
552            .await
553            .ok();
554    }
555
556    #[tracing::instrument(skip(self))]
557    pub async fn bootstrap(&mut self) {
558        if self.hashed_keys.is_none() {
559            // hash pre-shared keys for handshaking
560            self.hashed_keys = Some(
561                self.registry
562                    .nodeinfo()
563                    .keys()
564                    .iter()
565                    .flat_map(|key| GenericHash::hash::<_, Key, _>(key.as_bytes(), None))
566                    .collect(),
567            );
568        }
569
570        let peers: Vec<_> = self.registry.nodeinfo().peers().to_vec();
571        let registry = self.registry.clone();
572        let millis = self
573            .rng
574            .gen_range(HANDSHAKE_BOOTSTRAP_DELAY_MILLIS_MIN..HANDSHAKE_BOOTSTRAP_DELAY_MILLIS_MAX);
575        self.bootstrap = Some(Delayed::new(Duration::from_millis(millis), async move {
576            use std::net::ToSocketAddrs;
577            for peer in peers.iter() {
578                match peer.to_socket_addrs() {
579                    Ok(mut sockaddr_iter) => {
580                        if let Some(sockaddr) = sockaddr_iter.next() {
581                            registry
582                                .call_handshaker(HandshakerMessage::SendHello(sockaddr, None))
583                                .await
584                                .ok();
585                        }
586                    }
587                    Err(err) => error!("failed to resolve {peer}: {:?}", err),
588                }
589            }
590        }));
591    }
592}