1use std::fmt;
2use thiserror::Error;
3
4use super::gem_version::Version;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum Operator {
9 Equal,
11 NotEqual,
13 GreaterThan,
15 LessThan,
17 GreaterThanOrEqual,
19 LessThanOrEqual,
21 Pessimistic,
23}
24
25impl fmt::Display for Operator {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 match self {
28 Operator::Equal => write!(f, "="),
29 Operator::NotEqual => write!(f, "!="),
30 Operator::GreaterThan => write!(f, ">"),
31 Operator::LessThan => write!(f, "<"),
32 Operator::GreaterThanOrEqual => write!(f, ">="),
33 Operator::LessThanOrEqual => write!(f, "<="),
34 Operator::Pessimistic => write!(f, "~>"),
35 }
36 }
37}
38
39#[derive(Debug, Clone, PartialEq, Eq)]
41pub struct VersionConstraint {
42 pub operator: Operator,
43 pub version: Version,
44}
45
46impl VersionConstraint {
47 pub fn satisfied_by(&self, version: &Version) -> bool {
49 match &self.operator {
50 Operator::Equal => version == &self.version,
51 Operator::NotEqual => version != &self.version,
52 Operator::GreaterThan => version > &self.version,
53 Operator::LessThan => version < &self.version,
54 Operator::GreaterThanOrEqual => version >= &self.version,
55 Operator::LessThanOrEqual => version <= &self.version,
56 Operator::Pessimistic => {
57 let upper = self.version.bump();
60 version >= &self.version && version < &upper
61 }
62 }
63 }
64}
65
66impl fmt::Display for VersionConstraint {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 write!(f, "{} {}", self.operator, self.version)
69 }
70}
71
72#[derive(Debug, Clone, PartialEq, Eq)]
84pub struct Requirement {
85 pub constraints: Vec<VersionConstraint>,
86}
87
88#[derive(Debug, Clone, PartialEq, Eq, Error)]
89pub enum RequirementError {
90 #[error("invalid operator: '{0}'")]
91 InvalidOperator(String),
92 #[error("invalid version: '{0}'")]
93 InvalidVersion(String),
94 #[error("empty requirement string")]
95 Empty,
96}
97
98impl Requirement {
99 pub fn parse(input: &str) -> Result<Self, RequirementError> {
106 let input = input.trim();
107 if input.is_empty() {
108 return Ok(Requirement::default());
109 }
110
111 let parts: Vec<&str> = input.split(',').map(|s| s.trim()).collect();
112 let mut constraints = Vec::with_capacity(parts.len());
113
114 for part in parts {
115 let constraint = parse_single_constraint(part)?;
116 constraints.push(constraint);
117 }
118
119 if constraints.is_empty() {
120 return Err(RequirementError::Empty);
121 }
122
123 Ok(Requirement { constraints })
124 }
125
126 pub fn parse_multiple(inputs: &[&str]) -> Result<Self, RequirementError> {
130 let mut constraints = Vec::new();
131
132 for input in inputs {
133 let req = Requirement::parse(input)?;
134 constraints.extend(req.constraints);
135 }
136
137 if constraints.is_empty() {
138 return Ok(Requirement::default());
139 }
140
141 Ok(Requirement { constraints })
142 }
143
144 pub fn satisfied_by(&self, version: &Version) -> bool {
146 self.constraints.iter().all(|c| c.satisfied_by(version))
147 }
148}
149
150impl Default for Requirement {
151 fn default() -> Self {
153 Requirement {
154 constraints: vec![VersionConstraint {
155 operator: Operator::GreaterThanOrEqual,
156 version: Version::parse("0").unwrap(),
157 }],
158 }
159 }
160}
161
162impl fmt::Display for Requirement {
163 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164 let parts: Vec<String> = self.constraints.iter().map(|c| c.to_string()).collect();
165 write!(f, "{}", parts.join(", "))
166 }
167}
168
169fn parse_single_constraint(input: &str) -> Result<VersionConstraint, RequirementError> {
171 let input = input.trim();
172
173 if input.is_empty() {
174 return Err(RequirementError::Empty);
175 }
176
177 let (operator, version_str) = if let Some(rest) = input.strip_prefix("~>") {
179 (Operator::Pessimistic, rest.trim())
180 } else if let Some(rest) = input.strip_prefix(">=") {
181 (Operator::GreaterThanOrEqual, rest.trim())
182 } else if let Some(rest) = input.strip_prefix("<=") {
183 (Operator::LessThanOrEqual, rest.trim())
184 } else if let Some(rest) = input.strip_prefix("!=") {
185 (Operator::NotEqual, rest.trim())
186 } else if let Some(rest) = input.strip_prefix('>') {
187 (Operator::GreaterThan, rest.trim())
188 } else if let Some(rest) = input.strip_prefix('<') {
189 (Operator::LessThan, rest.trim())
190 } else if let Some(rest) = input.strip_prefix('=') {
191 (Operator::Equal, rest.trim())
192 } else {
193 (Operator::Equal, input)
195 };
196
197 let version = Version::parse(version_str)
198 .map_err(|_| RequirementError::InvalidVersion(version_str.to_string()))?;
199
200 Ok(VersionConstraint { operator, version })
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[test]
210 fn parse_simple_equality() {
211 let req = Requirement::parse("= 1.0.0").unwrap();
212 assert_eq!(req.constraints.len(), 1);
213 assert_eq!(req.constraints[0].operator, Operator::Equal);
214 assert_eq!(req.constraints[0].version, Version::parse("1.0.0").unwrap());
215 }
216
217 #[test]
218 fn parse_pessimistic() {
219 let req = Requirement::parse("~> 1.2.3").unwrap();
220 assert_eq!(req.constraints[0].operator, Operator::Pessimistic);
221 assert_eq!(req.constraints[0].version, Version::parse("1.2.3").unwrap());
222 }
223
224 #[test]
225 fn parse_greater_than_or_equal() {
226 let req = Requirement::parse(">= 2.0").unwrap();
227 assert_eq!(req.constraints[0].operator, Operator::GreaterThanOrEqual);
228 }
229
230 #[test]
231 fn parse_less_than() {
232 let req = Requirement::parse("< 3.0").unwrap();
233 assert_eq!(req.constraints[0].operator, Operator::LessThan);
234 }
235
236 #[test]
237 fn parse_not_equal() {
238 let req = Requirement::parse("!= 1.5").unwrap();
239 assert_eq!(req.constraints[0].operator, Operator::NotEqual);
240 }
241
242 #[test]
243 fn parse_compound_requirement() {
244 let req = Requirement::parse(">= 1.0, < 2.0").unwrap();
245 assert_eq!(req.constraints.len(), 2);
246 assert_eq!(req.constraints[0].operator, Operator::GreaterThanOrEqual);
247 assert_eq!(req.constraints[1].operator, Operator::LessThan);
248 }
249
250 #[test]
251 fn parse_no_operator_defaults_to_equal() {
252 let req = Requirement::parse("1.0.0").unwrap();
253 assert_eq!(req.constraints[0].operator, Operator::Equal);
254 assert_eq!(req.constraints[0].version, Version::parse("1.0.0").unwrap());
255 }
256
257 #[test]
258 fn parse_multiple_strings() {
259 let req = Requirement::parse_multiple(&[">= 1.0", "< 2.0", "!= 1.5"]).unwrap();
260 assert_eq!(req.constraints.len(), 3);
261 }
262
263 #[test]
266 fn default_requirement_matches_any() {
267 let req = Requirement::default();
268 assert!(req.satisfied_by(&Version::parse("0").unwrap()));
269 assert!(req.satisfied_by(&Version::parse("1.0.0").unwrap()));
270 assert!(req.satisfied_by(&Version::parse("999.999.999").unwrap()));
271 }
272
273 #[test]
276 fn equal_matches_exact() {
277 let req = Requirement::parse("= 1.0.0").unwrap();
278 assert!(req.satisfied_by(&Version::parse("1.0.0").unwrap()));
279 assert!(req.satisfied_by(&Version::parse("1.0").unwrap())); assert!(!req.satisfied_by(&Version::parse("1.0.1").unwrap()));
281 assert!(!req.satisfied_by(&Version::parse("0.9.9").unwrap()));
282 }
283
284 #[test]
287 fn not_equal_excludes_version() {
288 let req = Requirement::parse("!= 1.5").unwrap();
289 assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
290 assert!(req.satisfied_by(&Version::parse("2.0").unwrap()));
291 assert!(!req.satisfied_by(&Version::parse("1.5").unwrap()));
292 assert!(!req.satisfied_by(&Version::parse("1.5.0").unwrap()));
293 }
294
295 #[test]
298 fn greater_than() {
299 let req = Requirement::parse("> 1.0").unwrap();
300 assert!(req.satisfied_by(&Version::parse("1.0.1").unwrap()));
301 assert!(req.satisfied_by(&Version::parse("2.0").unwrap()));
302 assert!(!req.satisfied_by(&Version::parse("1.0").unwrap()));
303 assert!(!req.satisfied_by(&Version::parse("0.9").unwrap()));
304 }
305
306 #[test]
309 fn less_than() {
310 let req = Requirement::parse("< 2.0").unwrap();
311 assert!(req.satisfied_by(&Version::parse("1.9.9").unwrap()));
312 assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
313 assert!(!req.satisfied_by(&Version::parse("2.0").unwrap()));
314 assert!(!req.satisfied_by(&Version::parse("2.0.1").unwrap()));
315 }
316
317 #[test]
320 fn greater_than_or_equal() {
321 let req = Requirement::parse(">= 1.0").unwrap();
322 assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
323 assert!(req.satisfied_by(&Version::parse("1.0.0").unwrap()));
324 assert!(req.satisfied_by(&Version::parse("2.0").unwrap()));
325 assert!(!req.satisfied_by(&Version::parse("0.9.9").unwrap()));
326 }
327
328 #[test]
331 fn less_than_or_equal() {
332 let req = Requirement::parse("<= 2.0").unwrap();
333 assert!(req.satisfied_by(&Version::parse("2.0").unwrap()));
334 assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
335 assert!(!req.satisfied_by(&Version::parse("2.0.1").unwrap()));
336 }
337
338 #[test]
341 fn pessimistic_two_segments() {
342 let req = Requirement::parse("~> 2.3").unwrap();
344 assert!(req.satisfied_by(&Version::parse("2.3").unwrap()));
345 assert!(req.satisfied_by(&Version::parse("2.5").unwrap()));
346 assert!(req.satisfied_by(&Version::parse("2.9.9").unwrap()));
347 assert!(!req.satisfied_by(&Version::parse("3.0").unwrap()));
348 assert!(!req.satisfied_by(&Version::parse("2.2").unwrap()));
349 }
350
351 #[test]
352 fn pessimistic_three_segments() {
353 let req = Requirement::parse("~> 2.3.0").unwrap();
355 assert!(req.satisfied_by(&Version::parse("2.3.0").unwrap()));
356 assert!(req.satisfied_by(&Version::parse("2.3.5").unwrap()));
357 assert!(req.satisfied_by(&Version::parse("2.3.99").unwrap()));
358 assert!(!req.satisfied_by(&Version::parse("2.4.0").unwrap()));
359 assert!(!req.satisfied_by(&Version::parse("2.2.9").unwrap()));
360 }
361
362 #[test]
363 fn pessimistic_three_segments_nonzero() {
364 let req = Requirement::parse("~> 2.3.18").unwrap();
366 assert!(req.satisfied_by(&Version::parse("2.3.18").unwrap()));
367 assert!(req.satisfied_by(&Version::parse("2.3.20").unwrap()));
368 assert!(!req.satisfied_by(&Version::parse("2.3.17").unwrap()));
369 assert!(!req.satisfied_by(&Version::parse("2.4.0").unwrap()));
370 }
371
372 #[test]
373 fn pessimistic_single_segment() {
374 let req = Requirement::parse("~> 2").unwrap();
376 assert!(req.satisfied_by(&Version::parse("2.0").unwrap()));
377 assert!(req.satisfied_by(&Version::parse("2.9.9").unwrap()));
378 assert!(!req.satisfied_by(&Version::parse("3.0").unwrap()));
379 assert!(!req.satisfied_by(&Version::parse("1.9").unwrap()));
380 }
381
382 #[test]
383 fn pessimistic_four_segments() {
384 let req = Requirement::parse("~> 1.2.3.4").unwrap();
386 assert!(req.satisfied_by(&Version::parse("1.2.3.4").unwrap()));
387 assert!(req.satisfied_by(&Version::parse("1.2.3.99").unwrap()));
388 assert!(!req.satisfied_by(&Version::parse("1.2.4.0").unwrap()));
389 assert!(!req.satisfied_by(&Version::parse("1.2.3.3").unwrap()));
390 }
391
392 #[test]
395 fn compound_range() {
396 let req = Requirement::parse(">= 1.0, < 2.0").unwrap();
397 assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
398 assert!(req.satisfied_by(&Version::parse("1.5").unwrap()));
399 assert!(req.satisfied_by(&Version::parse("1.9.9").unwrap()));
400 assert!(!req.satisfied_by(&Version::parse("0.9").unwrap()));
401 assert!(!req.satisfied_by(&Version::parse("2.0").unwrap()));
402 }
403
404 #[test]
405 fn compound_with_exclusion() {
406 let req = Requirement::parse(">= 1.0, < 2.0, != 1.5").unwrap();
407 assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
408 assert!(req.satisfied_by(&Version::parse("1.4.9").unwrap()));
409 assert!(req.satisfied_by(&Version::parse("1.5.1").unwrap()));
410 assert!(!req.satisfied_by(&Version::parse("1.5").unwrap()));
411 assert!(!req.satisfied_by(&Version::parse("2.0").unwrap()));
412 }
413
414 #[test]
417 fn advisory_patched_versions_pattern() {
418 let patch1 = Requirement::parse("~> 0.1.42").unwrap();
425 let patch2 = Requirement::parse("~> 0.2.42").unwrap();
426 let patch3 = Requirement::parse(">= 1.0.0").unwrap();
427
428 let is_patched = |v: &str| -> bool {
429 let ver = Version::parse(v).unwrap();
430 patch1.satisfied_by(&ver) || patch2.satisfied_by(&ver) || patch3.satisfied_by(&ver)
431 };
432
433 assert!(is_patched("0.1.42"));
435 assert!(is_patched("0.1.50"));
436 assert!(is_patched("0.2.42"));
437 assert!(is_patched("0.2.99"));
438 assert!(is_patched("1.0.0"));
439 assert!(is_patched("2.0.0"));
440
441 assert!(!is_patched("0.1.0"));
443 assert!(!is_patched("0.1.41"));
444 assert!(!is_patched("0.2.0"));
445 assert!(!is_patched("0.2.41"));
446 assert!(!is_patched("0.3.0")); assert!(!is_patched("0.9.0"));
448 }
449
450 #[test]
451 fn advisory_unaffected_versions_pattern() {
452 let unaffected = Requirement::parse("< 0.1.0").unwrap();
455
456 assert!(unaffected.satisfied_by(&Version::parse("0.0.9").unwrap()));
457 assert!(unaffected.satisfied_by(&Version::parse("0.0.1").unwrap()));
458 assert!(!unaffected.satisfied_by(&Version::parse("0.1.0").unwrap()));
459 assert!(!unaffected.satisfied_by(&Version::parse("0.2.0").unwrap()));
460 }
461
462 #[test]
463 fn vulnerability_check_full() {
464 let patched: Vec<Requirement> = vec![
466 Requirement::parse("~> 0.1.42").unwrap(),
467 Requirement::parse("~> 0.2.42").unwrap(),
468 Requirement::parse(">= 1.0.0").unwrap(),
469 ];
470 let unaffected: Vec<Requirement> = vec![Requirement::parse("< 0.1.0").unwrap()];
471
472 let is_patched = |v: &Version| -> bool { patched.iter().any(|req| req.satisfied_by(v)) };
473 let is_unaffected =
474 |v: &Version| -> bool { unaffected.iter().any(|req| req.satisfied_by(v)) };
475 let is_vulnerable = |v: &str| -> bool {
476 let ver = Version::parse(v).unwrap();
477 !is_patched(&ver) && !is_unaffected(&ver)
478 };
479
480 assert!(!is_vulnerable("0.0.9"));
482
483 assert!(!is_vulnerable("0.1.42"));
485 assert!(!is_vulnerable("1.0.0"));
486 assert!(!is_vulnerable("2.0.0"));
487
488 assert!(is_vulnerable("0.1.0"));
490 assert!(is_vulnerable("0.1.41"));
491 assert!(is_vulnerable("0.2.0"));
492 assert!(is_vulnerable("0.2.41"));
493 assert!(is_vulnerable("0.3.0"));
494 }
495
496 #[test]
499 fn display_single_constraint() {
500 let req = Requirement::parse("~> 1.2.3").unwrap();
501 assert_eq!(req.to_string(), "~> 1.2.3");
502 }
503
504 #[test]
505 fn display_compound() {
506 let req = Requirement::parse(">= 1.0, < 2.0").unwrap();
507 assert_eq!(req.to_string(), ">= 1.0, < 2.0");
508 }
509}