use candle_core::{Device, Tensor};
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CandleModelType {
Whisper,
LLaMA,
Bert,
Generic,
}
impl CandleModelType {
pub fn from_str(s: &str) -> Self {
match s.to_lowercase().as_str() {
"whisper" => Self::Whisper,
"llama" | "llama2" | "llama3" => Self::LLaMA,
"bert" => Self::Bert,
_ => Self::Generic,
}
}
}
pub type ModelResult<T> = Result<T, ModelError>;
#[derive(Debug, thiserror::Error)]
pub enum ModelError {
#[error("Failed to load model: {0}")]
LoadFailed(String),
#[error("Inference failed: {0}")]
InferenceFailed(String),
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Unsupported operation: {0}")]
Unsupported(String),
#[error("Candle error: {0}")]
Candle(#[from] candle_core::Error),
}
pub trait CandleModel: Send {
fn model_type(&self) -> CandleModelType;
fn device(&self) -> &Device;
fn run(&mut self, inputs: HashMap<String, Tensor>) -> ModelResult<HashMap<String, Tensor>>;
fn input_names(&self) -> Vec<&str>;
fn output_names(&self) -> Vec<&str>;
}
pub fn load_candle_model(
model_type: CandleModelType,
model_path: &Path,
device: &Device,
) -> ModelResult<Box<dyn CandleModel>> {
match model_type {
CandleModelType::Whisper => {
use super::whisper::{WhisperConfig, WhisperModel};
let model =
WhisperModel::load_with_config(model_path, device, WhisperConfig::default())
.map_err(|e| ModelError::LoadFailed(e.to_string()))?;
Ok(Box::new(WhisperModelWrapper { model }))
}
CandleModelType::LLaMA => Err(ModelError::Unsupported(
"LLaMA model support not yet implemented".to_string(),
)),
CandleModelType::Bert => Err(ModelError::Unsupported(
"BERT model support not yet implemented".to_string(),
)),
CandleModelType::Generic => Err(ModelError::Unsupported(
"Generic model loading not supported. Specify a model type.".to_string(),
)),
}
}
pub struct WhisperModelWrapper {
pub model: super::whisper::WhisperModel,
}
impl CandleModel for WhisperModelWrapper {
fn model_type(&self) -> CandleModelType {
CandleModelType::Whisper
}
fn device(&self) -> &Device {
self.model.device()
}
fn run(&mut self, inputs: HashMap<String, Tensor>) -> ModelResult<HashMap<String, Tensor>> {
if let Some(mel) = inputs.get("mel") {
let text = self
.model
.transcribe(mel)
.map_err(|e| ModelError::InferenceFailed(e.to_string()))?;
let text_bytes: Vec<f32> = text.bytes().map(|b| b as f32).collect();
let text_tensor = Tensor::from_vec(text_bytes.clone(), text_bytes.len(), self.device())
.map_err(ModelError::Candle)?;
let mut outputs = HashMap::new();
outputs.insert("text".to_string(), text_tensor);
Ok(outputs)
} else if let Some(_pcm) = inputs.get("pcm") {
Err(ModelError::Unsupported(
"Direct PCM input not supported via trait interface. \
Use WhisperModel::transcribe_pcm() directly."
.to_string(),
))
} else {
Err(ModelError::InvalidInput(
"Whisper expects 'mel' tensor input".to_string(),
))
}
}
fn input_names(&self) -> Vec<&str> {
vec!["mel"]
}
fn output_names(&self) -> Vec<&str> {
vec!["text"]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_type_from_str() {
assert_eq!(
CandleModelType::from_str("whisper"),
CandleModelType::Whisper
);
assert_eq!(
CandleModelType::from_str("WHISPER"),
CandleModelType::Whisper
);
assert_eq!(CandleModelType::from_str("llama"), CandleModelType::LLaMA);
assert_eq!(CandleModelType::from_str("bert"), CandleModelType::Bert);
assert_eq!(
CandleModelType::from_str("unknown"),
CandleModelType::Generic
);
}
}