agent-source-repository 0.1.0

Agent Source Repository local context registry for coding agents
Documentation
use std::path::Path;

use anyhow::{bail, Context, Result};
use ndarray::{Array1, Array2, Axis};

const DEFAULT_MODEL_NAME: &str = "minishlab/potion-code-16M";
const MODEL_ENV_VAR: &str = "RUMBLE_MODEL";

pub struct StaticEncoder {
    tokenizer: tokenizers::Tokenizer,
    embeddings: Array2<f32>,
    vocab_size: usize,
    dim: usize,
}

impl StaticEncoder {
    pub fn load(model_name: Option<&str>) -> Result<Self> {
        if hf_hub_offline_enabled() {
            bail!("semantic model not available: offline mode (HF_HUB_OFFLINE=1)");
        }

        let configured_model = std::env::var(MODEL_ENV_VAR).ok();
        let name = resolve_model_name(model_name, configured_model.as_deref());
        let api = hf_hub::api::sync::ApiBuilder::from_env()
            .build()
            .context("Failed to create HuggingFace Hub API")?;
        let repo = api.model(name.to_string());
        let tokenizer_path = repo
            .get("tokenizer.json")
            .context("Failed to download tokenizer.json")?;
        let model_path = repo
            .get("model.safetensors")
            .context("Failed to download model.safetensors")?;
        Self::from_files(&tokenizer_path, &model_path)
    }

    pub fn from_files(tokenizer_path: &Path, model_path: &Path) -> Result<Self> {
        let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
            .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {e}"))?;

        let model_data = std::fs::read(model_path).context("Failed to read model file")?;
        let tensors = safetensors::SafeTensors::deserialize(&model_data)
            .map_err(|e| anyhow::anyhow!("Failed to deserialize safetensors: {e}"))?;

        let tensor_names: Vec<String> = tensors.names().into_iter().map(String::from).collect();
        let emb = ["embeddings", "static_embeddings", "embedding.weight"]
            .iter()
            .find_map(|name| tensors.tensor(name).ok())
            .or_else(|| {
                tensor_names.iter().find_map(|name| {
                    let t = tensors.tensor(name).ok()?;
                    if t.shape().len() == 2 {
                        Some(t)
                    } else {
                        None
                    }
                })
            })
            .ok_or_else(|| {
                anyhow::anyhow!("No embedding tensor found in model. Tensors: {tensor_names:?}")
            })?;

        let shape = emb.shape();
        if shape.len() != 2 {
            bail!("Expected 2D embedding tensor, got {}D", shape.len());
        }
        let vocab_size = shape[0];
        let dim = shape[1];

        let embedding_data: Vec<f32> = match emb.dtype() {
            safetensors::Dtype::F32 => emb
                .data()
                .chunks_exact(4)
                .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
                .collect(),
            safetensors::Dtype::F16 => emb
                .data()
                .chunks_exact(2)
                .map(|b| half::f16::from_bits(u16::from_le_bytes([b[0], b[1]])).to_f32())
                .collect(),
            safetensors::Dtype::BF16 => emb
                .data()
                .chunks_exact(2)
                .map(|b| half::bf16::from_bits(u16::from_le_bytes([b[0], b[1]])).to_f32())
                .collect(),
            dt => bail!("Unsupported embedding dtype: {dt:?}"),
        };

        let embeddings = Array2::from_shape_vec((vocab_size, dim), embedding_data)
            .context("Failed to reshape embedding tensor")?;

        Ok(Self {
            tokenizer,
            embeddings,
            vocab_size,
            dim,
        })
    }

    pub fn encode_single(&self, text: &str) -> Result<Array1<f32>> {
        let encoding = self
            .tokenizer
            .encode(text, false)
            .map_err(|e| anyhow::anyhow!("Tokenization failed: {e}"))?;
        let ids = encoding.get_ids();

        let mut sum = Array1::zeros(self.dim);
        let mut count = 0usize;

        for &id in ids {
            let id = id as usize;
            if id < self.vocab_size {
                sum += &self.embeddings.row(id);
                count += 1;
            }
        }

        if count > 0 {
            sum /= count as f32;
        }

        let norm = sum.dot(&sum).sqrt();
        if norm > 1e-12 {
            sum /= norm;
        }

        Ok(sum)
    }

