1use core::{fmt, marker::PhantomData, task::Poll};
6
7#[cfg(feature = "std")]
8use std::sync::Arc;
9
10#[cfg(not(feature = "std"))]
11use alloc::{
12 boxed::Box,
13 string::{String, ToString},
14 sync::Arc,
15 vec::Vec,
16};
17
18use async_lock::{Mutex, OnceCell};
19use hashbrown::HashMap;
20use x509_cert::{
21 Certificate,
22 der::{Decode, DecodePem, EncodePem, pem::LineEnding},
23};
24
25use crate::{
26 config::DeviceConfig,
27 io::{IoImpl, TcpListenerImpl, TcpStreamImpl, TlsStreamImpl, UdpSocketImpl},
28 packet::{
29 NetworkPacket, NetworkPacketBody, NetworkPacketType, identity::IdentityPacket,
30 pair::PairPacket,
31 },
32 plugin::Plugin,
33 trust::TrustHandler,
34};
35
36use serde::{Deserialize, Serialize};
37
38const ALLOWED_TIMESTAMP_TIME_DIFFERENCE_SECONDS: u64 = 1800; enum Either<A, B> {
41 A(A),
42 B(B),
43}
44
45#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
47#[serde(rename_all = "lowercase")]
48pub enum PairState {
49 Paired,
51
52 Unpaired,
54
55 RequestedByPeer,
57
58 Requested,
60}
61
62#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
64#[serde(rename_all = "lowercase")]
65pub enum DeviceType {
66 Desktop,
68
69 Laptop,
71
72 Phone,
74
75 Tablet,
77
78 Tv,
80
81 #[serde(untagged)]
83 Other(String),
84}
85
86impl fmt::Display for DeviceType {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 write!(f, "{}", serde_json::to_string(self).unwrap())
89 }
90}
91
92#[derive(Debug, Clone)]
94pub struct Link {
95 pub info: IdentityPacket,
97
98 pub pair_state: PairState,
100
101 pub(crate) send_queue: async_channel::Sender<NetworkPacket>,
103
104 pub(crate) loaded_plugins: Vec<bool>,
107}
108
109impl Link {
110 pub async fn send(&self, packet: NetworkPacket) {
112 let _ = self.send_queue.send(packet).await;
113 }
114}
115
116#[allow(missing_debug_implementations)]
201pub struct Device<
202 Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream>,
203 UdpSocket: UdpSocketImpl,
204 TcpStream: TcpStreamImpl,
205 TcpListener: TcpListenerImpl<TcpStream>,
206 TlsStream: TlsStreamImpl,
207> {
208 pub(crate) my_tcp_port: OnceCell<u16>,
209 pub(crate) links: Arc<Mutex<HashMap<String, Link>>>,
210 pub(crate) config: DeviceConfig,
211 pub(crate) plugins: Vec<Box<dyn Plugin + Send + Sync>>,
212 pub(crate) trust_handler: Arc<Mutex<dyn TrustHandler + Send + Sync>>,
213 pub(crate) host_device_id: String,
214 accepted_pair: (
215 async_channel::Sender<String>,
216 async_channel::Receiver<String>,
217 ),
218 device_connected: (
219 async_channel::Sender<String>,
220 async_channel::Receiver<String>,
221 ),
222 pub(crate) io_impl: Io,
223
224 _phantom: PhantomData<fn() -> (UdpSocket, TcpStream, TcpListener, TlsStream)>,
225}
226
227impl<
228 Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream> + Unpin + 'static,
229 UdpSocket: UdpSocketImpl + Unpin + 'static,
230 TcpStream: TcpStreamImpl + Unpin + 'static,
231 TcpListener: TcpListenerImpl<TcpStream> + Unpin + 'static,
232 TlsStream: TlsStreamImpl + Unpin + 'static,
233> Device<Io, UdpSocket, TcpStream, TcpListener, TlsStream>
234{
235 pub fn new<T: TrustHandler + Send + Sync + 'static>(
237 config: DeviceConfig,
238 plugins: Vec<Box<dyn Plugin + Send + Sync>>,
239 trust_handler: T,
240 io_impl: Io,
241 ) -> Self {
242 Self {
243 my_tcp_port: OnceCell::new(),
244 links: Arc::new(Mutex::new(HashMap::new())),
245 plugins,
246 trust_handler: Arc::new(Mutex::new(trust_handler)),
247 host_device_id: crate::transport::tls::extract_device_id_from_cert(
248 &Certificate::from_pem(&config.cert).unwrap(),
249 )
250 .expect("failed to extract device ID from a malformed certificate"),
251 config,
252 accepted_pair: async_channel::bounded(16),
253 device_connected: async_channel::bounded(4),
254 io_impl,
255 _phantom: PhantomData,
256 }
257 }
258
259 pub(crate) fn get_identity_packet(&self) -> NetworkPacket {
260 let incoming_capabilities = self
261 .plugins
262 .iter()
263 .flat_map(|p| p.supported_incoming_packets());
264 let outgoing_capabilities = self
265 .plugins
266 .iter()
267 .flat_map(|p| p.supported_outgoing_packets());
268
269 NetworkPacket::new(NetworkPacketBody::Identity(
270 IdentityPacket::new(
271 &self.host_device_id,
272 &self.config.name,
273 self.config.device_type.clone(),
274 *self
275 .my_tcp_port
276 .get()
277 .expect("tcp server is not started yet"),
278 )
279 .with_incoming_capabilities(incoming_capabilities)
280 .with_outgoing_capabilities(outgoing_capabilities),
281 ))
282 }
283
284 pub(crate) fn new_link(
285 &self,
286 identity_packet: IdentityPacket,
287 pair_state: PairState,
288 send_queue: async_channel::Sender<NetworkPacket>,
289 ) -> Link {
290 Link {
291 info: identity_packet,
292 pair_state,
293 send_queue,
294 loaded_plugins: (0..self.plugins.len()).map(|_| true).collect(),
295 }
296 }
297
298 async fn reload_plugins(&self, link_id: &str) {
299 for (i, plugin) in self.plugins.iter().enumerate() {
300 if self.links.lock().await.get(link_id).unwrap().loaded_plugins[i]
301 && let Err(e) = plugin
302 .on_start(self.links.lock().await.get(link_id).unwrap())
303 .await
304 {
305 log::warn!("Failed to start plugin: {e}, unloading it");
306 self.links
307 .lock()
308 .await
309 .get_mut(link_id)
310 .unwrap()
311 .loaded_plugins[i] = false;
312 }
313 }
314 }
315
316 pub fn links(&self) -> &Arc<Mutex<HashMap<String, Link>>> {
318 &self.links
319 }
320
321 pub async fn pair_with(&self, link_id: &str) {
326 if let Some(link) = self.links.lock().await.get_mut(link_id) {
327 link.pair_state = PairState::Requested;
328 link.send(NetworkPacket::pair_request(
329 self.io_impl.get_current_timestamp().await,
330 ))
331 .await;
332 }
333 }
334
335 pub async fn unpair_with(&self, link_id: &str) {
340 if let Some(link) = self.links.lock().await.get_mut(link_id) {
341 if self
343 .trust_handler
344 .lock()
345 .await
346 .get_certificate(link_id)
347 .await
348 .is_some()
349 {
350 self.trust_handler
351 .lock()
352 .await
353 .untrust_device(link_id)
354 .await;
355 }
356 link.pair_state = PairState::Unpaired;
357 link.send(NetworkPacket::unpair_request()).await;
358 }
359 }
360
361 pub async fn accept_pair(&self, link_id: &str) {
366 let _ = self.accepted_pair.0.send(link_id.to_string()).await;
367 }
368
369 pub async fn wait_for_connection(&self) -> String {
373 self.device_connected
374 .1
375 .recv()
376 .await
377 .expect("channel should not close unexpectedly")
378 }
379
380 pub fn start_arced(self: Arc<Self>) {
384 Arc::clone(&self).io_impl.start(self);
385 }
386
387 pub fn start(self) {
389 Arc::new(self).start_arced();
390 }
391
392 #[allow(clippy::too_many_lines)]
393 async fn handle_pair_packet(
394 &self,
395 device_id: &str,
396 socket: &mut TlsStream,
397 pair_packet: &PairPacket,
398 ) {
399 if pair_packet.pair {
400 let lock = self.links.lock().await;
401 let pair_state = lock.get(device_id).unwrap().pair_state;
402 drop(lock);
403
404 match pair_state {
405 PairState::Paired | PairState::RequestedByPeer => {
406 }
408 PairState::Unpaired => {
409 log::debug!("Received pair request");
410
411 let current_timestamp = self.io_impl.get_current_timestamp().await;
412 let Some(packet_timestamp) = pair_packet.timestamp else {
413 log::warn!("Pair request without timestamp, closing connection");
414 return;
415 };
416
417 if current_timestamp.abs_diff(packet_timestamp)
418 > ALLOWED_TIMESTAMP_TIME_DIFFERENCE_SECONDS
419 {
420 log::warn!("Pair packet timestamp mismatch, check device clocks");
421 return;
422 }
423
424 self.links
425 .lock()
426 .await
427 .get_mut(device_id)
428 .unwrap()
429 .pair_state = PairState::RequestedByPeer;
430
431 log::debug!("Waiting for host to accept {device_id}");
433
434 while self
435 .accepted_pair
436 .1
437 .recv()
438 .await
439 .is_ok_and(|d| d != device_id)
440 {
441 self.io_impl
442 .sleep(core::time::Duration::from_millis(100))
443 .await;
444 }
445
446 if let Some(pem_cert) = socket
447 .get_common_state()
448 .peer_certificates()
449 .and_then(|c| c.first())
450 .and_then(|c| Certificate::from_der(c).ok())
451 .and_then(|c| c.to_pem(LineEnding::default()).ok())
452 {
453 self.trust_handler
454 .lock()
455 .await
456 .trust_device(device_id.to_string(), pem_cert.into_bytes())
457 .await;
458 } else {
459 log::warn!("Failed to get peer certificate to store");
460 return;
461 }
462
463 NetworkPacket::pair_response().write_to_socket(socket).await;
464
465 log::info!("Paired successfully with {device_id}");
466
467 self.links
468 .lock()
469 .await
470 .get_mut(device_id)
471 .unwrap()
472 .pair_state = PairState::Paired;
473 self.reload_plugins(device_id).await;
474 }
475 PairState::Requested => {
476 log::debug!("Received pair response");
477
478 if let Some(pem_cert) = socket
479 .get_common_state()
480 .peer_certificates()
481 .and_then(|c| c.first())
482 .and_then(|c| Certificate::from_der(c).ok())
483 .and_then(|c| c.to_pem(LineEnding::default()).ok())
484 {
485 self.trust_handler
486 .lock()
487 .await
488 .trust_device(device_id.to_string(), pem_cert.into_bytes())
489 .await;
490 } else {
491 log::warn!("Failed to get peer certificate to store");
492 return;
493 }
494
495 log::info!("Paired successfully with {device_id}");
496
497 self.links
498 .lock()
499 .await
500 .get_mut(device_id)
501 .unwrap()
502 .pair_state = PairState::Paired;
503 self.reload_plugins(device_id).await;
504 }
505 }
506 } else {
507 let lock = self.links.lock().await;
508 let pair_state = lock.get(device_id).unwrap().pair_state;
509 drop(lock);
510
511 if pair_state != PairState::Unpaired {
512 log::debug!("Received unpair request");
513
514 if self
516 .trust_handler
517 .lock()
518 .await
519 .get_certificate(device_id)
520 .await
521 .is_some()
522 {
523 self.trust_handler
524 .lock()
525 .await
526 .untrust_device(device_id)
527 .await;
528 }
529 self.links
530 .lock()
531 .await
532 .get_mut(device_id)
533 .unwrap()
534 .pair_state = PairState::Unpaired;
535 NetworkPacket::unpair_response()
536 .write_to_socket(socket)
537 .await;
538 }
539 }
540 }
541
542 #[allow(clippy::too_many_lines)]
543 pub(crate) async fn on_conn_established(
544 self: Arc<Self>,
545 device_id: String,
546 mut socket: TlsStream,
547 send_queue: async_channel::Receiver<NetworkPacket>,
548 ) {
549 log::info!("New connection established with {device_id}");
550
551 if self.links.lock().await.get(&device_id).unwrap().pair_state == PairState::Paired {
552 self.reload_plugins(&device_id).await;
553 }
554
555 let mut i = 0;
556 let mut buf = [0u8; crate::config::TLS_APP_BUFFER_SIZE];
557 let link_incoming_capabilities = self
558 .links
559 .lock()
560 .await
561 .get(&device_id)
562 .unwrap()
563 .info
564 .incoming_capabilities
565 .clone();
566
567 self.device_connected
568 .0
569 .send(device_id.clone())
570 .await
571 .expect("channel should not close unexpectedly");
572
573 loop {
574 let bytes_read = loop {
575 let res = {
576 let mut future1 = Box::pin(socket.read(&mut buf[i..]));
577 let mut future2 = Box::pin(send_queue.recv());
578
579 core::future::poll_fn(|cx| {
580 if let Poll::Ready(r) = future1.as_mut().poll(cx) {
581 Poll::Ready(Either::A(r))
582 } else if let Poll::Ready(Ok(packet)) = future2.as_mut().poll(cx) {
583 if packet.body.get_type() != NetworkPacketType::Pair
584 && link_incoming_capabilities
585 .as_ref()
586 .is_some_and(|c| !c.contains(&packet.body.get_type()))
587 {
588 log::warn!(
589 "Refusing to send unsupported packet type: {:?}",
590 packet.body.get_type()
591 );
592 Poll::Pending
593 } else {
594 Poll::Ready(Either::B(packet))
595 }
596 } else {
597 Poll::Pending
598 }
599 })
600 .await
601 };
602
603 match res {
604 Either::A(b) => break b,
605 Either::B(packet) => packet.write_to_socket(&mut socket).await,
606 }
607 };
608
609 if bytes_read.is_err() || *bytes_read.as_ref().unwrap() == 0 {
610 break;
611 }
612
613 let bytes_read = bytes_read.unwrap();
614 i += bytes_read;
615
616 let mut last_index = 0;
617 for end in buf[..i]
618 .iter()
619 .enumerate()
620 .filter(|(_, c)| **c == b'\n')
621 .map(|c| c.0)
622 {
623 if end == 0 {
624 continue;
625 }
626
627 let packet_buf = &buf[last_index..end];
628 last_index = end + 1;
629
630 let packet = match NetworkPacket::try_read_from(packet_buf) {
631 Ok(p) => p,
632 Err(e) => {
633 log::warn!(
634 "Error while parsing incoming JSON packet: {e}\nOriginal packet:\n{}",
635 core::str::from_utf8(packet_buf)
636 .expect("packet is a valid UTF-8 string")
637 );
638 continue;
639 }
640 };
641
642 if let NetworkPacketBody::Pair(pair_packet) = &packet.body {
644 self.handle_pair_packet(&device_id, &mut socket, pair_packet)
645 .await;
646 }
647
648 if self.links.lock().await.get(&device_id).unwrap().pair_state == PairState::Paired
650 {
651 let packet_type = packet.body.get_type();
652
653 for (i, plugin) in self.plugins.iter().enumerate() {
654 if self
655 .links
656 .lock()
657 .await
658 .get(&device_id)
659 .unwrap()
660 .loaded_plugins[i]
661 && plugin.supported_incoming_packets().contains(&packet_type)
662 && let Err(e) = plugin
663 .on_packet_received(
664 &packet,
665 self.links.lock().await.get(&device_id).unwrap(),
666 )
667 .await
668 {
669 log::warn!("Error when handling a received packet: {e}");
670 }
671 }
672 }
673 }
674
675 i = 0;
676 }
677
678 log::info!("Disconnected from {device_id}");
679 self.links.lock().await.remove(&device_id);
680 }
681}