use crate::layers::l4::{tcp::TcpConfig, udp::UdpConfig};
use arc_swap::ArcSwap;
use dashmap::DashMap;
use once_cell::sync::Lazy;
use serde::Serialize;
use std::sync::Arc;
use tokio::{sync::oneshot, time::Instant};
#[derive(Serialize, Debug, Clone, PartialEq, Eq, Hash)]
#[serde(rename_all = "lowercase")]
pub enum Protocol {
Tcp,
Udp,
}
#[derive(Serialize, Debug, Clone)]
pub struct PortStatus {
pub port: u16,
pub active: bool,
#[serde(skip_serializing_if = "Option::is_none", with = "serde_arc")]
pub tcp_config: Option<Arc<TcpConfig>>,
#[serde(skip_serializing_if = "Option::is_none", with = "serde_arc")]
pub udp_config: Option<Arc<UdpConfig>>,
}
pub type PortState = Arc<ArcSwap<Vec<PortStatus>>>;
pub static CONFIG_STATE: Lazy<PortState> =
Lazy::new(|| Arc::new(ArcSwap::new(Arc::new(Vec::new()))));
pub enum ListenerState {
Active,
Draining { since: Instant },
}
pub struct RunningListener {
pub state: Arc<tokio::sync::Mutex<ListenerState>>,
pub shutdown_tx: oneshot::Sender<()>,
}
pub static TASK_REGISTRY: Lazy<DashMap<(u16, Protocol), RunningListener>> = Lazy::new(DashMap::new);
mod serde_arc {
use serde::{Serialize, Serializer};
use std::sync::Arc;
#[allow(clippy::ref_option)]
pub(super) fn serialize<S, T>(val: &Option<Arc<T>>, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
T: Serialize,
{
match val {
Some(v) => v.serialize(s),
None => s.serialize_none(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::layers::l4::{
legacy::{
LegacyTcpConfig,
tcp::{TcpDestination, TcpProtocolRule},
},
model::{Detect, DetectMethod, Forward, Strategy},
tcp::TcpConfig,
};
use serde_json::json;
#[test]
fn test_protocol_serialization() {
let tcp = Protocol::Tcp;
let udp = Protocol::Udp;
assert_eq!(serde_json::to_string(&tcp).unwrap(), "\"tcp\"");
assert_eq!(serde_json::to_string(&udp).unwrap(), "\"udp\"");
}
#[test]
fn test_port_status_serialization() {
let dummy_legacy_config = LegacyTcpConfig {
rules: vec![TcpProtocolRule {
name: "test".to_string(),
priority: 1,
detect: Detect {
method: DetectMethod::Fallback,
pattern: "any".to_string(),
},
session: None,
destination: TcpDestination::Forward {
forward: Forward {
strategy: Strategy::Random,
targets: vec![],
fallbacks: vec![],
},
},
}],
};
let dummy_tcp_config = Arc::new(TcpConfig::Legacy(dummy_legacy_config));
let full_status = PortStatus {
port: 8080,
active: true,
tcp_config: Some(dummy_tcp_config.clone()),
udp_config: None,
};
let full_json = serde_json::to_value(&full_status).unwrap();
let expected_full_json = json!({
"port": 8080,
"active": true,
"tcp_config": {
"protocols": [{
"name": "test",
"priority": 1,
"detect": { "method": "fallback", "pattern": "any" },
"session": null,
"destination": {
"type": "forward",
"forward": {
"strategy": "random",
"targets": [],
"fallbacks": []
}
}
}]
}
});
assert_eq!(full_json, expected_full_json);
let minimal_status = PortStatus {
port: 9090,
active: false,
tcp_config: None,
udp_config: None,
};
let minimal_json = serde_json::to_value(&minimal_status).unwrap();
let expected_minimal_json = json!({
"port": 9090,
"active": false
});
assert_eq!(minimal_json, expected_minimal_json);
}
}