sps-net 0.1.56

Networking library for the sps package manager
Documentation
use std::collections::HashMap;
use std::fs::{remove_file, File};
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;

use futures::StreamExt;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use reqwest::header::{ACCEPT, AUTHORIZATION};
use reqwest::{Client, Response, StatusCode};
use serde::{Deserialize, Serialize};
use sps_common::config::Config;
use sps_common::error::{Result, SpsError};
use tracing::{debug, error};
use url::Url;

use crate::http::ProgressCallback;
use crate::validation::{validate_url, verify_checksum};

const OCI_MANIFEST_V1_TYPE: &str = "application/vnd.oci.image.index.v1+json";
const OCI_LAYER_V1_TYPE: &str = "application/vnd.oci.image.layer.v1.tar+gzip";
const DEFAULT_GHCR_TOKEN_ENDPOINT: &str = "https://ghcr.io/token";
pub const DEFAULT_GHCR_DOMAIN: &str = "ghcr.io";

const CONNECT_TIMEOUT_SECS: u64 = 30;
const REQUEST_TIMEOUT_SECS: u64 = 300;
const USER_AGENT_STRING: &str = "sps package manager (Rust; +https://github.com/alexykn/sps)";

#[derive(Deserialize, Debug)]
struct OciTokenResponse {
    token: String,
}

#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
pub struct OciManifestIndex {
    pub schema_version: u32,
    pub media_type: Option<String>,
    pub manifests: Vec<OciManifestDescriptor>,
}

#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
pub struct OciManifestDescriptor {
    pub media_type: String,
    pub digest: String,
    pub size: u64,
    pub platform: Option<OciPlatform>,
    pub annotations: Option<HashMap<String, String>>,
}

#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
pub struct OciPlatform {
    pub architecture: String,
    pub os: String,
    #[serde(rename = "os.version")]
    pub os_version: Option<String>,
    #[serde(default)]
    pub features: Vec<String>,
    pub variant: Option<String>,
}

#[derive(Debug, Clone)]
enum OciAuth {
    None,
    AnonymousBearer { token: String },
    ExplicitBearer { token: String },
    Basic { encoded: String },
}

async fn fetch_oci_resource<T: serde::de::DeserializeOwned>(
    resource_url: &str,
    accept_header: &str,
    config: &Config,
    client: &Client,
) -> Result<T> {
    let url = Url::parse(resource_url)
        .map_err(|e| SpsError::Generic(format!("Invalid URL '{resource_url}': {e}")))?;
    validate_url(url.as_str())?;
    let registry_domain = url.host_str().unwrap_or(DEFAULT_GHCR_DOMAIN);
    let repo_path = extract_repo_path_from_url(&url).unwrap_or("");

    let auth = determine_auth(config, client, registry_domain, repo_path).await?;
    let resp = execute_oci_request(client, resource_url, accept_header, &auth).await?;
    let txt = resp.text().await.map_err(|e| SpsError::Http(Arc::new(e)))?;

    debug!("OCI response ({} bytes) from {}", txt.len(), resource_url);
    serde_json::from_str(&txt).map_err(|e| {
        error!("JSON parse error from {}: {}", resource_url, e);
        SpsError::Json(Arc::new(e))
    })
}

pub async fn download_oci_blob(
    blob_url: &str,
    destination_path: &Path,
    config: &Config,
    client: &Client,
    expected_digest: &str,
) -> Result<()> {
    download_oci_blob_with_progress(
        blob_url,
        destination_path,
        config,
        client,
        expected_digest,
        None,
    )
    .await
}

