use super::schema::*;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use url::Url;
#[derive(Debug, Serialize)]
#[serde(untagged)]
pub enum Request {
CommitQuery { commit: Commit },
PackageQuery { version: Version, package: Package },
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
enum Response {
Vulnerabilities { vulns: Vec<Vulnerability> },
NoResult(serde_json::Value),
}
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum ApiError {
#[error("requested resource {0} not found")]
NotFound(String),
#[error("invalid request url: {0:?}")]
InvalidUrl(#[from] url::ParseError),
#[error("serialization failure: {0:?}")]
SerializationError(#[from] serde_json::Error),
#[error("request to osv endpoint failed: {0:?}")]
RequestFailed(reqwest::Error),
#[error("unexpected error has occurred")]
Unexpected,
}
impl From<reqwest::Error> for ApiError {
fn from(err: reqwest::Error) -> Self {
ApiError::RequestFailed(err)
}
}
pub async fn query(q: &Request) -> Result<Option<Vec<Vulnerability>>, ApiError> {
let client = reqwest::Client::new();
let res = client
.post("https://api.osv.dev/v1/query")
.json(q)
.send()
.await?;
match res.status() {
StatusCode::NOT_FOUND => {
let err = match q {
Request::PackageQuery {
version: _,
package: pkg,
} => {
format!("package - `{}`", pkg.name)
}
Request::CommitQuery { commit: c } => {
format!("commit - `{}`", c)
}
};
Err(ApiError::NotFound(err))
}
_ => {
let vulns: Response = res.json().await?;
match vulns {
Response::Vulnerabilities { vulns: vs } => Ok(Some(vs)),
_ => Ok(None),
}
}
}
}
pub async fn query_package(
name: &str,
version: &str,
ecosystem: Ecosystem,
) -> Result<Option<Vec<Vulnerability>>, ApiError> {
let req = Request::PackageQuery {
version: Version::from(version),
package: Package {
name: name.to_string(),
ecosystem,
purl: None,
},
};
query(&req).await
}
pub async fn query_commit(commit: &str) -> Result<Option<Vec<Vulnerability>>, ApiError> {
let req = Request::CommitQuery {
commit: Commit::from(commit),
};
query(&req).await
}
pub async fn vulnerability(vuln_id: &str) -> Result<Vulnerability, ApiError> {
let base = Url::parse("https://api.osv.dev/v1/vulns/")?;
let req = base.join(vuln_id)?;
let res = reqwest::get(req.as_str()).await?;
if res.status() == StatusCode::NOT_FOUND {
Err(ApiError::NotFound(vuln_id.to_string()))
} else {
let vuln: Vulnerability = res.json().await?;
Ok(vuln)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_package_query() {
let req = Request::PackageQuery {
version: Version::from("2.4.1"),
package: Package {
name: "jinja2".to_string(),
ecosystem: Ecosystem::PyPI,
purl: None,
},
};
let res = query(&req).await.unwrap();
assert!(res.is_some());
}
#[tokio::test]
async fn test_package_query_wrapper() {
let res = query_package("jinja2", "2.4.1", Ecosystem::PyPI)
.await
.unwrap();
assert!(res.is_some());
}
#[tokio::test]
async fn test_invalid_packagename() {
let res = query_package(
"asdfasdlfkjlksdjfklsdjfklsdjfklds",
"0.0.1",
Ecosystem::PyPI,
)
.await
.unwrap();
assert!(res.is_none());
}
#[tokio::test]
async fn test_commit_query() {
let req = Request::CommitQuery {
commit: Commit::from("6879efc2c1596d11a6a6ad296f80063b558d5e0f"),
};
let res = query(&req).await.unwrap();
assert!(res.is_some());
}
#[tokio::test]
async fn test_commit_query_wrapper() {
let res = query_commit("6879efc2c1596d11a6a6ad296f80063b558d5e0f")
.await
.unwrap();
assert!(res.is_some());
}
#[tokio::test]
async fn test_invalid_commit() {
let res = query_commit("zzzz").await.unwrap();
assert!(res.is_none());
}
#[tokio::test]
async fn test_vulnerability() {
let res = vulnerability("OSV-2020-484").await;
assert!(res.is_ok());
}
#[tokio::test]
async fn test_get_missing_cve() {
let res = vulnerability("CVE-2014-0161").await;
assert!(res.is_err());
}
}