use std::collections::HashMap;
use std::io::{Error, ErrorKind};
use std::pin::Pin;
use std::task::{Context, Poll};
use log::{debug, error, trace};
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf};
use tokio::sync::mpsc::channel;
use openssl::ssl::{SslFiletype, SslMethod, SslVerifyMode};
use super::{Ctl, Tokio};
use crate::{ClientOptions, COAP_MTU};
impl Tokio {
pub(crate) async fn new_dtls(peer: &str, opts: &ClientOptions) -> Result<Self, std::io::Error> {
debug!("Creating DTLS listener");
let udp_socket = Self::udp_connect(peer).await?;
let udp_stream = UdpStream::from(udp_socket);
let mut ssl_builder = openssl::ssl::SslContext::builder(SslMethod::dtls()).unwrap();
if opts.tls_skip_verify {
ssl_builder.set_verify(SslVerifyMode::NONE);
}
if let Some(tls_ca) = &opts.tls_ca {
ssl_builder.set_ca_file(tls_ca)?;
}
if let Some(tls_cert) = &opts.tls_cert {
ssl_builder.set_certificate_file(tls_cert, SslFiletype::PEM)?;
}
if let Some(tls_key) = &opts.tls_key {
ssl_builder.set_private_key_file(tls_key, SslFiletype::PEM)?;
}
let ssl_ctx = ssl_builder.build();
let ssl_conn = openssl::ssl::Ssl::new(&ssl_ctx).unwrap();
let mut dtls_stream = tokio_openssl::SslStream::new(ssl_conn, udp_stream).unwrap();
let connect = tokio_openssl::SslStream::connect(Pin::new(&mut dtls_stream));
if let Err(e) = tokio::time::timeout(opts.connect_timeout, connect).await? {
debug!("DTLS connect error: {:?}", e);
return Err(Error::new(ErrorKind::Other, "DTLS connect failed"));
};
let (mut udp_rx, mut udp_tx) = tokio::io::split(dtls_stream);
let (ctl_tx, mut ctl_rx) = channel::<Ctl>(1000);
let l_ctl_tx = ctl_tx.clone();
let _listener = tokio::task::spawn(async move {
let mut buff = [0u8; COAP_MTU];
let mut handles = HashMap::new();
loop {
tokio::select!(
ctl = ctl_rx.recv() => {
match ctl {
Some(Ctl::Register(token, rx)) => {
debug!("Register handler: {:x}", token);
handles.insert(token, rx);
},
Some(Ctl::Deregister(token)) => {
debug!("Deregister handler: {:x}", token);
handles.remove(&token);
},
Some(Ctl::Send(data)) => {
trace!("Tx: {:02x?}", data);
if let Err(e) = udp_tx.write(&data[..]).await {
error!("net transmit error: {:?}", e);
break;
}
},
Some(Ctl::Exit) => {
debug!("Exiting client");
break;
},
_ => (),
}
}
r = udp_rx.read(&mut buff) => {
let data = match r {
Ok(n) => &buff[..n],
Err(e) => {
error!("net receive error: {:?}", e);
break;
}
};
trace!("Rx: {:02x?}", data);
if let Err(e) = Self::handle_rx(&mut handles, data, l_ctl_tx.clone()).await {
error!("net handle error: {:?}", e);
break;
}
},
);
}
debug!("Exiting coap DTLS handler");
let mut dtls_stream = udp_rx.unsplit(udp_tx);
dtls_stream.shutdown().await?;
Ok(())
});
Ok(Self { ctl_tx, _listener })
}
}
pub struct UdpStream {
socket: tokio::net::UdpSocket,
}
impl From<tokio::net::UdpSocket> for UdpStream {
fn from(socket: tokio::net::UdpSocket) -> Self {
Self { socket }
}
}
impl std::io::Read for UdpStream {
fn read(&mut self, buff: &mut [u8]) -> std::result::Result<usize, std::io::Error> {
self.socket.try_recv(buff)
}
}
impl std::io::Write for UdpStream {
fn write(&mut self, buff: &[u8]) -> std::result::Result<usize, std::io::Error> {
self.socket.try_send(buff)
}
fn flush(&mut self) -> Result<(), std::io::Error> {
Ok(())
}
}
impl tokio::io::AsyncRead for UdpStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<(), std::io::Error>> {
match self.socket.poll_recv(cx, buf) {
Poll::Ready(Ok(_n)) => Poll::Ready(Ok(())),
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
impl tokio::io::AsyncWrite for UdpStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
self.socket.poll_send(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Poll::Ready(Ok(()))
}
}