use std::{
collections::BTreeSet,
fs::OpenOptions,
io::{BufRead, BufReader},
net::SocketAddr,
path::{Path, PathBuf},
sync::Arc,
};
use anyhow::Result;
use aws_lc_rs::{
aead::{AES_256_GCM_SIV, Aad, Nonce, RandomizedNonceKey},
agreement::{ParsedPublicKey, PrivateKey, UnparsedPublicKey, X25519, agree},
cipher::AES_256_KEY_LEN,
digest::SHA512_OUTPUT_LEN,
error::Unspecified,
hkdf::{HKDF_SHA256, HKDF_SHA512, Salt},
rand::fill,
};
use bon::Builder;
use local_ip_address::local_ip;
use tokio::{
net::UdpSocket,
process::Command,
sync::{Mutex, mpsc::UnboundedSender},
};
use tracing::{error, trace};
use uuid::Uuid;
use crate::{
ConnectionReader, Frame, KexEvent, MoshpitError, ServerKex, UuidWrapper, load_private_key,
load_public_key, session::SessionRegistry,
};
const AEAD_KEY_INFO: &[u8] = b"AEAD KEY";
const HMAC_KEY_INFO: &[u8] = b"HMAC KEY";
#[derive(Builder, Debug)]
pub struct KexReader {
reader: ConnectionReader,
tx: UnboundedSender<Frame>,
tx_event: UnboundedSender<KexEvent>,
requested_session_uuid: Option<Uuid>,
}
impl KexReader {
pub async fn client_kex(&mut self, epk: &PrivateKey) -> Result<()> {
if let Some(frame) = self.reader.read_frame().await? {
if let Frame::PeerInitialize(pk, salt_bytes) = frame {
let peer_public_key = UnparsedPublicKey::new(&X25519, &pk);
let salt = Salt::new(HKDF_SHA256, &salt_bytes);
agree(epk, peer_public_key, Unspecified, |key_material| {
let pseudo_random_key = salt.extract(key_material);
let mut check = b"Yoda".to_vec();
let okm_aes = pseudo_random_key.expand(&[AEAD_KEY_INFO], &AES_256_GCM_SIV)?;
let mut key_bytes = [0u8; AES_256_KEY_LEN];
okm_aes.fill(&mut key_bytes)?;
let okm_hmac =
pseudo_random_key.expand(&[HMAC_KEY_INFO], HKDF_SHA512.hmac_algorithm())?;
let mut hmac_key_bytes = [0u8; SHA512_OUTPUT_LEN];
okm_hmac.fill(&mut hmac_key_bytes)?;
self.tx_event
.send(KexEvent::KeyMaterial(key_bytes))
.map_err(|_| Unspecified)?;
self.tx_event
.send(KexEvent::HMACKeyMaterial(hmac_key_bytes))
.map_err(|_| Unspecified)?;
let rnk = RandomizedNonceKey::new(&AES_256_GCM_SIV, &key_bytes)?;
let nonce = rnk.seal_in_place_append_tag(Aad::empty(), &mut check)?;
self.tx
.send(Frame::Check(*nonce.as_ref(), check))
.map_err(|_| Unspecified)?;
Ok(())
})?;
} else {
self.tx_event
.send(KexEvent::Failure)
.map_err(|_| Unspecified)?;
return Err(MoshpitError::KeyNotEstablished.into());
}
}
if let Some(frame) = self.reader.read_frame().await?
&& let Frame::KeyAgreement(uuid) = frame
{
self.tx_event
.send(KexEvent::Uuid(*uuid.as_ref()))
.map_err(|_| Unspecified)?;
}
if let Some(frame) = self.reader.read_frame().await?
&& let Frame::SessionToken(session_uuid_wrapper) = frame
{
let session_uuid = *session_uuid_wrapper.as_ref();
let is_resume = self.requested_session_uuid == Some(session_uuid);
self.tx_event
.send(KexEvent::SessionInfo(session_uuid, is_resume))
.map_err(|_| Unspecified)?;
}
if let Some(frame) = self.reader.read_frame().await?
&& let Frame::MoshpitsAddr(addr) = frame
{
self.tx_event
.send(KexEvent::MoshpitsAddr(addr))
.map_err(|_| Unspecified)?;
}
Ok(())
}
pub async fn server_kex(
&mut self,
socket_addr: SocketAddr,
port_pool: Arc<Mutex<BTreeSet<u16>>>,
private_key_path: &PathBuf,
public_key_path: &PathBuf,
session_registry: Option<SessionRegistry>,
) -> Result<(ServerKex, Arc<UdpSocket>)> {
let (rnk, user_str, shell, requested_session_uuid_opt) =
if let Some(frame) = self.reader.read_frame().await? {
let (user, pk, fpk, req_uuid) = match frame {
Frame::Initialize(user, pk, fpk) => (user, pk, fpk, None),
Frame::ResumeRequest(session_uuid_wrapper, user, pk, fpk) => {
(user, pk, fpk, Some(*session_uuid_wrapper.as_ref()))
}
_ => {
error!("Expected initialize frame from mp");
return Err(MoshpitError::InvalidFrame.into());
}
};
let user_str = String::from_utf8_lossy(&user).to_string();
let (home_dir, shell) = if self.validate_user(&user_str).await? {
self.get_home_dir_shell(&user_str).await?
} else {
return Err(MoshpitError::KeyNotEstablished.into());
};
if !check_authorized_keys(&home_dir, &fpk)? {
return Err(MoshpitError::KeyNotEstablished.into());
}
let rnk = self.handle_initialize(
&pk,
&self.tx_event.clone(),
private_key_path,
public_key_path,
)?;
(rnk, user_str, shell, req_uuid)
} else {
error!("Expected initialize frame from mp");
return Err(MoshpitError::InvalidFrame.into());
};
if let Some(frame) = self.reader.read_frame().await? {
if let Frame::Check(nonce, enc) = frame {
self.handle_check(&rnk, nonce, enc, &self.tx_event.clone())?;
} else {
error!("Expected check frame from mp");
return Err(MoshpitError::InvalidFrame.into());
}
} else {
error!("Expected check frame from mp");
return Err(MoshpitError::InvalidFrame.into());
}
let (session_uuid, is_resume) = match (requested_session_uuid_opt, &session_registry) {
(Some(req_uuid), Some(registry)) => {
let reg = registry.lock().await;
if let Some(stored_user) = reg.get(&req_uuid) {
if *stored_user == user_str {
(req_uuid, true)
} else {
(Uuid::new_v4(), false)
}
} else {
(Uuid::new_v4(), false)
}
}
_ => (Uuid::new_v4(), false),
};
if !is_resume && let Some(ref registry) = session_registry {
let mut reg = registry.lock().await;
drop(reg.insert(session_uuid, user_str.clone()));
}
self.tx
.send(Frame::SessionToken(UuidWrapper::new(session_uuid)))?;
let udp_arc = self.handle_udp_setup(socket_addr, port_pool).await?;
if let Some(frame) = self.reader.read_frame().await? {
if let Frame::MoshpitAddr(moshpit_addr) = frame {
udp_arc.connect(moshpit_addr).await?;
} else {
error!("Expected moshpit address frame");
return Err(MoshpitError::InvalidFrame.into());
}
} else {
error!("Expected moshpit address frame");
return Err(MoshpitError::InvalidFrame.into());
}
let skex = ServerKex::builder()
.user(user_str)
.shell(shell)
.session_uuid(session_uuid)
.is_resume(is_resume)
.build();
Ok((skex, udp_arc))
}
fn handle_initialize(
&mut self,
pk: &[u8],
tx_event: &UnboundedSender<KexEvent>,
private_key_path: &PathBuf,
public_key_path: &PathBuf,
) -> Result<RandomizedNonceKey> {
let (unenc_key_pair_opt, _enc_key_pair_opt) = load_private_key(private_key_path)?;
let (_, public_key_bytes) = load_public_key(public_key_path)?;
let (private_key, public_key) = if let Some(unenc_key_pair) = unenc_key_pair_opt {
unenc_key_pair.take()
} else {
return Err(anyhow::anyhow!("No valid private key found"));
};
if public_key.as_ref() != public_key_bytes.as_slice() {
return Err(anyhow::anyhow!(
"public key from file does not match computed public key"
));
}
let unparsed_public_key = UnparsedPublicKey::new(&X25519, &pk);
let parsed_public_key = ParsedPublicKey::try_from(&unparsed_public_key)?;
let mut salt_bytes = [0u8; 32];
fill(&mut salt_bytes)?;
let peer_initialize =
Frame::PeerInitialize(public_key.as_ref().to_vec(), salt_bytes.to_vec());
self.tx.send(peer_initialize)?;
let salt = Salt::new(HKDF_SHA256, &salt_bytes);
let rnk = agree(
&private_key,
parsed_public_key,
Unspecified,
|key_material| {
let pseudo_random_key = salt.extract(key_material);
let okm = pseudo_random_key.expand(&[AEAD_KEY_INFO], &AES_256_GCM_SIV)?;
let mut key_bytes = [0u8; AES_256_KEY_LEN];
okm.fill(&mut key_bytes)?;
let okm_hmac =
pseudo_random_key.expand(&[HMAC_KEY_INFO], HKDF_SHA512.hmac_algorithm())?;
let mut hmac_key_bytes = [0u8; SHA512_OUTPUT_LEN];
okm_hmac.fill(&mut hmac_key_bytes)?;
tx_event
.send(KexEvent::KeyMaterial(key_bytes))
.map_err(|_| Unspecified)?;
tx_event
.send(KexEvent::HMACKeyMaterial(hmac_key_bytes))
.map_err(|_| Unspecified)?;
let rnk = RandomizedNonceKey::new(&AES_256_GCM_SIV, &key_bytes)?;
Ok(rnk)
},
)?;
Ok(rnk)
}
fn handle_check(
&mut self,
rnk: &RandomizedNonceKey,
nonce_bytes: [u8; 12],
mut check_bytes: Vec<u8>,
tx_event: &UnboundedSender<KexEvent>,
) -> Result<()> {
let nonce = Nonce::from(&nonce_bytes);
let decrypted_data = rnk
.open_in_place(nonce, Aad::empty(), &mut check_bytes)
.map_err(|_| MoshpitError::DecryptionFailed)?;
if decrypted_data == b"Yoda" {
let id = Uuid::new_v4();
tx_event.send(KexEvent::Uuid(id)).map_err(|_| Unspecified)?;
self.tx.send(Frame::KeyAgreement(UuidWrapper::new(id)))?;
} else {
error!("Check frame verification failed");
return Err(MoshpitError::DecryptionFailed.into());
}
Ok(())
}
async fn handle_udp_setup(
&mut self,
mut socket_addr: SocketAddr,
port_pool: Arc<Mutex<BTreeSet<u16>>>,
) -> Result<Arc<UdpSocket>> {
let mut port_p = port_pool.lock().await;
let next_port = port_p.pop_first().unwrap_or(49999);
socket_addr.set_port(next_port);
let my_local_ip = local_ip()?;
let udp_socket_addr = SocketAddr::new(my_local_ip, socket_addr.port());
trace!("binding moshpits socket at {udp_socket_addr}");
self.tx.send(Frame::MoshpitsAddr(udp_socket_addr))?;
let udp_listener = UdpSocket::bind(udp_socket_addr).await?;
Ok(Arc::new(udp_listener))
}
#[cfg(target_os = "linux")]
async fn validate_user(&self, user: &str) -> Result<bool> {
let mut is_valid_user = Command::new("id");
let _ = is_valid_user.arg(user);
let output = is_valid_user
.output()
.await
.map_err(|_e| MoshpitError::KeyNotEstablished)?;
Ok(output.status.success())
}
#[cfg(target_os = "macos")]
async fn validate_user(&self, user: &str) -> Result<bool> {
let mut is_valid_user = Command::new("dscl");
let _ = is_valid_user.args([".", "-read", format!("/Users/{user}").as_str()]);
let output = is_valid_user
.output()
.await
.map_err(|_e| MoshpitError::KeyNotEstablished)?;
Ok(output.status.success())
}
#[cfg(target_os = "windows")]
async fn validate_user(&self, user: &str) -> Result<bool> {
let mut is_valid_user = Command::new("net");
let _ = is_valid_user.args(["user", user]);
let output = is_valid_user
.output()
.await
.map_err(|_e| MoshpitError::KeyNotEstablished)?;
Ok(output.status.success())
}
#[cfg(target_os = "linux")]
async fn get_home_dir_shell(&self, user: &str) -> Result<(String, String)> {
let mut cmd = Command::new("getent");
let _ = cmd.args(["passwd", user]);
let output = cmd
.output()
.await
.map_err(|_e| MoshpitError::KeyNotEstablished)?;
if output.status.success() {
let stdout = String::from_utf8_lossy(&output.stdout);
let parts: Vec<&str> = stdout.split(':').collect();
if parts.len() >= 7 {
let home_dir = parts[5].to_string();
let shell = parts[6].trim().to_string();
return Ok((home_dir, shell));
}
}
Err(MoshpitError::KeyNotEstablished.into())
}
#[cfg(target_os = "macos")]
async fn get_home_dir_shell(&self, user: &str) -> Result<(String, String)> {
let mut cmd = Command::new("dscl");
let _ = cmd.args([
".",
"-read",
format!("/Users/{user}").as_str(),
"NFSHomeDirectory",
"UserShell",
]);
let output = cmd
.output()
.await
.map_err(|_e| MoshpitError::KeyNotEstablished)?;
if output.status.success() {
let stdout = String::from_utf8_lossy(&output.stdout);
let mut home_dir = String::new();
let mut shell = String::new();
for line in stdout.lines() {
if let Some(stripped) = line.strip_prefix("NFSHomeDirectory:") {
home_dir = stripped.trim().to_string();
} else if let Some(stripped) = line.strip_prefix("UserShell:") {
shell = stripped.trim().to_string();
}
}
return Ok((home_dir, shell));
}
Err(MoshpitError::KeyNotEstablished.into())
}
#[cfg(target_os = "windows")]
async fn get_home_dir_shell(&self, user: &str) -> Result<(String, String)> {
let mut cmd = Command::new("net");
let _ = cmd.args(["user", user]);
let output = cmd
.output()
.await
.map_err(|_e| MoshpitError::KeyNotEstablished)?;
if output.status.success() {
let stdout = String::from_utf8_lossy(&output.stdout);
let mut home_dir = String::new();
for line in stdout.lines() {
if line.to_lowercase().starts_with("home directory") {
home_dir = line[14..].trim().to_string();
break;
}
}
if home_dir.is_empty() {
home_dir = format!("C:\\Users\\{user}");
}
return Ok((home_dir, String::from("cmd.exe")));
}
Err(MoshpitError::KeyNotEstablished.into())
}
}
fn check_authorized_keys(home_dir: &str, fpk: &[u8]) -> Result<bool> {
let moshpit_path = PathBuf::from(home_dir).join(".mp");
let authorized_keys_path = moshpit_path.join("authorized_keys");
if check_permissions(&moshpit_path, &authorized_keys_path)? {
let authorized_keys_file = OpenOptions::new()
.read(true)
.open(&authorized_keys_path)
.map_err(|_e| MoshpitError::KeyNotEstablished)?;
let buffered_reader = BufReader::new(authorized_keys_file);
let fpk_str = String::from_utf8_lossy(fpk);
for line in buffered_reader.lines().map_while(Result::ok) {
if line == fpk_str {
return Ok(true);
}
}
}
Ok(false)
}
#[cfg_attr(windows, allow(clippy::unnecessary_wraps))]
fn check_permissions(moshpit_path: &Path, authorized_keys_path: &Path) -> Result<bool> {
#[cfg(target_family = "unix")]
{
use std::os::unix::fs::MetadataExt;
let moshpit_metadata = moshpit_path.metadata()?;
let authorized_keys_metadata = authorized_keys_path.metadata()?;
let dir_perms = moshpit_metadata.mode() & 0o777;
if dir_perms != 0o700 {
return Ok(false);
}
let file_perms = authorized_keys_metadata.mode() & 0o777;
if file_perms != 0o600 {
return Ok(false);
}
}
#[cfg(target_os = "windows")]
{
if !windows_only_owner_has_access(moshpit_path)
|| !windows_only_owner_has_access(authorized_keys_path)
{
return Ok(false);
}
}
Ok(true)
}
#[cfg(target_os = "windows")]
fn windows_only_owner_has_access(path: &Path) -> bool {
use std::os::windows::ffi::OsStrExt;
use windows::{
Win32::Foundation::{HLOCAL, LocalFree},
Win32::Security::Authorization::{GetNamedSecurityInfoW, SE_FILE_OBJECT},
Win32::Security::{
ACCESS_ALLOWED_ACE, ACL, ACL_SIZE_INFORMATION, AclSizeInformation, CreateWellKnownSid,
DACL_SECURITY_INFORMATION, EqualSid, GetAce, GetAclInformation,
OWNER_SECURITY_INFORMATION, PSECURITY_DESCRIPTOR, PSID, WinLocalSystemSid,
},
core::PCWSTR,
};
let wide: Vec<u16> = path
.as_os_str()
.encode_wide()
.chain(std::iter::once(0))
.collect();
let mut p_dacl: *mut ACL = std::ptr::null_mut();
let mut p_owner = PSID(std::ptr::null_mut());
let mut p_sd = PSECURITY_DESCRIPTOR(std::ptr::null_mut());
let err = unsafe {
GetNamedSecurityInfoW(
PCWSTR(wide.as_ptr()),
SE_FILE_OBJECT,
DACL_SECURITY_INFORMATION | OWNER_SECURITY_INFORMATION,
Some(&raw mut p_owner),
None,
Some(&raw mut p_dacl),
None,
&raw mut p_sd,
)
};
if err.0 != 0 {
return false;
}
let mut system_sid_buf = [0u8; 68];
let mut system_sid_size: u32 = 68;
let system_sid = PSID(system_sid_buf.as_mut_ptr().cast());
let ok = unsafe {
CreateWellKnownSid(
WinLocalSystemSid,
None,
Some(system_sid),
&raw mut system_sid_size,
)
};
if ok.is_err() {
unsafe {
let _ = LocalFree(Some(HLOCAL(p_sd.0)));
}
return false;
}
let result = if p_dacl.is_null() {
false
} else {
let mut acl_info = ACL_SIZE_INFORMATION::default();
let ok = unsafe {
GetAclInformation(
p_dacl,
std::ptr::addr_of_mut!(acl_info).cast::<core::ffi::c_void>(),
u32::try_from(size_of::<ACL_SIZE_INFORMATION>())
.expect("ACL_SIZE_INFORMATION fits in u32"),
AclSizeInformation,
)
};
if ok.is_err() {
unsafe {
let _ = LocalFree(Some(HLOCAL(p_sd.0)));
}
return false;
}
let mut secure = true;
for i in 0..acl_info.AceCount {
let mut p_ace: *mut core::ffi::c_void = std::ptr::null_mut();
if unsafe { GetAce(p_dacl, i, &raw mut p_ace) }.is_ok() {
let ace = unsafe { &*(p_ace as *const ACCESS_ALLOWED_ACE) };
if ace.Header.AceType == 0u8 {
let ace_sid = PSID(std::ptr::addr_of!(ace.SidStart) as *mut core::ffi::c_void);
let is_owner = unsafe { EqualSid(ace_sid, p_owner) }.is_ok();
let is_system = unsafe { EqualSid(ace_sid, system_sid) }.is_ok();
if !is_owner && !is_system {
secure = false;
break;
}
}
}
}
secure
};
unsafe {
let _ = LocalFree(Some(HLOCAL(p_sd.0)));
}
result
}