use std::net::IpAddr;
use std::path::Path;
use std::sync::{Arc, RwLock};
use serde::{Deserialize, Serialize};
pub const MAINTENANCE_FLAG_FILE: &str = "tmp/autumn-maintenance.json";
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct MaintenanceConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
#[serde(default)]
pub allow_ips: Vec<String>,
#[serde(default)]
pub readonly: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub bypass_header: Option<(String, String)>,
}
#[derive(Clone, Debug, Default)]
pub struct MaintenanceState(Arc<RwLock<Option<MaintenanceConfig>>>);
impl MaintenanceState {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn is_active(&self) -> bool {
self.0
.read()
.expect("maintenance state lock poisoned")
.is_some()
}
#[must_use]
pub fn get(&self) -> Option<MaintenanceConfig> {
self.0
.read()
.expect("maintenance state lock poisoned")
.clone()
}
pub fn enable(&self, config: MaintenanceConfig) {
*self.0.write().expect("maintenance state lock poisoned") = Some(config);
}
pub fn disable(&self) {
*self.0.write().expect("maintenance state lock poisoned") = None;
}
pub fn load_from_file(path: &Path) -> std::io::Result<Option<MaintenanceConfig>> {
match std::fs::read_to_string(path) {
Ok(s) => {
let config: MaintenanceConfig = serde_json::from_str(&s)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(Some(config))
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
Err(e) => Err(e),
}
}
pub fn save_to_file(path: &Path, config: &MaintenanceConfig) -> std::io::Result<()> {
if let Some(parent) = path.parent().filter(|p| !p.as_os_str().is_empty()) {
std::fs::create_dir_all(parent)?;
}
let json = serde_json::to_string_pretty(config).map_err(std::io::Error::other)?;
std::fs::write(path, json)
}
pub fn remove_flag_file(path: &Path) -> std::io::Result<bool> {
match std::fs::remove_file(path) {
Ok(()) => Ok(true),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(false),
Err(e) => Err(e),
}
}
}
#[must_use]
pub fn ip_in_allow_list(client_ip: &IpAddr, allow_ips: &[String]) -> bool {
for entry in allow_ips {
if let Some((prefix, bits)) = entry.split_once('/') {
if let (Ok(network_ip), Ok(prefix_len)) = (prefix.parse::<IpAddr>(), bits.parse::<u8>())
&& ip_in_cidr(client_ip, &network_ip, prefix_len)
{
return true;
}
} else if let Ok(allowed) = entry.parse::<IpAddr>()
&& client_ip == &allowed
{
return true;
}
}
false
}
const fn ip_in_cidr(ip: &IpAddr, network: &IpAddr, prefix_len: u8) -> bool {
match (ip, network) {
(IpAddr::V4(ip), IpAddr::V4(net)) => {
if prefix_len > 32 {
return false;
}
let ip_bits = u32::from_be_bytes(ip.octets());
let net_bits = u32::from_be_bytes(net.octets());
let mask = if prefix_len == 0 {
0u32
} else {
!0u32 << (32 - prefix_len)
};
(ip_bits & mask) == (net_bits & mask)
}
(IpAddr::V6(ip), IpAddr::V6(net)) => {
if prefix_len > 128 {
return false;
}
let ip_bits = u128::from_be_bytes(ip.octets());
let net_bits = u128::from_be_bytes(net.octets());
let mask = if prefix_len == 0 {
0u128
} else {
!0u128 << (128 - prefix_len)
};
(ip_bits & mask) == (net_bits & mask)
}
_ => false, }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn maintenance_state_default_is_off() {
let state = MaintenanceState::new();
assert!(!state.is_active());
assert!(state.get().is_none());
}
#[test]
fn maintenance_state_enable_sets_active() {
let state = MaintenanceState::new();
state.enable(MaintenanceConfig::default());
assert!(state.is_active());
}
#[test]
fn maintenance_state_disable_clears_active() {
let state = MaintenanceState::new();
state.enable(MaintenanceConfig::default());
state.disable();
assert!(!state.is_active());
assert!(state.get().is_none());
}
#[test]
fn maintenance_state_get_returns_config() {
let state = MaintenanceState::new();
let config = MaintenanceConfig {
message: Some("deploying".into()),
readonly: true,
..Default::default()
};
state.enable(config.clone());
assert_eq!(state.get().unwrap(), config);
}
#[test]
fn maintenance_state_is_clone_safe() {
let state = MaintenanceState::new();
let clone = state.clone();
state.enable(MaintenanceConfig::default());
assert!(clone.is_active());
}
#[test]
fn load_from_file_returns_none_when_missing() {
let path = std::path::Path::new("/tmp/autumn_nonexistent_maintenance_test_xyz.json");
let result = MaintenanceState::load_from_file(path).unwrap();
assert!(result.is_none());
}
#[test]
fn save_and_load_round_trips() {
let tmp = tempfile::TempDir::new().unwrap();
let path = tmp.path().join("maintenance.json");
let config = MaintenanceConfig {
message: Some("Deploying new version".into()),
allow_ips: vec!["192.168.1.0/24".into()],
readonly: false,
bypass_header: Some(("X-Bypass".into(), "secret-token".into())),
};
MaintenanceState::save_to_file(&path, &config).unwrap();
let loaded = MaintenanceState::load_from_file(&path).unwrap().unwrap();
assert_eq!(loaded, config);
}
#[test]
fn save_creates_parent_directory() {
let tmp = tempfile::TempDir::new().unwrap();
let path = tmp.path().join("tmp").join("maintenance.json");
let config = MaintenanceConfig::default();
MaintenanceState::save_to_file(&path, &config).unwrap();
assert!(path.exists());
}
#[test]
fn remove_flag_file_returns_true_when_deleted() {
let tmp = tempfile::TempDir::new().unwrap();
let path = tmp.path().join("maintenance.json");
std::fs::write(&path, "{}").unwrap();
assert!(MaintenanceState::remove_flag_file(&path).unwrap());
}
#[test]
fn remove_flag_file_returns_false_when_already_absent() {
let tmp = tempfile::TempDir::new().unwrap();
let path = tmp.path().join("maintenance.json");
assert!(!MaintenanceState::remove_flag_file(&path).unwrap());
}
#[test]
fn maintenance_config_default_has_no_message() {
let c = MaintenanceConfig::default();
assert!(c.message.is_none());
assert!(c.allow_ips.is_empty());
assert!(!c.readonly);
assert!(c.bypass_header.is_none());
}
#[test]
fn ip_in_allow_list_exact_ipv4_match() {
let ip: IpAddr = "192.168.1.5".parse().unwrap();
assert!(ip_in_allow_list(&ip, &["192.168.1.5".into()]));
}
#[test]
fn ip_in_allow_list_exact_ipv4_no_match() {
let ip: IpAddr = "192.168.1.5".parse().unwrap();
assert!(!ip_in_allow_list(&ip, &["192.168.1.6".into()]));
}
#[test]
fn ip_in_allow_list_cidr_v4_in_range() {
let ip: IpAddr = "192.168.1.50".parse().unwrap();
assert!(ip_in_allow_list(&ip, &["192.168.1.0/24".into()]));
}
#[test]
fn ip_in_allow_list_cidr_v4_out_of_range() {
let ip: IpAddr = "192.168.2.50".parse().unwrap();
assert!(!ip_in_allow_list(&ip, &["192.168.1.0/24".into()]));
}
#[test]
fn ip_in_allow_list_cidr_v4_boundary() {
let ip: IpAddr = "10.0.0.255".parse().unwrap();
assert!(ip_in_allow_list(&ip, &["10.0.0.0/24".into()]));
let outside: IpAddr = "10.0.1.0".parse().unwrap();
assert!(!ip_in_allow_list(&outside, &["10.0.0.0/24".into()]));
}
#[test]
fn ip_in_allow_list_loopback() {
let ip: IpAddr = "127.0.0.1".parse().unwrap();
assert!(ip_in_allow_list(&ip, &["127.0.0.0/8".into()]));
assert!(ip_in_allow_list(&ip, &["127.0.0.1".into()]));
}
#[test]
fn ip_in_allow_list_empty_returns_false() {
let ip: IpAddr = "127.0.0.1".parse().unwrap();
assert!(!ip_in_allow_list(&ip, &[]));
}
#[test]
fn ip_in_allow_list_ipv6_exact() {
let ip: IpAddr = "::1".parse().unwrap();
assert!(ip_in_allow_list(&ip, &["::1".into()]));
}
#[test]
fn ip_in_allow_list_mixed_family_no_match() {
let ipv4: IpAddr = "127.0.0.1".parse().unwrap();
assert!(!ip_in_allow_list(&ipv4, &["::1".into()]));
}
#[test]
fn ip_in_allow_list_multiple_entries() {
let ip: IpAddr = "10.10.10.10".parse().unwrap();
let allow = vec!["192.168.1.1".into(), "10.10.0.0/16".into()];
assert!(ip_in_allow_list(&ip, &allow));
}
#[test]
fn ip_in_allow_list_invalid_entry_skipped() {
let ip: IpAddr = "192.168.1.1".parse().unwrap();
assert!(!ip_in_allow_list(
&ip,
&["not-an-ip".into(), "999.999.999.999".into()]
));
}
}