#![cfg_attr(not(unix), allow(dead_code))]
use crate::NodeAddr;
use crate::transport::{TransportAddr, TransportId};
use crossbeam_channel::{Receiver, Sender, TrySendError, bounded};
use ring::aead::{Aad, LessSafeKey, Nonce};
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::mpsc::UnboundedSender;
use tracing::{debug, trace, warn};
use crate::noise::ReplayWindow;
const WORKER_CHANNEL_CAP: usize = 32768;
pub(crate) struct OwnedSessionState {
pub fmp_cipher: LessSafeKey,
pub fmp_replay: ReplayWindow,
pub source_npub: Option<String>,
}
pub(crate) struct DecryptJob {
pub packet_data: Vec<u8>,
pub cache_key: (TransportId, u32),
pub _transport_id: TransportId,
pub _remote_addr: TransportAddr,
pub timestamp_ms: u64,
pub source_node_addr: NodeAddr,
pub fmp_counter: u64,
pub fmp_flags: u8,
pub fmp_header: [u8; 16],
pub fmp_ciphertext_offset: usize,
pub fallback_tx: UnboundedSender<DecryptFallback>,
}
#[allow(dead_code)] pub(crate) struct DecryptFallback {
pub source_node_addr: NodeAddr,
pub transport_id: TransportId,
pub remote_addr: TransportAddr,
pub timestamp_ms: u64,
pub packet_len: usize,
pub fmp_counter: u64,
pub fmp_flags: u8,
pub packet_data: Vec<u8>,
pub fmp_plaintext_offset: usize,
pub fmp_plaintext_len: usize,
}
#[allow(clippy::large_enum_variant)]
pub(crate) enum WorkerMsg {
Job(DecryptJob),
RegisterSession {
cache_key: (TransportId, u32),
state: OwnedSessionState,
},
UnregisterSession {
cache_key: (TransportId, u32),
},
}
#[derive(Clone)]
pub(crate) struct DecryptWorkerPool {
senders: Arc<[Sender<WorkerMsg>]>,
}
impl DecryptWorkerPool {
pub fn spawn(n: usize) -> Self {
let n = n.max(1);
let mut senders = Vec::with_capacity(n);
for i in 0..n {
let (tx, rx) = bounded::<WorkerMsg>(WORKER_CHANNEL_CAP);
std::thread::Builder::new()
.name(format!("fips-decrypt-{i}"))
.spawn(move || run_worker(i, rx))
.expect("failed to spawn fips-decrypt OS thread");
senders.push(tx);
}
Self {
senders: senders.into(),
}
}
fn worker_idx_for(&self, cache_key: (TransportId, u32)) -> usize {
use std::hash::{Hash, Hasher};
let mut h = std::collections::hash_map::DefaultHasher::new();
cache_key.hash(&mut h);
(h.finish() as usize) % self.senders.len()
}
pub fn dispatch_job(&self, job: DecryptJob) {
if self.senders.is_empty() {
return;
}
let idx = self.worker_idx_for(job.cache_key);
match self.senders[idx].try_send(WorkerMsg::Job(job)) {
Ok(()) => {}
Err(TrySendError::Full(_)) => {
static FULL_COUNT: AtomicU64 = AtomicU64::new(0);
let n = FULL_COUNT.fetch_add(1, Ordering::Relaxed);
if n < 8 || n.is_multiple_of(10000) {
warn!(
worker = idx,
drops = n + 1,
"DecryptWorker channel full; dropping inbound packet"
);
}
}
Err(TrySendError::Disconnected(_)) => {
debug!(worker = idx, "DecryptWorker thread gone; dropping job");
}
}
}
#[must_use = "registration may have failed under queue pressure; caller must gate its own session-registered flag on the returned bool"]
pub fn register_session(
&self,
cache_key: (TransportId, u32),
state: OwnedSessionState,
) -> bool {
if self.senders.is_empty() {
return false;
}
let idx = self.worker_idx_for(cache_key);
match self.senders[idx].try_send(WorkerMsg::RegisterSession { cache_key, state }) {
Ok(()) => true,
Err(TrySendError::Full(_)) => {
warn!(
worker = idx,
"DecryptWorker channel full at session registration; will retry on next packet"
);
false
}
Err(TrySendError::Disconnected(_)) => {
debug!(
worker = idx,
"DecryptWorker thread gone; ignoring registration"
);
false
}
}
}
#[allow(dead_code)] pub fn unregister_session(&self, cache_key: (TransportId, u32)) {
if self.senders.is_empty() {
return;
}
let idx = self.worker_idx_for(cache_key);
let _ = self.senders[idx].try_send(WorkerMsg::UnregisterSession { cache_key });
}
}
fn run_worker(idx: usize, rx: Receiver<WorkerMsg>) {
trace!(worker = idx, "FMP+FSP decrypt worker thread starting");
let mut sessions: HashMap<(TransportId, u32), OwnedSessionState> = HashMap::new();
while let Ok(msg) = rx.recv() {
handle_msg(idx, &mut sessions, msg);
while let Ok(m) = rx.try_recv() {
handle_msg(idx, &mut sessions, m);
}
}
trace!(worker = idx, "FMP+FSP decrypt worker thread exiting");
}
fn handle_msg(
idx: usize,
sessions: &mut HashMap<(TransportId, u32), OwnedSessionState>,
msg: WorkerMsg,
) {
match msg {
WorkerMsg::Job(job) => {
if let Err(err) = handle_job(sessions, job) {
debug!(worker = idx, error = %err, "decrypt worker job failed");
}
}
WorkerMsg::RegisterSession { cache_key, state } => {
trace!(worker = idx, ?cache_key, "DecryptWorker: register session");
sessions.insert(cache_key, state);
}
WorkerMsg::UnregisterSession { cache_key } => {
trace!(
worker = idx,
?cache_key,
"DecryptWorker: unregister session"
);
sessions.remove(&cache_key);
}
}
}
fn handle_job(
sessions: &mut HashMap<(TransportId, u32), OwnedSessionState>,
job: DecryptJob,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let DecryptJob {
mut packet_data,
cache_key,
_transport_id: transport_id,
_remote_addr: remote_addr,
timestamp_ms,
source_node_addr,
fmp_counter,
fmp_flags,
fmp_header,
fmp_ciphertext_offset,
fallback_tx,
} = job;
let packet_len = packet_data.len();
let state = match sessions.get_mut(&cache_key) {
Some(s) => s,
None => {
let _ = fallback_tx; let _ = source_node_addr;
let _ = packet_data;
return Ok(());
}
};
let _t_fmp = crate::perf_profile::Timer::start(crate::perf_profile::Stage::FmpDecrypt);
if !state.fmp_replay.check(fmp_counter) {
return Ok(()); }
let mut nonce_bytes = [0u8; 12];
nonce_bytes[4..12].copy_from_slice(&fmp_counter.to_le_bytes());
let nonce = Nonce::assume_unique_for_key(nonce_bytes);
let buf = &mut packet_data[fmp_ciphertext_offset..];
let plaintext_len = match state
.fmp_cipher
.open_in_place(nonce, Aad::from(&fmp_header), buf)
{
Ok(p) => p.len(),
Err(_) => return Ok(()), };
state.fmp_replay.accept(fmp_counter);
drop(_t_fmp);
let fmp_plaintext_start = fmp_ciphertext_offset;
let fmp_plaintext_end = fmp_ciphertext_offset + plaintext_len;
const INNER_TIMESTAMP_LEN: usize = 4;
if plaintext_len < INNER_TIMESTAMP_LEN + 1 {
return Ok(());
}
let link_msg_start = fmp_plaintext_start + INNER_TIMESTAMP_LEN;
let link_msg_end = fmp_plaintext_end;
let link_msg = &packet_data[link_msg_start..link_msg_end];
let _ = link_msg; let _ = fallback_tx.send(DecryptFallback {
source_node_addr,
transport_id,
remote_addr,
timestamp_ms,
packet_len,
fmp_counter,
fmp_flags,
packet_data,
fmp_plaintext_offset: fmp_plaintext_start,
fmp_plaintext_len: plaintext_len,
});
let _ = (link_msg_start, link_msg_end, &state.source_npub);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::noise::ReplayWindow;
use ring::aead::{LessSafeKey, UnboundKey};
#[test]
fn worker_preserves_fmp_flags_through_fallback() {
let key_bytes = [0u8; 32];
let unbound = UnboundKey::new(&ring::aead::CHACHA20_POLY1305, &key_bytes).unwrap();
let seal_cipher = LessSafeKey::new(unbound);
let unbound2 = UnboundKey::new(&ring::aead::CHACHA20_POLY1305, &key_bytes).unwrap();
let open_cipher = LessSafeKey::new(unbound2);
let counter: u64 = 7;
const HDR: usize = crate::node::wire::ESTABLISHED_HEADER_SIZE;
let mut wire = Vec::with_capacity(HDR + 4 + 1 + 16);
let flags_byte = crate::node::wire::FLAG_CE | crate::node::wire::FLAG_SP;
let mut header = [0u8; HDR];
header[1] = flags_byte;
wire.extend_from_slice(&header);
wire.extend_from_slice(&[0u8; 4]); wire.push(0xAB);
let mut nonce_bytes = [0u8; 12];
nonce_bytes[4..12].copy_from_slice(&counter.to_le_bytes());
let nonce = ring::aead::Nonce::assume_unique_for_key(nonce_bytes);
let (hdr_slice, payload_slice) = wire.split_at_mut(HDR);
let tag = seal_cipher
.seal_in_place_separate_tag(nonce, ring::aead::Aad::from(&*hdr_slice), payload_slice)
.unwrap();
wire.extend_from_slice(tag.as_ref());
let cache_key = (TransportId::new(1), 99u32);
let mut sessions: HashMap<(TransportId, u32), OwnedSessionState> = HashMap::new();
sessions.insert(
cache_key,
OwnedSessionState {
fmp_cipher: open_cipher,
fmp_replay: ReplayWindow::new(),
source_npub: None,
},
);
let (fallback_tx, mut fallback_rx) =
tokio::sync::mpsc::unbounded_channel::<DecryptFallback>();
let job = DecryptJob {
packet_data: wire,
cache_key,
_transport_id: TransportId::new(1),
_remote_addr: crate::transport::TransportAddr::from_string("127.0.0.1:1234"),
timestamp_ms: 1_000,
source_node_addr: crate::NodeAddr::from_bytes([0u8; 16]),
fmp_counter: counter,
fmp_flags: flags_byte,
fmp_header: header,
fmp_ciphertext_offset: HDR,
fallback_tx,
};
handle_job(&mut sessions, job).expect("worker job handled");
let fallback = fallback_rx.try_recv().expect("fallback delivered");
assert_eq!(
fallback.fmp_flags, flags_byte,
"fmp_flags must round-trip from DecryptJob to DecryptFallback"
);
assert!(
fallback.fmp_flags & crate::node::wire::FLAG_CE != 0,
"FLAG_CE bit lost on worker path"
);
assert!(
fallback.fmp_flags & crate::node::wire::FLAG_SP != 0,
"FLAG_SP bit lost on worker path"
);
}
}