use core::{
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
time::Duration,
};
#[cfg(feature = "std")]
use std::sync::Arc;
#[cfg(not(feature = "std"))]
use alloc::{sync::Arc, vec};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, UnixTime, pem::PemObject};
use x509_cert::{Certificate, der::DecodePem};
use crate::{
device::Device, io::{IoImpl, TcpListenerImpl, TcpStreamImpl, TlsStreamImpl, UdpSocketImpl}, packet::{NetworkPacket, NetworkPacketBody}
};
async fn broadcast_udp_identity<
Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream> + Unpin + 'static,
UdpSocket: UdpSocketImpl + Unpin + 'static,
TcpStream: TcpStreamImpl + Unpin + 'static,
TcpListener: TcpListenerImpl<TcpStream> + Unpin + 'static,
TlsStream: TlsStreamImpl + Unpin + 'static,
>(
device: Arc<Device<Io, UdpSocket, TcpStream, TcpListener, TlsStream>>,
) {
let Ok(mut socket) = device
.io_impl
.bind_udp(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0))
.await
else {
return;
};
let my_identity_packet = device.get_identity_packet();
let serialized_my_identity_packet = serde_json::to_string(&my_identity_packet).unwrap();
socket.set_broadcast(true).unwrap();
socket
.send_to(
serialized_my_identity_packet.as_bytes(),
SocketAddr::new(IpAddr::V4(Ipv4Addr::BROADCAST), crate::config::UDP_PORT),
)
.await
.unwrap();
}
pub async fn setup_udp<
Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream> + Unpin + 'static,
UdpSocket: UdpSocketImpl + Unpin + 'static,
TcpStream: TcpStreamImpl + Unpin + 'static,
TcpListener: TcpListenerImpl<TcpStream> + Unpin + 'static,
TlsStream: TlsStreamImpl + Unpin + 'static,
>(
device: Arc<Device<Io, UdpSocket, TcpStream, TcpListener, TlsStream>>,
) {
broadcast_udp_identity(Arc::clone(&device)).await;
let Ok(udp_listener) = device
.io_impl
.bind_udp_reuse_v6(SocketAddr::new(
IpAddr::V6(Ipv6Addr::UNSPECIFIED),
crate::config::UDP_PORT,
))
.await
else {
return;
};
let mut i = 0;
let mut buf = [0u8; crate::config::UDP_BUFFER_SIZE];
loop {
let (bytes_read, addr) = udp_listener.recv_from(&mut buf[i..]).await.unwrap();
i += bytes_read;
let device = Arc::clone(&device);
if let Ok(NetworkPacket {
body: NetworkPacketBody::Identity(packet),
..
}) = NetworkPacket::try_read_from(&buf[..i])
{
i = 0;
if packet.device_id == device.host_device_id {
continue;
}
log::debug!("UDP Identity packet received, upgrading connection");
upgrade_udp_connection(addr, packet, device).await;
}
if bytes_read == 0 {
log::warn!("udp: Failed to parse the received JSON");
i = 0;
}
}
}
async fn upgrade_udp_connection<
Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream> + Unpin + 'static,
UdpSocket: UdpSocketImpl + Unpin + 'static,
TcpStream: TcpStreamImpl + Unpin + 'static,
TcpListener: TcpListenerImpl<TcpStream> + Unpin + 'static,
TlsStream: TlsStreamImpl + Unpin + 'static,
>(
mut addr: SocketAddr,
packet: crate::packet::identity::IdentityPacket,
device: Arc<Device<Io, UdpSocket, TcpStream, TcpListener, TlsStream>>,
) {
if device.links.lock().await.contains_key(&packet.device_id) {
log::debug!(
"Device {} has already established connection, ignore the MDNS request",
packet.device_id
);
return;
}
addr.set_port(packet.get_tcp_port().unwrap());
let Ok(mut socket) = device.io_impl.connect_tcp(addr).await else {
return;
};
socket.writable().await.unwrap();
device
.get_identity_packet()
.write_to_socket_unencrypted(&mut socket)
.await;
let mut trusted_device_cert = device
.trust_handler
.lock()
.await
.get_certificate(&packet.device_id)
.await
.map(|f| Certificate::from_pem(f).unwrap());
let mut is_trusted = trusted_device_cert.is_some();
let current_time = device.io_impl.get_current_timestamp().await;
if let Some(cert) = &trusted_device_cert
&& super::tls::is_expired(
cert,
UnixTime::since_unix_epoch(Duration::from_secs(current_time)),
)
{
log::warn!(
"SSL certificate of the trusted device `{}` is expired, make the device untrusted",
packet.device_name.as_ref().unwrap_or(&packet.device_id)
);
device
.trust_handler
.lock()
.await
.untrust_device(&packet.device_id)
.await;
trusted_device_cert = None;
is_trusted = false;
}
let certs = vec![CertificateDer::from_pem_slice(&device.config.cert).unwrap()];
let key = PrivateKeyDer::from_pem_slice(&device.config.private_key).unwrap();
#[cfg(feature = "embedded")]
let config = rustls::ServerConfig::builder_with_details(
Arc::new(rustls_rustcrypto::provider()),
Arc::new(super::tls::CustomTimeProvider { current_time }),
)
.with_safe_default_protocol_versions()
.unwrap();
#[cfg(not(feature = "embedded"))]
let config = rustls::ServerConfig::builder();
let config = config
.with_client_cert_verifier(Arc::new(super::tls::CertVerifier {
trusted_device_cert,
trusted_device_id: packet.device_id.clone(),
}))
.with_single_cert(certs, key)
.unwrap();
log::debug!("Upgrade to SSL connection as server");
let Ok(socket_with_tls) = device.io_impl.accept_server_tls(config, socket).await else {
return;
};
log::debug!("SSL connection established successfully");
super::tls::per_tls_stream(packet, socket_with_tls, device, is_trusted, false).await;
}