use core::{fmt, marker::PhantomData, task::Poll};
#[cfg(feature = "std")]
use std::sync::Arc;
#[cfg(not(feature = "std"))]
use alloc::{
boxed::Box,
string::{String, ToString},
sync::Arc,
vec::Vec,
};
use async_lock::{Mutex, OnceCell};
use hashbrown::HashMap;
use x509_cert::{
Certificate,
der::{Decode, DecodePem, EncodePem, pem::LineEnding},
};
use crate::{
config::DeviceConfig,
io::{IoImpl, TcpListenerImpl, TcpStreamImpl, TlsStreamImpl, UdpSocketImpl},
packet::{
NetworkPacket, NetworkPacketBody, NetworkPacketType, identity::IdentityPacket,
pair::PairPacket,
},
plugin::Plugin,
trust::TrustHandler,
};
use serde::{Deserialize, Serialize};
const ALLOWED_TIMESTAMP_TIME_DIFFERENCE_SECONDS: u64 = 1800;
enum Either<A, B> {
A(A),
B(B),
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum PairState {
Paired,
Unpaired,
RequestedByPeer,
Requested,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum DeviceType {
Desktop,
Laptop,
Phone,
Tablet,
Tv,
#[serde(untagged)]
Other(String),
}
impl fmt::Display for DeviceType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", serde_json::to_string(self).unwrap())
}
}
#[derive(Debug, Clone)]
pub struct Link {
pub info: IdentityPacket,
pub pair_state: PairState,
pub(crate) send_queue: async_channel::Sender<NetworkPacket>,
pub(crate) loaded_plugins: Vec<bool>,
}
impl Link {
pub async fn send(&self, packet: NetworkPacket) {
let _ = self.send_queue.send(packet).await;
}
}
#[allow(missing_debug_implementations)]
pub struct Device<
Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream>,
UdpSocket: UdpSocketImpl,
TcpStream: TcpStreamImpl,
TcpListener: TcpListenerImpl<TcpStream>,
TlsStream: TlsStreamImpl,
> {
pub(crate) my_tcp_port: OnceCell<u16>,
pub(crate) links: Arc<Mutex<HashMap<String, Link>>>,
pub(crate) config: DeviceConfig,
pub(crate) plugins: Vec<Box<dyn Plugin + Send + Sync>>,
pub(crate) trust_handler: Arc<Mutex<dyn TrustHandler + Send + Sync>>,
pub(crate) host_device_id: String,
accepted_pair: (
async_channel::Sender<String>,
async_channel::Receiver<String>,
),
device_connected: (
async_channel::Sender<String>,
async_channel::Receiver<String>,
),
pub(crate) io_impl: Io,
_phantom: PhantomData<fn() -> (UdpSocket, TcpStream, TcpListener, TlsStream)>,
}
impl<
Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream> + Unpin + 'static,
UdpSocket: UdpSocketImpl + Unpin + 'static,
TcpStream: TcpStreamImpl + Unpin + 'static,
TcpListener: TcpListenerImpl<TcpStream> + Unpin + 'static,
TlsStream: TlsStreamImpl + Unpin + 'static,
> Device<Io, UdpSocket, TcpStream, TcpListener, TlsStream>
{
pub fn new<T: TrustHandler + Send + Sync + 'static>(
config: DeviceConfig,
plugins: Vec<Box<dyn Plugin + Send + Sync>>,
trust_handler: T,
io_impl: Io,
) -> Self {
Self {
my_tcp_port: OnceCell::new(),
links: Arc::new(Mutex::new(HashMap::new())),
plugins,
trust_handler: Arc::new(Mutex::new(trust_handler)),
host_device_id: crate::transport::tls::extract_device_id_from_cert(
&Certificate::from_pem(&config.cert).unwrap(),
)
.expect("failed to extract device ID from a malformed certificate"),
config,
accepted_pair: async_channel::bounded(16),
device_connected: async_channel::bounded(4),
io_impl,
_phantom: PhantomData,
}
}
pub(crate) fn get_identity_packet(&self) -> NetworkPacket {
let incoming_capabilities = self
.plugins
.iter()
.flat_map(|p| p.supported_incoming_packets());
let outgoing_capabilities = self
.plugins
.iter()
.flat_map(|p| p.supported_outgoing_packets());
NetworkPacket::new(NetworkPacketBody::Identity(
IdentityPacket::new(
&self.host_device_id,
&self.config.name,
self.config.device_type.clone(),
*self
.my_tcp_port
.get()
.expect("tcp server is not started yet"),
)
.with_incoming_capabilities(incoming_capabilities)
.with_outgoing_capabilities(outgoing_capabilities),
))
}
pub(crate) fn new_link(
&self,
identity_packet: IdentityPacket,
pair_state: PairState,
send_queue: async_channel::Sender<NetworkPacket>,
) -> Link {
Link {
info: identity_packet,
pair_state,
send_queue,
loaded_plugins: (0..self.plugins.len()).map(|_| true).collect(),
}
}
async fn reload_plugins(&self, link_id: &str) {
for (i, plugin) in self.plugins.iter().enumerate() {
if self.links.lock().await.get(link_id).unwrap().loaded_plugins[i]
&& let Err(e) = plugin
.on_start(self.links.lock().await.get(link_id).unwrap())
.await
{
log::warn!("Failed to start plugin: {e}, unloading it");
self.links
.lock()
.await
.get_mut(link_id)
.unwrap()
.loaded_plugins[i] = false;
}
}
}
pub fn links(&self) -> &Arc<Mutex<HashMap<String, Link>>> {
&self.links
}
pub async fn pair_with(&self, link_id: &str) {
if let Some(link) = self.links.lock().await.get_mut(link_id) {
link.pair_state = PairState::Requested;
link.send(NetworkPacket::pair_request(
self.io_impl.get_current_timestamp().await,
))
.await;
}
}
pub async fn unpair_with(&self, link_id: &str) {
if let Some(link) = self.links.lock().await.get_mut(link_id) {
if self
.trust_handler
.lock()
.await
.get_certificate(link_id)
.await
.is_some()
{
self.trust_handler
.lock()
.await
.untrust_device(link_id)
.await;
}
link.pair_state = PairState::Unpaired;
link.send(NetworkPacket::unpair_request()).await;
}
}
pub async fn accept_pair(&self, link_id: &str) {
let _ = self.accepted_pair.0.send(link_id.to_string()).await;
}
pub async fn wait_for_connection(&self) -> String {
self.device_connected
.1
.recv()
.await
.expect("channel should not close unexpectedly")
}
pub fn start_arced(self: Arc<Self>) {
Arc::clone(&self).io_impl.start(self);
}
pub fn start(self) {
Arc::new(self).start_arced();
}
#[allow(clippy::too_many_lines)]
async fn handle_pair_packet(
&self,
device_id: &str,
socket: &mut TlsStream,
pair_packet: &PairPacket,
) {
if pair_packet.pair {
let lock = self.links.lock().await;
let pair_state = lock.get(device_id).unwrap().pair_state;
drop(lock);
match pair_state {
PairState::Paired | PairState::RequestedByPeer => {
}
PairState::Unpaired => {
log::debug!("Received pair request");
let current_timestamp = self.io_impl.get_current_timestamp().await;
let Some(packet_timestamp) = pair_packet.timestamp else {
log::warn!("Pair request without timestamp, closing connection");
return;
};
if current_timestamp.abs_diff(packet_timestamp)
> ALLOWED_TIMESTAMP_TIME_DIFFERENCE_SECONDS
{
log::warn!("Pair packet timestamp mismatch, check device clocks");
return;
}
self.links
.lock()
.await
.get_mut(device_id)
.unwrap()
.pair_state = PairState::RequestedByPeer;
log::debug!("Waiting for host to accept {device_id}");
while self
.accepted_pair
.1
.recv()
.await
.is_ok_and(|d| d != device_id)
{
self.io_impl
.sleep(core::time::Duration::from_millis(100))
.await;
}
if let Some(pem_cert) = socket
.get_common_state()
.peer_certificates()
.and_then(|c| c.first())
.and_then(|c| Certificate::from_der(c).ok())
.and_then(|c| c.to_pem(LineEnding::default()).ok())
{
self.trust_handler
.lock()
.await
.trust_device(device_id.to_string(), pem_cert.into_bytes())
.await;
} else {
log::warn!("Failed to get peer certificate to store");
return;
}
NetworkPacket::pair_response().write_to_socket(socket).await;
log::info!("Paired successfully with {device_id}");
self.links
.lock()
.await
.get_mut(device_id)
.unwrap()
.pair_state = PairState::Paired;
self.reload_plugins(device_id).await;
}
PairState::Requested => {
log::debug!("Received pair response");
if let Some(pem_cert) = socket
.get_common_state()
.peer_certificates()
.and_then(|c| c.first())
.and_then(|c| Certificate::from_der(c).ok())
.and_then(|c| c.to_pem(LineEnding::default()).ok())
{
self.trust_handler
.lock()
.await
.trust_device(device_id.to_string(), pem_cert.into_bytes())
.await;
} else {
log::warn!("Failed to get peer certificate to store");
return;
}
log::info!("Paired successfully with {device_id}");
self.links
.lock()
.await
.get_mut(device_id)
.unwrap()
.pair_state = PairState::Paired;
self.reload_plugins(device_id).await;
}
}
} else {
let lock = self.links.lock().await;
let pair_state = lock.get(device_id).unwrap().pair_state;
drop(lock);
if pair_state != PairState::Unpaired {
log::debug!("Received unpair request");
if self
.trust_handler
.lock()
.await
.get_certificate(device_id)
.await
.is_some()
{
self.trust_handler
.lock()
.await
.untrust_device(device_id)
.await;
}
self.links
.lock()
.await
.get_mut(device_id)
.unwrap()
.pair_state = PairState::Unpaired;
NetworkPacket::unpair_response()
.write_to_socket(socket)
.await;
}
}
}
#[allow(clippy::too_many_lines)]
pub(crate) async fn on_conn_established(
self: Arc<Self>,
device_id: String,
mut socket: TlsStream,
send_queue: async_channel::Receiver<NetworkPacket>,
) {
log::info!("New connection established with {device_id}");
if self.links.lock().await.get(&device_id).unwrap().pair_state == PairState::Paired {
self.reload_plugins(&device_id).await;
}
let mut i = 0;
let mut buf = [0u8; crate::config::TLS_APP_BUFFER_SIZE];
let link_incoming_capabilities = self
.links
.lock()
.await
.get(&device_id)
.unwrap()
.info
.incoming_capabilities
.clone();
self.device_connected
.0
.send(device_id.clone())
.await
.expect("channel should not close unexpectedly");
loop {
let bytes_read = loop {
let res = {
let mut future1 = Box::pin(socket.read(&mut buf[i..]));
let mut future2 = Box::pin(send_queue.recv());
core::future::poll_fn(|cx| {
if let Poll::Ready(r) = future1.as_mut().poll(cx) {
Poll::Ready(Either::A(r))
} else if let Poll::Ready(Ok(packet)) = future2.as_mut().poll(cx) {
if packet.body.get_type() != NetworkPacketType::Pair
&& link_incoming_capabilities
.as_ref()
.is_some_and(|c| !c.contains(&packet.body.get_type()))
{
log::warn!(
"Refusing to send unsupported packet type: {:?}",
packet.body.get_type()
);
Poll::Pending
} else {
Poll::Ready(Either::B(packet))
}
} else {
Poll::Pending
}
})
.await
};
match res {
Either::A(b) => break b,
Either::B(packet) => packet.write_to_socket(&mut socket).await,
}
};
if bytes_read.is_err() || *bytes_read.as_ref().unwrap() == 0 {
break;
}
let bytes_read = bytes_read.unwrap();
i += bytes_read;
let mut last_index = 0;
for end in buf[..i]
.iter()
.enumerate()
.filter(|(_, c)| **c == b'\n')
.map(|c| c.0)
{
if end == 0 {
continue;
}
let packet_buf = &buf[last_index..end];
last_index = end + 1;
let packet = match NetworkPacket::try_read_from(packet_buf) {
Ok(p) => p,
Err(e) => {
log::warn!(
"Error while parsing incoming JSON packet: {e}\nOriginal packet:\n{}",
core::str::from_utf8(packet_buf)
.expect("packet is a valid UTF-8 string")
);
continue;
}
};
if let NetworkPacketBody::Pair(pair_packet) = &packet.body {
self.handle_pair_packet(&device_id, &mut socket, pair_packet)
.await;
}
if self.links.lock().await.get(&device_id).unwrap().pair_state == PairState::Paired
{
let packet_type = packet.body.get_type();
for (i, plugin) in self.plugins.iter().enumerate() {
if self
.links
.lock()
.await
.get(&device_id)
.unwrap()
.loaded_plugins[i]
&& plugin.supported_incoming_packets().contains(&packet_type)
&& let Err(e) = plugin
.on_packet_received(
&packet,
self.links.lock().await.get(&device_id).unwrap(),
)
.await
{
log::warn!("Error when handling a received packet: {e}");
}
}
}
}
i = 0;
}
log::info!("Disconnected from {device_id}");
self.links.lock().await.remove(&device_id);
}
}