use std::sync::Arc;
use crate::client::policy::TransportPolicy;
use crate::client::selfheal::SelfHealing;
use crate::config::SrxConfig;
use crate::crypto::{AeadPipeline, ReplayState};
use crate::error::Result;
use crate::pipeline::{Payload, SrxPipeline};
use crate::replay_storage::{
ReplayStoreMetricsSnapshot, decode_replay_envelope, merge_and_persist_replay_state,
replay_store_metrics_snapshot, storage_from_config,
};
use crate::session::{Handshake, Session};
use crate::signaling::inband::Signal;
use crate::transport::{TransportKind, TransportManager};
pub struct SrxNode {
config: SrxConfig,
pipe: SrxPipeline,
self_healing: SelfHealing,
policy: TransportPolicy,
}
impl SrxNode {
pub fn client_connect<F>(
config: SrxConfig,
transport_mgr: TransportManager,
exchange: F,
) -> Result<Self>
where
F: FnOnce(&[u8]) -> Result<(Vec<u8>, Vec<u8>)>,
{
let mut hs = Handshake::new_initiator();
let ch = hs.client_hello()?;
let (sh, _ack) = exchange(&ch)?;
let cf = hs.finalize(&sh)?;
let _ = cf;
let master = hs.master_secret().ok_or_else(|| {
crate::error::SrxError::Session(crate::error::SessionError::HandshakeFailed(
"master secret not available after finalize".into(),
))
})?;
Self::from_master_secret(config, master, transport_mgr)
}
pub fn from_master_secret(
config: SrxConfig,
master: [u8; 32],
transport_mgr: TransportManager,
) -> Result<Self> {
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let session = Session::from_master_secret(0, &master, timestamp, b"srx-node")?;
let aead = Arc::new(config.build_aead_pipeline(&session.data_key)?);
let pipe = SrxPipeline::from_config(&config, session, aead, transport_mgr);
let node = Self {
config,
pipe,
self_healing: SelfHealing::new(),
policy: TransportPolicy::default(),
};
node.restore_replay_state_from_disk()?;
Ok(node)
}
pub fn from_session(
config: SrxConfig,
session: Session,
aead: Arc<AeadPipeline>,
transport_mgr: TransportManager,
) -> Result<Self> {
let pipe = SrxPipeline::from_config(&config, session, aead, transport_mgr);
let node = Self {
config,
pipe,
self_healing: SelfHealing::new(),
policy: TransportPolicy::default(),
};
node.restore_replay_state_from_disk()?;
Ok(node)
}
pub async fn send(&mut self, payload: &[u8]) -> Result<TransportKind> {
self.pipe.send(payload).await
}
pub async fn recv_from(&mut self, kind: TransportKind) -> Result<Vec<u8>> {
let payload = self.pipe.recv_from(kind).await?;
self.persist_replay_state_to_disk()?;
Ok(payload)
}
pub fn process_incoming(&self, envelope: &[u8]) -> Result<Vec<u8>> {
let payload = self.pipe.process_incoming(envelope)?;
self.persist_replay_state_to_disk()?;
Ok(payload)
}
pub fn prepare_outgoing(&mut self, payload: &[u8]) -> Result<Vec<u8>> {
self.pipe.prepare_outgoing(payload)
}
pub fn pipeline(&self) -> &SrxPipeline {
&self.pipe
}
pub fn pipeline_mut(&mut self) -> &mut SrxPipeline {
&mut self.pipe
}
pub fn config(&self) -> &SrxConfig {
&self.config
}
fn replay_persistence_enabled(&self) -> bool {
self.config.replay.persist_enabled
}
fn restore_replay_state_from_disk(&self) -> Result<()> {
if !self.replay_persistence_enabled() {
return Ok(());
}
let storage = storage_from_config(&self.config.replay)?;
let Some(raw) = storage.load_raw(&self.config.replay)? else {
return Ok(());
};
let Some(state) =
decode_replay_envelope(&self.config.replay, &self.replay_session_binding(), &raw)?
else {
return Ok(());
};
self.pipe.set_replay_state(&state)
}
fn replay_session_binding(&self) -> String {
use sha2::{Digest, Sha256};
let seed = self.pipe.session.rng.seed_bytes();
let digest = Sha256::digest(seed);
let mut out = String::with_capacity(24);
for b in digest.iter().take(12) {
out.push(char::from(b"0123456789abcdef"[(b >> 4) as usize]));
out.push(char::from(b"0123456789abcdef"[(b & 0x0f) as usize]));
}
out
}
fn persist_replay_state_to_disk(&self) -> Result<()> {
if !self.replay_persistence_enabled() {
return Ok(());
}
let state = self.replay_state();
let storage = storage_from_config(&self.config.replay)?;
merge_and_persist_replay_state(
&self.config.replay,
storage.as_ref(),
&self.replay_session_binding(),
state,
)
}
pub fn replay_state(&self) -> ReplayState {
self.pipe.replay_state()
}
pub fn set_replay_state(&self, state: &ReplayState) -> Result<()> {
self.pipe.set_replay_state(state)?;
self.persist_replay_state_to_disk()
}
pub fn send_signal(&mut self, signal: &Signal) -> Result<Vec<u8>> {
self.pipe.prepare_signal(signal)
}
pub fn process_incoming_dispatched(&self, envelope: &[u8]) -> Result<Payload> {
let payload = self.process_incoming(envelope)?;
match self.pipe.try_decode_signal(&payload) {
Some(sig) => Ok(Payload::Signal(sig)),
None => Ok(Payload::Data(payload)),
}
}
pub fn heal_if_needed(&mut self) -> Option<Vec<TransportKind>> {
if !self.self_healing.should_heal(self.pipe.transport_mgr()) {
return None;
}
let healthy = self.pipe.transport_mgr().healthy_kinds();
let active = self.pipe.transport_mgr().active_kinds();
let blocked: Vec<_> = active
.iter()
.filter(|k| !healthy.contains(k))
.copied()
.collect();
self.self_healing.reseed_only(&mut self.pipe.session.rng);
let mut order = self.policy.recommend(&healthy);
order.extend(self.policy.recommend(&blocked));
Some(order)
}
pub fn record_success(&mut self) {
self.self_healing.record_success();
}
pub fn heal_count(&self) -> u32 {
self.self_healing.heal_count
}
pub fn set_environment(&mut self, env: crate::client::policy::NetworkEnvironment) {
self.policy.set_environment(env);
}
pub fn policy(&self) -> &TransportPolicy {
&self.policy
}
pub fn replay_store_metrics() -> ReplayStoreMetricsSnapshot {
replay_store_metrics_snapshot()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::policy::NetworkEnvironment;
use crate::config::SrxConfig;
use crate::pipeline::Payload;
use crate::signaling::inband::Signal;
use tempfile::TempDir;
fn make_pair() -> (SrxNode, SrxNode) {
let master = [0xBBu8; 32];
let mut config = SrxConfig::default();
config.replay.persist_enabled = false;
let sender =
SrxNode::from_master_secret(config.clone(), master, TransportManager::new()).unwrap();
let receiver =
SrxNode::from_master_secret(config, master, TransportManager::new()).unwrap();
(sender, receiver)
}
#[test]
fn from_master_secret_builds_node() {
let mut config = SrxConfig::default();
config.replay.persist_enabled = false;
let master = [0xAAu8; 32];
let node = SrxNode::from_master_secret(config, master, TransportManager::new()).unwrap();
assert!(node.pipeline().session.active);
}
#[test]
fn node_prepare_process_roundtrip() {
let (mut sender, receiver) = make_pair();
let envelope = sender.prepare_outgoing(b"node-test").unwrap();
let recovered = receiver.process_incoming(&envelope).unwrap();
assert_eq!(recovered, b"node-test");
}
#[test]
fn signal_roundtrip_through_node() {
let (mut sender, receiver) = make_pair();
let envelope = sender.send_signal(&Signal::SeedRotation).unwrap();
let payload = receiver.process_incoming_dispatched(&envelope).unwrap();
match payload {
Payload::Signal(sig) => assert_eq!(sig, Signal::SeedRotation),
Payload::Data(_) => panic!("expected signal"),
}
}
#[test]
fn data_dispatched_as_data() {
let (mut sender, receiver) = make_pair();
let envelope = sender.prepare_outgoing(b"app-data").unwrap();
let payload = receiver.process_incoming_dispatched(&envelope).unwrap();
match payload {
Payload::Data(d) => assert_eq!(d, b"app-data"),
Payload::Signal(_) => panic!("expected data"),
}
}
#[test]
fn heal_if_needed_returns_none_when_healthy() {
let (mut node, _) = make_pair();
assert!(node.heal_if_needed().is_none());
assert_eq!(node.heal_count(), 0);
}
#[test]
fn record_success_resets_backoff() {
let (mut node, _) = make_pair();
node.record_success();
assert_eq!(node.heal_count(), 0);
}
#[test]
fn set_environment_updates_policy() {
let (mut node, _) = make_pair();
node.set_environment(NetworkEnvironment::Corporate);
assert_eq!(node.policy().environment(), NetworkEnvironment::Corporate);
}
#[test]
fn replay_state_snapshot_restore_on_node() {
let (mut sender, receiver) = make_pair();
let env1 = sender.prepare_outgoing(b"r1").unwrap();
let env2 = sender.prepare_outgoing(b"r2").unwrap();
assert_eq!(receiver.process_incoming(&env1).unwrap(), b"r1");
assert_eq!(receiver.process_incoming(&env2).unwrap(), b"r2");
let state = receiver.replay_state();
let restored = SrxNode::from_master_secret(
{
let mut cfg = SrxConfig::default();
cfg.replay.persist_enabled = false;
cfg
},
[0xBBu8; 32],
TransportManager::new(),
)
.unwrap();
restored.set_replay_state(&state).unwrap();
let replay = restored.process_incoming(&env2);
assert!(replay.is_err(), "restored node must reject duplicate");
}
#[test]
fn auto_persist_and_restore_replay_state() {
let temp = TempDir::new().unwrap();
let state_file = temp.path().join("replay_state.json");
let mut cfg = SrxConfig::default();
cfg.replay.persist_enabled = true;
cfg.replay.state_file = state_file.clone();
let master = [0xACu8; 32];
let mut sender =
SrxNode::from_master_secret(cfg.clone(), master, TransportManager::new()).unwrap();
let receiver_before =
SrxNode::from_master_secret(cfg.clone(), master, TransportManager::new()).unwrap();
let envelope = sender.prepare_outgoing(b"persisted").unwrap();
assert_eq!(
receiver_before.process_incoming(&envelope).unwrap(),
b"persisted"
);
assert!(state_file.exists(), "replay state file must be created");
let receiver_after =
SrxNode::from_master_secret(cfg, master, TransportManager::new()).unwrap();
assert!(
receiver_after.process_incoming(&envelope).is_err(),
"restored node must reject duplicate from persisted state"
);
}
}