use crate::config::{CompiledEndpointRules, RouteConfig};
use crate::error::{ProxyError, Result};
use rustls::pki_types::pem::PemObject;
use std::collections::HashMap;
use std::sync::Arc;
use tracing::debug;
use zeroize::Zeroizing;
pub struct LoadedRoute {
pub upstream: String,
pub upstream_host_port: Option<String>,
pub endpoint_rules: CompiledEndpointRules,
pub tls_connector: Option<tokio_rustls::TlsConnector>,
}
impl std::fmt::Debug for LoadedRoute {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LoadedRoute")
.field("upstream", &self.upstream)
.field("upstream_host_port", &self.upstream_host_port)
.field("endpoint_rules", &self.endpoint_rules)
.field("has_custom_tls_ca", &self.tls_connector.is_some())
.finish()
}
}
#[derive(Debug)]
pub struct RouteStore {
routes: HashMap<String, LoadedRoute>,
}
impl RouteStore {
pub fn load(routes: &[RouteConfig]) -> Result<Self> {
let mut loaded = HashMap::new();
for route in routes {
let normalized_prefix = route.prefix.trim_matches('/').to_string();
debug!(
"Loading route '{}' -> {}",
normalized_prefix, route.upstream
);
let endpoint_rules = CompiledEndpointRules::compile(&route.endpoint_rules)
.map_err(|e| ProxyError::Config(format!("route '{}': {}", normalized_prefix, e)))?;
let tls_connector = if route.tls_ca.is_some()
|| route.tls_client_cert.is_some()
|| route.tls_client_key.is_some()
{
debug!(
"Building TLS connector for route '{}' (ca={}, client_cert={})",
normalized_prefix,
route.tls_ca.is_some(),
route.tls_client_cert.is_some(),
);
Some(build_tls_connector(
route.tls_ca.as_deref(),
route.tls_client_cert.as_deref(),
route.tls_client_key.as_deref(),
)?)
} else {
None
};
let upstream_host_port = extract_host_port(&route.upstream);
loaded.insert(
normalized_prefix,
LoadedRoute {
upstream: route.upstream.clone(),
upstream_host_port,
endpoint_rules,
tls_connector,
},
);
}
Ok(Self { routes: loaded })
}
#[must_use]
pub fn empty() -> Self {
Self {
routes: HashMap::new(),
}
}
#[must_use]
pub fn get(&self, prefix: &str) -> Option<&LoadedRoute> {
self.routes.get(prefix)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.routes.is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.routes.len()
}
#[must_use]
pub fn is_route_upstream(&self, host_port: &str) -> bool {
let normalised = host_port.to_lowercase();
self.routes.values().any(|route| {
route
.upstream_host_port
.as_ref()
.is_some_and(|hp| *hp == normalised)
})
}
#[must_use]
pub fn route_upstream_hosts(&self) -> std::collections::HashSet<String> {
self.routes
.values()
.filter_map(|route| route.upstream_host_port.clone())
.collect()
}
}
fn extract_host_port(url: &str) -> Option<String> {
let parsed = url::Url::parse(url).ok()?;
let host = parsed.host_str()?;
let default_port = match parsed.scheme() {
"https" => 443,
"http" => 80,
_ => return None,
};
let port = parsed.port().unwrap_or(default_port);
Some(format!("{}:{}", host.to_lowercase(), port))
}
fn read_pem_file(path: &std::path::Path, label: &str) -> Result<Zeroizing<Vec<u8>>> {
std::fs::read(path)
.map(Zeroizing::new)
.map_err(|e| match e.kind() {
std::io::ErrorKind::NotFound => {
ProxyError::Config(format!("{} file not found: '{}'", label, path.display()))
}
std::io::ErrorKind::PermissionDenied => ProxyError::Config(format!(
"{} permission denied: '{}' (check that nono can read this file)",
label,
path.display()
)),
_ => ProxyError::Config(format!(
"failed to read {} '{}': {}",
label,
path.display(),
e
)),
})
}
fn build_tls_connector(
ca_path: Option<&str>,
client_cert_path: Option<&str>,
client_key_path: Option<&str>,
) -> Result<tokio_rustls::TlsConnector> {
let mut root_store = rustls::RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
if let Some(ca_path) = ca_path {
let ca_path = std::path::Path::new(ca_path);
let ca_pem = read_pem_file(ca_path, "CA certificate")?;
let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_slice_iter(ca_pem.as_ref())
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| {
ProxyError::Config(format!(
"failed to parse CA certificate '{}': {}",
ca_path.display(),
e
))
})?;
if certs.is_empty() {
return Err(ProxyError::Config(format!(
"CA certificate file '{}' contains no valid PEM certificates",
ca_path.display()
)));
}
for cert in certs {
root_store.add(cert).map_err(|e| {
ProxyError::Config(format!(
"invalid CA certificate in '{}': {}",
ca_path.display(),
e
))
})?;
}
}
let builder = rustls::ClientConfig::builder_with_provider(Arc::new(
rustls::crypto::ring::default_provider(),
))
.with_safe_default_protocol_versions()
.map_err(|e| ProxyError::Config(format!("TLS config error: {}", e)))?
.with_root_certificates(root_store);
let tls_config = match (client_cert_path, client_key_path) {
(Some(cert_path), Some(key_path)) => {
let cert_path = std::path::Path::new(cert_path);
let key_path = std::path::Path::new(key_path);
let cert_pem = read_pem_file(cert_path, "client certificate")?;
let key_pem = read_pem_file(key_path, "client key")?;
let cert_chain: Vec<rustls::pki_types::CertificateDer> =
rustls::pki_types::CertificateDer::pem_slice_iter(cert_pem.as_ref())
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| {
ProxyError::Config(format!(
"failed to parse client certificate '{}': {}",
cert_path.display(),
e
))
})?;
if cert_chain.is_empty() {
return Err(ProxyError::Config(format!(
"client certificate file '{}' contains no valid PEM certificates",
cert_path.display()
)));
}
let private_key = rustls::pki_types::PrivateKeyDer::from_pem_slice(key_pem.as_ref())
.map_err(|e| match e {
rustls::pki_types::pem::Error::NoItemsFound => ProxyError::Config(format!(
"client key file '{}' contains no valid PEM private key",
key_path.display()
)),
_ => ProxyError::Config(format!(
"failed to parse client key '{}': {}",
key_path.display(),
e
)),
})?;
builder
.with_client_auth_cert(cert_chain, private_key)
.map_err(|e| {
ProxyError::Config(format!(
"invalid client certificate/key pair ('{}', '{}'): {}",
cert_path.display(),
key_path.display(),
e
))
})?
}
(Some(_), None) => {
return Err(ProxyError::Config(
"tls_client_cert is set but tls_client_key is missing".to_string(),
));
}
(None, Some(_)) => {
return Err(ProxyError::Config(
"tls_client_key is set but tls_client_cert is missing".to_string(),
));
}
(None, None) => builder.with_no_client_auth(),
};
let mut tls_config = tls_config;
if client_cert_path.is_some() {
tls_config.resumption = rustls::client::Resumption::disabled();
}
Ok(tokio_rustls::TlsConnector::from(Arc::new(tls_config)))
}
#[cfg(test)]
fn build_tls_connector_with_ca(ca_path: &str) -> Result<tokio_rustls::TlsConnector> {
build_tls_connector(Some(ca_path), None, None)
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use crate::config::EndpointRule;
#[test]
fn test_empty_route_store() {
let store = RouteStore::empty();
assert!(store.is_empty());
assert_eq!(store.len(), 0);
assert!(store.get("openai").is_none());
}
#[test]
fn test_load_routes_without_credentials() {
let routes = vec![RouteConfig {
prefix: "/openai".to_string(),
upstream: "https://api.openai.com".to_string(),
credential_key: None,
inject_mode: Default::default(),
inject_header: "Authorization".to_string(),
credential_format: "Bearer {}".to_string(),
path_pattern: None,
path_replacement: None,
query_param_name: None,
proxy: None,
env_var: None,
endpoint_rules: vec![
EndpointRule {
method: "POST".to_string(),
path: "/v1/chat/completions".to_string(),
},
EndpointRule {
method: "GET".to_string(),
path: "/v1/models".to_string(),
},
],
tls_ca: None,
tls_client_cert: None,
tls_client_key: None,
oauth2: None,
}];
let store = RouteStore::load(&routes).unwrap();
assert_eq!(store.len(), 1);
let route = store.get("openai").unwrap();
assert_eq!(route.upstream, "https://api.openai.com");
assert!(route
.endpoint_rules
.is_allowed("POST", "/v1/chat/completions"));
assert!(route.endpoint_rules.is_allowed("GET", "/v1/models"));
assert!(!route
.endpoint_rules
.is_allowed("DELETE", "/v1/files/file-123"));
}
#[test]
fn test_load_routes_normalises_prefix() {
let routes = vec![RouteConfig {
prefix: "/anthropic/".to_string(),
upstream: "https://api.anthropic.com".to_string(),
credential_key: None,
inject_mode: Default::default(),
inject_header: "Authorization".to_string(),
credential_format: "Bearer {}".to_string(),
path_pattern: None,
path_replacement: None,
query_param_name: None,
proxy: None,
env_var: None,
endpoint_rules: vec![],
tls_ca: None,
tls_client_cert: None,
tls_client_key: None,
oauth2: None,
}];
let store = RouteStore::load(&routes).unwrap();
assert!(store.get("anthropic").is_some());
assert!(store.get("/anthropic/").is_none());
}
#[test]
fn test_is_route_upstream() {
let routes = vec![RouteConfig {
prefix: "openai".to_string(),
upstream: "https://api.openai.com".to_string(),
credential_key: None,
inject_mode: Default::default(),
inject_header: "Authorization".to_string(),
credential_format: "Bearer {}".to_string(),
path_pattern: None,
path_replacement: None,
query_param_name: None,
proxy: None,
env_var: None,
endpoint_rules: vec![],
tls_ca: None,
tls_client_cert: None,
tls_client_key: None,
oauth2: None,
}];
let store = RouteStore::load(&routes).unwrap();
assert!(store.is_route_upstream("api.openai.com:443"));
assert!(!store.is_route_upstream("github.com:443"));
}
#[test]
fn test_route_upstream_hosts() {
let routes = vec![
RouteConfig {
prefix: "openai".to_string(),
upstream: "https://api.openai.com".to_string(),
credential_key: None,
inject_mode: Default::default(),
inject_header: "Authorization".to_string(),
credential_format: "Bearer {}".to_string(),
path_pattern: None,
path_replacement: None,
query_param_name: None,
proxy: None,
env_var: None,
endpoint_rules: vec![],
tls_ca: None,
tls_client_cert: None,
tls_client_key: None,
oauth2: None,
},
RouteConfig {
prefix: "anthropic".to_string(),
upstream: "https://api.anthropic.com".to_string(),
credential_key: None,
inject_mode: Default::default(),
inject_header: "Authorization".to_string(),
credential_format: "Bearer {}".to_string(),
path_pattern: None,
path_replacement: None,
query_param_name: None,
proxy: None,
env_var: None,
endpoint_rules: vec![],
tls_ca: None,
tls_client_cert: None,
tls_client_key: None,
oauth2: None,
},
];
let store = RouteStore::load(&routes).unwrap();
let hosts = store.route_upstream_hosts();
assert!(hosts.contains("api.openai.com:443"));
assert!(hosts.contains("api.anthropic.com:443"));
assert_eq!(hosts.len(), 2);
}
#[test]
fn test_extract_host_port_https() {
assert_eq!(
extract_host_port("https://api.openai.com"),
Some("api.openai.com:443".to_string())
);
}
#[test]
fn test_extract_host_port_with_port() {
assert_eq!(
extract_host_port("https://api.example.com:8443"),
Some("api.example.com:8443".to_string())
);
}
#[test]
fn test_extract_host_port_http() {
assert_eq!(
extract_host_port("http://internal-service"),
Some("internal-service:80".to_string())
);
}
#[test]
fn test_extract_host_port_normalises_case() {
assert_eq!(
extract_host_port("https://API.Example.COM"),
Some("api.example.com:443".to_string())
);
}
#[test]
fn test_loaded_route_debug() {
let route = LoadedRoute {
upstream: "https://api.openai.com".to_string(),
upstream_host_port: Some("api.openai.com:443".to_string()),
endpoint_rules: CompiledEndpointRules::compile(&[]).unwrap(),
tls_connector: None,
};
let debug_output = format!("{:?}", route);
assert!(debug_output.contains("api.openai.com"));
assert!(debug_output.contains("has_custom_tls_ca"));
}
const TEST_CA_PEM: &str = "\
-----BEGIN CERTIFICATE-----
MIIBnjCCAUWgAwIBAgIUT0bpOJJvHdOdZt+gW1stR8VBgXowCgYIKoZIzj0EAwIw
FzEVMBMGA1UEAwwMbm9uby10ZXN0LWNhMCAXDTI1MDEwMTAwMDAwMFoYDzIxMjQx
MjA3MDAwMDAwWjAXMRUwEwYDVQQDDAxub25vLXRlc3QtY2EwWTATBgcqhkjOPQIB
BggqhkjOPQMBBwNCAAR8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAo1MwUTAdBgNVHQ4EFgQUAAAAAAAAAAAAAAAAAAAAAAAA
AAAAMB8GA1UdIwQYMBaAFAAAAAAAAAAAAAAAAAAAAAAAAAAAADAPBgNVHRMBAf8E
BTADAQH/MAoGCCqGSM49BAMCA0cAMEQCIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAICAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
-----END CERTIFICATE-----";
#[test]
fn test_build_tls_connector_with_valid_ca() {
let dir = tempfile::tempdir().unwrap();
let ca_path = dir.path().join("ca.pem");
std::fs::write(&ca_path, TEST_CA_PEM).unwrap();
let result = build_tls_connector_with_ca(ca_path.to_str().unwrap());
match result {
Ok(connector) => {
drop(connector);
}
Err(ProxyError::Config(msg)) => {
assert!(
msg.contains("invalid CA certificate") || msg.contains("CA certificate"),
"unexpected error: {}",
msg
);
}
Err(e) => panic!("unexpected error type: {}", e),
}
}
#[test]
fn test_build_tls_connector_missing_file() {
let result = build_tls_connector_with_ca("/nonexistent/path/ca.pem");
let err = result
.err()
.expect("should fail for missing file")
.to_string();
assert!(
err.contains("CA certificate file not found"),
"unexpected error: {}",
err
);
}
#[test]
fn test_build_tls_connector_empty_pem() {
let dir = tempfile::tempdir().unwrap();
let ca_path = dir.path().join("empty.pem");
std::fs::write(&ca_path, "not a certificate\n").unwrap();
let result = build_tls_connector_with_ca(ca_path.to_str().unwrap());
let err = result
.err()
.expect("should fail for invalid PEM")
.to_string();
assert!(
err.contains("no valid PEM certificates"),
"unexpected error: {}",
err
);
}
const TEST_CLIENT_CERT_PEM: &str = "\
-----BEGIN CERTIFICATE-----
MIIBijCCATGgAwIBAgIUEoEb+0z+4CTRCzN98MqeTEXgdO8wCgYIKoZIzj0EAwIw
GzEZMBcGA1UEAwwQbm9uby10ZXN0LWNsaWVudDAeFw0yNjA0MTAwMDIwNTdaFw0z
NjA0MDcwMDIwNTdaMBsxGTAXBgNVBAMMEG5vbm8tdGVzdC1jbGllbnQwWTATBgcq
hkjOPQIBBggqhkjOPQMBBwNCAASt6g2Zt0STlgF+wZ64JzdDRlpPeNr1h56ZLEEq
HfVWFhJWIKRSabtxYPV/VJyMv+lo3L0QwSKsouHs3dtF1zVQo1MwUTAdBgNVHQ4E
FgQUTiHidg8uqgrJ1qlaVvR+XSebAlEwHwYDVR0jBBgwFoAUTiHidg8uqgrJ1qla
VvR+XSebAlEwDwYDVR0TAQH/BAUwAwEB/zAKBggqhkjOPQQDAgNHADBEAiA9PwBU
f832cQkGS9cyYaU7Ij5U8Rcy/g4J7Ckf2nKX3gIgG0aarAFcIzAi5VpxbCwEScnr
m0lHTyp6E7ut7llwMBY=
-----END CERTIFICATE-----";
const TEST_CLIENT_KEY_PEM: &str = "\
-----BEGIN PRIVATE KEY-----
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgskOkyJkTwlMZkm/L
eEleLY6bARaHFnqauYJqxNoJWvihRANCAASt6g2Zt0STlgF+wZ64JzdDRlpPeNr1
h56ZLEEqHfVWFhJWIKRSabtxYPV/VJyMv+lo3L0QwSKsouHs3dtF1zVQ
-----END PRIVATE KEY-----";
#[test]
fn test_build_tls_connector_cert_without_key_errors() {
let dir = tempfile::tempdir().unwrap();
let cert_path = dir.path().join("client.crt");
std::fs::write(&cert_path, TEST_CLIENT_CERT_PEM).unwrap();
let result = build_tls_connector(None, Some(cert_path.to_str().unwrap()), None);
let err = result
.err()
.expect("should fail with half-pair")
.to_string();
assert!(
err.contains("tls_client_cert is set but tls_client_key is missing"),
"unexpected error: {}",
err
);
}
#[test]
fn test_build_tls_connector_key_without_cert_errors() {
let dir = tempfile::tempdir().unwrap();
let key_path = dir.path().join("client.key");
std::fs::write(&key_path, TEST_CLIENT_KEY_PEM).unwrap();
let result = build_tls_connector(None, None, Some(key_path.to_str().unwrap()));
let err = result
.err()
.expect("should fail with half-pair")
.to_string();
assert!(
err.contains("tls_client_key is set but tls_client_cert is missing"),
"unexpected error: {}",
err
);
}
#[test]
fn test_build_tls_connector_missing_client_cert_file() {
let dir = tempfile::tempdir().unwrap();
let key_path = dir.path().join("client.key");
std::fs::write(&key_path, TEST_CLIENT_KEY_PEM).unwrap();
let result = build_tls_connector(
None,
Some("/nonexistent/client.crt"),
Some(key_path.to_str().unwrap()),
);
let err = result.err().expect("should fail").to_string();
assert!(
err.contains("client certificate file not found"),
"unexpected error: {}",
err
);
}
#[test]
fn test_build_tls_connector_missing_client_key_file() {
let dir = tempfile::tempdir().unwrap();
let cert_path = dir.path().join("client.crt");
std::fs::write(&cert_path, TEST_CLIENT_CERT_PEM).unwrap();
let result = build_tls_connector(
None,
Some(cert_path.to_str().unwrap()),
Some("/nonexistent/client.key"),
);
let err = result.err().expect("should fail").to_string();
assert!(
err.contains("client key file not found"),
"unexpected error: {}",
err
);
}
#[test]
#[cfg(unix)]
fn test_build_tls_connector_permission_denied() {
use std::os::unix::fs::PermissionsExt;
let dir = tempfile::tempdir().unwrap();
let cert_path = dir.path().join("client.crt");
std::fs::write(&cert_path, TEST_CLIENT_CERT_PEM).unwrap();
std::fs::set_permissions(&cert_path, std::fs::Permissions::from_mode(0o000)).unwrap();
if std::fs::read(&cert_path).is_ok() {
return;
}
let result = build_tls_connector(
None,
Some(cert_path.to_str().unwrap()),
Some("/nonexistent/key"),
);
let err = result.err().expect("should fail").to_string();
assert!(
err.contains("permission denied"),
"expected permission denied error, got: {}",
err
);
}
#[test]
fn test_build_tls_connector_empty_client_cert_pem() {
let dir = tempfile::tempdir().unwrap();
let cert_path = dir.path().join("client.crt");
let key_path = dir.path().join("client.key");
std::fs::write(&cert_path, "not a certificate\n").unwrap();
std::fs::write(&key_path, TEST_CLIENT_KEY_PEM).unwrap();
let result = build_tls_connector(
None,
Some(cert_path.to_str().unwrap()),
Some(key_path.to_str().unwrap()),
);
let err = result.err().expect("should fail").to_string();
assert!(
err.contains("no valid PEM certificates"),
"unexpected error: {}",
err
);
}
#[test]
fn test_build_tls_connector_empty_client_key_pem() {
let dir = tempfile::tempdir().unwrap();
let cert_path = dir.path().join("client.crt");
let key_path = dir.path().join("client.key");
std::fs::write(&cert_path, TEST_CLIENT_CERT_PEM).unwrap();
std::fs::write(&key_path, "not a key\n").unwrap();
let result = build_tls_connector(
None,
Some(cert_path.to_str().unwrap()),
Some(key_path.to_str().unwrap()),
);
let err = result
.err()
.expect("should fail with invalid PEM")
.to_string();
assert!(err.contains("client key"), "unexpected error: {}", err);
}
#[test]
fn test_route_store_loads_mtls_route() {
let dir = tempfile::tempdir().unwrap();
let cert_path = dir.path().join("client.crt");
let key_path = dir.path().join("client.key");
std::fs::write(&cert_path, TEST_CLIENT_CERT_PEM).unwrap();
std::fs::write(&key_path, TEST_CLIENT_KEY_PEM).unwrap();
let routes = vec![RouteConfig {
prefix: "k8s".to_string(),
upstream: "https://192.168.64.1:6443".to_string(),
credential_key: None,
inject_mode: Default::default(),
inject_header: "Authorization".to_string(),
credential_format: "Bearer {}".to_string(),
path_pattern: None,
path_replacement: None,
query_param_name: None,
proxy: None,
env_var: None,
endpoint_rules: vec![],
tls_ca: None,
tls_client_cert: Some(cert_path.to_str().unwrap().to_string()),
tls_client_key: Some(key_path.to_str().unwrap().to_string()),
oauth2: None,
}];
let store = RouteStore::load(&routes).expect("should load mTLS route");
let route = store.get("k8s").unwrap();
assert!(
route.tls_connector.is_some(),
"connector must be built when tls_client_cert/key are set"
);
}
}