pub async fn download_oci_blob_with_progress(
    blob_url: &str,
    destination_path: &Path,
    config: &Config,
    client: &Client,
    expected_digest: &str,
    progress_callback: Option<ProgressCallback>,
) -> Result<()> {
    debug!("Downloading OCI blob: {}", blob_url);
    let url = Url::parse(blob_url)
        .map_err(|e| SpsError::Generic(format!("Invalid URL '{blob_url}': {e}")))?;
    validate_url(url.as_str())?;
    let registry_domain = url.host_str().unwrap_or(DEFAULT_GHCR_DOMAIN);
    let repo_path = extract_repo_path_from_url(&url).unwrap_or("");

    let auth = determine_auth(config, client, registry_domain, repo_path).await?;
    let resp = execute_oci_request(client, blob_url, OCI_LAYER_V1_TYPE, &auth).await?;

    // Get total size from Content-Length header if available
    let total_size = resp.content_length();

    let tmp = destination_path.with_file_name(format!(
        ".{}.download",
        destination_path.file_name().unwrap().to_string_lossy()
    ));
    let mut out = File::create(&tmp).map_err(|e| SpsError::Io(Arc::new(e)))?;

    let mut stream = resp.bytes_stream();
    let mut bytes_downloaded = 0u64;

    while let Some(chunk) = stream.next().await {
        let b = chunk.map_err(|e| SpsError::Http(Arc::new(e)))?;
        std::io::Write::write_all(&mut out, &b).map_err(|e| SpsError::Io(Arc::new(e)))?;

        bytes_downloaded += b.len() as u64;

        // Call progress callback if provided
        if let Some(ref callback) = progress_callback {
            callback(bytes_downloaded, total_size);
        }
    }
    std::fs::rename(&tmp, destination_path).map_err(|e| SpsError::Io(Arc::new(e)))?;

    if !expected_digest.is_empty() {
        match verify_checksum(destination_path, expected_digest) {
            Ok(_) => {
                tracing::debug!("OCI Blob checksum verified: {}", destination_path.display());
            }
            Err(e) => {
                tracing::error!(
                    "OCI Blob checksum mismatch ({}). Deleting downloaded file.",
                    e
                );
                let _ = remove_file(destination_path);
                return Err(e);
            }
        }
    } else {
        tracing::warn!(
            "Skipping checksum verification for OCI blob {} - no checksum provided.",
            destination_path.display()
        );
    }

    debug!("Blob saved to {}", destination_path.display());
    Ok(())
}

pub async fn fetch_oci_manifest_index(
    manifest_url: &str,
    config: &Config,
    client: &Client,
) -> Result<OciManifestIndex> {
    fetch_oci_resource(manifest_url, OCI_MANIFEST_V1_TYPE, config, client).await
}

pub fn build_oci_client() -> Result<Client> {
    Client::builder()
        .user_agent(USER_AGENT_STRING)
        .connect_timeout(Duration::from_secs(CONNECT_TIMEOUT_SECS))
        .timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS))
        .redirect(reqwest::redirect::Policy::default())
        .build()
        .map_err(|e| SpsError::Http(Arc::new(e)))
}

fn extract_repo_path_from_url(url: &Url) -> Option<&str> {
    url.path()
        .trim_start_matches('/')
        .trim_start_matches("v2/")
        .split("/manifests/")
        .next()
        .and_then(|s| s.split("/blobs/").next())
        .filter(|s| !s.is_empty())
}

async fn determine_auth(
    config: &Config,
    client: &Client,
    registry_domain: &str,
    repo_path: &str,
) -> Result<OciAuth> {
    if let Some(token) = &config.docker_registry_token {
        debug!("Using explicit bearer for {}", registry_domain);
        return Ok(OciAuth::ExplicitBearer {
            token: token.clone(),
        });
    }
    if let Some(basic) = &config.docker_registry_basic_auth {
        debug!("Using explicit basic auth for {}", registry_domain);
        return Ok(OciAuth::Basic {
            encoded: basic.clone(),
        });
    }

    if registry_domain.eq_ignore_ascii_case(DEFAULT_GHCR_DOMAIN) && !repo_path.is_empty() {
        debug!(
            "Anonymous token fetch for {} scope={}",
            registry_domain, repo_path
        );
        match fetch_anonymous_token(client, registry_domain, repo_path).await {
            Ok(t) => return Ok(OciAuth::AnonymousBearer { token: t }),
            Err(e) => debug!("Anon token failed, proceeding unauthenticated: {}", e),
        }
    }
    Ok(OciAuth::None)
}

