Skip to main content

gem_audit/version/
requirement.rs

1use std::fmt;
2use thiserror::Error;
3
4use super::gem_version::Version;
5
6/// The comparison operator for a version constraint.
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum Operator {
9    /// `=`  — exactly equal
10    Equal,
11    /// `!=` — not equal
12    NotEqual,
13    /// `>`  — strictly greater than
14    GreaterThan,
15    /// `<`  — strictly less than
16    LessThan,
17    /// `>=` — greater than or equal
18    GreaterThanOrEqual,
19    /// `<=` — less than or equal
20    LessThanOrEqual,
21    /// `~>` — pessimistic version constraint
22    Pessimistic,
23}
24
25impl fmt::Display for Operator {
26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27        match self {
28            Operator::Equal => write!(f, "="),
29            Operator::NotEqual => write!(f, "!="),
30            Operator::GreaterThan => write!(f, ">"),
31            Operator::LessThan => write!(f, "<"),
32            Operator::GreaterThanOrEqual => write!(f, ">="),
33            Operator::LessThanOrEqual => write!(f, "<="),
34            Operator::Pessimistic => write!(f, "~>"),
35        }
36    }
37}
38
39/// A single version constraint, e.g., `>= 1.0.0` or `~> 2.3`.
40#[derive(Debug, Clone, PartialEq, Eq)]
41pub struct VersionConstraint {
42    pub operator: Operator,
43    pub version: Version,
44}
45
46impl VersionConstraint {
47    /// Check if the given version satisfies this constraint.
48    pub fn satisfied_by(&self, version: &Version) -> bool {
49        match &self.operator {
50            Operator::Equal => version == &self.version,
51            Operator::NotEqual => version != &self.version,
52            Operator::GreaterThan => version > &self.version,
53            Operator::LessThan => version < &self.version,
54            Operator::GreaterThanOrEqual => version >= &self.version,
55            Operator::LessThanOrEqual => version <= &self.version,
56            Operator::Pessimistic => {
57                // ~> X.Y.Z means >= X.Y.Z AND < X.Y+1.0
58                // ~> X.Y means >= X.Y AND < X+1.0
59                let upper = self.version.bump();
60                version >= &self.version && version < &upper
61            }
62        }
63    }
64}
65
66impl VersionConstraint {
67    /// Return the minimum version that satisfies this single constraint, if one exists.
68    ///
69    /// - `>= X` → `X`
70    /// - `> X`  → `X` with last segment incremented
71    /// - `~> X` → `X`
72    /// - `= X`  → `X`
73    /// - `<=`, `<`, `!=` → `None` (upper bounds / exclusions have no lower bound)
74    pub fn minimum_version(&self) -> Option<Version> {
75        match &self.operator {
76            Operator::GreaterThanOrEqual | Operator::Pessimistic | Operator::Equal => {
77                Some(self.version.clone())
78            }
79            Operator::GreaterThan => Some(self.version.increment_last()),
80            Operator::LessThan | Operator::LessThanOrEqual | Operator::NotEqual => None,
81        }
82    }
83}
84
85impl fmt::Display for VersionConstraint {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        write!(f, "{} {}", self.operator, self.version)
88    }
89}
90
91/// A compound version requirement (one or more constraints, all must be satisfied).
92///
93/// This is equivalent to Ruby's `Gem::Requirement`. A version must satisfy ALL
94/// constraints to match the requirement.
95///
96/// # Examples
97/// ```
98/// use gem_audit::version::Requirement;
99///
100/// let req = Requirement::parse("~> 1.2.3").unwrap();
101/// ```
102#[derive(Debug, Clone, PartialEq, Eq)]
103pub struct Requirement {
104    pub constraints: Vec<VersionConstraint>,
105}
106
107#[derive(Debug, Clone, PartialEq, Eq, Error)]
108pub enum RequirementError {
109    #[error("invalid operator: '{0}'")]
110    InvalidOperator(String),
111    #[error("invalid version: '{0}'")]
112    InvalidVersion(String),
113    #[error("empty requirement string")]
114    Empty,
115}
116
117impl Requirement {
118    /// Parse a requirement string.
119    ///
120    /// Supports:
121    /// - Single constraint: `">= 1.0.0"`, `"~> 2.3"`
122    /// - Compound constraints (comma-separated): `">= 1.0, < 2.0"`
123    /// - Default operator is `=` when omitted: `"1.0.0"` means `"= 1.0.0"`
124    pub fn parse(input: &str) -> Result<Self, RequirementError> {
125        let input = input.trim();
126        if input.is_empty() {
127            return Ok(Requirement::default());
128        }
129
130        let parts: Vec<&str> = input.split(',').map(|s| s.trim()).collect();
131        let mut constraints = Vec::with_capacity(parts.len());
132
133        for part in parts {
134            let constraint = parse_single_constraint(part)?;
135            constraints.push(constraint);
136        }
137
138        if constraints.is_empty() {
139            return Err(RequirementError::Empty);
140        }
141
142        Ok(Requirement { constraints })
143    }
144
145    /// Parse multiple requirement strings (as Ruby's `Gem::Requirement.new(*args)`).
146    ///
147    /// Each string can itself contain comma-separated constraints.
148    pub fn parse_multiple(inputs: &[&str]) -> Result<Self, RequirementError> {
149        let mut constraints = Vec::new();
150
151        for input in inputs {
152            let req = Requirement::parse(input)?;
153            constraints.extend(req.constraints);
154        }
155
156        if constraints.is_empty() {
157            return Ok(Requirement::default());
158        }
159
160        Ok(Requirement { constraints })
161    }
162
163    /// Check if the given version satisfies all constraints in this requirement.
164    pub fn satisfied_by(&self, version: &Version) -> bool {
165        self.constraints.iter().all(|c| c.satisfied_by(version))
166    }
167
168    /// Compute the minimum version that satisfies all constraints in this requirement.
169    ///
170    /// Takes the maximum of all lower bounds, then verifies it passes all upper bounds
171    /// and exclusions. If a `!=` excludes the candidate, tries incrementing the last segment.
172    pub fn minimum_version(&self) -> Option<Version> {
173        // Collect lower-bound candidates
174        let mut candidate: Option<Version> = None;
175        for c in &self.constraints {
176            if let Some(v) = c.minimum_version() {
177                candidate = Some(match candidate {
178                    Some(cur) if v > cur => v,
179                    Some(cur) => cur,
180                    None => v,
181                });
182            }
183        }
184
185        // If no lower bound found, requirement is pure upper-bound (e.g. "< 2.0")
186        let mut candidate = candidate?;
187
188        // Verify candidate satisfies all constraints; handle != by incrementing
189        for _ in 0..100 {
190            if self.satisfied_by(&candidate) {
191                return Some(candidate);
192            }
193            // If blocked by !=, try next patch
194            candidate = candidate.increment_last();
195        }
196
197        None
198    }
199}
200
201impl Default for Requirement {
202    /// The default requirement: `>= 0` (matches any version).
203    fn default() -> Self {
204        Requirement {
205            constraints: vec![VersionConstraint {
206                operator: Operator::GreaterThanOrEqual,
207                version: Version::parse("0").unwrap(),
208            }],
209        }
210    }
211}
212
213impl fmt::Display for Requirement {
214    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
215        let parts: Vec<String> = self.constraints.iter().map(|c| c.to_string()).collect();
216        write!(f, "{}", parts.join(", "))
217    }
218}
219
220/// Parse a single constraint string like ">= 1.0.0" or "~> 2.3" or "1.0.0".
221fn parse_single_constraint(input: &str) -> Result<VersionConstraint, RequirementError> {
222    let input = input.trim();
223
224    if input.is_empty() {
225        return Err(RequirementError::Empty);
226    }
227
228    // Try to extract operator + version (check 2-char operators first)
229    let (operator, version_str) = if let Some(rest) = input.strip_prefix("~>") {
230        (Operator::Pessimistic, rest.trim())
231    } else if let Some(rest) = input.strip_prefix(">=") {
232        (Operator::GreaterThanOrEqual, rest.trim())
233    } else if let Some(rest) = input.strip_prefix("<=") {
234        (Operator::LessThanOrEqual, rest.trim())
235    } else if let Some(rest) = input.strip_prefix("!=") {
236        (Operator::NotEqual, rest.trim())
237    } else if let Some(rest) = input.strip_prefix('>') {
238        (Operator::GreaterThan, rest.trim())
239    } else if let Some(rest) = input.strip_prefix('<') {
240        (Operator::LessThan, rest.trim())
241    } else if let Some(rest) = input.strip_prefix('=') {
242        (Operator::Equal, rest.trim())
243    } else {
244        // No operator, default to Equal
245        (Operator::Equal, input)
246    };
247
248    let version = Version::parse(version_str)
249        .map_err(|_| RequirementError::InvalidVersion(version_str.to_string()))?;
250
251    Ok(VersionConstraint { operator, version })
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    // ========== Parsing Tests ==========
259
260    #[test]
261    fn parse_simple_equality() {
262        let req = Requirement::parse("= 1.0.0").unwrap();
263        assert_eq!(req.constraints.len(), 1);
264        assert_eq!(req.constraints[0].operator, Operator::Equal);
265        assert_eq!(req.constraints[0].version, Version::parse("1.0.0").unwrap());
266    }
267
268    #[test]
269    fn parse_pessimistic() {
270        let req = Requirement::parse("~> 1.2.3").unwrap();
271        assert_eq!(req.constraints[0].operator, Operator::Pessimistic);
272        assert_eq!(req.constraints[0].version, Version::parse("1.2.3").unwrap());
273    }
274
275    #[test]
276    fn parse_greater_than_or_equal() {
277        let req = Requirement::parse(">= 2.0").unwrap();
278        assert_eq!(req.constraints[0].operator, Operator::GreaterThanOrEqual);
279    }
280
281    #[test]
282    fn parse_less_than() {
283        let req = Requirement::parse("< 3.0").unwrap();
284        assert_eq!(req.constraints[0].operator, Operator::LessThan);
285    }
286
287    #[test]
288    fn parse_not_equal() {
289        let req = Requirement::parse("!= 1.5").unwrap();
290        assert_eq!(req.constraints[0].operator, Operator::NotEqual);
291    }
292
293    #[test]
294    fn parse_compound_requirement() {
295        let req = Requirement::parse(">= 1.0, < 2.0").unwrap();
296        assert_eq!(req.constraints.len(), 2);
297        assert_eq!(req.constraints[0].operator, Operator::GreaterThanOrEqual);
298        assert_eq!(req.constraints[1].operator, Operator::LessThan);
299    }
300
301    #[test]
302    fn parse_no_operator_defaults_to_equal() {
303        let req = Requirement::parse("1.0.0").unwrap();
304        assert_eq!(req.constraints[0].operator, Operator::Equal);
305        assert_eq!(req.constraints[0].version, Version::parse("1.0.0").unwrap());
306    }
307
308    #[test]
309    fn parse_multiple_strings() {
310        let req = Requirement::parse_multiple(&[">= 1.0", "< 2.0", "!= 1.5"]).unwrap();
311        assert_eq!(req.constraints.len(), 3);
312    }
313
314    // ========== Default Requirement ==========
315
316    #[test]
317    fn default_requirement_matches_any() {
318        let req = Requirement::default();
319        assert!(req.satisfied_by(&Version::parse("0").unwrap()));
320        assert!(req.satisfied_by(&Version::parse("1.0.0").unwrap()));
321        assert!(req.satisfied_by(&Version::parse("999.999.999").unwrap()));
322    }
323
324    // ========== Equality Operator ==========
325
326    #[test]
327    fn equal_matches_exact() {
328        let req = Requirement::parse("= 1.0.0").unwrap();
329        assert!(req.satisfied_by(&Version::parse("1.0.0").unwrap()));
330        assert!(req.satisfied_by(&Version::parse("1.0").unwrap())); // trailing zero equivalence
331        assert!(!req.satisfied_by(&Version::parse("1.0.1").unwrap()));
332        assert!(!req.satisfied_by(&Version::parse("0.9.9").unwrap()));
333    }
334
335    // ========== Not Equal Operator ==========
336
337    #[test]
338    fn not_equal_excludes_version() {
339        let req = Requirement::parse("!= 1.5").unwrap();
340        assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
341        assert!(req.satisfied_by(&Version::parse("2.0").unwrap()));
342        assert!(!req.satisfied_by(&Version::parse("1.5").unwrap()));
343        assert!(!req.satisfied_by(&Version::parse("1.5.0").unwrap()));
344    }
345
346    // ========== Greater Than ==========
347
348    #[test]
349    fn greater_than() {
350        let req = Requirement::parse("> 1.0").unwrap();
351        assert!(req.satisfied_by(&Version::parse("1.0.1").unwrap()));
352        assert!(req.satisfied_by(&Version::parse("2.0").unwrap()));
353        assert!(!req.satisfied_by(&Version::parse("1.0").unwrap()));
354        assert!(!req.satisfied_by(&Version::parse("0.9").unwrap()));
355    }
356
357    // ========== Less Than ==========
358
359    #[test]
360    fn less_than() {
361        let req = Requirement::parse("< 2.0").unwrap();
362        assert!(req.satisfied_by(&Version::parse("1.9.9").unwrap()));
363        assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
364        assert!(!req.satisfied_by(&Version::parse("2.0").unwrap()));
365        assert!(!req.satisfied_by(&Version::parse("2.0.1").unwrap()));
366    }
367
368    // ========== Greater Than Or Equal ==========
369
370    #[test]
371    fn greater_than_or_equal() {
372        let req = Requirement::parse(">= 1.0").unwrap();
373        assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
374        assert!(req.satisfied_by(&Version::parse("1.0.0").unwrap()));
375        assert!(req.satisfied_by(&Version::parse("2.0").unwrap()));
376        assert!(!req.satisfied_by(&Version::parse("0.9.9").unwrap()));
377    }
378
379    // ========== Less Than Or Equal ==========
380
381    #[test]
382    fn less_than_or_equal() {
383        let req = Requirement::parse("<= 2.0").unwrap();
384        assert!(req.satisfied_by(&Version::parse("2.0").unwrap()));
385        assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
386        assert!(!req.satisfied_by(&Version::parse("2.0.1").unwrap()));
387    }
388
389    // ========== Pessimistic Operator (~>) ==========
390
391    #[test]
392    fn pessimistic_two_segments() {
393        // ~> 2.3 means >= 2.3, < 3.0
394        let req = Requirement::parse("~> 2.3").unwrap();
395        assert!(req.satisfied_by(&Version::parse("2.3").unwrap()));
396        assert!(req.satisfied_by(&Version::parse("2.5").unwrap()));
397        assert!(req.satisfied_by(&Version::parse("2.9.9").unwrap()));
398        assert!(!req.satisfied_by(&Version::parse("3.0").unwrap()));
399        assert!(!req.satisfied_by(&Version::parse("2.2").unwrap()));
400    }
401
402    #[test]
403    fn pessimistic_three_segments() {
404        // ~> 2.3.0 means >= 2.3.0, < 2.4.0
405        let req = Requirement::parse("~> 2.3.0").unwrap();
406        assert!(req.satisfied_by(&Version::parse("2.3.0").unwrap()));
407        assert!(req.satisfied_by(&Version::parse("2.3.5").unwrap()));
408        assert!(req.satisfied_by(&Version::parse("2.3.99").unwrap()));
409        assert!(!req.satisfied_by(&Version::parse("2.4.0").unwrap()));
410        assert!(!req.satisfied_by(&Version::parse("2.2.9").unwrap()));
411    }
412
413    #[test]
414    fn pessimistic_three_segments_nonzero() {
415        // ~> 2.3.18 means >= 2.3.18, < 2.4.0
416        let req = Requirement::parse("~> 2.3.18").unwrap();
417        assert!(req.satisfied_by(&Version::parse("2.3.18").unwrap()));
418        assert!(req.satisfied_by(&Version::parse("2.3.20").unwrap()));
419        assert!(!req.satisfied_by(&Version::parse("2.3.17").unwrap()));
420        assert!(!req.satisfied_by(&Version::parse("2.4.0").unwrap()));
421    }
422
423    #[test]
424    fn pessimistic_single_segment() {
425        // ~> 2 means >= 2, < 3
426        let req = Requirement::parse("~> 2").unwrap();
427        assert!(req.satisfied_by(&Version::parse("2.0").unwrap()));
428        assert!(req.satisfied_by(&Version::parse("2.9.9").unwrap()));
429        assert!(!req.satisfied_by(&Version::parse("3.0").unwrap()));
430        assert!(!req.satisfied_by(&Version::parse("1.9").unwrap()));
431    }
432
433    #[test]
434    fn pessimistic_four_segments() {
435        // ~> 1.2.3.4 means >= 1.2.3.4, < 1.2.4.0
436        let req = Requirement::parse("~> 1.2.3.4").unwrap();
437        assert!(req.satisfied_by(&Version::parse("1.2.3.4").unwrap()));
438        assert!(req.satisfied_by(&Version::parse("1.2.3.99").unwrap()));
439        assert!(!req.satisfied_by(&Version::parse("1.2.4.0").unwrap()));
440        assert!(!req.satisfied_by(&Version::parse("1.2.3.3").unwrap()));
441    }
442
443    // ========== Compound Requirements ==========
444
445    #[test]
446    fn compound_range() {
447        let req = Requirement::parse(">= 1.0, < 2.0").unwrap();
448        assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
449        assert!(req.satisfied_by(&Version::parse("1.5").unwrap()));
450        assert!(req.satisfied_by(&Version::parse("1.9.9").unwrap()));
451        assert!(!req.satisfied_by(&Version::parse("0.9").unwrap()));
452        assert!(!req.satisfied_by(&Version::parse("2.0").unwrap()));
453    }
454
455    #[test]
456    fn compound_with_exclusion() {
457        let req = Requirement::parse(">= 1.0, < 2.0, != 1.5").unwrap();
458        assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
459        assert!(req.satisfied_by(&Version::parse("1.4.9").unwrap()));
460        assert!(req.satisfied_by(&Version::parse("1.5.1").unwrap()));
461        assert!(!req.satisfied_by(&Version::parse("1.5").unwrap()));
462        assert!(!req.satisfied_by(&Version::parse("2.0").unwrap()));
463    }
464
465    // ========== Real-world Advisory Patterns ==========
466
467    #[test]
468    fn advisory_patched_versions_pattern() {
469        // From a typical advisory:
470        // patched_versions:
471        //   - "~> 0.1.42"
472        //   - "~> 0.2.42"
473        //   - ">= 1.0.0"
474
475        let patch1 = Requirement::parse("~> 0.1.42").unwrap();
476        let patch2 = Requirement::parse("~> 0.2.42").unwrap();
477        let patch3 = Requirement::parse(">= 1.0.0").unwrap();
478
479        let is_patched = |v: &str| -> bool {
480            let ver = Version::parse(v).unwrap();
481            patch1.satisfied_by(&ver) || patch2.satisfied_by(&ver) || patch3.satisfied_by(&ver)
482        };
483
484        // Patched versions
485        assert!(is_patched("0.1.42"));
486        assert!(is_patched("0.1.50"));
487        assert!(is_patched("0.2.42"));
488        assert!(is_patched("0.2.99"));
489        assert!(is_patched("1.0.0"));
490        assert!(is_patched("2.0.0"));
491
492        // Vulnerable versions
493        assert!(!is_patched("0.1.0"));
494        assert!(!is_patched("0.1.41"));
495        assert!(!is_patched("0.2.0"));
496        assert!(!is_patched("0.2.41"));
497        assert!(!is_patched("0.3.0")); // not covered by ~> 0.2.42 (which is < 0.3.0)
498        assert!(!is_patched("0.9.0"));
499    }
500
501    #[test]
502    fn advisory_unaffected_versions_pattern() {
503        // unaffected_versions:
504        //   - "< 0.1.0"
505        let unaffected = Requirement::parse("< 0.1.0").unwrap();
506
507        assert!(unaffected.satisfied_by(&Version::parse("0.0.9").unwrap()));
508        assert!(unaffected.satisfied_by(&Version::parse("0.0.1").unwrap()));
509        assert!(!unaffected.satisfied_by(&Version::parse("0.1.0").unwrap()));
510        assert!(!unaffected.satisfied_by(&Version::parse("0.2.0").unwrap()));
511    }
512
513    #[test]
514    fn vulnerability_check_full() {
515        // Simulating the full vulnerability check logic from advisory.rb
516        let patched: Vec<Requirement> = vec![
517            Requirement::parse("~> 0.1.42").unwrap(),
518            Requirement::parse("~> 0.2.42").unwrap(),
519            Requirement::parse(">= 1.0.0").unwrap(),
520        ];
521        let unaffected: Vec<Requirement> = vec![Requirement::parse("< 0.1.0").unwrap()];
522
523        let is_patched = |v: &Version| -> bool { patched.iter().any(|req| req.satisfied_by(v)) };
524        let is_unaffected =
525            |v: &Version| -> bool { unaffected.iter().any(|req| req.satisfied_by(v)) };
526        let is_vulnerable = |v: &str| -> bool {
527            let ver = Version::parse(v).unwrap();
528            !is_patched(&ver) && !is_unaffected(&ver)
529        };
530
531        // Unaffected (too old to be affected)
532        assert!(!is_vulnerable("0.0.9"));
533
534        // Patched
535        assert!(!is_vulnerable("0.1.42"));
536        assert!(!is_vulnerable("1.0.0"));
537        assert!(!is_vulnerable("2.0.0"));
538
539        // Vulnerable
540        assert!(is_vulnerable("0.1.0"));
541        assert!(is_vulnerable("0.1.41"));
542        assert!(is_vulnerable("0.2.0"));
543        assert!(is_vulnerable("0.2.41"));
544        assert!(is_vulnerable("0.3.0"));
545    }
546
547    // ========== Display ==========
548
549    #[test]
550    fn display_single_constraint() {
551        let req = Requirement::parse("~> 1.2.3").unwrap();
552        assert_eq!(req.to_string(), "~> 1.2.3");
553    }
554
555    #[test]
556    fn display_compound() {
557        let req = Requirement::parse(">= 1.0, < 2.0").unwrap();
558        assert_eq!(req.to_string(), ">= 1.0, < 2.0");
559    }
560}