1use std::fmt;
2use std::path::{Path, PathBuf};
3use thiserror::Error;
4
5use super::model::Advisory;
6use crate::version::Version;
7
8const ADVISORY_DB_URL: &str = "https://github.com/rubysec/ruby-advisory-db.git";
10
11#[derive(Debug)]
13pub struct Database {
14 path: PathBuf,
15}
16
17#[derive(Debug, Error)]
18pub enum DatabaseError {
19 #[error("database not found at {}", .0.display())]
20 NotFound(PathBuf),
21 #[error("download failed: {0}")]
22 DownloadFailed(String),
23 #[error("update failed: {0}")]
24 UpdateFailed(String),
25 #[error("git error: {0}")]
26 Git(String),
27 #[error("IO error: {0}")]
28 Io(#[from] std::io::Error),
29}
30
31impl Database {
32 pub fn open(path: &Path) -> Result<Self, DatabaseError> {
34 if !path.is_dir() {
35 return Err(DatabaseError::NotFound(path.to_path_buf()));
36 }
37 Ok(Database {
38 path: path.to_path_buf(),
39 })
40 }
41
42 pub fn default_path() -> PathBuf {
46 if let Ok(custom) = std::env::var("GEM_AUDIT_DB") {
47 return PathBuf::from(custom);
48 }
49 dirs_fallback()
50 }
51
52 pub fn download(path: &Path, _quiet: bool) -> Result<Self, DatabaseError> {
54 if let Some(parent) = path.parent() {
56 std::fs::create_dir_all(parent).map_err(DatabaseError::Io)?;
57 }
58
59 let (mut checkout, _outcome) = gix::prepare_clone(ADVISORY_DB_URL, path)
60 .map_err(|e| DatabaseError::DownloadFailed(e.to_string()))?
61 .fetch_then_checkout(gix::progress::Discard, &gix::interrupt::IS_INTERRUPTED)
62 .map_err(|e| DatabaseError::DownloadFailed(e.to_string()))?;
63
64 let (_repo, _outcome) = checkout
65 .main_worktree(gix::progress::Discard, &gix::interrupt::IS_INTERRUPTED)
66 .map_err(|e| DatabaseError::DownloadFailed(e.to_string()))?;
67
68 Ok(Database {
69 path: path.to_path_buf(),
70 })
71 }
72
73 pub fn update(&self) -> Result<bool, DatabaseError> {
78 if !self.is_git() {
79 return Ok(false);
80 }
81
82 if let Err(e) = self.try_fetch() {
83 eprintln!("warning: git fetch failed ({}), re-cloning ...", e);
84 return self.reclone();
85 }
86
87 self.checkout_head()
88 }
89
90 fn try_fetch(&self) -> Result<(), DatabaseError> {
92 let repo = gix::open(&self.path).map_err(|e| DatabaseError::Git(e.to_string()))?;
93
94 let remote = repo
95 .find_default_remote(gix::remote::Direction::Fetch)
96 .ok_or_else(|| DatabaseError::UpdateFailed("no remote configured".to_string()))?
97 .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
98
99 let connection = remote
100 .connect(gix::remote::Direction::Fetch)
101 .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
102
103 connection
104 .prepare_fetch(gix::progress::Discard, Default::default())
105 .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?
106 .receive(gix::progress::Discard, &gix::interrupt::IS_INTERRUPTED)
107 .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
108
109 Ok(())
110 }
111
112 fn checkout_head(&self) -> Result<bool, DatabaseError> {
114 let repo = gix::open(&self.path).map_err(|e| DatabaseError::Git(e.to_string()))?;
115 let tree = repo
116 .head_commit()
117 .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?
118 .tree()
119 .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
120
121 let mut index = repo
122 .index_from_tree(&tree.id)
123 .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
124
125 let opts = gix::worktree::state::checkout::Options {
126 overwrite_existing: true,
127 ..Default::default()
128 };
129
130 gix::worktree::state::checkout(
131 &mut index,
132 repo.workdir()
133 .ok_or_else(|| DatabaseError::UpdateFailed("bare repository".to_string()))?,
134 repo.objects
135 .clone()
136 .into_arc()
137 .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?,
138 &gix::progress::Discard,
139 &gix::progress::Discard,
140 &gix::interrupt::IS_INTERRUPTED,
141 opts,
142 )
143 .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
144
145 Ok(true)
146 }
147
148 fn reclone(&self) -> Result<bool, DatabaseError> {
154 let tmp = {
157 let mut p = self.path.clone().into_os_string();
158 p.push("_tmp");
159 PathBuf::from(p)
160 };
161 let old = {
162 let mut p = self.path.clone().into_os_string();
163 p.push("_old");
164 PathBuf::from(p)
165 };
166
167 let _ = std::fs::remove_dir_all(&tmp);
169 let _ = std::fs::remove_dir_all(&old);
170
171 Database::download(&tmp, true)?;
173 std::fs::rename(&self.path, &old).map_err(DatabaseError::Io)?;
174 std::fs::rename(&tmp, &self.path).map_err(|e| {
175 let _ = std::fs::rename(&old, &self.path);
177 DatabaseError::Io(e)
178 })?;
179
180 let _ = std::fs::remove_dir_all(&old);
182
183 Ok(true)
184 }
185
186 pub fn is_git(&self) -> bool {
188 self.path.join(".git").is_dir()
189 }
190
191 pub fn exists(&self) -> bool {
193 self.path.is_dir() && self.path.join("gems").is_dir()
194 }
195
196 pub fn path(&self) -> &Path {
198 &self.path
199 }
200
201 pub fn commit_id(&self) -> Option<String> {
203 if !self.is_git() {
204 return None;
205 }
206 let repo = gix::open(&self.path).ok()?;
207 let id = repo.head_id().ok()?;
208 Some(id.to_string())
209 }
210
211 pub fn last_updated_at(&self) -> Option<i64> {
213 if !self.is_git() {
214 return None;
215 }
216 let repo = gix::open(&self.path).ok()?;
217 let commit = repo.head_commit().ok()?;
218 let time = commit.time().ok()?;
219 Some(time.seconds)
220 }
221
222 pub fn advisories(&self) -> Vec<Advisory> {
224 let mut results = Vec::new();
225 let gems_dir = self.path.join("gems");
226
227 if !gems_dir.is_dir() {
228 return results;
229 }
230
231 if let Ok(entries) = std::fs::read_dir(&gems_dir) {
232 for entry in entries.flatten() {
233 if entry.path().is_dir() {
234 let _ = self.load_advisories_from_dir(&entry.path(), &mut results);
235 }
236 }
237 }
238
239 results
240 }
241
242 pub fn advisories_for(&self, gem_name: &str) -> Vec<Advisory> {
244 self.advisories_for_with_errors(gem_name).0
245 }
246
247 fn advisories_for_with_errors(&self, gem_name: &str) -> (Vec<Advisory>, usize) {
249 let mut results = Vec::new();
250 let gem_dir = self.path.join("gems").join(gem_name);
251
252 let errors = if gem_dir.is_dir() {
253 self.load_advisories_from_dir(&gem_dir, &mut results)
254 } else {
255 0
256 };
257
258 (results, errors)
259 }
260
261 pub fn check_gem(&self, gem_name: &str, version: &Version) -> (Vec<Advisory>, usize) {
266 let (advisories, errors) = self.advisories_for_with_errors(gem_name);
267 let vulnerable = advisories
268 .into_iter()
269 .filter(|advisory| advisory.vulnerable(version))
270 .collect();
271 (vulnerable, errors)
272 }
273
274 pub fn advisories_for_ruby(&self, engine: &str) -> Vec<Advisory> {
276 self.advisories_for_ruby_with_errors(engine).0
277 }
278
279 fn advisories_for_ruby_with_errors(&self, engine: &str) -> (Vec<Advisory>, usize) {
281 let mut results = Vec::new();
282 let engine_dir = self.path.join("rubies").join(engine);
283
284 let errors = if engine_dir.is_dir() {
285 self.load_advisories_from_dir(&engine_dir, &mut results)
286 } else {
287 0
288 };
289
290 (results, errors)
291 }
292
293 pub fn check_ruby(&self, engine: &str, version: &Version) -> (Vec<Advisory>, usize) {
298 let (advisories, errors) = self.advisories_for_ruby_with_errors(engine);
299 let vulnerable = advisories
300 .into_iter()
301 .filter(|advisory| advisory.vulnerable(version))
302 .collect();
303 (vulnerable, errors)
304 }
305
306 pub fn size(&self) -> usize {
308 self.count_advisories_in("gems")
309 }
310
311 pub fn rubies_size(&self) -> usize {
313 self.count_advisories_in("rubies")
314 }
315
316 fn count_advisories_in(&self, subdir: &str) -> usize {
318 let dir = self.path.join(subdir);
319 if !dir.is_dir() {
320 return 0;
321 }
322
323 let mut count = 0;
324 if let Ok(entries) = std::fs::read_dir(&dir) {
325 for entry in entries.flatten() {
326 if entry.path().is_dir()
327 && let Ok(advisory_files) = std::fs::read_dir(entry.path())
328 {
329 count += advisory_files
330 .flatten()
331 .filter(|f| f.path().extension().is_some_and(|ext| ext == "yml"))
332 .count();
333 }
334 }
335 }
336
337 count
338 }
339
340 fn load_advisories_from_dir(&self, dir: &Path, results: &mut Vec<Advisory>) -> usize {
344 let mut errors = 0;
345 if let Ok(entries) = std::fs::read_dir(dir) {
346 for entry in entries.flatten() {
347 let path = entry.path();
348 if path.extension().is_some_and(|ext| ext == "yml") {
349 match Advisory::load(&path) {
350 Ok(advisory) => results.push(advisory),
351 Err(e) => {
352 eprintln!("warning: failed to load advisory {}: {}", path.display(), e);
353 errors += 1;
354 }
355 }
356 }
357 }
358 }
359 errors
360 }
361}
362
363impl fmt::Display for Database {
364 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
365 write!(f, "{}", self.path.display())
366 }
367}
368
369fn dirs_fallback() -> PathBuf {
371 if let Ok(home) = std::env::var("HOME") {
372 PathBuf::from(home)
373 .join(".local")
374 .join("share")
375 .join("ruby-advisory-db")
376 } else {
377 PathBuf::from(".ruby-advisory-db")
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384
385 fn local_db() -> Option<Database> {
388 let path = Database::default_path();
389 if path.is_dir() && path.join("gems").is_dir() {
390 Database::open(&path).ok()
391 } else {
392 None
393 }
394 }
395
396 #[test]
397 fn open_local_database() {
398 if let Some(db) = local_db() {
399 assert!(db.exists());
400 assert!(db.is_git());
401 }
402 }
403
404 #[test]
405 fn database_size() {
406 if let Some(db) = local_db() {
407 let size = db.size();
408 assert!(size > 100, "expected > 100 advisories, got {}", size);
410 }
411 }
412
413 #[test]
414 fn database_commit_id() {
415 if let Some(db) = local_db() {
416 let commit = db.commit_id();
417 assert!(commit.is_some());
418 let id = commit.unwrap();
419 assert_eq!(id.len(), 40); }
421 }
422
423 #[test]
424 fn database_last_updated() {
425 if let Some(db) = local_db() {
426 let ts = db.last_updated_at();
427 assert!(ts.is_some());
428 assert!(ts.unwrap() > 0);
429 }
430 }
431
432 #[test]
433 fn advisories_for_actionpack() {
434 if let Some(db) = local_db() {
435 let advisories = db.advisories_for("actionpack");
436 assert!(!advisories.is_empty(), "expected advisories for actionpack");
438 }
439 }
440
441 #[test]
442 fn check_vulnerable_gem() {
443 if let Some(db) = local_db() {
444 let version = Version::parse("3.2.10").unwrap();
446 let (vulnerabilities, _errors) = db.check_gem("activerecord", &version);
447 assert!(
448 !vulnerabilities.is_empty(),
449 "expected activerecord 3.2.10 to have vulnerabilities"
450 );
451 }
452 }
453
454 #[test]
455 fn check_nonexistent_gem() {
456 if let Some(db) = local_db() {
457 let version = Version::parse("1.0.0").unwrap();
458 let (vulnerabilities, _errors) = db.check_gem("nonexistent-gem-xyz", &version);
459 assert!(vulnerabilities.is_empty());
460 }
461 }
462
463 #[test]
466 fn open_fixture_advisory_dir() {
467 let (db_dir, _) = temp_mock_db("fixture");
468
469 let db = Database::open(&db_dir).unwrap();
470 assert!(!db.is_git());
471
472 let advisories = db.advisories_for("test");
473 assert_eq!(advisories.len(), 1);
474 assert_eq!(advisories[0].id, "CVE-2020-1234");
475
476 let (vulns, _errors) = db.check_gem("test", &Version::parse("0.1.0").unwrap());
478 assert_eq!(vulns.len(), 1);
479
480 let (vulns, _errors) = db.check_gem("test", &Version::parse("1.0.0").unwrap());
482 assert!(vulns.is_empty());
483
484 std::fs::remove_dir_all(&db_dir).unwrap();
485 }
486
487 #[test]
490 fn open_nonexistent_path() {
491 let result = Database::open(Path::new("/nonexistent/path"));
492 assert!(result.is_err());
493 }
494
495 #[test]
496 fn default_path_is_sensible() {
497 let path = Database::default_path();
498 let path_str = path.to_string_lossy();
499 assert!(
500 path_str.contains("ruby-advisory-db"),
501 "default path should contain ruby-advisory-db: {}",
502 path_str
503 );
504 }
505
506 fn temp_mock_db(suffix: &str) -> (PathBuf, PathBuf) {
509 let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
510 let db_dir = std::env::temp_dir().join(format!("gem_audit_db_test_{}", suffix));
511 let _ = std::fs::remove_dir_all(&db_dir);
512 let gem_dir = db_dir.join("gems").join("test");
513 std::fs::create_dir_all(&gem_dir).unwrap();
514 std::fs::copy(
515 fixture_dir.join("advisory/CVE-2020-1234.yml"),
516 gem_dir.join("CVE-2020-1234.yml"),
517 )
518 .unwrap();
519 (db_dir, fixture_dir)
520 }
521
522 #[test]
525 fn database_display() {
526 let (db_dir, _) = temp_mock_db("display");
527 let db = Database::open(&db_dir).unwrap();
528 let display = db.to_string();
529 assert!(display.contains("gem_audit_db_test_display"));
530 std::fs::remove_dir_all(&db_dir).unwrap();
531 }
532
533 #[test]
536 fn database_exists_with_gems() {
537 let (db_dir, _) = temp_mock_db("exists");
538 let db = Database::open(&db_dir).unwrap();
539 assert!(db.exists());
540 assert!(db.path() == db_dir.as_path());
541 std::fs::remove_dir_all(&db_dir).unwrap();
542 }
543
544 #[test]
547 fn database_advisories_with_mock() {
548 let (db_dir, _) = temp_mock_db("advisories");
549 let db = Database::open(&db_dir).unwrap();
550 let all = db.advisories();
551 assert_eq!(all.len(), 1);
552 assert_eq!(all[0].id, "CVE-2020-1234");
553 std::fs::remove_dir_all(&db_dir).unwrap();
554 }
555
556 #[test]
557 fn database_size_with_mock() {
558 let (db_dir, _) = temp_mock_db("size");
559 let db = Database::open(&db_dir).unwrap();
560 assert_eq!(db.size(), 1);
561 std::fs::remove_dir_all(&db_dir).unwrap();
562 }
563
564 #[test]
567 fn rubies_size_with_mock() {
568 let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
570 let db_dir = fixture_dir.join("mock_db");
571 let db = Database::open(&db_dir).unwrap();
572 assert_eq!(db.rubies_size(), 1);
573 }
574
575 #[test]
576 fn advisories_for_ruby_with_mock() {
577 let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
578 let db_dir = fixture_dir.join("mock_db");
579 let db = Database::open(&db_dir).unwrap();
580 let advisories = db.advisories_for_ruby("ruby");
581 assert_eq!(advisories.len(), 1);
582 assert_eq!(advisories[0].id, "CVE-2021-31810");
583 }
584
585 #[test]
586 fn check_ruby_vulnerable_version() {
587 let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
588 let db_dir = fixture_dir.join("mock_db");
589 let db = Database::open(&db_dir).unwrap();
590 let (vulns, _) = db.check_ruby("ruby", &Version::parse("2.6.0").unwrap());
591 assert_eq!(vulns.len(), 1);
592 }
593
594 #[test]
595 fn check_ruby_patched_version() {
596 let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
597 let db_dir = fixture_dir.join("mock_db");
598 let db = Database::open(&db_dir).unwrap();
599 let (vulns, _) = db.check_ruby("ruby", &Version::parse("3.0.2").unwrap());
600 assert!(vulns.is_empty());
601 }
602
603 #[test]
604 fn check_ruby_nonexistent_engine() {
605 let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
606 let db_dir = fixture_dir.join("mock_db");
607 let db = Database::open(&db_dir).unwrap();
608 let (vulns, _) = db.check_ruby("nonexistent", &Version::parse("1.0.0").unwrap());
609 assert!(vulns.is_empty());
610 }
611
612 #[test]
615 fn commit_id_none_for_non_git() {
616 let (db_dir, _) = temp_mock_db("nongit");
617 let db = Database::open(&db_dir).unwrap();
618 assert_eq!(db.commit_id(), None);
619 assert_eq!(db.last_updated_at(), None);
620 std::fs::remove_dir_all(&db_dir).unwrap();
621 }
622
623 #[test]
626 fn database_error_not_found_display() {
627 let err = DatabaseError::NotFound(PathBuf::from("/tmp/missing"));
628 assert!(err.to_string().contains("database not found"));
629 assert!(err.to_string().contains("/tmp/missing"));
630 }
631
632 #[test]
633 fn database_error_download_failed_display() {
634 let err = DatabaseError::DownloadFailed("network error".to_string());
635 assert!(err.to_string().contains("download failed"));
636 assert!(err.to_string().contains("network error"));
637 }
638
639 #[test]
640 fn database_error_update_failed_display() {
641 let err = DatabaseError::UpdateFailed("merge conflict".to_string());
642 assert!(err.to_string().contains("update failed"));
643 }
644
645 #[test]
646 fn database_error_git_display() {
647 let err = DatabaseError::Git("corrupt repo".to_string());
648 assert!(err.to_string().contains("git error"));
649 }
650}