Skip to main content

kdeconnect_proto/
device.rs

1//! Define structures related to connected devices.
2//!
3//! - The [`Device`] structure represents the host device
4//! - The [`Link`] structure represents a peer device connected to the host device
5use core::{fmt, marker::PhantomData, task::Poll};
6
7#[cfg(feature = "std")]
8use std::sync::Arc;
9
10#[cfg(not(feature = "std"))]
11use alloc::{
12    boxed::Box,
13    string::{String, ToString},
14    sync::Arc,
15    vec::Vec,
16};
17
18use async_lock::{Mutex, OnceCell};
19use hashbrown::HashMap;
20use x509_cert::{
21    Certificate,
22    der::{Decode, DecodePem, EncodePem, pem::LineEnding},
23};
24
25use crate::{
26    config::DeviceConfig,
27    io::{IoImpl, TcpListenerImpl, TcpStreamImpl, TlsStreamImpl, UdpSocketImpl},
28    packet::{
29        NetworkPacket, NetworkPacketBody, NetworkPacketType, identity::IdentityPacket,
30        pair::PairPacket,
31    },
32    plugin::Plugin,
33    trust::TrustHandler,
34};
35
36use serde::{Deserialize, Serialize};
37
38const ALLOWED_TIMESTAMP_TIME_DIFFERENCE_SECONDS: u64 = 1800; // 30 min
39
40enum Either<A, B> {
41    A(A),
42    B(B),
43}
44
45/// The state of the pairing between the host device and a peer device.
46#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
47#[serde(rename_all = "lowercase")]
48pub enum PairState {
49    /// The host is connected to the linked device.
50    Paired,
51
52    /// The host is not connected to the linked device and **no** pair request as been initiated.
53    Unpaired,
54
55    /// The peer device as requested pairing but waits for approval of the host.
56    RequestedByPeer,
57
58    /// The host device as requested pairing but waits approval by the peer device.
59    Requested,
60}
61
62/// The physical type of the device.
63#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
64#[serde(rename_all = "lowercase")]
65pub enum DeviceType {
66    /// The device is a desktop.
67    Desktop,
68
69    /// The device is a laptop.
70    Laptop,
71
72    /// The device is a phone.
73    Phone,
74
75    /// The device is a tablet.
76    Tablet,
77
78    /// The device is a TV.
79    Tv,
80
81    /// The device is of another type.
82    #[serde(untagged)]
83    Other(String),
84}
85
86impl fmt::Display for DeviceType {
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        write!(f, "{}", serde_json::to_string(self).unwrap())
89    }
90}
91
92/// A structure representing a peer device linked to this device but not necessarily paired.
93#[derive(Debug, Clone)]
94pub struct Link {
95    /// Device information about this linked remote device.
96    pub info: IdentityPacket,
97
98    /// The pairing state between the current device and the linked remote device.
99    pub pair_state: PairState,
100
101    /// The queue for sending [`NetworkPacket`] to the linked remote device.
102    pub(crate) send_queue: async_channel::Sender<NetworkPacket>,
103
104    /// A list of indices into the [`Device::plugins`] field corresponding to the usable plugins
105    /// for this device based on the return value of the [`Plugin::on_start`] method.
106    pub(crate) loaded_plugins: Vec<bool>,
107}
108
109impl Link {
110    /// Send a [`NetworkPacket`] to the linked device.
111    pub async fn send(&self, packet: NetworkPacket) {
112        let _ = self.send_queue.send(packet).await;
113    }
114}
115
116/// Structure representing the host device.
117///
118/// To make one, you need to make a configuration with [`DeviceConfig`],
119/// to choose plugins to use, you need a trust handler i.e a structure implementing
120/// [`TrustHandler`] and an IO implementation of your choice (you can use
121/// [`TokioIoImpl`](`crate::io::TokioIoImpl`) as a default).
122///
123/// # Example
124///
125/// This code snippet launches a new device with no plugin defined and the trust handler
126/// defined in the example of [`TrustHandler`], assuming that an X.509 certificate exists at
127/// `cert.pem` and an X.509 private key exists at `private_key.pem` (see [`DeviceConfig`] to learn
128/// how to create them).
129///
130/// To learn how to make plugins, check [`Plugin`].
131///
132/// ```no_run
133/// use std::{fs, collections::HashMap, path::{Path, PathBuf}};
134/// use kdeconnect_proto::{
135///     config::DeviceConfig,
136///     device::{Device, DeviceType},
137///     trust::TrustHandler,
138///     io::TokioIoImpl,
139/// };
140///
141/// struct TrustHandlerImpl {
142///    path: PathBuf,
143///    trusted_devices: HashMap<String, Vec<u8>>,
144/// }
145///
146/// impl TrustHandlerImpl {
147///    pub fn new<P: AsRef<Path>>(path: P) -> Self {
148///        let path = path.as_ref();
149///
150///        let trusted_devices = if path.exists() {
151///             HashMap::from_iter(fs::read_dir(path).unwrap().filter_map(Result::ok).map(|f| {
152///                 let device_id = f.path().file_stem().unwrap().to_string_lossy().to_string();
153///                 let cert = fs::read(f.path()).expect("failed to read certificate");
154///                 (device_id, cert)
155///             }))
156///         } else {
157///             fs::create_dir_all(path).expect("failed to create directory for trusted devices");
158///             HashMap::new()
159///         };
160///
161///        Self {
162///            path: path.to_path_buf(),
163///            trusted_devices,
164///        }
165///    }
166/// }
167///
168/// #[kdeconnect_proto::async_trait]
169/// impl TrustHandler for TrustHandlerImpl {
170///     async fn trust_device(&mut self, device_id: String, cert: Vec<u8>) {
171///        fs::write(self.path.join(device_id.clone() + ".pem"), &cert).unwrap();
172///        self.trusted_devices.insert(device_id, cert);
173///    }
174///
175///    async fn untrust_device(&mut self, device_id: &str) {
176///        fs::remove_file(self.path.join(device_id.to_string() + ".pem")).unwrap();
177///        self.trusted_devices.remove(device_id);
178///    }
179///
180///    async fn get_certificate(&mut self, device_id: &str) -> Option<&[u8]> {
181///        self.trusted_devices.get(device_id).map(|v| &**v)
182///    }
183/// }
184///
185/// let config = DeviceConfig {
186///     name: String::from("kdeconnect client"),
187///     device_type: DeviceType::Desktop,
188///     cert: fs::read("cert.pem").expect("failed to read certificate"),
189///     private_key: fs::read("private_key.pem").expect("failed to read private key"),
190/// };
191///
192/// let device = Device::new(config, vec![], TrustHandlerImpl::new("trusted_devices"), TokioIoImpl);
193/// device.start();
194/// ```
195///
196/// # Pairing
197///
198/// You need to manually accept each pair request coming from a peer device using the
199/// [`Device::accept_pair`] method. You can use a background asynchronous task to do that.
200#[allow(missing_debug_implementations)]
201pub struct Device<
202    Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream>,
203    UdpSocket: UdpSocketImpl,
204    TcpStream: TcpStreamImpl,
205    TcpListener: TcpListenerImpl<TcpStream>,
206    TlsStream: TlsStreamImpl,
207> {
208    pub(crate) my_tcp_port: OnceCell<u16>,
209    pub(crate) links: Arc<Mutex<HashMap<String, Link>>>,
210    pub(crate) config: DeviceConfig,
211    pub(crate) plugins: Vec<Box<dyn Plugin + Send + Sync>>,
212    pub(crate) trust_handler: Arc<Mutex<dyn TrustHandler + Send + Sync>>,
213    pub(crate) host_device_id: String,
214    accepted_pair: (
215        async_channel::Sender<String>,
216        async_channel::Receiver<String>,
217    ),
218    device_connected: (
219        async_channel::Sender<String>,
220        async_channel::Receiver<String>,
221    ),
222    pub(crate) io_impl: Io,
223
224    _phantom: PhantomData<fn() -> (UdpSocket, TcpStream, TcpListener, TlsStream)>,
225}
226
227impl<
228    Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream> + Unpin + 'static,
229    UdpSocket: UdpSocketImpl + Unpin + 'static,
230    TcpStream: TcpStreamImpl + Unpin + 'static,
231    TcpListener: TcpListenerImpl<TcpStream> + Unpin + 'static,
232    TlsStream: TlsStreamImpl + Unpin + 'static,
233> Device<Io, UdpSocket, TcpStream, TcpListener, TlsStream>
234{
235    /// Make a new [`Device`].
236    pub fn new<T: TrustHandler + Send + Sync + 'static>(
237        config: DeviceConfig,
238        plugins: Vec<Box<dyn Plugin + Send + Sync>>,
239        trust_handler: T,
240        io_impl: Io,
241    ) -> Self {
242        Self {
243            my_tcp_port: OnceCell::new(),
244            links: Arc::new(Mutex::new(HashMap::new())),
245            plugins,
246            trust_handler: Arc::new(Mutex::new(trust_handler)),
247            host_device_id: crate::transport::tls::extract_device_id_from_cert(
248                &Certificate::from_pem(&config.cert).unwrap(),
249            )
250            .expect("failed to extract device ID from a malformed certificate"),
251            config,
252            accepted_pair: async_channel::bounded(16),
253            device_connected: async_channel::bounded(4),
254            io_impl,
255            _phantom: PhantomData,
256        }
257    }
258
259    pub(crate) fn get_identity_packet(&self) -> NetworkPacket {
260        let incoming_capabilities = self
261            .plugins
262            .iter()
263            .flat_map(|p| p.supported_incoming_packets());
264        let outgoing_capabilities = self
265            .plugins
266            .iter()
267            .flat_map(|p| p.supported_outgoing_packets());
268
269        NetworkPacket::new(NetworkPacketBody::Identity(
270            IdentityPacket::new(
271                &self.host_device_id,
272                &self.config.name,
273                self.config.device_type.clone(),
274                *self
275                    .my_tcp_port
276                    .get()
277                    .expect("tcp server is not started yet"),
278            )
279            .with_incoming_capabilities(incoming_capabilities)
280            .with_outgoing_capabilities(outgoing_capabilities),
281        ))
282    }
283
284    pub(crate) fn new_link(
285        &self,
286        identity_packet: IdentityPacket,
287        pair_state: PairState,
288        send_queue: async_channel::Sender<NetworkPacket>,
289    ) -> Link {
290        Link {
291            info: identity_packet,
292            pair_state,
293            send_queue,
294            loaded_plugins: (0..self.plugins.len()).map(|_| true).collect(),
295        }
296    }
297
298    async fn reload_plugins(&self, link_id: &str) {
299        for (i, plugin) in self.plugins.iter().enumerate() {
300            if self.links.lock().await.get(link_id).unwrap().loaded_plugins[i]
301                && let Err(e) = plugin
302                    .on_start(self.links.lock().await.get(link_id).unwrap())
303                    .await
304            {
305                log::warn!("Failed to start plugin: {e}, unloading it");
306                self.links
307                    .lock()
308                    .await
309                    .get_mut(link_id)
310                    .unwrap()
311                    .loaded_plugins[i] = false;
312            }
313        }
314    }
315
316    /// Get an arc mutex'd map of `(Device ID, Link)` pairs.
317    pub fn links(&self) -> &Arc<Mutex<HashMap<String, Link>>> {
318        &self.links
319    }
320
321    /// Pair with a peer device.
322    ///
323    /// In most cases, the pair state will be advanced from `PairState::Unpaired` to
324    /// `PairState::Requested`.
325    pub async fn pair_with(&self, link_id: &str) {
326        if let Some(link) = self.links.lock().await.get_mut(link_id) {
327            link.pair_state = PairState::Requested;
328            link.send(NetworkPacket::pair_request(
329                self.io_impl.get_current_timestamp().await,
330            ))
331            .await;
332        }
333    }
334
335    /// Unpair from a peer device.
336    ///
337    /// In most cases, the pair state will be advanced from `PairState::Paired` to
338    /// `PairState::Unpaired`.
339    pub async fn unpair_with(&self, link_id: &str) {
340        if let Some(link) = self.links.lock().await.get_mut(link_id) {
341            // Only untrust the device if it was trusted before
342            if self
343                .trust_handler
344                .lock()
345                .await
346                .get_certificate(link_id)
347                .await
348                .is_some()
349            {
350                self.trust_handler
351                    .lock()
352                    .await
353                    .untrust_device(link_id)
354                    .await;
355            }
356            link.pair_state = PairState::Unpaired;
357            link.send(NetworkPacket::unpair_request()).await;
358        }
359    }
360
361    /// Accept the pairing with a peer device which has already requested pairing.
362    ///
363    /// In most cases, the pair state will be advanced from `PairState::RequestedByPeer` to
364    /// `PairState::Paired`.
365    pub async fn accept_pair(&self, link_id: &str) {
366        let _ = self.accepted_pair.0.send(link_id.to_string()).await;
367    }
368
369    /// Wait for a peer device to connect and return its device ID.
370    ///
371    /// The connected device may **or may not** be paired.
372    pub async fn wait_for_connection(&self) -> String {
373        self.device_connected
374            .1
375            .recv()
376            .await
377            .expect("channel should not close unexpectedly")
378    }
379
380    /// Start an internal task responsible for handling connections and managing state.
381    ///
382    /// This function takes an Arc of self for easier usage of the [`Device`] structure.
383    pub fn start_arced(self: Arc<Self>) {
384        Arc::clone(&self).io_impl.start(self);
385    }
386
387    /// Start an internal task responsible for handling connections and managing state.
388    pub fn start(self) {
389        Arc::new(self).start_arced();
390    }
391
392    #[allow(clippy::too_many_lines)]
393    async fn handle_pair_packet(
394        &self,
395        device_id: &str,
396        socket: &mut TlsStream,
397        pair_packet: &PairPacket,
398    ) {
399        if pair_packet.pair {
400            let lock = self.links.lock().await;
401            let pair_state = lock.get(device_id).unwrap().pair_state;
402            drop(lock);
403
404            match pair_state {
405                PairState::Paired | PairState::RequestedByPeer => {
406                    // Ignore packet
407                }
408                PairState::Unpaired => {
409                    log::debug!("Received pair request");
410
411                    let current_timestamp = self.io_impl.get_current_timestamp().await;
412                    let Some(packet_timestamp) = pair_packet.timestamp else {
413                        log::warn!("Pair request without timestamp, closing connection");
414                        return;
415                    };
416
417                    if current_timestamp.abs_diff(packet_timestamp)
418                        > ALLOWED_TIMESTAMP_TIME_DIFFERENCE_SECONDS
419                    {
420                        log::warn!("Pair packet timestamp mismatch, check device clocks");
421                        return;
422                    }
423
424                    self.links
425                        .lock()
426                        .await
427                        .get_mut(device_id)
428                        .unwrap()
429                        .pair_state = PairState::RequestedByPeer;
430
431                    // TODO: add a timeout
432                    log::debug!("Waiting for host to accept {device_id}");
433
434                    while self
435                        .accepted_pair
436                        .1
437                        .recv()
438                        .await
439                        .is_ok_and(|d| d != device_id)
440                    {
441                        self.io_impl
442                            .sleep(core::time::Duration::from_millis(100))
443                            .await;
444                    }
445
446                    if let Some(pem_cert) = socket
447                        .get_common_state()
448                        .peer_certificates()
449                        .and_then(|c| c.first())
450                        .and_then(|c| Certificate::from_der(c).ok())
451                        .and_then(|c| c.to_pem(LineEnding::default()).ok())
452                    {
453                        self.trust_handler
454                            .lock()
455                            .await
456                            .trust_device(device_id.to_string(), pem_cert.into_bytes())
457                            .await;
458                    } else {
459                        log::warn!("Failed to get peer certificate to store");
460                        return;
461                    }
462
463                    NetworkPacket::pair_response().write_to_socket(socket).await;
464
465                    log::info!("Paired successfully with {device_id}");
466
467                    self.links
468                        .lock()
469                        .await
470                        .get_mut(device_id)
471                        .unwrap()
472                        .pair_state = PairState::Paired;
473                    self.reload_plugins(device_id).await;
474                }
475                PairState::Requested => {
476                    log::debug!("Received pair response");
477
478                    if let Some(pem_cert) = socket
479                        .get_common_state()
480                        .peer_certificates()
481                        .and_then(|c| c.first())
482                        .and_then(|c| Certificate::from_der(c).ok())
483                        .and_then(|c| c.to_pem(LineEnding::default()).ok())
484                    {
485                        self.trust_handler
486                            .lock()
487                            .await
488                            .trust_device(device_id.to_string(), pem_cert.into_bytes())
489                            .await;
490                    } else {
491                        log::warn!("Failed to get peer certificate to store");
492                        return;
493                    }
494
495                    log::info!("Paired successfully with {device_id}");
496
497                    self.links
498                        .lock()
499                        .await
500                        .get_mut(device_id)
501                        .unwrap()
502                        .pair_state = PairState::Paired;
503                    self.reload_plugins(device_id).await;
504                }
505            }
506        } else {
507            let lock = self.links.lock().await;
508            let pair_state = lock.get(device_id).unwrap().pair_state;
509            drop(lock);
510
511            if pair_state != PairState::Unpaired {
512                log::debug!("Received unpair request");
513
514                // Only untrust the device if it was trusted before
515                if self
516                    .trust_handler
517                    .lock()
518                    .await
519                    .get_certificate(device_id)
520                    .await
521                    .is_some()
522                {
523                    self.trust_handler
524                        .lock()
525                        .await
526                        .untrust_device(device_id)
527                        .await;
528                }
529                self.links
530                    .lock()
531                    .await
532                    .get_mut(device_id)
533                    .unwrap()
534                    .pair_state = PairState::Unpaired;
535                NetworkPacket::unpair_response()
536                    .write_to_socket(socket)
537                    .await;
538            }
539        }
540    }
541
542    #[allow(clippy::too_many_lines)]
543    pub(crate) async fn on_conn_established(
544        self: Arc<Self>,
545        device_id: String,
546        mut socket: TlsStream,
547        send_queue: async_channel::Receiver<NetworkPacket>,
548    ) {
549        log::info!("New connection established with {device_id}");
550
551        if self.links.lock().await.get(&device_id).unwrap().pair_state == PairState::Paired {
552            self.reload_plugins(&device_id).await;
553        }
554
555        let mut i = 0;
556        let mut buf = [0u8; crate::config::TLS_APP_BUFFER_SIZE];
557        let link_incoming_capabilities = self
558            .links
559            .lock()
560            .await
561            .get(&device_id)
562            .unwrap()
563            .info
564            .incoming_capabilities
565            .clone();
566
567        self.device_connected
568            .0
569            .send(device_id.clone())
570            .await
571            .expect("channel should not close unexpectedly");
572
573        loop {
574            let bytes_read = loop {
575                let res = {
576                    let mut future1 = Box::pin(socket.read(&mut buf[i..]));
577                    let mut future2 = Box::pin(send_queue.recv());
578
579                    core::future::poll_fn(|cx| {
580                        if let Poll::Ready(r) = future1.as_mut().poll(cx) {
581                            Poll::Ready(Either::A(r))
582                        } else if let Poll::Ready(Ok(packet)) = future2.as_mut().poll(cx) {
583                            if packet.body.get_type() != NetworkPacketType::Pair
584                                && link_incoming_capabilities
585                                    .as_ref()
586                                    .is_some_and(|c| !c.contains(&packet.body.get_type()))
587                            {
588                                log::warn!(
589                                    "Refusing to send unsupported packet type: {:?}",
590                                    packet.body.get_type()
591                                );
592                                Poll::Pending
593                            } else {
594                                Poll::Ready(Either::B(packet))
595                            }
596                        } else {
597                            Poll::Pending
598                        }
599                    })
600                    .await
601                };
602
603                match res {
604                    Either::A(b) => break b,
605                    Either::B(packet) => packet.write_to_socket(&mut socket).await,
606                }
607            };
608
609            if bytes_read.is_err() || *bytes_read.as_ref().unwrap() == 0 {
610                break;
611            }
612
613            let bytes_read = bytes_read.unwrap();
614            i += bytes_read;
615
616            let mut last_index = 0;
617            for end in buf[..i]
618                .iter()
619                .enumerate()
620                .filter(|(_, c)| **c == b'\n')
621                .map(|c| c.0)
622            {
623                if end == 0 {
624                    continue;
625                }
626
627                let packet_buf = &buf[last_index..end];
628                last_index = end + 1;
629
630                let packet = match NetworkPacket::try_read_from(packet_buf) {
631                    Ok(p) => p,
632                    Err(e) => {
633                        log::warn!(
634                            "Error while parsing incoming JSON packet: {e}\nOriginal packet:\n{}",
635                            core::str::from_utf8(packet_buf)
636                                .expect("packet is a valid UTF-8 string")
637                        );
638                        continue;
639                    }
640                };
641
642                // Special handling for the pairing packet
643                if let NetworkPacketBody::Pair(pair_packet) = &packet.body {
644                    self.handle_pair_packet(&device_id, &mut socket, pair_packet)
645                        .await;
646                }
647
648                // Prevent receiving other packets if the device is not paired
649                if self.links.lock().await.get(&device_id).unwrap().pair_state == PairState::Paired
650                {
651                    let packet_type = packet.body.get_type();
652
653                    for (i, plugin) in self.plugins.iter().enumerate() {
654                        if self
655                            .links
656                            .lock()
657                            .await
658                            .get(&device_id)
659                            .unwrap()
660                            .loaded_plugins[i]
661                            && plugin.supported_incoming_packets().contains(&packet_type)
662                            && let Err(e) = plugin
663                                .on_packet_received(
664                                    &packet,
665                                    self.links.lock().await.get(&device_id).unwrap(),
666                                )
667                                .await
668                        {
669                            log::warn!("Error when handling a received packet: {e}");
670                        }
671                    }
672                }
673            }
674
675            i = 0;
676        }
677
678        log::info!("Disconnected from {device_id}");
679        self.links.lock().await.remove(&device_id);
680    }
681}