iam_rs/evaluation/
matcher.rs

1use crate::{Arn, ArnError};
2use std::collections::HashSet;
3
4/// Advanced ARN matching capabilities for policy evaluation
5#[derive(Debug, Clone)]
6pub struct ArnMatcher {
7    /// Pre-compiled patterns for efficient matching
8    patterns: Vec<ArnPattern>,
9}
10
11/// Internal representation of an ARN pattern with pre-computed matching data
12#[derive(Debug, Clone)]
13struct ArnPattern {
14    /// The original pattern string
15    pattern: String,
16    /// Parsed ARN (with wildcards allowed)
17    arn: Arn,
18    /// Pre-computed flags for optimization
19    has_wildcards: bool,
20    /// Component-level wildcard flags
21    partition_wildcard: bool,
22    service_wildcard: bool,
23    region_wildcard: bool,
24    account_wildcard: bool,
25    resource_wildcard: bool,
26}
27
28impl ArnMatcher {
29    /// Create a new ARN matcher with a set of patterns
30    pub fn new<I>(patterns: I) -> Result<Self, ArnError>
31    where
32        I: IntoIterator<Item = String>,
33    {
34        let mut compiled_patterns = Vec::new();
35
36        for pattern in patterns {
37            let arn_pattern = ArnPattern::compile(&pattern)?;
38            compiled_patterns.push(arn_pattern);
39        }
40
41        Ok(ArnMatcher {
42            patterns: compiled_patterns,
43        })
44    }
45
46    /// Create a matcher from a single pattern
47    pub fn from_pattern(pattern: &str) -> Result<Self, ArnError> {
48        Self::new(vec![pattern.to_string()])
49    }
50
51    /// Check if any pattern matches the given ARN
52    pub fn matches(&self, arn: &str) -> Result<bool, ArnError> {
53        let target_arn = Arn::parse(arn)?;
54
55        for pattern in &self.patterns {
56            if pattern.matches(&target_arn) {
57                return Ok(true);
58            }
59        }
60
61        Ok(false)
62    }
63
64    /// Check if any pattern matches the given parsed ARN
65    pub fn matches_arn(&self, arn: &Arn) -> bool {
66        self.patterns.iter().any(|pattern| pattern.matches(arn))
67    }
68
69    /// Get all patterns that match the given ARN
70    pub fn matching_patterns(&self, arn: &str) -> Result<Vec<&str>, ArnError> {
71        let target_arn = Arn::parse(arn)?;
72
73        Ok(self
74            .patterns
75            .iter()
76            .filter(|pattern| pattern.matches(&target_arn))
77            .map(|pattern| pattern.pattern.as_str())
78            .collect())
79    }
80
81    /// Find ARNs from a collection that match any of our patterns
82    pub fn filter_matching<'a>(&self, arns: &'a [String]) -> Result<Vec<&'a str>, ArnError> {
83        let mut matching = Vec::new();
84
85        for arn_str in arns {
86            if self.matches(arn_str)? {
87                matching.push(arn_str.as_str());
88            }
89        }
90
91        Ok(matching)
92    }
93
94    /// Check if this matcher would match everything (contains "*")
95    pub fn matches_all(&self) -> bool {
96        self.patterns.iter().any(|p| p.pattern == "*")
97    }
98
99    /// Get the list of patterns this matcher uses
100    pub fn patterns(&self) -> Vec<&str> {
101        self.patterns.iter().map(|p| p.pattern.as_str()).collect()
102    }
103
104    /// Create a matcher that combines multiple matchers (OR logic)
105    pub fn combine(matchers: Vec<ArnMatcher>) -> Self {
106        let mut all_patterns = Vec::new();
107
108        for matcher in matchers {
109            all_patterns.extend(matcher.patterns);
110        }
111
112        ArnMatcher {
113            patterns: all_patterns,
114        }
115    }
116}
117
118impl ArnPattern {
119    /// Compile a pattern string into an optimized pattern
120    fn compile(pattern: &str) -> Result<Self, ArnError> {
121        // Handle the special case of "*" (matches everything)
122        if pattern == "*" {
123            return Ok(ArnPattern {
124                pattern: pattern.to_string(),
125                arn: Arn {
126                    partition: "*".to_string(),
127                    service: "*".to_string(),
128                    region: "*".to_string(),
129                    account_id: "*".to_string(),
130                    resource: "*".to_string(),
131                },
132                has_wildcards: true,
133                partition_wildcard: true,
134                service_wildcard: true,
135                region_wildcard: true,
136                account_wildcard: true,
137                resource_wildcard: true,
138            });
139        }
140
141        let arn = Arn::parse(pattern)?;
142        let has_wildcards = pattern.contains('*') || pattern.contains('?');
143
144        Ok(ArnPattern {
145            pattern: pattern.to_string(),
146            partition_wildcard: arn.partition.contains('*') || arn.partition.contains('?'),
147            service_wildcard: arn.service.contains('*') || arn.service.contains('?'),
148            region_wildcard: arn.region.contains('*') || arn.region.contains('?'),
149            account_wildcard: arn.account_id.contains('*') || arn.account_id.contains('?'),
150            resource_wildcard: arn.resource.contains('*') || arn.resource.contains('?'),
151            arn,
152            has_wildcards,
153        })
154    }
155
156    /// Check if this pattern matches the given ARN
157    fn matches(&self, target: &Arn) -> bool {
158        // Special case: "*" matches everything
159        if self.pattern == "*" {
160            return true;
161        }
162
163        // For performance, check exact matches first if no wildcards
164        if !self.has_wildcards {
165            return self.arn.partition == target.partition
166                && self.arn.service == target.service
167                && self.arn.region == target.region
168                && self.arn.account_id == target.account_id
169                && self.arn.resource == target.resource;
170        }
171
172        // Service cannot contain wildcards for security reasons
173        if self.service_wildcard {
174            return false;
175        }
176
177        // Check each component
178        self.match_component(&target.partition, &self.arn.partition, self.partition_wildcard)
179            && target.service == self.arn.service  // Service must match exactly
180            && self.match_component(&target.region, &self.arn.region, self.region_wildcard)
181            && self.match_component(&target.account_id, &self.arn.account_id, self.account_wildcard)
182            && self.match_component(&target.resource, &self.arn.resource, self.resource_wildcard)
183    }
184
185    /// Match a single component, using wildcards if needed
186    fn match_component(&self, target: &str, pattern: &str, has_wildcard: bool) -> bool {
187        if has_wildcard {
188            Arn::wildcard_match(target, pattern)
189        } else {
190            target == pattern
191        }
192    }
193}
194
195/// ARN builder for creating ARNs programmatically
196#[derive(Debug, Clone, Default)]
197pub struct ArnBuilder {
198    partition: Option<String>,
199    service: Option<String>,
200    region: Option<String>,
201    account_id: Option<String>,
202    resource_type: Option<String>,
203    resource_id: Option<String>,
204}
205
206impl ArnBuilder {
207    /// Create a new ARN builder
208    pub fn new() -> Self {
209        Self::default()
210    }
211
212    /// Set the partition (e.g., "aws", "aws-cn")
213    pub fn partition<S: Into<String>>(mut self, partition: S) -> Self {
214        self.partition = Some(partition.into());
215        self
216    }
217
218    /// Set the service (e.g., "s3", "ec2", "iam")
219    pub fn service<S: Into<String>>(mut self, service: S) -> Self {
220        self.service = Some(service.into());
221        self
222    }
223
224    /// Set the region (e.g., "us-east-1")
225    pub fn region<S: Into<String>>(mut self, region: S) -> Self {
226        self.region = Some(region.into());
227        self
228    }
229
230    /// Set the account ID
231    pub fn account_id<S: Into<String>>(mut self, account_id: S) -> Self {
232        self.account_id = Some(account_id.into());
233        self
234    }
235
236    /// Set the resource type and ID separately
237    pub fn resource<S: Into<String>>(mut self, resource_type: S, resource_id: S) -> Self {
238        self.resource_type = Some(resource_type.into());
239        self.resource_id = Some(resource_id.into());
240        self
241    }
242
243    /// Set the full resource string
244    pub fn resource_string<S: Into<String>>(mut self, resource: S) -> Self {
245        let resource_str = resource.into();
246        if let Some(slash_pos) = resource_str.find('/') {
247            self.resource_type = Some(resource_str[..slash_pos].to_string());
248            self.resource_id = Some(resource_str[slash_pos + 1..].to_string());
249        } else if let Some(colon_pos) = resource_str.find(':') {
250            self.resource_type = Some(resource_str[..colon_pos].to_string());
251            self.resource_id = Some(resource_str[colon_pos + 1..].to_string());
252        } else {
253            self.resource_type = None;
254            self.resource_id = Some(resource_str);
255        }
256        self
257    }
258
259    /// Build the ARN
260    pub fn build(self) -> Result<Arn, ArnError> {
261        let partition = self.partition.unwrap_or_else(|| "aws".to_string());
262        let service = self
263            .service
264            .ok_or_else(|| ArnError::InvalidService("Service is required".to_string()))?;
265        let region = self.region.unwrap_or_default();
266        let account_id = self.account_id.unwrap_or_default();
267
268        let resource = match (self.resource_type, self.resource_id) {
269            (Some(rt), Some(ri)) => format!("{}/{}", rt, ri),
270            (None, Some(ri)) => ri,
271            (Some(rt), None) => rt,
272            (None, None) => {
273                return Err(ArnError::InvalidResource(
274                    "Resource is required".to_string(),
275                ));
276            }
277        };
278
279        Ok(Arn {
280            partition,
281            service,
282            region,
283            account_id,
284            resource,
285        })
286    }
287
288    /// Build the ARN and convert to string
289    pub fn build_string(self) -> Result<String, ArnError> {
290        Ok(self.build()?.to_string())
291    }
292}
293
294/// ARN set operations for working with collections of ARNs
295pub struct ArnSet {
296    arns: HashSet<String>,
297}
298
299impl ArnSet {
300    /// Create a new ARN set
301    pub fn new() -> Self {
302        Self {
303            arns: HashSet::new(),
304        }
305    }
306
307    /// Create from a collection of ARNs
308    pub fn from_arns<I>(arns: I) -> Result<Self, ArnError>
309    where
310        I: IntoIterator<Item = String>,
311    {
312        let mut set = Self::new();
313        for arn in arns {
314            set.add(arn)?;
315        }
316        Ok(set)
317    }
318
319    /// Add an ARN to the set (validates it first)
320    pub fn add(&mut self, arn: String) -> Result<(), ArnError> {
321        // Validate the ARN
322        Arn::parse(&arn)?;
323        self.arns.insert(arn);
324        Ok(())
325    }
326
327    /// Check if the set contains an ARN
328    pub fn contains(&self, arn: &str) -> bool {
329        self.arns.contains(arn)
330    }
331
332    /// Get ARNs that match any of the given patterns
333    pub fn filter_by_patterns(&self, patterns: &[String]) -> Result<Vec<&str>, ArnError> {
334        let matcher = ArnMatcher::new(patterns.iter().cloned())?;
335
336        let mut matching = Vec::new();
337        for arn in &self.arns {
338            if matcher.matches(arn)? {
339                matching.push(arn.as_str());
340            }
341        }
342
343        Ok(matching)
344    }
345
346    /// Get all ARNs for a specific service
347    pub fn filter_by_service(&self, service: &str) -> Result<Vec<&str>, ArnError> {
348        let mut matching = Vec::new();
349
350        for arn_str in &self.arns {
351            let arn = Arn::parse(arn_str)?;
352            if arn.service == service {
353                matching.push(arn_str.as_str());
354            }
355        }
356
357        Ok(matching)
358    }
359
360    /// Get all ARNs for a specific account
361    pub fn filter_by_account(&self, account_id: &str) -> Result<Vec<&str>, ArnError> {
362        let mut matching = Vec::new();
363
364        for arn_str in &self.arns {
365            let arn = Arn::parse(arn_str)?;
366            if arn.account_id == account_id {
367                matching.push(arn_str.as_str());
368            }
369        }
370
371        Ok(matching)
372    }
373
374    /// Get the number of ARNs in the set
375    pub fn len(&self) -> usize {
376        self.arns.len()
377    }
378
379    /// Check if the set is empty
380    pub fn is_empty(&self) -> bool {
381        self.arns.is_empty()
382    }
383
384    /// Get all ARNs as a vector
385    pub fn to_vec(&self) -> Vec<&str> {
386        self.arns.iter().map(|s| s.as_str()).collect()
387    }
388}
389
390impl Default for ArnSet {
391    fn default() -> Self {
392        Self::new()
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399
400    #[test]
401    fn test_arn_matcher_exact_match() {
402        let matcher = ArnMatcher::from_pattern("arn:aws:s3:::my-bucket/*").unwrap();
403
404        assert!(matcher.matches("arn:aws:s3:::my-bucket/file.txt").unwrap());
405        assert!(
406            matcher
407                .matches("arn:aws:s3:::my-bucket/folder/file.txt")
408                .unwrap()
409        );
410        assert!(
411            !matcher
412                .matches("arn:aws:s3:::other-bucket/file.txt")
413                .unwrap()
414        );
415        assert!(
416            !matcher
417                .matches("arn:aws:ec2:us-east-1:123456789012:instance/i-1234567890abcdef0")
418                .unwrap()
419        );
420    }
421
422    #[test]
423    fn test_arn_matcher_wildcard() {
424        let matcher = ArnMatcher::from_pattern("arn:aws:s3:*:*:*").unwrap();
425
426        assert!(matcher.matches("arn:aws:s3:::my-bucket/file.txt").unwrap());
427        assert!(
428            matcher
429                .matches("arn:aws:s3:us-east-1:123456789012:bucket/my-bucket")
430                .unwrap()
431        );
432        assert!(
433            !matcher
434                .matches("arn:aws:ec2:us-east-1:123456789012:instance/i-1234567890abcdef0")
435                .unwrap()
436        );
437    }
438
439    #[test]
440    fn test_arn_matcher_multiple_patterns() {
441        let patterns = vec![
442            "arn:aws:s3:::my-bucket/*".to_string(),
443            "arn:aws:ec2:*:*:instance/*".to_string(),
444        ];
445        let matcher = ArnMatcher::new(patterns).unwrap();
446
447        assert!(matcher.matches("arn:aws:s3:::my-bucket/file.txt").unwrap());
448        assert!(
449            matcher
450                .matches("arn:aws:ec2:us-east-1:123456789012:instance/i-1234567890abcdef0")
451                .unwrap()
452        );
453        assert!(
454            !matcher
455                .matches("arn:aws:iam::123456789012:user/username")
456                .unwrap()
457        );
458    }
459
460    #[test]
461    fn test_arn_matcher_star_matches_all() {
462        let matcher = ArnMatcher::from_pattern("*").unwrap();
463
464        assert!(matcher.matches_all());
465        assert!(matcher.matches("arn:aws:s3:::my-bucket/file.txt").unwrap());
466        assert!(
467            matcher
468                .matches("arn:aws:ec2:us-east-1:123456789012:instance/i-1234567890abcdef0")
469                .unwrap()
470        );
471        assert!(
472            matcher
473                .matches("arn:aws:iam::123456789012:user/username")
474                .unwrap()
475        );
476    }
477
478    #[test]
479    fn test_arn_matcher_service_wildcards_rejected() {
480        let matcher = ArnMatcher::from_pattern("arn:aws:*:*:*:*").unwrap();
481
482        // Service wildcards should not match anything for security
483        assert!(!matcher.matches("arn:aws:s3:::my-bucket/file.txt").unwrap());
484        assert!(
485            !matcher
486                .matches("arn:aws:ec2:us-east-1:123456789012:instance/i-1234567890abcdef0")
487                .unwrap()
488        );
489    }
490
491    #[test]
492    fn test_arn_builder() {
493        let arn = ArnBuilder::new()
494            .partition("aws")
495            .service("s3")
496            .region("us-east-1")
497            .account_id("123456789012")
498            .resource("bucket", "my-bucket")
499            .build()
500            .unwrap();
501
502        assert_eq!(arn.partition, "aws");
503        assert_eq!(arn.service, "s3");
504        assert_eq!(arn.region, "us-east-1");
505        assert_eq!(arn.account_id, "123456789012");
506        assert_eq!(arn.resource, "bucket/my-bucket");
507        assert_eq!(
508            arn.to_string(),
509            "arn:aws:s3:us-east-1:123456789012:bucket/my-bucket"
510        );
511    }
512
513    #[test]
514    fn test_arn_builder_defaults() {
515        let arn = ArnBuilder::new()
516            .service("iam")
517            .resource("user", "test-user")
518            .build()
519            .unwrap();
520
521        assert_eq!(arn.partition, "aws");
522        assert_eq!(arn.service, "iam");
523        assert_eq!(arn.region, "");
524        assert_eq!(arn.account_id, "");
525        assert_eq!(arn.resource, "user/test-user");
526    }
527
528    #[test]
529    fn test_arn_set_operations() {
530        let mut arn_set = ArnSet::new();
531
532        arn_set
533            .add("arn:aws:s3:::bucket1/file1.txt".to_string())
534            .unwrap();
535        arn_set
536            .add("arn:aws:s3:::bucket2/file2.txt".to_string())
537            .unwrap();
538        arn_set
539            .add("arn:aws:ec2:us-east-1:123456789012:instance/i-1234567890abcdef0".to_string())
540            .unwrap();
541
542        assert_eq!(arn_set.len(), 3);
543        assert!(arn_set.contains("arn:aws:s3:::bucket1/file1.txt"));
544        assert!(!arn_set.contains("arn:aws:s3:::bucket3/file3.txt"));
545
546        let s3_arns = arn_set.filter_by_service("s3").unwrap();
547        assert_eq!(s3_arns.len(), 2);
548
549        let ec2_arns = arn_set.filter_by_service("ec2").unwrap();
550        assert_eq!(ec2_arns.len(), 1);
551    }
552
553    #[test]
554    fn test_arn_set_pattern_filtering() {
555        let arns = vec![
556            "arn:aws:s3:::my-bucket/file1.txt".to_string(),
557            "arn:aws:s3:::my-bucket/file2.txt".to_string(),
558            "arn:aws:s3:::other-bucket/file3.txt".to_string(),
559            "arn:aws:ec2:us-east-1:123456789012:instance/i-1234567890abcdef0".to_string(),
560        ];
561        let arn_set = ArnSet::from_arns(arns).unwrap();
562
563        let patterns = vec!["arn:aws:s3:::my-bucket/*".to_string()];
564        let matching = arn_set.filter_by_patterns(&patterns).unwrap();
565
566        assert_eq!(matching.len(), 2);
567        assert!(matching.contains(&"arn:aws:s3:::my-bucket/file1.txt"));
568        assert!(matching.contains(&"arn:aws:s3:::my-bucket/file2.txt"));
569        assert!(!matching.contains(&"arn:aws:s3:::other-bucket/file3.txt"));
570    }
571
572    #[test]
573    fn test_arn_matcher_performance_optimization() {
574        // Test that exact matches (no wildcards) are handled efficiently
575        let matcher = ArnMatcher::from_pattern("arn:aws:s3:::my-bucket/specific-file.txt").unwrap();
576
577        assert!(
578            matcher
579                .matches("arn:aws:s3:::my-bucket/specific-file.txt")
580                .unwrap()
581        );
582        assert!(
583            !matcher
584                .matches("arn:aws:s3:::my-bucket/other-file.txt")
585                .unwrap()
586        );
587    }
588
589    #[test]
590    fn test_matching_patterns_list() {
591        let patterns = vec![
592            "arn:aws:s3:::bucket1/*".to_string(),
593            "arn:aws:s3:::bucket2/*".to_string(),
594            "arn:aws:ec2:*:*:instance/*".to_string(),
595        ];
596        let matcher = ArnMatcher::new(patterns).unwrap();
597
598        let matching = matcher
599            .matching_patterns("arn:aws:s3:::bucket1/file.txt")
600            .unwrap();
601        assert_eq!(matching, vec!["arn:aws:s3:::bucket1/*"]);
602
603        let matching2 = matcher
604            .matching_patterns("arn:aws:ec2:us-east-1:123456789012:instance/i-123")
605            .unwrap();
606        assert_eq!(matching2, vec!["arn:aws:ec2:*:*:instance/*"]);
607    }
608}