vectordb-cli 1.6.0

A CLI tool for semantic code search.
use crate::vectordb::error::Result;
use crate::vectordb::provider::{EmbeddingProvider};
// Explicitly import the concrete provider type
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use crate::vectordb::error::VectorDBError;
use std::fmt;
use crate::config::AppConfig;
#[cfg(feature = "ort")]
use crate::vectordb::provider::onnx::OnnxEmbeddingModel;

// Use the embedding dimensions from the providers
// use crate::vectordb::provider::fast::FAST_EMBEDDING_DIM;

/// Type alias for embeddings (vectors of f32).
// pub type Embedding = Vec<f32>; // Keep commented out or remove if confirmed unused

/// Enum representing the type of embedding model to use.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum EmbeddingModelType {
    #[default]
    Default, // Represents the default model
    Onnx,
    // Add other model types here in the future (e.g., SentenceTransformers)
}

impl EmbeddingModelType {
    /// Get the expected dimension for the model type.
    pub fn dimension(&self) -> usize {
        match self {
            EmbeddingModelType::Onnx => 384,
            EmbeddingModelType::Default => 384, // Use default dimension for now
            // Add other model types here
        }
    }
}

impl fmt::Display for EmbeddingModelType {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            EmbeddingModelType::Onnx => write!(f, "ONNX"),
            EmbeddingModelType::Default => write!(f, "Default"),
        }
    }
}

/// Represents an embedding model (currently ONNX-based).
#[derive(Clone, Debug)]
pub struct EmbeddingModel {
    provider: Arc<dyn EmbeddingProvider + Send + Sync>,
    model_type: EmbeddingModelType,
    onnx_model_path: Option<PathBuf>,
    onnx_tokenizer_path: Option<PathBuf>,
}

impl EmbeddingModel {
    /// Creates a new ONNX-based EmbeddingModel.
    pub fn new_onnx<P: AsRef<Path>>(model_path: P, tokenizer_path: P) -> Result<Self> {
        // Use full path for the provider constructor
        let onnx_provider = crate::vectordb::provider::onnx::OnnxEmbeddingModel::new(
            model_path.as_ref(), 
            tokenizer_path.as_ref()
        ).map_err(|e| VectorDBError::EmbeddingError(format!("Failed to create ONNX provider: {}", e)))?; // Explicitly map error
        
        Ok(Self {
            provider: Arc::new(onnx_provider),
            model_type: EmbeddingModelType::Onnx,
            onnx_model_path: Some(model_path.as_ref().to_path_buf()),
            onnx_tokenizer_path: Some(tokenizer_path.as_ref().to_path_buf()),
        })
    }

    /// Get the type of the embedding model.
    pub fn model_type(&self) -> EmbeddingModelType {
        self.model_type
    }

    /// Get the dimensions of the embeddings generated by this model.
    pub fn dim(&self) -> usize {
        let provider_ref: &dyn EmbeddingProvider = self.provider.as_ref();
        provider_ref.dimension()
    }

    /// Generates embeddings for a batch of texts.
    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
        let provider_ref: &dyn EmbeddingProvider = self.provider.as_ref();
        provider_ref.embed_batch(texts).map_err(Into::into)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    // use crate::vectordb::error::VectorDBError; // Removed unused import
    use std::path::PathBuf; // Added for invalid path test

    // Removed test_fast_embedding
    // #[test]
    // fn test_fast_embedding() { ... }

    // Removed test_embedding_batch (it used the default FastText model)
    // #[test]
    // fn test_embedding_batch() { ... }

