vectordb-cli 1.2.0

A CLI tool for semantic code search.
use crate::vectordb::error::Result;
use crate::vectordb::provider::{EmbeddingProvider};
// Explicitly import the concrete provider type
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use crate::vectordb::error::VectorDBError;
use std::fmt;

// Use the embedding dimensions from the providers
// use crate::vectordb::provider::fast::FAST_EMBEDDING_DIM;

/// Supported embedding models.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash, Default)]
pub enum EmbeddingModelType {
    /// Use the ONNX model for embeddings.
    #[default]
    Onnx,
    // No specific CodeBert type needed if we handle dimensions dynamically
}

impl EmbeddingModelType {
    /// Returns the default embedding dimension for this model type.
    /// Used as a fallback when loading an index without an explicit dimension stored.
    pub fn default_dimension(&self) -> usize {
        match self {
            // TODO: Make this dynamically configurable or read from a default ONNX model?
            // For now, assume the default ONNX is MiniLM with 384 dims.
            EmbeddingModelType::Onnx => 384,
        }
    }
}

impl fmt::Display for EmbeddingModelType {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            EmbeddingModelType::Onnx => write!(f, "ONNX"),
        }
    }
}

/// Represents an embedding model (currently ONNX-based).
#[derive(Clone, Debug)]
pub struct EmbeddingModel {
    provider: Arc<dyn EmbeddingProvider + Send + Sync>,
    model_type: EmbeddingModelType,
    onnx_model_path: Option<PathBuf>,
    onnx_tokenizer_path: Option<PathBuf>,
}

impl EmbeddingModel {
    /// Creates a new ONNX-based EmbeddingModel.
    pub fn new_onnx<P: AsRef<Path>>(model_path: P, tokenizer_path: P) -> Result<Self> {
        // Use full path for the provider constructor
        let onnx_provider = crate::vectordb::provider::onnx::OnnxEmbeddingProvider::new(
            model_path.as_ref(), 
            tokenizer_path.as_ref()
        ).map_err(|e| VectorDBError::EmbeddingError(format!("Failed to create ONNX provider: {}", e)))?; // Explicitly map error
        
        Ok(Self {
            provider: Arc::new(onnx_provider),
            model_type: EmbeddingModelType::Onnx,
            onnx_model_path: Some(model_path.as_ref().to_path_buf()),
            onnx_tokenizer_path: Some(tokenizer_path.as_ref().to_path_buf()),
        })
    }

    /// Get the type of the embedding model.
    pub fn model_type(&self) -> EmbeddingModelType {
        self.model_type
    }

    /// Get the dimensions of the embeddings generated by this model.
    pub fn dim(&self) -> usize {
        let provider_ref: &dyn EmbeddingProvider = self.provider.as_ref();
        provider_ref.dimension()
    }

    /// Generates an embedding for the given text.
    pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
        let provider_ref: &dyn EmbeddingProvider = self.provider.as_ref();
        provider_ref.embed(text).map_err(Into::into)
    }

    /// Generates embeddings for a batch of texts.
    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
        let provider_ref: &dyn EmbeddingProvider = self.provider.as_ref();
        provider_ref.embed_batch(texts).map_err(Into::into)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    // use crate::vectordb::error::VectorDBError; // Removed unused import
    use std::path::PathBuf; // Added for invalid path test

    // Removed test_fast_embedding
    // #[test]
    // fn test_fast_embedding() { ... }

    // Removed test_embedding_batch (it used the default FastText model)
    // #[test]
    // fn test_embedding_batch() { ... }

    // Keep test_onnx_embedding_fallback
    #[test]
    fn test_onnx_embedding_fallback() {
        let model_path = Path::new("onnx/all-minilm-l12-v2.onnx");
        let tokenizer_path = Path::new("onnx/minilm_tokenizer.json");

        // Skip test if ONNX files don't exist
        if !model_path.exists() || !tokenizer_path.exists() {
            println!("Skipping ONNX test because model files aren't available");
            return;
        }

        // Create ONNX model
        let onnx_model = EmbeddingModel::new_onnx(model_path, tokenizer_path);
        assert!(onnx_model.is_ok());

        let model = onnx_model.unwrap();
        let expected_dim = model.dim(); // Get dimension from model

        // Test embedding
        let text = "fn main() { let x = 42; }";
        let embedding = model.embed(text).unwrap();

        assert_eq!(embedding.len(), expected_dim); // Check against model's dimension
        assert!(!embedding.iter().all(|&x| x == 0.0));

        // Test cloning
        let cloned_model = model.clone();
        assert_eq!(cloned_model.dim(), expected_dim);
        let cloned_embedding = cloned_model.embed(text).unwrap();
        assert_eq!(embedding, cloned_embedding);
    }

    // Removed test_model_cloning (it used the default FastText model)
    // #[test]
    // fn test_model_cloning() { ... }

    #[test]
    fn test_embedding_model_type_display() {
        assert_eq!(EmbeddingModelType::Onnx.to_string(), "ONNX");
        // Add other types here if they exist
    }

    // Removed tests for FromStr (not implemented)
    // #[test]
    // fn test_embedding_model_type_from_str_valid() { ... }
    // #[test]
    // fn test_embedding_model_type_from_str_invalid() { ... }

    // --- EmbeddingModel Tests (Error paths only for now) ---

    // Mock/Dummy ONNX provider needed for deeper testing
    // #[test]
    // fn test_embedding_model_new_onnx_valid() { ... }

    #[test]
    fn test_embedding_model_new_onnx_invalid_path() {
        // Use paths known not to exist
        let model_path = PathBuf::from("./nonexistent/model.onnx");
        let tokenizer_path = PathBuf::from("./nonexistent/tokenizer.json");

        // This check relies on the underlying `OnnxProvider::new` failing
        // We expect an EmbeddingError wrapping the provider's error
        let result = EmbeddingModel::new_onnx(&model_path, &tokenizer_path);
        assert!(matches!(result, Err(VectorDBError::EmbeddingError(_))));
        // We can't easily assert the inner error message without a real provider error
    }
    
    // Mock Provider needed to test `generate_embeddings`
    // #[test]
    // fn test_embedding_model_generate_embeddings() { ... }
}