Skip to main content

ferrum_models/architectures/
mod.rs

1//! Model architecture implementations
2
3pub mod bert;
4pub mod clip;
5pub mod llama;
6pub mod qwen2;
7pub mod qwen3;
8pub mod qwen3_tts;
9pub mod qwen3_tts_vocoder;
10pub mod speaker_encoder;
11pub mod speech_tokenizer_encoder;
12pub mod whisper;
13
14pub use bert::BertModelWrapper;
15pub use clip::ClipModelWrapper;
16pub use llama::LlamaModelWrapper;
17pub use qwen2::Qwen2ModelWrapper;
18pub use qwen3::Qwen3ModelWrapper;
19pub use whisper::WhisperModelWrapper;
20
21/// GQA repeat_kv: repeat K/V heads to match Q heads.
22pub(crate) fn repeat_kv(
23    x: candle_core::Tensor,
24    n_rep: usize,
25) -> candle_core::Result<candle_core::Tensor> {
26    if n_rep == 1 {
27        return Ok(x);
28    }
29    let (b, nkv, seq, hd) = x.dims4()?;
30    let x = x.unsqueeze(2)?;
31    let x = x.expand((b, nkv, n_rep, seq, hd))?;
32    x.reshape((b, nkv * n_rep, seq, hd))
33}