use crate::charter::Uplink;
use chrono_machines::{BackoffStrategy, ExponentialBackoff};
use rama::http::{client::EasyHttpWebClient, service::client::HttpClientExt};
use rand::rng;
use std::net::ToSocketAddrs;
use std::time::Duration;
use thiserror::Error;
use tokio::net::TcpStream;
use tokio::time::timeout;
use tracing::{debug, error, warn};
#[derive(Error, Debug)]
pub enum UplinkError {
#[error("Uplink check failed:\n{0}")]
ChecksFailed(String),
#[error("Invalid URL '{url}' for uplink '{name}': {reason}")]
InvalidUrl {
name: String,
url: String,
reason: String,
},
#[error(
"Unsupported scheme '{scheme}' for uplink '{name}'. Supported: postgres, postgresql, mysql, redis, memgraph, neo4j, http, https, tcp"
)]
UnsupportedScheme { name: String, scheme: String },
}
#[derive(Debug)]
struct UplinkCheckResult {
name: String,
url: String,
success: bool,
error: Option<String>,
}
fn parse_uplink_url(name: &str, url: &str) -> Result<(String, u16, bool), UplinkError> {
let (scheme, rest) = url
.split_once("://")
.ok_or_else(|| UplinkError::InvalidUrl {
name: name.to_string(),
url: url.to_string(),
reason: "missing '://' scheme separator".to_string(),
})?;
let scheme_lower = scheme.to_lowercase();
let (default_port, is_http) = match scheme_lower.as_str() {
"postgres" | "postgresql" => (Some(5432), false),
"mysql" => (Some(3306), false),
"redis" => (Some(6379), false),
"memgraph" | "neo4j" => (Some(7687), false),
"http" => (Some(80), true),
"https" => (Some(443), true),
"tcp" => (None, false),
_ => {
return Err(UplinkError::UnsupportedScheme {
name: name.to_string(),
scheme: scheme.to_string(),
});
}
};
if is_http {
let host_part = rest.split('/').next().unwrap_or(rest);
let host_part = host_part.split('@').next_back().unwrap_or(host_part);
let (host, port) = if let Some((h, p)) = host_part.rsplit_once(':') {
let port = p.parse::<u16>().map_err(|_| UplinkError::InvalidUrl {
name: name.to_string(),
url: url.to_string(),
reason: format!("invalid port '{p}'"),
})?;
(h.to_string(), port)
} else {
(host_part.to_string(), default_port.unwrap())
};
return Ok((host, port, true));
}
let host_part = rest.split('/').next().unwrap_or(rest);
let host_part = host_part.split('@').next_back().unwrap_or(host_part);
let (host, port) = if let Some((h, p)) = host_part.rsplit_once(':') {
let port = p.parse::<u16>().map_err(|_| UplinkError::InvalidUrl {
name: name.to_string(),
url: url.to_string(),
reason: format!("invalid port '{p}'"),
})?;
(h.to_string(), port)
} else if let Some(dp) = default_port {
(host_part.to_string(), dp)
} else {
return Err(UplinkError::InvalidUrl {
name: name.to_string(),
url: url.to_string(),
reason: "port required for tcp:// scheme".to_string(),
});
};
if host.is_empty() {
return Err(UplinkError::InvalidUrl {
name: name.to_string(),
url: url.to_string(),
reason: "empty host".to_string(),
});
}
Ok((host, port, false))
}
fn parse_timeout(s: &str) -> Duration {
let s = s.trim();
if let Some(secs) = s.strip_suffix('s')
&& let Ok(n) = secs.parse::<u64>()
{
return Duration::from_secs(n);
}
if let Some(mins) = s.strip_suffix('m')
&& let Ok(n) = mins.parse::<u64>()
{
return Duration::from_secs(n * 60);
}
Duration::from_secs(5)
}
async fn check_uplink(uplink: &Uplink) -> UplinkCheckResult {
let timeout_duration = parse_timeout(&uplink.timeout);
let (host, port, is_http) = match parse_uplink_url(&uplink.name, &uplink.url) {
Ok(result) => result,
Err(e) => {
return UplinkCheckResult {
name: uplink.name.clone(),
url: uplink.url.clone(),
success: false,
error: Some(e.to_string()),
};
}
};
if is_http {
match check_http(&uplink.url, timeout_duration).await {
Ok(()) => UplinkCheckResult {
name: uplink.name.clone(),
url: uplink.url.clone(),
success: true,
error: None,
},
Err(e) => UplinkCheckResult {
name: uplink.name.clone(),
url: uplink.url.clone(),
success: false,
error: Some(e),
},
}
} else {
match check_tcp(&host, port, timeout_duration).await {
Ok(_) => UplinkCheckResult {
name: uplink.name.clone(),
url: uplink.url.clone(),
success: true,
error: None,
},
Err(e) => UplinkCheckResult {
name: uplink.name.clone(),
url: uplink.url.clone(),
success: false,
error: Some(e),
},
}
}
}
async fn check_tcp(host: &str, port: u16, timeout_duration: Duration) -> Result<String, String> {
let addr = format!("{host}:{port}");
let socket_addr = addr
.to_socket_addrs()
.map_err(|e| format!("DNS resolution failed: {e}"))?
.next()
.ok_or_else(|| "DNS resolution returned no addresses".to_string())?;
debug!(addr = %addr, "Checking TCP uplink");
let backoff = ExponentialBackoff::new()
.base_delay_ms(100)
.max_delay_ms(2000)
.max_attempts(3);
let mut rng = rng();
let mut attempt = 0u8;
loop {
attempt += 1;
let last_error = match timeout(timeout_duration, TcpStream::connect(socket_addr)).await {
Ok(Ok(_stream)) => {
debug!(addr = %addr, attempt = attempt, "TCP uplink OK");
return Ok(format!("Connected to {addr}"));
}
Ok(Err(e)) => format!("Connection failed: {e}"),
Err(_) => format!("Connection timed out after {}s", timeout_duration.as_secs()),
};
match backoff.delay(attempt, &mut rng) {
Some(delay_ms) => {
warn!(addr = %addr, attempt = attempt, delay_ms = delay_ms, "TCP uplink failed, retrying");
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
None => {
error!(addr = %addr, attempts = attempt, "TCP uplink failed after all retries");
return Err(last_error);
}
}
}
}
async fn check_http(url: &str, timeout_duration: Duration) -> Result<(), String> {
debug!(url = %url, "Checking HTTP uplink");
let client = EasyHttpWebClient::default();
let backoff = ExponentialBackoff::new()
.base_delay_ms(100)
.max_delay_ms(2000)
.max_attempts(3);
let mut rng = rng();
let mut attempt = 0u8;
loop {
attempt += 1;
let last_error = match timeout(timeout_duration, client.get(url).send()).await {
Ok(Ok(response)) => {
if response.status().is_success() {
debug!(url = %url, status = %response.status(), attempt = attempt, "HTTP uplink OK");
return Ok(());
} else {
format!(
"HTTP {} {}",
response.status().as_u16(),
response.status().canonical_reason().unwrap_or("")
)
}
}
Ok(Err(e)) => format!("Request failed: {e}"),
Err(_) => format!("Connection timed out after {}s", timeout_duration.as_secs()),
};
match backoff.delay(attempt, &mut rng) {
Some(delay_ms) => {
warn!(url = %url, attempt = attempt, delay_ms = delay_ms, "HTTP uplink failed, retrying");
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
None => {
error!(url = %url, attempts = attempt, "HTTP uplink failed after all retries");
return Err(last_error);
}
}
}
}
fn uplink_key(uplink: &Uplink) -> Option<String> {
let (host, port, is_http) = parse_uplink_url(&uplink.name, &uplink.url).ok()?;
if is_http {
Some(uplink.url.clone())
} else {
Some(format!("{host}:{port}"))
}
}
pub async fn verify_uplinks(uplinks: &[Uplink]) -> Result<(), UplinkError> {
if uplinks.is_empty() {
return Ok(());
}
let mut seen_endpoints = std::collections::HashSet::new();
let unique_uplinks: Vec<_> = uplinks
.iter()
.filter(|u| {
uplink_key(u)
.map(|key| seen_endpoints.insert(key))
.unwrap_or(true) })
.collect();
debug!(
total = uplinks.len(),
unique = unique_uplinks.len(),
"Checking uplinks (deduped by host:port)"
);
let checks: Vec<_> = unique_uplinks.iter().map(|u| check_uplink(u)).collect();
let results = futures::future::join_all(checks).await;
let failures: Vec<_> = results.iter().filter(|r| !r.success).collect();
if failures.is_empty() {
Ok(())
} else {
let mut error_msg = String::new();
for failure in failures {
error!(
uplink = %failure.name,
url = %failure.url,
error = %failure.error.as_deref().unwrap_or("unknown"),
"Uplink check failed"
);
error_msg.push_str(&format!(
"\n {} ({})\n {}",
failure.name,
failure.url,
failure.error.as_deref().unwrap_or("unknown error")
));
}
Err(UplinkError::ChecksFailed(error_msg))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_postgres_url() {
let (host, port, is_http) =
parse_uplink_url("db", "postgres://user:pass@localhost:5432/mydb").unwrap();
assert_eq!(host, "localhost");
assert_eq!(port, 5432);
assert!(!is_http);
}
#[test]
fn test_parse_postgres_default_port() {
let (host, port, is_http) = parse_uplink_url("db", "postgres://localhost/mydb").unwrap();
assert_eq!(host, "localhost");
assert_eq!(port, 5432);
assert!(!is_http);
}
#[test]
fn test_parse_redis_url() {
let (host, port, is_http) =
parse_uplink_url("cache", "redis://192.168.1.100:6379").unwrap();
assert_eq!(host, "192.168.1.100");
assert_eq!(port, 6379);
assert!(!is_http);
}
#[test]
fn test_parse_mysql_url() {
let (host, port, _) = parse_uplink_url("db", "mysql://root@db.example.com/app").unwrap();
assert_eq!(host, "db.example.com");
assert_eq!(port, 3306);
}
#[test]
fn test_parse_neo4j_url() {
let (host, port, _) = parse_uplink_url("graph", "neo4j://localhost:7687").unwrap();
assert_eq!(host, "localhost");
assert_eq!(port, 7687);
}
#[test]
fn test_parse_memgraph_url() {
let (host, port, _) = parse_uplink_url("graph", "memgraph://localhost").unwrap();
assert_eq!(host, "localhost");
assert_eq!(port, 7687);
}
#[test]
fn test_parse_http_url() {
let (host, port, is_http) =
parse_uplink_url("api", "http://api.example.com/health").unwrap();
assert_eq!(host, "api.example.com");
assert_eq!(port, 80);
assert!(is_http);
}
#[test]
fn test_parse_https_url() {
let (host, port, is_http) =
parse_uplink_url("api", "https://api.example.com:8443/health").unwrap();
assert_eq!(host, "api.example.com");
assert_eq!(port, 8443);
assert!(is_http);
}
#[test]
fn test_parse_tcp_requires_port() {
let result = parse_uplink_url("raw", "tcp://localhost");
assert!(result.is_err());
}
#[test]
fn test_parse_tcp_with_port() {
let (host, port, _) = parse_uplink_url("raw", "tcp://localhost:9999").unwrap();
assert_eq!(host, "localhost");
assert_eq!(port, 9999);
}
#[test]
fn test_parse_unsupported_scheme() {
let result = parse_uplink_url("foo", "mongodb://localhost:27017");
assert!(matches!(result, Err(UplinkError::UnsupportedScheme { .. })));
}
#[test]
fn test_parse_timeout() {
assert_eq!(parse_timeout("5s"), Duration::from_secs(5));
assert_eq!(parse_timeout("10s"), Duration::from_secs(10));
assert_eq!(parse_timeout("1m"), Duration::from_secs(60));
assert_eq!(parse_timeout("invalid"), Duration::from_secs(5)); }
#[test]
fn test_parse_postgres_custom_port() {
let (host, port, _) = parse_uplink_url("db", "postgres://localhost:5434/mydb").unwrap();
assert_eq!(host, "localhost");
assert_eq!(port, 5434); }
#[test]
fn test_dedup_uplinks_by_host_port_and_http_full_url() {
#[allow(clippy::useless_vec)]
let uplinks = vec![
Uplink {
url: "postgres://localhost:5432/mydb".to_string(),
name: "primary".to_string(),
timeout: "5s".to_string(),
},
Uplink {
url: "postgres://localhost:5432/mypotato".to_string(), name: "replica".to_string(),
timeout: "5s".to_string(),
},
Uplink {
url: "redis://localhost:6379".to_string(), name: "cache".to_string(),
timeout: "5s".to_string(),
},
Uplink {
url: "http://api.example.com/health".to_string(),
name: "api-health".to_string(),
timeout: "5s".to_string(),
},
Uplink {
url: "http://api.example.com/ready".to_string(),
name: "api-ready".to_string(),
timeout: "5s".to_string(),
},
Uplink {
url: "http://api.example.com/health".to_string(), name: "api-health-dup".to_string(),
timeout: "5s".to_string(),
},
];
let mut seen_endpoints = std::collections::HashSet::new();
let unique: Vec<_> = uplinks
.iter()
.filter(|u| {
uplink_key(u)
.map(|key| seen_endpoints.insert(key))
.unwrap_or(true)
})
.collect();
assert_eq!(unique.len(), 4); assert_eq!(unique[0].name, "primary"); assert_eq!(unique[1].name, "cache");
assert_eq!(unique[2].name, "api-health");
assert_eq!(unique[3].name, "api-ready");
}
}