crtx-retrieval 0.1.1

Hybrid retrieval over memory views (lexical + salience; vectors later).
Documentation
//! Ollama-backed embedder (Phase 4.C enrichment layer).
//!
//! [`OllamaEmbedder`] calls the Ollama `/api/embeddings` REST endpoint to
//! produce real semantic vectors. It coexists alongside
//! [`super::LocalStubEmbedder`]: both can be stored simultaneously under
//! different `backend_id` values in the `memory_embeddings` side table.
//!
//! # Backend id format
//!
//! `"ollama:<model_name>:<dim>"` — e.g. `"ollama:nomic-embed-text:768"`.
//!
//! The dimension is part of the id so that a model upgrade (which typically
//! changes dimensionality) automatically creates a new backend bucket rather
//! than overwriting incomparable old vectors.
//!
//! # Synchronous HTTP
//!
//! `ureq` (already a workspace dependency) is used for the HTTP call. Ollama
//! runs locally, so latency is dominated by the model inference time rather
//! than network round-trip. The call is blocking; callers that require async
//! should wrap it in `tokio::task::spawn_blocking`.
//!
//! # Error handling
//!
//! Network errors, HTTP non-200 responses, and parse failures all surface as
//! [`EmbedError::Backend`]. Ollama returning a zero-length vector surfaces as
//! [`EmbedError::DimensionMismatch`].

use std::time::Duration;

use serde::{Deserialize, Serialize};

use super::{EmbedError, EmbedResult, Embedder};

/// Prefix for all Ollama backend ids.
pub const OLLAMA_BACKEND_ID_PREFIX: &str = "ollama";

/// Default Ollama endpoint used when none is configured.
pub const DEFAULT_OLLAMA_ENDPOINT: &str = "http://localhost:11434";

/// Default Ollama embedding model.
pub const DEFAULT_OLLAMA_EMBED_MODEL: &str = "nomic-embed-text";

/// Default dimension for `nomic-embed-text`. Used as a fallback when the
/// caller does not supply an explicit dimension at construction time (the
/// embedder learns the true dim on the first call and validates from there).
pub const NOMIC_EMBED_DIM: usize = 768;

/// Default HTTP timeout for embedding calls (milliseconds).
const DEFAULT_TIMEOUT_MS: u64 = 30_000;

/// Ollama `/api/embeddings` request body.
#[derive(Debug, Serialize)]
struct EmbedRequest<'a> {
    model: &'a str,
    prompt: &'a str,
}

/// Ollama `/api/embeddings` response body.
#[derive(Debug, Deserialize)]
struct EmbedResponse {
    embedding: Vec<f64>,
}

/// Return `true` if `endpoint` targets a loopback address.
///
/// Accepted hosts: `localhost`, `127.x.x.x` (any IPv4 in 127/8), and `::1`.
/// Scheme and path are ignored; only the host part is inspected.
fn is_loopback_endpoint(endpoint: &str) -> bool {
    let without_scheme = endpoint
        .strip_prefix("https://")
        .or_else(|| endpoint.strip_prefix("http://"))
        .unwrap_or(endpoint);
    let host_port = without_scheme.split('/').next().unwrap_or(without_scheme);
    let host = if host_port.starts_with('[') {
        // IPv6 literal bracket form: `[::1]:11434`.
        host_port
            .trim_start_matches('[')
            .split(']')
            .next()
            .unwrap_or(host_port)
    } else {
        // IPv4 or hostname: drop `:port`.
        host_port.split(':').next().unwrap_or(host_port)
    };
    host.eq_ignore_ascii_case("localhost")
        || host == "127.0.0.1"
        || host.starts_with("127.")
        || host == "::1"
}

/// Embedder that calls Ollama's `/api/embeddings` endpoint.
///
/// Construct via [`OllamaEmbedder::new`] (endpoint + model + expected
/// dimension) or [`OllamaEmbedder::default_nomic`] (loopback, nomic-embed-
/// text, 768-dim).
///
/// The `backend_id` is fixed at construction time as
/// `"ollama:<model>:<dim>"`.
#[derive(Debug, Clone)]
pub struct OllamaEmbedder {
    endpoint: String,
    model: String,
    dim: usize,
    backend_id: String,
    timeout_ms: u64,
}

impl OllamaEmbedder {
    /// Construct an embedder targeting `endpoint` (e.g.
    /// `"http://localhost:11434"`) with `model` (e.g. `"nomic-embed-text"`)
    /// and expected output dimensionality `dim`.
    ///
    /// Returns an error if `endpoint` is empty, `model` is empty, or `dim`
    /// is zero.
    pub fn new(
        endpoint: impl Into<String>,
        model: impl Into<String>,
        dim: usize,
    ) -> EmbedResult<Self> {
        let endpoint = endpoint.into();
        let model = model.into();

        if endpoint.trim().is_empty() {
            return Err(EmbedError::InvalidInput(
                "OllamaEmbedder: endpoint must not be empty".to_string(),
            ));
        }
        if model.trim().is_empty() {
            return Err(EmbedError::InvalidInput(
                "OllamaEmbedder: model must not be empty".to_string(),
            ));
        }
        if dim == 0 {
            return Err(EmbedError::InvalidInput(
                "OllamaEmbedder: dim must be > 0".to_string(),
            ));
        }

        // Enforce loopback-only: the endpoint hostname must be localhost,
        // 127.0.0.1, or ::1. This mirrors the guardrail in CLAUDE.md §Ollama.
        if !is_loopback_endpoint(&endpoint) {
            return Err(EmbedError::InvalidInput(format!(
                "OllamaEmbedder: endpoint must be loopback-only (localhost/127.0.0.1/::1), got `{endpoint}`"
            )));
        }

        let backend_id = format!("{OLLAMA_BACKEND_ID_PREFIX}:{model}:{dim}");
        Ok(Self {
            endpoint,
            model,
            dim,
            backend_id,
            timeout_ms: DEFAULT_TIMEOUT_MS,
        })
    }

