use serde::{Deserialize, Serialize};
use std::path::PathBuf;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum SocketAddr {
Unix(PathBuf),
Tcp(String),
}
impl SocketAddr {
pub fn unix<P: Into<PathBuf>>(path: P) -> Self {
SocketAddr::Unix(path.into())
}
pub fn tcp<S: Into<String>>(addr: S) -> Self {
SocketAddr::Tcp(addr.into())
}
pub fn is_unix(&self) -> bool {
matches!(self, SocketAddr::Unix(_))
}
pub fn is_tcp(&self) -> bool {
matches!(self, SocketAddr::Tcp(_))
}
pub fn as_unix_path(&self) -> Option<&PathBuf> {
match self {
SocketAddr::Unix(p) => Some(p),
_ => None,
}
}
pub fn as_tcp_addr(&self) -> Option<&str> {
match self {
SocketAddr::Tcp(a) => Some(a),
_ => None,
}
}
}
impl std::fmt::Display for SocketAddr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SocketAddr::Unix(p) => write!(f, "unix:{}", p.display()),
SocketAddr::Tcp(a) => write!(f, "tcp:{}", a),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct XinetConfig {
pub name: String,
pub listen: Vec<SocketAddr>,
pub backend: SocketAddr,
pub service: String,
#[serde(default = "default_connect_timeout")]
pub connect_timeout: u64,
#[serde(default)]
pub idle_timeout: u64,
#[serde(default)]
pub single_connection: bool,
}
fn default_connect_timeout() -> u64 {
30
}
impl XinetConfig {
pub fn new<S: Into<String>>(
name: S,
listen: SocketAddr,
backend: SocketAddr,
service: S,
) -> Self {
Self {
name: name.into(),
listen: vec![listen],
backend,
service: service.into(),
connect_timeout: default_connect_timeout(),
idle_timeout: 0,
single_connection: false,
}
}
pub fn new_multi<S: Into<String>>(
name: S,
listen: Vec<SocketAddr>,
backend: SocketAddr,
service: S,
) -> Self {
Self {
name: name.into(),
listen,
backend,
service: service.into(),
connect_timeout: default_connect_timeout(),
idle_timeout: 0,
single_connection: false,
}
}
pub fn add_listen(mut self, addr: SocketAddr) -> Self {
self.listen.push(addr);
self
}
pub fn with_connect_timeout(mut self, seconds: u64) -> Self {
self.connect_timeout = seconds;
self
}
pub fn with_idle_timeout(mut self, seconds: u64) -> Self {
self.idle_timeout = seconds;
self
}
pub fn with_single_connection(mut self, single: bool) -> Self {
self.single_connection = single;
self
}
pub fn validate(&self) -> Result<(), String> {
if self.name.is_empty() {
return Err("name is required".to_string());
}
if self.listen.is_empty() {
return Err("at least one listen address is required".to_string());
}
if self.service.is_empty() {
return Err("service name is required".to_string());
}
Ok(())
}
pub fn listen_addrs_string(&self) -> String {
self.listen
.iter()
.map(|a| a.to_string())
.collect::<Vec<_>>()
.join(", ")
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProxyStatus {
pub name: String,
pub listen: String,
pub backend: String,
pub service: String,
pub total_connections: u64,
pub active_connections: usize,
pub bytes_to_backend: u64,
pub bytes_from_backend: u64,
pub running: bool,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_socket_addr_unix() {
let addr = SocketAddr::unix("/tmp/test.sock");
assert!(addr.is_unix());
assert!(!addr.is_tcp());
assert_eq!(addr.as_unix_path(), Some(&PathBuf::from("/tmp/test.sock")));
assert_eq!(format!("{}", addr), "unix:/tmp/test.sock");
}
#[test]
fn test_socket_addr_tcp() {
let addr = SocketAddr::tcp("127.0.0.1:8080");
assert!(!addr.is_unix());
assert!(addr.is_tcp());
assert_eq!(addr.as_tcp_addr(), Some("127.0.0.1:8080"));
assert_eq!(format!("{}", addr), "tcp:127.0.0.1:8080");
}
#[test]
fn test_xinet_config() {
let config = XinetConfig::new(
"myproxy",
SocketAddr::unix("/tmp/frontend.sock"),
SocketAddr::tcp("127.0.0.1:5432"),
"postgres",
)
.with_idle_timeout(300)
.with_single_connection(true);
assert_eq!(config.name, "myproxy");
assert_eq!(config.service, "postgres");
assert_eq!(config.listen.len(), 1);
assert_eq!(config.idle_timeout, 300);
assert!(config.single_connection);
assert!(config.validate().is_ok());
}
#[test]
fn test_xinet_config_multi_listen() {
let config = XinetConfig::new(
"myproxy",
SocketAddr::unix("/tmp/frontend.sock"),
SocketAddr::tcp("127.0.0.1:5432"),
"postgres",
)
.add_listen(SocketAddr::tcp("127.0.0.1:5433"));
assert_eq!(config.listen.len(), 2);
assert_eq!(
config.listen_addrs_string(),
"unix:/tmp/frontend.sock, tcp:127.0.0.1:5433"
);
assert!(config.validate().is_ok());
}
#[test]
fn test_proxy_status_serialize() {
let status = ProxyStatus {
name: "test".to_string(),
listen: "tcp:127.0.0.1:8080".to_string(),
backend: "tcp:127.0.0.1:5432".to_string(),
service: "postgres".to_string(),
total_connections: 100,
active_connections: 5,
bytes_to_backend: 1024,
bytes_from_backend: 2048,
running: true,
};
let json = serde_json::to_string(&status).unwrap();
assert!(json.contains("\"name\":\"test\""));
}
}