use std::sync::Arc;
use parking_lot::Mutex;
#[derive(Debug, thiserror::Error)]
pub enum ApplyError {
#[error("invalid settings: {0}")]
ValidationFailed(String),
#[error("listen port {attempted} in use (still listening on {existing})")]
ListenPortInUse {
attempted: u16,
existing: u16,
},
#[error("DHT restart failed: {0}")]
DhtRestartFailed(String),
#[error("LSD restart failed: {0}")]
LsdRestartFailed(String),
#[error("NAT refresh failed: {0}")]
NatRefreshFailed(String),
#[error("concurrent reconfig in flight, retry shortly")]
ConcurrentReconfig,
#[error("I/O during reconfig: {0}")]
Io(String),
}
impl ApplyError {
#[must_use]
pub fn http_status(&self) -> u16 {
match self {
Self::ValidationFailed(_) => 400,
Self::ListenPortInUse { .. } | Self::ConcurrentReconfig => 409,
Self::NatRefreshFailed(_) => 200,
Self::DhtRestartFailed(_) | Self::LsdRestartFailed(_) | Self::Io(_) => 500,
}
}
#[must_use]
pub const fn is_fatal(&self) -> bool {
false
}
}
#[derive(Debug, Clone, Default)]
pub struct ReconfigInFlight {
inner: Arc<Mutex<bool>>,
}
#[must_use = "the guard must be held for the duration of the apply call"]
pub struct ReconfigGuard<'g> {
parent: &'g ReconfigInFlight,
}
impl ReconfigInFlight {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn try_lock(&self) -> Option<ReconfigGuard<'_>> {
let mut held = self.inner.lock();
if *held {
None
} else {
*held = true;
Some(ReconfigGuard { parent: self })
}
}
}
impl Drop for ReconfigGuard<'_> {
fn drop(&mut self) {
let mut held = self.parent.inner.lock();
*held = false;
}
}
pub type ForwardStep<S> = Box<dyn FnOnce(&mut S) -> Result<(), ApplyError> + Send>;
pub type RollbackStep<S> = Box<dyn FnOnce(&mut S) + Send>;
pub struct Phase<S> {
pub name: &'static str,
pub forward: ForwardStep<S>,
pub rollback: RollbackStep<S>,
}
pub fn apply_phases_with_rollback<S>(state: &mut S, phases: Vec<Phase<S>>) -> Result<(), ApplyError>
where
S: Send,
{
let mut applied: Vec<(&'static str, RollbackStep<S>)> = Vec::with_capacity(phases.len());
for phase in phases {
let Phase {
name,
forward,
rollback,
} = phase;
match forward(state) {
Ok(()) => {
applied.push((name, rollback));
}
Err(e) => {
while let Some((rb_name, rb)) = applied.pop() {
tracing::debug!(phase = rb_name, "rolling back phase");
rb(state);
}
tracing::warn!(phase = name, error = %e, "transactional apply failed, rolled back");
return Err(e);
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Default)]
struct MockState {
rate_limit: u32,
listen_port: u16,
dht_enabled: bool,
#[allow(dead_code)]
lsd_enabled: bool,
events: Vec<&'static str>,
}
fn make_phase<F1, F2>(name: &'static str, forward: F1, rollback: F2) -> Phase<MockState>
where
F1: FnOnce(&mut MockState) -> Result<(), ApplyError> + Send + 'static,
F2: FnOnce(&mut MockState) + Send + 'static,
{
Phase {
name,
forward: Box::new(forward),
rollback: Box::new(rollback),
}
}
#[test]
fn empty_phase_list_succeeds_noop() {
let mut state = MockState::default();
let result = apply_phases_with_rollback(&mut state, Vec::new());
assert!(result.is_ok());
assert!(state.events.is_empty());
}
#[test]
fn all_phases_succeed_in_order() {
let mut state = MockState::default();
let phases = vec![
make_phase(
"rate_limits",
|s| {
s.rate_limit = 100;
s.events.push("rate:fwd");
Ok(())
},
|s| s.events.push("rate:rb"),
),
make_phase(
"listen_port",
|s| {
s.listen_port = 6881;
s.events.push("port:fwd");
Ok(())
},
|s| s.events.push("port:rb"),
),
make_phase(
"dht",
|s| {
s.dht_enabled = true;
s.events.push("dht:fwd");
Ok(())
},
|s| s.events.push("dht:rb"),
),
];
let result = apply_phases_with_rollback(&mut state, phases);
assert!(result.is_ok());
assert_eq!(state.rate_limit, 100);
assert_eq!(state.listen_port, 6881);
assert!(state.dht_enabled);
assert_eq!(state.events, vec!["rate:fwd", "port:fwd", "dht:fwd"]);
}
#[test]
fn third_phase_fails_first_two_rollback_in_reverse() {
let mut state = MockState {
rate_limit: 1000, listen_port: 51413, dht_enabled: false, lsd_enabled: false,
events: Vec::new(),
};
let pre_rate = state.rate_limit;
let pre_port = state.listen_port;
let phases = vec![
make_phase(
"rate_limits",
|s| {
s.rate_limit = 100;
s.events.push("rate:fwd");
Ok(())
},
move |s| {
s.rate_limit = pre_rate; s.events.push("rate:rb");
},
),
make_phase(
"listen_port",
|s| {
s.listen_port = 6881;
s.events.push("port:fwd");
Ok(())
},
move |s| {
s.listen_port = pre_port;
s.events.push("port:rb");
},
),
make_phase(
"dht",
|s| {
s.events.push("dht:fwd-fail");
Err(ApplyError::DhtRestartFailed("simulated".into()))
},
|_| panic!("dht rollback must NOT run if forward failed"),
),
];
let result = apply_phases_with_rollback(&mut state, phases);
assert!(matches!(result, Err(ApplyError::DhtRestartFailed(_))));
assert_eq!(state.rate_limit, pre_rate);
assert_eq!(state.listen_port, pre_port);
assert_eq!(
state.events,
vec![
"rate:fwd",
"port:fwd",
"dht:fwd-fail",
"port:rb", "rate:rb", ]
);
}
#[test]
fn first_phase_fails_no_rollback() {
let mut state = MockState::default();
let phases = vec![
make_phase(
"rate_limits",
|s| {
s.events.push("rate:fwd-fail");
Err(ApplyError::ValidationFailed("rate too low".into()))
},
|_| panic!("rollback must NOT run for failed forward"),
),
make_phase(
"listen_port",
|_| panic!("subsequent forward must NOT run after a failure"),
|_| panic!("subsequent rollback must NOT run after a failure"),
),
];
let result = apply_phases_with_rollback(&mut state, phases);
assert!(matches!(result, Err(ApplyError::ValidationFailed(_))));
assert_eq!(state.events, vec!["rate:fwd-fail"]);
}
#[test]
fn reconfig_in_flight_single_lock_releases_on_drop() {
let guard = ReconfigInFlight::new();
{
let _lock = guard.try_lock().expect("first lock should succeed");
assert!(
guard.try_lock().is_none(),
"second concurrent lock must fail"
);
}
assert!(
guard.try_lock().is_some(),
"lock must be acquirable again after the guard is dropped"
);
}
#[test]
fn reconfig_in_flight_default_is_unlocked() {
let g = ReconfigInFlight::default();
assert!(g.try_lock().is_some(), "fresh guard should be unlocked");
}
#[test]
fn apply_error_http_status_classification() {
assert_eq!(ApplyError::ValidationFailed("x".into()).http_status(), 400);
assert_eq!(
ApplyError::ListenPortInUse {
attempted: 6881,
existing: 51413
}
.http_status(),
409
);
assert_eq!(ApplyError::ConcurrentReconfig.http_status(), 409);
assert_eq!(ApplyError::DhtRestartFailed("d".into()).http_status(), 500);
assert_eq!(ApplyError::LsdRestartFailed("l".into()).http_status(), 500);
assert_eq!(ApplyError::NatRefreshFailed("n".into()).http_status(), 200);
assert_eq!(ApplyError::Io("i".into()).http_status(), 500);
}
#[test]
fn apply_error_is_fatal_returns_false_in_b1() {
assert!(!ApplyError::ValidationFailed("x".into()).is_fatal());
assert!(!ApplyError::ConcurrentReconfig.is_fatal());
}
}