crtx-llm 0.1.1

Claude, Ollama, and replay adapters behind a shared trait.
Documentation
//! Local validation helpers for Ollama configuration.
//!
//! These checks are deliberately pure: they parse caller-supplied model and
//! endpoint strings without probing Ollama or making any network request.

use std::net::IpAddr;

use crate::adapter::LlmError;

const SHA256_PREFIX: &str = "@sha256:";
const SHA256_HEX_LEN: usize = 64;

/// Minimal Ollama adapter configuration that can be validated offline.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OllamaConfig {
    /// Base API endpoint used by an Ollama HTTP adapter.
    pub endpoint_url: String,
    /// Model identifier requested from Ollama.
    pub model: String,
}

impl OllamaConfig {
    /// Build a new config from endpoint and model strings.
    #[must_use]
    pub fn new(endpoint_url: impl Into<String>, model: impl Into<String>) -> Self {
        Self {
            endpoint_url: endpoint_url.into(),
            model: model.into(),
        }
    }

    /// Validate this config without contacting Ollama.
    pub fn validate(&self) -> Result<(), LlmError> {
        validate_config(self)
    }
}

/// Validate the Ollama endpoint URL and model reference.
pub fn validate_config(config: &OllamaConfig) -> Result<(), LlmError> {
    validate_endpoint_url(&config.endpoint_url)?;
    validate_model_ref(&config.model)
}

/// Validate that `model` is pinned by a SHA-256 digest.
///
/// Accepted refs have a non-empty name followed by `@sha256:` and exactly
/// sixty-four hex characters, such as
/// `llama3.1:8b@sha256:0123...abcd`.
pub fn validate_model_ref(model: &str) -> Result<(), LlmError> {
    let Some((name, digest)) = model.rsplit_once(SHA256_PREFIX) else {
        return Err(invalid_request(format!(
            "ollama model ref must be digest-pinned with @sha256:<64 hex chars>: {model}"
        )));
    };

    if name.is_empty() {
        return Err(invalid_request(
            "ollama model ref must include a model name before @sha256".to_string(),
        ));
    }

    if digest.len() != SHA256_HEX_LEN || !digest.as_bytes().iter().all(u8::is_ascii_hexdigit) {
        return Err(invalid_request(format!(
            "ollama model ref has invalid sha256 digest; expected 64 hex chars: {model}"
        )));
    }

    Ok(())
}

/// Validate that `endpoint_url` uses HTTP(S) and a loopback host.
///
/// Loopback hosts are `localhost`, any `127.0.0.0/8` address, and `::1`.
pub fn validate_endpoint_url(endpoint_url: &str) -> Result<(), LlmError> {
    let rest = if let Some(rest) = endpoint_url.strip_prefix("http://") {
        rest
    } else if let Some(rest) = endpoint_url.strip_prefix("https://") {
        rest
    } else {
        return Err(invalid_request(format!(
            "ollama endpoint must use http:// or https:// loopback URL: {endpoint_url}"
        )));
    };

    let host = extract_host(rest).ok_or_else(|| {
        invalid_request(format!(
            "ollama endpoint must include a loopback host: {endpoint_url}"
        ))
    })?;

    if is_loopback_host(host) {
        Ok(())
    } else {
        Err(invalid_request(format!(
            "ollama endpoint host must be loopback-only; got {host}"
        )))
    }
}

fn extract_host(rest: &str) -> Option<&str> {
    let authority = rest.split(['/', '?', '#']).next().unwrap_or_default();
    if authority.is_empty() || authority.contains('@') {
        return None;
    }

    if let Some(after_open) = authority.strip_prefix('[') {
        let (host, suffix) = after_open.split_once(']')?;
        if suffix.is_empty() || suffix.starts_with(':') {
            return Some(host);
        }
        return None;
    }

    let host = authority.split(':').next().unwrap_or_default();
    if host.is_empty() {
        None
    } else {
        Some(host)
    }
}

fn is_loopback_host(host: &str) -> bool {
    if host.eq_ignore_ascii_case("localhost") {
        return true;
    }

    host.parse::<IpAddr>().is_ok_and(|ip| ip.is_loopback())
}

fn invalid_request(message: String) -> LlmError {
    LlmError::InvalidRequest(message)
}