1use async_trait::async_trait;
45use parking_lot::RwLock;
46use pingora::prelude::*;
47use pingora_load_balancing::discovery::{ServiceDiscovery, Static as StaticDiscovery};
48use pingora_load_balancing::Backend;
49use std::collections::{BTreeSet, HashMap};
50use std::net::ToSocketAddrs;
51use std::sync::Arc;
52use std::time::{Duration, Instant};
53use tracing::{debug, error, info, trace, warn};
54
55#[derive(Debug, Clone)]
57pub enum DiscoveryConfig {
58 Static {
60 backends: Vec<String>,
62 },
63 Dns {
65 hostname: String,
67 port: u16,
69 refresh_interval: Duration,
71 },
72 DnsSrv {
74 service: String,
76 refresh_interval: Duration,
78 },
79 Consul {
81 address: String,
83 service: String,
85 datacenter: Option<String>,
87 only_passing: bool,
89 refresh_interval: Duration,
91 tag: Option<String>,
93 },
94 Kubernetes {
96 namespace: String,
98 service: String,
100 port_name: Option<String>,
102 refresh_interval: Duration,
104 kubeconfig: Option<String>,
106 },
107 File {
109 path: String,
111 watch_interval: Duration,
113 },
114}
115
116impl Default for DiscoveryConfig {
117 fn default() -> Self {
118 Self::Static {
119 backends: vec!["127.0.0.1:8080".to_string()],
120 }
121 }
122}
123
124pub struct DnsDiscovery {
128 hostname: String,
129 port: u16,
130 refresh_interval: Duration,
131 cached_backends: RwLock<BTreeSet<Backend>>,
133 last_resolution: RwLock<Instant>,
135}
136
137impl DnsDiscovery {
138 pub fn new(hostname: String, port: u16, refresh_interval: Duration) -> Self {
140 Self {
141 hostname,
142 port,
143 refresh_interval,
144 cached_backends: RwLock::new(BTreeSet::new()),
145 last_resolution: RwLock::new(Instant::now() - refresh_interval),
146 }
147 }
148
149 fn resolve(&self) -> Result<BTreeSet<Backend>, Box<Error>> {
151 let address = format!("{}:{}", self.hostname, self.port);
152
153 trace!(
154 hostname = %self.hostname,
155 port = self.port,
156 "Resolving DNS for service discovery"
157 );
158
159 match address.to_socket_addrs() {
160 Ok(addrs) => {
161 let backends: BTreeSet<Backend> = addrs
162 .map(|addr| Backend {
163 addr: pingora_core::protocols::l4::socket::SocketAddr::Inet(addr),
164 weight: 1,
165 ext: http::Extensions::new(),
166 })
167 .collect();
168
169 debug!(
170 hostname = %self.hostname,
171 backend_count = backends.len(),
172 "DNS resolution successful"
173 );
174
175 Ok(backends)
176 }
177 Err(e) => {
178 error!(
179 hostname = %self.hostname,
180 error = %e,
181 "DNS resolution failed"
182 );
183 Err(Error::explain(
184 ErrorType::ConnectNoRoute,
185 format!("DNS resolution failed for {}: {}", self.hostname, e),
186 ))
187 }
188 }
189 }
190
191 fn needs_refresh(&self) -> bool {
193 let last = *self.last_resolution.read();
194 last.elapsed() >= self.refresh_interval
195 }
196}
197
198#[async_trait]
199impl ServiceDiscovery for DnsDiscovery {
200 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
201 if self.needs_refresh() {
203 match self.resolve() {
204 Ok(backends) => {
205 *self.cached_backends.write() = backends;
206 *self.last_resolution.write() = Instant::now();
207 }
208 Err(e) => {
209 let cached = self.cached_backends.read().clone();
211 if !cached.is_empty() {
212 warn!(
213 hostname = %self.hostname,
214 error = %e,
215 cached_count = cached.len(),
216 "DNS resolution failed, using cached backends"
217 );
218 return Ok((cached, HashMap::new()));
219 }
220 return Err(e);
221 }
222 }
223 }
224
225 let backends = self.cached_backends.read().clone();
226 Ok((backends, HashMap::new()))
227 }
228}
229
230pub struct ConsulDiscovery {
238 address: String,
240 service: String,
242 datacenter: Option<String>,
244 only_passing: bool,
246 refresh_interval: Duration,
248 tag: Option<String>,
250 cached_backends: RwLock<BTreeSet<Backend>>,
252 last_resolution: RwLock<Instant>,
254}
255
256impl ConsulDiscovery {
257 pub fn new(
259 address: String,
260 service: String,
261 datacenter: Option<String>,
262 only_passing: bool,
263 refresh_interval: Duration,
264 tag: Option<String>,
265 ) -> Self {
266 Self {
267 address,
268 service,
269 datacenter,
270 only_passing,
271 refresh_interval,
272 tag,
273 cached_backends: RwLock::new(BTreeSet::new()),
274 last_resolution: RwLock::new(Instant::now() - refresh_interval),
275 }
276 }
277
278 fn build_url(&self) -> String {
280 let mut url = format!(
281 "{}/v1/health/service/{}",
282 self.address.trim_end_matches('/'),
283 self.service
284 );
285
286 let mut params = Vec::new();
287 if self.only_passing {
288 params.push("passing=true".to_string());
289 }
290 if let Some(dc) = &self.datacenter {
291 params.push(format!("dc={}", dc));
292 }
293 if let Some(tag) = &self.tag {
294 params.push(format!("tag={}", tag));
295 }
296
297 if !params.is_empty() {
298 url.push('?');
299 url.push_str(¶ms.join("&"));
300 }
301
302 url
303 }
304
305 fn needs_refresh(&self) -> bool {
307 let last = *self.last_resolution.read();
308 last.elapsed() >= self.refresh_interval
309 }
310}
311
312#[async_trait]
313impl ServiceDiscovery for ConsulDiscovery {
314 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
315 if !self.needs_refresh() {
316 let backends = self.cached_backends.read().clone();
317 return Ok((backends, HashMap::new()));
318 }
319
320 let url = self.build_url();
321 trace!(
322 service = %self.service,
323 url = %url,
324 "Querying Consul for service discovery"
325 );
326
327 let result = tokio::task::spawn_blocking({
330 let url = url.clone();
331 let service = self.service.clone();
332 move || -> Result<BTreeSet<Backend>, Box<Error>> {
333 let url_parsed = url
336 .trim_start_matches("http://")
337 .trim_start_matches("https://");
338 let (host_port, path) = url_parsed.split_once('/').unwrap_or((url_parsed, ""));
339
340 let socket_addr = host_port
341 .to_socket_addrs()
342 .map_err(|e| {
343 Error::explain(
344 ErrorType::ConnectNoRoute,
345 format!("Failed to resolve Consul address: {}", e),
346 )
347 })?
348 .next()
349 .ok_or_else(|| {
350 Error::explain(
351 ErrorType::ConnectNoRoute,
352 "Failed to resolve Consul address",
353 )
354 })?;
355
356 let stream = match std::net::TcpStream::connect_timeout(
357 &socket_addr,
358 Duration::from_secs(5),
359 ) {
360 Ok(s) => s,
361 Err(e) => {
362 return Err(Error::explain(
363 ErrorType::ConnectTimedout,
364 format!("Failed to connect to Consul: {}", e),
365 ));
366 }
367 };
368
369 stream
370 .set_read_timeout(Some(Duration::from_secs(10)))
371 .map_err(|e| {
372 Error::explain(
373 ErrorType::InternalError,
374 format!("Failed to set read timeout: {}", e),
375 )
376 })?;
377 stream
378 .set_write_timeout(Some(Duration::from_secs(5)))
379 .map_err(|e| {
380 Error::explain(
381 ErrorType::InternalError,
382 format!("Failed to set write timeout: {}", e),
383 )
384 })?;
385
386 use std::io::{Read, Write};
387 let request = format!(
388 "GET /{} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
389 path, host_port
390 );
391
392 let mut stream = stream;
393 stream.write_all(request.as_bytes()).map_err(|e| {
394 Error::explain(
395 ErrorType::WriteError,
396 format!("Failed to send request: {}", e),
397 )
398 })?;
399
400 let mut response = String::new();
401 stream.read_to_string(&mut response).map_err(|e| {
402 Error::explain(
403 ErrorType::ReadError,
404 format!("Failed to read response: {}", e),
405 )
406 })?;
407
408 let body = response.split("\r\n\r\n").nth(1).unwrap_or("");
410
411 let backends = parse_consul_response(body, &service)?;
414
415 Ok(backends)
416 }
417 })
418 .await
419 .map_err(|e| Error::explain(ErrorType::InternalError, format!("Task failed: {}", e)))?;
420
421 match result {
422 Ok(backends) => {
423 info!(
424 service = %self.service,
425 backend_count = backends.len(),
426 "Consul discovery successful"
427 );
428 *self.cached_backends.write() = backends.clone();
429 *self.last_resolution.write() = Instant::now();
430 Ok((backends, HashMap::new()))
431 }
432 Err(e) => {
433 let cached = self.cached_backends.read().clone();
434 if !cached.is_empty() {
435 warn!(
436 service = %self.service,
437 error = %e,
438 cached_count = cached.len(),
439 "Consul query failed, using cached backends"
440 );
441 return Ok((cached, HashMap::new()));
442 }
443 Err(e)
444 }
445 }
446 }
447}
448
449fn parse_consul_response(body: &str, service_name: &str) -> Result<BTreeSet<Backend>, Box<Error>> {
451 let mut backends = BTreeSet::new();
454
455 let entries: Vec<&str> = body.split(r#""Service":"#).skip(1).collect();
457
458 for entry in entries {
459 let port = entry
461 .split(r#""Port":"#)
462 .nth(1)
463 .and_then(|s| s.split(|c: char| !c.is_ascii_digit()).next())
464 .and_then(|s| s.parse::<u16>().ok());
465
466 let service_addr = entry
468 .split(r#""Address":""#)
469 .nth(1)
470 .and_then(|s| s.split('"').next())
471 .filter(|s| !s.is_empty());
472
473 let node_addr = body
475 .split(r#""Node":"#)
476 .nth(1)
477 .and_then(|s| s.split(r#""Address":""#).nth(1))
478 .and_then(|s| s.split('"').next());
479
480 let address = service_addr.or(node_addr);
481
482 if let (Some(addr), Some(port)) = (address, port) {
483 let full_addr = format!("{}:{}", addr, port);
484 if let Ok(mut addrs) = full_addr.to_socket_addrs() {
485 if let Some(socket_addr) = addrs.next() {
486 backends.insert(Backend {
487 addr: pingora_core::protocols::l4::socket::SocketAddr::Inet(socket_addr),
488 weight: 1,
489 ext: http::Extensions::new(),
490 });
491 }
492 }
493 }
494 }
495
496 if backends.is_empty() && !body.starts_with("[]") && !body.is_empty() {
497 warn!(
498 service = %service_name,
499 body_len = body.len(),
500 "Failed to parse Consul response, no backends found"
501 );
502 }
503
504 Ok(backends)
505}
506
507#[cfg(feature = "kubernetes")]
512use crate::kubeconfig::{KubeAuth, Kubeconfig, ResolvedKubeConfig};
513
514pub struct KubernetesDiscovery {
538 namespace: String,
540 service: String,
542 port_name: Option<String>,
544 refresh_interval: Duration,
546 kubeconfig: Option<String>,
548 cached_backends: RwLock<BTreeSet<Backend>>,
550 last_resolution: RwLock<Instant>,
552 #[cfg(feature = "kubernetes")]
554 resolved_config: RwLock<Option<ResolvedKubeConfig>>,
555}
556
557impl KubernetesDiscovery {
558 pub fn new(
560 namespace: String,
561 service: String,
562 port_name: Option<String>,
563 refresh_interval: Duration,
564 kubeconfig: Option<String>,
565 ) -> Self {
566 Self {
567 namespace,
568 service,
569 port_name,
570 refresh_interval,
571 kubeconfig,
572 cached_backends: RwLock::new(BTreeSet::new()),
573 last_resolution: RwLock::new(Instant::now() - refresh_interval),
574 #[cfg(feature = "kubernetes")]
575 resolved_config: RwLock::new(None),
576 }
577 }
578
579 fn needs_refresh(&self) -> bool {
581 let last = *self.last_resolution.read();
582 last.elapsed() >= self.refresh_interval
583 }
584
585 fn get_in_cluster_config(&self) -> Result<(String, String), Box<Error>> {
587 let host = std::env::var("KUBERNETES_SERVICE_HOST").map_err(|_| {
588 Error::explain(
589 ErrorType::InternalError,
590 "KUBERNETES_SERVICE_HOST not set, not running in Kubernetes?",
591 )
592 })?;
593 let port = std::env::var("KUBERNETES_SERVICE_PORT").unwrap_or_else(|_| "443".to_string());
594 let token = std::fs::read_to_string("/var/run/secrets/kubernetes.io/serviceaccount/token")
595 .map_err(|e| {
596 Error::explain(
597 ErrorType::InternalError,
598 format!("Failed to read service account token: {}", e),
599 )
600 })?;
601
602 Ok((format!("https://{}:{}", host, port), token.trim().to_string()))
603 }
604
605 #[cfg(feature = "kubernetes")]
607 fn load_kubeconfig(&self) -> Result<ResolvedKubeConfig, Box<Error>> {
608 if let Some(config) = self.resolved_config.read().as_ref() {
610 return Ok(config.clone());
611 }
612
613 let kubeconfig = if let Some(path) = &self.kubeconfig {
614 Kubeconfig::from_file(path).map_err(|e| {
615 Error::explain(
616 ErrorType::InternalError,
617 format!("Failed to load kubeconfig from {}: {}", path, e),
618 )
619 })?
620 } else {
621 Kubeconfig::from_default_location().map_err(|e| {
622 Error::explain(
623 ErrorType::InternalError,
624 format!("Failed to load kubeconfig from default location: {}", e),
625 )
626 })?
627 };
628
629 let resolved = kubeconfig.resolve_current().map_err(|e| {
630 Error::explain(
631 ErrorType::InternalError,
632 format!("Failed to resolve kubeconfig context: {}", e),
633 )
634 })?;
635
636 *self.resolved_config.write() = Some(resolved.clone());
638
639 Ok(resolved)
640 }
641}
642
643#[cfg(feature = "kubernetes")]
645mod k8s_types {
646 use serde::Deserialize;
647
648 #[derive(Debug, Deserialize)]
649 pub struct Endpoints {
650 pub subsets: Option<Vec<EndpointSubset>>,
651 }
652
653 #[derive(Debug, Deserialize)]
654 pub struct EndpointSubset {
655 pub addresses: Option<Vec<EndpointAddress>>,
656 pub ports: Option<Vec<EndpointPort>>,
657 }
658
659 #[derive(Debug, Deserialize)]
660 pub struct EndpointAddress {
661 pub ip: String,
662 pub hostname: Option<String>,
663 }
664
665 #[derive(Debug, Deserialize)]
666 pub struct EndpointPort {
667 pub name: Option<String>,
668 pub port: u16,
669 pub protocol: Option<String>,
670 }
671}
672
673#[cfg(feature = "kubernetes")]
674#[async_trait]
675impl ServiceDiscovery for KubernetesDiscovery {
676 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
677 if !self.needs_refresh() {
678 let backends = self.cached_backends.read().clone();
679 return Ok((backends, HashMap::new()));
680 }
681
682 trace!(
683 namespace = %self.namespace,
684 service = %self.service,
685 "Querying Kubernetes for endpoint discovery"
686 );
687
688 let (api_server, auth, ca_cert, skip_verify) = if self.kubeconfig.is_some() {
690 let config = self.load_kubeconfig()?;
691 (config.server, config.auth, config.ca_cert, config.insecure_skip_tls_verify)
692 } else {
693 match self.get_in_cluster_config() {
695 Ok((server, token)) => {
696 let ca = std::fs::read("/var/run/secrets/kubernetes.io/serviceaccount/ca.crt").ok();
698 (server, KubeAuth::Token(token), ca, false)
699 }
700 Err(e) => {
701 debug!(
703 error = %e,
704 "In-cluster config not available, trying default kubeconfig"
705 );
706 let config = self.load_kubeconfig()?;
707 (config.server, config.auth, config.ca_cert, config.insecure_skip_tls_verify)
708 }
709 }
710 };
711
712 let url = format!(
714 "{}/api/v1/namespaces/{}/endpoints/{}",
715 api_server.trim_end_matches('/'),
716 self.namespace,
717 self.service
718 );
719
720 debug!(
721 url = %url,
722 namespace = %self.namespace,
723 service = %self.service,
724 "Fetching Kubernetes endpoints"
725 );
726
727 let client_builder = reqwest::Client::builder()
729 .timeout(Duration::from_secs(10))
730 .danger_accept_invalid_certs(skip_verify);
731
732 let client_builder = if let Some(ca_data) = ca_cert {
734 let cert = reqwest::Certificate::from_pem(&ca_data).map_err(|e| {
735 Error::explain(
736 ErrorType::InternalError,
737 format!("Failed to parse CA certificate: {}", e),
738 )
739 })?;
740 client_builder.add_root_certificate(cert)
741 } else {
742 client_builder
743 };
744
745 let client_builder = match &auth {
747 KubeAuth::ClientCert { cert, key } => {
748 let mut identity_pem = cert.clone();
750 identity_pem.extend_from_slice(key);
751 let identity = reqwest::Identity::from_pem(&identity_pem).map_err(|e| {
752 Error::explain(
753 ErrorType::InternalError,
754 format!("Failed to create client identity: {}", e),
755 )
756 })?;
757 client_builder.identity(identity)
758 }
759 _ => client_builder,
760 };
761
762 let client = client_builder.build().map_err(|e| {
763 Error::explain(
764 ErrorType::InternalError,
765 format!("Failed to create HTTP client: {}", e),
766 )
767 })?;
768
769 let mut request = client.get(&url);
771 if let KubeAuth::Token(token) = &auth {
772 request = request.bearer_auth(token);
773 }
774
775 let response = request.send().await.map_err(|e| {
777 Error::explain(
778 ErrorType::ConnectError,
779 format!("Failed to connect to Kubernetes API: {}", e),
780 )
781 })?;
782
783 if !response.status().is_success() {
784 let status = response.status();
785 let body = response.text().await.unwrap_or_default();
786 return Err(Error::explain(
787 ErrorType::HTTPStatus(status.as_u16()),
788 format!("Kubernetes API returned {}: {}", status, body),
789 ));
790 }
791
792 let endpoints: k8s_types::Endpoints = response.json().await.map_err(|e| {
794 Error::explain(
795 ErrorType::InternalError,
796 format!("Failed to parse Kubernetes endpoints: {}", e),
797 )
798 })?;
799
800 let mut backends = BTreeSet::new();
802 if let Some(subsets) = endpoints.subsets {
803 for subset in subsets {
804 let target_port = subset.ports.as_ref().and_then(|ports| {
806 if let Some(port_name) = &self.port_name {
807 ports.iter().find(|p| p.name.as_ref() == Some(port_name)).map(|p| p.port)
809 } else {
810 ports.first().map(|p| p.port)
812 }
813 });
814
815 if let (Some(addresses), Some(port)) = (subset.addresses, target_port) {
816 for addr in addresses {
817 let socket_addr = format!("{}:{}", addr.ip, port);
818 if let Ok(mut addrs) = socket_addr.to_socket_addrs() {
819 if let Some(socket_addr) = addrs.next() {
820 backends.insert(Backend {
821 addr: pingora_core::protocols::l4::socket::SocketAddr::Inet(socket_addr),
822 weight: 1,
823 ext: http::Extensions::new(),
824 });
825 }
826 }
827 }
828 }
829 }
830 }
831
832 info!(
833 service = %self.service,
834 namespace = %self.namespace,
835 backend_count = backends.len(),
836 "Kubernetes endpoint discovery successful"
837 );
838
839 *self.cached_backends.write() = backends.clone();
841 *self.last_resolution.write() = Instant::now();
842
843 Ok((backends, HashMap::new()))
844 }
845}
846
847#[cfg(not(feature = "kubernetes"))]
849#[async_trait]
850impl ServiceDiscovery for KubernetesDiscovery {
851 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
852 if !self.needs_refresh() {
853 let backends = self.cached_backends.read().clone();
854 return Ok((backends, HashMap::new()));
855 }
856
857 if self.kubeconfig.is_none() {
859 if let Ok((_, _)) = self.get_in_cluster_config() {
860 warn!(
861 service = %self.service,
862 "Kubernetes discovery requires 'kubernetes' feature flag for full support"
863 );
864 }
865 } else {
866 warn!(
867 service = %self.service,
868 kubeconfig = ?self.kubeconfig,
869 "Kubeconfig support requires 'kubernetes' feature flag"
870 );
871 }
872
873 let cached = self.cached_backends.read().clone();
874 Ok((cached, HashMap::new()))
875 }
876}
877
878pub struct DiscoveryManager {
883 discoveries: RwLock<HashMap<String, Arc<dyn ServiceDiscovery + Send + Sync>>>,
885}
886
887impl DiscoveryManager {
888 pub fn new() -> Self {
890 Self {
891 discoveries: RwLock::new(HashMap::new()),
892 }
893 }
894
895 pub fn register(&self, upstream_id: &str, config: DiscoveryConfig) -> Result<(), Box<Error>> {
897 let discovery: Arc<dyn ServiceDiscovery + Send + Sync> = match config {
898 DiscoveryConfig::Static { backends } => {
899 let backend_set = backends
900 .iter()
901 .filter_map(|addr| {
902 addr.to_socket_addrs()
903 .ok()
904 .and_then(|mut addrs| addrs.next())
905 .map(|addr| Backend {
906 addr: pingora_core::protocols::l4::socket::SocketAddr::Inet(addr),
907 weight: 1,
908 ext: http::Extensions::new(),
909 })
910 })
911 .collect();
912
913 info!(
914 upstream_id = %upstream_id,
915 backend_count = backends.len(),
916 "Registered static service discovery"
917 );
918
919 Arc::new(StaticWrapper(StaticDiscovery::new(backend_set)))
920 }
921 DiscoveryConfig::Dns {
922 hostname,
923 port,
924 refresh_interval,
925 } => {
926 info!(
927 upstream_id = %upstream_id,
928 hostname = %hostname,
929 port = port,
930 refresh_interval_secs = refresh_interval.as_secs(),
931 "Registered DNS service discovery"
932 );
933
934 Arc::new(DnsDiscovery::new(hostname, port, refresh_interval))
935 }
936 DiscoveryConfig::DnsSrv {
937 service,
938 refresh_interval,
939 } => {
940 info!(
941 upstream_id = %upstream_id,
942 service = %service,
943 refresh_interval_secs = refresh_interval.as_secs(),
944 "DNS SRV discovery not yet fully implemented, using DNS A record fallback"
945 );
946
947 let hostname = service
950 .split('.')
951 .skip_while(|s| s.starts_with('_'))
952 .collect::<Vec<_>>()
953 .join(".");
954 Arc::new(DnsDiscovery::new(hostname, 80, refresh_interval))
955 }
956 DiscoveryConfig::Consul {
957 address,
958 service,
959 datacenter,
960 only_passing,
961 refresh_interval,
962 tag,
963 } => {
964 info!(
965 upstream_id = %upstream_id,
966 address = %address,
967 service = %service,
968 datacenter = datacenter.as_deref().unwrap_or("default"),
969 only_passing = only_passing,
970 refresh_interval_secs = refresh_interval.as_secs(),
971 "Registered Consul service discovery"
972 );
973
974 Arc::new(ConsulDiscovery::new(
975 address,
976 service,
977 datacenter,
978 only_passing,
979 refresh_interval,
980 tag,
981 ))
982 }
983 DiscoveryConfig::Kubernetes {
984 namespace,
985 service,
986 port_name,
987 refresh_interval,
988 kubeconfig,
989 } => {
990 info!(
991 upstream_id = %upstream_id,
992 namespace = %namespace,
993 service = %service,
994 port_name = port_name.as_deref().unwrap_or("default"),
995 refresh_interval_secs = refresh_interval.as_secs(),
996 "Registered Kubernetes endpoint discovery"
997 );
998
999 Arc::new(KubernetesDiscovery::new(
1000 namespace,
1001 service,
1002 port_name,
1003 refresh_interval,
1004 kubeconfig,
1005 ))
1006 }
1007 DiscoveryConfig::File {
1008 path,
1009 watch_interval,
1010 } => {
1011 info!(
1012 upstream_id = %upstream_id,
1013 path = %path,
1014 watch_interval_secs = watch_interval.as_secs(),
1015 "File-based discovery not yet implemented, using empty static"
1016 );
1017
1018 Arc::new(StaticWrapper(StaticDiscovery::new(BTreeSet::new())))
1020 }
1021 };
1022
1023 self.discoveries
1024 .write()
1025 .insert(upstream_id.to_string(), discovery);
1026 Ok(())
1027 }
1028
1029 pub fn get(&self, upstream_id: &str) -> Option<Arc<dyn ServiceDiscovery + Send + Sync>> {
1031 self.discoveries.read().get(upstream_id).cloned()
1032 }
1033
1034 pub async fn discover(
1036 &self,
1037 upstream_id: &str,
1038 ) -> Option<Result<(BTreeSet<Backend>, HashMap<u64, bool>)>> {
1039 let discovery = self.get(upstream_id)?;
1040 Some(discovery.discover().await)
1041 }
1042
1043 pub fn remove(&self, upstream_id: &str) {
1045 self.discoveries.write().remove(upstream_id);
1046 }
1047
1048 pub fn count(&self) -> usize {
1050 self.discoveries.read().len()
1051 }
1052}
1053
1054impl Default for DiscoveryManager {
1055 fn default() -> Self {
1056 Self::new()
1057 }
1058}
1059
1060struct StaticWrapper(Box<StaticDiscovery>);
1062
1063#[async_trait]
1064impl ServiceDiscovery for StaticWrapper {
1065 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
1066 self.0.discover().await
1067 }
1068}
1069
1070unsafe impl Send for StaticWrapper {}
1072unsafe impl Sync for StaticWrapper {}
1073
1074#[cfg(test)]
1075mod tests {
1076 use super::*;
1077
1078 #[test]
1079 fn test_discovery_config_default() {
1080 let config = DiscoveryConfig::default();
1081 match config {
1082 DiscoveryConfig::Static { backends } => {
1083 assert_eq!(backends.len(), 1);
1084 assert_eq!(backends[0], "127.0.0.1:8080");
1085 }
1086 _ => panic!("Expected Static config"),
1087 }
1088 }
1089
1090 #[tokio::test]
1091 async fn test_discovery_manager() {
1092 let manager = DiscoveryManager::new();
1093
1094 manager
1096 .register(
1097 "test-upstream",
1098 DiscoveryConfig::Static {
1099 backends: vec!["127.0.0.1:8080".to_string(), "127.0.0.1:8081".to_string()],
1100 },
1101 )
1102 .unwrap();
1103
1104 assert_eq!(manager.count(), 1);
1105
1106 let result = manager.discover("test-upstream").await;
1108 assert!(result.is_some());
1109 let (backends, _) = result.unwrap().unwrap();
1110 assert_eq!(backends.len(), 2);
1111 }
1112
1113 #[test]
1114 fn test_dns_discovery_needs_refresh() {
1115 let discovery = DnsDiscovery::new(
1116 "localhost".to_string(),
1117 8080,
1118 Duration::from_secs(0), );
1120
1121 assert!(discovery.needs_refresh());
1123 }
1124
1125 #[test]
1126 fn test_consul_discovery_url_building() {
1127 let discovery = ConsulDiscovery::new(
1128 "http://localhost:8500".to_string(),
1129 "my-service".to_string(),
1130 Some("dc1".to_string()),
1131 true,
1132 Duration::from_secs(10),
1133 Some("production".to_string()),
1134 );
1135
1136 let url = discovery.build_url();
1137 assert!(url.starts_with("http://localhost:8500/v1/health/service/my-service"));
1138 assert!(url.contains("passing=true"));
1139 assert!(url.contains("dc=dc1"));
1140 assert!(url.contains("tag=production"));
1141 }
1142
1143 #[test]
1144 fn test_consul_discovery_url_minimal() {
1145 let discovery = ConsulDiscovery::new(
1146 "http://consul.local:8500".to_string(),
1147 "backend".to_string(),
1148 None,
1149 false,
1150 Duration::from_secs(30),
1151 None,
1152 );
1153
1154 let url = discovery.build_url();
1155 assert_eq!(url, "http://consul.local:8500/v1/health/service/backend");
1156 }
1157
1158 #[test]
1159 fn test_kubernetes_discovery_config() {
1160 let discovery = KubernetesDiscovery::new(
1161 "default".to_string(),
1162 "my-service".to_string(),
1163 Some("http".to_string()),
1164 Duration::from_secs(10),
1165 None,
1166 );
1167
1168 assert!(discovery.needs_refresh());
1170 }
1171
1172 #[test]
1173 fn test_parse_consul_response_empty() {
1174 let body = "[]";
1175 let backends = parse_consul_response(body, "test").unwrap();
1176 assert!(backends.is_empty());
1177 }
1178
1179 #[tokio::test]
1180 async fn test_discovery_manager_consul() {
1181 let manager = DiscoveryManager::new();
1182
1183 manager
1185 .register(
1186 "consul-upstream",
1187 DiscoveryConfig::Consul {
1188 address: "http://localhost:8500".to_string(),
1189 service: "my-service".to_string(),
1190 datacenter: Some("dc1".to_string()),
1191 only_passing: true,
1192 refresh_interval: Duration::from_secs(10),
1193 tag: None,
1194 },
1195 )
1196 .unwrap();
1197
1198 assert_eq!(manager.count(), 1);
1199 assert!(manager.get("consul-upstream").is_some());
1200 }
1201
1202 #[tokio::test]
1203 async fn test_discovery_manager_kubernetes() {
1204 let manager = DiscoveryManager::new();
1205
1206 manager
1208 .register(
1209 "k8s-upstream",
1210 DiscoveryConfig::Kubernetes {
1211 namespace: "production".to_string(),
1212 service: "api-server".to_string(),
1213 port_name: Some("http".to_string()),
1214 refresh_interval: Duration::from_secs(15),
1215 kubeconfig: None,
1216 },
1217 )
1218 .unwrap();
1219
1220 assert_eq!(manager.count(), 1);
1221 assert!(manager.get("k8s-upstream").is_some());
1222 }
1223}