use std::time::Duration;
use crate::packs::manifest::{PackIndex, PackManifest, manifest_sha256};
pub const DEFAULT_PACK_REGISTRY: &str =
"https://raw.githubusercontent.com/difflore/rule-packs/main";
const FETCH_TIMEOUT: Duration = Duration::from_secs(20);
const MAX_REDIRECTS: usize = 4;
#[derive(Debug)]
pub enum PackFetchError {
BadUrl(String),
Transport(String),
Status { url: String, status: u16 },
Io(String),
Parse(String),
IntegrityMismatch { expected: String, actual: String },
}
impl std::fmt::Display for PackFetchError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::BadUrl(m) => write!(f, "invalid registry URL: {m}"),
Self::Transport(m) => write!(f, "could not reach registry: {m}"),
Self::Status { url, status } => {
write!(f, "registry returned HTTP {status} for {url}")
}
Self::Io(m) => write!(f, "could not read local registry path: {m}"),
Self::Parse(m) => write!(f, "registry payload did not parse: {m}"),
Self::IntegrityMismatch { expected, actual } => write!(
f,
"pack manifest failed integrity check (sha256 expected {expected}, got {actual}) \
— refusing to install"
),
}
}
}
impl std::error::Error for PackFetchError {}
#[allow(dead_code)]
fn is_file_registry(base: &str) -> bool {
base.starts_with("file://")
}
fn join_url(base: &str, rel: &str) -> String {
format!(
"{}/{}",
base.trim_end_matches('/'),
rel.trim_start_matches('/')
)
}
async fn get_bytes(url: &str) -> Result<Vec<u8>, PackFetchError> {
if let Some(path) = url.strip_prefix("file://") {
let path = path
.strip_prefix('/')
.filter(|p| p.as_bytes().get(1) == Some(&b':'))
.unwrap_or(path);
return tokio::fs::read(path)
.await
.map_err(|e| PackFetchError::Io(format!("{path}: {e}")));
}
let client = reqwest::Client::builder()
.timeout(FETCH_TIMEOUT)
.redirect(reqwest::redirect::Policy::limited(MAX_REDIRECTS))
.build()
.map_err(|e| PackFetchError::Transport(format!("could not build HTTP client: {e}")))?;
let resp = client
.get(url)
.send()
.await
.map_err(|e| PackFetchError::Transport(e.to_string()))?;
let status = resp.status();
if !status.is_success() {
return Err(PackFetchError::Status {
url: url.to_owned(),
status: status.as_u16(),
});
}
resp.bytes()
.await
.map(|b| b.to_vec())
.map_err(|e| PackFetchError::Transport(e.to_string()))
}
pub async fn fetch_index(registry_base: &str) -> Result<PackIndex, PackFetchError> {
let base = registry_base.trim();
if base.is_empty() {
return Err(PackFetchError::BadUrl("empty registry base".to_owned()));
}
let url = join_url(base, "index.json");
let bytes = get_bytes(&url).await?;
serde_json::from_slice(&bytes).map_err(|e| PackFetchError::Parse(format!("index.json: {e}")))
}
pub async fn fetch_manifest(
registry_base: &str,
manifest_rel: &str,
expected_sha256: &str,
) -> Result<PackManifest, PackFetchError> {
let base = registry_base.trim();
if base.is_empty() {
return Err(PackFetchError::BadUrl("empty registry base".to_owned()));
}
let url = join_url(base, manifest_rel);
let bytes = get_bytes(&url).await?;
let actual = manifest_sha256(&bytes);
let expected = expected_sha256.trim().to_ascii_lowercase();
if !expected.is_empty() && actual != expected {
return Err(PackFetchError::IntegrityMismatch { expected, actual });
}
serde_json::from_slice(&bytes)
.map_err(|e| PackFetchError::Parse(format!("{manifest_rel}: {e}")))
}
#[must_use]
pub fn is_default_registry(registry_base: &str) -> bool {
registry_base.trim().trim_end_matches('/') == DEFAULT_PACK_REGISTRY
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn join_url_normalises_slashes() {
assert_eq!(
join_url("https://example.com/reg/", "/index.json"),
"https://example.com/reg/index.json"
);
assert_eq!(
join_url("https://example.com/reg", "packs/a/pack.json"),
"https://example.com/reg/packs/a/pack.json"
);
}
#[test]
fn detects_file_and_default_registries() {
assert!(is_file_registry("file:///tmp/reg"));
assert!(!is_file_registry("https://example.com"));
assert!(is_default_registry(DEFAULT_PACK_REGISTRY));
assert!(is_default_registry(&format!("{DEFAULT_PACK_REGISTRY}/")));
assert!(!is_default_registry("https://example.com/fork"));
}
#[tokio::test]
async fn file_registry_round_trips_index() {
let dir = tempfile::tempdir().expect("tempdir");
let index_path = dir.path().join("index.json");
std::fs::write(
&index_path,
r#"{"schemaVersion":1,"packs":[{"id":"x/y","name":"Y","latest":"1.0.0","versions":{}}]}"#,
)
.expect("write");
let base = format!("file://{}", dir.path().display());
let index = fetch_index(&base).await.expect("fetch index");
assert_eq!(index.packs.len(), 1);
assert_eq!(index.packs[0].id, "x/y");
}
#[tokio::test]
async fn manifest_integrity_mismatch_is_refused() {
let dir = tempfile::tempdir().expect("tempdir");
let raw = r#"{"schemaVersion":1,"id":"x/y","name":"Y","version":"1.0.0","rules":[]}"#;
std::fs::write(dir.path().join("pack.json"), raw).expect("write");
let base = format!("file://{}", dir.path().display());
let err = fetch_manifest(&base, "pack.json", "0000")
.await
.expect_err("should refuse");
assert!(matches!(err, PackFetchError::IntegrityMismatch { .. }));
let good = manifest_sha256(raw.as_bytes());
let manifest = fetch_manifest(&base, "pack.json", &good)
.await
.expect("fetch manifest");
assert_eq!(manifest.id, "x/y");
}
}