use std::{ffi::OsString, io::stdin, net::SocketAddr, sync::Arc};
use anyhow::{Context as _, Result};
use aws_lc_rs::{
agreement::{EphemeralPrivateKey, X25519},
rand::SystemRandom,
};
use clap::Parser as _;
use libmoshpit::{
ConnectionReader, ConnectionWriter, Frame, MoshpitError, UdpState, init_tracing, load,
};
use tokio::{
net::{TcpStream, UdpSocket},
spawn,
sync::mpsc::unbounded_channel,
};
use tracing::{error, info, trace};
use uuid::Uuid;
use crate::{
cli::Cli,
config::Config,
tcp::{reader::FrameReader, sender::FrameSender},
udp::{reader::UdpReader, sender::UdpSender},
};
pub(crate) async fn run<I, T>(args: Option<I>) -> Result<()>
where
I: IntoIterator<Item = T>,
T: Into<OsString> + Clone,
{
let cli = if let Some(args) = args {
Cli::try_parse_from(args)?
} else {
Cli::try_parse()?
};
let config = load::<Cli, Config, Cli>(&cli, &cli).with_context(|| MoshpitError::ConfigLoad)?;
init_tracing(&config, config.tracing().file(), &cli, None)
.with_context(|| MoshpitError::TracingInit)?;
trace!("Configuration loaded");
trace!("Tracing initialized");
let socket_addr = config
.server_ip()
.parse::<SocketAddr>()
.with_context(|| MoshpitError::InvalidServerAddress)?;
let (key_bytes, hmac_key_bytes, uuid) = run_key_exchange(&socket_addr).await?;
let udp_listener = UdpSocket::bind("127.0.0.1:0").await?;
udp_listener.connect(socket_addr).await?;
let udp_recv = Arc::new(udp_listener);
let udp_send = udp_recv.clone();
let (tx, rx) = unbounded_channel::<Vec<u8>>();
let mut udp_reader = UdpReader::builder().socket(udp_recv).build();
let mut udp_sender = UdpSender::builder()
.socket(udp_send)
.rx(rx)
.id(uuid)
.hmac(hmac_key_bytes)
.rnk(key_bytes)?
.build();
let _udp_handle = spawn(async move {
if let Err(e) = udp_sender.handle_send().await {
error!("udp sender error {e}");
}
});
let _udp_reader_handle = spawn(async move {
if let Err(e) = udp_reader.handle_read().await {
error!("udp reader error {e}");
}
});
tx.send(b"Hello, world!".to_vec())?;
loop {
let mut input = String::new();
match stdin().read_line(&mut input) {
Ok(_n) => {
tx.send(input.into_bytes())?;
}
Err(error) => {
println!("error: {error}");
break;
}
}
}
Ok(())
}
async fn run_key_exchange(socket_addr: &SocketAddr) -> Result<([u8; 32], [u8; 64], Uuid)> {
let socket = TcpStream::connect(socket_addr).await?;
let (sock_read, sock_write) = socket.into_split();
let reader = ConnectionReader::builder().reader(sock_read).build();
let writer = ConnectionWriter::builder().writer(sock_write).build();
let (tx, rx) = unbounded_channel();
let (tx_udp_state, mut rx_udp_state) = unbounded_channel::<UdpState>();
info!("Connected to the server!");
let rng = SystemRandom::new();
let pk = EphemeralPrivateKey::generate(&X25519, &rng)?;
let my_public_key = pk.compute_public_key()?;
trace!("Generated ephemeral X25519 key pair");
let tx_c = tx.clone();
let tx_udp_state_c = tx_udp_state.clone();
let _read_handle = spawn(async move {
let mut frame_reader = FrameReader::builder()
.reader(reader)
.tx(tx_c)
.tx_udp(tx_udp_state_c)
.build();
if let Err(e) = frame_reader.handle_connection(pk).await {
error!("mps frame reader: {e}");
}
});
trace!("Spawned TCP frame reader task");
let _write_handle = spawn(async move {
let mut sender = FrameSender::builder().writer(writer).rx(rx).build();
if let Err(e) = sender.handle_tx().await {
error!("mp sender error {e}");
}
});
trace!("Spawned TCP frame sender task");
trace!("Sending initialize frame...");
let frame = Frame::Initialize(my_public_key.as_ref().to_vec());
tx.send(frame.clone())?;
let mut key_bytes = [0u8; 32];
let mut hmac_key_bytes = [0u8; 64];
let mut uuid = Uuid::nil();
while let Some(udp_state) = rx_udp_state.recv().await {
match udp_state {
UdpState::Key(key_b) => {
trace!("Received UDP key");
key_bytes = key_b;
}
UdpState::HmacKey(hmac_key_b) => {
trace!("Received UDP HMAC key");
hmac_key_bytes = hmac_key_b;
}
UdpState::Uuid(set_uuid) => {
trace!("Received UDP UUID: {}", set_uuid);
uuid = set_uuid;
break;
}
}
}
Ok((key_bytes, hmac_key_bytes, uuid))
}