use std::path::PathBuf;
use anyhow::{Context, Error, Result};
use byteorder::{LittleEndian, ReadBytesExt};
use candle::{utils, DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::voxtral;
use candle_transformers::models::voxtral::{
VoxtralCache, VoxtralConfig, VoxtralEncoderConfig, VoxtralForConditionalGeneration,
VoxtralGenerationConfig, VoxtralLlamaConfig as LlamaConfig,
};
use serde_json;
use std::io::Cursor;
use tekken::Tekkenizer;
use super::download;
const SAMPLE_RATE: u32 = 16000;
#[derive(Debug, serde::Serialize)]
pub struct TranscriptionResult {
pub text: String,
pub tokens: Vec<u32>,
}
pub struct VoxtralModel {
model: VoxtralForConditionalGeneration,
tokenizer: Tekkenizer,
device: Device,
audio_token_id: usize,
cache: VoxtralCache,
}
impl VoxtralModel {
pub fn new(model_id: &str, use_cpu: bool) -> Result<Self> {
let device = if !use_cpu && utils::cuda_is_available() {
Device::new_cuda(0).context("Failed to create CUDA device")?
} else {
Device::Cpu
};
let (model_files, tokenizer_file) = download::model_files(model_id)?;
let config = load_model_config(&model_files.0)?;
let vb = load_model_weights(&model_files.1, &device)?;
let model = VoxtralForConditionalGeneration::new(&config, vb)?;
let tokenizer = Tekkenizer::from_file(tokenizer_file).map_err(Error::msg)?;
let cache = VoxtralCache::new(true, DType::F16, &config.text_config, &device)?;
let audio_token_id = config.audio_token_id;
Ok(Self {
model,
tokenizer,
device,
audio_token_id,
cache,
})
}
pub fn transcribe_audio(
&mut self,
audio_data: &[f32],
sample_rate: u32,
) -> Result<TranscriptionResult> {
let audio = if sample_rate == SAMPLE_RATE {
audio_data.to_vec()
} else {
candle_examples::audio::resample(audio_data, sample_rate, SAMPLE_RATE)
.context("Failed to resample audio")?
};
let chunk_size = 480000; let padded_audio = if audio.len() % chunk_size != 0 {
let target_samples = ((audio.len() / chunk_size) + 1) * chunk_size;
let mut padded = audio.clone();
padded.resize(target_samples, 0.0); padded
} else {
audio
};
let mel_bytes = include_bytes!("melfilters128.bytes");
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
let mut cursor = Cursor::new(mel_bytes);
cursor.read_f32_into::<LittleEndian>(&mut mel_filters)?;
let audio_features =
voxtral::extract_features(&padded_audio, &mel_filters, &self.device()).unwrap();
let (result, tokens) = transcribe_with_voxtral(
&self.model,
&self.tokenizer,
&audio_features,
&self.audio_token_id,
&self.device,
&self.cache.clone(),
)?;
Ok(TranscriptionResult {
text: result,
tokens,
})
}
pub fn device(&self) -> &Device {
&self.device
}
}
fn transcribe_with_voxtral(
model: &VoxtralForConditionalGeneration,
tokenizer: &Tekkenizer,
audio_features: &Tensor,
audio_token_id: &usize,
device: &Device,
cache: &VoxtralCache,
) -> Result<(String, Vec<u32>)> {
let audio_dims = audio_features.dims();
if audio_dims.len() != 3 {
return Err(anyhow::anyhow!(
"Audio features must be 3D tensor (batch, mels, time), got shape: {:?}",
audio_dims
));
}
if audio_dims[1] != 128 {
return Err(anyhow::anyhow!(
"Audio features must have 128 mel bins, got {}",
audio_dims[1]
));
}
let mut input_tokens = Vec::new();
input_tokens.push(1u32); input_tokens.push(3u32); input_tokens.push(25u32);
let batch_size = audio_features.dim(0)?;
let tokens_per_chunk = 375; let num_audio_tokens = batch_size * tokens_per_chunk;
for _ in 0..num_audio_tokens {
input_tokens.push(*audio_token_id as u32); }
input_tokens.push(4u32); input_tokens.push(9909u32); input_tokens.push(1058u32); input_tokens.push(1262u32); input_tokens.push(34u32);
let input_len = input_tokens.len();
let input_ids = Tensor::new(input_tokens, device)?.unsqueeze(0)?;
let generation_config = VoxtralGenerationConfig {
max_new_tokens: 1000, temperature: 0.0, top_p: None,
device: device.clone(),
cache: Some(cache.clone()),
};
let generated_tokens = model
.generate(
&input_ids,
Some(audio_features), generation_config,
)
.map_err(|e| {
println!("Generation error: {:?}", e);
println!("Error details: {:#}", e);
anyhow::anyhow!("Failed to generate tokens: {e}")
})?;
let new_tokens = if generated_tokens.len() > input_len {
&generated_tokens[input_len..]
} else {
&generated_tokens
};
let decoded_text = tokenizer
.decode(new_tokens, tekken::SpecialTokenPolicy::Ignore)
.map_err(|e| anyhow::anyhow!("Failed to decode tokens: {}", e))?;
Ok((decoded_text, new_tokens.to_vec()))
}
fn load_model_weights<'a>(model_files: &'a [PathBuf], device: &Device) -> Result<VarBuilder<'a>> {
let dtype = DType::F16;
if let candle::Device::Cuda(_) = device {
device.synchronize()?;
}
let vb = unsafe { VarBuilder::from_mmaped_safetensors(model_files, dtype, device)? };
if let candle::Device::Cuda(_) = device {
device.synchronize()?;
}
Ok(vb)
}
fn load_model_config(config_file: &PathBuf) -> Result<VoxtralConfig> {
let config_str = std::fs::read_to_string(config_file)?;
let json: serde_json::Value =
serde_json::from_str(&config_str).context("Failed to parse config.json")?;
let audio_token_id = json
.get("audio_token_id")
.and_then(|v| v.as_u64())
.unwrap_or(24) as usize;
let audio_config = parse_audio_config(&json)?;
let text_config = parse_text_config(&json)?;
let projector_hidden_act = json
.get("projector_hidden_act")
.and_then(|v| v.as_str())
.unwrap_or("gelu")
.to_string();
Ok(VoxtralConfig {
audio_config,
text_config,
audio_token_id,
projector_hidden_act,
})
}
fn parse_audio_config(json: &serde_json::Value) -> Result<VoxtralEncoderConfig> {
let audio_json = json
.get("audio_config")
.ok_or_else(|| anyhow::anyhow!("Missing audio_config in configuration"))?;
Ok(VoxtralEncoderConfig {
vocab_size: audio_json
.get("vocab_size")
.and_then(|v| v.as_u64())
.unwrap_or(51866) as usize,
hidden_size: audio_json
.get("hidden_size")
.and_then(|v| v.as_u64())
.unwrap_or(1280) as usize,
num_hidden_layers: audio_json
.get("num_hidden_layers")
.and_then(|v| v.as_u64())
.unwrap_or(32) as usize,
num_attention_heads: audio_json
.get("num_attention_heads")
.and_then(|v| v.as_u64())
.unwrap_or(20) as usize,
num_key_value_heads: audio_json
.get("num_key_value_heads")
.and_then(|v| v.as_u64())
.unwrap_or(20) as usize,
intermediate_size: audio_json
.get("intermediate_size")
.and_then(|v| v.as_u64())
.unwrap_or(5120) as usize,
dropout: audio_json
.get("dropout")
.and_then(|v| v.as_f64())
.unwrap_or(0.0),
attention_dropout: audio_json
.get("attention_dropout")
.and_then(|v| v.as_f64())
.unwrap_or(0.0),
activation_dropout: audio_json
.get("activation_dropout")
.and_then(|v| v.as_f64())
.unwrap_or(0.0),
activation_function: audio_json
.get("activation_function")
.and_then(|v| v.as_str())
.unwrap_or("gelu")
.to_string(),
max_source_positions: audio_json
.get("max_source_positions")
.and_then(|v| v.as_u64())
.unwrap_or(1500) as usize,
layerdrop: audio_json
.get("layerdrop")
.and_then(|v| v.as_f64())
.unwrap_or(0.0),
initializer_range: audio_json
.get("initializer_range")
.and_then(|v| v.as_f64())
.unwrap_or(0.02),
scale_embedding: audio_json
.get("scale_embedding")
.and_then(|v| v.as_bool())
.unwrap_or(false),
num_mel_bins: audio_json
.get("num_mel_bins")
.and_then(|v| v.as_u64())
.unwrap_or(128) as usize,
head_dim: audio_json
.get("head_dim")
.and_then(|v| v.as_u64())
.unwrap_or(64) as usize,
})
}
fn parse_text_config(json: &serde_json::Value) -> Result<LlamaConfig> {
let text_json = json
.get("text_config")
.ok_or_else(|| anyhow::anyhow!("Missing text_config in configuration"))?;
Ok(LlamaConfig {
vocab_size: text_json
.get("vocab_size")
.and_then(|v| v.as_u64())
.unwrap_or(131072) as usize,
hidden_size: text_json
.get("hidden_size")
.and_then(|v| v.as_u64())
.unwrap_or(3072) as usize,
intermediate_size: text_json
.get("intermediate_size")
.and_then(|v| v.as_u64())
.unwrap_or(8192) as usize,
num_hidden_layers: text_json
.get("num_hidden_layers")
.and_then(|v| v.as_u64())
.unwrap_or(30) as usize,
num_attention_heads: text_json
.get("num_attention_heads")
.and_then(|v| v.as_u64())
.unwrap_or(32) as usize,
num_key_value_heads: text_json
.get("num_key_value_heads")
.and_then(|v| v.as_u64())
.unwrap_or(8) as usize,
head_dim: text_json
.get("head_dim")
.and_then(|v| v.as_u64())
.map(|v| v as usize),
rms_norm_eps: text_json
.get("rms_norm_eps")
.and_then(|v| v.as_f64())
.unwrap_or(1e-5),
rope_theta: text_json
.get("rope_theta")
.and_then(|v| v.as_f64())
.unwrap_or(100_000_000.0) as f32,
max_position_embeddings: text_json
.get("max_position_embeddings")
.and_then(|v| v.as_u64())
.unwrap_or(131072) as usize,
use_flash_attn: false,
tie_word_embeddings: text_json
.get("attention_bias")
.and_then(|v| v.as_bool())
.unwrap_or(false),
})
}