Skip to main content

gem_audit/version/
requirement.rs

1use std::fmt;
2use thiserror::Error;
3
4use super::gem_version::Version;
5
6/// The comparison operator for a version constraint.
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum Operator {
9    /// `=`  — exactly equal
10    Equal,
11    /// `!=` — not equal
12    NotEqual,
13    /// `>`  — strictly greater than
14    GreaterThan,
15    /// `<`  — strictly less than
16    LessThan,
17    /// `>=` — greater than or equal
18    GreaterThanOrEqual,
19    /// `<=` — less than or equal
20    LessThanOrEqual,
21    /// `~>` — pessimistic version constraint
22    Pessimistic,
23}
24
25impl fmt::Display for Operator {
26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27        match self {
28            Operator::Equal => write!(f, "="),
29            Operator::NotEqual => write!(f, "!="),
30            Operator::GreaterThan => write!(f, ">"),
31            Operator::LessThan => write!(f, "<"),
32            Operator::GreaterThanOrEqual => write!(f, ">="),
33            Operator::LessThanOrEqual => write!(f, "<="),
34            Operator::Pessimistic => write!(f, "~>"),
35        }
36    }
37}
38
39/// A single version constraint, e.g., `>= 1.0.0` or `~> 2.3`.
40#[derive(Debug, Clone, PartialEq, Eq)]
41pub struct VersionConstraint {
42    pub operator: Operator,
43    pub version: Version,
44}
45
46impl VersionConstraint {
47    /// Check if the given version satisfies this constraint.
48    pub fn satisfied_by(&self, version: &Version) -> bool {
49        match &self.operator {
50            Operator::Equal => version == &self.version,
51            Operator::NotEqual => version != &self.version,
52            Operator::GreaterThan => version > &self.version,
53            Operator::LessThan => version < &self.version,
54            Operator::GreaterThanOrEqual => version >= &self.version,
55            Operator::LessThanOrEqual => version <= &self.version,
56            Operator::Pessimistic => {
57                // ~> X.Y.Z means >= X.Y.Z AND < X.Y+1.0
58                // ~> X.Y means >= X.Y AND < X+1.0
59                let upper = self.version.bump();
60                version >= &self.version && version < &upper
61            }
62        }
63    }
64}
65
66impl fmt::Display for VersionConstraint {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        write!(f, "{} {}", self.operator, self.version)
69    }
70}
71
72/// A compound version requirement (one or more constraints, all must be satisfied).
73///
74/// This is equivalent to Ruby's `Gem::Requirement`. A version must satisfy ALL
75/// constraints to match the requirement.
76///
77/// # Examples
78/// ```
79/// use gem_audit::version::Requirement;
80///
81/// let req = Requirement::parse("~> 1.2.3").unwrap();
82/// ```
83#[derive(Debug, Clone, PartialEq, Eq)]
84pub struct Requirement {
85    pub constraints: Vec<VersionConstraint>,
86}
87
88#[derive(Debug, Clone, PartialEq, Eq, Error)]
89pub enum RequirementError {
90    #[error("invalid operator: '{0}'")]
91    InvalidOperator(String),
92    #[error("invalid version: '{0}'")]
93    InvalidVersion(String),
94    #[error("empty requirement string")]
95    Empty,
96}
97
98impl Requirement {
99    /// Parse a requirement string.
100    ///
101    /// Supports:
102    /// - Single constraint: `">= 1.0.0"`, `"~> 2.3"`
103    /// - Compound constraints (comma-separated): `">= 1.0, < 2.0"`
104    /// - Default operator is `=` when omitted: `"1.0.0"` means `"= 1.0.0"`
105    pub fn parse(input: &str) -> Result<Self, RequirementError> {
106        let input = input.trim();
107        if input.is_empty() {
108            return Ok(Requirement::default());
109        }
110
111        let parts: Vec<&str> = input.split(',').map(|s| s.trim()).collect();
112        let mut constraints = Vec::with_capacity(parts.len());
113
114        for part in parts {
115            let constraint = parse_single_constraint(part)?;
116            constraints.push(constraint);
117        }
118
119        if constraints.is_empty() {
120            return Err(RequirementError::Empty);
121        }
122
123        Ok(Requirement { constraints })
124    }
125
126    /// Parse multiple requirement strings (as Ruby's `Gem::Requirement.new(*args)`).
127    ///
128    /// Each string can itself contain comma-separated constraints.
129    pub fn parse_multiple(inputs: &[&str]) -> Result<Self, RequirementError> {
130        let mut constraints = Vec::new();
131
132        for input in inputs {
133            let req = Requirement::parse(input)?;
134            constraints.extend(req.constraints);
135        }
136
137        if constraints.is_empty() {
138            return Ok(Requirement::default());
139        }
140
141        Ok(Requirement { constraints })
142    }
143
144    /// Check if the given version satisfies all constraints in this requirement.
145    pub fn satisfied_by(&self, version: &Version) -> bool {
146        self.constraints.iter().all(|c| c.satisfied_by(version))
147    }
148}
149
150impl Default for Requirement {
151    /// The default requirement: `>= 0` (matches any version).
152    fn default() -> Self {
153        Requirement {
154            constraints: vec![VersionConstraint {
155                operator: Operator::GreaterThanOrEqual,
156                version: Version::parse("0").unwrap(),
157            }],
158        }
159    }
160}
161
162impl fmt::Display for Requirement {
163    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164        let parts: Vec<String> = self.constraints.iter().map(|c| c.to_string()).collect();
165        write!(f, "{}", parts.join(", "))
166    }
167}
168
169/// Parse a single constraint string like ">= 1.0.0" or "~> 2.3" or "1.0.0".
170fn parse_single_constraint(input: &str) -> Result<VersionConstraint, RequirementError> {
171    let input = input.trim();
172
173    if input.is_empty() {
174        return Err(RequirementError::Empty);
175    }
176
177    // Try to extract operator + version (check 2-char operators first)
178    let (operator, version_str) = if let Some(rest) = input.strip_prefix("~>") {
179        (Operator::Pessimistic, rest.trim())
180    } else if let Some(rest) = input.strip_prefix(">=") {
181        (Operator::GreaterThanOrEqual, rest.trim())
182    } else if let Some(rest) = input.strip_prefix("<=") {
183        (Operator::LessThanOrEqual, rest.trim())
184    } else if let Some(rest) = input.strip_prefix("!=") {
185        (Operator::NotEqual, rest.trim())
186    } else if let Some(rest) = input.strip_prefix('>') {
187        (Operator::GreaterThan, rest.trim())
188    } else if let Some(rest) = input.strip_prefix('<') {
189        (Operator::LessThan, rest.trim())
190    } else if let Some(rest) = input.strip_prefix('=') {
191        (Operator::Equal, rest.trim())
192    } else {
193        // No operator, default to Equal
194        (Operator::Equal, input)
195    };
196
197    let version = Version::parse(version_str)
198        .map_err(|_| RequirementError::InvalidVersion(version_str.to_string()))?;
199
200    Ok(VersionConstraint { operator, version })
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    // ========== Parsing Tests ==========
208
209    #[test]
210    fn parse_simple_equality() {
211        let req = Requirement::parse("= 1.0.0").unwrap();
212        assert_eq!(req.constraints.len(), 1);
213        assert_eq!(req.constraints[0].operator, Operator::Equal);
214        assert_eq!(req.constraints[0].version, Version::parse("1.0.0").unwrap());
215    }
216
217    #[test]
218    fn parse_pessimistic() {
219        let req = Requirement::parse("~> 1.2.3").unwrap();
220        assert_eq!(req.constraints[0].operator, Operator::Pessimistic);
221        assert_eq!(req.constraints[0].version, Version::parse("1.2.3").unwrap());
222    }
223
224    #[test]
225    fn parse_greater_than_or_equal() {
226        let req = Requirement::parse(">= 2.0").unwrap();
227        assert_eq!(req.constraints[0].operator, Operator::GreaterThanOrEqual);
228    }
229
230    #[test]
231    fn parse_less_than() {
232        let req = Requirement::parse("< 3.0").unwrap();
233        assert_eq!(req.constraints[0].operator, Operator::LessThan);
234    }
235
236    #[test]
237    fn parse_not_equal() {
238        let req = Requirement::parse("!= 1.5").unwrap();
239        assert_eq!(req.constraints[0].operator, Operator::NotEqual);
240    }
241
242    #[test]
243    fn parse_compound_requirement() {
244        let req = Requirement::parse(">= 1.0, < 2.0").unwrap();
245        assert_eq!(req.constraints.len(), 2);
246        assert_eq!(req.constraints[0].operator, Operator::GreaterThanOrEqual);
247        assert_eq!(req.constraints[1].operator, Operator::LessThan);
248    }
249
250    #[test]
251    fn parse_no_operator_defaults_to_equal() {
252        let req = Requirement::parse("1.0.0").unwrap();
253        assert_eq!(req.constraints[0].operator, Operator::Equal);
254        assert_eq!(req.constraints[0].version, Version::parse("1.0.0").unwrap());
255    }
256
257    #[test]
258    fn parse_multiple_strings() {
259        let req = Requirement::parse_multiple(&[">= 1.0", "< 2.0", "!= 1.5"]).unwrap();
260        assert_eq!(req.constraints.len(), 3);
261    }
262
263    // ========== Default Requirement ==========
264
265    #[test]
266    fn default_requirement_matches_any() {
267        let req = Requirement::default();
268        assert!(req.satisfied_by(&Version::parse("0").unwrap()));
269        assert!(req.satisfied_by(&Version::parse("1.0.0").unwrap()));
270        assert!(req.satisfied_by(&Version::parse("999.999.999").unwrap()));
271    }
272
273    // ========== Equality Operator ==========
274
275    #[test]
276    fn equal_matches_exact() {
277        let req = Requirement::parse("= 1.0.0").unwrap();
278        assert!(req.satisfied_by(&Version::parse("1.0.0").unwrap()));
279        assert!(req.satisfied_by(&Version::parse("1.0").unwrap())); // trailing zero equivalence
280        assert!(!req.satisfied_by(&Version::parse("1.0.1").unwrap()));
281        assert!(!req.satisfied_by(&Version::parse("0.9.9").unwrap()));
282    }
283
284    // ========== Not Equal Operator ==========
285
286    #[test]
287    fn not_equal_excludes_version() {
288        let req = Requirement::parse("!= 1.5").unwrap();
289        assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
290        assert!(req.satisfied_by(&Version::parse("2.0").unwrap()));
291        assert!(!req.satisfied_by(&Version::parse("1.5").unwrap()));
292        assert!(!req.satisfied_by(&Version::parse("1.5.0").unwrap()));
293    }
294
295    // ========== Greater Than ==========
296
297    #[test]
298    fn greater_than() {
299        let req = Requirement::parse("> 1.0").unwrap();
300        assert!(req.satisfied_by(&Version::parse("1.0.1").unwrap()));
301        assert!(req.satisfied_by(&Version::parse("2.0").unwrap()));
302        assert!(!req.satisfied_by(&Version::parse("1.0").unwrap()));
303        assert!(!req.satisfied_by(&Version::parse("0.9").unwrap()));
304    }
305
306    // ========== Less Than ==========
307
308    #[test]
309    fn less_than() {
310        let req = Requirement::parse("< 2.0").unwrap();
311        assert!(req.satisfied_by(&Version::parse("1.9.9").unwrap()));
312        assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
313        assert!(!req.satisfied_by(&Version::parse("2.0").unwrap()));
314        assert!(!req.satisfied_by(&Version::parse("2.0.1").unwrap()));
315    }
316
317    // ========== Greater Than Or Equal ==========
318
319    #[test]
320    fn greater_than_or_equal() {
321        let req = Requirement::parse(">= 1.0").unwrap();
322        assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
323        assert!(req.satisfied_by(&Version::parse("1.0.0").unwrap()));
324        assert!(req.satisfied_by(&Version::parse("2.0").unwrap()));
325        assert!(!req.satisfied_by(&Version::parse("0.9.9").unwrap()));
326    }
327
328    // ========== Less Than Or Equal ==========
329
330    #[test]
331    fn less_than_or_equal() {
332        let req = Requirement::parse("<= 2.0").unwrap();
333        assert!(req.satisfied_by(&Version::parse("2.0").unwrap()));
334        assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
335        assert!(!req.satisfied_by(&Version::parse("2.0.1").unwrap()));
336    }
337
338    // ========== Pessimistic Operator (~>) ==========
339
340    #[test]
341    fn pessimistic_two_segments() {
342        // ~> 2.3 means >= 2.3, < 3.0
343        let req = Requirement::parse("~> 2.3").unwrap();
344        assert!(req.satisfied_by(&Version::parse("2.3").unwrap()));
345        assert!(req.satisfied_by(&Version::parse("2.5").unwrap()));
346        assert!(req.satisfied_by(&Version::parse("2.9.9").unwrap()));
347        assert!(!req.satisfied_by(&Version::parse("3.0").unwrap()));
348        assert!(!req.satisfied_by(&Version::parse("2.2").unwrap()));
349    }
350
351    #[test]
352    fn pessimistic_three_segments() {
353        // ~> 2.3.0 means >= 2.3.0, < 2.4.0
354        let req = Requirement::parse("~> 2.3.0").unwrap();
355        assert!(req.satisfied_by(&Version::parse("2.3.0").unwrap()));
356        assert!(req.satisfied_by(&Version::parse("2.3.5").unwrap()));
357        assert!(req.satisfied_by(&Version::parse("2.3.99").unwrap()));
358        assert!(!req.satisfied_by(&Version::parse("2.4.0").unwrap()));
359        assert!(!req.satisfied_by(&Version::parse("2.2.9").unwrap()));
360    }
361
362    #[test]
363    fn pessimistic_three_segments_nonzero() {
364        // ~> 2.3.18 means >= 2.3.18, < 2.4.0
365        let req = Requirement::parse("~> 2.3.18").unwrap();
366        assert!(req.satisfied_by(&Version::parse("2.3.18").unwrap()));
367        assert!(req.satisfied_by(&Version::parse("2.3.20").unwrap()));
368        assert!(!req.satisfied_by(&Version::parse("2.3.17").unwrap()));
369        assert!(!req.satisfied_by(&Version::parse("2.4.0").unwrap()));
370    }
371
372    #[test]
373    fn pessimistic_single_segment() {
374        // ~> 2 means >= 2, < 3
375        let req = Requirement::parse("~> 2").unwrap();
376        assert!(req.satisfied_by(&Version::parse("2.0").unwrap()));
377        assert!(req.satisfied_by(&Version::parse("2.9.9").unwrap()));
378        assert!(!req.satisfied_by(&Version::parse("3.0").unwrap()));
379        assert!(!req.satisfied_by(&Version::parse("1.9").unwrap()));
380    }
381
382    #[test]
383    fn pessimistic_four_segments() {
384        // ~> 1.2.3.4 means >= 1.2.3.4, < 1.2.4.0
385        let req = Requirement::parse("~> 1.2.3.4").unwrap();
386        assert!(req.satisfied_by(&Version::parse("1.2.3.4").unwrap()));
387        assert!(req.satisfied_by(&Version::parse("1.2.3.99").unwrap()));
388        assert!(!req.satisfied_by(&Version::parse("1.2.4.0").unwrap()));
389        assert!(!req.satisfied_by(&Version::parse("1.2.3.3").unwrap()));
390    }
391
392    // ========== Compound Requirements ==========
393
394    #[test]
395    fn compound_range() {
396        let req = Requirement::parse(">= 1.0, < 2.0").unwrap();
397        assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
398        assert!(req.satisfied_by(&Version::parse("1.5").unwrap()));
399        assert!(req.satisfied_by(&Version::parse("1.9.9").unwrap()));
400        assert!(!req.satisfied_by(&Version::parse("0.9").unwrap()));
401        assert!(!req.satisfied_by(&Version::parse("2.0").unwrap()));
402    }
403
404    #[test]
405    fn compound_with_exclusion() {
406        let req = Requirement::parse(">= 1.0, < 2.0, != 1.5").unwrap();
407        assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
408        assert!(req.satisfied_by(&Version::parse("1.4.9").unwrap()));
409        assert!(req.satisfied_by(&Version::parse("1.5.1").unwrap()));
410        assert!(!req.satisfied_by(&Version::parse("1.5").unwrap()));
411        assert!(!req.satisfied_by(&Version::parse("2.0").unwrap()));
412    }
413
414    // ========== Real-world Advisory Patterns ==========
415
416    #[test]
417    fn advisory_patched_versions_pattern() {
418        // From a typical advisory:
419        // patched_versions:
420        //   - "~> 0.1.42"
421        //   - "~> 0.2.42"
422        //   - ">= 1.0.0"
423
424        let patch1 = Requirement::parse("~> 0.1.42").unwrap();
425        let patch2 = Requirement::parse("~> 0.2.42").unwrap();
426        let patch3 = Requirement::parse(">= 1.0.0").unwrap();
427
428        let is_patched = |v: &str| -> bool {
429            let ver = Version::parse(v).unwrap();
430            patch1.satisfied_by(&ver) || patch2.satisfied_by(&ver) || patch3.satisfied_by(&ver)
431        };
432
433        // Patched versions
434        assert!(is_patched("0.1.42"));
435        assert!(is_patched("0.1.50"));
436        assert!(is_patched("0.2.42"));
437        assert!(is_patched("0.2.99"));
438        assert!(is_patched("1.0.0"));
439        assert!(is_patched("2.0.0"));
440
441        // Vulnerable versions
442        assert!(!is_patched("0.1.0"));
443        assert!(!is_patched("0.1.41"));
444        assert!(!is_patched("0.2.0"));
445        assert!(!is_patched("0.2.41"));
446        assert!(!is_patched("0.3.0")); // not covered by ~> 0.2.42 (which is < 0.3.0)
447        assert!(!is_patched("0.9.0"));
448    }
449
450    #[test]
451    fn advisory_unaffected_versions_pattern() {
452        // unaffected_versions:
453        //   - "< 0.1.0"
454        let unaffected = Requirement::parse("< 0.1.0").unwrap();
455
456        assert!(unaffected.satisfied_by(&Version::parse("0.0.9").unwrap()));
457        assert!(unaffected.satisfied_by(&Version::parse("0.0.1").unwrap()));
458        assert!(!unaffected.satisfied_by(&Version::parse("0.1.0").unwrap()));
459        assert!(!unaffected.satisfied_by(&Version::parse("0.2.0").unwrap()));
460    }
461
462    #[test]
463    fn vulnerability_check_full() {
464        // Simulating the full vulnerability check logic from advisory.rb
465        let patched: Vec<Requirement> = vec![
466            Requirement::parse("~> 0.1.42").unwrap(),
467            Requirement::parse("~> 0.2.42").unwrap(),
468            Requirement::parse(">= 1.0.0").unwrap(),
469        ];
470        let unaffected: Vec<Requirement> = vec![Requirement::parse("< 0.1.0").unwrap()];
471
472        let is_patched = |v: &Version| -> bool { patched.iter().any(|req| req.satisfied_by(v)) };
473        let is_unaffected =
474            |v: &Version| -> bool { unaffected.iter().any(|req| req.satisfied_by(v)) };
475        let is_vulnerable = |v: &str| -> bool {
476            let ver = Version::parse(v).unwrap();
477            !is_patched(&ver) && !is_unaffected(&ver)
478        };
479
480        // Unaffected (too old to be affected)
481        assert!(!is_vulnerable("0.0.9"));
482
483        // Patched
484        assert!(!is_vulnerable("0.1.42"));
485        assert!(!is_vulnerable("1.0.0"));
486        assert!(!is_vulnerable("2.0.0"));
487
488        // Vulnerable
489        assert!(is_vulnerable("0.1.0"));
490        assert!(is_vulnerable("0.1.41"));
491        assert!(is_vulnerable("0.2.0"));
492        assert!(is_vulnerable("0.2.41"));
493        assert!(is_vulnerable("0.3.0"));
494    }
495
496    // ========== Display ==========
497
498    #[test]
499    fn display_single_constraint() {
500        let req = Requirement::parse("~> 1.2.3").unwrap();
501        assert_eq!(req.to_string(), "~> 1.2.3");
502    }
503
504    #[test]
505    fn display_compound() {
506        let req = Requirement::parse(">= 1.0, < 2.0").unwrap();
507        assert_eq!(req.to_string(), ">= 1.0, < 2.0");
508    }
509}