reasonkit-core 0.1.8

The Reasoning Engine — Auditable Reasoning for Production AI | Rust-Native | Turn Prompts into Protocols
//! Local ML Inference with Candle
//!
//! This module provides local machine learning inference using Hugging Face's
//! Candle framework, enabling serverless ML without external API calls.
//!
//! # Features
//! - Local embedding generation (no API calls)
//! - Model loading from Hugging Face Hub
//! - GPU acceleration (CUDA/Metal) when available
//! - Support for popular models (LLaMA, Whisper, T5, BERT, etc.)
//!
//! Enable with: `cargo build --features local-ml`

use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;

// Re-export Candle modules for direct access
pub use candle_core;
pub use candle_nn;
pub use candle_transformers;

use candle_core::{DType, Device, Tensor};

/// Supported model types
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ModelType {
    /// Text embedding models (BERT, BGE, etc.)
    Embedding,
    /// Text generation models (LLaMA, Mistral, etc.)
    TextGeneration,
    /// Speech recognition (Whisper)
    SpeechRecognition,
    /// Text-to-text (T5, BART)
    Seq2Seq,
    /// Vision models
    Vision,
}

/// Configuration for local ML inference
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LocalMlConfig {
    /// Model identifier (Hugging Face format: org/model)
    pub model_id: String,
    /// Model type
    pub model_type: ModelType,
    /// Use GPU if available
    pub use_gpu: bool,
    /// Cache directory for downloaded models
    pub cache_dir: Option<PathBuf>,
    /// Use quantized model if available
    pub quantized: bool,
    /// Data type for inference
    pub dtype: String,
}

impl Default for LocalMlConfig {
    fn default() -> Self {
        Self {
            model_id: "BAAI/bge-small-en-v1.5".to_string(),
            model_type: ModelType::Embedding,
            use_gpu: true,
            cache_dir: None,
            quantized: false,
            dtype: "f32".to_string(),
        }
    }
}

impl LocalMlConfig {
    /// Create config for BGE embedding model
    pub fn bge_small() -> Self {
        Self {
            model_id: "BAAI/bge-small-en-v1.5".to_string(),
            model_type: ModelType::Embedding,
            ..Default::default()
        }
    }

    /// Create config for BGE-M3 multilingual embedding
    pub fn bge_m3() -> Self {
        Self {
            model_id: "BAAI/bge-m3".to_string(),
            model_type: ModelType::Embedding,
            ..Default::default()
        }
    }

    /// Create config for all-MiniLM embedding
    pub fn minilm() -> Self {
        Self {
            model_id: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
            model_type: ModelType::Embedding,
            ..Default::default()
        }
    }

    /// Create config for LLaMA text generation
    pub fn llama(model_id: impl Into<String>) -> Self {
        Self {
            model_id: model_id.into(),
            model_type: ModelType::TextGeneration,
            quantized: true,
            ..Default::default()
        }
    }

    /// Create config for Whisper speech recognition
    pub fn whisper(size: &str) -> Self {
        Self {
            model_id: format!("openai/whisper-{}", size),
            model_type: ModelType::SpeechRecognition,
            ..Default::default()
        }
    }
}

/// Device selection helper
pub struct DeviceSelector;

impl DeviceSelector {
    /// Get the best available device (GPU if available, otherwise CPU)
    pub fn best_available() -> Device {
        #[cfg(feature = "cuda")]
        {
            if let Ok(device) = Device::new_cuda(0) {
                return device;
            }
        }

        #[cfg(feature = "metal")]
        {
            if let Ok(device) = Device::new_metal(0) {
                return device;
            }
        }

        Device::Cpu
    }

    /// Get CPU device
    pub fn cpu() -> Device {
        Device::Cpu
    }

    /// Check if GPU is available
    pub fn is_gpu_available() -> bool {
        #[cfg(feature = "cuda")]
        {
            if Device::new_cuda(0).is_ok() {
                return true;
            }
        }

        #[cfg(feature = "metal")]
        {
            if Device::new_metal(0).is_ok() {
                return true;
            }
        }

        false
    }

    /// Get device name
    pub fn device_name(device: &Device) -> &'static str {
        match device {
            Device::Cpu => "CPU",
            Device::Cuda(_) => "CUDA GPU",
            Device::Metal(_) => "Metal GPU",
        }
    }
}

