1use async_trait::async_trait;
45use parking_lot::RwLock;
46use pingora::prelude::*;
47use pingora_load_balancing::discovery::{ServiceDiscovery, Static as StaticDiscovery};
48use pingora_load_balancing::Backend;
49use std::collections::{BTreeSet, HashMap};
50use std::net::ToSocketAddrs;
51use std::sync::Arc;
52use std::time::{Duration, Instant};
53use tracing::{debug, error, info, trace, warn};
54
55#[derive(Debug, Clone)]
57pub enum DiscoveryConfig {
58 Static {
60 backends: Vec<String>,
62 },
63 Dns {
65 hostname: String,
67 port: u16,
69 refresh_interval: Duration,
71 },
72 DnsSrv {
74 service: String,
76 refresh_interval: Duration,
78 },
79 Consul {
81 address: String,
83 service: String,
85 datacenter: Option<String>,
87 only_passing: bool,
89 refresh_interval: Duration,
91 tag: Option<String>,
93 },
94 Kubernetes {
96 namespace: String,
98 service: String,
100 port_name: Option<String>,
102 refresh_interval: Duration,
104 kubeconfig: Option<String>,
106 },
107 File {
109 path: String,
111 watch_interval: Duration,
113 },
114}
115
116impl Default for DiscoveryConfig {
117 fn default() -> Self {
118 Self::Static {
119 backends: vec!["127.0.0.1:8080".to_string()],
120 }
121 }
122}
123
124pub struct DnsDiscovery {
128 hostname: String,
129 port: u16,
130 refresh_interval: Duration,
131 cached_backends: RwLock<BTreeSet<Backend>>,
133 last_resolution: RwLock<Instant>,
135}
136
137impl DnsDiscovery {
138 pub fn new(hostname: String, port: u16, refresh_interval: Duration) -> Self {
140 Self {
141 hostname,
142 port,
143 refresh_interval,
144 cached_backends: RwLock::new(BTreeSet::new()),
145 last_resolution: RwLock::new(Instant::now() - refresh_interval),
146 }
147 }
148
149 fn resolve(&self) -> Result<BTreeSet<Backend>, Box<Error>> {
151 let address = format!("{}:{}", self.hostname, self.port);
152
153 trace!(
154 hostname = %self.hostname,
155 port = self.port,
156 "Resolving DNS for service discovery"
157 );
158
159 match address.to_socket_addrs() {
160 Ok(addrs) => {
161 let backends: BTreeSet<Backend> = addrs
162 .map(|addr| Backend {
163 addr: pingora_core::protocols::l4::socket::SocketAddr::Inet(addr),
164 weight: 1,
165 ext: http::Extensions::new(),
166 })
167 .collect();
168
169 debug!(
170 hostname = %self.hostname,
171 backend_count = backends.len(),
172 "DNS resolution successful"
173 );
174
175 Ok(backends)
176 }
177 Err(e) => {
178 error!(
179 hostname = %self.hostname,
180 error = %e,
181 "DNS resolution failed"
182 );
183 Err(Error::explain(
184 ErrorType::ConnectNoRoute,
185 format!("DNS resolution failed for {}: {}", self.hostname, e),
186 ))
187 }
188 }
189 }
190
191 fn needs_refresh(&self) -> bool {
193 let last = *self.last_resolution.read();
194 last.elapsed() >= self.refresh_interval
195 }
196}
197
198#[async_trait]
199impl ServiceDiscovery for DnsDiscovery {
200 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
201 if self.needs_refresh() {
203 match self.resolve() {
204 Ok(backends) => {
205 *self.cached_backends.write() = backends;
206 *self.last_resolution.write() = Instant::now();
207 }
208 Err(e) => {
209 let cached = self.cached_backends.read().clone();
211 if !cached.is_empty() {
212 warn!(
213 hostname = %self.hostname,
214 error = %e,
215 cached_count = cached.len(),
216 "DNS resolution failed, using cached backends"
217 );
218 return Ok((cached, HashMap::new()));
219 }
220 return Err(e);
221 }
222 }
223 }
224
225 let backends = self.cached_backends.read().clone();
226 Ok((backends, HashMap::new()))
227 }
228}
229
230pub struct ConsulDiscovery {
238 address: String,
240 service: String,
242 datacenter: Option<String>,
244 only_passing: bool,
246 refresh_interval: Duration,
248 tag: Option<String>,
250 cached_backends: RwLock<BTreeSet<Backend>>,
252 last_resolution: RwLock<Instant>,
254}
255
256impl ConsulDiscovery {
257 pub fn new(
259 address: String,
260 service: String,
261 datacenter: Option<String>,
262 only_passing: bool,
263 refresh_interval: Duration,
264 tag: Option<String>,
265 ) -> Self {
266 Self {
267 address,
268 service,
269 datacenter,
270 only_passing,
271 refresh_interval,
272 tag,
273 cached_backends: RwLock::new(BTreeSet::new()),
274 last_resolution: RwLock::new(Instant::now() - refresh_interval),
275 }
276 }
277
278 fn build_url(&self) -> String {
280 let mut url = format!(
281 "{}/v1/health/service/{}",
282 self.address.trim_end_matches('/'),
283 self.service
284 );
285
286 let mut params = Vec::new();
287 if self.only_passing {
288 params.push("passing=true".to_string());
289 }
290 if let Some(dc) = &self.datacenter {
291 params.push(format!("dc={}", dc));
292 }
293 if let Some(tag) = &self.tag {
294 params.push(format!("tag={}", tag));
295 }
296
297 if !params.is_empty() {
298 url.push('?');
299 url.push_str(¶ms.join("&"));
300 }
301
302 url
303 }
304
305 fn needs_refresh(&self) -> bool {
307 let last = *self.last_resolution.read();
308 last.elapsed() >= self.refresh_interval
309 }
310}
311
312#[async_trait]
313impl ServiceDiscovery for ConsulDiscovery {
314 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
315 if !self.needs_refresh() {
316 let backends = self.cached_backends.read().clone();
317 return Ok((backends, HashMap::new()));
318 }
319
320 let url = self.build_url();
321 trace!(
322 service = %self.service,
323 url = %url,
324 "Querying Consul for service discovery"
325 );
326
327 let result = tokio::task::spawn_blocking({
330 let url = url.clone();
331 let service = self.service.clone();
332 move || -> Result<BTreeSet<Backend>, Box<Error>> {
333 let url_parsed = url
336 .trim_start_matches("http://")
337 .trim_start_matches("https://");
338 let (host_port, path) = url_parsed.split_once('/').unwrap_or((url_parsed, ""));
339
340 let socket_addr = host_port
341 .to_socket_addrs()
342 .map_err(|e| {
343 Error::explain(
344 ErrorType::ConnectNoRoute,
345 format!("Failed to resolve Consul address: {}", e),
346 )
347 })?
348 .next()
349 .ok_or_else(|| {
350 Error::explain(
351 ErrorType::ConnectNoRoute,
352 "Failed to resolve Consul address",
353 )
354 })?;
355
356 let stream = match std::net::TcpStream::connect_timeout(
357 &socket_addr,
358 Duration::from_secs(5),
359 ) {
360 Ok(s) => s,
361 Err(e) => {
362 return Err(Error::explain(
363 ErrorType::ConnectTimedout,
364 format!("Failed to connect to Consul: {}", e),
365 ));
366 }
367 };
368
369 stream
370 .set_read_timeout(Some(Duration::from_secs(10)))
371 .map_err(|e| {
372 Error::explain(
373 ErrorType::InternalError,
374 format!("Failed to set read timeout: {}", e),
375 )
376 })?;
377 stream
378 .set_write_timeout(Some(Duration::from_secs(5)))
379 .map_err(|e| {
380 Error::explain(
381 ErrorType::InternalError,
382 format!("Failed to set write timeout: {}", e),
383 )
384 })?;
385
386 use std::io::{Read, Write};
387 let request = format!(
388 "GET /{} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
389 path, host_port
390 );
391
392 let mut stream = stream;
393 stream.write_all(request.as_bytes()).map_err(|e| {
394 Error::explain(
395 ErrorType::WriteError,
396 format!("Failed to send request: {}", e),
397 )
398 })?;
399
400 let mut response = String::new();
401 stream.read_to_string(&mut response).map_err(|e| {
402 Error::explain(
403 ErrorType::ReadError,
404 format!("Failed to read response: {}", e),
405 )
406 })?;
407
408 let body = response.split("\r\n\r\n").nth(1).unwrap_or("");
410
411 let backends = parse_consul_response(body, &service)?;
414
415 Ok(backends)
416 }
417 })
418 .await
419 .map_err(|e| Error::explain(ErrorType::InternalError, format!("Task failed: {}", e)))?;
420
421 match result {
422 Ok(backends) => {
423 info!(
424 service = %self.service,
425 backend_count = backends.len(),
426 "Consul discovery successful"
427 );
428 *self.cached_backends.write() = backends.clone();
429 *self.last_resolution.write() = Instant::now();
430 Ok((backends, HashMap::new()))
431 }
432 Err(e) => {
433 let cached = self.cached_backends.read().clone();
434 if !cached.is_empty() {
435 warn!(
436 service = %self.service,
437 error = %e,
438 cached_count = cached.len(),
439 "Consul query failed, using cached backends"
440 );
441 return Ok((cached, HashMap::new()));
442 }
443 Err(e)
444 }
445 }
446 }
447}
448
449fn parse_consul_response(body: &str, service_name: &str) -> Result<BTreeSet<Backend>, Box<Error>> {
451 let mut backends = BTreeSet::new();
454
455 let entries: Vec<&str> = body.split(r#""Service":"#).skip(1).collect();
457
458 for entry in entries {
459 let port = entry
461 .split(r#""Port":"#)
462 .nth(1)
463 .and_then(|s| s.split(|c: char| !c.is_ascii_digit()).next())
464 .and_then(|s| s.parse::<u16>().ok());
465
466 let service_addr = entry
468 .split(r#""Address":""#)
469 .nth(1)
470 .and_then(|s| s.split('"').next())
471 .filter(|s| !s.is_empty());
472
473 let node_addr = body
475 .split(r#""Node":"#)
476 .nth(1)
477 .and_then(|s| s.split(r#""Address":""#).nth(1))
478 .and_then(|s| s.split('"').next());
479
480 let address = service_addr.or(node_addr);
481
482 if let (Some(addr), Some(port)) = (address, port) {
483 let full_addr = format!("{}:{}", addr, port);
484 if let Ok(mut addrs) = full_addr.to_socket_addrs() {
485 if let Some(socket_addr) = addrs.next() {
486 backends.insert(Backend {
487 addr: pingora_core::protocols::l4::socket::SocketAddr::Inet(socket_addr),
488 weight: 1,
489 ext: http::Extensions::new(),
490 });
491 }
492 }
493 }
494 }
495
496 if backends.is_empty() && !body.starts_with("[]") && !body.is_empty() {
497 warn!(
498 service = %service_name,
499 body_len = body.len(),
500 "Failed to parse Consul response, no backends found"
501 );
502 }
503
504 Ok(backends)
505}
506
507#[cfg(feature = "kubernetes")]
512use crate::kubeconfig::{KubeAuth, Kubeconfig, ResolvedKubeConfig};
513
514pub struct KubernetesDiscovery {
538 namespace: String,
540 service: String,
542 port_name: Option<String>,
544 refresh_interval: Duration,
546 kubeconfig: Option<String>,
548 cached_backends: RwLock<BTreeSet<Backend>>,
550 last_resolution: RwLock<Instant>,
552 #[cfg(feature = "kubernetes")]
554 resolved_config: RwLock<Option<ResolvedKubeConfig>>,
555}
556
557impl KubernetesDiscovery {
558 pub fn new(
560 namespace: String,
561 service: String,
562 port_name: Option<String>,
563 refresh_interval: Duration,
564 kubeconfig: Option<String>,
565 ) -> Self {
566 Self {
567 namespace,
568 service,
569 port_name,
570 refresh_interval,
571 kubeconfig,
572 cached_backends: RwLock::new(BTreeSet::new()),
573 last_resolution: RwLock::new(Instant::now() - refresh_interval),
574 #[cfg(feature = "kubernetes")]
575 resolved_config: RwLock::new(None),
576 }
577 }
578
579 fn needs_refresh(&self) -> bool {
581 let last = *self.last_resolution.read();
582 last.elapsed() >= self.refresh_interval
583 }
584
585 fn get_in_cluster_config(&self) -> Result<(String, String), Box<Error>> {
587 let host = std::env::var("KUBERNETES_SERVICE_HOST").map_err(|_| {
588 Error::explain(
589 ErrorType::InternalError,
590 "KUBERNETES_SERVICE_HOST not set, not running in Kubernetes?",
591 )
592 })?;
593 let port = std::env::var("KUBERNETES_SERVICE_PORT").unwrap_or_else(|_| "443".to_string());
594 let token = std::fs::read_to_string("/var/run/secrets/kubernetes.io/serviceaccount/token")
595 .map_err(|e| {
596 Error::explain(
597 ErrorType::InternalError,
598 format!("Failed to read service account token: {}", e),
599 )
600 })?;
601
602 Ok((
603 format!("https://{}:{}", host, port),
604 token.trim().to_string(),
605 ))
606 }
607
608 #[cfg(feature = "kubernetes")]
610 fn load_kubeconfig(&self) -> Result<ResolvedKubeConfig, Box<Error>> {
611 if let Some(config) = self.resolved_config.read().as_ref() {
613 return Ok(config.clone());
614 }
615
616 let kubeconfig = if let Some(path) = &self.kubeconfig {
617 Kubeconfig::from_file(path).map_err(|e| {
618 Error::explain(
619 ErrorType::InternalError,
620 format!("Failed to load kubeconfig from {}: {}", path, e),
621 )
622 })?
623 } else {
624 Kubeconfig::from_default_location().map_err(|e| {
625 Error::explain(
626 ErrorType::InternalError,
627 format!("Failed to load kubeconfig from default location: {}", e),
628 )
629 })?
630 };
631
632 let resolved = kubeconfig.resolve_current().map_err(|e| {
633 Error::explain(
634 ErrorType::InternalError,
635 format!("Failed to resolve kubeconfig context: {}", e),
636 )
637 })?;
638
639 *self.resolved_config.write() = Some(resolved.clone());
641
642 Ok(resolved)
643 }
644}
645
646#[cfg(feature = "kubernetes")]
648mod k8s_types {
649 use serde::Deserialize;
650
651 #[derive(Debug, Deserialize)]
652 pub struct Endpoints {
653 pub subsets: Option<Vec<EndpointSubset>>,
654 }
655
656 #[derive(Debug, Deserialize)]
657 pub struct EndpointSubset {
658 pub addresses: Option<Vec<EndpointAddress>>,
659 pub ports: Option<Vec<EndpointPort>>,
660 }
661
662 #[derive(Debug, Deserialize)]
663 pub struct EndpointAddress {
664 pub ip: String,
665 pub hostname: Option<String>,
666 }
667
668 #[derive(Debug, Deserialize)]
669 pub struct EndpointPort {
670 pub name: Option<String>,
671 pub port: u16,
672 pub protocol: Option<String>,
673 }
674}
675
676#[cfg(feature = "kubernetes")]
677#[async_trait]
678impl ServiceDiscovery for KubernetesDiscovery {
679 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
680 if !self.needs_refresh() {
681 let backends = self.cached_backends.read().clone();
682 return Ok((backends, HashMap::new()));
683 }
684
685 trace!(
686 namespace = %self.namespace,
687 service = %self.service,
688 "Querying Kubernetes for endpoint discovery"
689 );
690
691 let (api_server, auth, ca_cert, skip_verify) = if self.kubeconfig.is_some() {
693 let config = self.load_kubeconfig()?;
694 (
695 config.server,
696 config.auth,
697 config.ca_cert,
698 config.insecure_skip_tls_verify,
699 )
700 } else {
701 match self.get_in_cluster_config() {
703 Ok((server, token)) => {
704 let ca =
706 std::fs::read("/var/run/secrets/kubernetes.io/serviceaccount/ca.crt").ok();
707 (server, KubeAuth::Token(token), ca, false)
708 }
709 Err(e) => {
710 debug!(
712 error = %e,
713 "In-cluster config not available, trying default kubeconfig"
714 );
715 let config = self.load_kubeconfig()?;
716 (
717 config.server,
718 config.auth,
719 config.ca_cert,
720 config.insecure_skip_tls_verify,
721 )
722 }
723 }
724 };
725
726 let url = format!(
728 "{}/api/v1/namespaces/{}/endpoints/{}",
729 api_server.trim_end_matches('/'),
730 self.namespace,
731 self.service
732 );
733
734 debug!(
735 url = %url,
736 namespace = %self.namespace,
737 service = %self.service,
738 "Fetching Kubernetes endpoints"
739 );
740
741 let client_builder = reqwest::Client::builder()
743 .timeout(Duration::from_secs(10))
744 .danger_accept_invalid_certs(skip_verify);
745
746 let client_builder = if let Some(ca_data) = ca_cert {
748 let cert = reqwest::Certificate::from_pem(&ca_data).map_err(|e| {
749 Error::explain(
750 ErrorType::InternalError,
751 format!("Failed to parse CA certificate: {}", e),
752 )
753 })?;
754 client_builder.add_root_certificate(cert)
755 } else {
756 client_builder
757 };
758
759 let client_builder = match &auth {
761 KubeAuth::ClientCert { cert, key } => {
762 let mut identity_pem = cert.clone();
764 identity_pem.extend_from_slice(key);
765 let identity = reqwest::Identity::from_pem(&identity_pem).map_err(|e| {
766 Error::explain(
767 ErrorType::InternalError,
768 format!("Failed to create client identity: {}", e),
769 )
770 })?;
771 client_builder.identity(identity)
772 }
773 _ => client_builder,
774 };
775
776 let client = client_builder.build().map_err(|e| {
777 Error::explain(
778 ErrorType::InternalError,
779 format!("Failed to create HTTP client: {}", e),
780 )
781 })?;
782
783 let mut request = client.get(&url);
785 if let KubeAuth::Token(token) = &auth {
786 request = request.bearer_auth(token);
787 }
788
789 let response = request.send().await.map_err(|e| {
791 Error::explain(
792 ErrorType::ConnectError,
793 format!("Failed to connect to Kubernetes API: {}", e),
794 )
795 })?;
796
797 if !response.status().is_success() {
798 let status = response.status();
799 let body = response.text().await.unwrap_or_default();
800 return Err(Error::explain(
801 ErrorType::HTTPStatus(status.as_u16()),
802 format!("Kubernetes API returned {}: {}", status, body),
803 ));
804 }
805
806 let endpoints: k8s_types::Endpoints = response.json().await.map_err(|e| {
808 Error::explain(
809 ErrorType::InternalError,
810 format!("Failed to parse Kubernetes endpoints: {}", e),
811 )
812 })?;
813
814 let mut backends = BTreeSet::new();
816 if let Some(subsets) = endpoints.subsets {
817 for subset in subsets {
818 let target_port = subset.ports.as_ref().and_then(|ports| {
820 if let Some(port_name) = &self.port_name {
821 ports
823 .iter()
824 .find(|p| p.name.as_ref() == Some(port_name))
825 .map(|p| p.port)
826 } else {
827 ports.first().map(|p| p.port)
829 }
830 });
831
832 if let (Some(addresses), Some(port)) = (subset.addresses, target_port) {
833 for addr in addresses {
834 let socket_addr = format!("{}:{}", addr.ip, port);
835 if let Ok(mut addrs) = socket_addr.to_socket_addrs() {
836 if let Some(socket_addr) = addrs.next() {
837 backends.insert(Backend {
838 addr: pingora_core::protocols::l4::socket::SocketAddr::Inet(
839 socket_addr,
840 ),
841 weight: 1,
842 ext: http::Extensions::new(),
843 });
844 }
845 }
846 }
847 }
848 }
849 }
850
851 info!(
852 service = %self.service,
853 namespace = %self.namespace,
854 backend_count = backends.len(),
855 "Kubernetes endpoint discovery successful"
856 );
857
858 *self.cached_backends.write() = backends.clone();
860 *self.last_resolution.write() = Instant::now();
861
862 Ok((backends, HashMap::new()))
863 }
864}
865
866#[cfg(not(feature = "kubernetes"))]
868#[async_trait]
869impl ServiceDiscovery for KubernetesDiscovery {
870 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
871 if !self.needs_refresh() {
872 let backends = self.cached_backends.read().clone();
873 return Ok((backends, HashMap::new()));
874 }
875
876 if self.kubeconfig.is_none() {
878 if let Ok((_, _)) = self.get_in_cluster_config() {
879 warn!(
880 service = %self.service,
881 "Kubernetes discovery requires 'kubernetes' feature flag for full support"
882 );
883 }
884 } else {
885 warn!(
886 service = %self.service,
887 kubeconfig = ?self.kubeconfig,
888 "Kubeconfig support requires 'kubernetes' feature flag"
889 );
890 }
891
892 let cached = self.cached_backends.read().clone();
893 Ok((cached, HashMap::new()))
894 }
895}
896
897pub struct FileDiscovery {
930 path: String,
932 watch_interval: Duration,
934 cached_backends: RwLock<BTreeSet<Backend>>,
936 last_check: RwLock<Instant>,
938 last_modified: RwLock<Option<std::time::SystemTime>>,
940}
941
942impl FileDiscovery {
943 pub fn new(path: String, watch_interval: Duration) -> Self {
945 Self {
946 path,
947 watch_interval,
948 cached_backends: RwLock::new(BTreeSet::new()),
949 last_check: RwLock::new(Instant::now() - watch_interval),
950 last_modified: RwLock::new(None),
951 }
952 }
953
954 fn needs_check(&self) -> bool {
956 let last = *self.last_check.read();
957 last.elapsed() >= self.watch_interval
958 }
959
960 fn file_modified(&self) -> bool {
962 let metadata = match std::fs::metadata(&self.path) {
963 Ok(m) => m,
964 Err(_) => return true, };
966
967 let modified = match metadata.modified() {
968 Ok(m) => m,
969 Err(_) => return true,
970 };
971
972 let last_known = *self.last_modified.read();
973 match last_known {
974 Some(last) => modified > last,
975 None => true, }
977 }
978
979 fn read_backends(&self) -> Result<BTreeSet<Backend>, Box<Error>> {
981 trace!(path = %self.path, "Reading backends from file");
982
983 let content = std::fs::read_to_string(&self.path).map_err(|e| {
984 Error::explain(
985 ErrorType::ReadError,
986 format!("Failed to read backends file '{}': {}", self.path, e),
987 )
988 })?;
989
990 if let Ok(metadata) = std::fs::metadata(&self.path) {
992 if let Ok(modified) = metadata.modified() {
993 *self.last_modified.write() = Some(modified);
994 }
995 }
996
997 let mut backends = BTreeSet::new();
998 let mut line_num = 0;
999
1000 for line in content.lines() {
1001 line_num += 1;
1002 let line = line.trim();
1003
1004 if line.is_empty() || line.starts_with('#') {
1006 continue;
1007 }
1008
1009 let (address, weight) = Self::parse_backend_line(line, line_num)?;
1011
1012 match address.to_socket_addrs() {
1014 Ok(mut addrs) => {
1015 if let Some(socket_addr) = addrs.next() {
1016 backends.insert(Backend {
1017 addr: pingora_core::protocols::l4::socket::SocketAddr::Inet(
1018 socket_addr,
1019 ),
1020 weight,
1021 ext: http::Extensions::new(),
1022 });
1023 trace!(
1024 address = %address,
1025 weight = weight,
1026 "Added backend from file"
1027 );
1028 } else {
1029 warn!(
1030 path = %self.path,
1031 line = line_num,
1032 address = %address,
1033 "Address resolved but no socket address found"
1034 );
1035 }
1036 }
1037 Err(e) => {
1038 warn!(
1039 path = %self.path,
1040 line = line_num,
1041 address = %address,
1042 error = %e,
1043 "Failed to resolve backend address, skipping"
1044 );
1045 }
1046 }
1047 }
1048
1049 debug!(
1050 path = %self.path,
1051 backend_count = backends.len(),
1052 "Loaded backends from file"
1053 );
1054
1055 Ok(backends)
1056 }
1057
1058 fn parse_backend_line(line: &str, line_num: usize) -> Result<(String, usize), Box<Error>> {
1062 let parts: Vec<&str> = line.split_whitespace().collect();
1063
1064 if parts.is_empty() {
1065 return Err(Error::explain(
1066 ErrorType::InternalError,
1067 format!("Empty backend line at line {}", line_num),
1068 ));
1069 }
1070
1071 let address = parts[0].to_string();
1072 let mut weight = 1usize;
1073
1074 for part in parts.iter().skip(1) {
1076 if let Some(weight_str) = part.strip_prefix("weight=") {
1077 weight = weight_str.parse().unwrap_or_else(|_| {
1078 warn!(
1079 line = line_num,
1080 weight = weight_str,
1081 "Invalid weight value, using default 1"
1082 );
1083 1
1084 });
1085 }
1086 }
1087
1088 Ok((address, weight))
1089 }
1090}
1091
1092#[async_trait]
1093impl ServiceDiscovery for FileDiscovery {
1094 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
1095 if self.needs_check() {
1097 *self.last_check.write() = Instant::now();
1098
1099 if self.file_modified() {
1101 match self.read_backends() {
1102 Ok(backends) => {
1103 info!(
1104 path = %self.path,
1105 backend_count = backends.len(),
1106 "File-based discovery updated backends"
1107 );
1108 *self.cached_backends.write() = backends;
1109 }
1110 Err(e) => {
1111 let cached = self.cached_backends.read().clone();
1113 if !cached.is_empty() {
1114 warn!(
1115 path = %self.path,
1116 error = %e,
1117 cached_count = cached.len(),
1118 "File read failed, using cached backends"
1119 );
1120 return Ok((cached, HashMap::new()));
1121 }
1122 return Err(e);
1123 }
1124 }
1125 }
1126 }
1127
1128 let backends = self.cached_backends.read().clone();
1129 Ok((backends, HashMap::new()))
1130 }
1131}
1132
1133pub struct DiscoveryManager {
1142 discoveries: RwLock<HashMap<String, Arc<dyn ServiceDiscovery + Send + Sync>>>,
1144}
1145
1146impl DiscoveryManager {
1147 pub fn new() -> Self {
1149 Self {
1150 discoveries: RwLock::new(HashMap::new()),
1151 }
1152 }
1153
1154 pub fn register(&self, upstream_id: &str, config: DiscoveryConfig) -> Result<(), Box<Error>> {
1156 let discovery: Arc<dyn ServiceDiscovery + Send + Sync> = match config {
1157 DiscoveryConfig::Static { backends } => {
1158 let backend_set = backends
1159 .iter()
1160 .filter_map(|addr| {
1161 addr.to_socket_addrs()
1162 .ok()
1163 .and_then(|mut addrs| addrs.next())
1164 .map(|addr| Backend {
1165 addr: pingora_core::protocols::l4::socket::SocketAddr::Inet(addr),
1166 weight: 1,
1167 ext: http::Extensions::new(),
1168 })
1169 })
1170 .collect();
1171
1172 info!(
1173 upstream_id = %upstream_id,
1174 backend_count = backends.len(),
1175 "Registered static service discovery"
1176 );
1177
1178 Arc::new(StaticWrapper(StaticDiscovery::new(backend_set)))
1179 }
1180 DiscoveryConfig::Dns {
1181 hostname,
1182 port,
1183 refresh_interval,
1184 } => {
1185 info!(
1186 upstream_id = %upstream_id,
1187 hostname = %hostname,
1188 port = port,
1189 refresh_interval_secs = refresh_interval.as_secs(),
1190 "Registered DNS service discovery"
1191 );
1192
1193 Arc::new(DnsDiscovery::new(hostname, port, refresh_interval))
1194 }
1195 DiscoveryConfig::DnsSrv {
1196 service,
1197 refresh_interval,
1198 } => {
1199 info!(
1200 upstream_id = %upstream_id,
1201 service = %service,
1202 refresh_interval_secs = refresh_interval.as_secs(),
1203 "DNS SRV discovery not yet fully implemented, using DNS A record fallback"
1204 );
1205
1206 let hostname = service
1209 .split('.')
1210 .skip_while(|s| s.starts_with('_'))
1211 .collect::<Vec<_>>()
1212 .join(".");
1213 Arc::new(DnsDiscovery::new(hostname, 80, refresh_interval))
1214 }
1215 DiscoveryConfig::Consul {
1216 address,
1217 service,
1218 datacenter,
1219 only_passing,
1220 refresh_interval,
1221 tag,
1222 } => {
1223 info!(
1224 upstream_id = %upstream_id,
1225 address = %address,
1226 service = %service,
1227 datacenter = datacenter.as_deref().unwrap_or("default"),
1228 only_passing = only_passing,
1229 refresh_interval_secs = refresh_interval.as_secs(),
1230 "Registered Consul service discovery"
1231 );
1232
1233 Arc::new(ConsulDiscovery::new(
1234 address,
1235 service,
1236 datacenter,
1237 only_passing,
1238 refresh_interval,
1239 tag,
1240 ))
1241 }
1242 DiscoveryConfig::Kubernetes {
1243 namespace,
1244 service,
1245 port_name,
1246 refresh_interval,
1247 kubeconfig,
1248 } => {
1249 info!(
1250 upstream_id = %upstream_id,
1251 namespace = %namespace,
1252 service = %service,
1253 port_name = port_name.as_deref().unwrap_or("default"),
1254 refresh_interval_secs = refresh_interval.as_secs(),
1255 "Registered Kubernetes endpoint discovery"
1256 );
1257
1258 Arc::new(KubernetesDiscovery::new(
1259 namespace,
1260 service,
1261 port_name,
1262 refresh_interval,
1263 kubeconfig,
1264 ))
1265 }
1266 DiscoveryConfig::File {
1267 path,
1268 watch_interval,
1269 } => {
1270 info!(
1271 upstream_id = %upstream_id,
1272 path = %path,
1273 watch_interval_secs = watch_interval.as_secs(),
1274 "Registered file-based service discovery"
1275 );
1276
1277 Arc::new(FileDiscovery::new(path, watch_interval))
1278 }
1279 };
1280
1281 self.discoveries
1282 .write()
1283 .insert(upstream_id.to_string(), discovery);
1284 Ok(())
1285 }
1286
1287 pub fn get(&self, upstream_id: &str) -> Option<Arc<dyn ServiceDiscovery + Send + Sync>> {
1289 self.discoveries.read().get(upstream_id).cloned()
1290 }
1291
1292 pub async fn discover(
1294 &self,
1295 upstream_id: &str,
1296 ) -> Option<Result<(BTreeSet<Backend>, HashMap<u64, bool>)>> {
1297 let discovery = self.get(upstream_id)?;
1298 Some(discovery.discover().await)
1299 }
1300
1301 pub fn remove(&self, upstream_id: &str) {
1303 self.discoveries.write().remove(upstream_id);
1304 }
1305
1306 pub fn count(&self) -> usize {
1308 self.discoveries.read().len()
1309 }
1310}
1311
1312impl Default for DiscoveryManager {
1313 fn default() -> Self {
1314 Self::new()
1315 }
1316}
1317
1318struct StaticWrapper(Box<StaticDiscovery>);
1320
1321#[async_trait]
1322impl ServiceDiscovery for StaticWrapper {
1323 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
1324 self.0.discover().await
1325 }
1326}
1327
1328unsafe impl Send for StaticWrapper {}
1330unsafe impl Sync for StaticWrapper {}
1331
1332#[cfg(test)]
1333mod tests {
1334 use super::*;
1335
1336 #[test]
1337 fn test_discovery_config_default() {
1338 let config = DiscoveryConfig::default();
1339 match config {
1340 DiscoveryConfig::Static { backends } => {
1341 assert_eq!(backends.len(), 1);
1342 assert_eq!(backends[0], "127.0.0.1:8080");
1343 }
1344 _ => panic!("Expected Static config"),
1345 }
1346 }
1347
1348 #[tokio::test]
1349 async fn test_discovery_manager() {
1350 let manager = DiscoveryManager::new();
1351
1352 manager
1354 .register(
1355 "test-upstream",
1356 DiscoveryConfig::Static {
1357 backends: vec!["127.0.0.1:8080".to_string(), "127.0.0.1:8081".to_string()],
1358 },
1359 )
1360 .unwrap();
1361
1362 assert_eq!(manager.count(), 1);
1363
1364 let result = manager.discover("test-upstream").await;
1366 assert!(result.is_some());
1367 let (backends, _) = result.unwrap().unwrap();
1368 assert_eq!(backends.len(), 2);
1369 }
1370
1371 #[test]
1372 fn test_dns_discovery_needs_refresh() {
1373 let discovery = DnsDiscovery::new(
1374 "localhost".to_string(),
1375 8080,
1376 Duration::from_secs(0), );
1378
1379 assert!(discovery.needs_refresh());
1381 }
1382
1383 #[test]
1384 fn test_consul_discovery_url_building() {
1385 let discovery = ConsulDiscovery::new(
1386 "http://localhost:8500".to_string(),
1387 "my-service".to_string(),
1388 Some("dc1".to_string()),
1389 true,
1390 Duration::from_secs(10),
1391 Some("production".to_string()),
1392 );
1393
1394 let url = discovery.build_url();
1395 assert!(url.starts_with("http://localhost:8500/v1/health/service/my-service"));
1396 assert!(url.contains("passing=true"));
1397 assert!(url.contains("dc=dc1"));
1398 assert!(url.contains("tag=production"));
1399 }
1400
1401 #[test]
1402 fn test_consul_discovery_url_minimal() {
1403 let discovery = ConsulDiscovery::new(
1404 "http://consul.local:8500".to_string(),
1405 "backend".to_string(),
1406 None,
1407 false,
1408 Duration::from_secs(30),
1409 None,
1410 );
1411
1412 let url = discovery.build_url();
1413 assert_eq!(url, "http://consul.local:8500/v1/health/service/backend");
1414 }
1415
1416 #[test]
1417 fn test_kubernetes_discovery_config() {
1418 let discovery = KubernetesDiscovery::new(
1419 "default".to_string(),
1420 "my-service".to_string(),
1421 Some("http".to_string()),
1422 Duration::from_secs(10),
1423 None,
1424 );
1425
1426 assert!(discovery.needs_refresh());
1428 }
1429
1430 #[test]
1431 fn test_parse_consul_response_empty() {
1432 let body = "[]";
1433 let backends = parse_consul_response(body, "test").unwrap();
1434 assert!(backends.is_empty());
1435 }
1436
1437 #[tokio::test]
1438 async fn test_discovery_manager_consul() {
1439 let manager = DiscoveryManager::new();
1440
1441 manager
1443 .register(
1444 "consul-upstream",
1445 DiscoveryConfig::Consul {
1446 address: "http://localhost:8500".to_string(),
1447 service: "my-service".to_string(),
1448 datacenter: Some("dc1".to_string()),
1449 only_passing: true,
1450 refresh_interval: Duration::from_secs(10),
1451 tag: None,
1452 },
1453 )
1454 .unwrap();
1455
1456 assert_eq!(manager.count(), 1);
1457 assert!(manager.get("consul-upstream").is_some());
1458 }
1459
1460 #[tokio::test]
1461 async fn test_discovery_manager_kubernetes() {
1462 let manager = DiscoveryManager::new();
1463
1464 manager
1466 .register(
1467 "k8s-upstream",
1468 DiscoveryConfig::Kubernetes {
1469 namespace: "production".to_string(),
1470 service: "api-server".to_string(),
1471 port_name: Some("http".to_string()),
1472 refresh_interval: Duration::from_secs(15),
1473 kubeconfig: None,
1474 },
1475 )
1476 .unwrap();
1477
1478 assert_eq!(manager.count(), 1);
1479 assert!(manager.get("k8s-upstream").is_some());
1480 }
1481
1482 #[test]
1487 fn test_file_discovery_parse_backend_line_simple() {
1488 let (address, weight) = FileDiscovery::parse_backend_line("127.0.0.1:8080", 1).unwrap();
1489 assert_eq!(address, "127.0.0.1:8080");
1490 assert_eq!(weight, 1);
1491 }
1492
1493 #[test]
1494 fn test_file_discovery_parse_backend_line_with_weight() {
1495 let (address, weight) =
1496 FileDiscovery::parse_backend_line("10.0.0.1:8080 weight=5", 1).unwrap();
1497 assert_eq!(address, "10.0.0.1:8080");
1498 assert_eq!(weight, 5);
1499 }
1500
1501 #[test]
1502 fn test_file_discovery_parse_backend_line_hostname() {
1503 let (address, weight) =
1504 FileDiscovery::parse_backend_line("backend.example.com:443 weight=2", 1).unwrap();
1505 assert_eq!(address, "backend.example.com:443");
1506 assert_eq!(weight, 2);
1507 }
1508
1509 #[test]
1510 fn test_file_discovery_needs_check() {
1511 let discovery = FileDiscovery::new(
1512 "/nonexistent/path.txt".to_string(),
1513 Duration::from_secs(0), );
1515
1516 assert!(discovery.needs_check());
1518 }
1519
1520 #[tokio::test]
1521 async fn test_file_discovery_with_temp_file() {
1522 use std::io::Write;
1523
1524 let temp_dir = tempfile::tempdir().unwrap();
1526 let file_path = temp_dir.path().join("backends.txt");
1527
1528 {
1529 let mut file = std::fs::File::create(&file_path).unwrap();
1530 writeln!(file, "# Backend servers").unwrap();
1531 writeln!(file, "127.0.0.1:8080").unwrap();
1532 writeln!(file, "127.0.0.1:8081 weight=2").unwrap();
1533 writeln!(file).unwrap(); writeln!(file, "127.0.0.1:8082 weight=3").unwrap();
1535 }
1536
1537 let discovery = FileDiscovery::new(
1538 file_path.to_string_lossy().to_string(),
1539 Duration::from_secs(1),
1540 );
1541
1542 let (backends, _) = discovery.discover().await.unwrap();
1544
1545 assert_eq!(backends.len(), 3);
1546
1547 let weights: Vec<usize> = backends.iter().map(|b| b.weight).collect();
1549 assert!(weights.contains(&1)); assert!(weights.contains(&2));
1551 assert!(weights.contains(&3));
1552 }
1553
1554 #[tokio::test]
1555 async fn test_file_discovery_missing_file_uses_cache() {
1556 use std::io::Write;
1557
1558 let temp_dir = tempfile::tempdir().unwrap();
1560 let file_path = temp_dir.path().join("backends.txt");
1561
1562 {
1563 let mut file = std::fs::File::create(&file_path).unwrap();
1564 writeln!(file, "127.0.0.1:8080").unwrap();
1565 }
1566
1567 let discovery = FileDiscovery::new(
1568 file_path.to_string_lossy().to_string(),
1569 Duration::from_secs(0), );
1571
1572 let (backends, _) = discovery.discover().await.unwrap();
1574 assert_eq!(backends.len(), 1);
1575
1576 std::fs::remove_file(&file_path).unwrap();
1578
1579 std::thread::sleep(Duration::from_millis(10));
1581
1582 let (backends, _) = discovery.discover().await.unwrap();
1584 assert_eq!(backends.len(), 1);
1585 }
1586
1587 #[tokio::test]
1588 async fn test_file_discovery_hot_reload() {
1589 use std::io::Write;
1590
1591 let temp_dir = tempfile::tempdir().unwrap();
1593 let file_path = temp_dir.path().join("backends.txt");
1594
1595 {
1596 let mut file = std::fs::File::create(&file_path).unwrap();
1597 writeln!(file, "127.0.0.1:8080").unwrap();
1598 }
1599
1600 let discovery = FileDiscovery::new(
1601 file_path.to_string_lossy().to_string(),
1602 Duration::from_millis(10), );
1604
1605 let (backends, _) = discovery.discover().await.unwrap();
1607 assert_eq!(backends.len(), 1);
1608
1609 std::thread::sleep(Duration::from_millis(50));
1611
1612 {
1614 let mut file = std::fs::File::create(&file_path).unwrap();
1615 writeln!(file, "127.0.0.1:8080").unwrap();
1616 writeln!(file, "127.0.0.1:8081").unwrap();
1617 writeln!(file, "127.0.0.1:8082").unwrap();
1618 }
1619
1620 let (backends, _) = discovery.discover().await.unwrap();
1622 assert_eq!(backends.len(), 3);
1623 }
1624
1625 #[tokio::test]
1626 async fn test_discovery_manager_file() {
1627 use std::io::Write;
1628
1629 let temp_dir = tempfile::tempdir().unwrap();
1631 let file_path = temp_dir.path().join("backends.txt");
1632
1633 {
1634 let mut file = std::fs::File::create(&file_path).unwrap();
1635 writeln!(file, "127.0.0.1:8080").unwrap();
1636 writeln!(file, "127.0.0.1:8081").unwrap();
1637 }
1638
1639 let manager = DiscoveryManager::new();
1640
1641 manager
1643 .register(
1644 "file-upstream",
1645 DiscoveryConfig::File {
1646 path: file_path.to_string_lossy().to_string(),
1647 watch_interval: Duration::from_secs(5),
1648 },
1649 )
1650 .unwrap();
1651
1652 assert_eq!(manager.count(), 1);
1653 assert!(manager.get("file-upstream").is_some());
1654
1655 let result = manager.discover("file-upstream").await;
1657 assert!(result.is_some());
1658 let (backends, _) = result.unwrap().unwrap();
1659 assert_eq!(backends.len(), 2);
1660 }
1661}