use crate::util::DURATION_0;
use std::fs::File;
use std::io::Read;
use std::sync::Arc;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use crate::util::{
hash_string, logger, path_within_duration, CacheConfig, FlagLog, ResultDynError,
};
use crate::{package::Package, ureq_client::UreqClient};
#[derive(Serialize, Deserialize, Debug, Clone)]
struct OSVPackage {
name: String,
ecosystem: String,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
struct OSVPackageQuery {
package: OSVPackage,
version: String,
}
impl OSVPackageQuery {
fn from_package(package: &Package) -> Self {
OSVPackageQuery {
package: OSVPackage {
name: package.name.clone(),
ecosystem: "PyPI".to_string(),
},
version: package.version.to_string(),
}
}
}
#[derive(Serialize, Deserialize, Debug)]
struct OSVQueryBatch {
queries: Vec<OSVPackageQuery>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
struct OSVVuln {
id: String,
modified: String,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
struct OSVQueryResult {
vulns: Option<Vec<OSVVuln>>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
struct OSVResponse {
results: Vec<OSVQueryResult>,
}
const OSV_BATCH_URL: &str = "https://api.osv.dev/v1/querybatch";
fn query_osv_batch(
client: Arc<dyn UreqClient>,
packages: &[OSVPackageQuery],
) -> Vec<Option<Vec<String>>> {
let batch_query = OSVQueryBatch {
queries: packages.to_vec(),
};
let body = serde_json::to_string(&batch_query).unwrap();
let response: Result<String, ureq::Error> = client.post(OSV_BATCH_URL, &body);
match response {
Ok(body_str) => {
let osv_res: OSVResponse = serde_json::from_str(&body_str).unwrap();
osv_res
.results
.iter()
.map(|result| {
result.vulns.as_ref().map(|vuln_list| {
vuln_list
.iter()
.map(|v| v.id.clone())
.collect::<Vec<String>>()
})
})
.collect()
}
Err(_) => {
vec![None; packages.len()]
}
}
}
pub(crate) fn query_osv_batches(
client: Arc<dyn UreqClient>,
packages: &[Package],
cache_config: &CacheConfig,
log: FlagLog,
) -> ResultDynError<Vec<Option<Vec<String>>>> {
let packages_osv: Vec<OSVPackageQuery> =
packages.iter().map(OSVPackageQuery::from_package).collect();
let query_api = || -> Vec<Option<Vec<String>>> {
let chunk_size = 64.min(packages_osv.len());
packages_osv
.par_chunks(chunk_size)
.flat_map(|chunk| query_osv_batch(client.clone(), chunk))
.collect()
};
if cache_config.duration == DURATION_0 {
logger!(log, module_path!(), "Cache OSV batch disabled by duration");
return Ok(query_api());
}
let json =
serde_json::to_string(&packages_osv).expect("Failed to serialize packages_osv");
let cache_key = hash_string(&json);
let cache_fp = cache_config
.directory
.join(format!("osv_batch_{cache_key}"))
.with_extension("json");
if path_within_duration(&cache_fp, cache_config.duration) {
logger!(
log,
module_path!(),
"Loading OSV batch cache: {:?}",
cache_fp
);
if let Ok(mut file) = File::open(&cache_fp) {
let mut contents = String::new();
if file.read_to_string(&mut contents).is_ok() {
if let Ok(cached_results) = serde_json::from_str(&contents) {
return Ok(cached_results);
}
}
}
}
let results = query_api();
if let Ok(json) = serde_json::to_string(&results) {
logger!(
log,
module_path!(),
"Writing OSV batch cache: {:?}",
cache_fp
);
let _ = std::fs::write(&cache_fp, json);
}
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{ureq_client::UreqClientMock, util::path_cache};
use std::collections::HashMap;
#[test]
fn test_osv_querybatch_a() {
let mut mock_post_map = HashMap::new();
mock_post_map.insert("https://api.osv.dev".to_string(), "{\"results\":[{\"vulns\":[{\"id\":\"GHSA-34rf-p3r3-58x2\",\"modified\":\"2024-05-06T14:46:47.572046Z\"},{\"id\":\"GHSA-3f95-mxq2-2f63\",\"modified\":\"2024-04-10T22:19:39.095481Z\"},{\"id\":\"GHSA-48cq-79qq-6f7x\",\"modified\":\"2024-05-21T14:58:25.710902Z\"}]},{\"vulns\":[{\"id\":\"GHSA-pmv9-3xqp-8w42\",\"modified\":\"2024-09-18T19:36:03.377591Z\"}]}]}".to_string());
let client = Arc::new(UreqClientMock {
mock_post: Some(mock_post_map),
mock_get: None,
});
let packages = vec![
Package::from_name_version_durl("gradio", "4.0.0", None).unwrap(),
Package::from_name_version_durl("mesop", "0.11.1", None).unwrap(),
];
let cache_dir = path_cache(true).unwrap();
let cache_config = CacheConfig::new(DURATION_0, cache_dir);
let results =
query_osv_batches(client, &packages, &cache_config, FlagLog(false)).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(
results[0],
Some(vec![
"GHSA-34rf-p3r3-58x2".to_string(),
"GHSA-3f95-mxq2-2f63".to_string(),
"GHSA-48cq-79qq-6f7x".to_string()
])
);
assert_eq!(results[1], Some(vec!["GHSA-pmv9-3xqp-8w42".to_string()]));
}
#[test]
fn test_osv_querybatch_cache_disabled() {
let mut mock_post_map = HashMap::new();
mock_post_map.insert("https://api.osv.dev".to_string(), "{\"results\":[{\"vulns\":[{\"id\":\"GHSA-test-disabled\",\"modified\":\"2024-05-06T14:46:47.572046Z\"}]}]}".to_string());
let client = Arc::new(UreqClientMock {
mock_post: Some(mock_post_map),
mock_get: None,
});
let packages =
vec![Package::from_name_version_durl("test-package", "1.0.0", None).unwrap()];
let cache_dir = path_cache(true).unwrap();
let cache_config = CacheConfig::new(DURATION_0, cache_dir);
let results =
query_osv_batches(client, &packages, &cache_config, FlagLog(false)).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0], Some(vec!["GHSA-test-disabled".to_string()]));
}
}