use super::response::{OsvBatchRequest, OsvBatchResponse, OsvQuery};
use crate::error::{EnrichmentErrorKind, Result, SbomDiffError};
use reqwest::blocking::Client;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct OsvClientConfig {
pub api_base: String,
pub timeout: Duration,
pub max_retries: u8,
pub batch_size: usize,
}
impl Default for OsvClientConfig {
fn default() -> Self {
Self {
api_base: "https://api.osv.dev".to_string(),
timeout: Duration::from_secs(30),
max_retries: 3,
batch_size: 1000, }
}
}
pub struct OsvClient {
client: Client,
config: OsvClientConfig,
}
fn network_error(msg: &str, err: &reqwest::Error) -> SbomDiffError {
SbomDiffError::enrichment(msg, EnrichmentErrorKind::NetworkError(err.to_string()))
}
fn api_error(msg: impl Into<String>) -> SbomDiffError {
SbomDiffError::enrichment("API request", EnrichmentErrorKind::ApiError(msg.into()))
}
impl OsvClient {
pub fn new(config: OsvClientConfig) -> Result<Self> {
let client = Client::builder()
.timeout(config.timeout)
.user_agent(concat!(
env!("CARGO_PKG_NAME"),
"/",
env!("CARGO_PKG_VERSION")
))
.build()
.map_err(|e| network_error("Failed to create HTTP client", &e))?;
Ok(Self { client, config })
}
pub fn health_check(&self) -> Result<bool> {
let url = format!("{}/v1/vulns/OSV-2020-1", self.config.api_base);
let response = self
.client
.get(&url)
.send()
.map_err(|e| network_error("Health check request failed", &e))?;
Ok(response.status().is_success() || response.status().as_u16() == 404)
}
pub fn query_batch(&self, queries: &[OsvQuery]) -> Result<Vec<OsvBatchResponse>> {
if queries.is_empty() {
return Ok(vec![]);
}
let mut results = Vec::new();
for chunk in queries.chunks(self.config.batch_size) {
let response = self.query_batch_internal(chunk)?;
results.push(response);
}
Ok(results)
}
fn query_batch_internal(&self, queries: &[OsvQuery]) -> Result<OsvBatchResponse> {
let url = format!("{}/v1/querybatch", self.config.api_base);
let request_body = OsvBatchRequest {
queries: queries.to_vec(),
};
let mut last_error = None;
for attempt in 0..=self.config.max_retries {
if attempt > 0 {
let delay = Duration::from_secs(1 << (attempt - 1));
std::thread::sleep(delay);
tracing::debug!("Retry attempt {} after {:?}", attempt, delay);
}
match self.send_batch_request(&url, &request_body) {
Ok(response) => return Ok(response),
Err(e) => {
tracing::debug!("Batch request attempt {} failed: {}", attempt + 1, e);
last_error = Some(e);
}
}
}
Err(last_error.unwrap_or_else(|| api_error("Unknown error")))
}
fn send_batch_request(
&self,
url: &str,
request_body: &OsvBatchRequest,
) -> Result<OsvBatchResponse> {
let response = self
.client
.post(url)
.json(request_body)
.send()
.map_err(|e| network_error("Failed to send batch request", &e))?;
let status = response.status();
if !status.is_success() {
let body = response.text().unwrap_or_default();
return Err(api_error(format!(
"OSV API returned error status {}: {}",
status.as_u16(),
body
)));
}
let batch_response: OsvBatchResponse = response.json().map_err(|e| {
SbomDiffError::enrichment(
"parsing response",
EnrichmentErrorKind::InvalidResponse(e.to_string()),
)
})?;
Ok(batch_response)
}
pub fn get_vulnerability(
&self,
vuln_id: &str,
) -> Result<Option<super::response::OsvVulnerability>> {
let url = format!("{}/v1/vulns/{}", self.config.api_base, vuln_id);
let response = self
.client
.get(&url)
.send()
.map_err(|e| network_error("Failed to fetch vulnerability", &e))?;
if response.status().as_u16() == 404 {
return Ok(None);
}
if !response.status().is_success() {
return Err(api_error(format!(
"OSV API returned error status {}",
response.status().as_u16()
)));
}
let vuln = response.json().map_err(|e| {
SbomDiffError::enrichment(
"parsing vulnerability",
EnrichmentErrorKind::InvalidResponse(e.to_string()),
)
})?;
Ok(Some(vuln))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_config_defaults() {
let config = OsvClientConfig::default();
assert_eq!(config.api_base, "https://api.osv.dev");
assert_eq!(config.batch_size, 1000);
}
#[test]
fn test_query_construction() {
let query = OsvQuery::from_purl("pkg:npm/lodash@4.17.21".to_string());
let json = serde_json::to_string(&query).unwrap();
assert!(json.contains("lodash"));
}
}