Skip to main content

open_feature_flagd/resolver/common/
upstream.rs

1use crate::error::FlagdError;
2use std::str::FromStr;
3use tonic::transport::{Certificate, ClientTlsConfig};
4use tonic::transport::{Endpoint, Uri};
5use tracing::debug;
6
7#[derive(Debug)]
8pub struct UpstreamConfig {
9    endpoint: Endpoint,
10    authority: Option<String>, // Only set for custom name resolution (envoy://)
11}
12
13impl UpstreamConfig {
14    /// Creates a new upstream configuration for connecting to flagd.
15    ///
16    /// # Arguments
17    /// * `target` - The target address (host:port, URL, or envoy:// URI)
18    /// * `is_in_process` - Whether this is for in-process resolver (affects default port)
19    /// * `tls` - Whether to use TLS for the connection
20    /// * `cert_path` - Optional path to a PEM-encoded CA certificate for custom/self-signed certs
21    ///
22    /// # TLS Behavior
23    /// - If `cert_path` is provided, the certificate is loaded and used as the trusted CA
24    /// - If `cert_path` is None and TLS is enabled, system/webpki roots are used
25    /// - Self-signed certificates require providing the CA cert via `cert_path`
26    pub fn new(
27        target: String,
28        is_in_process: bool,
29        tls: bool,
30        cert_path: Option<&str>,
31    ) -> Result<Self, FlagdError> {
32        debug!(
33            "Creating upstream config for target: {}, tls: {}, cert_path: {:?}",
34            target, tls, cert_path
35        );
36
37        let scheme = if tls { "https" } else { "http" };
38
39        if target.starts_with("http://") || target.starts_with("https://") {
40            debug!("Target is already an HTTP(S) endpoint");
41            let mut endpoint = Endpoint::from_shared(target.clone())
42                .map_err(|e| FlagdError::Config(format!("Invalid endpoint: {}", e)))?;
43
44            // Apply TLS config for https URLs
45            if target.starts_with("https://") {
46                let tls_config = Self::build_tls_config(cert_path)?;
47                endpoint = endpoint
48                    .tls_config(tls_config)
49                    .map_err(|e| FlagdError::Config(format!("TLS config error: {}", e)))?;
50            }
51
52            return Ok(Self {
53                endpoint,
54                authority: None, // Standard HTTP(S) doesn't need custom authority
55            });
56        }
57
58        let (endpoint_str, authority) = if target.starts_with("envoy://") {
59            let uri = Uri::from_str(&target)
60                .map_err(|e| FlagdError::Config(format!("Failed to parse target URI: {}", e)))?;
61            let authority = uri.path().trim_start_matches('/');
62
63            if authority.is_empty() {
64                return Err(FlagdError::Config(
65                    "Service name (authority) cannot be empty".to_string(),
66                ));
67            }
68
69            let host = uri.host().unwrap_or("localhost");
70            let port = uri.port_u16().unwrap_or(9211); // Use Envoy port directly
71
72            (
73                format!("{}://{}:{}", scheme, host, port),
74                Some(authority.to_string()),
75            )
76        } else {
77            let parts: Vec<&str> = target.split(':').collect();
78            let host = parts.first().unwrap_or(&"localhost").to_string();
79            let port = parts
80                .get(1)
81                .and_then(|p| p.parse().ok())
82                .unwrap_or(if is_in_process { 8015 } else { 8013 });
83
84            debug!("Using standard resolution with {}:{}", host, port);
85            (format!("{}://{}:{}", scheme, host, port), None)
86        };
87
88        let mut endpoint = Endpoint::from_shared(endpoint_str)
89            .map_err(|e| FlagdError::Config(format!("Invalid endpoint: {}", e)))?;
90
91        // Apply TLS config when tls is enabled
92        if tls {
93            let tls_config = Self::build_tls_config(cert_path)?;
94            endpoint = endpoint
95                .tls_config(tls_config)
96                .map_err(|e| FlagdError::Config(format!("TLS config error: {}", e)))?;
97        }
98
99        Ok(Self {
100            endpoint,
101            authority,
102        })
103    }
104
105    /// Builds a TLS configuration, optionally loading a custom CA certificate.
106    ///
107    /// # Arguments
108    /// * `cert_path` - Optional path to a PEM-encoded CA certificate file
109    ///
110    /// # Returns
111    /// A configured `ClientTlsConfig` with either custom CA or system roots
112    fn build_tls_config(cert_path: Option<&str>) -> Result<ClientTlsConfig, FlagdError> {
113        let mut tls_config = ClientTlsConfig::new();
114
115        if let Some(path) = cert_path {
116            debug!("Loading custom CA certificate from: {}", path);
117            let cert_pem = std::fs::read(path).map_err(|e| {
118                FlagdError::Config(format!("Failed to read certificate file '{}': {}", path, e))
119            })?;
120            let ca_cert = Certificate::from_pem(cert_pem);
121            tls_config = tls_config.ca_certificate(ca_cert);
122        } else {
123            tls_config = tls_config.with_enabled_roots();
124        }
125
126        Ok(tls_config)
127    }
128
129    pub fn endpoint(&self) -> &Endpoint {
130        &self.endpoint
131    }
132
133    pub fn authority(&self) -> Option<String> {
134        self.authority.clone()
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    #[test]
143    fn test_tls_disabled_uses_http_scheme() {
144        let config = UpstreamConfig::new("localhost:8013".to_string(), false, false, None).unwrap();
145        assert!(config.endpoint().uri().to_string().starts_with("http://"));
146        assert_eq!(
147            config.endpoint().uri().to_string(),
148            "http://localhost:8013/"
149        );
150    }
151
152    #[test]
153    fn test_tls_enabled_uses_https_scheme() {
154        let config = UpstreamConfig::new("localhost:8013".to_string(), false, true, None).unwrap();
155        assert!(config.endpoint().uri().to_string().starts_with("https://"));
156        assert_eq!(
157            config.endpoint().uri().to_string(),
158            "https://localhost:8013/"
159        );
160    }
161
162    #[test]
163    fn test_in_process_default_port_with_tls() {
164        let config = UpstreamConfig::new("localhost".to_string(), true, true, None).unwrap();
165        assert_eq!(
166            config.endpoint().uri().to_string(),
167            "https://localhost:8015/"
168        );
169    }
170
171    #[test]
172    fn test_rpc_default_port_with_tls() {
173        let config = UpstreamConfig::new("localhost".to_string(), false, true, None).unwrap();
174        assert_eq!(
175            config.endpoint().uri().to_string(),
176            "https://localhost:8013/"
177        );
178    }
179
180    #[test]
181    fn test_explicit_http_url_preserved() {
182        let config =
183            UpstreamConfig::new("http://example.com:9000".to_string(), false, true, None).unwrap();
184        assert_eq!(
185            config.endpoint().uri().to_string(),
186            "http://example.com:9000/"
187        );
188    }
189
190    #[test]
191    fn test_explicit_https_url_preserved() {
192        let config =
193            UpstreamConfig::new("https://example.com:9000".to_string(), false, false, None)
194                .unwrap();
195        assert_eq!(
196            config.endpoint().uri().to_string(),
197            "https://example.com:9000/"
198        );
199    }
200
201    #[test]
202    fn test_envoy_target_with_tls() {
203        let config = UpstreamConfig::new(
204            "envoy://localhost:9211/my-service".to_string(),
205            false,
206            true,
207            None,
208        )
209        .unwrap();
210        assert!(config.endpoint().uri().to_string().starts_with("https://"));
211        assert_eq!(config.authority(), Some("my-service".to_string()));
212    }
213
214    #[test]
215    fn test_envoy_target_without_tls() {
216        let config = UpstreamConfig::new(
217            "envoy://localhost:9211/my-service".to_string(),
218            false,
219            false,
220            None,
221        )
222        .unwrap();
223        assert!(config.endpoint().uri().to_string().starts_with("http://"));
224        assert_eq!(config.authority(), Some("my-service".to_string()));
225    }
226
227    #[test]
228    fn test_cert_path_file_not_found() {
229        let result = UpstreamConfig::new(
230            "localhost:8013".to_string(),
231            false,
232            true,
233            Some("/nonexistent/path/to/cert.pem"),
234        );
235        assert!(result.is_err());
236        let err = result.unwrap_err();
237        assert!(err.to_string().contains("Failed to read certificate file"));
238    }
239
240    #[test]
241    fn test_tls_with_no_cert_path_uses_system_roots() {
242        // This test verifies that TLS works without a custom cert (uses system roots)
243        let config = UpstreamConfig::new("localhost:8013".to_string(), false, true, None).unwrap();
244        assert!(config.endpoint().uri().to_string().starts_with("https://"));
245    }
246}