Skip to main content

synapse_pingora/
access.rs

1//! Access control lists with CIDR-based allow/deny rules.
2//!
3//! Provides IP-based access control with support for IPv4 and IPv6 CIDR notation.
4//! Rules are evaluated in order: first matching rule wins.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::net::IpAddr;
9use tracing::debug;
10
11/// Access decision result.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum AccessDecision {
14    /// Request is allowed
15    Allow,
16    /// Request is denied
17    Deny,
18    /// No matching rule, use default
19    NoMatch,
20}
21
22/// Access control rule action.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
24#[serde(rename_all = "lowercase")]
25pub enum AccessAction {
26    Allow,
27    #[default]
28    Deny,
29}
30
31/// A CIDR network range for IP matching.
32#[derive(Debug, Clone)]
33pub struct CidrRange {
34    /// Network address
35    network: IpAddr,
36    /// Prefix length (0-32 for IPv4, 0-128 for IPv6)
37    prefix_len: u8,
38}
39
40impl CidrRange {
41    /// Parses a CIDR string (e.g., "192.168.1.0/24" or "10.0.0.1").
42    pub fn parse(cidr: &str) -> Result<Self, AccessError> {
43        use std::str::FromStr;
44        Self::from_str(cidr)
45    }
46
47    /// Checks if an IP address is within this CIDR range.
48    pub fn contains(&self, ip: &IpAddr) -> bool {
49        match (&self.network, ip) {
50            (IpAddr::V4(net), IpAddr::V4(addr)) => {
51                let net_bits = u32::from_be_bytes(net.octets());
52                let addr_bits = u32::from_be_bytes(addr.octets());
53                let mask = if self.prefix_len == 0 {
54                    0
55                } else {
56                    !0u32 << (32 - self.prefix_len)
57                };
58                (addr_bits & mask) == (net_bits & mask)
59            }
60            (IpAddr::V6(net), IpAddr::V6(addr)) => {
61                let net_bits = u128::from_be_bytes(net.octets());
62                let addr_bits = u128::from_be_bytes(addr.octets());
63                let mask = if self.prefix_len == 0 {
64                    0
65                } else {
66                    !0u128 << (128 - self.prefix_len)
67                };
68                (addr_bits & mask) == (net_bits & mask)
69            }
70            // IPv4 and IPv6 don't match
71            _ => false,
72        }
73    }
74}
75
76impl std::str::FromStr for CidrRange {
77    type Err = AccessError;
78
79    fn from_str(cidr: &str) -> Result<Self, Self::Err> {
80        let (addr_str, prefix_str) = if let Some(idx) = cidr.find('/') {
81            (&cidr[..idx], Some(&cidr[idx + 1..]))
82        } else {
83            (cidr, None)
84        };
85
86        let network: IpAddr = addr_str.parse().map_err(|_| AccessError::InvalidCidr {
87            cidr: cidr.to_string(),
88            reason: "invalid IP address".to_string(),
89        })?;
90
91        let max_prefix = match network {
92            IpAddr::V4(_) => 32,
93            IpAddr::V6(_) => 128,
94        };
95
96        let prefix_len: u8 = match prefix_str {
97            Some(s) => s.parse().map_err(|_| AccessError::InvalidCidr {
98                cidr: cidr.to_string(),
99                reason: "invalid prefix length".to_string(),
100            })?,
101            None => max_prefix,
102        };
103
104        if prefix_len > max_prefix {
105            return Err(AccessError::InvalidCidr {
106                cidr: cidr.to_string(),
107                reason: format!(
108                    "prefix length {} exceeds maximum {}",
109                    prefix_len, max_prefix
110                ),
111            });
112        }
113
114        Ok(Self {
115            network,
116            prefix_len,
117        })
118    }
119}
120
121/// A single access control rule.
122#[derive(Debug, Clone)]
123pub struct AccessRule {
124    /// CIDR range to match
125    pub cidr: CidrRange,
126    /// Action to take on match
127    pub action: AccessAction,
128    /// Optional comment/description
129    pub comment: Option<String>,
130}
131
132impl AccessRule {
133    /// Creates a new allow rule for the given CIDR.
134    pub fn allow(cidr: &str) -> Result<Self, AccessError> {
135        Ok(Self {
136            cidr: CidrRange::parse(cidr)?,
137            action: AccessAction::Allow,
138            comment: None,
139        })
140    }
141
142    /// Creates a new deny rule for the given CIDR.
143    pub fn deny(cidr: &str) -> Result<Self, AccessError> {
144        Ok(Self {
145            cidr: CidrRange::parse(cidr)?,
146            action: AccessAction::Deny,
147            comment: None,
148        })
149    }
150
151    /// Adds a comment to the rule.
152    pub fn with_comment(mut self, comment: &str) -> Self {
153        self.comment = Some(comment.to_string());
154        self
155    }
156
157    /// Checks if this rule matches the given IP.
158    pub fn matches(&self, ip: &IpAddr) -> bool {
159        self.cidr.contains(ip)
160    }
161}
162
163/// Access control list for a site.
164#[derive(Debug, Default)]
165pub struct AccessList {
166    /// Rules evaluated in order
167    rules: Vec<AccessRule>,
168    /// Default action when no rule matches
169    default_action: AccessAction,
170}
171
172impl AccessList {
173    /// Creates a new access list with deny as default.
174    pub fn new() -> Self {
175        Self {
176            rules: Vec::new(),
177            default_action: AccessAction::Deny,
178        }
179    }
180
181    /// Creates an access list that allows all by default.
182    pub fn allow_all() -> Self {
183        Self {
184            rules: Vec::new(),
185            default_action: AccessAction::Allow,
186        }
187    }
188
189    /// Creates an access list that denies all by default.
190    pub fn deny_all() -> Self {
191        Self {
192            rules: Vec::new(),
193            default_action: AccessAction::Deny,
194        }
195    }
196
197    /// Adds a rule to the access list.
198    pub fn add_rule(&mut self, rule: AccessRule) {
199        self.rules.push(rule);
200    }
201
202    /// Adds an allow rule for the given CIDR.
203    pub fn allow(&mut self, cidr: &str) -> Result<(), AccessError> {
204        self.rules.push(AccessRule::allow(cidr)?);
205        Ok(())
206    }
207
208    /// Adds a deny rule for the given CIDR.
209    pub fn deny(&mut self, cidr: &str) -> Result<(), AccessError> {
210        self.rules.push(AccessRule::deny(cidr)?);
211        Ok(())
212    }
213
214    /// Sets the default action.
215    pub fn set_default(&mut self, action: AccessAction) {
216        self.default_action = action;
217    }
218
219    /// Checks if an IP address is allowed.
220    pub fn check(&self, ip: &IpAddr) -> AccessDecision {
221        for rule in &self.rules {
222            if rule.matches(ip) {
223                debug!(
224                    "IP {} matched rule {:?} -> {:?}",
225                    ip, rule.cidr.network, rule.action
226                );
227                return match rule.action {
228                    AccessAction::Allow => AccessDecision::Allow,
229                    AccessAction::Deny => AccessDecision::Deny,
230                };
231            }
232        }
233
234        debug!(
235            "IP {} no match, using default {:?}",
236            ip, self.default_action
237        );
238        match self.default_action {
239            AccessAction::Allow => AccessDecision::Allow,
240            AccessAction::Deny => AccessDecision::Deny,
241        }
242    }
243
244    /// Returns true if the IP is allowed.
245    pub fn is_allowed(&self, ip: &IpAddr) -> bool {
246        matches!(self.check(ip), AccessDecision::Allow)
247    }
248
249    /// Returns the number of rules.
250    pub fn rule_count(&self) -> usize {
251        self.rules.len()
252    }
253}
254
255/// Per-site access list manager.
256#[derive(Debug, Default)]
257pub struct AccessListManager {
258    /// Site hostname -> access list mapping
259    lists: HashMap<String, AccessList>,
260    /// Global access list (checked first)
261    global: AccessList,
262}
263
264impl AccessListManager {
265    /// Creates a new manager with allow-all defaults.
266    pub fn new() -> Self {
267        Self {
268            lists: HashMap::new(),
269            global: AccessList::allow_all(),
270        }
271    }
272
273    /// Sets the global access list.
274    pub fn set_global(&mut self, list: AccessList) {
275        self.global = list;
276    }
277
278    /// Adds a site-specific access list.
279    pub fn add_site(&mut self, hostname: &str, list: AccessList) {
280        self.lists.insert(hostname.to_lowercase(), list);
281    }
282
283    /// Removes a site-specific access list.
284    pub fn remove_site(&mut self, hostname: &str) {
285        self.lists.remove(&hostname.to_lowercase());
286    }
287
288    /// Checks if an IP is allowed for a site.
289    ///
290    /// Evaluation order:
291    /// 1. Global deny rules
292    /// 2. Site-specific rules
293    /// 3. Global allow rules
294    /// 4. Default action
295    pub fn check(&self, hostname: &str, ip: &IpAddr) -> AccessDecision {
296        // Check global rules first
297        let global_decision = self.global.check(ip);
298        if matches!(global_decision, AccessDecision::Deny) {
299            return AccessDecision::Deny;
300        }
301
302        // Check site-specific rules
303        let normalized = hostname.to_lowercase();
304        if let Some(site_list) = self.lists.get(&normalized) {
305            let site_decision = site_list.check(ip);
306            if !matches!(site_decision, AccessDecision::NoMatch) {
307                return site_decision;
308            }
309        }
310
311        // Fall back to global decision
312        global_decision
313    }
314
315    /// Returns true if the IP is allowed for the site.
316    pub fn is_allowed(&self, hostname: &str, ip: &IpAddr) -> bool {
317        matches!(self.check(hostname, ip), AccessDecision::Allow)
318    }
319
320    /// Returns the number of configured sites.
321    pub fn site_count(&self) -> usize {
322        self.lists.len()
323    }
324
325    /// Dynamically adds a deny rule for an IP address to the global list.
326    ///
327    /// Used by CampaignManager for automated mitigation of high-confidence campaigns.
328    ///
329    /// # Arguments
330    /// * `ip` - The IP address to deny
331    /// * `comment` - Reason for the denial (e.g., campaign ID)
332    ///
333    /// # Returns
334    /// Ok(()) on success, or an error if the IP is invalid.
335    pub fn add_deny_ip(&mut self, ip: &IpAddr, comment: Option<&str>) -> Result<(), AccessError> {
336        let cidr = match ip {
337            IpAddr::V4(_) => format!("{}/32", ip),
338            IpAddr::V6(_) => format!("{}/128", ip),
339        };
340
341        let mut rule = AccessRule::deny(&cidr)?;
342        if let Some(c) = comment {
343            rule = rule.with_comment(c);
344        }
345
346        self.global.add_rule(rule);
347        tracing::info!(ip = %ip, comment = ?comment, "Added dynamic deny rule");
348        Ok(())
349    }
350
351    /// Removes all deny rules for a specific IP from the global list.
352    ///
353    /// Used for mitigation rollback when campaign confidence drops.
354    ///
355    /// # Arguments
356    /// * `ip` - The IP address to unblock
357    ///
358    /// # Returns
359    /// The number of rules removed.
360    pub fn remove_deny_ip(&mut self, ip: &IpAddr) -> usize {
361        let ip_str = ip.to_string();
362
363        let before_count = self.global.rules.len();
364        self.global.rules.retain(|rule| {
365            // Match rules by network IP and deny action
366            let network_str = match rule.cidr.network {
367                IpAddr::V4(v4) => v4.to_string(),
368                IpAddr::V6(v6) => v6.to_string(),
369            };
370            !(network_str == ip_str && matches!(rule.action, AccessAction::Deny))
371        });
372        let removed = before_count - self.global.rules.len();
373
374        if removed > 0 {
375            tracing::info!(ip = %ip, removed = removed, "Removed dynamic deny rules");
376        }
377
378        removed
379    }
380
381    /// Returns a list of all configured site hostnames.
382    pub fn list_sites(&self) -> Vec<String> {
383        self.lists.keys().cloned().collect()
384    }
385
386    /// Returns the global access list for inspection/modification.
387    pub fn global_list(&self) -> &AccessList {
388        &self.global
389    }
390
391    /// Returns a mutable reference to the global access list.
392    pub fn global_list_mut(&mut self) -> &mut AccessList {
393        &mut self.global
394    }
395}
396
397/// Errors that can occur during access control operations.
398#[derive(Debug, thiserror::Error)]
399pub enum AccessError {
400    #[error("invalid CIDR '{cidr}': {reason}")]
401    InvalidCidr { cidr: String, reason: String },
402}
403
404/// Parses an IP address from a string, handling common formats.
405pub fn parse_ip(s: &str) -> Result<IpAddr, AccessError> {
406    // Handle IPv6 with brackets
407    let s = s.trim_start_matches('[').trim_end_matches(']');
408
409    s.parse().map_err(|_| AccessError::InvalidCidr {
410        cidr: s.to_string(),
411        reason: "invalid IP address format".to_string(),
412    })
413}
414
415// ========== SSRF Protection Functions ==========
416
417/// Extract IPv4 address from IPv6-mapped IPv4 address (::ffff:x.x.x.x).
418///
419/// IPv6-mapped IPv4 addresses are commonly used to bypass SSRF protections
420/// that only check IPv4 addresses. This function extracts the underlying
421/// IPv4 address for proper validation.
422///
423/// Returns `Some(Ipv4Addr)` if the address is an IPv6-mapped IPv4, `None` otherwise.
424pub fn extract_mapped_ipv4(ip: &IpAddr) -> Option<std::net::Ipv4Addr> {
425    match ip {
426        IpAddr::V6(v6) => {
427            // Check for ::ffff:x.x.x.x format
428            let segments = v6.segments();
429            // IPv6-mapped IPv4: first 80 bits are 0, next 16 bits are 1s
430            // Format: ::ffff:192.168.1.1 = 0:0:0:0:0:ffff:c0a8:0101
431            if segments[0] == 0
432                && segments[1] == 0
433                && segments[2] == 0
434                && segments[3] == 0
435                && segments[4] == 0
436                && segments[5] == 0xffff
437            {
438                let octets = v6.octets();
439                Some(std::net::Ipv4Addr::new(
440                    octets[12], octets[13], octets[14], octets[15],
441                ))
442            } else {
443                None
444            }
445        }
446        _ => None,
447    }
448}
449
450/// Check if an IPv4 address is private/internal.
451///
452/// Private ranges (RFC 1918):
453/// - 10.0.0.0/8
454/// - 172.16.0.0/12
455/// - 192.168.0.0/16
456fn is_private_ipv4(ip: &std::net::Ipv4Addr) -> bool {
457    let octets = ip.octets();
458    // 10.0.0.0/8
459    if octets[0] == 10 {
460        return true;
461    }
462    // 172.16.0.0/12 (172.16.0.0 - 172.31.255.255)
463    if octets[0] == 172 && (16..=31).contains(&octets[1]) {
464        return true;
465    }
466    // 192.168.0.0/16
467    if octets[0] == 192 && octets[1] == 168 {
468        return true;
469    }
470    false
471}
472
473/// Check if an IP address is a loopback address.
474///
475/// Loopback addresses:
476/// - IPv4: 127.0.0.0/8
477/// - IPv6: ::1
478fn is_loopback(ip: &IpAddr) -> bool {
479    match ip {
480        IpAddr::V4(v4) => v4.octets()[0] == 127,
481        IpAddr::V6(v6) => v6.is_loopback(),
482    }
483}
484
485/// Check if an IP address is a link-local address.
486///
487/// Link-local addresses:
488/// - IPv4: 169.254.0.0/16 (includes cloud metadata 169.254.169.254)
489/// - IPv6: fe80::/10
490fn is_link_local(ip: &IpAddr) -> bool {
491    match ip {
492        IpAddr::V4(v4) => {
493            let octets = v4.octets();
494            octets[0] == 169 && octets[1] == 254
495        }
496        IpAddr::V6(v6) => {
497            // fe80::/10
498            let segments = v6.segments();
499            (segments[0] & 0xffc0) == 0xfe80
500        }
501    }
502}
503
504/// Check if an IP is a cloud metadata endpoint.
505///
506/// Common cloud metadata IPs:
507/// - AWS/Azure/GCP: 169.254.169.254
508/// - AWS (newer): 169.254.170.2
509/// - Google: metadata.google.internal typically resolves to 169.254.169.254
510fn is_cloud_metadata(ip: &IpAddr) -> bool {
511    match ip {
512        IpAddr::V4(v4) => {
513            let octets = v4.octets();
514            // 169.254.169.254 (AWS, Azure, GCP)
515            if octets == [169, 254, 169, 254] {
516                return true;
517            }
518            // 169.254.170.2 (AWS ECS task metadata)
519            if octets == [169, 254, 170, 2] {
520                return true;
521            }
522            false
523        }
524        IpAddr::V6(_) => false,
525    }
526}
527
528/// Comprehensive SSRF check for an IP address.
529///
530/// Returns `true` if the IP address is potentially dangerous for SSRF attacks:
531/// - Loopback addresses (127.0.0.0/8, ::1)
532/// - Private addresses (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16)
533/// - Link-local addresses (169.254.0.0/16, fe80::/10)
534/// - Cloud metadata endpoints (169.254.169.254, 169.254.170.2)
535/// - IPv6-mapped IPv4 addresses that resolve to any of the above
536///
537/// # Security
538/// This function is critical for SSRF prevention. Always call this before
539/// making outbound HTTP requests to user-controlled URLs.
540///
541/// # Example
542/// ```
543/// use synapse_pingora::access::is_ssrf_target;
544/// use std::net::IpAddr;
545///
546/// // Direct localhost
547/// assert!(is_ssrf_target(&"127.0.0.1".parse().unwrap()));
548///
549/// // IPv6-mapped localhost (SSRF bypass attempt)
550/// assert!(is_ssrf_target(&"::ffff:127.0.0.1".parse().unwrap()));
551///
552/// // Cloud metadata endpoint
553/// assert!(is_ssrf_target(&"169.254.169.254".parse().unwrap()));
554///
555/// // Public IP is safe
556/// assert!(!is_ssrf_target(&"8.8.8.8".parse().unwrap()));
557/// ```
558pub fn is_ssrf_target(ip: &IpAddr) -> bool {
559    // First, check if this is an IPv6-mapped IPv4 address
560    // This catches SSRF bypass attempts using ::ffff:127.0.0.1
561    if let Some(mapped_v4) = extract_mapped_ipv4(ip) {
562        // Check the underlying IPv4 address
563        if mapped_v4.octets()[0] == 127 {
564            tracing::warn!(
565                ip = %ip,
566                mapped = %mapped_v4,
567                "SSRF attempt blocked: IPv6-mapped loopback"
568            );
569            return true;
570        }
571        if is_private_ipv4(&mapped_v4) {
572            tracing::warn!(
573                ip = %ip,
574                mapped = %mapped_v4,
575                "SSRF attempt blocked: IPv6-mapped private IP"
576            );
577            return true;
578        }
579        if is_cloud_metadata(&IpAddr::V4(mapped_v4)) {
580            tracing::warn!(
581                ip = %ip,
582                mapped = %mapped_v4,
583                "SSRF attempt blocked: IPv6-mapped cloud metadata"
584            );
585            return true;
586        }
587        if is_link_local(&IpAddr::V4(mapped_v4)) {
588            tracing::warn!(
589                ip = %ip,
590                mapped = %mapped_v4,
591                "SSRF attempt blocked: IPv6-mapped link-local"
592            );
593            return true;
594        }
595        // The mapped IPv4 is public, allow it
596        return false;
597    }
598
599    // Check direct addresses
600    if is_loopback(ip) {
601        tracing::debug!(ip = %ip, "SSRF blocked: loopback address");
602        return true;
603    }
604
605    if is_cloud_metadata(ip) {
606        tracing::warn!(ip = %ip, "SSRF blocked: cloud metadata endpoint");
607        return true;
608    }
609
610    if is_link_local(ip) {
611        tracing::debug!(ip = %ip, "SSRF blocked: link-local address");
612        return true;
613    }
614
615    // Check private IPv4 ranges
616    if let IpAddr::V4(v4) = ip {
617        if is_private_ipv4(v4) {
618            tracing::debug!(ip = %ip, "SSRF blocked: private IPv4");
619            return true;
620        }
621    }
622
623    // Check IPv6 unique local (fc00::/7) and site-local (deprecated but still used)
624    if let IpAddr::V6(v6) = ip {
625        let segments = v6.segments();
626        // fc00::/7 (Unique Local Address)
627        if (segments[0] & 0xfe00) == 0xfc00 {
628            tracing::debug!(ip = %ip, "SSRF blocked: IPv6 unique local");
629            return true;
630        }
631    }
632
633    false
634}
635
636/// Result of SSRF validation with detailed reason.
637#[derive(Debug, Clone, PartialEq, Eq)]
638pub enum SsrfCheckResult {
639    /// IP is safe for outbound connections
640    Safe,
641    /// IP is loopback (127.0.0.0/8 or ::1)
642    Loopback,
643    /// IP is private RFC1918
644    Private,
645    /// IP is link-local
646    LinkLocal,
647    /// IP is cloud metadata endpoint
648    CloudMetadata,
649    /// IP is IPv6-mapped IPv4 that resolved to a blocked address
650    MappedBlocked {
651        mapped_v4: std::net::Ipv4Addr,
652        reason: &'static str,
653    },
654    /// IP is IPv6 unique local address
655    Ipv6UniqueLocal,
656}
657
658impl SsrfCheckResult {
659    /// Returns true if the result indicates a blocked address.
660    pub fn is_blocked(&self) -> bool {
661        !matches!(self, Self::Safe)
662    }
663}
664
665/// Comprehensive SSRF check with detailed result.
666///
667/// Similar to `is_ssrf_target` but returns detailed information about why
668/// an IP was blocked, useful for logging and debugging.
669pub fn check_ssrf(ip: &IpAddr) -> SsrfCheckResult {
670    // Check IPv6-mapped IPv4 first
671    if let Some(mapped_v4) = extract_mapped_ipv4(ip) {
672        if mapped_v4.octets()[0] == 127 {
673            return SsrfCheckResult::MappedBlocked {
674                mapped_v4,
675                reason: "loopback",
676            };
677        }
678        if is_private_ipv4(&mapped_v4) {
679            return SsrfCheckResult::MappedBlocked {
680                mapped_v4,
681                reason: "private",
682            };
683        }
684        if is_cloud_metadata(&IpAddr::V4(mapped_v4)) {
685            return SsrfCheckResult::MappedBlocked {
686                mapped_v4,
687                reason: "cloud_metadata",
688            };
689        }
690        if is_link_local(&IpAddr::V4(mapped_v4)) {
691            return SsrfCheckResult::MappedBlocked {
692                mapped_v4,
693                reason: "link_local",
694            };
695        }
696        return SsrfCheckResult::Safe;
697    }
698
699    if is_loopback(ip) {
700        return SsrfCheckResult::Loopback;
701    }
702    if is_cloud_metadata(ip) {
703        return SsrfCheckResult::CloudMetadata;
704    }
705    if is_link_local(ip) {
706        return SsrfCheckResult::LinkLocal;
707    }
708    if let IpAddr::V4(v4) = ip {
709        if is_private_ipv4(v4) {
710            return SsrfCheckResult::Private;
711        }
712    }
713    if let IpAddr::V6(v6) = ip {
714        let segments = v6.segments();
715        if (segments[0] & 0xfe00) == 0xfc00 {
716            return SsrfCheckResult::Ipv6UniqueLocal;
717        }
718    }
719
720    SsrfCheckResult::Safe
721}
722
723#[cfg(test)]
724mod tests {
725    use super::*;
726
727    #[test]
728    fn test_cidr_parse_ipv4() {
729        let cidr = CidrRange::parse("192.168.1.0/24").unwrap();
730        assert!(cidr.contains(&"192.168.1.1".parse().unwrap()));
731        assert!(cidr.contains(&"192.168.1.254".parse().unwrap()));
732        assert!(!cidr.contains(&"192.168.2.1".parse().unwrap()));
733    }
734
735    #[test]
736    fn test_cidr_parse_ipv4_single() {
737        let cidr = CidrRange::parse("10.0.0.1").unwrap();
738        assert!(cidr.contains(&"10.0.0.1".parse().unwrap()));
739        assert!(!cidr.contains(&"10.0.0.2".parse().unwrap()));
740    }
741
742    #[test]
743    fn test_cidr_parse_ipv6() {
744        let cidr = CidrRange::parse("2001:db8::/32").unwrap();
745        assert!(cidr.contains(&"2001:db8::1".parse().unwrap()));
746        assert!(cidr.contains(&"2001:db8:ffff::1".parse().unwrap()));
747        assert!(!cidr.contains(&"2001:db9::1".parse().unwrap()));
748    }
749
750    #[test]
751    fn test_cidr_invalid() {
752        assert!(CidrRange::parse("not-an-ip").is_err());
753        assert!(CidrRange::parse("192.168.1.0/33").is_err());
754        assert!(CidrRange::parse("192.168.1.0/abc").is_err());
755    }
756
757    #[test]
758    fn test_access_rule_allow() {
759        let rule = AccessRule::allow("10.0.0.0/8").unwrap();
760        assert!(rule.matches(&"10.1.2.3".parse().unwrap()));
761        assert!(!rule.matches(&"192.168.1.1".parse().unwrap()));
762    }
763
764    #[test]
765    fn test_access_rule_deny() {
766        let rule = AccessRule::deny("192.168.0.0/16").unwrap();
767        assert_eq!(rule.action, AccessAction::Deny);
768        assert!(rule.matches(&"192.168.1.1".parse().unwrap()));
769    }
770
771    #[test]
772    fn test_access_list_allow_all() {
773        let list = AccessList::allow_all();
774        assert!(list.is_allowed(&"1.2.3.4".parse().unwrap()));
775        assert!(list.is_allowed(&"::1".parse().unwrap()));
776    }
777
778    #[test]
779    fn test_access_list_deny_all() {
780        let list = AccessList::deny_all();
781        assert!(!list.is_allowed(&"1.2.3.4".parse().unwrap()));
782        assert!(!list.is_allowed(&"::1".parse().unwrap()));
783    }
784
785    #[test]
786    fn test_access_list_rules() {
787        let mut list = AccessList::deny_all();
788        // Order matters: first match wins
789        // To deny a specific IP within an allow range, add the deny first
790        list.deny("10.0.0.1").unwrap(); // Specific deny - must come first
791        list.allow("10.0.0.0/8").unwrap(); // Then allow the broader range
792        list.allow("192.168.1.0/24").unwrap();
793
794        assert!(!list.is_allowed(&"10.0.0.1".parse().unwrap())); // Denied by specific rule
795        assert!(list.is_allowed(&"10.0.0.2".parse().unwrap())); // Allowed by /8 rule
796        assert!(list.is_allowed(&"192.168.1.100".parse().unwrap()));
797        assert!(!list.is_allowed(&"8.8.8.8".parse().unwrap())); // Default deny
798    }
799
800    #[test]
801    fn test_access_list_manager() {
802        let mut manager = AccessListManager::new();
803
804        // Global: deny known bad actors
805        let mut global = AccessList::allow_all();
806        global.deny("1.2.3.4").unwrap();
807        manager.set_global(global);
808
809        // Site-specific: only allow internal
810        let mut site_list = AccessList::deny_all();
811        site_list.allow("10.0.0.0/8").unwrap();
812        manager.add_site("internal.example.com", site_list);
813
814        // Global deny takes precedence
815        assert!(!manager.is_allowed("any.com", &"1.2.3.4".parse().unwrap()));
816
817        // Site-specific rules
818        assert!(manager.is_allowed("internal.example.com", &"10.0.0.1".parse().unwrap()));
819        assert!(!manager.is_allowed("internal.example.com", &"8.8.8.8".parse().unwrap()));
820
821        // Other sites use global
822        assert!(manager.is_allowed("public.example.com", &"8.8.8.8".parse().unwrap()));
823    }
824
825    #[test]
826    fn test_manager_case_insensitive() {
827        let mut manager = AccessListManager::new();
828        manager.add_site("Example.COM", AccessList::deny_all());
829
830        assert!(!manager.is_allowed("example.com", &"1.2.3.4".parse().unwrap()));
831        assert!(!manager.is_allowed("EXAMPLE.COM", &"1.2.3.4".parse().unwrap()));
832    }
833
834    #[test]
835    fn test_rule_with_comment() {
836        let rule = AccessRule::deny("0.0.0.0/0")
837            .unwrap()
838            .with_comment("Block all by default");
839
840        assert_eq!(rule.comment, Some("Block all by default".to_string()));
841    }
842
843    #[test]
844    fn test_parse_ip_formats() {
845        assert!(parse_ip("192.168.1.1").is_ok());
846        assert!(parse_ip("::1").is_ok());
847        assert!(parse_ip("[::1]").is_ok()); // Bracketed IPv6
848        assert!(parse_ip("invalid").is_err());
849    }
850
851    #[test]
852    fn test_cidr_zero_prefix() {
853        let cidr = CidrRange::parse("0.0.0.0/0").unwrap();
854        assert!(cidr.contains(&"1.2.3.4".parse().unwrap()));
855        assert!(cidr.contains(&"255.255.255.255".parse().unwrap()));
856    }
857
858    #[test]
859    fn test_rule_count() {
860        let mut list = AccessList::new();
861        assert_eq!(list.rule_count(), 0);
862
863        list.allow("10.0.0.0/8").unwrap();
864        list.deny("192.168.0.0/16").unwrap();
865
866        assert_eq!(list.rule_count(), 2);
867    }
868
869    // ==================== SSRF Protection Tests ====================
870
871    #[test]
872    fn test_extract_mapped_ipv4() {
873        // IPv6-mapped IPv4 localhost
874        let mapped_localhost: IpAddr = "::ffff:127.0.0.1".parse().unwrap();
875        let extracted = extract_mapped_ipv4(&mapped_localhost);
876        assert!(extracted.is_some());
877        assert_eq!(extracted.unwrap().to_string(), "127.0.0.1");
878
879        // IPv6-mapped private IP
880        let mapped_private: IpAddr = "::ffff:192.168.1.1".parse().unwrap();
881        let extracted = extract_mapped_ipv4(&mapped_private);
882        assert!(extracted.is_some());
883        assert_eq!(extracted.unwrap().to_string(), "192.168.1.1");
884
885        // Regular IPv6 (not mapped)
886        let regular_v6: IpAddr = "2001:db8::1".parse().unwrap();
887        assert!(extract_mapped_ipv4(&regular_v6).is_none());
888
889        // IPv4 (not applicable)
890        let v4: IpAddr = "127.0.0.1".parse().unwrap();
891        assert!(extract_mapped_ipv4(&v4).is_none());
892
893        // IPv6-mapped cloud metadata
894        let mapped_metadata: IpAddr = "::ffff:169.254.169.254".parse().unwrap();
895        let extracted = extract_mapped_ipv4(&mapped_metadata);
896        assert!(extracted.is_some());
897        assert_eq!(extracted.unwrap().to_string(), "169.254.169.254");
898    }
899
900    #[test]
901    fn test_ssrf_loopback() {
902        // IPv4 localhost
903        assert!(is_ssrf_target(&"127.0.0.1".parse().unwrap()));
904        assert!(is_ssrf_target(&"127.0.0.2".parse().unwrap()));
905        assert!(is_ssrf_target(&"127.255.255.255".parse().unwrap()));
906
907        // IPv6 localhost
908        assert!(is_ssrf_target(&"::1".parse().unwrap()));
909    }
910
911    #[test]
912    fn test_ssrf_private_ipv4() {
913        // 10.0.0.0/8
914        assert!(is_ssrf_target(&"10.0.0.1".parse().unwrap()));
915        assert!(is_ssrf_target(&"10.255.255.255".parse().unwrap()));
916
917        // 172.16.0.0/12
918        assert!(is_ssrf_target(&"172.16.0.1".parse().unwrap()));
919        assert!(is_ssrf_target(&"172.31.255.255".parse().unwrap()));
920        assert!(!is_ssrf_target(&"172.15.0.1".parse().unwrap())); // Not in range
921        assert!(!is_ssrf_target(&"172.32.0.1".parse().unwrap())); // Not in range
922
923        // 192.168.0.0/16
924        assert!(is_ssrf_target(&"192.168.0.1".parse().unwrap()));
925        assert!(is_ssrf_target(&"192.168.255.255".parse().unwrap()));
926    }
927
928    #[test]
929    fn test_ssrf_cloud_metadata() {
930        // AWS/Azure/GCP metadata
931        assert!(is_ssrf_target(&"169.254.169.254".parse().unwrap()));
932        // AWS ECS task metadata
933        assert!(is_ssrf_target(&"169.254.170.2".parse().unwrap()));
934    }
935
936    #[test]
937    fn test_ssrf_link_local() {
938        // IPv4 link-local
939        assert!(is_ssrf_target(&"169.254.0.1".parse().unwrap()));
940        assert!(is_ssrf_target(&"169.254.255.255".parse().unwrap()));
941
942        // IPv6 link-local (fe80::/10)
943        assert!(is_ssrf_target(&"fe80::1".parse().unwrap()));
944        assert!(is_ssrf_target(&"fe80::abcd:1234".parse().unwrap()));
945    }
946
947    #[test]
948    fn test_ssrf_ipv6_mapped_bypass_attempts() {
949        // CRITICAL: These are common SSRF bypass attempts using IPv6-mapped IPv4
950
951        // Mapped localhost
952        assert!(is_ssrf_target(&"::ffff:127.0.0.1".parse().unwrap()));
953
954        // Mapped private IPs
955        assert!(is_ssrf_target(&"::ffff:10.0.0.1".parse().unwrap()));
956        assert!(is_ssrf_target(&"::ffff:172.16.0.1".parse().unwrap()));
957        assert!(is_ssrf_target(&"::ffff:192.168.1.1".parse().unwrap()));
958
959        // Mapped cloud metadata (HIGH SEVERITY)
960        assert!(is_ssrf_target(&"::ffff:169.254.169.254".parse().unwrap()));
961
962        // Mapped link-local
963        assert!(is_ssrf_target(&"::ffff:169.254.1.1".parse().unwrap()));
964
965        // Mapped public IP should be allowed
966        assert!(!is_ssrf_target(&"::ffff:8.8.8.8".parse().unwrap()));
967    }
968
969    #[test]
970    fn test_ssrf_ipv6_unique_local() {
971        // fc00::/7 - Unique Local Address
972        assert!(is_ssrf_target(&"fc00::1".parse().unwrap()));
973        assert!(is_ssrf_target(&"fd00::1".parse().unwrap()));
974        assert!(is_ssrf_target(&"fdab:cdef::1234".parse().unwrap()));
975    }
976
977    #[test]
978    fn test_ssrf_public_ips_allowed() {
979        // Public IPv4
980        assert!(!is_ssrf_target(&"8.8.8.8".parse().unwrap()));
981        assert!(!is_ssrf_target(&"1.1.1.1".parse().unwrap()));
982        assert!(!is_ssrf_target(&"203.0.113.1".parse().unwrap()));
983
984        // Public IPv6
985        assert!(!is_ssrf_target(&"2001:4860:4860::8888".parse().unwrap()));
986        assert!(!is_ssrf_target(&"2606:4700::1111".parse().unwrap()));
987    }
988
989    #[test]
990    fn test_check_ssrf_detailed() {
991        // Loopback
992        assert_eq!(
993            check_ssrf(&"127.0.0.1".parse().unwrap()),
994            SsrfCheckResult::Loopback
995        );
996
997        // Private
998        assert_eq!(
999            check_ssrf(&"10.0.0.1".parse().unwrap()),
1000            SsrfCheckResult::Private
1001        );
1002
1003        // Cloud metadata
1004        assert_eq!(
1005            check_ssrf(&"169.254.169.254".parse().unwrap()),
1006            SsrfCheckResult::CloudMetadata
1007        );
1008
1009        // Link-local
1010        assert_eq!(
1011            check_ssrf(&"169.254.1.1".parse().unwrap()),
1012            SsrfCheckResult::LinkLocal
1013        );
1014
1015        // IPv6 unique local
1016        assert_eq!(
1017            check_ssrf(&"fc00::1".parse().unwrap()),
1018            SsrfCheckResult::Ipv6UniqueLocal
1019        );
1020
1021        // Safe public IP
1022        assert_eq!(
1023            check_ssrf(&"8.8.8.8".parse().unwrap()),
1024            SsrfCheckResult::Safe
1025        );
1026
1027        // IPv6-mapped blocked
1028        let result = check_ssrf(&"::ffff:127.0.0.1".parse().unwrap());
1029        assert!(result.is_blocked());
1030        if let SsrfCheckResult::MappedBlocked { mapped_v4, reason } = result {
1031            assert_eq!(mapped_v4.to_string(), "127.0.0.1");
1032            assert_eq!(reason, "loopback");
1033        } else {
1034            panic!("Expected MappedBlocked");
1035        }
1036    }
1037
1038    #[test]
1039    fn test_ssrf_check_result_is_blocked() {
1040        assert!(!SsrfCheckResult::Safe.is_blocked());
1041        assert!(SsrfCheckResult::Loopback.is_blocked());
1042        assert!(SsrfCheckResult::Private.is_blocked());
1043        assert!(SsrfCheckResult::LinkLocal.is_blocked());
1044        assert!(SsrfCheckResult::CloudMetadata.is_blocked());
1045        assert!(SsrfCheckResult::Ipv6UniqueLocal.is_blocked());
1046        assert!(SsrfCheckResult::MappedBlocked {
1047            mapped_v4: "127.0.0.1".parse().unwrap(),
1048            reason: "loopback"
1049        }
1050        .is_blocked());
1051    }
1052}