nabla_cli/
ssrf_protection.rs

1use anyhow::{Result, anyhow};
2use std::collections::HashSet;
3use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
4use url::Url;
5
6/// SSRF protection configuration
7#[derive(Debug, Clone)]
8pub struct SSRFConfig {
9    /// Whitelisted domains that are allowed
10    pub whitelisted_domains: HashSet<String>,
11    /// Whitelisted IP ranges (CIDR notation)
12    pub whitelisted_ips: HashSet<String>,
13    /// Whether to allow localhost
14    pub allow_localhost: bool,
15    /// Whether to allow private IPs
16    pub allow_private_ips: bool,
17}
18
19impl Default for SSRFConfig {
20    fn default() -> Self {
21        let mut whitelisted_domains = HashSet::new();
22
23        // AWS Marketplace domains
24        whitelisted_domains.insert("platform.atelierlogos.studio".to_string());
25        whitelisted_domains.insert("nabla.atelierlogos.studio".to_string());
26        whitelisted_domains.insert("custom.nabla.com".to_string());
27        whitelisted_domains.insert("aws.amazon.com".to_string());
28        whitelisted_domains.insert("marketplace.amazonaws.com".to_string());
29
30        // OpenAI and common AI providers
31        whitelisted_domains.insert("api.openai.com".to_string());
32        whitelisted_domains.insert("api.together.xyz".to_string());
33        whitelisted_domains.insert("api.anthropic.com".to_string());
34        whitelisted_domains.insert("api.groq.com".to_string());
35
36        // Hugging Face
37        whitelisted_domains.insert("huggingface.co".to_string());
38        whitelisted_domains.insert("hf-mirror.com".to_string());
39
40        // Common local inference servers
41        whitelisted_domains.insert("localhost".to_string());
42        whitelisted_domains.insert("127.0.0.1".to_string());
43        whitelisted_domains.insert("0.0.0.0".to_string());
44
45        let mut whitelisted_ips = HashSet::new();
46        whitelisted_ips.insert("127.0.0.1/32".to_string());
47        whitelisted_ips.insert("::1/128".to_string());
48
49        Self {
50            whitelisted_domains,
51            whitelisted_ips,
52            allow_localhost: true,
53            allow_private_ips: false,
54        }
55    }
56}
57
58/// SSRF protection validator
59#[derive(Debug, Clone)]
60pub struct SSRFValidator {
61    pub config: SSRFConfig,
62}
63
64impl SSRFValidator {
65    /// Create a new SSRF validator with default configuration
66    pub fn new() -> Self {
67        Self {
68            config: SSRFConfig::default(),
69        }
70    }
71
72    /// Create a new SSRF validator with custom configuration
73
74    /// Validate a URL for SSRF protection
75    pub fn validate_url(&self, url_str: &str) -> Result<Url, anyhow::Error> {
76        // Parse the URL
77        let url = Url::parse(url_str).map_err(|e| anyhow!("Invalid URL format: {}", e))?;
78
79        // Check if it's a valid scheme
80        if url.scheme() != "http" && url.scheme() != "https" {
81            return Err(anyhow!("Only HTTP and HTTPS schemes are allowed"));
82        }
83
84        // Check for URL manipulation attempts
85        if url.username() != "" || url.password().is_some() {
86            return Err(anyhow!("URLs with user credentials are not allowed"));
87        }
88
89        // Check for suspicious fragments that might indicate manipulation
90        if let Some(fragment) = url.fragment() {
91            if fragment.contains("@") || fragment.contains("%") {
92                return Err(anyhow!("Suspicious URL fragment detected"));
93            }
94        }
95
96        // Extract host
97        let host = url
98            .host_str()
99            .ok_or_else(|| anyhow!("URL must have a host"))?;
100
101        // Check for IP addresses first
102        if let Some(ip) = self.parse_ip(host) {
103            if !self.is_ip_allowed(&ip) {
104                return Err(anyhow!("IP address '{}' is not allowed", ip));
105            }
106        } else {
107            // For domain names, check if host is in whitelist
108            if !self.is_host_whitelisted(host) {
109                return Err(anyhow!("Host '{}' is not in the whitelist", host));
110            }
111        }
112
113        // Check for localhost
114        if !self.config.allow_localhost && self.is_localhost(host) {
115            return Err(anyhow!("Localhost is not allowed"));
116        }
117
118        // Check for dangerous ports on localhost
119        if self.is_localhost(host) && self.has_dangerous_port(&url) {
120            return Err(anyhow!("Access to dangerous localhost port is not allowed"));
121        }
122
123        Ok(url)
124    }
125
126    /// Check if a host is whitelisted
127    fn is_host_whitelisted(&self, host: &str) -> bool {
128        // Check exact match
129        if self.config.whitelisted_domains.contains(host) {
130            return true;
131        }
132
133        // Check subdomain matches
134        for whitelisted in &self.config.whitelisted_domains {
135            if host.ends_with(&format!(".{}", whitelisted)) {
136                return true;
137            }
138        }
139
140        false
141    }
142
143    /// Parse IP address from host string
144    fn parse_ip(&self, host: &str) -> Option<IpAddr> {
145        host.parse::<IpAddr>().ok()
146    }
147
148    /// Check if an IP address is allowed
149    fn is_ip_allowed(&self, ip: &IpAddr) -> bool {
150        // Check whitelisted IP ranges first
151        for whitelisted in &self.config.whitelisted_ips {
152            if self.ip_in_cidr(ip, whitelisted) {
153                return true;
154            }
155        }
156
157        match ip {
158            IpAddr::V4(ipv4) => self.is_ipv4_allowed(ipv4),
159            IpAddr::V6(ipv6) => self.is_ipv6_allowed(ipv6),
160        }
161    }
162
163    /// Check if IPv4 address is allowed
164    fn is_ipv4_allowed(&self, ip: &Ipv4Addr) -> bool {
165        // Check if it's localhost
166        if ip.octets() == [127, 0, 0, 1] {
167            return self.config.allow_localhost;
168        }
169
170        // Check if it's private and private IPs are allowed
171        if self.is_private_ipv4(ip) {
172            return self.config.allow_private_ips;
173        }
174
175        // For public IPs, they should be explicitly whitelisted as domains
176        false
177    }
178
179    /// Check if IPv6 address is allowed
180    fn is_ipv6_allowed(&self, ip: &Ipv6Addr) -> bool {
181        // Check if it's localhost
182        if ip.segments() == [0, 0, 0, 0, 0, 0, 0, 1] {
183            return self.config.allow_localhost;
184        }
185
186        // Check if it's private and private IPs are allowed
187        if self.is_private_ipv6(ip) {
188            return self.config.allow_private_ips;
189        }
190
191        // For public IPs, they should be explicitly whitelisted as domains
192        false
193    }
194
195    /// Check if IPv4 address is private
196    fn is_private_ipv4(&self, ip: &Ipv4Addr) -> bool {
197        let octets = ip.octets();
198
199        // Private ranges
200        (octets[0] == 10) || // 10.0.0.0/8
201        (octets[0] == 172 && octets[1] >= 16 && octets[1] <= 31) || // 172.16.0.0/12
202        (octets[0] == 192 && octets[1] == 168) || // 192.168.0.0/16
203        (octets[0] == 127) || // 127.0.0.0/8 (localhost)
204        (octets[0] == 0) || // 0.0.0.0/8
205        (octets[0] == 169 && octets[1] == 254) || // 169.254.0.0/16 (link-local)
206        (octets[0] == 224) || // 224.0.0.0/4 (multicast)
207        (octets[0] == 240) // 240.0.0.0/4 (reserved)
208    }
209
210    /// Check if IPv6 address is private
211    fn is_private_ipv6(&self, ip: &Ipv6Addr) -> bool {
212        let segments = ip.segments();
213
214        // Localhost
215        if segments == [0, 0, 0, 0, 0, 0, 0, 1] {
216            return true;
217        }
218
219        // Link-local
220        if segments[0] == 0xfe80 {
221            return true;
222        }
223
224        // Unique local
225        if segments[0] & 0xfe00 == 0xfc00 {
226            return true;
227        }
228
229        // Multicast
230        if segments[0] & 0xff00 == 0xff00 {
231            return true;
232        }
233
234        false
235    }
236
237    /// Check if IP is in CIDR range
238    fn ip_in_cidr(&self, ip: &IpAddr, cidr: &str) -> bool {
239        // Simple implementation - in production, use a proper CIDR library
240        if let Some((network, bits_str)) = cidr.split_once('/') {
241            if let (Ok(network_ip), Ok(bits)) = (network.parse::<IpAddr>(), bits_str.parse::<u8>())
242            {
243                match (ip, &network_ip) {
244                    (IpAddr::V4(ip_v4), IpAddr::V4(net_v4)) => {
245                        // Simple IPv4 CIDR matching
246                        let ip_bits = u32::from_be_bytes(ip_v4.octets());
247                        let net_bits = u32::from_be_bytes(net_v4.octets());
248                        let mask = !((1u32 << (32 - bits)) - 1);
249                        (ip_bits & mask) == (net_bits & mask)
250                    }
251                    (IpAddr::V6(_), IpAddr::V6(_)) => {
252                        // For IPv6, just do exact match for now
253                        ip == &network_ip
254                    }
255                    _ => false,
256                }
257            } else {
258                false
259            }
260        } else {
261            false
262        }
263    }
264
265    /// Check if host is localhost
266    fn is_localhost(&self, host: &str) -> bool {
267        host == "localhost"
268            || host == "127.0.0.1"
269            || host == "::1"
270            || host.starts_with("localhost:")
271    }
272
273    /// Check if URL has a dangerous port
274    fn has_dangerous_port(&self, url: &Url) -> bool {
275        if let Some(port) = url.port() {
276            // Common dangerous ports
277            matches!(
278                port,
279                22 |   // SSH
280                23 |   // Telnet
281                25 |   // SMTP
282                53 |   // DNS
283                110 |  // POP3
284                143 |  // IMAP
285                993 |  // IMAPS
286                995 |  // POP3S
287                1433 | // MSSQL
288                3306 | // MySQL
289                3389 | // RDP
290                5432 | // PostgreSQL
291                5984 | // CouchDB
292                6379 | // Redis
293                7000 | // Cassandra
294                7001 | // Cassandra SSL
295                8086 | // InfluxDB
296                9042 | // Cassandra
297                9160 | // Cassandra Thrift
298                9200 | // Elasticsearch
299                9300 | // Elasticsearch
300                11211 | // Memcached
301                27017 | // MongoDB
302                27018 | // MongoDB
303                27019 // MongoDB
304            )
305        } else {
306            false
307        }
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314
315    #[test]
316    fn test_whitelisted_domains() {
317        let validator = SSRFValidator::new();
318
319        // Test whitelisted domains
320        assert!(
321            validator
322                .validate_url("https://api.openai.com/v1/chat/completions")
323                .is_ok()
324        );
325        assert!(
326            validator
327                .validate_url("https://platform.atelierlogos.studio/marketplace/register")
328                .is_ok()
329        );
330        assert!(
331            validator
332                .validate_url("https://aws.amazon.com/marketplace/listing")
333                .is_ok()
334        );
335
336        // Test non-whitelisted domains
337        assert!(validator.validate_url("https://evil.com/api").is_err());
338        assert!(
339            validator
340                .validate_url("https://malicious.example.com/")
341                .is_err()
342        );
343    }
344
345    #[test]
346    fn test_localhost() {
347        let mut validator = SSRFValidator::new();
348
349        // Test localhost with allow_localhost = true
350        assert!(
351            validator
352                .validate_url("http://localhost:11434/completion")
353                .is_ok()
354        );
355        assert!(validator.validate_url("http://127.0.0.1:8080/api").is_ok());
356
357        // Test localhost with allow_localhost = false
358        validator.config.allow_localhost = false;
359        assert!(
360            validator
361                .validate_url("http://localhost:11434/completion")
362                .is_err()
363        );
364        assert!(validator.validate_url("http://127.0.0.1:8080/api").is_err());
365    }
366
367    #[test]
368    fn test_private_ips() {
369        let mut validator = SSRFValidator::new();
370
371        // Test private IPs with allow_private_ips = false
372        assert!(
373            validator
374                .validate_url("http://192.168.1.1:8080/api")
375                .is_err()
376        );
377        assert!(validator.validate_url("http://10.0.0.1:8080/api").is_err());
378        assert!(
379            validator
380                .validate_url("http://172.16.0.1:8080/api")
381                .is_err()
382        );
383
384        // Test private IPs with allow_private_ips = true
385        validator.config.allow_private_ips = true;
386        assert!(
387            validator
388                .validate_url("http://192.168.1.1:8080/api")
389                .is_ok()
390        );
391        assert!(validator.validate_url("http://10.0.0.1:8080/api").is_ok());
392        assert!(validator.validate_url("http://172.16.0.1:8080/api").is_ok());
393    }
394
395    #[test]
396    fn test_invalid_urls() {
397        let validator = SSRFValidator::new();
398
399        // Test invalid schemes
400        assert!(validator.validate_url("ftp://example.com").is_err());
401        assert!(validator.validate_url("file:///etc/passwd").is_err());
402
403        // Test invalid URLs
404        assert!(validator.validate_url("not-a-url").is_err());
405        assert!(validator.validate_url("http://").is_err());
406    }
407}