use sha2::{Digest, Sha256};
use crate::error::AcdpError;
use crate::safe_http::SsrfPolicy;
use crate::types::data_ref::{DataRef, Location};
use crate::types::primitives::ContentHash;
pub const DEFAULT_MAX_BYTES: u64 = 16 * 1024 * 1024;
pub trait DataRefFetcher: Send + Sync {
fn fetch(
&self,
location: &Location,
) -> impl std::future::Future<Output = Result<Vec<u8>, AcdpError>> + Send;
}
pub struct HttpsDataRefFetcher {
http: reqwest::Client,
ssrf_policy: SsrfPolicy,
max_bytes: u64,
}
impl Default for HttpsDataRefFetcher {
fn default() -> Self {
Self::new()
}
}
impl HttpsDataRefFetcher {
pub fn new() -> Self {
Self::with_max_bytes(DEFAULT_MAX_BYTES)
}
pub fn with_max_bytes(max_bytes: u64) -> Self {
let policy = SsrfPolicy::default();
let http = build_data_ref_http_client(&policy)
.expect("HttpsDataRefFetcher HTTP client build failed");
Self {
http,
ssrf_policy: policy,
max_bytes,
}
}
pub fn with_ssrf_policy(mut self, policy: SsrfPolicy) -> Self {
self.http = build_data_ref_http_client(&policy)
.expect("rebuild HttpsDataRefFetcher HTTP client with new SSRF policy");
self.ssrf_policy = policy;
self
}
}
fn build_data_ref_http_client(policy: &SsrfPolicy) -> Result<reqwest::Client, AcdpError> {
use crate::limits::MAX_REDIRECTS;
let redirect_policy = reqwest::redirect::Policy::custom(|attempt| {
if attempt.previous().len() >= MAX_REDIRECTS {
return attempt.error(format!(
"data_ref fetch: exceeded {MAX_REDIRECTS} redirects"
));
}
let cross = attempt
.previous()
.first()
.filter(|orig| !crate::safe_http::same_fetch_authority(orig, attempt.url()))
.map(|orig| (orig.to_string(), attempt.url().to_string()));
if let Some((from, to)) = cross {
return attempt.error(format!(
"data_ref fetch: cross-authority redirect rejected ({from} -> {to})"
));
}
attempt.follow()
});
reqwest::Client::builder()
.use_rustls_tls()
.connect_timeout(std::time::Duration::from_secs(5))
.timeout(std::time::Duration::from_secs(30))
.redirect(redirect_policy)
.dns_resolver(crate::safe_http::SafeDnsResolver::arc(policy.clone()))
.build()
.map_err(|e| AcdpError::Http(e.to_string()))
}
impl DataRefFetcher for HttpsDataRefFetcher {
async fn fetch(&self, location: &Location) -> Result<Vec<u8>, AcdpError> {
let uri = match location {
Location::Uri(s) => s,
Location::Structured(_) => {
return Err(AcdpError::NotImplemented(
"HttpsDataRefFetcher does not handle structured locators \
(kafka.offset, ipfs.cid, …) — implement DataRefFetcher \
for the relevant scheme"
.into(),
));
}
};
self.ssrf_policy
.check_url(uri)
.map_err(|e| AcdpError::SchemaViolation(format!("SSRF policy on data_ref: {e}")))?;
let mut resp = self
.http
.get(uri)
.send()
.await
.map_err(|e| AcdpError::Http(e.to_string()))?;
if !resp.status().is_success() {
return Err(AcdpError::Http(format!(
"data_ref fetch returned HTTP {}",
resp.status()
)));
}
let mut buf = Vec::with_capacity(8 * 1024);
while let Some(chunk) = resp
.chunk()
.await
.map_err(|e| AcdpError::Http(e.to_string()))?
{
if (buf.len() as u64).saturating_add(chunk.len() as u64) > self.max_bytes {
return Err(AcdpError::PayloadTooLarge(format!(
"data_ref response exceeded {} bytes",
self.max_bytes
)));
}
buf.extend_from_slice(&chunk);
}
Ok(buf)
}
}
pub async fn fetch_and_verify_data_ref(
dr: &DataRef,
fetcher: &impl DataRefFetcher,
) -> Result<Vec<u8>, AcdpError> {
if let Some(emb) = &dr.embedded {
let bytes = crate::validation::embedded_decoded_bytes(emb)?;
if dr.content_hash.is_some() {
crate::validation::verify_embedded_hash(dr)?;
}
return Ok(bytes);
}
let Some(location) = &dr.location else {
return Err(AcdpError::SchemaViolation(
"data_ref has neither embedded nor location — cannot fetch".into(),
));
};
let bytes = fetcher.fetch(location).await?;
if let Some(declared) = &dr.content_hash {
check_sha256(&bytes, declared)?;
}
Ok(bytes)
}
fn check_sha256(bytes: &[u8], declared: &ContentHash) -> Result<(), AcdpError> {
let Some(declared_hex) = declared.as_str().strip_prefix("sha256:") else {
return Err(AcdpError::SchemaViolation(format!(
"data_ref content_hash must start with 'sha256:', got '{}'",
declared.as_str()
)));
};
let got = format!("{:x}", Sha256::digest(bytes));
if got != declared_hex {
return Err(AcdpError::DataRefHashMismatch(format!(
"data_ref content_hash mismatch: declared sha256:{declared_hex}, computed sha256:{got}"
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::data_ref::{DataRefType, EmbeddedContent, EmbeddedEncoding};
use sha2::{Digest, Sha256};
struct StubFetcher {
bytes: Vec<u8>,
}
impl DataRefFetcher for StubFetcher {
async fn fetch(&self, _location: &Location) -> Result<Vec<u8>, AcdpError> {
Ok(self.bytes.clone())
}
}
#[tokio::test]
async fn fetch_and_verify_uri_ref_passes_with_matching_hash() {
let bytes = b"hello-world".to_vec();
let hash = format!("sha256:{:x}", Sha256::digest(&bytes));
let dr = DataRef::uri_verified(
DataRefType::RawData,
"https://example.com/data",
ContentHash(hash),
);
let got = fetch_and_verify_data_ref(
&dr,
&StubFetcher {
bytes: bytes.clone(),
},
)
.await
.unwrap();
assert_eq!(got, bytes);
}
#[tokio::test]
async fn fetch_and_verify_uri_ref_fails_on_hash_mismatch() {
let dr = DataRef::uri_verified(
DataRefType::RawData,
"https://example.com/data",
ContentHash(format!("sha256:{}", "0".repeat(64))),
);
let err = fetch_and_verify_data_ref(
&dr,
&StubFetcher {
bytes: b"different bytes".to_vec(),
},
)
.await
.unwrap_err();
assert!(
matches!(err, AcdpError::DataRefHashMismatch(_)),
"expected DataRefHashMismatch, got {err:?}"
);
}
#[tokio::test]
async fn fetch_and_verify_uri_ref_without_declared_hash_returns_bytes_unverified() {
let dr = DataRef::uri(DataRefType::RawData, "https://example.com/data");
let got = fetch_and_verify_data_ref(
&dr,
&StubFetcher {
bytes: b"unverified".to_vec(),
},
)
.await
.unwrap();
assert_eq!(got, b"unverified");
}
#[tokio::test]
async fn fetch_and_verify_embedded_ref_returns_decoded_bytes() {
use base64::{engine::general_purpose::STANDARD, Engine};
let payload = b"embedded-bytes";
let encoded = STANDARD.encode(payload);
let dr = DataRef {
ref_type: DataRefType::RawData,
description: None,
size_bytes: None,
format: None,
schema_version: None,
content_hash: None,
location: None,
embedded: Some(EmbeddedContent {
encoding: EmbeddedEncoding::Base64,
content: serde_json::json!(encoded),
}),
extensions: serde_json::Map::new(),
};
let got = fetch_and_verify_data_ref(&dr, &StubFetcher { bytes: vec![] })
.await
.unwrap();
assert_eq!(got, payload);
}
#[tokio::test]
async fn https_fetcher_rejects_http_uri() {
let f = HttpsDataRefFetcher::new();
let err = f
.fetch(&Location::Uri("http://insecure.example.com/x".into()))
.await
.unwrap_err();
assert!(matches!(err, AcdpError::SchemaViolation(_)));
}
#[tokio::test]
async fn https_fetcher_rejects_structured_locator() {
let f = HttpsDataRefFetcher::new();
let mut m = serde_json::Map::new();
m.insert("scheme".into(), serde_json::json!("kafka.offset"));
let err = f.fetch(&Location::Structured(m)).await.unwrap_err();
assert!(matches!(err, AcdpError::NotImplemented(_)));
}
#[tokio::test]
async fn https_fetcher_rejects_ip_literal_private_location() {
let f = HttpsDataRefFetcher::new();
for uri in [
"https://10.0.0.1/data.csv",
"https://127.0.0.1/data.csv",
"https://[::1]/data.csv",
"https://169.254.169.254/latest/meta-data/",
"https://192.168.1.10/export.parquet",
] {
let err = f.fetch(&Location::Uri(uri.into())).await.unwrap_err();
assert!(
matches!(err, AcdpError::SchemaViolation(_)),
"data-ref-ssrf-001: '{uri}' must be refused by the SSRF policy, got {err:?}"
);
}
}
#[tokio::test]
async fn https_fetcher_blocks_hostname_resolving_to_loopback() {
let f = HttpsDataRefFetcher::new();
let err = f
.fetch(&Location::Uri("https://localhost/data.csv".into()))
.await
.unwrap_err();
assert!(
!matches!(err, AcdpError::NotImplemented(_)),
"data-ref-ssrf-002: loopback-resolving host must be blocked, got {err:?}"
);
}
#[tokio::test]
async fn https_fetcher_allow_test_loopback_permits_localhost_dns() {
let f = HttpsDataRefFetcher::new()
.with_ssrf_policy(crate::safe_http::SsrfPolicy::allow_test_loopback());
let _ = f
.fetch(&Location::Uri("https://localhost:1/data.csv".into()))
.await;
}
}