1use std::fmt;
2use std::path::{Path, PathBuf};
3use thiserror::Error;
4
5use super::model::Advisory;
6use crate::version::Version;
7
8const ADVISORY_DB_URL: &str = "https://github.com/rubysec/ruby-advisory-db.git";
10
11#[derive(Debug)]
13pub struct Database {
14 path: PathBuf,
15}
16
17#[derive(Debug, Error)]
18pub enum DatabaseError {
19 #[error("database not found at {}", .0.display())]
20 NotFound(PathBuf),
21 #[error("download failed: {0}")]
22 DownloadFailed(String),
23 #[error("update failed: {0}")]
24 UpdateFailed(String),
25 #[error("git error: {0}")]
26 Git(String),
27 #[error("IO error: {0}")]
28 Io(#[from] std::io::Error),
29}
30
31impl Database {
32 pub fn open(path: &Path) -> Result<Self, DatabaseError> {
34 if !path.is_dir() {
35 return Err(DatabaseError::NotFound(path.to_path_buf()));
36 }
37 Ok(Database {
38 path: path.to_path_buf(),
39 })
40 }
41
42 pub fn default_path() -> PathBuf {
46 if let Ok(custom) = std::env::var("GEM_AUDIT_DB") {
47 return PathBuf::from(custom);
48 }
49 dirs_fallback()
50 }
51
52 pub fn download(path: &Path, _quiet: bool) -> Result<Self, DatabaseError> {
54 if let Some(parent) = path.parent() {
56 std::fs::create_dir_all(parent).map_err(DatabaseError::Io)?;
57 }
58
59 let (mut checkout, _outcome) = gix::prepare_clone(ADVISORY_DB_URL, path)
60 .map_err(|e| DatabaseError::DownloadFailed(e.to_string()))?
61 .fetch_then_checkout(gix::progress::Discard, &gix::interrupt::IS_INTERRUPTED)
62 .map_err(|e| DatabaseError::DownloadFailed(e.to_string()))?;
63
64 let (_repo, _outcome) = checkout
65 .main_worktree(gix::progress::Discard, &gix::interrupt::IS_INTERRUPTED)
66 .map_err(|e| DatabaseError::DownloadFailed(e.to_string()))?;
67
68 Ok(Database {
69 path: path.to_path_buf(),
70 })
71 }
72
73 pub fn update(&self) -> Result<bool, DatabaseError> {
78 if !self.is_git() {
79 return Ok(false);
80 }
81
82 if let Err(e) = self.try_fetch() {
83 eprintln!("warning: git fetch failed ({}), re-cloning ...", e);
84 return self.reclone();
85 }
86
87 self.checkout_head()
88 }
89
90 fn try_fetch(&self) -> Result<(), DatabaseError> {
92 let repo = gix::open(&self.path).map_err(|e| DatabaseError::Git(e.to_string()))?;
93
94 let remote = repo
95 .find_default_remote(gix::remote::Direction::Fetch)
96 .ok_or_else(|| DatabaseError::UpdateFailed("no remote configured".to_string()))?
97 .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
98
99 let connection = remote
100 .connect(gix::remote::Direction::Fetch)
101 .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
102
103 connection
104 .prepare_fetch(gix::progress::Discard, Default::default())
105 .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?
106 .receive(gix::progress::Discard, &gix::interrupt::IS_INTERRUPTED)
107 .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
108
109 Ok(())
110 }
111
112 fn checkout_head(&self) -> Result<bool, DatabaseError> {
114 let repo = gix::open(&self.path).map_err(|e| DatabaseError::Git(e.to_string()))?;
115
116 let remote_commit = self.find_remote_head(&repo)?;
118 let head_commit = repo
119 .head_commit()
120 .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
121
122 if remote_commit.id != head_commit.id {
123 repo.reference(
125 repo.head_name()
126 .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?
127 .ok_or_else(|| DatabaseError::UpdateFailed("detached HEAD".to_string()))?
128 .as_ref(),
129 remote_commit.id,
130 gix::refs::transaction::PreviousValue::MustExist,
131 "gem-audit update",
132 )
133 .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
134 }
135
136 let tree = remote_commit
137 .tree()
138 .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
139
140 let mut index = repo
141 .index_from_tree(&tree.id)
142 .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
143
144 let opts = gix::worktree::state::checkout::Options {
145 overwrite_existing: true,
146 ..Default::default()
147 };
148
149 gix::worktree::state::checkout(
150 &mut index,
151 repo.workdir()
152 .ok_or_else(|| DatabaseError::UpdateFailed("bare repository".to_string()))?,
153 repo.objects
154 .clone()
155 .into_arc()
156 .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?,
157 &gix::progress::Discard,
158 &gix::progress::Discard,
159 &gix::interrupt::IS_INTERRUPTED,
160 opts,
161 )
162 .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
163
164 Ok(true)
165 }
166
167 fn find_remote_head<'a>(
172 &self,
173 repo: &'a gix::Repository,
174 ) -> Result<gix::Commit<'a>, DatabaseError> {
175 let candidates = ["refs/remotes/origin/main", "refs/remotes/origin/master"];
177
178 for refname in &candidates {
179 if let Ok(reference) = repo.find_reference(*refname) {
180 let commit = reference
181 .into_fully_peeled_id()
182 .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?
183 .object()
184 .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?
185 .try_into_commit()
186 .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
187 return Ok(commit);
188 }
189 }
190
191 Err(DatabaseError::UpdateFailed(
192 "no remote tracking branch found (tried origin/main, origin/master)".to_string(),
193 ))
194 }
195
196 fn reclone(&self) -> Result<bool, DatabaseError> {
202 let tmp = {
205 let mut p = self.path.clone().into_os_string();
206 p.push("_tmp");
207 PathBuf::from(p)
208 };
209 let old = {
210 let mut p = self.path.clone().into_os_string();
211 p.push("_old");
212 PathBuf::from(p)
213 };
214
215 let _ = std::fs::remove_dir_all(&tmp);
217 let _ = std::fs::remove_dir_all(&old);
218
219 Database::download(&tmp, true)?;
221 std::fs::rename(&self.path, &old).map_err(DatabaseError::Io)?;
222 std::fs::rename(&tmp, &self.path).map_err(|e| {
223 let _ = std::fs::rename(&old, &self.path);
225 DatabaseError::Io(e)
226 })?;
227
228 let _ = std::fs::remove_dir_all(&old);
230
231 Ok(true)
232 }
233
234 pub fn is_git(&self) -> bool {
236 self.path.join(".git").is_dir()
237 }
238
239 pub fn exists(&self) -> bool {
241 self.path.is_dir() && self.path.join("gems").is_dir()
242 }
243
244 pub fn path(&self) -> &Path {
246 &self.path
247 }
248
249 pub fn commit_id(&self) -> Option<String> {
251 if !self.is_git() {
252 return None;
253 }
254 let repo = gix::open(&self.path).ok()?;
255 let id = repo.head_id().ok()?;
256 Some(id.to_string())
257 }
258
259 pub fn last_updated_at(&self) -> Option<i64> {
261 if !self.is_git() {
262 return None;
263 }
264 let repo = gix::open(&self.path).ok()?;
265 let commit = repo.head_commit().ok()?;
266 let time = commit.time().ok()?;
267 Some(time.seconds)
268 }
269
270 pub fn advisories(&self) -> Vec<Advisory> {
272 let mut results = Vec::new();
273 let gems_dir = self.path.join("gems");
274
275 if !gems_dir.is_dir() {
276 return results;
277 }
278
279 if let Ok(entries) = std::fs::read_dir(&gems_dir) {
280 for entry in entries.flatten() {
281 if entry.path().is_dir() {
282 let _ = self.load_advisories_from_dir(&entry.path(), &mut results);
283 }
284 }
285 }
286
287 results
288 }
289
290 pub fn advisories_for(&self, gem_name: &str) -> Vec<Advisory> {
292 self.advisories_for_with_errors(gem_name).0
293 }
294
295 fn advisories_for_with_errors(&self, gem_name: &str) -> (Vec<Advisory>, usize) {
297 let mut results = Vec::new();
298 let gem_dir = self.path.join("gems").join(gem_name);
299
300 if !is_contained_in(&gem_dir, &self.path) {
301 return (results, 0);
302 }
303
304 let errors = if gem_dir.is_dir() {
305 self.load_advisories_from_dir(&gem_dir, &mut results)
306 } else {
307 0
308 };
309
310 (results, errors)
311 }
312
313 pub fn check_gem(&self, gem_name: &str, version: &Version) -> (Vec<Advisory>, usize) {
318 let (advisories, errors) = self.advisories_for_with_errors(gem_name);
319 let vulnerable = advisories
320 .into_iter()
321 .filter(|advisory| advisory.vulnerable(version))
322 .collect();
323 (vulnerable, errors)
324 }
325
326 pub fn advisories_for_ruby(&self, engine: &str) -> Vec<Advisory> {
328 self.advisories_for_ruby_with_errors(engine).0
329 }
330
331 fn advisories_for_ruby_with_errors(&self, engine: &str) -> (Vec<Advisory>, usize) {
333 let mut results = Vec::new();
334 let engine_dir = self.path.join("rubies").join(engine);
335
336 if !is_contained_in(&engine_dir, &self.path) {
337 return (results, 0);
338 }
339
340 let errors = if engine_dir.is_dir() {
341 self.load_advisories_from_dir(&engine_dir, &mut results)
342 } else {
343 0
344 };
345
346 (results, errors)
347 }
348
349 pub fn check_ruby(&self, engine: &str, version: &Version) -> (Vec<Advisory>, usize) {
354 let (advisories, errors) = self.advisories_for_ruby_with_errors(engine);
355 let vulnerable = advisories
356 .into_iter()
357 .filter(|advisory| advisory.vulnerable(version))
358 .collect();
359 (vulnerable, errors)
360 }
361
362 pub fn size(&self) -> usize {
364 self.count_advisories_in("gems")
365 }
366
367 pub fn rubies_size(&self) -> usize {
369 self.count_advisories_in("rubies")
370 }
371
372 fn count_advisories_in(&self, subdir: &str) -> usize {
374 let dir = self.path.join(subdir);
375 if !dir.is_dir() {
376 return 0;
377 }
378
379 let mut count = 0;
380 if let Ok(entries) = std::fs::read_dir(&dir) {
381 for entry in entries.flatten() {
382 if entry.path().is_dir()
383 && let Ok(advisory_files) = std::fs::read_dir(entry.path())
384 {
385 count += advisory_files
386 .flatten()
387 .filter(|f| f.path().extension().is_some_and(|ext| ext == "yml"))
388 .count();
389 }
390 }
391 }
392
393 count
394 }
395
396 fn load_advisories_from_dir(&self, dir: &Path, results: &mut Vec<Advisory>) -> usize {
400 let mut errors = 0;
401 if let Ok(entries) = std::fs::read_dir(dir) {
402 for entry in entries.flatten() {
403 let path = entry.path();
404 if path.extension().is_some_and(|ext| ext == "yml") {
405 match Advisory::load(&path) {
406 Ok(advisory) => results.push(advisory),
407 Err(e) => {
408 eprintln!("warning: failed to load advisory {}: {}", path.display(), e);
409 errors += 1;
410 }
411 }
412 }
413 }
414 }
415 errors
416 }
417}
418
419impl fmt::Display for Database {
420 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
421 write!(f, "{}", self.path.display())
422 }
423}
424
425fn is_contained_in(child: &Path, parent: &Path) -> bool {
428 use std::path::Component;
429
430 let mut depth: usize = 0;
431 for component in child.strip_prefix(parent).unwrap_or(child).components() {
432 match component {
433 Component::ParentDir => {
434 if depth == 0 {
435 return false;
436 }
437 depth -= 1;
438 }
439 Component::Normal(_) => depth += 1,
440 _ => {}
441 }
442 }
443 true
444}
445
446fn dirs_fallback() -> PathBuf {
448 if let Ok(home) = std::env::var("HOME") {
449 PathBuf::from(home)
450 .join(".local")
451 .join("share")
452 .join("ruby-advisory-db")
453 } else {
454 PathBuf::from(".ruby-advisory-db")
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461
462 fn local_db() -> Option<Database> {
465 let path = Database::default_path();
466 if path.is_dir() && path.join("gems").is_dir() {
467 Database::open(&path).ok()
468 } else {
469 None
470 }
471 }
472
473 #[test]
474 fn open_local_database() {
475 if let Some(db) = local_db() {
476 assert!(db.exists());
477 assert!(db.is_git());
478 }
479 }
480
481 #[test]
482 fn database_size() {
483 if let Some(db) = local_db() {
484 let size = db.size();
485 assert!(size > 100, "expected > 100 advisories, got {}", size);
487 }
488 }
489
490 #[test]
491 fn database_commit_id() {
492 if let Some(db) = local_db() {
493 let commit = db.commit_id();
494 assert!(commit.is_some());
495 let id = commit.unwrap();
496 assert_eq!(id.len(), 40); }
498 }
499
500 #[test]
501 fn database_last_updated() {
502 if let Some(db) = local_db() {
503 let ts = db.last_updated_at();
504 assert!(ts.is_some());
505 assert!(ts.unwrap() > 0);
506 }
507 }
508
509 #[test]
510 fn advisories_for_actionpack() {
511 if let Some(db) = local_db() {
512 let advisories = db.advisories_for("actionpack");
513 assert!(!advisories.is_empty(), "expected advisories for actionpack");
515 }
516 }
517
518 #[test]
519 fn check_vulnerable_gem() {
520 if let Some(db) = local_db() {
521 let version = Version::parse("3.2.10").unwrap();
523 let (vulnerabilities, _errors) = db.check_gem("activerecord", &version);
524 assert!(
525 !vulnerabilities.is_empty(),
526 "expected activerecord 3.2.10 to have vulnerabilities"
527 );
528 }
529 }
530
531 #[test]
532 fn check_nonexistent_gem() {
533 if let Some(db) = local_db() {
534 let version = Version::parse("1.0.0").unwrap();
535 let (vulnerabilities, _errors) = db.check_gem("nonexistent-gem-xyz", &version);
536 assert!(vulnerabilities.is_empty());
537 }
538 }
539
540 #[test]
543 fn open_fixture_advisory_dir() {
544 let (tmp, _) = temp_mock_db();
545
546 let db = Database::open(tmp.path()).unwrap();
547 assert!(!db.is_git());
548
549 let advisories = db.advisories_for("test");
550 assert_eq!(advisories.len(), 1);
551 assert_eq!(advisories[0].id, "CVE-2020-1234");
552
553 let (vulns, _errors) = db.check_gem("test", &Version::parse("0.1.0").unwrap());
555 assert_eq!(vulns.len(), 1);
556
557 let (vulns, _errors) = db.check_gem("test", &Version::parse("1.0.0").unwrap());
559 assert!(vulns.is_empty());
560 }
561
562 #[test]
565 fn open_nonexistent_path() {
566 let result = Database::open(Path::new("/nonexistent/path"));
567 assert!(result.is_err());
568 }
569
570 #[test]
571 fn default_path_is_sensible() {
572 let path = Database::default_path();
573 let path_str = path.to_string_lossy();
574 assert!(
575 path_str.contains("ruby-advisory-db"),
576 "default path should contain ruby-advisory-db: {}",
577 path_str
578 );
579 }
580
581 fn temp_mock_db() -> (tempfile::TempDir, PathBuf) {
584 let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
585 let tmp = tempfile::tempdir().unwrap();
586 let gem_dir = tmp.path().join("gems").join("test");
587 std::fs::create_dir_all(&gem_dir).unwrap();
588 std::fs::copy(
589 fixture_dir.join("advisory/CVE-2020-1234.yml"),
590 gem_dir.join("CVE-2020-1234.yml"),
591 )
592 .unwrap();
593 (tmp, fixture_dir)
594 }
595
596 #[test]
599 fn database_display() {
600 let (tmp, _) = temp_mock_db();
601 let db = Database::open(tmp.path()).unwrap();
602 let display = db.to_string();
603 assert_eq!(display, tmp.path().to_string_lossy());
604 }
605
606 #[test]
609 fn database_exists_with_gems() {
610 let (tmp, _) = temp_mock_db();
611 let db = Database::open(tmp.path()).unwrap();
612 assert!(db.exists());
613 assert!(db.path() == tmp.path());
614 }
615
616 #[test]
619 fn database_advisories_with_mock() {
620 let (tmp, _) = temp_mock_db();
621 let db = Database::open(tmp.path()).unwrap();
622 let all = db.advisories();
623 assert_eq!(all.len(), 1);
624 assert_eq!(all[0].id, "CVE-2020-1234");
625 }
626
627 #[test]
628 fn database_size_with_mock() {
629 let (tmp, _) = temp_mock_db();
630 let db = Database::open(tmp.path()).unwrap();
631 assert_eq!(db.size(), 1);
632 }
633
634 #[test]
637 fn rubies_size_with_mock() {
638 let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
640 let db_dir = fixture_dir.join("mock_db");
641 let db = Database::open(&db_dir).unwrap();
642 assert_eq!(db.rubies_size(), 1);
643 }
644
645 #[test]
646 fn advisories_for_ruby_with_mock() {
647 let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
648 let db_dir = fixture_dir.join("mock_db");
649 let db = Database::open(&db_dir).unwrap();
650 let advisories = db.advisories_for_ruby("ruby");
651 assert_eq!(advisories.len(), 1);
652 assert_eq!(advisories[0].id, "CVE-2021-31810");
653 }
654
655 #[test]
656 fn check_ruby_vulnerable_version() {
657 let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
658 let db_dir = fixture_dir.join("mock_db");
659 let db = Database::open(&db_dir).unwrap();
660 let (vulns, _) = db.check_ruby("ruby", &Version::parse("2.6.0").unwrap());
661 assert_eq!(vulns.len(), 1);
662 }
663
664 #[test]
665 fn check_ruby_patched_version() {
666 let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
667 let db_dir = fixture_dir.join("mock_db");
668 let db = Database::open(&db_dir).unwrap();
669 let (vulns, _) = db.check_ruby("ruby", &Version::parse("3.0.2").unwrap());
670 assert!(vulns.is_empty());
671 }
672
673 #[test]
674 fn check_ruby_nonexistent_engine() {
675 let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
676 let db_dir = fixture_dir.join("mock_db");
677 let db = Database::open(&db_dir).unwrap();
678 let (vulns, _) = db.check_ruby("nonexistent", &Version::parse("1.0.0").unwrap());
679 assert!(vulns.is_empty());
680 }
681
682 #[test]
685 fn commit_id_none_for_non_git() {
686 let (tmp, _) = temp_mock_db();
687 let db = Database::open(tmp.path()).unwrap();
688 assert_eq!(db.commit_id(), None);
689 assert_eq!(db.last_updated_at(), None);
690 }
691
692 #[test]
695 fn database_error_not_found_display() {
696 let err = DatabaseError::NotFound(PathBuf::from("/tmp/missing"));
697 assert!(err.to_string().contains("database not found"));
698 assert!(err.to_string().contains("/tmp/missing"));
699 }
700
701 #[test]
702 fn database_error_download_failed_display() {
703 let err = DatabaseError::DownloadFailed("network error".to_string());
704 assert!(err.to_string().contains("download failed"));
705 assert!(err.to_string().contains("network error"));
706 }
707
708 #[test]
709 fn database_error_update_failed_display() {
710 let err = DatabaseError::UpdateFailed("merge conflict".to_string());
711 assert!(err.to_string().contains("update failed"));
712 }
713
714 #[test]
715 fn database_error_git_display() {
716 let err = DatabaseError::Git("corrupt repo".to_string());
717 assert!(err.to_string().contains("git error"));
718 }
719
720 #[test]
723 fn is_contained_in_normal_path() {
724 let parent = Path::new("/db");
725 assert!(is_contained_in(&parent.join("gems").join("rails"), parent));
726 }
727
728 #[test]
729 fn is_contained_in_rejects_traversal() {
730 let parent = Path::new("/db");
731 assert!(!is_contained_in(
732 &parent.join("gems").join("..").join("..").join("etc"),
733 parent
734 ));
735 }
736
737 #[test]
738 fn advisories_for_traversal_gem_returns_empty() {
739 let (tmp, _) = temp_mock_db();
740 let db = Database::open(tmp.path()).unwrap();
741 let (advisories, errors) = db.advisories_for_with_errors("../../etc");
742 assert!(advisories.is_empty());
743 assert_eq!(errors, 0);
744 }
745
746 #[test]
747 fn advisories_for_ruby_traversal_returns_empty() {
748 let (tmp, _) = temp_mock_db();
749 let db = Database::open(tmp.path()).unwrap();
750 let (advisories, errors) = db.advisories_for_ruby_with_errors("../../etc");
751 assert!(advisories.is_empty());
752 assert_eq!(errors, 0);
753 }
754}