use crate::runtime_adapter::{BackendError, BackendResult, InferenceBackend, RuntimeType};
use candle_core::{Device, Tensor};
use ndarray::ArrayD;
use std::collections::HashMap;
use std::path::Path;
use super::device::{select_device, DeviceSelection};
use super::whisper::WhisperModel;
pub struct CandleBackend {
device: Device,
whisper_model: Option<WhisperModel>,
model_path: Option<String>,
}
impl CandleBackend {
pub fn new() -> BackendResult<Self> {
Self::with_device(DeviceSelection::Auto)
}
pub fn with_device(preference: DeviceSelection) -> BackendResult<Self> {
let device = select_device(preference)
.map_err(|e| BackendError::RuntimeError(format!("Device selection failed: {}", e)))?;
Ok(Self {
device,
whisper_model: None,
model_path: None,
})
}
pub fn device(&self) -> &Device {
&self.device
}
pub fn has_whisper_model(&self) -> bool {
self.whisper_model.is_some()
}
pub fn run_whisper(&mut self, mel: &Tensor) -> BackendResult<String> {
let model = self
.whisper_model
.as_mut()
.ok_or(BackendError::ModelNotLoaded)?;
model
.transcribe(mel)
.map_err(|e| BackendError::InferenceFailed(format!("Whisper inference failed: {}", e)))
}
}
impl Default for CandleBackend {
fn default() -> Self {
Self::new().expect("Failed to create default CandleBackend")
}
}
impl InferenceBackend for CandleBackend {
fn runtime_type(&self) -> RuntimeType {
RuntimeType::Candle
}
fn load_model(&mut self, model_path: &Path, _config_path: Option<&Path>) -> BackendResult<()> {
let path_str = model_path.to_string_lossy();
let is_whisper = path_str.contains("whisper")
|| model_path.join("config.json").exists()
|| model_path.join("model.safetensors").exists();
if is_whisper {
let model = WhisperModel::load(model_path, &self.device).map_err(|e| {
BackendError::LoadFailed(format!("Failed to load Whisper model: {}", e))
})?;
self.whisper_model = Some(model);
self.model_path = Some(path_str.to_string());
Ok(())
} else {
Err(BackendError::LoadFailed(format!(
"Unsupported model type at path: {}. Currently only Whisper models are supported.",
path_str
)))
}
}
fn run_inference(
&self,
inputs: HashMap<String, ArrayD<f32>>,
) -> BackendResult<HashMap<String, ArrayD<f32>>> {
let mel_input = inputs
.get("mel")
.or_else(|| inputs.get("input"))
.ok_or_else(|| {
BackendError::InvalidInput(
"Expected 'mel' or 'input' tensor for Whisper inference".to_string(),
)
})?;
let shape: Vec<usize> = mel_input.shape().to_vec();
let data: Vec<f32> = mel_input.iter().copied().collect();
let mel_tensor = Tensor::from_vec(data, shape.as_slice(), &self.device)
.map_err(|e| BackendError::InvalidInput(format!("Failed to create tensor: {}", e)))?;
let _ = mel_tensor;
Err(BackendError::InferenceFailed(
"Candle Whisper requires mutable model access for inference. \
Use CandleBackend::run_whisper() directly with mutable reference."
.to_string(),
))
}
fn is_loaded(&self) -> bool {
self.whisper_model.is_some()
}
fn input_names(&self) -> BackendResult<Vec<String>> {
if self.whisper_model.is_some() {
Ok(vec!["mel".to_string()])
} else {
Err(BackendError::ModelNotLoaded)
}
}
fn output_names(&self) -> BackendResult<Vec<String>> {
if self.whisper_model.is_some() {
Ok(vec!["encoder_output".to_string(), "text".to_string()])
} else {
Err(BackendError::ModelNotLoaded)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backend_creation() {
let backend = CandleBackend::new();
assert!(backend.is_ok());
let backend = backend.unwrap();
assert!(!backend.is_loaded());
}
#[test]
fn test_runtime_type() {
let backend = CandleBackend::new().unwrap();
assert_eq!(backend.runtime_type(), RuntimeType::Candle);
}
#[test]
fn test_input_names_without_model() {
let backend = CandleBackend::new().unwrap();
assert!(matches!(
backend.input_names(),
Err(BackendError::ModelNotLoaded)
));
}
}