Skip to main content

gem_audit/advisory/
database.rs

1use std::fmt;
2use std::path::{Path, PathBuf};
3use thiserror::Error;
4
5use super::model::Advisory;
6use crate::version::Version;
7
8/// Git URL of the ruby-advisory-db.
9const ADVISORY_DB_URL: &str = "https://github.com/rubysec/ruby-advisory-db.git";
10
11/// The ruby-advisory-db database.
12#[derive(Debug)]
13pub struct Database {
14    path: PathBuf,
15}
16
17#[derive(Debug, Error)]
18pub enum DatabaseError {
19    #[error("database not found at {}", .0.display())]
20    NotFound(PathBuf),
21    #[error("download failed: {0}")]
22    DownloadFailed(String),
23    #[error("update failed: {0}")]
24    UpdateFailed(String),
25    #[error("git error: {0}")]
26    Git(String),
27    #[error("IO error: {0}")]
28    Io(#[from] std::io::Error),
29}
30
31impl Database {
32    /// Open an existing advisory database at the given path.
33    pub fn open(path: &Path) -> Result<Self, DatabaseError> {
34        if !path.is_dir() {
35            return Err(DatabaseError::NotFound(path.to_path_buf()));
36        }
37        Ok(Database {
38            path: path.to_path_buf(),
39        })
40    }
41
42    /// The default database path: `~/.local/share/ruby-advisory-db`.
43    ///
44    /// Can be overridden by `GEM_AUDIT_DB` environment variable.
45    pub fn default_path() -> PathBuf {
46        if let Ok(custom) = std::env::var("GEM_AUDIT_DB") {
47            return PathBuf::from(custom);
48        }
49        dirs_fallback()
50    }
51
52    /// Download the ruby-advisory-db to the given path.
53    pub fn download(path: &Path, _quiet: bool) -> Result<Self, DatabaseError> {
54        // Ensure the parent directory exists; gix does not create it automatically.
55        if let Some(parent) = path.parent() {
56            std::fs::create_dir_all(parent).map_err(DatabaseError::Io)?;
57        }
58
59        let (mut checkout, _outcome) = gix::prepare_clone(ADVISORY_DB_URL, path)
60            .map_err(|e| DatabaseError::DownloadFailed(e.to_string()))?
61            .fetch_then_checkout(gix::progress::Discard, &gix::interrupt::IS_INTERRUPTED)
62            .map_err(|e| DatabaseError::DownloadFailed(e.to_string()))?;
63
64        let (_repo, _outcome) = checkout
65            .main_worktree(gix::progress::Discard, &gix::interrupt::IS_INTERRUPTED)
66            .map_err(|e| DatabaseError::DownloadFailed(e.to_string()))?;
67
68        Ok(Database {
69            path: path.to_path_buf(),
70        })
71    }
72
73    /// Update the database by fetching from origin and fast-forwarding.
74    ///
75    /// If the git fetch fails (e.g. due to ref-update issues in containerised environments),
76    /// falls back to a fresh clone so the update always succeeds.
77    pub fn update(&self) -> Result<bool, DatabaseError> {
78        if !self.is_git() {
79            return Ok(false);
80        }
81
82        if let Err(e) = self.try_fetch() {
83            eprintln!("warning: git fetch failed ({}), re-cloning ...", e);
84            return self.reclone();
85        }
86
87        self.checkout_head()
88    }
89
90    /// Attempt a git fetch from origin.  Returns `Err` on any failure.
91    fn try_fetch(&self) -> Result<(), DatabaseError> {
92        let repo = gix::open(&self.path).map_err(|e| DatabaseError::Git(e.to_string()))?;
93
94        let remote = repo
95            .find_default_remote(gix::remote::Direction::Fetch)
96            .ok_or_else(|| DatabaseError::UpdateFailed("no remote configured".to_string()))?
97            .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
98
99        let connection = remote
100            .connect(gix::remote::Direction::Fetch)
101            .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
102
103        connection
104            .prepare_fetch(gix::progress::Discard, Default::default())
105            .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?
106            .receive(gix::progress::Discard, &gix::interrupt::IS_INTERRUPTED)
107            .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
108
109        Ok(())
110    }
111
112    /// Checkout the current HEAD into the working tree.
113    fn checkout_head(&self) -> Result<bool, DatabaseError> {
114        let repo = gix::open(&self.path).map_err(|e| DatabaseError::Git(e.to_string()))?;
115        let tree = repo
116            .head_commit()
117            .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?
118            .tree()
119            .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
120
121        let mut index = repo
122            .index_from_tree(&tree.id)
123            .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
124
125        let opts = gix::worktree::state::checkout::Options {
126            overwrite_existing: true,
127            ..Default::default()
128        };
129
130        gix::worktree::state::checkout(
131            &mut index,
132            repo.workdir()
133                .ok_or_else(|| DatabaseError::UpdateFailed("bare repository".to_string()))?,
134            repo.objects
135                .clone()
136                .into_arc()
137                .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?,
138            &gix::progress::Discard,
139            &gix::progress::Discard,
140            &gix::interrupt::IS_INTERRUPTED,
141            opts,
142        )
143        .map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
144
145        Ok(true)
146    }
147
148    /// Delete the existing DB and re-clone from scratch.
149    ///
150    /// Uses an atomic swap: clone to a sibling `_tmp` directory, rename the
151    /// existing DB to `_old`, rename `_tmp` to the final path, then remove
152    /// `_old`.  This ensures `self.path` always contains a valid database.
153    fn reclone(&self) -> Result<bool, DatabaseError> {
154        // Derive sibling paths by appending a suffix to the directory name,
155        // avoiding `with_extension()` which replaces rather than appends.
156        let tmp = {
157            let mut p = self.path.clone().into_os_string();
158            p.push("_tmp");
159            PathBuf::from(p)
160        };
161        let old = {
162            let mut p = self.path.clone().into_os_string();
163            p.push("_old");
164            PathBuf::from(p)
165        };
166
167        // Clean up any leftover from a previous failed attempt.
168        let _ = std::fs::remove_dir_all(&tmp);
169        let _ = std::fs::remove_dir_all(&old);
170
171        // Clone into tmp, then atomically swap with the live DB.
172        Database::download(&tmp, true)?;
173        std::fs::rename(&self.path, &old).map_err(DatabaseError::Io)?;
174        std::fs::rename(&tmp, &self.path).map_err(|e| {
175            // Best-effort rollback: restore the old DB before returning the error.
176            let _ = std::fs::rename(&old, &self.path);
177            DatabaseError::Io(e)
178        })?;
179
180        // Remove old DB (best-effort, failure is non-fatal).
181        let _ = std::fs::remove_dir_all(&old);
182
183        Ok(true)
184    }
185
186    /// Check whether the database path is a git repository.
187    pub fn is_git(&self) -> bool {
188        self.path.join(".git").is_dir()
189    }
190
191    /// Check whether the database exists and is non-empty.
192    pub fn exists(&self) -> bool {
193        self.path.is_dir() && self.path.join("gems").is_dir()
194    }
195
196    /// The path to the database.
197    pub fn path(&self) -> &Path {
198        &self.path
199    }
200
201    /// The last commit ID (HEAD) of the database repository.
202    pub fn commit_id(&self) -> Option<String> {
203        if !self.is_git() {
204            return None;
205        }
206        let repo = gix::open(&self.path).ok()?;
207        let id = repo.head_id().ok()?;
208        Some(id.to_string())
209    }
210
211    /// The timestamp of the last commit.
212    pub fn last_updated_at(&self) -> Option<i64> {
213        if !self.is_git() {
214            return None;
215        }
216        let repo = gix::open(&self.path).ok()?;
217        let commit = repo.head_commit().ok()?;
218        let time = commit.time().ok()?;
219        Some(time.seconds)
220    }
221
222    /// Enumerate all advisories in the database.
223    pub fn advisories(&self) -> Vec<Advisory> {
224        let mut results = Vec::new();
225        let gems_dir = self.path.join("gems");
226
227        if !gems_dir.is_dir() {
228            return results;
229        }
230
231        if let Ok(entries) = std::fs::read_dir(&gems_dir) {
232            for entry in entries.flatten() {
233                if entry.path().is_dir() {
234                    let _ = self.load_advisories_from_dir(&entry.path(), &mut results);
235                }
236            }
237        }
238
239        results
240    }
241
242    /// Get advisories for a specific gem.
243    pub fn advisories_for(&self, gem_name: &str) -> Vec<Advisory> {
244        self.advisories_for_with_errors(gem_name).0
245    }
246
247    /// Get advisories for a specific gem, along with the count of load errors.
248    fn advisories_for_with_errors(&self, gem_name: &str) -> (Vec<Advisory>, usize) {
249        let mut results = Vec::new();
250        let gem_dir = self.path.join("gems").join(gem_name);
251
252        let errors = if gem_dir.is_dir() {
253            self.load_advisories_from_dir(&gem_dir, &mut results)
254        } else {
255            0
256        };
257
258        (results, errors)
259    }
260
261    /// Check a gem (name + version) against the database.
262    ///
263    /// Returns all advisories that the gem version is vulnerable to,
264    /// along with the count of advisory files that failed to load.
265    pub fn check_gem(&self, gem_name: &str, version: &Version) -> (Vec<Advisory>, usize) {
266        let (advisories, errors) = self.advisories_for_with_errors(gem_name);
267        let vulnerable = advisories
268            .into_iter()
269            .filter(|advisory| advisory.vulnerable(version))
270            .collect();
271        (vulnerable, errors)
272    }
273
274    /// Get advisories for a specific Ruby engine (e.g., "ruby", "jruby").
275    pub fn advisories_for_ruby(&self, engine: &str) -> Vec<Advisory> {
276        self.advisories_for_ruby_with_errors(engine).0
277    }
278
279    /// Get advisories for a specific Ruby engine, along with the count of load errors.
280    fn advisories_for_ruby_with_errors(&self, engine: &str) -> (Vec<Advisory>, usize) {
281        let mut results = Vec::new();
282        let engine_dir = self.path.join("rubies").join(engine);
283
284        let errors = if engine_dir.is_dir() {
285            self.load_advisories_from_dir(&engine_dir, &mut results)
286        } else {
287            0
288        };
289
290        (results, errors)
291    }
292
293    /// Check a Ruby engine+version against the database.
294    ///
295    /// Returns all advisories that the Ruby version is vulnerable to,
296    /// along with the count of advisory files that failed to load.
297    pub fn check_ruby(&self, engine: &str, version: &Version) -> (Vec<Advisory>, usize) {
298        let (advisories, errors) = self.advisories_for_ruby_with_errors(engine);
299        let vulnerable = advisories
300            .into_iter()
301            .filter(|advisory| advisory.vulnerable(version))
302            .collect();
303        (vulnerable, errors)
304    }
305
306    /// Total number of gem advisories in the database.
307    pub fn size(&self) -> usize {
308        self.count_advisories_in("gems")
309    }
310
311    /// Total number of Ruby interpreter advisories in the database.
312    pub fn rubies_size(&self) -> usize {
313        self.count_advisories_in("rubies")
314    }
315
316    /// Count advisory YAML files under a top-level directory (e.g., "gems" or "rubies").
317    fn count_advisories_in(&self, subdir: &str) -> usize {
318        let dir = self.path.join(subdir);
319        if !dir.is_dir() {
320            return 0;
321        }
322
323        let mut count = 0;
324        if let Ok(entries) = std::fs::read_dir(&dir) {
325            for entry in entries.flatten() {
326                if entry.path().is_dir()
327                    && let Ok(advisory_files) = std::fs::read_dir(entry.path())
328                {
329                    count += advisory_files
330                        .flatten()
331                        .filter(|f| f.path().extension().is_some_and(|ext| ext == "yml"))
332                        .count();
333                }
334            }
335        }
336
337        count
338    }
339
340    /// Load all advisory YAML files from a gem directory.
341    ///
342    /// Returns the number of files that failed to load.
343    fn load_advisories_from_dir(&self, dir: &Path, results: &mut Vec<Advisory>) -> usize {
344        let mut errors = 0;
345        if let Ok(entries) = std::fs::read_dir(dir) {
346            for entry in entries.flatten() {
347                let path = entry.path();
348                if path.extension().is_some_and(|ext| ext == "yml") {
349                    match Advisory::load(&path) {
350                        Ok(advisory) => results.push(advisory),
351                        Err(e) => {
352                            eprintln!("warning: failed to load advisory {}: {}", path.display(), e);
353                            errors += 1;
354                        }
355                    }
356                }
357            }
358        }
359        errors
360    }
361}
362
363impl fmt::Display for Database {
364    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
365        write!(f, "{}", self.path.display())
366    }
367}
368
369/// Fallback for getting the default database path when the `dirs` crate is not available.
370fn dirs_fallback() -> PathBuf {
371    if let Ok(home) = std::env::var("HOME") {
372        PathBuf::from(home)
373            .join(".local")
374            .join("share")
375            .join("ruby-advisory-db")
376    } else {
377        PathBuf::from(".ruby-advisory-db")
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384
385    // ========== Database with real ruby-advisory-db ==========
386
387    fn local_db() -> Option<Database> {
388        let path = Database::default_path();
389        if path.is_dir() && path.join("gems").is_dir() {
390            Database::open(&path).ok()
391        } else {
392            None
393        }
394    }
395
396    #[test]
397    fn open_local_database() {
398        if let Some(db) = local_db() {
399            assert!(db.exists());
400            assert!(db.is_git());
401        }
402    }
403
404    #[test]
405    fn database_size() {
406        if let Some(db) = local_db() {
407            let size = db.size();
408            // ruby-advisory-db has hundreds of advisories
409            assert!(size > 100, "expected > 100 advisories, got {}", size);
410        }
411    }
412
413    #[test]
414    fn database_commit_id() {
415        if let Some(db) = local_db() {
416            let commit = db.commit_id();
417            assert!(commit.is_some());
418            let id = commit.unwrap();
419            assert_eq!(id.len(), 40); // SHA-1 hex
420        }
421    }
422
423    #[test]
424    fn database_last_updated() {
425        if let Some(db) = local_db() {
426            let ts = db.last_updated_at();
427            assert!(ts.is_some());
428            assert!(ts.unwrap() > 0);
429        }
430    }
431
432    #[test]
433    fn advisories_for_actionpack() {
434        if let Some(db) = local_db() {
435            let advisories = db.advisories_for("actionpack");
436            // actionpack has many known CVEs
437            assert!(!advisories.is_empty(), "expected advisories for actionpack");
438        }
439    }
440
441    #[test]
442    fn check_vulnerable_gem() {
443        if let Some(db) = local_db() {
444            // Rails 3.2.10 is known to have vulnerabilities
445            let version = Version::parse("3.2.10").unwrap();
446            let (vulnerabilities, _errors) = db.check_gem("activerecord", &version);
447            assert!(
448                !vulnerabilities.is_empty(),
449                "expected activerecord 3.2.10 to have vulnerabilities"
450            );
451        }
452    }
453
454    #[test]
455    fn check_nonexistent_gem() {
456        if let Some(db) = local_db() {
457            let version = Version::parse("1.0.0").unwrap();
458            let (vulnerabilities, _errors) = db.check_gem("nonexistent-gem-xyz", &version);
459            assert!(vulnerabilities.is_empty());
460        }
461    }
462
463    // ========== Database with fixture advisory ==========
464
465    #[test]
466    fn open_fixture_advisory_dir() {
467        let (db_dir, _) = temp_mock_db("fixture");
468
469        let db = Database::open(&db_dir).unwrap();
470        assert!(!db.is_git());
471
472        let advisories = db.advisories_for("test");
473        assert_eq!(advisories.len(), 1);
474        assert_eq!(advisories[0].id, "CVE-2020-1234");
475
476        // Check vulnerable version
477        let (vulns, _errors) = db.check_gem("test", &Version::parse("0.1.0").unwrap());
478        assert_eq!(vulns.len(), 1);
479
480        // Check patched version
481        let (vulns, _errors) = db.check_gem("test", &Version::parse("1.0.0").unwrap());
482        assert!(vulns.is_empty());
483
484        std::fs::remove_dir_all(&db_dir).unwrap();
485    }
486
487    // ========== Error Cases ==========
488
489    #[test]
490    fn open_nonexistent_path() {
491        let result = Database::open(Path::new("/nonexistent/path"));
492        assert!(result.is_err());
493    }
494
495    #[test]
496    fn default_path_is_sensible() {
497        let path = Database::default_path();
498        let path_str = path.to_string_lossy();
499        assert!(
500            path_str.contains("ruby-advisory-db"),
501            "default path should contain ruby-advisory-db: {}",
502            path_str
503        );
504    }
505
506    // Helper: create an isolated temporary mock DB for tests that don't
507    // share state with `mock_database()` in scanner tests.
508    fn temp_mock_db(suffix: &str) -> (PathBuf, PathBuf) {
509        let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
510        let db_dir = std::env::temp_dir().join(format!("gem_audit_db_test_{}", suffix));
511        let _ = std::fs::remove_dir_all(&db_dir);
512        let gem_dir = db_dir.join("gems").join("test");
513        std::fs::create_dir_all(&gem_dir).unwrap();
514        std::fs::copy(
515            fixture_dir.join("advisory/CVE-2020-1234.yml"),
516            gem_dir.join("CVE-2020-1234.yml"),
517        )
518        .unwrap();
519        (db_dir, fixture_dir)
520    }
521
522    // ========== Database Display ==========
523
524    #[test]
525    fn database_display() {
526        let (db_dir, _) = temp_mock_db("display");
527        let db = Database::open(&db_dir).unwrap();
528        let display = db.to_string();
529        assert!(display.contains("gem_audit_db_test_display"));
530        std::fs::remove_dir_all(&db_dir).unwrap();
531    }
532
533    // ========== Database exists/path ==========
534
535    #[test]
536    fn database_exists_with_gems() {
537        let (db_dir, _) = temp_mock_db("exists");
538        let db = Database::open(&db_dir).unwrap();
539        assert!(db.exists());
540        assert!(db.path() == db_dir.as_path());
541        std::fs::remove_dir_all(&db_dir).unwrap();
542    }
543
544    // ========== Database advisories/size with mock ==========
545
546    #[test]
547    fn database_advisories_with_mock() {
548        let (db_dir, _) = temp_mock_db("advisories");
549        let db = Database::open(&db_dir).unwrap();
550        let all = db.advisories();
551        assert_eq!(all.len(), 1);
552        assert_eq!(all[0].id, "CVE-2020-1234");
553        std::fs::remove_dir_all(&db_dir).unwrap();
554    }
555
556    #[test]
557    fn database_size_with_mock() {
558        let (db_dir, _) = temp_mock_db("size");
559        let db = Database::open(&db_dir).unwrap();
560        assert_eq!(db.size(), 1);
561        std::fs::remove_dir_all(&db_dir).unwrap();
562    }
563
564    // ========== Ruby advisory methods ==========
565
566    #[test]
567    fn rubies_size_with_mock() {
568        // Use the shared mock_db fixture which has rubies/ruby/CVE-2021-31810.yml
569        let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
570        let db_dir = fixture_dir.join("mock_db");
571        let db = Database::open(&db_dir).unwrap();
572        assert_eq!(db.rubies_size(), 1);
573    }
574
575    #[test]
576    fn advisories_for_ruby_with_mock() {
577        let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
578        let db_dir = fixture_dir.join("mock_db");
579        let db = Database::open(&db_dir).unwrap();
580        let advisories = db.advisories_for_ruby("ruby");
581        assert_eq!(advisories.len(), 1);
582        assert_eq!(advisories[0].id, "CVE-2021-31810");
583    }
584
585    #[test]
586    fn check_ruby_vulnerable_version() {
587        let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
588        let db_dir = fixture_dir.join("mock_db");
589        let db = Database::open(&db_dir).unwrap();
590        let (vulns, _) = db.check_ruby("ruby", &Version::parse("2.6.0").unwrap());
591        assert_eq!(vulns.len(), 1);
592    }
593
594    #[test]
595    fn check_ruby_patched_version() {
596        let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
597        let db_dir = fixture_dir.join("mock_db");
598        let db = Database::open(&db_dir).unwrap();
599        let (vulns, _) = db.check_ruby("ruby", &Version::parse("3.0.2").unwrap());
600        assert!(vulns.is_empty());
601    }
602
603    #[test]
604    fn check_ruby_nonexistent_engine() {
605        let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
606        let db_dir = fixture_dir.join("mock_db");
607        let db = Database::open(&db_dir).unwrap();
608        let (vulns, _) = db.check_ruby("nonexistent", &Version::parse("1.0.0").unwrap());
609        assert!(vulns.is_empty());
610    }
611
612    // ========== commit_id / last_updated_at for non-git ==========
613
614    #[test]
615    fn commit_id_none_for_non_git() {
616        let (db_dir, _) = temp_mock_db("nongit");
617        let db = Database::open(&db_dir).unwrap();
618        assert_eq!(db.commit_id(), None);
619        assert_eq!(db.last_updated_at(), None);
620        std::fs::remove_dir_all(&db_dir).unwrap();
621    }
622
623    // ========== DatabaseError Display ==========
624
625    #[test]
626    fn database_error_not_found_display() {
627        let err = DatabaseError::NotFound(PathBuf::from("/tmp/missing"));
628        assert!(err.to_string().contains("database not found"));
629        assert!(err.to_string().contains("/tmp/missing"));
630    }
631
632    #[test]
633    fn database_error_download_failed_display() {
634        let err = DatabaseError::DownloadFailed("network error".to_string());
635        assert!(err.to_string().contains("download failed"));
636        assert!(err.to_string().contains("network error"));
637    }
638
639    #[test]
640    fn database_error_update_failed_display() {
641        let err = DatabaseError::UpdateFailed("merge conflict".to_string());
642        assert!(err.to_string().contains("update failed"));
643    }
644
645    #[test]
646    fn database_error_git_display() {
647        let err = DatabaseError::Git("corrupt repo".to_string());
648        assert!(err.to_string().contains("git error"));
649    }
650}