async fn fetch_anonymous_token(
    client: &Client,
    registry_domain: &str,
    repo_path: &str,
) -> Result<String> {
    let endpoint = if registry_domain.eq_ignore_ascii_case(DEFAULT_GHCR_DOMAIN) {
        DEFAULT_GHCR_TOKEN_ENDPOINT.to_string()
    } else {
        format!("https://{registry_domain}/token")
    };
    let scope = format!("repository:{repo_path}:pull");
    let token_url = format!("{endpoint}?service={registry_domain}&scope={scope}");

    const MAX_RETRIES: u8 = 3;
    let base_delay = Duration::from_millis(200);
    let mut delay = base_delay;
    // Use a Sendable RNG
    let mut rng = SmallRng::from_os_rng();

    for attempt in 0..=MAX_RETRIES {
        debug!(
            "Token attempt {}/{} from {}",
            attempt + 1,
            MAX_RETRIES + 1,
            token_url
        );

        match client.get(&token_url).send().await {
            Ok(resp) if resp.status().is_success() => {
                let tok: OciTokenResponse = resp
                    .json()
                    .await
                    .map_err(|e| SpsError::ApiRequestError(format!("Parse token response: {e}")))?;
                return Ok(tok.token);
            }
            Ok(resp) => {
                let code = resp.status();
                let body = resp.text().await.unwrap_or_default();
                error!("Token fetch {}: {} – {}", attempt + 1, code, body);
                if !code.is_server_error() || attempt == MAX_RETRIES {
                    return Err(SpsError::Api(format!("Token endpoint {code}: {body}")));
                }
            }
            Err(e) => {
                error!("Network error on token fetch {}: {}", attempt + 1, e);
                if attempt == MAX_RETRIES {
                    return Err(SpsError::Http(Arc::new(e)));
                }
            }
        }

        let jitter = rng.random_range(0..(base_delay.as_millis() as u64 / 2));
        tokio::time::sleep(delay + Duration::from_millis(jitter)).await;
        delay *= 2;
    }

    Err(SpsError::Api(format!(
        "Failed to fetch OCI token after {} attempts",
        MAX_RETRIES + 1
    )))
}

async fn execute_oci_request(
    client: &Client,
    url: &str,
    accept: &str,
    auth: &OciAuth,
) -> Result<Response> {
    debug!("OCI request → {} (Accept: {})", url, accept);
    let mut req = client.get(url).header(ACCEPT, accept);
    match auth {
        OciAuth::AnonymousBearer { token } | OciAuth::ExplicitBearer { token }
            if !token.is_empty() =>
        {
            req = req.header(AUTHORIZATION, format!("Bearer {token}"))
        }
        OciAuth::Basic { encoded } if !encoded.is_empty() => {
            req = req.header(AUTHORIZATION, format!("Basic {encoded}"))
        }
        _ => {}
    }

    let resp = req.send().await.map_err(|e| SpsError::Http(Arc::new(e)))?;
    let status = resp.status();
    if status.is_success() {
        Ok(resp)
    } else {
        let body = resp.text().await.unwrap_or_default();
        error!("OCI {} ⇒ {} – {}", url, status, body);
        let err = match status {
            StatusCode::UNAUTHORIZED => SpsError::Api(format!("Auth required: {status}")),
            StatusCode::FORBIDDEN => SpsError::Api(format!("Permission denied: {status}")),
            StatusCode::NOT_FOUND => SpsError::NotFound(format!("Not found: {status}")),
            _ => SpsError::Api(format!("HTTP {status}{body}")),
        };
        Err(err)
    }
}