Skip to main content

libmoshpit/kex/
mod.rs

1// Copyright (c) 2025 moshpit developers
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or https://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9use std::{
10    fmt::{self, Display, Formatter},
11    net::SocketAddr,
12    sync::Arc,
13};
14
15use anyhow::Result;
16use aws_lc_rs::{
17    agreement::{PrivateKey, X25519},
18    cipher::AES_256_KEY_LEN,
19};
20use bon::Builder;
21use getset::{CopyGetters, Getters};
22use local_ip_address::local_ip;
23use serde::{Deserialize, Serialize};
24use tokio::{
25    net::{
26        UdpSocket,
27        tcp::{OwnedReadHalf, OwnedWriteHalf},
28    },
29    spawn,
30    sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel},
31    task::JoinHandle,
32};
33use tracing::{error, trace};
34use uuid::Uuid;
35
36use crate::{
37    ConnectionReader, ConnectionWriter, Frame, KexConfig, KexReader, KexSender, MoshpitError,
38    UuidWrapper, decrypt_private_key, load_private_key, load_public_key,
39};
40
41pub(crate) mod reader;
42pub(crate) mod sender;
43
44/// The key exchange events
45#[derive(Clone, Copy, Debug)]
46pub enum KexEvent {
47    /// Key material for encrypting/decrypting UDP packets
48    KeyMaterial([u8; 32]),
49    /// HMAC key for signing UDP packets
50    HMACKeyMaterial([u8; 64]),
51    /// moshpit client UUID
52    Uuid(Uuid),
53    /// moshpits socket address
54    MoshpitsAddr(SocketAddr),
55    /// Session information: (stable session UUID, `is_resume` flag)
56    SessionInfo(Uuid, bool),
57    /// Key exchange failure
58    Failure,
59}
60
61/// The moshpit key exchange state
62#[derive(Clone, Copy, Debug, Default)]
63pub enum KexState {
64    /// Awaiting key material for encrypting/decrypting UDP packets
65    #[default]
66    AwaitingKeyMaterial,
67    /// Awaiting HMAC key for signing UDP packets
68    AwaitingHMACKeyMaterial,
69    /// Awaiting moshpit client UUID
70    AwaitingUuid,
71    /// Awaiting session token from moshpits (client mode only, between Uuid and `MoshpitsAddr`)
72    AwaitingSessionToken,
73    /// Awaiting moshpits socket address
74    AwaitingMoshpitsAddr,
75    /// Key exchange is complete
76    Complete,
77}
78
79/// The moshpit key exchange state machine
80#[derive(Builder, CopyGetters, Debug)]
81pub struct KexStateMachine {
82    /// The current key exchange state
83    #[getset(get_copy = "pub")]
84    #[builder(default = KexState::default())]
85    state: KexState,
86    rx_event: UnboundedReceiver<KexEvent>,
87}
88
89/// The moshpit key exchange result
90#[derive(Clone, Copy, CopyGetters, Debug)]
91pub struct Kex {
92    /// AES-256-GCM-SIV key material for encrypting/decrypting UDP packets
93    #[getset(get_copy = "pub")]
94    key: [u8; 32],
95    /// HMAC key for signing UDP packets
96    #[getset(get_copy = "pub")]
97    hmac_key: [u8; 64],
98    /// moshpit client UUID (per-connection, changes on every reconnect)
99    #[getset(get_copy = "pub")]
100    uuid: Uuid,
101    /// An optional moshpits socket address used by moshpit.
102    #[getset(get_copy = "pub")]
103    moshpits_addr: Option<SocketAddr>,
104    /// Stable session UUID, set for client mode after `SessionToken` received.
105    #[getset(get_copy = "pub")]
106    session_uuid: Option<Uuid>,
107    /// Whether this connection is resuming an existing session.
108    #[getset(get_copy = "pub")]
109    is_resume: bool,
110}
111
112impl Kex {
113    /// Get the wrapped UUID
114    #[must_use]
115    pub fn uuid_wrapper(&self) -> UuidWrapper {
116        UuidWrapper::new(self.uuid)
117    }
118}
119
120impl Default for Kex {
121    fn default() -> Self {
122        Self {
123            key: [0u8; 32],
124            hmac_key: [0u8; 64],
125            uuid: Uuid::nil(),
126            moshpits_addr: None,
127            session_uuid: None,
128            is_resume: false,
129        }
130    }
131}
132
133/// Extended key exchange for the moshpits side of the exchange
134#[derive(Builder, Clone, Debug, CopyGetters, Getters)]
135pub struct ServerKex {
136    /// The user associated with the key exchange
137    #[getset(get = "pub")]
138    user: String,
139    /// The shell associated with the key exchange
140    #[getset(get = "pub")]
141    shell: String,
142    /// The stable session UUID assigned to this connection
143    #[getset(get_copy = "pub")]
144    session_uuid: Uuid,
145    /// Whether this connection is resuming an existing session
146    #[getset(get_copy = "pub")]
147    #[builder(default)]
148    is_resume: bool,
149}
150
151impl KexStateMachine {
152    /// Handle key exchange events
153    ///
154    /// # Errors
155    /// Returns an error if the key exchange state is invalid
156    ///
157    pub async fn handle_events(&mut self, client_mode: bool) -> Result<Kex> {
158        let mut kex = Kex::default();
159
160        while let Some(event) = self.rx_event.recv().await {
161            match (self.state, event) {
162                (KexState::AwaitingKeyMaterial, KexEvent::KeyMaterial(key_material)) => {
163                    kex.key = key_material;
164                    self.state = KexState::AwaitingHMACKeyMaterial;
165                }
166                (
167                    KexState::AwaitingHMACKeyMaterial,
168                    KexEvent::HMACKeyMaterial(hmac_key_material),
169                ) => {
170                    kex.hmac_key = hmac_key_material;
171                    self.state = KexState::AwaitingUuid;
172                }
173                (KexState::AwaitingUuid, KexEvent::Uuid(uuid)) => {
174                    kex.uuid = uuid;
175                    if client_mode {
176                        self.state = KexState::AwaitingSessionToken;
177                    } else {
178                        self.state = KexState::Complete;
179                        break;
180                    }
181                }
182                (
183                    KexState::AwaitingSessionToken,
184                    KexEvent::SessionInfo(session_uuid, is_resume),
185                ) => {
186                    kex.session_uuid = Some(session_uuid);
187                    kex.is_resume = is_resume;
188                    self.state = KexState::AwaitingMoshpitsAddr;
189                }
190                (KexState::AwaitingMoshpitsAddr, KexEvent::MoshpitsAddr(addr)) => {
191                    self.state = KexState::Complete;
192                    kex.moshpits_addr = Some(addr);
193                    break;
194                }
195                _ => {
196                    return Err(MoshpitError::InvalidKexState.into());
197                }
198            }
199        }
200
201        match self.state {
202            KexState::Complete => Ok(kex),
203            _ => Err(MoshpitError::InvalidKexState.into()),
204        }
205    }
206}
207
208/// The key exchange mode
209#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
210pub enum KexMode {
211    /// Client mode
212    #[default]
213    Client,
214    /// Server mode
215    Server(SocketAddr),
216}
217
218impl Display for KexMode {
219    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
220        match self {
221            KexMode::Client => write!(f, "Client"),
222            KexMode::Server(addr) => write!(f, "Server({addr})"),
223        }
224    }
225}
226
227/// Run the client side of the key exchange
228///
229/// # Errors
230///
231pub async fn run_key_exchange<T: KexConfig>(
232    config: T,
233    sock_read: OwnedReadHalf,
234    sock_write: OwnedWriteHalf,
235    passphrase_fn: impl Fn() -> Result<Option<String>>,
236) -> Result<(Kex, Arc<UdpSocket>, Option<ServerKex>)> {
237    // Setup the TCP connection to the server for key exchange
238    let mode = config.mode();
239    let reader = ConnectionReader::builder().reader(sock_read).build();
240    let writer = ConnectionWriter::builder().writer(sock_write).build();
241    let (tx, rx) = unbounded_channel();
242    let (tx_event, rx_event) = unbounded_channel::<KexEvent>();
243    let mut kex_sm = KexStateMachine::builder().rx_event(rx_event).build();
244    let kex_handle = spawn(async move { kex_sm.handle_events(mode == KexMode::Client).await });
245
246    // Setup the TCP frame sender
247    let _write_handle = spawn(async move {
248        let mut sender = KexSender::builder().writer(writer).rx(rx).build();
249        if let Err(e) = sender.handle_send_frames().await {
250            error!("{e}");
251        }
252    });
253
254    Ok(match mode {
255        KexMode::Client => {
256            run_client_kex(config, tx, tx_event, reader, kex_handle, passphrase_fn).await?
257        }
258        KexMode::Server(socket_addr) => {
259            let tx_c = tx.clone();
260            match run_server_kex(config, socket_addr, tx, tx_event, reader, kex_handle).await {
261                Ok(result) => result,
262                Err(e) => {
263                    let _blah = tx_c.send(Frame::KexFailure);
264                    Err(e)?
265                }
266            }
267        }
268    })
269}
270
271async fn run_client_kex<T: KexConfig>(
272    config: T,
273    tx: UnboundedSender<Frame>,
274    tx_event: UnboundedSender<KexEvent>,
275    reader: ConnectionReader,
276    kex_handle: JoinHandle<Result<Kex>>,
277    passphrase_fn: impl Fn() -> Result<Option<String>>,
278) -> Result<(Kex, Arc<UdpSocket>, Option<ServerKex>)> {
279    let (private_key_path, public_key_path) = config.key_pair_paths()?;
280    trace!("Loading private key from {}", private_key_path.display());
281    trace!("Loading public key from {}", public_key_path.display());
282
283    // Load the moshpit public and private key
284    let (unenc_key_pair_opt, enc_key_pair_opt) = load_private_key(&private_key_path)?;
285    let (full_public_key_bytes, public_key_bytes) = load_public_key(&public_key_path)?;
286
287    let (pk, my_public_key) = if let Some(enc_key_pair) = enc_key_pair_opt {
288        // Get the passphrase
289        if let Some(passphrase) = passphrase_fn()? {
290            let salt_bytes = enc_key_pair.salt_bytes();
291            let nonce_bytes = enc_key_pair.nonce_bytes();
292            let mut encrypted_private_key_bytes =
293                enc_key_pair.encrypted_private_key_bytes().clone();
294            decrypt_private_key(
295                &passphrase,
296                salt_bytes,
297                nonce_bytes,
298                &mut encrypted_private_key_bytes,
299            )?;
300
301            let private_key = PrivateKey::from_private_key(
302                &X25519,
303                &encrypted_private_key_bytes[..AES_256_KEY_LEN],
304            )?;
305            let public_key = private_key.compute_public_key()?;
306
307            if public_key.as_ref() != public_key_bytes.as_slice() {
308                return Err(anyhow::anyhow!("Public key does not match the private key"));
309            }
310            (private_key, public_key)
311        } else {
312            return Err(anyhow::anyhow!("No valid private key found"));
313        }
314    } else if let Some(unenc_key_pair) = unenc_key_pair_opt {
315        unenc_key_pair.take()
316    } else {
317        return Err(anyhow::anyhow!("No valid private key found"));
318    };
319
320    // Setup the TCP frame reader
321    let tx_c = tx.clone();
322    let tx_event_c = tx_event.clone();
323    let requested = config.resume_session_uuid();
324    let _read_handle = spawn(async move {
325        let mut frame_reader = KexReader::builder()
326            .reader(reader)
327            .tx(tx_c)
328            .tx_event(tx_event_c)
329            .maybe_requested_session_uuid(requested)
330            .build();
331        if let Err(e) = frame_reader.client_kex(&pk).await {
332            trace!("{e}");
333        }
334    });
335
336    // Send the initialize or resume-request frame with our public key
337    let frame = if let Some(session_uuid) = config.resume_session_uuid() {
338        Frame::ResumeRequest(
339            UuidWrapper::new(session_uuid),
340            config.user().unwrap_or_default().as_bytes().to_vec(),
341            my_public_key.as_ref().to_vec(),
342            full_public_key_bytes,
343        )
344    } else {
345        Frame::Initialize(
346            config.user().unwrap_or_default().as_bytes().to_vec(),
347            my_public_key.as_ref().to_vec(),
348            full_public_key_bytes,
349        )
350    };
351    tx.send(frame)?;
352
353    let kex = kex_handle.await??;
354
355    if let Some(moshpits_addr) = kex.moshpits_addr() {
356        trace!("Connecting to moshpits at {moshpits_addr}");
357        let my_local_ip = local_ip()?;
358        let socket_addr = SocketAddr::new(my_local_ip, 0);
359        let udp_listener = UdpSocket::bind(socket_addr).await?;
360        udp_listener.connect(moshpits_addr).await?;
361        let frame = Frame::MoshpitAddr(udp_listener.local_addr()?);
362        tx.send(frame.clone())?;
363        Ok((kex, Arc::new(udp_listener), None))
364    } else {
365        Err(MoshpitError::InvalidMoshpitsAddress.into())
366    }
367}
368
369async fn run_server_kex<T: KexConfig>(
370    config: T,
371    socket_addr: SocketAddr,
372    tx: UnboundedSender<Frame>,
373    tx_event: UnboundedSender<KexEvent>,
374    reader: ConnectionReader,
375    kex_handle: JoinHandle<Result<Kex>>,
376) -> Result<(Kex, Arc<UdpSocket>, Option<ServerKex>)> {
377    let port_pool_opt = config.port_pool();
378    let (private_key_path, public_key_path) = config.key_pair_paths()?;
379    let session_registry = config.session_registry();
380    trace!("Loading private key from {}", private_key_path.display());
381    trace!("Loading public key from {}", public_key_path.display());
382
383    // Setup the TCP frame reader
384    let tx_c = tx.clone();
385    let tx_event_c = tx_event.clone();
386    let mut frame_reader = KexReader::builder()
387        .reader(reader)
388        .tx(tx_c)
389        .tx_event(tx_event_c)
390        .build();
391    if let Some(port_pool) = port_pool_opt {
392        let (skex, udp_arc) = frame_reader
393            .server_kex(
394                socket_addr,
395                port_pool,
396                &private_key_path,
397                &public_key_path,
398                session_registry,
399            )
400            .await?;
401        Ok((kex_handle.await??, udp_arc, Some(skex)))
402    } else {
403        Err(anyhow::anyhow!(
404            "Port pool is required for server key exchange"
405        ))
406    }
407}