1use crate::config::{CompiledEndpointRules, RouteConfig};
14use crate::error::{ProxyError, Result};
15use std::collections::HashMap;
16use std::sync::Arc;
17use tracing::debug;
18use zeroize::Zeroizing;
19
20pub struct LoadedRoute {
26 pub upstream: String,
28
29 pub upstream_host_port: Option<String>,
33
34 pub endpoint_rules: CompiledEndpointRules,
38
39 pub tls_connector: Option<tokio_rustls::TlsConnector>,
43}
44
45impl std::fmt::Debug for LoadedRoute {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 f.debug_struct("LoadedRoute")
48 .field("upstream", &self.upstream)
49 .field("upstream_host_port", &self.upstream_host_port)
50 .field("endpoint_rules", &self.endpoint_rules)
51 .field("has_custom_tls_ca", &self.tls_connector.is_some())
52 .finish()
53 }
54}
55
56#[derive(Debug)]
62pub struct RouteStore {
63 routes: HashMap<String, LoadedRoute>,
64}
65
66impl RouteStore {
67 pub fn load(routes: &[RouteConfig]) -> Result<Self> {
73 let mut loaded = HashMap::new();
74
75 for route in routes {
76 let normalized_prefix = route.prefix.trim_matches('/').to_string();
77
78 debug!(
79 "Loading route '{}' -> {}",
80 normalized_prefix, route.upstream
81 );
82
83 let endpoint_rules = CompiledEndpointRules::compile(&route.endpoint_rules)
84 .map_err(|e| ProxyError::Config(format!("route '{}': {}", normalized_prefix, e)))?;
85
86 let tls_connector = match route.tls_ca {
87 Some(ref ca_path) => {
88 debug!(
89 "Building TLS connector with custom CA for route '{}': {}",
90 normalized_prefix, ca_path
91 );
92 Some(build_tls_connector_with_ca(ca_path)?)
93 }
94 None => None,
95 };
96
97 let upstream_host_port = extract_host_port(&route.upstream);
98
99 loaded.insert(
100 normalized_prefix,
101 LoadedRoute {
102 upstream: route.upstream.clone(),
103 upstream_host_port,
104 endpoint_rules,
105 tls_connector,
106 },
107 );
108 }
109
110 Ok(Self { routes: loaded })
111 }
112
113 #[must_use]
115 pub fn empty() -> Self {
116 Self {
117 routes: HashMap::new(),
118 }
119 }
120
121 #[must_use]
123 pub fn get(&self, prefix: &str) -> Option<&LoadedRoute> {
124 self.routes.get(prefix)
125 }
126
127 #[must_use]
129 pub fn is_empty(&self) -> bool {
130 self.routes.is_empty()
131 }
132
133 #[must_use]
135 pub fn len(&self) -> usize {
136 self.routes.len()
137 }
138
139 #[must_use]
143 pub fn is_route_upstream(&self, host_port: &str) -> bool {
144 let normalised = host_port.to_lowercase();
145 self.routes.values().any(|route| {
146 route
147 .upstream_host_port
148 .as_ref()
149 .is_some_and(|hp| *hp == normalised)
150 })
151 }
152
153 #[must_use]
156 pub fn route_upstream_hosts(&self) -> std::collections::HashSet<String> {
157 self.routes
158 .values()
159 .filter_map(|route| route.upstream_host_port.clone())
160 .collect()
161 }
162}
163
164fn extract_host_port(url: &str) -> Option<String> {
169 let parsed = url::Url::parse(url).ok()?;
170 let host = parsed.host_str()?;
171 let default_port = match parsed.scheme() {
172 "https" => 443,
173 "http" => 80,
174 _ => return None,
175 };
176 let port = parsed.port().unwrap_or(default_port);
177 Some(format!("{}:{}", host.to_lowercase(), port))
178}
179
180fn build_tls_connector_with_ca(ca_path: &str) -> Result<tokio_rustls::TlsConnector> {
186 let ca_path = std::path::Path::new(ca_path);
187
188 let ca_pem = Zeroizing::new(std::fs::read(ca_path).map_err(|e| {
189 if e.kind() == std::io::ErrorKind::NotFound {
190 ProxyError::Config(format!(
191 "CA certificate file not found: '{}'",
192 ca_path.display()
193 ))
194 } else {
195 ProxyError::Config(format!(
196 "failed to read CA certificate '{}': {}",
197 ca_path.display(),
198 e
199 ))
200 }
201 })?);
202
203 let mut root_store = rustls::RootCertStore::empty();
204
205 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
207
208 let certs: Vec<_> = rustls_pemfile::certs(&mut ca_pem.as_slice())
210 .collect::<std::result::Result<Vec<_>, _>>()
211 .map_err(|e| {
212 ProxyError::Config(format!(
213 "failed to parse CA certificate '{}': {}",
214 ca_path.display(),
215 e
216 ))
217 })?;
218
219 if certs.is_empty() {
220 return Err(ProxyError::Config(format!(
221 "CA certificate file '{}' contains no valid PEM certificates",
222 ca_path.display()
223 )));
224 }
225
226 for cert in certs {
227 root_store.add(cert).map_err(|e| {
228 ProxyError::Config(format!(
229 "invalid CA certificate in '{}': {}",
230 ca_path.display(),
231 e
232 ))
233 })?;
234 }
235
236 let tls_config = rustls::ClientConfig::builder_with_provider(Arc::new(
237 rustls::crypto::ring::default_provider(),
238 ))
239 .with_safe_default_protocol_versions()
240 .map_err(|e| ProxyError::Config(format!("TLS config error: {}", e)))?
241 .with_root_certificates(root_store)
242 .with_no_client_auth();
243
244 Ok(tokio_rustls::TlsConnector::from(Arc::new(tls_config)))
245}
246
247#[cfg(test)]
248#[allow(clippy::unwrap_used)]
249mod tests {
250 use super::*;
251 use crate::config::EndpointRule;
252
253 #[test]
254 fn test_empty_route_store() {
255 let store = RouteStore::empty();
256 assert!(store.is_empty());
257 assert_eq!(store.len(), 0);
258 assert!(store.get("openai").is_none());
259 }
260
261 #[test]
262 fn test_load_routes_without_credentials() {
263 let routes = vec![RouteConfig {
265 prefix: "/openai".to_string(),
266 upstream: "https://api.openai.com".to_string(),
267 credential_key: None,
268 inject_mode: Default::default(),
269 inject_header: "Authorization".to_string(),
270 credential_format: "Bearer {}".to_string(),
271 path_pattern: None,
272 path_replacement: None,
273 query_param_name: None,
274 env_var: None,
275 endpoint_rules: vec![
276 EndpointRule {
277 method: "POST".to_string(),
278 path: "/v1/chat/completions".to_string(),
279 },
280 EndpointRule {
281 method: "GET".to_string(),
282 path: "/v1/models".to_string(),
283 },
284 ],
285 tls_ca: None,
286 }];
287
288 let store = RouteStore::load(&routes).unwrap();
289 assert_eq!(store.len(), 1);
290
291 let route = store.get("openai").unwrap();
292 assert_eq!(route.upstream, "https://api.openai.com");
293 assert!(route
294 .endpoint_rules
295 .is_allowed("POST", "/v1/chat/completions"));
296 assert!(route.endpoint_rules.is_allowed("GET", "/v1/models"));
297 assert!(!route
298 .endpoint_rules
299 .is_allowed("DELETE", "/v1/files/file-123"));
300 }
301
302 #[test]
303 fn test_load_routes_normalises_prefix() {
304 let routes = vec![RouteConfig {
305 prefix: "/anthropic/".to_string(),
306 upstream: "https://api.anthropic.com".to_string(),
307 credential_key: None,
308 inject_mode: Default::default(),
309 inject_header: "Authorization".to_string(),
310 credential_format: "Bearer {}".to_string(),
311 path_pattern: None,
312 path_replacement: None,
313 query_param_name: None,
314 env_var: None,
315 endpoint_rules: vec![],
316 tls_ca: None,
317 }];
318
319 let store = RouteStore::load(&routes).unwrap();
320 assert!(store.get("anthropic").is_some());
321 assert!(store.get("/anthropic/").is_none());
322 }
323
324 #[test]
325 fn test_is_route_upstream() {
326 let routes = vec![RouteConfig {
327 prefix: "openai".to_string(),
328 upstream: "https://api.openai.com".to_string(),
329 credential_key: None,
330 inject_mode: Default::default(),
331 inject_header: "Authorization".to_string(),
332 credential_format: "Bearer {}".to_string(),
333 path_pattern: None,
334 path_replacement: None,
335 query_param_name: None,
336 env_var: None,
337 endpoint_rules: vec![],
338 tls_ca: None,
339 }];
340
341 let store = RouteStore::load(&routes).unwrap();
342 assert!(store.is_route_upstream("api.openai.com:443"));
343 assert!(!store.is_route_upstream("github.com:443"));
344 }
345
346 #[test]
347 fn test_route_upstream_hosts() {
348 let routes = vec![
349 RouteConfig {
350 prefix: "openai".to_string(),
351 upstream: "https://api.openai.com".to_string(),
352 credential_key: None,
353 inject_mode: Default::default(),
354 inject_header: "Authorization".to_string(),
355 credential_format: "Bearer {}".to_string(),
356 path_pattern: None,
357 path_replacement: None,
358 query_param_name: None,
359 env_var: None,
360 endpoint_rules: vec![],
361 tls_ca: None,
362 },
363 RouteConfig {
364 prefix: "anthropic".to_string(),
365 upstream: "https://api.anthropic.com".to_string(),
366 credential_key: None,
367 inject_mode: Default::default(),
368 inject_header: "Authorization".to_string(),
369 credential_format: "Bearer {}".to_string(),
370 path_pattern: None,
371 path_replacement: None,
372 query_param_name: None,
373 env_var: None,
374 endpoint_rules: vec![],
375 tls_ca: None,
376 },
377 ];
378
379 let store = RouteStore::load(&routes).unwrap();
380 let hosts = store.route_upstream_hosts();
381 assert!(hosts.contains("api.openai.com:443"));
382 assert!(hosts.contains("api.anthropic.com:443"));
383 assert_eq!(hosts.len(), 2);
384 }
385
386 #[test]
387 fn test_extract_host_port_https() {
388 assert_eq!(
389 extract_host_port("https://api.openai.com"),
390 Some("api.openai.com:443".to_string())
391 );
392 }
393
394 #[test]
395 fn test_extract_host_port_with_port() {
396 assert_eq!(
397 extract_host_port("https://api.example.com:8443"),
398 Some("api.example.com:8443".to_string())
399 );
400 }
401
402 #[test]
403 fn test_extract_host_port_http() {
404 assert_eq!(
405 extract_host_port("http://internal-service"),
406 Some("internal-service:80".to_string())
407 );
408 }
409
410 #[test]
411 fn test_extract_host_port_normalises_case() {
412 assert_eq!(
413 extract_host_port("https://API.Example.COM"),
414 Some("api.example.com:443".to_string())
415 );
416 }
417
418 #[test]
419 fn test_loaded_route_debug() {
420 let route = LoadedRoute {
421 upstream: "https://api.openai.com".to_string(),
422 upstream_host_port: Some("api.openai.com:443".to_string()),
423 endpoint_rules: CompiledEndpointRules::compile(&[]).unwrap(),
424 tls_connector: None,
425 };
426 let debug_output = format!("{:?}", route);
427 assert!(debug_output.contains("api.openai.com"));
428 assert!(debug_output.contains("has_custom_tls_ca"));
429 }
430
431 const TEST_CA_PEM: &str = "\
435-----BEGIN CERTIFICATE-----
436MIIBnjCCAUWgAwIBAgIUT0bpOJJvHdOdZt+gW1stR8VBgXowCgYIKoZIzj0EAwIw
437FzEVMBMGA1UEAwwMbm9uby10ZXN0LWNhMCAXDTI1MDEwMTAwMDAwMFoYDzIxMjQx
438MjA3MDAwMDAwWjAXMRUwEwYDVQQDDAxub25vLXRlc3QtY2EwWTATBgcqhkjOPQIB
439BggqhkjOPQMBBwNCAAR8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
440AAAAAAAAAAAAAAAAAAAAo1MwUTAdBgNVHQ4EFgQUAAAAAAAAAAAAAAAAAAAAAAAA
441AAAAMB8GA1UdIwQYMBaAFAAAAAAAAAAAAAAAAAAAAAAAAAAAADAPBgNVHRMBAf8E
442BTADAQH/MAoGCCqGSM49BAMCA0cAMEQCIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
443AAAAAAAICAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
444-----END CERTIFICATE-----";
445
446 #[test]
447 fn test_build_tls_connector_with_valid_ca() {
448 let dir = tempfile::tempdir().unwrap();
449 let ca_path = dir.path().join("ca.pem");
450 std::fs::write(&ca_path, TEST_CA_PEM).unwrap();
451
452 let result = build_tls_connector_with_ca(ca_path.to_str().unwrap());
453 match result {
454 Ok(connector) => {
455 drop(connector);
456 }
457 Err(ProxyError::Config(msg)) => {
458 assert!(
459 msg.contains("invalid CA certificate") || msg.contains("CA certificate"),
460 "unexpected error: {}",
461 msg
462 );
463 }
464 Err(e) => panic!("unexpected error type: {}", e),
465 }
466 }
467
468 #[test]
469 fn test_build_tls_connector_missing_file() {
470 let result = build_tls_connector_with_ca("/nonexistent/path/ca.pem");
471 let err = result
472 .err()
473 .expect("should fail for missing file")
474 .to_string();
475 assert!(
476 err.contains("CA certificate file not found"),
477 "unexpected error: {}",
478 err
479 );
480 }
481
482 #[test]
483 fn test_build_tls_connector_empty_pem() {
484 let dir = tempfile::tempdir().unwrap();
485 let ca_path = dir.path().join("empty.pem");
486 std::fs::write(&ca_path, "not a certificate\n").unwrap();
487
488 let result = build_tls_connector_with_ca(ca_path.to_str().unwrap());
489 let err = result
490 .err()
491 .expect("should fail for invalid PEM")
492 .to_string();
493 assert!(
494 err.contains("no valid PEM certificates"),
495 "unexpected error: {}",
496 err
497 );
498 }
499}