1use std::collections::{HashMap, HashSet};
2use std::future::Future;
3use std::marker::PhantomData;
4use std::net::SocketAddr;
5use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use dryoc::dryocbox::PublicKey;
9use dryoc::generichash::GenericHash;
10use dryoc::rng::randombytes_buf;
11use dryoc::types::StackByteArray;
12use genserver::GenServer;
13use rand::rngs::StdRng;
14use rand::{Rng, SeedableRng};
15use serde::{Deserialize, Serialize};
16use tokio::time::sleep;
17use tracing::error;
18use uuid::Uuid;
19
20use crate::and_then::AndThen;
21use crate::delayed::Delayed;
22use crate::message::{Body, Message};
23use crate::packet::Payload;
24use crate::packet_handler::PacketHandlerMessage;
25use crate::peer::Peer;
26use crate::peer_manager::{PeerManagerRequest, PeerRequest, PeerState, PeerStatus};
27use crate::registry::Registry;
28
29type Hash = StackByteArray<64>;
30type Key = StackByteArray<64>;
31
32#[derive(Serialize, Deserialize, Debug, Clone)]
33pub enum HandshakeStep {
34 Hello {
35 pk: PublicKey,
36 zone: Option<String>,
37 hmacs: Vec<Hmac>,
38 timestamp: DateTime<Utc>,
39 },
40 Ohai {
41 pk: PublicKey,
42 zone: Option<String>,
43 hmacs: Vec<Hmac>,
44 challenge: Vec<u8>,
45 },
46 OkBoss {
47 response: Vec<Vec<u8>>,
48 },
49 GoTeam,
50 Dead,
51}
52
53#[derive(Serialize, Deserialize, Debug, Clone)]
54pub struct Handshake {
55 pub(crate) id: Uuid,
56 pub(crate) step: HandshakeStep,
57}
58
59#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
60pub struct Hmac {
61 hmac: Hash,
62 message: Vec<u8>,
63}
64
65impl Hmac {
66 pub fn new(hashed_key: &Hash, timestamp: &DateTime<Utc>, len: usize) -> Self {
67 let message = randombytes_buf(len);
68 let hmac = Self::calculate_hmac(hashed_key, &message, timestamp);
69 Self { hmac, message }
70 }
71
72 pub fn is_valid_timestamp_for(timestamp: &DateTime<Utc>, now: &DateTime<Utc>) -> bool {
73 let difference = now.signed_duration_since(*timestamp);
75 if let Some(nanos) = difference.num_nanoseconds() {
76 nanos.abs() < 1000 * HANDSHAKE_TIMEOUT_MILLIS as i64
77 } else {
78 false
79 }
80 }
81
82 pub fn valid_for_key(&self, timestamp: &DateTime<Utc>, hashed_key: &Hash) -> bool {
83 let hmac = Self::calculate_hmac(hashed_key, &self.message, timestamp);
84 hmac == self.hmac
85 }
86
87 fn calculate_hmac(hashed_key: &Hash, message: &Vec<u8>, timestamp: &DateTime<Utc>) -> Hash {
88 let mut hasher = GenericHash::new(Some(hashed_key)).unwrap();
89 hasher.update(message.as_slice());
90 hasher.update(
91 timestamp
92 .timestamp_nanos_opt()
93 .expect("timestamp_nanos_opt failed")
94 .to_string()
95 .as_bytes(),
96 );
97 hasher.finalize().expect("HMAC failed")
98 }
99}
100
101struct HandshakeState<State> {
102 id: Uuid,
103 peer: Peer,
104 challenge: Option<Vec<u8>>,
105 response: Option<Vec<Vec<u8>>>,
106 _timer: Option<Delayed>,
107 and_then: Option<AndThen<()>>,
108 timestamp: DateTime<Utc>,
109 phantom: PhantomData<State>,
110}
111
112#[derive(Debug)]
113struct Hello;
114#[derive(Debug)]
115struct Ohai;
116#[derive(Debug)]
117struct OkBoss;
118
119const HANDSHAKE_TIMEOUT_MILLIS: u64 = 60_000;
120const HANDSHAKE_BOOTSTRAP_DELAY_MILLIS_MIN: u64 = 10;
121const HANDSHAKE_BOOTSTRAP_DELAY_MILLIS_MAX: u64 = 150;
122
123impl<S> HandshakeState<S> {
124 fn start_timer(registry: Registry, id: Uuid, sockaddr: SocketAddr) -> Delayed {
125 Delayed::new(
126 Duration::from_millis(HANDSHAKE_TIMEOUT_MILLIS),
127 async move {
128 registry
129 .cast_handshaker(HandshakerMessage::Packet(
130 sockaddr,
131 Handshake {
132 id,
133 step: HandshakeStep::Dead,
134 },
135 Uuid::nil(),
136 ))
137 .await
138 .ok();
139 },
140 )
141 }
142}
143
144impl HandshakeState<Hello> {
145 #[tracing::instrument(skip(registry))]
146 async fn new(registry: Registry, sockaddr: SocketAddr, and_then: Option<AndThen<()>>) -> Self {
147 let id = Uuid::new_v4();
148 let timestamp = Utc::now();
149 Self {
150 id,
151 peer: Peer::new(None, sockaddr, Uuid::nil(), Some(timestamp), None),
152 challenge: None,
153 response: None,
154 _timer: Some(Self::start_timer(registry, id, sockaddr)),
155 and_then,
156 timestamp,
157 phantom: PhantomData,
158 }
159 }
160}
161
162impl HandshakeState<Ohai> {
163 #[tracing::instrument(skip(registry))]
164 async fn from(
165 registry: Registry,
166 sockaddr: SocketAddr,
167 node_id: Uuid,
168 id: Uuid,
169 their_public_key: PublicKey,
170 timestamp: DateTime<Utc>,
171 zone: Option<String>,
172 ) -> Self {
173 Self {
174 id,
175 peer: Peer::new(
176 Some(their_public_key),
177 sockaddr,
178 node_id,
179 Some(timestamp),
180 zone,
181 ),
182 challenge: Some(randombytes_buf(69)),
183 response: None,
184 _timer: Some(Self::start_timer(registry, id, sockaddr)),
185 and_then: None,
186 timestamp,
187 phantom: PhantomData,
188 }
189 }
190}
191
192fn compute_response(
193 id: &Uuid,
194 public_key: &PublicKey,
195 challenge: &Vec<u8>,
196 keys: &[String],
197) -> Vec<Vec<u8>> {
198 keys.iter()
199 .map(|key| {
200 let mut hasher: GenericHash<64, 64> =
201 GenericHash::new::<Key>(None).expect("new failed");
202 hasher.update(key.as_bytes());
203 hasher.update(challenge);
204 hasher.update(public_key);
205 hasher.update(id.as_bytes());
206 hasher.finalize_to_vec().expect("finalize failed")
207 })
208 .collect()
209}
210
211impl HandshakeState<Hello> {
212 fn got_ohai(
213 self,
214 node_id: Uuid,
215 public_key: &PublicKey,
216 challenge: &Vec<u8>,
217 keys: &[String],
218 zone: Option<String>,
219 ) -> HandshakeState<OkBoss> {
220 let mut peer = self.peer;
221 peer.set_their_public_key(public_key.clone());
222 peer.set_id(node_id);
223 peer.set_zone(zone);
224 let response = compute_response(&self.id, public_key, challenge, keys);
225 HandshakeState {
226 id: self.id,
227 peer,
228 response: Some(response),
229 challenge: self.challenge,
230 _timer: None,
231 and_then: self.and_then,
232 timestamp: self.timestamp,
233 phantom: PhantomData,
234 }
235 }
236}
237
238impl HandshakeState<Ohai> {
239 fn got_okboss(self, response: &[Vec<u8>], keys: &[String]) -> Result<Self, Self> {
240 let pk = &self.peer.our_keypair().unwrap().public_key;
242 let expected_response =
243 compute_response(&self.id, pk, self.challenge.as_ref().unwrap(), keys);
244
245 let expected: HashSet<Vec<u8>> = HashSet::from_iter(expected_response.iter().cloned());
246 let received: HashSet<Vec<u8>> = HashSet::from_iter(response.iter().cloned());
247
248 let intersection: HashSet<_> = expected.intersection(&received).collect();
249
250 if intersection.is_empty() {
251 Err(self)
252 } else {
253 Ok(self)
254 }
255 }
256}
257
258pub struct Handshaker {
259 registry: Registry,
260 hellos: HashMap<(Uuid, SocketAddr), HandshakeState<Hello>>,
261 ohais: HashMap<(Uuid, SocketAddr), HandshakeState<Ohai>>,
262 okboss: HashMap<(Uuid, SocketAddr), HandshakeState<OkBoss>>,
263 inflight: HashSet<SocketAddr>,
264 rng: StdRng,
265 hashed_keys: Option<Vec<Hash>>,
266 bootstrap: Option<Delayed>,
267}
268
269#[derive(Debug)]
270pub enum HandshakerMessage {
271 Bootstrap,
272 Packet(SocketAddr, Handshake, Uuid),
273 SendHello(SocketAddr, Option<AndThen<()>>),
274}
275
276impl GenServer for Handshaker {
277 type Message = HandshakerMessage;
278 type Registry = Registry;
279 type Response = ();
280
281 type CallResponse<'a> = impl Future<Output = Self::Response> + 'a;
282 type CastResponse<'a> = impl Future<Output = ()> + 'a;
283
284 fn new(registry: Self::Registry) -> Self {
285 Self {
286 registry,
287 hellos: HashMap::new(),
288 ohais: HashMap::new(),
289 okboss: HashMap::new(),
290 inflight: HashSet::new(),
291 rng: SeedableRng::from_entropy(),
292 hashed_keys: None,
293 bootstrap: None,
294 }
295 }
296
297 fn handle_call(&mut self, message: Self::Message) -> Self::CallResponse<'_> {
298 async { self.handle_message(message).await }
299 }
300
301 fn handle_cast(&mut self, message: Self::Message) -> Self::CastResponse<'_> {
302 async { self.handle_message(message).await }
303 }
304}
305
306impl Handshaker {
307 async fn handle_message(&mut self, message: HandshakerMessage) {
308 match message {
309 HandshakerMessage::Bootstrap => self.bootstrap().await,
310 HandshakerMessage::Packet(sockaddr, handshake, node_id) => {
311 self.handle_handshake(sockaddr, handshake, node_id).await;
312 }
313 HandshakerMessage::SendHello(sockaddr, and_then) => {
314 self.send_hello(sockaddr, and_then).await
315 }
316 }
317 }
318
319 fn has_valid_hmac(&self, timestamp: &DateTime<Utc>, hmacs: &[Hmac]) -> bool {
320 let now = Utc::now();
321 if let Some(hashed_keys) = &self.hashed_keys {
322 Hmac::is_valid_timestamp_for(timestamp, &now)
323 && hmacs.iter().any(|hmac| {
324 hashed_keys
325 .iter()
326 .any(|hashed_key| hmac.valid_for_key(timestamp, hashed_key))
327 })
328 } else {
329 false
330 }
331 }
332
333 fn compute_hmacs(&mut self, timestamp: &DateTime<Utc>) -> Vec<Hmac> {
334 if let Some(hashed_keys) = &self.hashed_keys {
335 hashed_keys
336 .iter()
337 .map(|hashed_key| Hmac::new(hashed_key, timestamp, self.rng.gen_range(69..420)))
338 .collect()
339 } else {
340 vec![]
341 }
342 }
343
344 #[tracing::instrument(skip(self))]
345 async fn handle_handshake(
346 &mut self,
347 sockaddr: SocketAddr,
348 handshake: Handshake,
349 node_id: Uuid,
350 ) {
351 match &handshake.step {
352 HandshakeStep::Hello {
353 pk,
354 zone,
355 hmacs,
356 timestamp,
357 } => {
358 if &node_id != self.registry.nodeinfo().id()
362 && self.has_valid_hmac(timestamp, hmacs)
363 {
364 let handshake = HandshakeState::from(
365 self.registry.clone(),
366 sockaddr,
367 node_id,
368 handshake.id,
369 pk.clone(),
370 *timestamp,
371 zone.clone(),
372 )
373 .await;
374
375 self.send_ohai(
376 handshake.id,
377 sockaddr,
378 handshake.peer.our_keypair().unwrap().public_key.clone(),
379 handshake.challenge.as_ref().unwrap().clone(),
380 &handshake.timestamp,
381 self.inflight.contains(&sockaddr),
382 )
383 .await;
384 self.inflight.insert(sockaddr);
385 self.ohais.insert((handshake.id, sockaddr), handshake);
386 }
387 }
388 HandshakeStep::Ohai {
389 pk,
390 zone,
391 hmacs,
392 challenge,
393 } => {
394 if let Some(handshake) = self.hellos.remove(&(handshake.id, sockaddr)) {
395 if self.has_valid_hmac(&handshake.timestamp, hmacs) {
396 let handshake = handshake.got_ohai(
397 node_id,
398 pk,
399 challenge,
400 self.registry.nodeinfo().keys(),
401 zone.clone(),
402 );
403 self.send_okboss(
404 handshake.id,
405 sockaddr,
406 handshake.response.as_ref().unwrap().clone(),
407 )
408 .await;
409
410 assert!(handshake.peer.is_handshaken());
411 self.okboss.insert((handshake.id, sockaddr), handshake);
412 }
413 }
414 }
415 HandshakeStep::OkBoss { response } => {
416 if let Some(handshake) = self.ohais.remove(&(handshake.id, sockaddr)) {
417 if let Ok(handshake) =
418 handshake.got_okboss(response, self.registry.nodeinfo().keys())
419 {
420 self.send_peer_up(handshake.peer).await;
421 self.send_goteam(handshake.id, sockaddr).await;
422 if let Some(and_then) = handshake.and_then {
423 and_then.call().await;
424 }
425 } else {
426 error!("bad okboss from {sockaddr}");
427 }
428 }
429 }
430 HandshakeStep::GoTeam => {
431 if let Some(handshake) = self.okboss.remove(&(handshake.id, sockaddr)) {
432 self.send_peer_up(handshake.peer.clone()).await;
433 if let Some(and_then) = handshake.and_then {
434 and_then.call().await;
435 }
436 }
437 }
438 HandshakeStep::Dead => {
439 self.hellos.remove(&(handshake.id, sockaddr));
440 self.ohais.remove(&(handshake.id, sockaddr));
441 self.okboss.remove(&(handshake.id, sockaddr));
442 self.inflight.remove(&sockaddr);
443 }
444 }
445 }
446
447 #[tracing::instrument(skip(self))]
448 pub async fn send_hello(&mut self, sockaddr: SocketAddr, and_then: Option<AndThen<()>>) {
449 let handshake = HandshakeState::new(self.registry.clone(), sockaddr, and_then).await;
450 let payload = Payload::H(Handshake {
451 id: handshake.id,
452 step: HandshakeStep::Hello {
453 pk: handshake.peer.our_keypair().unwrap().public_key.clone(),
454 zone: self.registry.nodeinfo().zone().clone(),
455 hmacs: self.compute_hmacs(&handshake.timestamp),
456 timestamp: handshake.timestamp,
457 },
458 });
459 self.inflight.insert(sockaddr);
460 self.hellos.insert((handshake.id, sockaddr), handshake);
461 self.registry
462 .cast_packet_handler(PacketHandlerMessage::Send(sockaddr, payload))
463 .await
464 .ok();
465 }
466
467 #[tracing::instrument(skip(self))]
468 pub async fn send_ohai(
469 &mut self,
470 id: Uuid,
471 sockaddr: SocketAddr,
472 public_key: PublicKey,
473 challenge: Vec<u8>,
474 timestamp: &DateTime<Utc>,
475 inflight: bool,
476 ) {
477 let payload = Payload::H(Handshake {
478 id,
479 step: HandshakeStep::Ohai {
480 pk: public_key,
481 zone: self.registry.nodeinfo().zone().clone(),
482 hmacs: self.compute_hmacs(timestamp),
483 challenge,
484 },
485 });
486
487 let registry = self.registry.clone();
488 let millis = if inflight {
489 self.rng.gen_range(
490 HANDSHAKE_BOOTSTRAP_DELAY_MILLIS_MIN..HANDSHAKE_BOOTSTRAP_DELAY_MILLIS_MAX,
491 )
492 } else {
493 0
494 };
495 tokio::spawn(async move {
496 sleep(Duration::from_millis(millis)).await;
497 registry
498 .call_packet_handler(PacketHandlerMessage::Send(sockaddr, payload))
499 .await
500 .ok();
501 });
502 }
503
504 #[tracing::instrument(skip(self))]
505 pub async fn send_okboss(&mut self, id: Uuid, sockaddr: SocketAddr, response: Vec<Vec<u8>>) {
506 let payload = Payload::H(Handshake {
507 id,
508 step: HandshakeStep::OkBoss { response },
509 });
510 self.registry
511 .call_packet_handler(PacketHandlerMessage::Send(sockaddr, payload))
512 .await
513 .unwrap();
514 }
515
516 #[tracing::instrument(skip(self))]
517 pub async fn send_goteam(&mut self, id: Uuid, sockaddr: SocketAddr) {
518 let payload = Payload::H(Handshake {
519 id,
520 step: HandshakeStep::GoTeam,
521 });
522 self.registry
523 .call_packet_handler(PacketHandlerMessage::Send(sockaddr, payload))
524 .await
525 .unwrap();
526 }
527
528 #[tracing::instrument(skip(self))]
529 pub async fn request_peerlist(&mut self, id: Uuid, sockaddr: SocketAddr) {
530 self.registry
531 .call_packet_handler(PacketHandlerMessage::SendMessage(
532 id,
533 Message {
534 id: Uuid::new_v4(),
535 body: Body::PeerRequest(Box::new(PeerRequest::PeerList)),
536 },
537 ))
538 .await
539 .ok();
540 }
541
542 #[tracing::instrument(skip(self))]
543 pub async fn send_peer_up(&self, peer: Peer) {
544 self.registry
545 .call_peer_manager(PeerManagerRequest::Request {
546 from_id: *self.registry.nodeinfo().id(),
547 request: Box::new(PeerRequest::PeerStatus(PeerStatus::new(
548 peer,
549 PeerState::Up,
550 ))),
551 })
552 .await
553 .ok();
554 }
555
556 #[tracing::instrument(skip(self))]
557 pub async fn bootstrap(&mut self) {
558 if self.hashed_keys.is_none() {
559 self.hashed_keys = Some(
561 self.registry
562 .nodeinfo()
563 .keys()
564 .iter()
565 .flat_map(|key| GenericHash::hash::<_, Key, _>(key.as_bytes(), None))
566 .collect(),
567 );
568 }
569
570 let peers: Vec<_> = self.registry.nodeinfo().peers().to_vec();
571 let registry = self.registry.clone();
572 let millis = self
573 .rng
574 .gen_range(HANDSHAKE_BOOTSTRAP_DELAY_MILLIS_MIN..HANDSHAKE_BOOTSTRAP_DELAY_MILLIS_MAX);
575 self.bootstrap = Some(Delayed::new(Duration::from_millis(millis), async move {
576 use std::net::ToSocketAddrs;
577 for peer in peers.iter() {
578 match peer.to_socket_addrs() {
579 Ok(mut sockaddr_iter) => {
580 if let Some(sockaddr) = sockaddr_iter.next() {
581 registry
582 .call_handshaker(HandshakerMessage::SendHello(sockaddr, None))
583 .await
584 .ok();
585 }
586 }
587 Err(err) => error!("failed to resolve {peer}: {:?}", err),
588 }
589 }
590 }));
591 }
592}