use std::fmt;
use std::path::{Path, PathBuf};
use thiserror::Error;
use super::model::Advisory;
use crate::version::Version;
const ADVISORY_DB_URL: &str = "https://github.com/rubysec/ruby-advisory-db.git";
#[derive(Debug)]
pub struct Database {
path: PathBuf,
}
#[derive(Debug, Error)]
pub enum DatabaseError {
#[error("database not found at {}", .0.display())]
NotFound(PathBuf),
#[error("download failed: {0}")]
DownloadFailed(String),
#[error("update failed: {0}")]
UpdateFailed(String),
#[error("git error: {0}")]
Git(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}
impl Database {
pub fn open(path: &Path) -> Result<Self, DatabaseError> {
if !path.is_dir() {
return Err(DatabaseError::NotFound(path.to_path_buf()));
}
Ok(Database {
path: path.to_path_buf(),
})
}
pub fn default_path() -> PathBuf {
if let Ok(custom) = std::env::var("GEM_AUDIT_DB") {
return PathBuf::from(custom);
}
dirs_fallback()
}
pub fn download(path: &Path, _quiet: bool) -> Result<Self, DatabaseError> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).map_err(DatabaseError::Io)?;
}
let (mut checkout, _outcome) = gix::prepare_clone(ADVISORY_DB_URL, path)
.map_err(|e| DatabaseError::DownloadFailed(e.to_string()))?
.fetch_then_checkout(gix::progress::Discard, &gix::interrupt::IS_INTERRUPTED)
.map_err(|e| DatabaseError::DownloadFailed(e.to_string()))?;
let (_repo, _outcome) = checkout
.main_worktree(gix::progress::Discard, &gix::interrupt::IS_INTERRUPTED)
.map_err(|e| DatabaseError::DownloadFailed(e.to_string()))?;
Ok(Database {
path: path.to_path_buf(),
})
}
pub fn update(&self) -> Result<bool, DatabaseError> {
if !self.is_git() {
return Ok(false);
}
if let Err(e) = self.try_fetch() {
eprintln!("warning: git fetch failed ({}), re-cloning ...", e);
return self.reclone();
}
self.checkout_head()
}
fn try_fetch(&self) -> Result<(), DatabaseError> {
let repo = gix::open(&self.path).map_err(|e| DatabaseError::Git(e.to_string()))?;
let remote = repo
.find_default_remote(gix::remote::Direction::Fetch)
.ok_or_else(|| DatabaseError::UpdateFailed("no remote configured".to_string()))?
.map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
let connection = remote
.connect(gix::remote::Direction::Fetch)
.map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
connection
.prepare_fetch(gix::progress::Discard, Default::default())
.map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?
.receive(gix::progress::Discard, &gix::interrupt::IS_INTERRUPTED)
.map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
Ok(())
}
fn checkout_head(&self) -> Result<bool, DatabaseError> {
let repo = gix::open(&self.path).map_err(|e| DatabaseError::Git(e.to_string()))?;
let tree = repo
.head_commit()
.map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?
.tree()
.map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
let mut index = repo
.index_from_tree(&tree.id)
.map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
let opts = gix::worktree::state::checkout::Options {
overwrite_existing: true,
..Default::default()
};
gix::worktree::state::checkout(
&mut index,
repo.workdir()
.ok_or_else(|| DatabaseError::UpdateFailed("bare repository".to_string()))?,
repo.objects
.clone()
.into_arc()
.map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?,
&gix::progress::Discard,
&gix::progress::Discard,
&gix::interrupt::IS_INTERRUPTED,
opts,
)
.map_err(|e| DatabaseError::UpdateFailed(e.to_string()))?;
Ok(true)
}
fn reclone(&self) -> Result<bool, DatabaseError> {
let tmp = {
let mut p = self.path.clone().into_os_string();
p.push("_tmp");
PathBuf::from(p)
};
let old = {
let mut p = self.path.clone().into_os_string();
p.push("_old");
PathBuf::from(p)
};
let _ = std::fs::remove_dir_all(&tmp);
let _ = std::fs::remove_dir_all(&old);
Database::download(&tmp, true)?;
std::fs::rename(&self.path, &old).map_err(DatabaseError::Io)?;
std::fs::rename(&tmp, &self.path).map_err(|e| {
let _ = std::fs::rename(&old, &self.path);
DatabaseError::Io(e)
})?;
let _ = std::fs::remove_dir_all(&old);
Ok(true)
}
pub fn is_git(&self) -> bool {
self.path.join(".git").is_dir()
}
pub fn exists(&self) -> bool {
self.path.is_dir() && self.path.join("gems").is_dir()
}
pub fn path(&self) -> &Path {
&self.path
}
pub fn commit_id(&self) -> Option<String> {
if !self.is_git() {
return None;
}
let repo = gix::open(&self.path).ok()?;
let id = repo.head_id().ok()?;
Some(id.to_string())
}
pub fn last_updated_at(&self) -> Option<i64> {
if !self.is_git() {
return None;
}
let repo = gix::open(&self.path).ok()?;
let commit = repo.head_commit().ok()?;
let time = commit.time().ok()?;
Some(time.seconds)
}
pub fn advisories(&self) -> Vec<Advisory> {
let mut results = Vec::new();
let gems_dir = self.path.join("gems");
if !gems_dir.is_dir() {
return results;
}
if let Ok(entries) = std::fs::read_dir(&gems_dir) {
for entry in entries.flatten() {
if entry.path().is_dir() {
let _ = self.load_advisories_from_dir(&entry.path(), &mut results);
}
}
}
results
}
pub fn advisories_for(&self, gem_name: &str) -> Vec<Advisory> {
self.advisories_for_with_errors(gem_name).0
}
fn advisories_for_with_errors(&self, gem_name: &str) -> (Vec<Advisory>, usize) {
let mut results = Vec::new();
let gem_dir = self.path.join("gems").join(gem_name);
let errors = if gem_dir.is_dir() {
self.load_advisories_from_dir(&gem_dir, &mut results)
} else {
0
};
(results, errors)
}
pub fn check_gem(&self, gem_name: &str, version: &Version) -> (Vec<Advisory>, usize) {
let (advisories, errors) = self.advisories_for_with_errors(gem_name);
let vulnerable = advisories
.into_iter()
.filter(|advisory| advisory.vulnerable(version))
.collect();
(vulnerable, errors)
}
pub fn advisories_for_ruby(&self, engine: &str) -> Vec<Advisory> {
self.advisories_for_ruby_with_errors(engine).0
}
fn advisories_for_ruby_with_errors(&self, engine: &str) -> (Vec<Advisory>, usize) {
let mut results = Vec::new();
let engine_dir = self.path.join("rubies").join(engine);
let errors = if engine_dir.is_dir() {
self.load_advisories_from_dir(&engine_dir, &mut results)
} else {
0
};
(results, errors)
}
pub fn check_ruby(&self, engine: &str, version: &Version) -> (Vec<Advisory>, usize) {
let (advisories, errors) = self.advisories_for_ruby_with_errors(engine);
let vulnerable = advisories
.into_iter()
.filter(|advisory| advisory.vulnerable(version))
.collect();
(vulnerable, errors)
}
pub fn size(&self) -> usize {
self.count_advisories_in("gems")
}
pub fn rubies_size(&self) -> usize {
self.count_advisories_in("rubies")
}
fn count_advisories_in(&self, subdir: &str) -> usize {
let dir = self.path.join(subdir);
if !dir.is_dir() {
return 0;
}
let mut count = 0;
if let Ok(entries) = std::fs::read_dir(&dir) {
for entry in entries.flatten() {
if entry.path().is_dir()
&& let Ok(advisory_files) = std::fs::read_dir(entry.path())
{
count += advisory_files
.flatten()
.filter(|f| f.path().extension().is_some_and(|ext| ext == "yml"))
.count();
}
}
}
count
}
fn load_advisories_from_dir(&self, dir: &Path, results: &mut Vec<Advisory>) -> usize {
let mut errors = 0;
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().is_some_and(|ext| ext == "yml") {
match Advisory::load(&path) {
Ok(advisory) => results.push(advisory),
Err(e) => {
eprintln!("warning: failed to load advisory {}: {}", path.display(), e);
errors += 1;
}
}
}
}
}
errors
}
}
impl fmt::Display for Database {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.path.display())
}
}
fn dirs_fallback() -> PathBuf {
if let Ok(home) = std::env::var("HOME") {
PathBuf::from(home)
.join(".local")
.join("share")
.join("ruby-advisory-db")
} else {
PathBuf::from(".ruby-advisory-db")
}
}
#[cfg(test)]
mod tests {
use super::*;
fn local_db() -> Option<Database> {
let path = Database::default_path();
if path.is_dir() && path.join("gems").is_dir() {
Database::open(&path).ok()
} else {
None
}
}
#[test]
fn open_local_database() {
if let Some(db) = local_db() {
assert!(db.exists());
assert!(db.is_git());
}
}
#[test]
fn database_size() {
if let Some(db) = local_db() {
let size = db.size();
assert!(size > 100, "expected > 100 advisories, got {}", size);
}
}
#[test]
fn database_commit_id() {
if let Some(db) = local_db() {
let commit = db.commit_id();
assert!(commit.is_some());
let id = commit.unwrap();
assert_eq!(id.len(), 40); }
}
#[test]
fn database_last_updated() {
if let Some(db) = local_db() {
let ts = db.last_updated_at();
assert!(ts.is_some());
assert!(ts.unwrap() > 0);
}
}
#[test]
fn advisories_for_actionpack() {
if let Some(db) = local_db() {
let advisories = db.advisories_for("actionpack");
assert!(!advisories.is_empty(), "expected advisories for actionpack");
}
}
#[test]
fn check_vulnerable_gem() {
if let Some(db) = local_db() {
let version = Version::parse("3.2.10").unwrap();
let (vulnerabilities, _errors) = db.check_gem("activerecord", &version);
assert!(
!vulnerabilities.is_empty(),
"expected activerecord 3.2.10 to have vulnerabilities"
);
}
}
#[test]
fn check_nonexistent_gem() {
if let Some(db) = local_db() {
let version = Version::parse("1.0.0").unwrap();
let (vulnerabilities, _errors) = db.check_gem("nonexistent-gem-xyz", &version);
assert!(vulnerabilities.is_empty());
}
}
#[test]
fn open_fixture_advisory_dir() {
let (tmp, _) = temp_mock_db("fixture");
let db = Database::open(tmp.path()).unwrap();
assert!(!db.is_git());
let advisories = db.advisories_for("test");
assert_eq!(advisories.len(), 1);
assert_eq!(advisories[0].id, "CVE-2020-1234");
let (vulns, _errors) = db.check_gem("test", &Version::parse("0.1.0").unwrap());
assert_eq!(vulns.len(), 1);
let (vulns, _errors) = db.check_gem("test", &Version::parse("1.0.0").unwrap());
assert!(vulns.is_empty());
}
#[test]
fn open_nonexistent_path() {
let result = Database::open(Path::new("/nonexistent/path"));
assert!(result.is_err());
}
#[test]
fn default_path_is_sensible() {
let path = Database::default_path();
let path_str = path.to_string_lossy();
assert!(
path_str.contains("ruby-advisory-db"),
"default path should contain ruby-advisory-db: {}",
path_str
);
}
fn temp_mock_db(_suffix: &str) -> (tempfile::TempDir, PathBuf) {
let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
let tmp = tempfile::tempdir().unwrap();
let gem_dir = tmp.path().join("gems").join("test");
std::fs::create_dir_all(&gem_dir).unwrap();
std::fs::copy(
fixture_dir.join("advisory/CVE-2020-1234.yml"),
gem_dir.join("CVE-2020-1234.yml"),
)
.unwrap();
(tmp, fixture_dir)
}
#[test]
fn database_display() {
let (tmp, _) = temp_mock_db("display");
let db = Database::open(tmp.path()).unwrap();
let display = db.to_string();
assert_eq!(display, tmp.path().to_string_lossy());
}
#[test]
fn database_exists_with_gems() {
let (tmp, _) = temp_mock_db("exists");
let db = Database::open(tmp.path()).unwrap();
assert!(db.exists());
assert!(db.path() == tmp.path());
}
#[test]
fn database_advisories_with_mock() {
let (tmp, _) = temp_mock_db("advisories");
let db = Database::open(tmp.path()).unwrap();
let all = db.advisories();
assert_eq!(all.len(), 1);
assert_eq!(all[0].id, "CVE-2020-1234");
}
#[test]
fn database_size_with_mock() {
let (tmp, _) = temp_mock_db("size");
let db = Database::open(tmp.path()).unwrap();
assert_eq!(db.size(), 1);
}
#[test]
fn rubies_size_with_mock() {
let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
let db_dir = fixture_dir.join("mock_db");
let db = Database::open(&db_dir).unwrap();
assert_eq!(db.rubies_size(), 1);
}
#[test]
fn advisories_for_ruby_with_mock() {
let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
let db_dir = fixture_dir.join("mock_db");
let db = Database::open(&db_dir).unwrap();
let advisories = db.advisories_for_ruby("ruby");
assert_eq!(advisories.len(), 1);
assert_eq!(advisories[0].id, "CVE-2021-31810");
}
#[test]
fn check_ruby_vulnerable_version() {
let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
let db_dir = fixture_dir.join("mock_db");
let db = Database::open(&db_dir).unwrap();
let (vulns, _) = db.check_ruby("ruby", &Version::parse("2.6.0").unwrap());
assert_eq!(vulns.len(), 1);
}
#[test]
fn check_ruby_patched_version() {
let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
let db_dir = fixture_dir.join("mock_db");
let db = Database::open(&db_dir).unwrap();
let (vulns, _) = db.check_ruby("ruby", &Version::parse("3.0.2").unwrap());
assert!(vulns.is_empty());
}
#[test]
fn check_ruby_nonexistent_engine() {
let fixture_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
let db_dir = fixture_dir.join("mock_db");
let db = Database::open(&db_dir).unwrap();
let (vulns, _) = db.check_ruby("nonexistent", &Version::parse("1.0.0").unwrap());
assert!(vulns.is_empty());
}
#[test]
fn commit_id_none_for_non_git() {
let (tmp, _) = temp_mock_db("nongit");
let db = Database::open(tmp.path()).unwrap();
assert_eq!(db.commit_id(), None);
assert_eq!(db.last_updated_at(), None);
}
#[test]
fn database_error_not_found_display() {
let err = DatabaseError::NotFound(PathBuf::from("/tmp/missing"));
assert!(err.to_string().contains("database not found"));
assert!(err.to_string().contains("/tmp/missing"));
}
#[test]
fn database_error_download_failed_display() {
let err = DatabaseError::DownloadFailed("network error".to_string());
assert!(err.to_string().contains("download failed"));
assert!(err.to_string().contains("network error"));
}
#[test]
fn database_error_update_failed_display() {
let err = DatabaseError::UpdateFailed("merge conflict".to_string());
assert!(err.to_string().contains("update failed"));
}
#[test]
fn database_error_git_display() {
let err = DatabaseError::Git("corrupt repo".to_string());
assert!(err.to_string().contains("git error"));
}
}