1use std::{
10 fmt::{self, Display, Formatter},
11 net::SocketAddr,
12 sync::Arc,
13};
14
15use anyhow::Result;
16use aws_lc_rs::{
17 agreement::{PrivateKey, X25519},
18 cipher::AES_256_KEY_LEN,
19};
20use bon::Builder;
21use getset::{CopyGetters, Getters};
22use local_ip_address::local_ip;
23use serde::{Deserialize, Serialize};
24use tokio::{
25 net::{
26 UdpSocket,
27 tcp::{OwnedReadHalf, OwnedWriteHalf},
28 },
29 spawn,
30 sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel},
31 task::JoinHandle,
32};
33use tracing::{error, trace};
34use uuid::Uuid;
35
36use crate::{
37 ConnectionReader, ConnectionWriter, Frame, KexConfig, KexReader, KexSender, MoshpitError,
38 UuidWrapper, decrypt_private_key, load_private_key, load_public_key,
39};
40
41pub(crate) mod reader;
42pub(crate) mod sender;
43
44#[derive(Clone, Copy, Debug)]
46pub enum KexEvent {
47 KeyMaterial([u8; 32]),
49 HMACKeyMaterial([u8; 64]),
51 Uuid(Uuid),
53 MoshpitsAddr(SocketAddr),
55 SessionInfo(Uuid, bool),
57 Failure,
59}
60
61#[derive(Clone, Copy, Debug, Default)]
63pub enum KexState {
64 #[default]
66 AwaitingKeyMaterial,
67 AwaitingHMACKeyMaterial,
69 AwaitingUuid,
71 AwaitingSessionToken,
73 AwaitingMoshpitsAddr,
75 Complete,
77}
78
79#[derive(Builder, CopyGetters, Debug)]
81pub struct KexStateMachine {
82 #[getset(get_copy = "pub")]
84 #[builder(default = KexState::default())]
85 state: KexState,
86 rx_event: UnboundedReceiver<KexEvent>,
87}
88
89#[derive(Clone, Copy, CopyGetters, Debug)]
91pub struct Kex {
92 #[getset(get_copy = "pub")]
94 key: [u8; 32],
95 #[getset(get_copy = "pub")]
97 hmac_key: [u8; 64],
98 #[getset(get_copy = "pub")]
100 uuid: Uuid,
101 #[getset(get_copy = "pub")]
103 moshpits_addr: Option<SocketAddr>,
104 #[getset(get_copy = "pub")]
106 session_uuid: Option<Uuid>,
107 #[getset(get_copy = "pub")]
109 is_resume: bool,
110}
111
112impl Kex {
113 #[must_use]
115 pub fn uuid_wrapper(&self) -> UuidWrapper {
116 UuidWrapper::new(self.uuid)
117 }
118}
119
120impl Default for Kex {
121 fn default() -> Self {
122 Self {
123 key: [0u8; 32],
124 hmac_key: [0u8; 64],
125 uuid: Uuid::nil(),
126 moshpits_addr: None,
127 session_uuid: None,
128 is_resume: false,
129 }
130 }
131}
132
133#[derive(Builder, Clone, Debug, CopyGetters, Getters)]
135pub struct ServerKex {
136 #[getset(get = "pub")]
138 user: String,
139 #[getset(get = "pub")]
141 shell: String,
142 #[getset(get_copy = "pub")]
144 session_uuid: Uuid,
145 #[getset(get_copy = "pub")]
147 #[builder(default)]
148 is_resume: bool,
149}
150
151impl KexStateMachine {
152 pub async fn handle_events(&mut self, client_mode: bool) -> Result<Kex> {
158 let mut kex = Kex::default();
159
160 while let Some(event) = self.rx_event.recv().await {
161 match (self.state, event) {
162 (KexState::AwaitingKeyMaterial, KexEvent::KeyMaterial(key_material)) => {
163 kex.key = key_material;
164 self.state = KexState::AwaitingHMACKeyMaterial;
165 }
166 (
167 KexState::AwaitingHMACKeyMaterial,
168 KexEvent::HMACKeyMaterial(hmac_key_material),
169 ) => {
170 kex.hmac_key = hmac_key_material;
171 self.state = KexState::AwaitingUuid;
172 }
173 (KexState::AwaitingUuid, KexEvent::Uuid(uuid)) => {
174 kex.uuid = uuid;
175 if client_mode {
176 self.state = KexState::AwaitingSessionToken;
177 } else {
178 self.state = KexState::Complete;
179 break;
180 }
181 }
182 (
183 KexState::AwaitingSessionToken,
184 KexEvent::SessionInfo(session_uuid, is_resume),
185 ) => {
186 kex.session_uuid = Some(session_uuid);
187 kex.is_resume = is_resume;
188 self.state = KexState::AwaitingMoshpitsAddr;
189 }
190 (KexState::AwaitingMoshpitsAddr, KexEvent::MoshpitsAddr(addr)) => {
191 self.state = KexState::Complete;
192 kex.moshpits_addr = Some(addr);
193 break;
194 }
195 _ => {
196 return Err(MoshpitError::InvalidKexState.into());
197 }
198 }
199 }
200
201 match self.state {
202 KexState::Complete => Ok(kex),
203 _ => Err(MoshpitError::InvalidKexState.into()),
204 }
205 }
206}
207
208#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
210pub enum KexMode {
211 #[default]
213 Client,
214 Server(SocketAddr),
216}
217
218impl Display for KexMode {
219 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
220 match self {
221 KexMode::Client => write!(f, "Client"),
222 KexMode::Server(addr) => write!(f, "Server({addr})"),
223 }
224 }
225}
226
227pub async fn run_key_exchange<T: KexConfig>(
232 config: T,
233 sock_read: OwnedReadHalf,
234 sock_write: OwnedWriteHalf,
235 passphrase_fn: impl Fn() -> Result<Option<String>>,
236) -> Result<(Kex, Arc<UdpSocket>, Option<ServerKex>)> {
237 let mode = config.mode();
239 let reader = ConnectionReader::builder().reader(sock_read).build();
240 let writer = ConnectionWriter::builder().writer(sock_write).build();
241 let (tx, rx) = unbounded_channel();
242 let (tx_event, rx_event) = unbounded_channel::<KexEvent>();
243 let mut kex_sm = KexStateMachine::builder().rx_event(rx_event).build();
244 let kex_handle = spawn(async move { kex_sm.handle_events(mode == KexMode::Client).await });
245
246 let _write_handle = spawn(async move {
248 let mut sender = KexSender::builder().writer(writer).rx(rx).build();
249 if let Err(e) = sender.handle_send_frames().await {
250 error!("{e}");
251 }
252 });
253
254 Ok(match mode {
255 KexMode::Client => {
256 run_client_kex(config, tx, tx_event, reader, kex_handle, passphrase_fn).await?
257 }
258 KexMode::Server(socket_addr) => {
259 let tx_c = tx.clone();
260 match run_server_kex(config, socket_addr, tx, tx_event, reader, kex_handle).await {
261 Ok(result) => result,
262 Err(e) => {
263 let _blah = tx_c.send(Frame::KexFailure);
264 Err(e)?
265 }
266 }
267 }
268 })
269}
270
271async fn run_client_kex<T: KexConfig>(
272 config: T,
273 tx: UnboundedSender<Frame>,
274 tx_event: UnboundedSender<KexEvent>,
275 reader: ConnectionReader,
276 kex_handle: JoinHandle<Result<Kex>>,
277 passphrase_fn: impl Fn() -> Result<Option<String>>,
278) -> Result<(Kex, Arc<UdpSocket>, Option<ServerKex>)> {
279 let (private_key_path, public_key_path) = config.key_pair_paths()?;
280 trace!("Loading private key from {}", private_key_path.display());
281 trace!("Loading public key from {}", public_key_path.display());
282
283 let (unenc_key_pair_opt, enc_key_pair_opt) = load_private_key(&private_key_path)?;
285 let (full_public_key_bytes, public_key_bytes) = load_public_key(&public_key_path)?;
286
287 let (pk, my_public_key) = if let Some(enc_key_pair) = enc_key_pair_opt {
288 if let Some(passphrase) = passphrase_fn()? {
290 let salt_bytes = enc_key_pair.salt_bytes();
291 let nonce_bytes = enc_key_pair.nonce_bytes();
292 let mut encrypted_private_key_bytes =
293 enc_key_pair.encrypted_private_key_bytes().clone();
294 decrypt_private_key(
295 &passphrase,
296 salt_bytes,
297 nonce_bytes,
298 &mut encrypted_private_key_bytes,
299 )?;
300
301 let private_key = PrivateKey::from_private_key(
302 &X25519,
303 &encrypted_private_key_bytes[..AES_256_KEY_LEN],
304 )?;
305 let public_key = private_key.compute_public_key()?;
306
307 if public_key.as_ref() != public_key_bytes.as_slice() {
308 return Err(anyhow::anyhow!("Public key does not match the private key"));
309 }
310 (private_key, public_key)
311 } else {
312 return Err(anyhow::anyhow!("No valid private key found"));
313 }
314 } else if let Some(unenc_key_pair) = unenc_key_pair_opt {
315 unenc_key_pair.take()
316 } else {
317 return Err(anyhow::anyhow!("No valid private key found"));
318 };
319
320 let tx_c = tx.clone();
322 let tx_event_c = tx_event.clone();
323 let requested = config.resume_session_uuid();
324 let _read_handle = spawn(async move {
325 let mut frame_reader = KexReader::builder()
326 .reader(reader)
327 .tx(tx_c)
328 .tx_event(tx_event_c)
329 .maybe_requested_session_uuid(requested)
330 .build();
331 if let Err(e) = frame_reader.client_kex(&pk).await {
332 trace!("{e}");
333 }
334 });
335
336 let frame = if let Some(session_uuid) = config.resume_session_uuid() {
338 Frame::ResumeRequest(
339 UuidWrapper::new(session_uuid),
340 config.user().unwrap_or_default().as_bytes().to_vec(),
341 my_public_key.as_ref().to_vec(),
342 full_public_key_bytes,
343 )
344 } else {
345 Frame::Initialize(
346 config.user().unwrap_or_default().as_bytes().to_vec(),
347 my_public_key.as_ref().to_vec(),
348 full_public_key_bytes,
349 )
350 };
351 tx.send(frame)?;
352
353 let kex = kex_handle.await??;
354
355 if let Some(moshpits_addr) = kex.moshpits_addr() {
356 trace!("Connecting to moshpits at {moshpits_addr}");
357 let my_local_ip = local_ip()?;
358 let socket_addr = SocketAddr::new(my_local_ip, 0);
359 let udp_listener = UdpSocket::bind(socket_addr).await?;
360 udp_listener.connect(moshpits_addr).await?;
361 let frame = Frame::MoshpitAddr(udp_listener.local_addr()?);
362 tx.send(frame.clone())?;
363 Ok((kex, Arc::new(udp_listener), None))
364 } else {
365 Err(MoshpitError::InvalidMoshpitsAddress.into())
366 }
367}
368
369async fn run_server_kex<T: KexConfig>(
370 config: T,
371 socket_addr: SocketAddr,
372 tx: UnboundedSender<Frame>,
373 tx_event: UnboundedSender<KexEvent>,
374 reader: ConnectionReader,
375 kex_handle: JoinHandle<Result<Kex>>,
376) -> Result<(Kex, Arc<UdpSocket>, Option<ServerKex>)> {
377 let port_pool_opt = config.port_pool();
378 let (private_key_path, public_key_path) = config.key_pair_paths()?;
379 let session_registry = config.session_registry();
380 trace!("Loading private key from {}", private_key_path.display());
381 trace!("Loading public key from {}", public_key_path.display());
382
383 let tx_c = tx.clone();
385 let tx_event_c = tx_event.clone();
386 let mut frame_reader = KexReader::builder()
387 .reader(reader)
388 .tx(tx_c)
389 .tx_event(tx_event_c)
390 .build();
391 if let Some(port_pool) = port_pool_opt {
392 let (skex, udp_arc) = frame_reader
393 .server_kex(
394 socket_addr,
395 port_pool,
396 &private_key_path,
397 &public_key_path,
398 session_registry,
399 )
400 .await?;
401 Ok((kex_handle.await??, udp_arc, Some(skex)))
402 } else {
403 Err(anyhow::anyhow!(
404 "Port pool is required for server key exchange"
405 ))
406 }
407}