Skip to main content

shuru_proxy/
config.rs

1use std::collections::HashMap;
2
3/// Configuration for the proxy engine.
4#[derive(Debug, Clone, Default)]
5pub struct ProxyConfig {
6    /// Secrets to inject. Key is the env var name visible to the guest.
7    /// The guest gets a random placeholder token; the proxy substitutes
8    /// the real value only when the request targets an allowed host.
9    pub secrets: HashMap<String, SecretConfig>,
10    /// Network access rules.
11    pub network: NetworkConfig,
12}
13
14/// A secret that the proxy injects into HTTP requests.
15#[derive(Debug, Clone)]
16pub struct SecretConfig {
17    /// Host environment variable to read the real value from.
18    pub from: String,
19    /// Domain patterns where this secret may be sent (e.g., "api.openai.com").
20    /// The proxy only substitutes the placeholder on requests to these hosts.
21    pub hosts: Vec<String>,
22}
23
24/// Network access policy.
25#[derive(Debug, Clone, Default)]
26pub struct NetworkConfig {
27    /// Allowed domain patterns. Empty = allow all.
28    /// Supports wildcards: "*.openai.com", "registry.npmjs.org".
29    pub allow: Vec<String>,
30}
31
32impl ProxyConfig {
33    /// Check if a domain is allowed by the network policy.
34    /// Empty allowlist means all domains are allowed.
35    pub fn is_domain_allowed(&self, domain: &str) -> bool {
36        if self.network.allow.is_empty() {
37            return true;
38        }
39        self.network
40            .allow
41            .iter()
42            .any(|pattern| domain_matches(pattern, domain))
43    }
44
45    /// Get all secret placeholder→real value mappings for a given domain.
46    pub fn secrets_for_domain(
47        &self,
48        domain: &str,
49        placeholders: &HashMap<String, String>,
50    ) -> Vec<(String, String)> {
51        let mut result = Vec::new();
52        for (name, secret) in &self.secrets {
53            if secret
54                .hosts
55                .iter()
56                .any(|pattern| domain_matches(pattern, domain))
57            {
58                if let Some(placeholder) = placeholders.get(name) {
59                    if let Ok(real_value) = std::env::var(&secret.from) {
60                        result.push((placeholder.clone(), real_value));
61                    }
62                }
63            }
64        }
65        result
66    }
67}
68
69/// Simple wildcard domain matching.
70/// "*.example.com" matches "api.example.com" but not "example.com".
71/// "example.com" matches exactly "example.com".
72fn domain_matches(pattern: &str, domain: &str) -> bool {
73    if let Some(suffix) = pattern.strip_prefix("*.") {
74        domain.ends_with(suffix) && domain.len() > suffix.len() && domain.as_bytes()[domain.len() - suffix.len() - 1] == b'.'
75    } else {
76        pattern == domain
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83
84    #[test]
85    fn test_domain_matching() {
86        assert!(domain_matches("example.com", "example.com"));
87        assert!(!domain_matches("example.com", "api.example.com"));
88        assert!(domain_matches("*.example.com", "api.example.com"));
89        assert!(domain_matches("*.example.com", "deep.api.example.com"));
90        assert!(!domain_matches("*.example.com", "example.com"));
91        assert!(!domain_matches("*.example.com", "notexample.com"));
92    }
93}