use core::{
net::{IpAddr, Ipv6Addr, SocketAddr},
time::Duration,
};
#[cfg(feature = "std")]
use std::sync::Arc;
#[cfg(not(feature = "std"))]
use alloc::{sync::Arc, vec};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime, pem::PemObject};
use x509_cert::{Certificate, der::DecodePem};
use crate::{
device::Device, io::{IoImpl, KnownFunctionName, TcpListenerImpl, TcpStreamImpl, TlsStreamImpl, UdpSocketImpl}, packet::{NetworkPacket, NetworkPacketBody}
};
pub async fn setup_tcp<
Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream> + Unpin + 'static,
UdpSocket: UdpSocketImpl + Unpin + 'static,
TcpStream: TcpStreamImpl + Unpin + 'static,
TcpListener: TcpListenerImpl<TcpStream> + Unpin + 'static,
TlsStream: TlsStreamImpl + Unpin + 'static,
>(
device: Arc<Device<Io, UdpSocket, TcpStream, TcpListener, TlsStream>>,
) {
#[allow(unused_mut)]
let (mut tcp_listener, tcp_port) = {
let mut port = crate::config::MIN_TCP_PORT;
loop {
if port > crate::config::MAX_TCP_PORT {
log::error!(
"No port available in {}..{}",
crate::config::MIN_TCP_PORT,
crate::config::MAX_TCP_PORT
);
return;
}
if let Ok(tcp_listener) = device
.io_impl
.listen_tcp(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port))
.await
{
break (tcp_listener, port);
}
port += 1;
}
};
let _ = device.my_tcp_port.set(tcp_port).await;
super::start_discovering(Arc::clone(&device));
loop {
let Ok(socket) = tcp_listener.accept().await else {
#[cfg(feature = "embedded")]
device.io_impl.sleep(Duration::from_millis(500)).await;
continue;
};
Arc::clone(&device)
.io_impl
.spawn(KnownFunctionName::PerTcpStream(socket), Arc::clone(&device));
#[cfg(feature = "embedded")]
{
tcp_listener = device
.io_impl
.listen_tcp(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), tcp_port))
.await
.unwrap();
}
}
}
pub async fn per_tcp_stream<
Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream> + Unpin + 'static,
UdpSocket: UdpSocketImpl + Unpin + 'static,
TcpStream: TcpStreamImpl + Unpin + 'static,
TcpListener: TcpListenerImpl<TcpStream> + Unpin + 'static,
TlsStream: TlsStreamImpl + Unpin + 'static,
>(
mut socket: TcpStream,
device: Arc<Device<Io, UdpSocket, TcpStream, TcpListener, TlsStream>>,
) {
let mut i = 0;
let mut buf = [0; crate::config::TCP_BUFFER_SIZE];
loop {
let bytes_read = socket.read(&mut buf[i..]).await.unwrap();
i += bytes_read;
if let Ok(NetworkPacket {
body: NetworkPacketBody::Identity(packet),
..
}) = NetworkPacket::try_read_from(&buf[..i])
{
log::debug!("TCP Identity packet received, upgrading connection");
upgrade_tcp_connection(packet, socket, device).await;
break;
}
if bytes_read == 0 {
log::warn!("tcp: Failed to parse the received JSON");
break;
}
}
}
async fn upgrade_tcp_connection<
Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream> + Unpin + 'static,
UdpSocket: UdpSocketImpl + Unpin + 'static,
TcpStream: TcpStreamImpl + Unpin + 'static,
TcpListener: TcpListenerImpl<TcpStream> + Unpin + 'static,
TlsStream: TlsStreamImpl + Unpin + 'static,
>(
packet: crate::packet::identity::IdentityPacket,
socket: TcpStream,
device: Arc<Device<Io, UdpSocket, TcpStream, TcpListener, TlsStream>>,
) {
if device.links.lock().await.contains_key(&packet.device_id) {
log::debug!(
"Device {} has already established connection, ignore the MDNS request",
packet.device_id
);
return;
}
let mut trusted_device_cert = device
.trust_handler
.lock()
.await
.get_certificate(&packet.device_id)
.await
.map(|f| Certificate::from_pem(f).unwrap());
let mut is_trusted = trusted_device_cert.is_some();
let current_time = device.io_impl.get_current_timestamp().await;
if let Some(cert) = &trusted_device_cert
&& super::tls::is_expired(
cert,
UnixTime::since_unix_epoch(Duration::from_secs(current_time)),
)
{
log::warn!(
"SSL certificate of the trusted device `{}` is expired, make the device untrusted",
packet.device_name.as_ref().unwrap_or(&packet.device_id)
);
device
.trust_handler
.lock()
.await
.untrust_device(&packet.device_id)
.await;
trusted_device_cert = None;
is_trusted = false;
}
let certs = vec![CertificateDer::from_pem_slice(&device.config.cert).unwrap()];
let key = PrivateKeyDer::from_pem_slice(&device.config.private_key).unwrap();
#[cfg(feature = "embedded")]
let config = rustls::ClientConfig::builder_with_details(
Arc::new(rustls_rustcrypto::provider()),
Arc::new(super::tls::CustomTimeProvider { current_time }),
)
.with_safe_default_protocol_versions()
.unwrap();
#[cfg(not(feature = "embedded"))]
let config = rustls::ClientConfig::builder();
let config = config
.dangerous()
.with_custom_certificate_verifier(Arc::new(super::tls::CertVerifier {
trusted_device_cert,
trusted_device_id: packet.device_id.clone(),
}))
.with_client_auth_cert(certs, key)
.unwrap();
log::debug!("Upgrade to SSL connection as client");
let server_name = ServerName::try_from(packet.device_id.clone()).unwrap();
let Ok(socket_with_tls) = device
.io_impl
.connect_client_tls(config, server_name, socket)
.await
else {
return;
};
log::debug!("SSL connection established successfully");
super::tls::per_tls_stream(packet, socket_with_tls, device, is_trusted, true).await;
}