nabla_cli/
ssrf_protection.rs1use anyhow::{Result, anyhow};
2use std::collections::HashSet;
3use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
4use url::Url;
5
6#[derive(Debug, Clone)]
8pub struct SSRFConfig {
9 pub whitelisted_domains: HashSet<String>,
11 pub whitelisted_ips: HashSet<String>,
13 pub allow_localhost: bool,
15 pub allow_private_ips: bool,
17}
18
19impl Default for SSRFConfig {
20 fn default() -> Self {
21 let mut whitelisted_domains = HashSet::new();
22
23 whitelisted_domains.insert("platform.atelierlogos.studio".to_string());
25 whitelisted_domains.insert("nabla.atelierlogos.studio".to_string());
26 whitelisted_domains.insert("custom.nabla.com".to_string());
27 whitelisted_domains.insert("aws.amazon.com".to_string());
28 whitelisted_domains.insert("marketplace.amazonaws.com".to_string());
29
30 whitelisted_domains.insert("api.openai.com".to_string());
32 whitelisted_domains.insert("api.together.xyz".to_string());
33 whitelisted_domains.insert("api.anthropic.com".to_string());
34 whitelisted_domains.insert("api.groq.com".to_string());
35
36 whitelisted_domains.insert("huggingface.co".to_string());
38 whitelisted_domains.insert("hf-mirror.com".to_string());
39
40 whitelisted_domains.insert("localhost".to_string());
42 whitelisted_domains.insert("127.0.0.1".to_string());
43 whitelisted_domains.insert("0.0.0.0".to_string());
44
45 let mut whitelisted_ips = HashSet::new();
46 whitelisted_ips.insert("127.0.0.1/32".to_string());
47 whitelisted_ips.insert("::1/128".to_string());
48
49 Self {
50 whitelisted_domains,
51 whitelisted_ips,
52 allow_localhost: true,
53 allow_private_ips: false,
54 }
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct SSRFValidator {
61 pub config: SSRFConfig,
62}
63
64impl SSRFValidator {
65 pub fn new() -> Self {
67 Self {
68 config: SSRFConfig::default(),
69 }
70 }
71
72 pub fn validate_url(&self, url_str: &str) -> Result<Url, anyhow::Error> {
76 let url = Url::parse(url_str).map_err(|e| anyhow!("Invalid URL format: {}", e))?;
78
79 if url.scheme() != "http" && url.scheme() != "https" {
81 return Err(anyhow!("Only HTTP and HTTPS schemes are allowed"));
82 }
83
84 if url.username() != "" || url.password().is_some() {
86 return Err(anyhow!("URLs with user credentials are not allowed"));
87 }
88
89 if let Some(fragment) = url.fragment() {
91 if fragment.contains("@") || fragment.contains("%") {
92 return Err(anyhow!("Suspicious URL fragment detected"));
93 }
94 }
95
96 let host = url
98 .host_str()
99 .ok_or_else(|| anyhow!("URL must have a host"))?;
100
101 if let Some(ip) = self.parse_ip(host) {
103 if !self.is_ip_allowed(&ip) {
104 return Err(anyhow!("IP address '{}' is not allowed", ip));
105 }
106 } else {
107 if !self.is_host_whitelisted(host) {
109 return Err(anyhow!("Host '{}' is not in the whitelist", host));
110 }
111 }
112
113 if !self.config.allow_localhost && self.is_localhost(host) {
115 return Err(anyhow!("Localhost is not allowed"));
116 }
117
118 if self.is_localhost(host) && self.has_dangerous_port(&url) {
120 return Err(anyhow!("Access to dangerous localhost port is not allowed"));
121 }
122
123 Ok(url)
124 }
125
126 fn is_host_whitelisted(&self, host: &str) -> bool {
128 if self.config.whitelisted_domains.contains(host) {
130 return true;
131 }
132
133 for whitelisted in &self.config.whitelisted_domains {
135 if host.ends_with(&format!(".{}", whitelisted)) {
136 return true;
137 }
138 }
139
140 false
141 }
142
143 fn parse_ip(&self, host: &str) -> Option<IpAddr> {
145 host.parse::<IpAddr>().ok()
146 }
147
148 fn is_ip_allowed(&self, ip: &IpAddr) -> bool {
150 for whitelisted in &self.config.whitelisted_ips {
152 if self.ip_in_cidr(ip, whitelisted) {
153 return true;
154 }
155 }
156
157 match ip {
158 IpAddr::V4(ipv4) => self.is_ipv4_allowed(ipv4),
159 IpAddr::V6(ipv6) => self.is_ipv6_allowed(ipv6),
160 }
161 }
162
163 fn is_ipv4_allowed(&self, ip: &Ipv4Addr) -> bool {
165 if ip.octets() == [127, 0, 0, 1] {
167 return self.config.allow_localhost;
168 }
169
170 if self.is_private_ipv4(ip) {
172 return self.config.allow_private_ips;
173 }
174
175 false
177 }
178
179 fn is_ipv6_allowed(&self, ip: &Ipv6Addr) -> bool {
181 if ip.segments() == [0, 0, 0, 0, 0, 0, 0, 1] {
183 return self.config.allow_localhost;
184 }
185
186 if self.is_private_ipv6(ip) {
188 return self.config.allow_private_ips;
189 }
190
191 false
193 }
194
195 fn is_private_ipv4(&self, ip: &Ipv4Addr) -> bool {
197 let octets = ip.octets();
198
199 (octets[0] == 10) || (octets[0] == 172 && octets[1] >= 16 && octets[1] <= 31) || (octets[0] == 192 && octets[1] == 168) || (octets[0] == 127) || (octets[0] == 0) || (octets[0] == 169 && octets[1] == 254) || (octets[0] == 224) || (octets[0] == 240) }
209
210 fn is_private_ipv6(&self, ip: &Ipv6Addr) -> bool {
212 let segments = ip.segments();
213
214 if segments == [0, 0, 0, 0, 0, 0, 0, 1] {
216 return true;
217 }
218
219 if segments[0] == 0xfe80 {
221 return true;
222 }
223
224 if segments[0] & 0xfe00 == 0xfc00 {
226 return true;
227 }
228
229 if segments[0] & 0xff00 == 0xff00 {
231 return true;
232 }
233
234 false
235 }
236
237 fn ip_in_cidr(&self, ip: &IpAddr, cidr: &str) -> bool {
239 if let Some((network, bits_str)) = cidr.split_once('/') {
241 if let (Ok(network_ip), Ok(bits)) = (network.parse::<IpAddr>(), bits_str.parse::<u8>())
242 {
243 match (ip, &network_ip) {
244 (IpAddr::V4(ip_v4), IpAddr::V4(net_v4)) => {
245 let ip_bits = u32::from_be_bytes(ip_v4.octets());
247 let net_bits = u32::from_be_bytes(net_v4.octets());
248 let mask = !((1u32 << (32 - bits)) - 1);
249 (ip_bits & mask) == (net_bits & mask)
250 }
251 (IpAddr::V6(_), IpAddr::V6(_)) => {
252 ip == &network_ip
254 }
255 _ => false,
256 }
257 } else {
258 false
259 }
260 } else {
261 false
262 }
263 }
264
265 fn is_localhost(&self, host: &str) -> bool {
267 host == "localhost"
268 || host == "127.0.0.1"
269 || host == "::1"
270 || host.starts_with("localhost:")
271 }
272
273 fn has_dangerous_port(&self, url: &Url) -> bool {
275 if let Some(port) = url.port() {
276 matches!(
278 port,
279 22 | 23 | 25 | 53 | 110 | 143 | 993 | 995 | 1433 | 3306 | 3389 | 5432 | 5984 | 6379 | 7000 | 7001 | 8086 | 9042 | 9160 | 9200 | 9300 | 11211 | 27017 | 27018 | 27019 )
305 } else {
306 false
307 }
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314
315 #[test]
316 fn test_whitelisted_domains() {
317 let validator = SSRFValidator::new();
318
319 assert!(
321 validator
322 .validate_url("https://api.openai.com/v1/chat/completions")
323 .is_ok()
324 );
325 assert!(
326 validator
327 .validate_url("https://platform.atelierlogos.studio/marketplace/register")
328 .is_ok()
329 );
330 assert!(
331 validator
332 .validate_url("https://aws.amazon.com/marketplace/listing")
333 .is_ok()
334 );
335
336 assert!(validator.validate_url("https://evil.com/api").is_err());
338 assert!(
339 validator
340 .validate_url("https://malicious.example.com/")
341 .is_err()
342 );
343 }
344
345 #[test]
346 fn test_localhost() {
347 let mut validator = SSRFValidator::new();
348
349 assert!(
351 validator
352 .validate_url("http://localhost:11434/completion")
353 .is_ok()
354 );
355 assert!(validator.validate_url("http://127.0.0.1:8080/api").is_ok());
356
357 validator.config.allow_localhost = false;
359 assert!(
360 validator
361 .validate_url("http://localhost:11434/completion")
362 .is_err()
363 );
364 assert!(validator.validate_url("http://127.0.0.1:8080/api").is_err());
365 }
366
367 #[test]
368 fn test_private_ips() {
369 let mut validator = SSRFValidator::new();
370
371 assert!(
373 validator
374 .validate_url("http://192.168.1.1:8080/api")
375 .is_err()
376 );
377 assert!(validator.validate_url("http://10.0.0.1:8080/api").is_err());
378 assert!(
379 validator
380 .validate_url("http://172.16.0.1:8080/api")
381 .is_err()
382 );
383
384 validator.config.allow_private_ips = true;
386 assert!(
387 validator
388 .validate_url("http://192.168.1.1:8080/api")
389 .is_ok()
390 );
391 assert!(validator.validate_url("http://10.0.0.1:8080/api").is_ok());
392 assert!(validator.validate_url("http://172.16.0.1:8080/api").is_ok());
393 }
394
395 #[test]
396 fn test_invalid_urls() {
397 let validator = SSRFValidator::new();
398
399 assert!(validator.validate_url("ftp://example.com").is_err());
401 assert!(validator.validate_url("file:///etc/passwd").is_err());
402
403 assert!(validator.validate_url("not-a-url").is_err());
405 assert!(validator.validate_url("http://").is_err());
406 }
407}