use std::{
fmt::{self, Display, Formatter},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
sync::Arc,
};
use anyhow::Result;
use aws_lc_rs::{
aead::{
AES_128_GCM_SIV, AES_256_GCM, AES_256_GCM_SIV, CHACHA20_POLY1305, LessSafeKey, UnboundKey,
},
hmac::{HMAC_SHA256, HMAC_SHA512, Key},
};
use bon::Builder;
use getset::{CopyGetters, Getters};
use serde::{Deserialize, Serialize};
use socket2::SockRef;
use tokio::{
net::{
UdpSocket,
tcp::{OwnedReadHalf, OwnedWriteHalf},
},
spawn,
sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel},
task::JoinHandle,
};
use tracing::{debug, error, info, trace};
use uuid::Uuid;
use crate::{
ConnectionReader, ConnectionWriter, Frame, KexConfig, KexReader, KexSender, MoshpitError,
UuidWrapper, kex::negotiate::NegotiatedAlgorithms, load_identity_key, load_public_key,
udp::DiffMode,
};
fn fmt_hex(bytes: &[u8]) -> String {
use std::fmt::Write as _;
bytes
.iter()
.fold(String::with_capacity(bytes.len() * 2), |mut s, b| {
let _ = write!(s, "{b:02x}");
s
})
}
pub type TofuFn = Arc<dyn Fn(&str, &str) -> Result<bool> + Send + Sync>;
pub type HostKeyMismatchFn = Arc<dyn Fn(&str, &str, &str) -> Result<bool> + Send + Sync>;
#[derive(Clone)]
struct HostKeyCallbacks {
tofu_fn: Option<TofuFn>,
host_key_mismatch_fn: Option<HostKeyMismatchFn>,
}
pub(crate) mod negotiate;
#[must_use]
pub fn env_var_matches(name: &str, patterns: &[String]) -> bool {
patterns.iter().any(|pat| {
if let Some(prefix) = pat.strip_suffix('*') {
name.starts_with(prefix)
} else {
name == pat.as_str()
}
})
}
pub(crate) mod reader;
pub(crate) mod sender;
#[derive(Clone, Debug)]
pub enum KexEvent {
NegotiatedAlgorithms(NegotiatedAlgorithms),
KeyMaterial(Vec<u8>),
HMACKeyMaterial(Vec<u8>),
Uuid(Uuid),
MoshpitsAddr(SocketAddr),
SessionInfo(Uuid, bool),
Failure,
NoCommonAlgorithm,
}
#[derive(Clone, Copy, Debug, Default)]
pub enum KexState {
#[default]
AwaitingNegotiatedAlgorithms,
AwaitingKeyMaterial,
AwaitingHMACKeyMaterial,
AwaitingUuid,
AwaitingSessionToken,
AwaitingMoshpitsAddr,
Complete,
}
#[derive(Builder, CopyGetters, Debug)]
pub struct KexStateMachine {
#[getset(get_copy = "pub")]
#[builder(default = KexState::default())]
state: KexState,
rx_event: UnboundedReceiver<KexEvent>,
}
#[derive(Clone, Debug, CopyGetters, Getters)]
pub struct Kex {
#[getset(get = "pub")]
key: Vec<u8>,
#[getset(get = "pub")]
hmac_key: Vec<u8>,
#[getset(get_copy = "pub")]
uuid: Uuid,
#[getset(get_copy = "pub")]
moshpits_addr: Option<SocketAddr>,
#[getset(get_copy = "pub")]
session_uuid: Option<Uuid>,
#[getset(get_copy = "pub")]
is_resume: bool,
#[getset(get = "pub")]
negotiated_algorithms: NegotiatedAlgorithms,
}
impl Kex {
#[must_use]
pub fn uuid_wrapper(&self) -> UuidWrapper {
UuidWrapper::new(self.uuid)
}
pub fn build_aead_key(&self) -> Result<LessSafeKey> {
use negotiate::{
AEAD_AES128_GCM_SIV, AEAD_AES256_GCM, AEAD_AES256_GCM_SIV, AEAD_CHACHA20_POLY1305,
};
let alg: &'static aws_lc_rs::aead::Algorithm =
match self.negotiated_algorithms.aead.as_str() {
AEAD_AES256_GCM_SIV => &AES_256_GCM_SIV,
AEAD_AES256_GCM => &AES_256_GCM,
AEAD_CHACHA20_POLY1305 => &CHACHA20_POLY1305,
AEAD_AES128_GCM_SIV => &AES_128_GCM_SIV,
_ => return Err(MoshpitError::NoCommonAlgorithm.into()),
};
debug!(
aead = %self.negotiated_algorithms.aead,
key_len = self.key.len(),
key_hex = %fmt_hex(&self.key),
"build_aead_key: constructing LessSafeKey"
);
Ok(LessSafeKey::new(UnboundKey::new(alg, &self.key)?))
}
#[must_use]
pub fn build_hmac(&self) -> Key {
use negotiate::MAC_HMAC_SHA256;
if self.negotiated_algorithms.mac.as_str() == MAC_HMAC_SHA256 {
Key::new(HMAC_SHA256, &self.hmac_key)
} else {
Key::new(HMAC_SHA512, &self.hmac_key)
}
}
#[must_use]
pub fn mac_tag_len(&self) -> usize {
use negotiate::MAC_HMAC_SHA256;
if self.negotiated_algorithms.mac.as_str() == MAC_HMAC_SHA256 {
32
} else {
64
}
}
}
impl Default for Kex {
fn default() -> Self {
Self {
key: Vec::new(),
hmac_key: Vec::new(),
uuid: Uuid::nil(),
moshpits_addr: None,
session_uuid: None,
is_resume: false,
negotiated_algorithms: NegotiatedAlgorithms::default(),
}
}
}
#[derive(Builder, Clone, Debug, CopyGetters, Getters)]
pub struct ServerKex {
#[getset(get = "pub")]
user: String,
#[getset(get = "pub")]
shell: String,
#[getset(get_copy = "pub")]
session_uuid: Uuid,
#[getset(get_copy = "pub")]
#[builder(default)]
is_resume: bool,
#[getset(get_copy = "pub")]
#[builder(default)]
diff_mode: DiffMode,
#[getset(get = "pub")]
#[builder(default)]
negotiated_algorithms: NegotiatedAlgorithms,
#[getset(get = "pub")]
#[builder(default)]
client_env: Vec<(String, String)>,
#[getset(get = "pub")]
#[builder(default)]
client_extra_path: Vec<String>,
}
impl KexStateMachine {
pub async fn handle_events(&mut self, client_mode: bool) -> Result<Kex> {
let mut kex = Kex::default();
while let Some(event) = self.rx_event.recv().await {
match (self.state, event) {
(KexState::AwaitingNegotiatedAlgorithms, KexEvent::NegotiatedAlgorithms(algos)) => {
kex.negotiated_algorithms = algos;
self.state = KexState::AwaitingKeyMaterial;
}
(KexState::AwaitingKeyMaterial, KexEvent::KeyMaterial(key_material)) => {
kex.key = key_material;
self.state = KexState::AwaitingHMACKeyMaterial;
}
(
KexState::AwaitingHMACKeyMaterial,
KexEvent::HMACKeyMaterial(hmac_key_material),
) => {
kex.hmac_key = hmac_key_material;
self.state = KexState::AwaitingUuid;
}
(KexState::AwaitingUuid, KexEvent::Uuid(uuid)) => {
kex.uuid = uuid;
if client_mode {
self.state = KexState::AwaitingSessionToken;
} else {
self.state = KexState::Complete;
break;
}
}
(
KexState::AwaitingSessionToken,
KexEvent::SessionInfo(session_uuid, is_resume),
) => {
kex.session_uuid = Some(session_uuid);
kex.is_resume = is_resume;
self.state = KexState::AwaitingMoshpitsAddr;
}
(KexState::AwaitingMoshpitsAddr, KexEvent::MoshpitsAddr(addr)) => {
self.state = KexState::Complete;
kex.moshpits_addr = Some(addr);
break;
}
(_, KexEvent::NoCommonAlgorithm) => {
return Err(MoshpitError::NoCommonAlgorithm.into());
}
_ => {
return Err(MoshpitError::InvalidKexState.into());
}
}
}
match self.state {
KexState::Complete => Ok(kex),
_ => Err(MoshpitError::InvalidKexState.into()),
}
}
}
#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
pub enum KexMode {
#[default]
Client,
Server(SocketAddr),
}
impl Display for KexMode {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
KexMode::Client => write!(f, "Client"),
KexMode::Server(addr) => write!(f, "Server({addr})"),
}
}
}
pub async fn run_key_exchange<T: KexConfig>(
config: T,
sock_read: OwnedReadHalf,
sock_write: OwnedWriteHalf,
passphrase_fn: impl Fn() -> Result<Option<String>>,
tofu_fn: Option<TofuFn>,
host_key_mismatch_fn: Option<HostKeyMismatchFn>,
) -> Result<(Kex, Arc<UdpSocket>, Option<ServerKex>)> {
let mode = config.mode();
let reader = ConnectionReader::builder().reader(sock_read).build();
let writer = ConnectionWriter::builder().writer(sock_write).build();
let (tx, rx) = unbounded_channel();
let (tx_event, rx_event) = unbounded_channel::<KexEvent>();
let mut kex_sm = KexStateMachine::builder().rx_event(rx_event).build();
let kex_handle = spawn(async move { kex_sm.handle_events(mode == KexMode::Client).await });
let _write_handle = spawn(async move {
let mut sender = KexSender::builder().writer(writer).rx(rx).build();
if let Err(e) = sender.handle_send_frames().await {
error!("{e}");
}
});
Ok(match mode {
KexMode::Client => {
run_client_kex(
config,
tx,
tx_event,
reader,
kex_handle,
passphrase_fn,
HostKeyCallbacks {
tofu_fn,
host_key_mismatch_fn,
},
)
.await?
}
KexMode::Server(socket_addr) => {
let tx_c = tx.clone();
match run_server_kex(config, socket_addr, tx, tx_event, reader, kex_handle).await {
Ok(result) => result,
Err(e) => {
let _blah = tx_c.send(Frame::KexFailure);
Err(e)?
}
}
}
})
}
#[cfg_attr(nightly, allow(clippy::too_many_lines))]
async fn run_client_kex<T: KexConfig>(
config: T,
tx: UnboundedSender<Frame>,
tx_event: UnboundedSender<KexEvent>,
reader: ConnectionReader,
kex_handle: JoinHandle<Result<Kex>>,
passphrase_fn: impl Fn() -> Result<Option<String>>,
callbacks: HostKeyCallbacks,
) -> Result<(Kex, Arc<UdpSocket>, Option<ServerKex>)> {
let (private_key_path, public_key_path) = config.key_pair_paths()?;
info!("Loading private key from {}", private_key_path.display());
info!("Loading public key from {}", public_key_path.display());
let (full_public_key_bytes, public_key_bytes) = load_public_key(&public_key_path)
.inspect_err(|e| {
error!(
"Failed to load public key from {}: {e}",
public_key_path.display()
);
})
.map_err(|_| MoshpitError::KeyFileMissing)?;
if !private_key_path.try_exists().unwrap_or(false) {
error!(
"Failed to load private key from {}: file does not exist",
private_key_path.display()
);
return Err(MoshpitError::KeyFileMissing.into());
}
let identity_key = if let Ok(identity_key) = load_identity_key(&private_key_path, None) {
info!("Private key is unencrypted — no passphrase needed");
identity_key
} else {
info!("Private key may be encrypted — invoking passphrase prompt");
let passphrase = passphrase_fn().map_err(|e| {
error!("Passphrase prompt failed: {e}");
e
})?;
let Some(passphrase) = passphrase else {
error!("Passphrase prompt returned no input — cannot decrypt key");
return Err(MoshpitError::KeyCorrupt.into());
};
load_identity_key(&private_key_path, Some(&passphrase))
.inspect_err(|e| error!("Private key validation failed: {e}"))
.map_err(|_| MoshpitError::KeyCorrupt)?
};
if identity_key.public_key().as_slice() != public_key_bytes.as_slice() {
error!(
"Computed public key does not match stored public key at {}",
public_key_path.display()
);
return Err(MoshpitError::KeyPairMismatch.into());
}
info!(
"Private identity key ({}) verified successfully",
identity_key.key_algorithm()
);
let tx_c = tx.clone();
let tx_event_c = tx_event.clone();
let requested = config.resume_session_uuid();
let server_id = config.server_id();
let HostKeyCallbacks {
tofu_fn,
host_key_mismatch_fn,
} = callbacks;
let diff_mode = config.diff_mode();
let client_algos = config.preferred_algorithms();
let user = config.user().unwrap_or_default();
let send_env_patterns = config.send_env();
let send_env: Vec<(String, String)> = std::env::vars()
.filter(|(k, _)| env_var_matches(k, &send_env_patterns))
.collect();
let send_path = config.send_path();
#[cfg(feature = "unstable")]
let client_identity_key_algorithm = identity_key.key_algorithm().clone();
#[cfg(feature = "unstable")]
let client_identity_private_key = identity_key.private_key().clone();
let _read_handle = spawn(async move {
#[cfg(feature = "unstable")]
let mut frame_reader = KexReader::builder()
.reader(reader)
.tx(tx_c)
.tx_event(tx_event_c)
.maybe_requested_session_uuid(requested)
.maybe_server_destination(server_id)
.maybe_tofu_fn(tofu_fn)
.maybe_host_key_mismatch_fn(host_key_mismatch_fn)
.diff_mode(diff_mode)
.client_algos(client_algos)
.user(user)
.full_public_key_bytes(full_public_key_bytes)
.client_identity_key_algorithm(client_identity_key_algorithm)
.client_identity_private_key(client_identity_private_key)
.send_env(send_env)
.send_path(send_path)
.build();
#[cfg(not(feature = "unstable"))]
let mut frame_reader = KexReader::builder()
.reader(reader)
.tx(tx_c)
.tx_event(tx_event_c)
.maybe_requested_session_uuid(requested)
.maybe_server_destination(server_id)
.maybe_tofu_fn(tofu_fn)
.maybe_host_key_mismatch_fn(host_key_mismatch_fn)
.diff_mode(diff_mode)
.client_algos(client_algos)
.user(user)
.full_public_key_bytes(full_public_key_bytes)
.send_env(send_env)
.send_path(send_path)
.build();
if let Err(e) = frame_reader.client_kex().await {
error!("client_kex failed: {e}");
}
});
tx.send(Frame::KexInit(config.preferred_algorithms()))?;
let kex = kex_handle.await??;
if let Some(moshpits_addr) = kex.moshpits_addr() {
trace!("Connecting to moshpits at {moshpits_addr}");
let bind_addr = if moshpits_addr.is_ipv6() {
SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)
} else {
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
};
let udp_listener = UdpSocket::bind(bind_addr).await?;
let sock = SockRef::from(&udp_listener);
drop(sock.set_recv_buffer_size(4 * 1024 * 1024));
drop(sock.set_send_buffer_size(4 * 1024 * 1024));
#[cfg(any(target_os = "linux", target_os = "macos"))]
if bind_addr.is_ipv4() {
drop(sock.set_tos_v4(0xB8));
} else {
drop(sock.set_tclass_v6(0xB8));
}
udp_listener.connect(moshpits_addr).await?;
Ok((kex, Arc::new(udp_listener), None))
} else {
Err(MoshpitError::InvalidMoshpitsAddress.into())
}
}
async fn run_server_kex<T: KexConfig>(
config: T,
socket_addr: SocketAddr,
tx: UnboundedSender<Frame>,
tx_event: UnboundedSender<KexEvent>,
reader: ConnectionReader,
kex_handle: JoinHandle<Result<Kex>>,
) -> Result<(Kex, Arc<UdpSocket>, Option<ServerKex>)> {
let port_pool_opt = config.port_pool();
let (_private_key_path, public_key_path) = config.key_pair_paths()?;
let session_registry = config.session_registry();
trace!(
"Loading identity public key from {}",
public_key_path.display()
);
let tx_c = tx.clone();
let tx_event_c = tx_event.clone();
let server_preferred_algos = config.preferred_algorithms();
let mut frame_reader = KexReader::builder()
.reader(reader)
.tx(tx_c)
.tx_event(tx_event_c)
.server_preferred_algos(server_preferred_algos)
.build();
if let Some(port_pool) = port_pool_opt {
let (skex, udp_arc) = frame_reader
.server_kex(socket_addr, port_pool, &public_key_path, session_registry)
.await?;
Ok((kex_handle.await??, udp_arc, Some(skex)))
} else {
Err(anyhow::anyhow!(
"Port pool is required for server key exchange"
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn kex_state_machine_server_mode_completes_after_uuid() {
use crate::kex::negotiate::NegotiatedAlgorithms;
let (tx, rx) = unbounded_channel();
let mut sm = KexStateMachine::builder().rx_event(rx).build();
let key = vec![1u8; 32];
let hmac_key = vec![2u8; 64];
let uuid = Uuid::new_v4();
tx.send(KexEvent::NegotiatedAlgorithms(
NegotiatedAlgorithms::default(),
))
.unwrap();
tx.send(KexEvent::KeyMaterial(key.clone())).unwrap();
tx.send(KexEvent::HMACKeyMaterial(hmac_key.clone()))
.unwrap();
tx.send(KexEvent::Uuid(uuid)).unwrap();
drop(tx);
let kex = sm.handle_events(false).await.unwrap();
assert_eq!(kex.key().as_slice(), key.as_slice());
assert_eq!(kex.hmac_key().as_slice(), hmac_key.as_slice());
assert_eq!(kex.uuid(), uuid);
assert!(kex.moshpits_addr().is_none());
assert!(kex.session_uuid().is_none());
}
#[tokio::test]
async fn kex_state_machine_client_mode_full_sequence() {
use crate::kex::negotiate::NegotiatedAlgorithms;
let (tx, rx) = unbounded_channel();
let mut sm = KexStateMachine::builder().rx_event(rx).build();
let key = vec![3u8; 32];
let hmac_key = vec![4u8; 64];
let uuid = Uuid::new_v4();
let session_uuid = Uuid::new_v4();
let addr: SocketAddr = "127.0.0.1:50001".parse().unwrap();
tx.send(KexEvent::NegotiatedAlgorithms(
NegotiatedAlgorithms::default(),
))
.unwrap();
tx.send(KexEvent::KeyMaterial(key.clone())).unwrap();
tx.send(KexEvent::HMACKeyMaterial(hmac_key.clone()))
.unwrap();
tx.send(KexEvent::Uuid(uuid)).unwrap();
tx.send(KexEvent::SessionInfo(session_uuid, false)).unwrap();
tx.send(KexEvent::MoshpitsAddr(addr)).unwrap();
let kex = sm.handle_events(true).await.unwrap();
assert_eq!(kex.key().as_slice(), key.as_slice());
assert_eq!(kex.hmac_key().as_slice(), hmac_key.as_slice());
assert_eq!(kex.uuid(), uuid);
assert_eq!(kex.session_uuid(), Some(session_uuid));
assert_eq!(kex.moshpits_addr(), Some(addr));
assert!(!kex.is_resume());
}
#[tokio::test]
async fn kex_state_machine_wrong_event_order_returns_invalid_state() {
let (tx, rx) = unbounded_channel();
let mut sm = KexStateMachine::builder().rx_event(rx).build();
tx.send(KexEvent::Uuid(Uuid::new_v4())).unwrap();
drop(tx);
let result = sm.handle_events(true).await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.downcast_ref::<MoshpitError>()
.is_some_and(|e| *e == MoshpitError::InvalidKexState),
);
}
#[tokio::test]
async fn kex_state_machine_channel_dropped_returns_invalid_state() {
let (tx, rx) = unbounded_channel::<KexEvent>();
let mut sm = KexStateMachine::builder().rx_event(rx).build();
drop(tx);
let result = sm.handle_events(true).await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.downcast_ref::<MoshpitError>()
.is_some_and(|e| *e == MoshpitError::InvalidKexState),
);
}
#[test]
fn kex_mode_display_formatting() {
assert_eq!(format!("{}", KexMode::Client), "Client");
let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
assert_eq!(
format!("{}", KexMode::Server(addr)),
"Server(127.0.0.1:12345)"
);
}
}