Skip to main content

shuru_proxy/
config.rs

1use std::collections::HashMap;
2use std::net::Ipv4Addr;
3
4/// A host port exposed to the guest via host.shuru.internal.
5#[derive(Debug, Clone)]
6pub struct ExposeHostMapping {
7    /// Port on the host (127.0.0.1:host_port).
8    pub host_port: u16,
9    /// Port the guest connects to (host.shuru.internal:guest_port).
10    pub guest_port: u16,
11}
12
13/// Configuration for the proxy engine.
14#[derive(Debug, Clone, Default)]
15pub struct ProxyConfig {
16    /// Secrets to inject. Key is the env var name visible to the guest.
17    /// The guest gets a random placeholder token; the proxy substitutes
18    /// the real value only when the request targets an allowed host.
19    pub secrets: HashMap<String, SecretConfig>,
20    /// Network access rules.
21    pub network: NetworkConfig,
22    /// Host ports exposed to the guest via host.shuru.internal.
23    pub expose_host: Vec<ExposeHostMapping>,
24}
25
26/// A secret that the proxy injects into HTTP requests.
27#[derive(Debug, Clone)]
28pub struct SecretConfig {
29    /// Host environment variable to read the real value from.
30    pub from: String,
31    /// Domain patterns where this secret may be sent (e.g., "api.openai.com").
32    /// The proxy only substitutes the placeholder on requests to these hosts.
33    pub hosts: Vec<String>,
34    /// If set, use this value directly instead of reading from the host env var.
35    pub value: Option<String>,
36}
37
38/// Network access policy.
39#[derive(Debug, Clone, Default)]
40pub struct NetworkConfig {
41    /// Allowed domain patterns. Empty = allow all.
42    /// Supports wildcards: "*.openai.com", "registry.npmjs.org".
43    pub allow: Vec<String>,
44}
45
46impl ProxyConfig {
47    /// Check if a domain is allowed by the network policy.
48    /// Empty allowlist means all domains are allowed.
49    pub fn is_domain_allowed(&self, domain: &str) -> bool {
50        if self.network.allow.is_empty() {
51            return true;
52        }
53        self.network
54            .allow
55            .iter()
56            .any(|pattern| domain_matches(pattern, domain))
57    }
58
59    /// Look up whether a connection to the gateway IP on `guest_port` should
60    /// be forwarded to a host port. Returns the host port if matched.
61    pub fn exposed_host_port(&self, dst_ip: Ipv4Addr, guest_port: u16) -> Option<u16> {
62        const GATEWAY: Ipv4Addr = Ipv4Addr::new(10, 0, 0, 1);
63        if dst_ip != GATEWAY {
64            return None;
65        }
66        self.expose_host
67            .iter()
68            .find(|m| m.guest_port == guest_port)
69            .map(|m| m.host_port)
70    }
71
72    /// Get all secret placeholder→real value mappings for a given domain.
73    pub fn secrets_for_domain(
74        &self,
75        domain: &str,
76        placeholders: &HashMap<String, String>,
77    ) -> Vec<(String, String)> {
78        let mut substitutions = Vec::new();
79        for (name, secret) in &self.secrets {
80            if secret
81                .hosts
82                .iter()
83                .any(|pattern| domain_matches(pattern, domain))
84            {
85                if let Some(placeholder) = placeholders.get(name) {
86                    let real_value = secret
87                        .value
88                        .clone()
89                        .or_else(|| std::env::var(&secret.from).ok());
90                    if let Some(real_value) = real_value {
91                        substitutions.push((placeholder.clone(), real_value));
92                    }
93                }
94            }
95        }
96        substitutions
97    }
98}
99
100/// Simple wildcard domain matching.
101/// "*" matches any domain (catch-all).
102/// "*.example.com" matches "api.example.com" but not "example.com".
103/// "example.com" matches exactly "example.com".
104fn domain_matches(pattern: &str, domain: &str) -> bool {
105    if pattern == "*" {
106        true
107    } else if let Some(suffix) = pattern.strip_prefix("*.") {
108        domain.ends_with(suffix) && domain.len() > suffix.len() && domain.as_bytes()[domain.len() - suffix.len() - 1] == b'.'
109    } else {
110        pattern == domain
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn test_exposed_host_port() {
120        use std::net::Ipv4Addr;
121        let config = ProxyConfig {
122            expose_host: vec![
123                ExposeHostMapping { host_port: 3000, guest_port: 8080 },
124                ExposeHostMapping { host_port: 5432, guest_port: 5432 },
125            ],
126            ..Default::default()
127        };
128        // Gateway IP match
129        assert_eq!(config.exposed_host_port(Ipv4Addr::new(10, 0, 0, 1), 8080), Some(3000));
130        assert_eq!(config.exposed_host_port(Ipv4Addr::new(10, 0, 0, 1), 5432), Some(5432));
131        // No mapping for this port
132        assert_eq!(config.exposed_host_port(Ipv4Addr::new(10, 0, 0, 1), 9999), None);
133        // Non-gateway IP
134        assert_eq!(config.exposed_host_port(Ipv4Addr::new(1, 2, 3, 4), 8080), None);
135    }
136
137    #[test]
138    fn test_domain_matching() {
139        assert!(domain_matches("*", "anything.com"));
140        assert!(domain_matches("*", "api.example.com"));
141        assert!(domain_matches("example.com", "example.com"));
142        assert!(!domain_matches("example.com", "api.example.com"));
143        assert!(domain_matches("*.example.com", "api.example.com"));
144        assert!(domain_matches("*.example.com", "deep.api.example.com"));
145        assert!(!domain_matches("*.example.com", "example.com"));
146        assert!(!domain_matches("*.example.com", "notexample.com"));
147    }
148}