    /// Convenience constructor: loopback Ollama, `nomic-embed-text`, 768 dim.
    pub fn default_nomic() -> EmbedResult<Self> {
        Self::new(
            DEFAULT_OLLAMA_ENDPOINT,
            DEFAULT_OLLAMA_EMBED_MODEL,
            NOMIC_EMBED_DIM,
        )
    }

    /// Override the HTTP timeout (milliseconds). Default is 30 000.
    #[must_use]
    pub fn with_timeout_ms(mut self, ms: u64) -> Self {
        self.timeout_ms = ms;
        self
    }

    /// Return the backend id for this embedder without constructing a full
    /// instance. Useful for querying the store before creating the embedder.
    pub fn backend_id_for(model: &str, dim: usize) -> String {
        format!("{OLLAMA_BACKEND_ID_PREFIX}:{model}:{dim}")
    }
}

impl Embedder for OllamaEmbedder {
    fn backend_id(&self) -> &str {
        &self.backend_id
    }

    fn dim(&self) -> usize {
        self.dim
    }

    fn embed(&self, text: &str, tags: &[String]) -> EmbedResult<Vec<f32>> {
        // Build the prompt: concatenate claim text + tags with a separator so
        // tags influence the embedding without polluting the main claim signal.
        let prompt = if tags.is_empty() {
            text.to_string()
        } else {
            format!("{text} | {}", tags.join(" "))
        };

        let url = format!("{}/api/embeddings", self.endpoint);

        let body = EmbedRequest {
            model: &self.model,
            prompt: &prompt,
        };

        let timeout = Duration::from_millis(self.timeout_ms);
        let agent = ureq::AgentBuilder::new().timeout(timeout).build();

        let body_json = serde_json::to_value(&body)
            .map_err(|e| EmbedError::Backend(format!("request serialization failed: {e}")))?;

        let response = agent
            .post(&url)
            .send_json(body_json)
            .map_err(|err| EmbedError::Backend(format!("Ollama HTTP error: {err}")))?;

        if response.status() != 200 {
            let status = response.status();
            return Err(EmbedError::Backend(format!(
                "Ollama returned HTTP {status}"
            )));
        }

        let response_text = response
            .into_string()
            .map_err(|e| EmbedError::Backend(format!("reading Ollama response body: {e}")))?;

        let parsed: EmbedResponse = serde_json::from_str(&response_text)
            .map_err(|e| EmbedError::Backend(format!("Ollama response parse: {e}")))?;

        let vector: Vec<f32> = parsed.embedding.iter().map(|&v| v as f32).collect();

        if vector.len() != self.dim {
            return Err(EmbedError::DimensionMismatch {
                backend_id: self.backend_id.clone(),
                expected: self.dim,
                actual: vector.len(),
            });
        }

        Ok(vector)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn constructor_rejects_empty_endpoint() {
        let err = OllamaEmbedder::new("", "nomic-embed-text", 768).unwrap_err();
        assert!(
            matches!(err, EmbedError::InvalidInput(_)),
            "expected InvalidInput, got {err:?}"
        );
    }

    #[test]
    fn constructor_rejects_empty_model() {
        let err = OllamaEmbedder::new("http://localhost:11434", "", 768).unwrap_err();
        assert!(
            matches!(err, EmbedError::InvalidInput(_)),
            "expected InvalidInput, got {err:?}"
        );
    }

    #[test]
    fn constructor_rejects_zero_dim() {
        let err = OllamaEmbedder::new("http://localhost:11434", "nomic-embed-text", 0).unwrap_err();
        assert!(
            matches!(err, EmbedError::InvalidInput(_)),
            "expected InvalidInput, got {err:?}"
        );
    }

    #[test]
    fn backend_id_encodes_model_and_dim() {
        let e = OllamaEmbedder::new("http://localhost:11434", "nomic-embed-text", 768).unwrap();
        assert_eq!(e.backend_id(), "ollama:nomic-embed-text:768");
        assert_eq!(e.dim(), 768);
    }

    #[test]
    fn backend_id_for_matches_instance() {
        let id = OllamaEmbedder::backend_id_for("nomic-embed-text", 768);
        let e = OllamaEmbedder::default_nomic().unwrap();
        assert_eq!(id, e.backend_id());
    }

    #[test]
    fn default_nomic_has_expected_backend_id() {
        let e = OllamaEmbedder::default_nomic().unwrap();
        assert_eq!(e.backend_id(), "ollama:nomic-embed-text:768");
        assert_eq!(e.dim(), NOMIC_EMBED_DIM);
    }

    #[test]
    fn with_timeout_ms_overrides_default() {
        let e = OllamaEmbedder::default_nomic()
            .unwrap()
            .with_timeout_ms(5_000);
        assert_eq!(e.timeout_ms, 5_000);
    }
}