    // Keep test_onnx_embedding_fallback
    #[test]
    fn test_onnx_embedding_fallback() {
        let model_path = Path::new("onnx/all-minilm-l12-v2.onnx");
        let tokenizer_path = Path::new("onnx/minilm_tokenizer.json");

        // Skip test if ONNX files don't exist
        if !model_path.exists() || !tokenizer_path.exists() {
            println!("Skipping ONNX test because model files aren't available");
            return;
        }

        // Create ONNX model
        let onnx_model = EmbeddingModel::new_onnx(model_path, tokenizer_path);
        assert!(onnx_model.is_ok());

        let model = onnx_model.unwrap();
        let expected_dim = model.dim(); // Get dimension from model

        // Test embedding
        let text = "fn main() { let x = 42; }";
        let embedding = model.embed_batch(&[text]).unwrap().into_iter().next().unwrap();

        assert_eq!(embedding.len(), expected_dim); // Check against model's dimension
        assert!(!embedding.iter().all(|&x| x == 0.0));

        // Test cloning
        let cloned_model = model.clone();
        assert_eq!(cloned_model.dim(), expected_dim);
        let cloned_embedding = cloned_model.embed_batch(&[text]).unwrap().into_iter().next().unwrap();
        assert_eq!(embedding, cloned_embedding);
    }

    // Removed test_model_cloning (it used the default FastText model)
    // #[test]
    // fn test_model_cloning() { ... }

    #[test]
    fn test_embedding_model_type_display() {
        assert_eq!(EmbeddingModelType::Onnx.to_string(), "ONNX");
        // Add other types here if they exist
    }

    // Removed tests for FromStr (not implemented)
    // #[test]
    // fn test_embedding_model_type_from_str_valid() { ... }
    // #[test]
    // fn test_embedding_model_type_from_str_invalid() { ... }

    // --- EmbeddingModel Tests (Error paths only for now) ---

    // Mock/Dummy ONNX provider needed for deeper testing
    // #[test]
    // fn test_embedding_model_new_onnx_valid() { ... }

    #[test]
    fn test_embedding_model_new_onnx_invalid_path() {
        // Use paths known not to exist
        let model_path = PathBuf::from("./nonexistent/model.onnx");
        let tokenizer_path = PathBuf::from("./nonexistent/tokenizer.json");

        // This check relies on the underlying `OnnxProvider::new` failing
        // We expect an EmbeddingError wrapping the provider's error
        let result = EmbeddingModel::new_onnx(&model_path, &tokenizer_path);
        assert!(matches!(result, Err(VectorDBError::EmbeddingError(_))));
        // We can't easily assert the inner error message without a real provider error
    }
    
    // Mock Provider needed to test `generate_embeddings`
    // #[test]
    // fn test_embedding_model_generate_embeddings() { ... }
}

pub fn initialize_provider(
    config: &AppConfig,
) -> std::result::Result<Arc<dyn EmbeddingProvider + Send + Sync>, VectorDBError> {
    // Determine model type - currently only supports ONNX/Default
    // In the future, this could check a field like `config.embedding_model.model_type`
    let model_type = EmbeddingModelType::Onnx; // Assume ONNX for now

    match model_type {
        EmbeddingModelType::Default | EmbeddingModelType::Onnx => {
            #[cfg(feature = "ort")]
            {
                // Access AppConfig fields directly
                let model_path = config.onnx_model_path.as_deref()
                    .ok_or_else(|| VectorDBError::ConfigurationError("ONNX model path not set in AppConfig".to_string()))?;
                let tokenizer_path = config.onnx_tokenizer_path.as_deref()
                    .ok_or_else(|| VectorDBError::ConfigurationError("ONNX tokenizer path not set in AppConfig".to_string()))?;
                
                let onnx_provider_result = OnnxEmbeddingModel::new(
                    Path::new(model_path), 
                    Path::new(tokenizer_path)
                );
                    
                match onnx_provider_result {
                    Ok(provider) => Ok(Arc::new(provider)),
                    // Explicitly convert error just in case
                    Err(e) => Err(VectorDBError::from(e)), 
                }
            }
            #[cfg(not(feature = "ort"))]
            {
                Err(VectorDBError::FeatureNotEnabled("ort".to_string()))
            }
        }
        // Handle other model types if necessary
    }
}