    pub fn encode_batch<S: AsRef<str>>(&self, texts: &[S]) -> Result<Array2<f32>> {
        if texts.is_empty() {
            return Ok(Array2::zeros((0, self.dim)));
        }
        let mut result = Array2::zeros((texts.len(), self.dim));
        for (i, text) in texts.iter().enumerate() {
            let embedding = self.encode_single(text.as_ref())?;
            result.row_mut(i).assign(&embedding);
        }
        Ok(result)
    }
}

fn resolve_model_name<'a>(explicit: Option<&'a str>, configured: Option<&'a str>) -> &'a str {
    explicit.or(configured).unwrap_or(DEFAULT_MODEL_NAME)
}

fn hf_hub_offline_enabled() -> bool {
    std::env::var("HF_HUB_OFFLINE")
        .map(|value| {
            let value = value.trim();
            value == "1" || value.eq_ignore_ascii_case("true")
        })
        .unwrap_or(false)
}

pub struct SemanticIndex {
    embeddings: Array2<f32>,
}

impl SemanticIndex {
    pub fn new(mut embeddings: Array2<f32>) -> Self {
        for mut row in embeddings.axis_iter_mut(Axis(0)) {
            let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
            if norm > 1e-12 {
                row.mapv_inplace(|x| x / norm);
            }
        }
        Self { embeddings }
    }

    pub fn query(
        &self,
        query_embedding: &Array1<f32>,
        k: usize,
        selector: Option<&[usize]>,
    ) -> Vec<(usize, f32)> {
        let norm: f32 = query_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
        let query_norm = if norm > 1e-12 {
            query_embedding.mapv(|x| x / norm)
        } else {
            query_embedding.clone()
        };

        if let Some(selector) = selector {
            let mut dists: Vec<(usize, f32)> = selector
                .iter()
                .filter(|&&idx| idx < self.embeddings.nrows())
                .map(|&idx| {
                    let sim: f32 = self.embeddings.row(idx).dot(&query_norm);
                    (idx, 1.0 - sim)
                })
                .collect();
            dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
            dists.truncate(k);
            dists
        } else {
            let similarities = self.embeddings.dot(&query_norm);
            let mut dists: Vec<(usize, f32)> = similarities
                .iter()
                .enumerate()
                .map(|(idx, &sim)| (idx, 1.0 - sim))
                .collect();
            if k < dists.len() {
                dists.select_nth_unstable_by(k, |a, b| {
                    a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
                });
                dists.truncate(k);
            }
            dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
            dists
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::Mutex;

    static ENV_LOCK: Mutex<()> = Mutex::new(());

    #[test]
    fn load_honors_hf_hub_offline_without_network_or_cache_lookup() {
        let _guard = ENV_LOCK.lock().expect("env lock should not be poisoned");
        let previous = std::env::var("HF_HUB_OFFLINE").ok();
        std::env::set_var("HF_HUB_OFFLINE", "1");

        let err = match StaticEncoder::load(Some("definitely/not-needed")) {
            Ok(_) => panic!("offline mode should stop before hub access"),
            Err(err) => err,
        };

        match previous {
            Some(value) => std::env::set_var("HF_HUB_OFFLINE", value),
            None => std::env::remove_var("HF_HUB_OFFLINE"),
        }

        let message = err.to_string();
        assert!(
            message.contains("offline mode"),
            "error should explain offline semantic model state: {message}"
        );
    }

    #[test]
    fn model_name_defaults_to_builtin_model() {
        assert_eq!(resolve_model_name(None, None), DEFAULT_MODEL_NAME);
    }

    #[test]
    fn model_name_uses_environment_when_no_explicit_model_is_set() {
        assert_eq!(
            resolve_model_name(None, Some("example/custom-static-encoder")),
            "example/custom-static-encoder"
        );
    }

    #[test]
    fn explicit_model_name_overrides_environment() {
        assert_eq!(
            resolve_model_name(Some("example/explicit"), Some("example/env")),
            "example/explicit"
        );
    }
}