use serde::{Deserialize, Serialize};
use socket2::Socket;
use std::{
collections::HashMap,
net::{SocketAddr, UdpSocket},
sync::Arc,
time::{Duration, Instant},
};
use structopt::StructOpt;
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::{io::Result, join, select, sync::RwLock, task::JoinHandle};
use crate::session_key::SessionKey;
use super::{
new_endpoint_channel, EndpointMessage, RemotePublic, TransportRecvMessage,
TransportSendMessage, CONNECTING_WAITING,
};
const DOMAIN: &str = "chamomile.quic";
const SIZE_LIMIT: usize = 67108864;
pub async fn start(
bind_addr: SocketAddr,
send: Sender<TransportRecvMessage>,
recv: Receiver<TransportSendMessage>,
both: bool,
) -> tokio::io::Result<SocketAddr> {
let config = InternalConfig::try_from_config(Default::default()).unwrap();
let udp_socket = UdpSocket::bind(&bind_addr)?;
let socket = Socket::from(udp_socket);
socket.set_reuse_address(true)?;
let new_udp_socket: UdpSocket = socket.into();
let endpoint = quinn::Endpoint::new(
Default::default(),
Some(config.server.clone()),
new_udp_socket,
Arc::new(quinn::TokioRuntime),
)
.unwrap();
let addr = endpoint.local_addr()?;
info!("QUIC listening at: {:?}", addr);
let out_send = send.clone();
let incoming = endpoint.clone();
let task = tokio::spawn(async move {
loop {
match incoming.accept().await {
Some(quinn_conn) => match quinn_conn.await {
Ok(conn) => {
if both {
let (self_sender, self_receiver) = new_endpoint_channel();
let (out_sender, out_receiver) = new_endpoint_channel();
tokio::spawn(process_stream(
conn,
out_sender,
self_receiver,
OutType::DHT(out_send.clone(), self_sender, out_receiver),
None,
None,
));
}
}
Err(err) => {
error!("An incoming failed because of an error: {:?}", err);
}
},
None => {
break;
}
}
}
});
tokio::spawn(run_self_recv(endpoint, config.client, recv, send, task));
Ok(addr)
}
async fn connect_to(
connect: std::result::Result<quinn::Connecting, quinn::ConnectError>,
remote_pk: RemotePublic,
) -> Result<quinn::Connection> {
let conn = connect
.map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "connecting failure."))?
.await?;
let mut stream = conn.open_uni().await?;
stream
.write_all(&EndpointMessage::Handshake(remote_pk).to_bytes())
.await?;
stream.finish().await?;
Ok(conn)
}
async fn dht_connect_to(
connect: std::result::Result<quinn::Connecting, quinn::ConnectError>,
out_send: Sender<TransportRecvMessage>,
remote_pk: RemotePublic,
session_key: SessionKey,
connectiongs: Arc<RwLock<HashMap<SocketAddr, Instant>>>,
) -> Result<()> {
let conn = connect_to(connect, remote_pk).await?;
let (self_sender, self_receiver) = new_endpoint_channel();
let (out_sender, out_receiver) = new_endpoint_channel();
process_stream(
conn,
out_sender,
self_receiver,
OutType::DHT(out_send, self_sender, out_receiver),
Some(session_key),
Some(connectiongs),
)
.await
}
async fn stable_connect_to(
connect: std::result::Result<quinn::Connecting, quinn::ConnectError>,
out_sender: Sender<EndpointMessage>,
self_receiver: Receiver<EndpointMessage>,
remote_pk: RemotePublic,
connectiongs: Arc<RwLock<HashMap<SocketAddr, Instant>>>,
) -> Result<()> {
match connect_to(connect, remote_pk).await {
Ok(conn) => {
process_stream(
conn,
out_sender,
self_receiver,
OutType::Stable,
None,
Some(connectiongs),
)
.await
}
Err(_) => {
let _ = out_sender.send(EndpointMessage::Close).await;
Ok(())
}
}
}
async fn run_self_recv(
endpoint: quinn::Endpoint,
client_cfg: quinn::ClientConfig,
mut recv: Receiver<TransportSendMessage>,
out_send: Sender<TransportRecvMessage>,
task: JoinHandle<()>,
) -> Result<()> {
let connecting: Arc<RwLock<HashMap<SocketAddr, Instant>>> =
Arc::new(RwLock::new(HashMap::new()));
while let Some(m) = recv.recv().await {
match m {
TransportSendMessage::Connect(addr, remote_pk, session_key) => {
let read_lock = connecting.read().await;
if let Some(time) = read_lock.get(&addr) {
if time.elapsed().as_secs() < CONNECTING_WAITING {
drop(read_lock);
continue;
}
}
drop(read_lock);
let mut lock = connecting.write().await;
lock.insert(addr, Instant::now());
drop(lock);
let connect = endpoint.connect_with(client_cfg.clone(), addr, DOMAIN);
info!("QUIC dht connect to: {:?}", addr);
tokio::spawn(dht_connect_to(
connect,
out_send.clone(),
remote_pk,
session_key,
connecting.clone(),
));
}
TransportSendMessage::StableConnect(out_sender, self_receiver, addr, remote_pk) => {
let read_lock = connecting.read().await;
if let Some(time) = read_lock.get(&addr) {
if time.elapsed().as_secs() < CONNECTING_WAITING {
drop(read_lock);
continue;
}
}
drop(read_lock);
let mut lock = connecting.write().await;
lock.insert(addr, Instant::now());
drop(lock);
let connect = endpoint.connect_with(client_cfg.clone(), addr, DOMAIN);
info!("QUIC stable connect to: {:?}", addr);
tokio::spawn(stable_connect_to(
connect,
out_sender,
self_receiver,
remote_pk,
connecting.clone(),
));
}
TransportSendMessage::Stop => {
task.abort();
endpoint.close(0u8.into(), &[]);
break;
}
}
}
Ok(())
}
enum OutType {
DHT(
Sender<TransportRecvMessage>,
Sender<EndpointMessage>,
Receiver<EndpointMessage>,
),
Stable,
}
async fn process_stream(
conn: quinn::Connection,
out_sender: Sender<EndpointMessage>,
mut self_receiver: Receiver<EndpointMessage>,
out_type: OutType,
has_session: Option<SessionKey>,
connectiongs: Option<Arc<RwLock<HashMap<SocketAddr, Instant>>>>,
) -> tokio::io::Result<()> {
let addr = conn.remote_address();
let handshake: std::result::Result<RemotePublic, ()> = select! {
v = async {
match conn.accept_uni().await {
Err(quinn::ConnectionError::ApplicationClosed { .. }) => {
debug!("Connection terminated by peer {:?}.", addr);
Err(())
}
Err(err) => {
debug!(
"Failed to read incoming message on uni-stream for peer {:?} with error: {:?}",
addr, err
);
Err(())
}
Ok(mut recv) => {
if let Ok(bytes) = recv.read_to_end(SIZE_LIMIT).await {
if let Ok(EndpointMessage::Handshake(remote_pk)) =
EndpointMessage::from_bytes(bytes)
{
return Ok(remote_pk);
} else {
Err(())
}
} else {
Err(())
}
}
}
} => v,
v = async {
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
Err(())
} => v
};
if handshake.is_err() {
debug!("Transport: connect read publics timeout, close it.");
return Ok(());
}
let remote_pk = handshake.unwrap();
match out_type {
OutType::Stable => {
out_sender
.send(EndpointMessage::Handshake(remote_pk))
.await
.map_err(|_e| {
std::io::Error::new(std::io::ErrorKind::Other, "endpoint channel missing")
})?;
}
OutType::DHT(sender, self_sender, out_receiver) => {
sender
.send(TransportRecvMessage(
addr,
remote_pk,
has_session,
out_sender.clone(),
out_receiver,
self_sender,
))
.await
.map_err(|_e| {
std::io::Error::new(std::io::ErrorKind::Other, "server channel missing")
})?;
}
}
if let Some(connectiongs) = connectiongs {
let mut lock = connectiongs.write().await;
lock.remove(&addr);
drop(lock);
drop(connectiongs);
}
let conn_send = conn.clone();
let a = async move {
loop {
match self_receiver.recv().await {
Some(msg) => {
let mut writer = conn_send.open_uni().await.map_err(|_e| ())?;
let is_close = match msg {
EndpointMessage::Close => true,
_ => false,
};
let _ = writer.write_all(&msg.to_bytes()).await;
let _ = writer.finish().await;
if is_close {
break;
}
}
None => break,
}
}
Err::<(), ()>(())
};
let b = async {
loop {
match conn.accept_uni().await {
Err(quinn::ConnectionError::ApplicationClosed { .. }) => {
debug!("Connection terminated by peer {:?}.", addr);
break;
}
Err(err) => {
debug!(
"Failed to read incoming message on uni-stream for peer {:?} with error: {:?}",
addr, err
);
break;
}
Ok(mut recv) => {
if let Ok(bytes) = recv.read_to_end(SIZE_LIMIT).await {
if let Ok(msg) = EndpointMessage::from_bytes(bytes) {
let _ = out_sender.send(msg).await;
}
}
}
}
}
};
let _ = join!(a, b);
info!("close stream: {}", addr);
conn.close(0u8.into(), &[]);
Ok(())
}
pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(5);
#[derive(Clone, Debug, Default, Serialize, Deserialize, Eq, PartialEq, StructOpt)]
pub struct Config {
#[serde(default)]
#[structopt(long, parse(try_from_str = parse_millis), value_name = "MILLIS")]
pub idle_timeout: Option<Duration>,
}
fn parse_millis(millis: &str) -> std::result::Result<Duration, std::num::ParseIntError> {
Ok(Duration::from_millis(millis.parse()?))
}
#[derive(Clone, Debug)]
pub(crate) struct InternalConfig {
pub(crate) client: quinn::ClientConfig,
pub(crate) server: quinn::ServerConfig,
}
impl InternalConfig {
pub(crate) fn try_from_config(config: Config) -> Result<Self> {
let idle_timeout =
quinn::IdleTimeout::try_from(config.idle_timeout.unwrap_or(DEFAULT_IDLE_TIMEOUT))
.map_err(|_e| {
std::io::Error::new(std::io::ErrorKind::Other, "rcgen generate failure.")
})?;
let mut tconfig = quinn::TransportConfig::default();
let _ = tconfig.max_idle_timeout(Some(idle_timeout));
let transport = Arc::new(tconfig);
let client = Self::new_client_config(transport.clone());
let server = Self::new_server_config(transport)?;
Ok(Self {
client,
server,
})
}
fn new_client_config(transport: Arc<quinn::TransportConfig>) -> quinn::ClientConfig {
let mut client_crypto = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(rustls::RootCertStore::empty())
.with_no_client_auth();
client_crypto
.dangerous()
.set_certificate_verifier(Arc::new(SkipCertificateVerification));
let mut config = quinn::ClientConfig::new(Arc::new(client_crypto));
config.transport_config(transport);
config
}
fn new_server_config(transport: Arc<quinn::TransportConfig>) -> Result<quinn::ServerConfig> {
let (cert, key) = Self::generate_cert()?;
let server_crypto = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert], key)
.map_err(|_e| {
std::io::Error::new(std::io::ErrorKind::Other, "server config failure.")
})?;
let mut config = quinn::ServerConfig::with_crypto(Arc::new(server_crypto));
config.transport = transport;
Ok(config)
}
fn generate_cert() -> Result<(rustls::Certificate, rustls::PrivateKey)> {
let cert = rcgen::generate_simple_self_signed(vec![DOMAIN.to_string()]).map_err(|_e| {
std::io::Error::new(std::io::ErrorKind::Other, "rcgen generate failure.")
})?;
let cert_der = cert.serialize_der().map_err(|_e| {
std::io::Error::new(std::io::ErrorKind::Other, "cert serialize failure.")
})?;
let key_der = cert.serialize_private_key_der();
Ok((rustls::Certificate(cert_der), rustls::PrivateKey(key_der)))
}
}
struct SkipCertificateVerification;
impl rustls::client::ServerCertVerifier for SkipCertificateVerification {
fn verify_server_cert(
&self,
_: &rustls::Certificate,
_: &[rustls::Certificate],
_: &rustls::ServerName,
_: &mut dyn Iterator<Item = &[u8]>,
_: &[u8],
_: std::time::SystemTime,
) -> std::result::Result<rustls::client::ServerCertVerified, rustls::Error> {
Ok(rustls::client::ServerCertVerified::assertion())
}
}