1use std::fmt;
2use thiserror::Error;
3
4use super::gem_version::Version;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum Operator {
9 Equal,
11 NotEqual,
13 GreaterThan,
15 LessThan,
17 GreaterThanOrEqual,
19 LessThanOrEqual,
21 Pessimistic,
23}
24
25impl fmt::Display for Operator {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 match self {
28 Operator::Equal => write!(f, "="),
29 Operator::NotEqual => write!(f, "!="),
30 Operator::GreaterThan => write!(f, ">"),
31 Operator::LessThan => write!(f, "<"),
32 Operator::GreaterThanOrEqual => write!(f, ">="),
33 Operator::LessThanOrEqual => write!(f, "<="),
34 Operator::Pessimistic => write!(f, "~>"),
35 }
36 }
37}
38
39#[derive(Debug, Clone, PartialEq, Eq)]
41pub struct VersionConstraint {
42 pub operator: Operator,
43 pub version: Version,
44}
45
46impl VersionConstraint {
47 pub fn satisfied_by(&self, version: &Version) -> bool {
49 match &self.operator {
50 Operator::Equal => version == &self.version,
51 Operator::NotEqual => version != &self.version,
52 Operator::GreaterThan => version > &self.version,
53 Operator::LessThan => version < &self.version,
54 Operator::GreaterThanOrEqual => version >= &self.version,
55 Operator::LessThanOrEqual => version <= &self.version,
56 Operator::Pessimistic => {
57 let upper = self.version.bump();
60 version >= &self.version && version < &upper
61 }
62 }
63 }
64}
65
66impl VersionConstraint {
67 pub fn minimum_version(&self) -> Option<Version> {
75 match &self.operator {
76 Operator::GreaterThanOrEqual | Operator::Pessimistic | Operator::Equal => {
77 Some(self.version.clone())
78 }
79 Operator::GreaterThan => Some(self.version.increment_last()),
80 Operator::LessThan | Operator::LessThanOrEqual | Operator::NotEqual => None,
81 }
82 }
83}
84
85impl fmt::Display for VersionConstraint {
86 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87 write!(f, "{} {}", self.operator, self.version)
88 }
89}
90
91#[derive(Debug, Clone, PartialEq, Eq)]
103pub struct Requirement {
104 pub constraints: Vec<VersionConstraint>,
105}
106
107#[derive(Debug, Clone, PartialEq, Eq, Error)]
108pub enum RequirementError {
109 #[error("invalid operator: '{0}'")]
110 InvalidOperator(String),
111 #[error("invalid version: '{0}'")]
112 InvalidVersion(String),
113 #[error("empty requirement string")]
114 Empty,
115}
116
117impl Requirement {
118 pub fn parse(input: &str) -> Result<Self, RequirementError> {
125 let input = input.trim();
126 if input.is_empty() {
127 return Ok(Requirement::default());
128 }
129
130 let parts: Vec<&str> = input.split(',').map(|s| s.trim()).collect();
131 let mut constraints = Vec::with_capacity(parts.len());
132
133 for part in parts {
134 let constraint = parse_single_constraint(part)?;
135 constraints.push(constraint);
136 }
137
138 if constraints.is_empty() {
139 return Err(RequirementError::Empty);
140 }
141
142 Ok(Requirement { constraints })
143 }
144
145 pub fn parse_multiple(inputs: &[&str]) -> Result<Self, RequirementError> {
149 let mut constraints = Vec::new();
150
151 for input in inputs {
152 let req = Requirement::parse(input)?;
153 constraints.extend(req.constraints);
154 }
155
156 if constraints.is_empty() {
157 return Ok(Requirement::default());
158 }
159
160 Ok(Requirement { constraints })
161 }
162
163 pub fn satisfied_by(&self, version: &Version) -> bool {
165 self.constraints.iter().all(|c| c.satisfied_by(version))
166 }
167
168 pub fn minimum_version(&self) -> Option<Version> {
173 let mut candidate: Option<Version> = None;
175 for c in &self.constraints {
176 if let Some(v) = c.minimum_version() {
177 candidate = Some(match candidate {
178 Some(cur) if v > cur => v,
179 Some(cur) => cur,
180 None => v,
181 });
182 }
183 }
184
185 let mut candidate = candidate?;
187
188 for _ in 0..100 {
190 if self.satisfied_by(&candidate) {
191 return Some(candidate);
192 }
193 candidate = candidate.increment_last();
195 }
196
197 None
198 }
199}
200
201impl Default for Requirement {
202 fn default() -> Self {
204 Requirement {
205 constraints: vec![VersionConstraint {
206 operator: Operator::GreaterThanOrEqual,
207 version: Version::parse("0").unwrap(),
208 }],
209 }
210 }
211}
212
213impl fmt::Display for Requirement {
214 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
215 let parts: Vec<String> = self.constraints.iter().map(|c| c.to_string()).collect();
216 write!(f, "{}", parts.join(", "))
217 }
218}
219
220fn parse_single_constraint(input: &str) -> Result<VersionConstraint, RequirementError> {
222 let input = input.trim();
223
224 if input.is_empty() {
225 return Err(RequirementError::Empty);
226 }
227
228 let (operator, version_str) = if let Some(rest) = input.strip_prefix("~>") {
230 (Operator::Pessimistic, rest.trim())
231 } else if let Some(rest) = input.strip_prefix(">=") {
232 (Operator::GreaterThanOrEqual, rest.trim())
233 } else if let Some(rest) = input.strip_prefix("<=") {
234 (Operator::LessThanOrEqual, rest.trim())
235 } else if let Some(rest) = input.strip_prefix("!=") {
236 (Operator::NotEqual, rest.trim())
237 } else if let Some(rest) = input.strip_prefix('>') {
238 (Operator::GreaterThan, rest.trim())
239 } else if let Some(rest) = input.strip_prefix('<') {
240 (Operator::LessThan, rest.trim())
241 } else if let Some(rest) = input.strip_prefix('=') {
242 (Operator::Equal, rest.trim())
243 } else {
244 (Operator::Equal, input)
246 };
247
248 let version = Version::parse(version_str)
249 .map_err(|_| RequirementError::InvalidVersion(version_str.to_string()))?;
250
251 Ok(VersionConstraint { operator, version })
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[test]
261 fn parse_simple_equality() {
262 let req = Requirement::parse("= 1.0.0").unwrap();
263 assert_eq!(req.constraints.len(), 1);
264 assert_eq!(req.constraints[0].operator, Operator::Equal);
265 assert_eq!(req.constraints[0].version, Version::parse("1.0.0").unwrap());
266 }
267
268 #[test]
269 fn parse_pessimistic() {
270 let req = Requirement::parse("~> 1.2.3").unwrap();
271 assert_eq!(req.constraints[0].operator, Operator::Pessimistic);
272 assert_eq!(req.constraints[0].version, Version::parse("1.2.3").unwrap());
273 }
274
275 #[test]
276 fn parse_greater_than_or_equal() {
277 let req = Requirement::parse(">= 2.0").unwrap();
278 assert_eq!(req.constraints[0].operator, Operator::GreaterThanOrEqual);
279 }
280
281 #[test]
282 fn parse_less_than() {
283 let req = Requirement::parse("< 3.0").unwrap();
284 assert_eq!(req.constraints[0].operator, Operator::LessThan);
285 }
286
287 #[test]
288 fn parse_not_equal() {
289 let req = Requirement::parse("!= 1.5").unwrap();
290 assert_eq!(req.constraints[0].operator, Operator::NotEqual);
291 }
292
293 #[test]
294 fn parse_compound_requirement() {
295 let req = Requirement::parse(">= 1.0, < 2.0").unwrap();
296 assert_eq!(req.constraints.len(), 2);
297 assert_eq!(req.constraints[0].operator, Operator::GreaterThanOrEqual);
298 assert_eq!(req.constraints[1].operator, Operator::LessThan);
299 }
300
301 #[test]
302 fn parse_no_operator_defaults_to_equal() {
303 let req = Requirement::parse("1.0.0").unwrap();
304 assert_eq!(req.constraints[0].operator, Operator::Equal);
305 assert_eq!(req.constraints[0].version, Version::parse("1.0.0").unwrap());
306 }
307
308 #[test]
309 fn parse_multiple_strings() {
310 let req = Requirement::parse_multiple(&[">= 1.0", "< 2.0", "!= 1.5"]).unwrap();
311 assert_eq!(req.constraints.len(), 3);
312 }
313
314 #[test]
317 fn default_requirement_matches_any() {
318 let req = Requirement::default();
319 assert!(req.satisfied_by(&Version::parse("0").unwrap()));
320 assert!(req.satisfied_by(&Version::parse("1.0.0").unwrap()));
321 assert!(req.satisfied_by(&Version::parse("999.999.999").unwrap()));
322 }
323
324 #[test]
327 fn equal_matches_exact() {
328 let req = Requirement::parse("= 1.0.0").unwrap();
329 assert!(req.satisfied_by(&Version::parse("1.0.0").unwrap()));
330 assert!(req.satisfied_by(&Version::parse("1.0").unwrap())); assert!(!req.satisfied_by(&Version::parse("1.0.1").unwrap()));
332 assert!(!req.satisfied_by(&Version::parse("0.9.9").unwrap()));
333 }
334
335 #[test]
338 fn not_equal_excludes_version() {
339 let req = Requirement::parse("!= 1.5").unwrap();
340 assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
341 assert!(req.satisfied_by(&Version::parse("2.0").unwrap()));
342 assert!(!req.satisfied_by(&Version::parse("1.5").unwrap()));
343 assert!(!req.satisfied_by(&Version::parse("1.5.0").unwrap()));
344 }
345
346 #[test]
349 fn greater_than() {
350 let req = Requirement::parse("> 1.0").unwrap();
351 assert!(req.satisfied_by(&Version::parse("1.0.1").unwrap()));
352 assert!(req.satisfied_by(&Version::parse("2.0").unwrap()));
353 assert!(!req.satisfied_by(&Version::parse("1.0").unwrap()));
354 assert!(!req.satisfied_by(&Version::parse("0.9").unwrap()));
355 }
356
357 #[test]
360 fn less_than() {
361 let req = Requirement::parse("< 2.0").unwrap();
362 assert!(req.satisfied_by(&Version::parse("1.9.9").unwrap()));
363 assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
364 assert!(!req.satisfied_by(&Version::parse("2.0").unwrap()));
365 assert!(!req.satisfied_by(&Version::parse("2.0.1").unwrap()));
366 }
367
368 #[test]
371 fn greater_than_or_equal() {
372 let req = Requirement::parse(">= 1.0").unwrap();
373 assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
374 assert!(req.satisfied_by(&Version::parse("1.0.0").unwrap()));
375 assert!(req.satisfied_by(&Version::parse("2.0").unwrap()));
376 assert!(!req.satisfied_by(&Version::parse("0.9.9").unwrap()));
377 }
378
379 #[test]
382 fn less_than_or_equal() {
383 let req = Requirement::parse("<= 2.0").unwrap();
384 assert!(req.satisfied_by(&Version::parse("2.0").unwrap()));
385 assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
386 assert!(!req.satisfied_by(&Version::parse("2.0.1").unwrap()));
387 }
388
389 #[test]
392 fn pessimistic_two_segments() {
393 let req = Requirement::parse("~> 2.3").unwrap();
395 assert!(req.satisfied_by(&Version::parse("2.3").unwrap()));
396 assert!(req.satisfied_by(&Version::parse("2.5").unwrap()));
397 assert!(req.satisfied_by(&Version::parse("2.9.9").unwrap()));
398 assert!(!req.satisfied_by(&Version::parse("3.0").unwrap()));
399 assert!(!req.satisfied_by(&Version::parse("2.2").unwrap()));
400 }
401
402 #[test]
403 fn pessimistic_three_segments() {
404 let req = Requirement::parse("~> 2.3.0").unwrap();
406 assert!(req.satisfied_by(&Version::parse("2.3.0").unwrap()));
407 assert!(req.satisfied_by(&Version::parse("2.3.5").unwrap()));
408 assert!(req.satisfied_by(&Version::parse("2.3.99").unwrap()));
409 assert!(!req.satisfied_by(&Version::parse("2.4.0").unwrap()));
410 assert!(!req.satisfied_by(&Version::parse("2.2.9").unwrap()));
411 }
412
413 #[test]
414 fn pessimistic_three_segments_nonzero() {
415 let req = Requirement::parse("~> 2.3.18").unwrap();
417 assert!(req.satisfied_by(&Version::parse("2.3.18").unwrap()));
418 assert!(req.satisfied_by(&Version::parse("2.3.20").unwrap()));
419 assert!(!req.satisfied_by(&Version::parse("2.3.17").unwrap()));
420 assert!(!req.satisfied_by(&Version::parse("2.4.0").unwrap()));
421 }
422
423 #[test]
424 fn pessimistic_single_segment() {
425 let req = Requirement::parse("~> 2").unwrap();
427 assert!(req.satisfied_by(&Version::parse("2.0").unwrap()));
428 assert!(req.satisfied_by(&Version::parse("2.9.9").unwrap()));
429 assert!(!req.satisfied_by(&Version::parse("3.0").unwrap()));
430 assert!(!req.satisfied_by(&Version::parse("1.9").unwrap()));
431 }
432
433 #[test]
434 fn pessimistic_four_segments() {
435 let req = Requirement::parse("~> 1.2.3.4").unwrap();
437 assert!(req.satisfied_by(&Version::parse("1.2.3.4").unwrap()));
438 assert!(req.satisfied_by(&Version::parse("1.2.3.99").unwrap()));
439 assert!(!req.satisfied_by(&Version::parse("1.2.4.0").unwrap()));
440 assert!(!req.satisfied_by(&Version::parse("1.2.3.3").unwrap()));
441 }
442
443 #[test]
446 fn compound_range() {
447 let req = Requirement::parse(">= 1.0, < 2.0").unwrap();
448 assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
449 assert!(req.satisfied_by(&Version::parse("1.5").unwrap()));
450 assert!(req.satisfied_by(&Version::parse("1.9.9").unwrap()));
451 assert!(!req.satisfied_by(&Version::parse("0.9").unwrap()));
452 assert!(!req.satisfied_by(&Version::parse("2.0").unwrap()));
453 }
454
455 #[test]
456 fn compound_with_exclusion() {
457 let req = Requirement::parse(">= 1.0, < 2.0, != 1.5").unwrap();
458 assert!(req.satisfied_by(&Version::parse("1.0").unwrap()));
459 assert!(req.satisfied_by(&Version::parse("1.4.9").unwrap()));
460 assert!(req.satisfied_by(&Version::parse("1.5.1").unwrap()));
461 assert!(!req.satisfied_by(&Version::parse("1.5").unwrap()));
462 assert!(!req.satisfied_by(&Version::parse("2.0").unwrap()));
463 }
464
465 #[test]
468 fn advisory_patched_versions_pattern() {
469 let patch1 = Requirement::parse("~> 0.1.42").unwrap();
476 let patch2 = Requirement::parse("~> 0.2.42").unwrap();
477 let patch3 = Requirement::parse(">= 1.0.0").unwrap();
478
479 let is_patched = |v: &str| -> bool {
480 let ver = Version::parse(v).unwrap();
481 patch1.satisfied_by(&ver) || patch2.satisfied_by(&ver) || patch3.satisfied_by(&ver)
482 };
483
484 assert!(is_patched("0.1.42"));
486 assert!(is_patched("0.1.50"));
487 assert!(is_patched("0.2.42"));
488 assert!(is_patched("0.2.99"));
489 assert!(is_patched("1.0.0"));
490 assert!(is_patched("2.0.0"));
491
492 assert!(!is_patched("0.1.0"));
494 assert!(!is_patched("0.1.41"));
495 assert!(!is_patched("0.2.0"));
496 assert!(!is_patched("0.2.41"));
497 assert!(!is_patched("0.3.0")); assert!(!is_patched("0.9.0"));
499 }
500
501 #[test]
502 fn advisory_unaffected_versions_pattern() {
503 let unaffected = Requirement::parse("< 0.1.0").unwrap();
506
507 assert!(unaffected.satisfied_by(&Version::parse("0.0.9").unwrap()));
508 assert!(unaffected.satisfied_by(&Version::parse("0.0.1").unwrap()));
509 assert!(!unaffected.satisfied_by(&Version::parse("0.1.0").unwrap()));
510 assert!(!unaffected.satisfied_by(&Version::parse("0.2.0").unwrap()));
511 }
512
513 #[test]
514 fn vulnerability_check_full() {
515 let patched: Vec<Requirement> = vec![
517 Requirement::parse("~> 0.1.42").unwrap(),
518 Requirement::parse("~> 0.2.42").unwrap(),
519 Requirement::parse(">= 1.0.0").unwrap(),
520 ];
521 let unaffected: Vec<Requirement> = vec![Requirement::parse("< 0.1.0").unwrap()];
522
523 let is_patched = |v: &Version| -> bool { patched.iter().any(|req| req.satisfied_by(v)) };
524 let is_unaffected =
525 |v: &Version| -> bool { unaffected.iter().any(|req| req.satisfied_by(v)) };
526 let is_vulnerable = |v: &str| -> bool {
527 let ver = Version::parse(v).unwrap();
528 !is_patched(&ver) && !is_unaffected(&ver)
529 };
530
531 assert!(!is_vulnerable("0.0.9"));
533
534 assert!(!is_vulnerable("0.1.42"));
536 assert!(!is_vulnerable("1.0.0"));
537 assert!(!is_vulnerable("2.0.0"));
538
539 assert!(is_vulnerable("0.1.0"));
541 assert!(is_vulnerable("0.1.41"));
542 assert!(is_vulnerable("0.2.0"));
543 assert!(is_vulnerable("0.2.41"));
544 assert!(is_vulnerable("0.3.0"));
545 }
546
547 #[test]
550 fn display_single_constraint() {
551 let req = Requirement::parse("~> 1.2.3").unwrap();
552 assert_eq!(req.to_string(), "~> 1.2.3");
553 }
554
555 #[test]
556 fn display_compound() {
557 let req = Requirement::parse(">= 1.0, < 2.0").unwrap();
558 assert_eq!(req.to_string(), ">= 1.0, < 2.0");
559 }
560}