use crate::config::OssIndexConfig;
use crate::error::AdvisoryError;
use crate::models::{
Advisory, Affected, Event, Package, Range, RangeType, Reference, ReferenceType, Severity,
};
use crate::purl::Purl;
use anyhow::Result;
use reqwest::Client;
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::env;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Semaphore;
use tracing::{debug, warn};
const MAX_BATCH_SIZE: usize = 128;
const DEFAULT_CONCURRENCY: usize = 4;
const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const API_BASE_URL: &str = "https://ossindex.sonatype.org/api/v3";
#[derive(Debug, Serialize)]
struct ComponentReportRequest {
coordinates: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct ComponentReport {
pub coordinates: String,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub reference: Option<String>,
#[serde(default)]
pub vulnerabilities: Vec<OssVulnerability>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct OssVulnerability {
pub id: String,
#[serde(rename = "displayName")]
pub display_name: Option<String>,
pub title: String,
pub description: String,
#[serde(rename = "cvssScore")]
pub cvss_score: Option<f64>,
#[serde(rename = "cvssVector")]
pub cvss_vector: Option<String>,
#[serde(default)]
pub cwe: Option<String>,
#[serde(default)]
pub cve: Option<String>,
pub reference: String,
#[serde(rename = "versionRanges")]
pub version_ranges: Option<Vec<String>>,
#[serde(rename = "externalReferences")]
pub external_references: Option<Vec<String>>,
}
pub struct OssIndexSource {
client: ClientWithMiddleware,
config: OssIndexConfig,
semaphore: Arc<Semaphore>,
}
impl OssIndexSource {
pub fn new(config: Option<OssIndexConfig>) -> Result<Self> {
let config = config.unwrap_or_else(Self::config_from_env);
let raw_client = Client::builder()
.timeout(REQUEST_TIMEOUT)
.connect_timeout(CONNECT_TIMEOUT)
.build()
.unwrap_or_default();
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
let client = ClientBuilder::new(raw_client)
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build();
Ok(Self {
client,
semaphore: Arc::new(Semaphore::new(DEFAULT_CONCURRENCY)),
config,
})
}
fn config_from_env() -> OssIndexConfig {
OssIndexConfig {
user: env::var("OSSINDEX_USER").ok(),
token: env::var("OSSINDEX_TOKEN").ok(),
batch_size: 128,
}
}
pub fn with_concurrency(config: Option<OssIndexConfig>, concurrency: usize) -> Result<Self> {
let mut source = Self::new(config)?;
source.semaphore = Arc::new(Semaphore::new(concurrency));
Ok(source)
}
pub async fn query_advisories(&self, purls: &[String]) -> Result<Vec<Advisory>> {
let reports = self.query_batch(purls).await?;
Ok(self.convert_reports_to_advisories(&reports))
}
pub async fn query_components(&self, purls: &[String]) -> Result<Vec<ComponentReport>> {
self.query_batch(purls).await
}
async fn query_batch(&self, purls: &[String]) -> Result<Vec<ComponentReport>> {
if purls.is_empty() {
return Ok(Vec::new());
}
let chunks: Vec<_> = purls.chunks(MAX_BATCH_SIZE).collect();
let mut handles = Vec::with_capacity(chunks.len());
for chunk in chunks {
let chunk_vec: Vec<String> = chunk.to_vec();
let client = self.client.clone();
let config = self.config.clone();
let semaphore = self.semaphore.clone();
handles.push(tokio::spawn(async move {
let _permit =
semaphore
.acquire()
.await
.map_err(|e| AdvisoryError::SourceFetch {
source_name: "ossindex".to_string(),
message: format!("Semaphore error: {}", e),
})?;
Self::query_chunk(&client, &config, &chunk_vec).await
}));
}
let mut all_reports = Vec::new();
for handle in handles {
match handle.await {
Ok(Ok(reports)) => all_reports.extend(reports),
Ok(Err(e)) => {
warn!("OSS Index batch query failed: {}", e);
return Err(e);
}
Err(e) => {
warn!("OSS Index task panicked: {}", e);
return Err(AdvisoryError::SourceFetch {
source_name: "ossindex".to_string(),
message: format!("Task panicked: {}", e),
}
.into());
}
}
}
Ok(all_reports)
}
async fn query_chunk(
client: &ClientWithMiddleware,
config: &OssIndexConfig,
purls: &[String],
) -> Result<Vec<ComponentReport>> {
let url = format!("{}/component-report", API_BASE_URL);
let request = ComponentReportRequest {
coordinates: purls.to_vec(),
};
let mut req_builder = client
.post(&url)
.header("Content-Type", "application/json")
.header("Accept", "application/json");
if let (Some(user), Some(token)) = (&config.user, &config.token) {
req_builder = req_builder.basic_auth(user, Some(token));
}
let response = req_builder
.body(serde_json::to_string(&request)?)
.send()
.await
.map_err(|e| AdvisoryError::SourceFetch {
source_name: "ossindex".to_string(),
message: format!("Request failed: {}", e),
})?;
let status = response.status();
if status == reqwest::StatusCode::UNAUTHORIZED {
return Err(AdvisoryError::SourceFetch {
source_name: "ossindex".to_string(),
message: "Authentication required. Set OSSINDEX_USER and OSSINDEX_TOKEN environment variables.".to_string(),
}.into());
}
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
return Err(AdvisoryError::SourceFetch {
source_name: "ossindex".to_string(),
message: "Rate limited by OSS Index. Please retry later.".to_string(),
}
.into());
}
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(AdvisoryError::SourceFetch {
source_name: "ossindex".to_string(),
message: format!("HTTP {}: {}", status, body),
}
.into());
}
let reports: Vec<ComponentReport> =
response
.json()
.await
.map_err(|e| AdvisoryError::SourceFetch {
source_name: "ossindex".to_string(),
message: format!("Failed to parse response: {}", e),
})?;
debug!("OSS Index returned {} reports", reports.len());
Ok(reports)
}
fn convert_reports_to_advisories(&self, reports: &[ComponentReport]) -> Vec<Advisory> {
let mut advisories = Vec::new();
let mut seen_ids: HashSet<String> = HashSet::new();
for report in reports {
for vuln in &report.vulnerabilities {
let advisory_id = self.generate_advisory_id(vuln);
if seen_ids.contains(&advisory_id) {
if let Some(advisory) = advisories
.iter_mut()
.find(|a: &&mut Advisory| a.id == advisory_id)
{
if let Some(affected) = self.extract_affected(&report.coordinates, vuln) {
advisory.affected.push(affected);
}
}
continue;
}
seen_ids.insert(advisory_id.clone());
let advisory = self.convert_vulnerability(vuln, &report.coordinates);
advisories.push(advisory);
}
}
advisories
}
fn generate_advisory_id(&self, vuln: &OssVulnerability) -> String {
if let Some(ref cve) = vuln.cve {
if !cve.is_empty() {
return cve.clone();
}
}
if let Some(ref name) = vuln.display_name {
if name.starts_with("CVE-") {
return name.clone();
}
}
if let Some(cve) = Self::extract_cve_from_url(&vuln.reference) {
return cve;
}
vuln.id.clone()
}
fn extract_cve_from_url(url: &str) -> Option<String> {
let parts: Vec<&str> = url.split('/').collect();
parts
.last()
.filter(|id| id.starts_with("CVE-"))
.map(|s| s.to_string())
}
fn convert_vulnerability(&self, vuln: &OssVulnerability, coordinates: &str) -> Advisory {
let mut affected = Vec::new();
if let Some(aff) = self.extract_affected(coordinates, vuln) {
affected.push(aff);
}
let mut aliases = Vec::new();
if let Some(ref cve) = vuln.cve {
if !cve.is_empty() && !cve.starts_with("CVE-") {
aliases.push(format!("CVE-{}", cve));
} else if !cve.is_empty() {
aliases.push(cve.clone());
}
}
let advisory_id = self.generate_advisory_id(vuln);
if advisory_id.starts_with("CVE-") && !vuln.id.starts_with("CVE-") {
aliases.push(vuln.id.clone());
}
let mut references = vec![Reference {
reference_type: ReferenceType::Advisory,
url: vuln.reference.clone(),
}];
if let Some(ref ext_refs) = vuln.external_references {
for url in ext_refs {
references.push(Reference {
reference_type: ReferenceType::Web,
url: url.clone(),
});
}
}
let mut db_specific = serde_json::Map::new();
if let Some(score) = vuln.cvss_score {
db_specific.insert("cvss_score".to_string(), serde_json::json!(score));
db_specific.insert(
"severity".to_string(),
serde_json::json!(Self::cvss_to_severity(score)),
);
}
if let Some(ref vector) = vuln.cvss_vector {
db_specific.insert("cvss_vector".to_string(), serde_json::json!(vector));
}
if let Some(ref cwe) = vuln.cwe {
db_specific.insert("cwe_ids".to_string(), serde_json::json!([cwe]));
}
db_specific.insert("source".to_string(), serde_json::json!("ossindex"));
Advisory {
id: advisory_id,
summary: Some(vuln.title.clone()),
details: Some(vuln.description.clone()),
affected,
references,
published: None,
modified: None,
aliases: if aliases.is_empty() {
None
} else {
Some(aliases)
},
database_specific: Some(serde_json::Value::Object(db_specific)),
enrichment: None,
}
}
fn extract_affected(&self, coordinates: &str, vuln: &OssVulnerability) -> Option<Affected> {
let purl = Purl::parse(coordinates).ok()?;
let ranges = vuln
.version_ranges
.as_ref()
.map(|ranges| {
ranges
.iter()
.filter_map(|r| Self::parse_version_range(r))
.collect()
})
.unwrap_or_default();
Some(Affected {
package: Package {
ecosystem: purl.ecosystem(),
name: purl.name.clone(),
purl: Some(coordinates.to_string()),
},
ranges,
versions: Vec::new(),
ecosystem_specific: None,
database_specific: None,
})
}
fn parse_version_range(range: &str) -> Option<Range> {
let range = range.trim();
if range.is_empty() {
return None;
}
if !range.contains(',') && !range.starts_with('[') && !range.starts_with('(') {
return Some(Range {
range_type: RangeType::Semver,
events: vec![Event::LastAffected(range.to_string())],
repo: None,
});
}
let start_inclusive = range.starts_with('[');
let end_inclusive = range.ends_with(']');
let inner = range
.trim_start_matches(['[', '('])
.trim_end_matches([']', ')']);
let parts: Vec<&str> = inner.split(',').collect();
if parts.len() != 2 {
return None;
}
let start = parts[0].trim();
let end = parts[1].trim();
let mut events = Vec::new();
if !start.is_empty() {
if start_inclusive {
events.push(Event::Introduced(start.to_string()));
} else {
events.push(Event::Introduced(start.to_string()));
}
} else {
events.push(Event::Introduced("0".to_string()));
}
if !end.is_empty() {
if end_inclusive {
events.push(Event::LastAffected(end.to_string()));
} else {
events.push(Event::Fixed(end.to_string()));
}
}
Some(Range {
range_type: RangeType::Semver,
events,
repo: None,
})
}
fn cvss_to_severity(score: f64) -> &'static str {
match score {
s if s >= 9.0 => "CRITICAL",
s if s >= 7.0 => "HIGH",
s if s >= 4.0 => "MEDIUM",
s if s > 0.0 => "LOW",
_ => "NONE",
}
}
pub fn score_to_severity(score: f64) -> Severity {
Severity::from_cvss_score(score)
}
pub fn name(&self) -> &'static str {
"ossindex"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_version_range_standard() {
let range = OssIndexSource::parse_version_range("[1.0.0,2.0.0)");
assert!(range.is_some());
let range = range.unwrap();
assert_eq!(range.range_type, RangeType::Semver);
assert_eq!(range.events.len(), 2);
assert!(matches!(&range.events[0], Event::Introduced(v) if v == "1.0.0"));
assert!(matches!(&range.events[1], Event::Fixed(v) if v == "2.0.0"));
}
#[test]
fn test_parse_version_range_inclusive_end() {
let range = OssIndexSource::parse_version_range("[1.0.0,2.0.0]");
assert!(range.is_some());
let range = range.unwrap();
assert_eq!(range.events.len(), 2);
assert!(matches!(&range.events[0], Event::Introduced(v) if v == "1.0.0"));
assert!(matches!(&range.events[1], Event::LastAffected(v) if v == "2.0.0"));
}
#[test]
fn test_parse_version_range_open_start() {
let range = OssIndexSource::parse_version_range("(,1.0.0)");
assert!(range.is_some());
let range = range.unwrap();
assert_eq!(range.events.len(), 2);
assert!(matches!(&range.events[0], Event::Introduced(v) if v == "0"));
assert!(matches!(&range.events[1], Event::Fixed(v) if v == "1.0.0"));
}
#[test]
fn test_parse_version_range_open_end() {
let range = OssIndexSource::parse_version_range("[1.0.0,)");
assert!(range.is_some());
let range = range.unwrap();
assert_eq!(range.events.len(), 1);
assert!(matches!(&range.events[0], Event::Introduced(v) if v == "1.0.0"));
}
#[test]
fn test_parse_version_range_exact() {
let range = OssIndexSource::parse_version_range("1.0.0");
assert!(range.is_some());
let range = range.unwrap();
assert_eq!(range.events.len(), 1);
assert!(matches!(&range.events[0], Event::LastAffected(v) if v == "1.0.0"));
}
#[test]
fn test_cvss_to_severity() {
assert_eq!(OssIndexSource::cvss_to_severity(9.5), "CRITICAL");
assert_eq!(OssIndexSource::cvss_to_severity(7.5), "HIGH");
assert_eq!(OssIndexSource::cvss_to_severity(5.0), "MEDIUM");
assert_eq!(OssIndexSource::cvss_to_severity(2.0), "LOW");
assert_eq!(OssIndexSource::cvss_to_severity(0.0), "NONE");
}
#[test]
fn test_extract_cve_from_url() {
assert_eq!(
OssIndexSource::extract_cve_from_url(
"https://ossindex.sonatype.org/vulnerability/CVE-2021-23337"
),
Some("CVE-2021-23337".to_string())
);
assert_eq!(
OssIndexSource::extract_cve_from_url(
"https://ossindex.sonatype.org/vulnerability/sonatype-2020-1234"
),
None
);
}
#[test]
fn test_purl_integration() {
let purl = Purl::new("npm", "lodash").with_version("4.17.20");
assert_eq!(purl.to_string(), "pkg:npm/lodash@4.17.20");
assert_eq!(purl.ecosystem(), "npm");
assert_eq!(purl.name, "lodash");
assert_eq!(purl.version, Some("4.17.20".to_string()));
}
#[test]
fn test_score_to_severity() {
assert_eq!(OssIndexSource::score_to_severity(9.5), Severity::Critical);
assert_eq!(OssIndexSource::score_to_severity(7.5), Severity::High);
assert_eq!(OssIndexSource::score_to_severity(5.0), Severity::Medium);
assert_eq!(OssIndexSource::score_to_severity(2.0), Severity::Low);
assert_eq!(OssIndexSource::score_to_severity(0.0), Severity::None);
}
}