use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
pub use candle_core;
pub use candle_nn;
pub use candle_transformers;
use candle_core::{DType, Device, Tensor};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ModelType {
Embedding,
TextGeneration,
SpeechRecognition,
Seq2Seq,
Vision,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LocalMlConfig {
pub model_id: String,
pub model_type: ModelType,
pub use_gpu: bool,
pub cache_dir: Option<PathBuf>,
pub quantized: bool,
pub dtype: String,
}
impl Default for LocalMlConfig {
fn default() -> Self {
Self {
model_id: "BAAI/bge-small-en-v1.5".to_string(),
model_type: ModelType::Embedding,
use_gpu: true,
cache_dir: None,
quantized: false,
dtype: "f32".to_string(),
}
}
}
impl LocalMlConfig {
pub fn bge_small() -> Self {
Self {
model_id: "BAAI/bge-small-en-v1.5".to_string(),
model_type: ModelType::Embedding,
..Default::default()
}
}
pub fn bge_m3() -> Self {
Self {
model_id: "BAAI/bge-m3".to_string(),
model_type: ModelType::Embedding,
..Default::default()
}
}
pub fn minilm() -> Self {
Self {
model_id: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
model_type: ModelType::Embedding,
..Default::default()
}
}
pub fn llama(model_id: impl Into<String>) -> Self {
Self {
model_id: model_id.into(),
model_type: ModelType::TextGeneration,
quantized: true,
..Default::default()
}
}
pub fn whisper(size: &str) -> Self {
Self {
model_id: format!("openai/whisper-{}", size),
model_type: ModelType::SpeechRecognition,
..Default::default()
}
}
}
pub struct DeviceSelector;
impl DeviceSelector {
pub fn best_available() -> Device {
#[cfg(feature = "cuda")]
{
if let Ok(device) = Device::new_cuda(0) {
return device;
}
}
#[cfg(feature = "metal")]
{
if let Ok(device) = Device::new_metal(0) {
return device;
}
}
Device::Cpu
}
pub fn cpu() -> Device {
Device::Cpu
}
pub fn is_gpu_available() -> bool {
#[cfg(feature = "cuda")]
{
if Device::new_cuda(0).is_ok() {
return true;
}
}
#[cfg(feature = "metal")]
{
if Device::new_metal(0).is_ok() {
return true;
}
}
false
}
pub fn device_name(device: &Device) -> &'static str {
match device {
Device::Cpu => "CPU",
Device::Cuda(_) => "CUDA GPU",
Device::Metal(_) => "Metal GPU",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingResult {
pub embedding: Vec<f32>,
pub dimension: usize,
pub source_preview: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenerationParams {
pub max_tokens: usize,
pub temperature: f64,
pub top_p: f64,
pub top_k: usize,
pub repetition_penalty: f32,
pub stop_sequences: Vec<String>,
}
impl Default for GenerationParams {
fn default() -> Self {
Self {
max_tokens: 512,
temperature: 0.7,
top_p: 0.95,
top_k: 40,
repetition_penalty: 1.1,
stop_sequences: Vec::new(),
}
}
}
pub fn normalize_embedding(embedding: &mut [f32]) {
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in embedding.iter_mut() {
*x /= norm;
}
}
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a > 0.0 && norm_b > 0.0 {
dot / (norm_a * norm_b)
} else {
0.0
}
}
pub fn mean_pooling(token_embeddings: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let mask = attention_mask
.unsqueeze(2)?
.broadcast_as(token_embeddings.shape())?;
let masked = token_embeddings.broadcast_mul(&mask.to_dtype(token_embeddings.dtype())?)?;
let sum = masked.sum(1)?;
let count = mask.sum(1)?.to_dtype(token_embeddings.dtype())?;
Ok(sum.broadcast_div(&count)?)
}
pub fn dtype_from_str(s: &str) -> DType {
match s.to_lowercase().as_str() {
"f16" | "float16" => DType::F16,
"bf16" | "bfloat16" => DType::BF16,
"f64" | "float64" => DType::F64,
_ => DType::F32,
}
}
pub fn default_cache_dir() -> PathBuf {
let base = dirs::cache_dir().unwrap_or_else(|| PathBuf::from(".cache"));
base.join("reasonkit").join("models")
}
pub fn is_model_cached(model_id: &str, cache_dir: Option<&PathBuf>) -> bool {
let cache = cache_dir.cloned().unwrap_or_else(default_cache_dir);
let model_path = cache.join(model_id.replace('/', "--"));
model_path.exists()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = LocalMlConfig::default();
assert_eq!(config.model_type, ModelType::Embedding);
assert!(config.use_gpu);
}
#[test]
fn test_config_presets() {
let bge = LocalMlConfig::bge_small();
assert!(bge.model_id.contains("bge"));
let minilm = LocalMlConfig::minilm();
assert!(minilm.model_id.contains("MiniLM"));
let whisper = LocalMlConfig::whisper("small");
assert!(whisper.model_id.contains("whisper"));
}
#[test]
fn test_normalize_embedding() {
let mut embedding = vec![3.0, 4.0];
normalize_embedding(&mut embedding);
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0];
let b = vec![1.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
let c = vec![0.0, 1.0];
assert!(cosine_similarity(&a, &c).abs() < 1e-6);
}
#[test]
fn test_dtype_from_str() {
assert!(matches!(dtype_from_str("f32"), DType::F32));
assert!(matches!(dtype_from_str("f16"), DType::F16));
assert!(matches!(dtype_from_str("bf16"), DType::BF16));
}
#[test]
fn test_device_cpu() {
let device = DeviceSelector::cpu();
assert!(matches!(device, Device::Cpu));
}
}