/// Embedding result
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingResult {
    /// The embedding vector
    pub embedding: Vec<f32>,
    /// Embedding dimension
    pub dimension: usize,
    /// Source text (truncated)
    pub source_preview: String,
}

/// Text generation parameters
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenerationParams {
    /// Maximum tokens to generate
    pub max_tokens: usize,
    /// Temperature (0.0 - 2.0)
    pub temperature: f64,
    /// Top-p (nucleus) sampling
    pub top_p: f64,
    /// Top-k sampling
    pub top_k: usize,
    /// Repetition penalty
    pub repetition_penalty: f32,
    /// Stop sequences
    pub stop_sequences: Vec<String>,
}

impl Default for GenerationParams {
    fn default() -> Self {
        Self {
            max_tokens: 512,
            temperature: 0.7,
            top_p: 0.95,
            top_k: 40,
            repetition_penalty: 1.1,
            stop_sequences: Vec::new(),
        }
    }
}

/// Normalize an embedding vector to unit length
pub fn normalize_embedding(embedding: &mut [f32]) {
    let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
    if norm > 0.0 {
        for x in embedding.iter_mut() {
            *x /= norm;
        }
    }
}

/// Compute cosine similarity between two embedding vectors
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
    if a.len() != b.len() {
        return 0.0;
    }

    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();

    if norm_a > 0.0 && norm_b > 0.0 {
        dot / (norm_a * norm_b)
    } else {
        0.0
    }
}

/// Mean pooling over token embeddings
pub fn mean_pooling(token_embeddings: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
    let mask = attention_mask
        .unsqueeze(2)?
        .broadcast_as(token_embeddings.shape())?;
    let masked = token_embeddings.broadcast_mul(&mask.to_dtype(token_embeddings.dtype())?)?;
    let sum = masked.sum(1)?;
    let count = mask.sum(1)?.to_dtype(token_embeddings.dtype())?;
    Ok(sum.broadcast_div(&count)?)
}

/// Get data type from string
pub fn dtype_from_str(s: &str) -> DType {
    match s.to_lowercase().as_str() {
        "f16" | "float16" => DType::F16,
        "bf16" | "bfloat16" => DType::BF16,
        "f64" | "float64" => DType::F64,
        _ => DType::F32,
    }
}

/// Model cache directory
pub fn default_cache_dir() -> PathBuf {
    let base = dirs::cache_dir().unwrap_or_else(|| PathBuf::from(".cache"));
    base.join("reasonkit").join("models")
}

/// Check if a model is cached locally
pub fn is_model_cached(model_id: &str, cache_dir: Option<&PathBuf>) -> bool {
    let cache = cache_dir.cloned().unwrap_or_else(default_cache_dir);
    let model_path = cache.join(model_id.replace('/', "--"));
    model_path.exists()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_config_default() {
        let config = LocalMlConfig::default();
        assert_eq!(config.model_type, ModelType::Embedding);
        assert!(config.use_gpu);
    }

    #[test]
    fn test_config_presets() {
        let bge = LocalMlConfig::bge_small();
        assert!(bge.model_id.contains("bge"));

        let minilm = LocalMlConfig::minilm();
        assert!(minilm.model_id.contains("MiniLM"));

        let whisper = LocalMlConfig::whisper("small");
        assert!(whisper.model_id.contains("whisper"));
    }

    #[test]
    fn test_normalize_embedding() {
        let mut embedding = vec![3.0, 4.0];
        normalize_embedding(&mut embedding);

        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
        assert!((norm - 1.0).abs() < 1e-6);
    }

    #[test]
    fn test_cosine_similarity() {
        let a = vec![1.0, 0.0];
        let b = vec![1.0, 0.0];
        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);

        let c = vec![0.0, 1.0];
        assert!(cosine_similarity(&a, &c).abs() < 1e-6);
    }

    #[test]
    fn test_dtype_from_str() {
        assert!(matches!(dtype_from_str("f32"), DType::F32));
        assert!(matches!(dtype_from_str("f16"), DType::F16));
        assert!(matches!(dtype_from_str("bf16"), DType::BF16));
    }

    #[test]
    fn test_device_cpu() {
        let device = DeviceSelector::cpu();
        assert!(matches!(device, Device::Cpu));
    }
}