Skip to main content

hanzo_engine/
layers.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{f32::consts::PI, ops::Mul, str::FromStr, sync::Arc};
4
5use float8::F8E4M3;
6use half::{bf16, f16};
7use hanzo_ml::{
8    quantized::{QMatMul, QTensor},
9    Context, DType, Device, IndexOp, Result, Tensor, D,
10};
11use hanzo_nn::{
12    BatchNorm, BatchNormConfig, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, Embedding, GroupNorm,
13    LayerNorm, LayerNormConfig, Linear, Module,
14};
15use hanzo_quant::{
16    AfqLayer, ColumnParallelLayer, Convolution, QuantMethod, QuantizedConfig, RowParallelLayer,
17    ShardedVarBuilder,
18};
19use serde::{Deserialize, Serialize};
20
21pub use crate::attention::Sdpa;
22pub use crate::layers_masker::{CausalMaskConfig, CausalMasker};
23pub use crate::layers_utils::repeat_kv;
24use crate::{
25    amoe::{AnyMoeTrainableLayer, MlpLayer},
26    embedding_models::embedding_gemma::EmbeddingGemmaConfig,
27    gguf::Content,
28    models::{llama, smollm3},
29    ops::SplitOp,
30    vision_models::{
31        gemma3::config::Gemma3TextConfig,
32        gemma3n::config::Gemma3nTextConfig,
33        llama4,
34        mllama::{MLlamaRopeScaling, MLlamaRopeType, MLlamaTextConfig},
35        phi4::Phi4MMConfig,
36    },
37};
38
39pub use hanzo_quant::MatMul;
40
41pub fn embedding(
42    in_size: usize,
43    out_size: usize,
44    vb: ShardedVarBuilder,
45    config: &Option<QuantizedConfig>,
46) -> Result<Embedding> {
47    // AFQ quantized applies quantization to the embeddings.
48    let embeddings = if let Some(QuantizedConfig::Afq { .. }) = config {
49        let afq_layer =
50            AfqLayer::afq_linear_b(out_size, in_size, config.as_ref().unwrap(), false, vb)?;
51        afq_layer.dequantize_w()?
52    } else {
53        vb.get_with_hints((in_size, out_size), "weight", Default::default())?
54    };
55    Ok(Embedding::new(embeddings, out_size))
56}
57
58pub fn layer_norm<C: Into<LayerNormConfig>>(
59    size: usize,
60    config: C,
61    vb: ShardedVarBuilder,
62) -> Result<LayerNorm> {
63    let config = config.into();
64    let weight = vb.get(size, "weight")?;
65    if config.affine {
66        let bias = vb.get(size, "bias")?;
67        Ok(LayerNorm::new(weight, bias, config.eps))
68    } else {
69        Ok(LayerNorm::new_no_bias(weight, config.eps))
70    }
71}
72
73pub fn batch_norm<C: Into<BatchNormConfig>>(
74    num_features: usize,
75    config: C,
76    vb: ShardedVarBuilder,
77) -> Result<BatchNorm> {
78    let config = config.into();
79    if config.eps < 0. {
80        hanzo_ml::bail!("batch-norm eps cannot be negative {}", config.eps)
81    }
82    let running_mean = vb.get(num_features, "running_mean")?;
83    let running_var = vb.get(num_features, "running_var")?;
84
85    if config.affine {
86        let weight = vb.get(num_features, "weight")?;
87        let bias = vb.get(num_features, "bias")?;
88        BatchNorm::new(
89            num_features,
90            running_mean,
91            running_var,
92            weight,
93            bias,
94            config.eps,
95        )
96    } else {
97        BatchNorm::new_no_bias(num_features, running_mean, running_var, config.eps)
98    }
99}
100
101pub fn group_norm(
102    num_groups: usize,
103    num_channels: usize,
104    eps: f64,
105    vb: ShardedVarBuilder,
106) -> Result<GroupNorm> {
107    let weight = vb.get(num_channels, "weight")?;
108    let bias = vb.get(num_channels, "bias")?;
109    GroupNorm::new(weight, bias, num_channels, num_groups, eps)
110}
111
112pub fn conv2d(
113    in_channels: usize,
114    out_channels: usize,
115    kernel_size: usize,
116    cfg: Conv2dConfig,
117    vb: ShardedVarBuilder,
118) -> Result<Conv2d> {
119    let ws = vb.get(
120        (
121            out_channels,
122            in_channels / cfg.groups,
123            kernel_size,
124            kernel_size,
125        ),
126        "weight",
127    )?;
128    let bs = vb.get(out_channels, "bias")?;
129    Ok(Conv2d::new(ws, Some(bs), cfg))
130}
131
132pub fn conv2d_no_bias(
133    in_channels: usize,
134    out_channels: usize,
135    kernel_size: usize,
136    cfg: Conv2dConfig,
137    vb: ShardedVarBuilder,
138) -> Result<Conv2d> {
139    let ws = vb.get(
140        (
141            out_channels,
142            in_channels / cfg.groups,
143            kernel_size,
144            kernel_size,
145        ),
146        "weight",
147    )?;
148    Ok(Conv2d::new(ws, None, cfg))
149}
150
151pub fn conv1d(
152    in_channels: usize,
153    out_channels: usize,
154    kernel_size: usize,
155    cfg: Conv1dConfig,
156    vb: ShardedVarBuilder,
157) -> Result<Conv1d> {
158    let ws = vb.get(
159        (out_channels, in_channels / cfg.groups, kernel_size),
160        "weight",
161    )?;
162    let bs = vb.get(out_channels, "bias")?;
163    Ok(Conv1d::new(ws, Some(bs), cfg))
164}
165
166pub fn conv1d_no_bias(
167    in_channels: usize,
168    out_channels: usize,
169    kernel_size: usize,
170    cfg: Conv1dConfig,
171    vb: ShardedVarBuilder,
172) -> Result<Conv1d> {
173    let ws = vb.get(
174        (out_channels, in_channels / cfg.groups, kernel_size),
175        "weight",
176    )?;
177    Ok(Conv1d::new(ws, None, cfg))
178}
179
180pub fn linear(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
181    let ws = vb.get((out_dim, in_dim), "weight")?;
182    let bs = vb.get(out_dim, "bias")?;
183    Ok(Linear::new(ws, Some(bs)))
184}
185
186pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
187    let ws = vb.get((out_dim, in_dim), "weight")?;
188    Ok(Linear::new(ws, None))
189}
190
191pub fn linear_b(
192    in_dim: usize,
193    out_dim: usize,
194    bias: bool,
195    vb: ShardedVarBuilder,
196) -> Result<Linear> {
197    if bias {
198        linear(in_dim, out_dim, vb)
199    } else {
200        linear_no_bias(in_dim, out_dim, vb)
201    }
202}
203
204#[derive(Debug, Clone)]
205pub struct RmsNorm {
206    eps: f64,
207    weight: Tensor,
208}
209
210impl RmsNorm {
211    pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
212        let w = vb.get(size, "weight")?;
213        Ok(Self { eps, weight: w })
214    }
215
216    /// Gemma uses weight + 1.0
217    #[deprecated(
218        note = "Use GemmaRmsNorm::new() instead, which handles UQFF serialization correctly"
219    )]
220    pub fn new_gemma(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
221        let w = vb.get(size, "weight")?;
222        let w = (w + 1.0)?;
223        Ok(Self { eps, weight: w })
224    }
225
226    /// Gemma 3n uses weight
227    pub fn new_gemma_3n(
228        size: usize,
229        eps: f64,
230        with_scale: bool,
231        vb: ShardedVarBuilder,
232    ) -> Result<Self> {
233        let w = if with_scale {
234            vb.get(size, "weight")?
235        } else {
236            Tensor::ones(size, vb.dtype(), vb.device())?
237        };
238        Ok(Self { eps, weight: w })
239    }
240
241    /// Gemma uses weight + 1.0. Undo for UQFF generation.
242    #[deprecated(note = "Use GemmaRmsNorm instead, which handles UQFF serialization automatically")]
243    pub fn undo_gemma(&self) -> Result<Self> {
244        Ok(Self {
245            eps: self.eps,
246            weight: (&self.weight - 1.0)?,
247        })
248    }
249
250    pub fn from_w(w: Tensor, eps: f64) -> Result<Self> {
251        Ok(Self { eps, weight: w })
252    }
253
254    pub fn weight(&self) -> &Tensor {
255        &self.weight
256    }
257
258    pub fn eps(&self) -> f64 {
259        self.eps
260    }
261
262    pub fn forward_residual(&self, x: &Tensor, residual: &Tensor) -> Result<Tensor> {
263        rms_norm_forward_residual(x, residual, &self.weight, self.eps, None)
264    }
265
266    pub fn forward_residual_scaled(
267        &self,
268        x: &Tensor,
269        residual: &Tensor,
270        scale: &Tensor,
271    ) -> Result<Tensor> {
272        rms_norm_forward_residual(x, residual, &self.weight, self.eps, Some(scale))
273    }
274}
275
276impl Module for RmsNorm {
277    fn forward(&self, x: &Tensor) -> Result<Tensor> {
278        hanzo_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
279    }
280}
281
282fn rms_norm_forward_residual(
283    x: &Tensor,
284    residual: &Tensor,
285    weight: &Tensor,
286    eps: f64,
287    scale: Option<&Tensor>,
288) -> Result<Tensor> {
289    #[cfg(feature = "cuda")]
290    if x.device().is_cuda()
291        && residual.device().same_device(x.device())
292        && weight.device().same_device(x.device())
293        && scale.is_none_or(|scale| scale.device().same_device(x.device()))
294        && x.dtype() == residual.dtype()
295        && x.dtype() == weight.dtype()
296        && scale.is_none_or(|scale| scale.dtype() == x.dtype())
297        && matches!(x.dtype(), DType::BF16 | DType::F16 | DType::F32)
298    {
299        return crate::ops::cuda_rms_norm_residual(x, residual, weight, scale, eps as f32);
300    }
301
302    #[cfg(feature = "metal")]
303    if x.device().is_metal()
304        && residual.device().same_device(x.device())
305        && weight.device().same_device(x.device())
306        && scale.is_none_or(|scale| scale.device().same_device(x.device()))
307        && x.dtype() == residual.dtype()
308        && x.dtype() == weight.dtype()
309        && scale.is_none_or(|scale| scale.dtype() == x.dtype())
310        && matches!(x.dtype(), DType::BF16 | DType::F16 | DType::F32)
311    {
312        if let Some(out) =
313            crate::ops::metal_rms_norm_residual(x, residual, weight, scale, eps as f32)?
314        {
315            return Ok(out);
316        }
317    }
318
319    let normed = hanzo_nn::ops::rms_norm(&x.contiguous()?, weight, eps as f32)?;
320    let out = (residual + normed)?;
321    if let Some(scale) = scale {
322        out.broadcast_mul(scale)
323    } else {
324        Ok(out)
325    }
326}
327
328/// Gemma-style RmsNorm that adds +1.0 to the weight during initialization.
329///
330/// Unlike using `RmsNorm::new_gemma()`, this type stores the original checkpoint
331/// weight separately, ensuring that UQFF serialization (via `ToTensors`) always
332/// returns the un-offset weight. This prevents the double-addition bug where
333/// `new_gemma` would add +1.0 on both write and read.
334#[derive(Debug, Clone)]
335pub struct GemmaRmsNorm {
336    eps: f64,
337    original_weight: Tensor,
338    weight: Tensor,
339}
340
341impl GemmaRmsNorm {
342    pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
343        let original_weight = vb.get(size, "weight")?;
344        let weight = (&original_weight + 1.0)?;
345        Ok(Self {
346            eps,
347            original_weight,
348            weight,
349        })
350    }
351
352    pub fn weight(&self) -> &Tensor {
353        &self.weight
354    }
355
356    pub fn original_weight(&self) -> &Tensor {
357        &self.original_weight
358    }
359
360    pub fn eps(&self) -> f64 {
361        self.eps
362    }
363
364    pub fn forward_residual(&self, x: &Tensor, residual: &Tensor) -> Result<Tensor> {
365        rms_norm_forward_residual(x, residual, &self.weight, self.eps, None)
366    }
367
368    pub fn forward_residual_scaled(
369        &self,
370        x: &Tensor,
371        residual: &Tensor,
372        scale: &Tensor,
373    ) -> Result<Tensor> {
374        rms_norm_forward_residual(x, residual, &self.weight, self.eps, Some(scale))
375    }
376}
377
378impl Module for GemmaRmsNorm {
379    fn forward(&self, x: &Tensor) -> Result<Tensor> {
380        hanzo_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
381    }
382}
383
384#[derive(Debug, Clone)]
385pub struct F32RmsNorm {
386    w: Tensor,
387    eps: f64,
388}
389
390impl F32RmsNorm {
391    pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
392        Ok(Self {
393            w: vb.get((size,), "weight")?,
394            eps,
395        })
396    }
397
398    pub fn weight(&self) -> &Tensor {
399        &self.w
400    }
401}
402
403impl Module for F32RmsNorm {
404    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
405        let initial_type = xs.dtype();
406        let mut xs = xs.to_dtype(DType::F32)?;
407        let var = xs.powf(2.)?.mean_keepdim(D::Minus1)?;
408        xs = xs.broadcast_mul(&(&var + self.eps)?.recip()?.sqrt()?)?;
409        xs.to_dtype(initial_type)?.broadcast_mul(&self.w)
410    }
411}
412
413#[derive(Debug, Clone)]
414pub struct QRmsNorm {
415    eps: f64,
416    weight: Tensor,
417}
418
419impl QRmsNorm {
420    pub fn new(scale: QTensor, eps: f32) -> Result<Self> {
421        let scale = scale.dequantize(&scale.device())?;
422        Ok(Self {
423            eps: eps as f64,
424            weight: scale,
425        })
426    }
427
428    pub fn weight(&self) -> &Tensor {
429        &self.weight
430    }
431
432    pub fn eps(&self) -> f64 {
433        self.eps
434    }
435
436    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
437        hanzo_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
438    }
439}
440
441/// RoPE supporting LongRope
442#[derive(Debug, Clone)]
443pub struct PhiRotaryEmbedding {
444    short_sin: Tensor,
445    short_cos: Tensor,
446    long_cos: Option<Tensor>,
447    long_sin: Option<Tensor>,
448    original_max_position_embeddings: usize,
449}
450
451#[derive(Debug, Clone, Deserialize, Serialize)]
452#[serde(rename_all = "lowercase")]
453pub enum ScaledRopeType {
454    #[serde(alias = "su")]
455    #[serde(alias = "longrope")]
456    Su,
457    #[serde(alias = "yarn")]
458    Yarn,
459    #[serde(alias = "dynamic")]
460    Dynamic,
461    #[serde(alias = "linear")]
462    Linear,
463}
464
465impl FromStr for ScaledRopeType {
466    type Err = hanzo_ml::Error;
467    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
468        match s {
469            "su" | "longrope" => Ok(Self::Su),
470            "yarn" => Ok(Self::Yarn),
471            "linear" => Ok(Self::Linear),
472            "dynamic" => Ok(Self::Dynamic),
473            _ => Err(hanzo_ml::Error::Msg(
474                "Expected either `su` or `yarn` scaled RoPE type.".to_string(),
475            )),
476        }
477    }
478}
479
480#[derive(Debug, Clone, Deserialize, Serialize)]
481#[serde(untagged)]
482pub enum PhiRopeScalingConfig {
483    Classic {
484        short_factor: Vec<f64>,
485        long_factor: Vec<f64>,
486        #[serde(rename = "type")]
487        scaling_type: ScaledRopeType,
488    },
489    Scaled {
490        short_factor: Vec<f64>,
491        long_factor: Vec<f64>,
492        #[serde(rename = "type")]
493        scaling_type: ScaledRopeType,
494        long_mscale: f64,
495        short_mscale: f64,
496    },
497}
498
499pub struct PhiRopeConfig {
500    pub rope_scaling: Option<PhiRopeScalingConfig>,
501    pub max_position_embeddings: usize,
502    pub original_max_position_embeddings: usize,
503    pub rope_theta: f64,
504    pub head_dim: usize,
505    pub partial_rotary_factor: Option<f64>,
506}
507
508impl PhiRotaryEmbedding {
509    fn new_classic_scaled(
510        short_factor: &[f64],
511        long_factor: &[f64],
512        scaling_type: &ScaledRopeType,
513        cfg: &PhiRopeConfig,
514        dtype: DType,
515        dev: &Device,
516    ) -> Result<Self> {
517        let max_seq_len = cfg.max_position_embeddings;
518        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
519
520        // Calculate scale
521        let scale =
522            cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
523        let scaling_factor = if scale <= 1.0 {
524            1.0
525        } else {
526            match scaling_type {
527                ScaledRopeType::Su => {
528                    (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
529                }
530                ScaledRopeType::Yarn => 0.1 * scale.ln() + 1.0,
531                _ => hanzo_ml::bail!("Expected either `su` or `yarn` RoPE"),
532            }
533        };
534
535        // Calculate inv freqs for short, long
536        let inv_freq_long = (0..dim)
537            .step_by(2)
538            .enumerate()
539            .map(|(k, i)| {
540                (1f64 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
541            })
542            .collect::<Vec<_>>();
543        let inv_freq_short = (0..dim)
544            .step_by(2)
545            .enumerate()
546            .map(|(k, i)| {
547                (1f64 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
548            })
549            .collect::<Vec<_>>();
550        let inv_freq_len = inv_freq_long.len();
551
552        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
553            .to_dtype(DType::F32)?
554            .reshape((max_seq_len, 1))?;
555
556        // Calculate sin,cos for long
557        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?;
558        let freqs_long = t.matmul(&inv_freq_long)?;
559        let long_sin = freqs_long.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
560        let long_cos = freqs_long.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
561
562        // Calculate sin,cos for short
563        let inv_freq_short =
564            Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
565        let freqs_short = t.matmul(&inv_freq_short)?;
566        let short_sin = freqs_short.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
567        let short_cos = freqs_short.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
568
569        Ok(Self {
570            short_cos,
571            short_sin,
572            long_cos: Some(long_cos),
573            long_sin: Some(long_sin),
574            original_max_position_embeddings: cfg.original_max_position_embeddings,
575        })
576    }
577
578    fn new_unscaled(cfg: &PhiRopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
579        let max_seq_len = cfg.max_position_embeddings;
580        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
581
582        let inv_freq: Vec<_> = (0..dim)
583            .step_by(2)
584            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
585            .collect();
586        let inv_freq_len = inv_freq.len();
587        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
588        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
589            .to_dtype(DType::F32)?
590            .reshape((max_seq_len, 1))?;
591        let freqs = t.matmul(&inv_freq)?;
592        let sin = freqs.sin()?.to_dtype(dtype)?;
593        let cos = freqs.cos()?.to_dtype(dtype)?;
594        Ok(Self {
595            short_cos: cos,
596            short_sin: sin,
597            long_cos: None,
598            long_sin: None,
599            original_max_position_embeddings: cfg.original_max_position_embeddings,
600        })
601    }
602
603    #[allow(clippy::too_many_arguments)]
604    fn new_scaled(
605        short_factor: &[f64],
606        long_factor: &[f64],
607        scaling_type: &ScaledRopeType,
608        long_mscale: f64,
609        short_mscale: f64,
610        cfg: &PhiRopeConfig,
611        dtype: DType,
612        dev: &Device,
613    ) -> Result<Self> {
614        let max_seq_len = cfg.max_position_embeddings;
615        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
616
617        if !matches!(scaling_type, ScaledRopeType::Su) {
618            hanzo_ml::bail!("Scaled Phi3 RoPE (non-classic scaled, with mscales) must have type `su`/`longrope`.");
619        }
620
621        if short_factor.len() != dim / 2 {
622            hanzo_ml::bail!(
623                "Misaligned length {}, expected {} for `su`/`longrope` short rescale factors",
624                short_factor.len(),
625                dim / 2
626            );
627        }
628        if long_factor.len() != dim / 2 {
629            hanzo_ml::bail!(
630                "Misaligned length {}, expected {} for `su`/`longrope` long rescale factors",
631                long_factor.len(),
632                dim / 2
633            );
634        }
635
636        // Short cos/sin
637        let inv_freq_short: Vec<_> = (0..dim)
638            .step_by(2)
639            .enumerate()
640            .map(|(k, i)| {
641                1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
642            })
643            .collect();
644        let inv_freq_len_short = inv_freq_short.len();
645        let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
646        let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
647            .to_dtype(DType::F32)?
648            .reshape((max_seq_len, 1))?;
649        let freqs_short = t_short.matmul(&inv_freq_short)?;
650        let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * short_mscale)?;
651        let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * short_mscale)?;
652
653        // Long cos/sin
654        let inv_freq_long: Vec<_> = (0..dim)
655            .step_by(2)
656            .enumerate()
657            .map(|(k, i)| {
658                1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
659            })
660            .collect();
661        let inv_freq_len_long = inv_freq_long.len();
662        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
663        let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
664            .to_dtype(DType::F32)?
665            .reshape((max_seq_len, 1))?;
666        let freqs_long = t_long.matmul(&inv_freq_long)?;
667        let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * long_mscale)?;
668        let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * long_mscale)?;
669        Ok(Self {
670            short_cos: cos_short,
671            short_sin: sin_short,
672            long_cos: Some(cos_long),
673            long_sin: Some(sin_long),
674            original_max_position_embeddings: cfg.original_max_position_embeddings,
675        })
676    }
677
678    pub fn new(dtype: DType, cfg: impl Into<PhiRopeConfig>, dev: &Device) -> Result<Self> {
679        let cfg: PhiRopeConfig = cfg.into();
680
681        match &cfg.rope_scaling {
682            Some(PhiRopeScalingConfig::Classic {
683                short_factor,
684                long_factor,
685                scaling_type,
686            }) => {
687                Self::new_classic_scaled(short_factor, long_factor, scaling_type, &cfg, dtype, dev)
688            }
689
690            Some(PhiRopeScalingConfig::Scaled {
691                short_factor,
692                long_factor,
693                scaling_type,
694                long_mscale,
695                short_mscale,
696            }) => Self::new_scaled(
697                short_factor,
698                long_factor,
699                scaling_type,
700                *long_mscale,
701                *short_mscale,
702                &cfg,
703                dtype,
704                dev,
705            ),
706
707            None => Self::new_unscaled(&cfg, dtype, dev),
708        }
709    }
710
711    /// Returns (sin, cos) taking into account LongRope
712    fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
713        if self.long_cos.is_none() {
714            return (&self.short_sin, &self.short_cos);
715        }
716        let seq_len = position_ids.iter().max().unwrap() + 1;
717        if seq_len > self.original_max_position_embeddings {
718            (
719                self.long_sin.as_ref().unwrap(),
720                self.long_cos.as_ref().unwrap(),
721            )
722        } else {
723            (&self.short_sin, &self.short_cos)
724        }
725    }
726
727    pub fn forward(
728        &self,
729        q: &Tensor,
730        k: &Tensor,
731        seqlen_offsets: &[usize],
732        position_ids: &[usize],
733    ) -> Result<(Tensor, Tensor)> {
734        let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
735        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
736
737        let rot_dim = cos.dim(D::Minus1)? * 2;
738
739        // Case for Phi 3 / Phi 4 mini
740        if rot_dim != q.dim(D::Minus1)? {
741            let rot_dim = cos.dim(D::Minus1)? * 2;
742            let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
743            let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
744            let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
745            let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
746
747            let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
748                let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
749                let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
750                let q_embed = hanzo_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
751                let k_embed = hanzo_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
752                (q_embed, k_embed)
753            } else {
754                let mut q_embeds = Vec::new();
755                let mut k_embeds = Vec::new();
756                for (i, offset) in seqlen_offsets.iter().enumerate() {
757                    let cos = cos.narrow(0, *offset, seq_len)?;
758                    let sin = sin.narrow(0, *offset, seq_len)?;
759                    let q_embed = hanzo_nn::rotary_emb::rope(
760                        &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
761                        &cos,
762                        &sin,
763                    )?;
764                    let k_embed = hanzo_nn::rotary_emb::rope(
765                        &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
766                        &cos,
767                        &sin,
768                    )?;
769                    q_embeds.push(q_embed);
770                    k_embeds.push(k_embed);
771                }
772                let q_rot = Tensor::cat(&q_embeds, 0)?;
773                let k_rot = Tensor::cat(&k_embeds, 0)?;
774                (q_rot, k_rot)
775            };
776
777            Ok((
778                Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
779                Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
780            ))
781        } else if seqlen_offsets.len() == 1 {
782            let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
783            let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
784            let q_embed = hanzo_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
785            let k_embed = hanzo_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
786            Ok((q_embed, k_embed))
787        } else {
788            let mut q_embeds = Vec::new();
789            let mut k_embeds = Vec::new();
790            for (i, offset) in seqlen_offsets.iter().enumerate() {
791                let cos = cos.narrow(0, *offset, seq_len)?;
792                let sin = sin.narrow(0, *offset, seq_len)?;
793                let q_embed =
794                    hanzo_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
795                let k_embed =
796                    hanzo_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
797                q_embeds.push(q_embed);
798                k_embeds.push(k_embed);
799            }
800            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
801        }
802    }
803}
804
805/// RoPE for Llama3
806#[derive(Debug, Clone)]
807pub struct Llama3RotaryEmbedding(RotaryEmbedding);
808
809#[derive(Debug, Clone, Deserialize, Serialize, Default)]
810pub enum Llama3RopeType {
811    #[serde(rename = "llama3")]
812    Llama3,
813    #[serde(rename = "linear")]
814    Linear,
815    #[default]
816    #[serde(rename = "default")]
817    Default,
818}
819
820#[derive(Debug, Clone, Deserialize, Serialize, Default)]
821pub struct Llama3RopeConfig {
822    pub factor: f32,
823    pub low_freq_factor: Option<f32>,
824    pub high_freq_factor: Option<f32>,
825    pub original_max_position_embeddings: Option<usize>,
826    pub rope_type: Llama3RopeType,
827}
828
829fn calculate_default_inv_freq(cfg: &llama::Config) -> Vec<f32> {
830    let head_dim = cfg.hidden_size / cfg.num_attention_heads;
831    (0..head_dim)
832        .step_by(2)
833        .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
834        .collect()
835}
836
837fn calculate_default_inv_freq_llama4(cfg: &llama4::TextConfig) -> Vec<f32> {
838    let head_dim = cfg.hidden_size / cfg.num_attention_heads;
839    (0..head_dim)
840        .step_by(2)
841        .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
842        .collect()
843}
844
845// https://github.com/huggingface/transformers/blob/1392a6867f40a55dfabaf306745c67627598b1af/src/transformers/modeling_rope_utils.py#L298
846impl Llama3RotaryEmbedding {
847    pub fn new_llama3(
848        dtype: DType,
849        cfg: &llama::Config,
850        dev: &Device,
851        is_gpt_neox: bool,
852    ) -> Result<Self> {
853        match &cfg.rope_scaling {
854            None
855            | Some(Llama3RopeConfig {
856                rope_type: Llama3RopeType::Default,
857                ..
858            }) => Ok(Self(RotaryEmbedding::new(
859                cfg.rope_theta,
860                cfg.hidden_size / cfg.num_attention_heads,
861                cfg.max_position_embeddings,
862                dev,
863                is_gpt_neox,
864                dtype,
865            )?)),
866            Some(Llama3RopeConfig {
867                rope_type: Llama3RopeType::Llama3,
868                factor,
869                low_freq_factor,
870                high_freq_factor,
871                original_max_position_embeddings,
872            }) => {
873                let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
874                let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
875                let original_max_position_embeddings = original_max_position_embeddings
876                    .context("original_max_position_embeddings is required")?;
877
878                let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
879                let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
880
881                let inv_freq = calculate_default_inv_freq(cfg)
882                    .into_iter()
883                    .map(|freq| {
884                        let wavelen = 2. * PI / freq;
885                        if wavelen < high_freq_wavelen {
886                            freq
887                        } else if wavelen > low_freq_wavelen {
888                            freq / *factor
889                        } else {
890                            let smooth = (original_max_position_embeddings as f32 / wavelen
891                                - low_freq_factor)
892                                / (high_freq_factor - low_freq_factor);
893                            (1. - smooth) * freq / *factor + smooth * freq
894                        }
895                    })
896                    .collect::<Vec<_>>();
897                let inv_freq_len = inv_freq.len();
898                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
899                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
900                    .to_dtype(DType::F32)?
901                    .reshape((cfg.max_position_embeddings, 1))?;
902                let freqs = t.matmul(&inv_freq)?;
903                let sin = freqs.sin()?.to_dtype(dtype)?;
904                let cos = freqs.cos()?.to_dtype(dtype)?;
905                Ok(Self(RotaryEmbedding {
906                    sin,
907                    cos,
908                    is_gpt_neox,
909                }))
910            }
911            Some(Llama3RopeConfig {
912                rope_type: Llama3RopeType::Linear,
913                factor,
914                ..
915            }) => {
916                let inv_freq_vec = calculate_default_inv_freq(cfg)
917                    .into_iter()
918                    .map(|freq| freq / *factor)
919                    .collect::<Vec<_>>();
920                let inv_freq_len = inv_freq_vec.len();
921                let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
922                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
923                    .to_dtype(DType::F32)?
924                    .reshape((cfg.max_position_embeddings, 1))?;
925                let freqs = t.matmul(&inv_freq)?;
926                let sin = freqs.sin()?.to_dtype(dtype)?;
927                let cos = freqs.cos()?.to_dtype(dtype)?;
928                Ok(Self(RotaryEmbedding {
929                    sin,
930                    cos,
931                    is_gpt_neox,
932                }))
933            }
934        }
935    }
936
937    pub fn new_llama4(
938        dtype: DType,
939        cfg: &llama4::TextConfig,
940        dev: &Device,
941        is_gpt_neox: bool,
942    ) -> Result<Self> {
943        match &cfg.rope_scaling {
944            None
945            | Some(Llama3RopeConfig {
946                rope_type: Llama3RopeType::Default,
947                ..
948            }) => Ok(Self(RotaryEmbedding::new(
949                cfg.rope_theta,
950                cfg.hidden_size / cfg.num_attention_heads,
951                cfg.max_position_embeddings,
952                dev,
953                is_gpt_neox,
954                dtype,
955            )?)),
956            Some(Llama3RopeConfig {
957                rope_type: Llama3RopeType::Llama3,
958                factor,
959                low_freq_factor,
960                high_freq_factor,
961                original_max_position_embeddings,
962            }) => {
963                let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
964                let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
965                let original_max_position_embeddings = original_max_position_embeddings
966                    .context("original_max_position_embeddings is required")?;
967
968                let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
969                let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
970
971                let inv_freq = calculate_default_inv_freq_llama4(cfg)
972                    .into_iter()
973                    .map(|freq| {
974                        let wavelen = 2. * PI / freq;
975                        if wavelen < high_freq_wavelen {
976                            freq
977                        } else if wavelen > low_freq_wavelen {
978                            freq / *factor
979                        } else {
980                            let smooth = (original_max_position_embeddings as f32 / wavelen
981                                - low_freq_factor)
982                                / (high_freq_factor - low_freq_factor);
983                            (1. - smooth) * freq / *factor + smooth * freq
984                        }
985                    })
986                    .collect::<Vec<_>>();
987                let inv_freq_len = inv_freq.len();
988                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
989                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
990                    .to_dtype(DType::F32)?
991                    .reshape((cfg.max_position_embeddings, 1))?;
992                let freqs = t.matmul(&inv_freq)?;
993                let sin = freqs.sin()?.to_dtype(dtype)?;
994                let cos = freqs.cos()?.to_dtype(dtype)?;
995                Ok(Self(RotaryEmbedding {
996                    sin,
997                    cos,
998                    is_gpt_neox,
999                }))
1000            }
1001            Some(Llama3RopeConfig {
1002                rope_type: Llama3RopeType::Linear,
1003                factor,
1004                ..
1005            }) => {
1006                let inv_freq_vec = calculate_default_inv_freq_llama4(cfg)
1007                    .into_iter()
1008                    .map(|freq| freq / *factor)
1009                    .collect::<Vec<_>>();
1010                let inv_freq_len = inv_freq_vec.len();
1011                let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
1012                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1013                    .to_dtype(DType::F32)?
1014                    .reshape((cfg.max_position_embeddings, 1))?;
1015                let freqs = t.matmul(&inv_freq)?;
1016                let sin = freqs.sin()?.to_dtype(dtype)?;
1017                let cos = freqs.cos()?.to_dtype(dtype)?;
1018                Ok(Self(RotaryEmbedding {
1019                    sin,
1020                    cos,
1021                    is_gpt_neox,
1022                }))
1023            }
1024        }
1025    }
1026
1027    pub fn new_mllama3(
1028        dtype: DType,
1029        cfg: &MLlamaTextConfig,
1030        dev: &Device,
1031        is_gpt_neox: bool,
1032    ) -> Result<Self> {
1033        match &cfg.rope_scaling {
1034            None
1035            | Some(MLlamaRopeScaling {
1036                rope_type: MLlamaRopeType::Default,
1037                ..
1038            }) => Ok(Self(RotaryEmbedding::new(
1039                cfg.rope_theta,
1040                cfg.hidden_size / cfg.num_attention_heads,
1041                cfg.max_position_embeddings,
1042                dev,
1043                is_gpt_neox,
1044                dtype,
1045            )?)),
1046            Some(MLlamaRopeScaling {
1047                rope_type: MLlamaRopeType::Llama3,
1048                original_max_position_embeddings,
1049                factor,
1050                attention_factor: _,
1051                beta_fast: _,
1052                beta_slow: _,
1053                short_factor: _,
1054                long_factor: _,
1055                low_freq_factor,
1056                high_freq_factor,
1057            }) => {
1058                let factor = factor.context("MLlama Llama3 RoPE needs `factor` parameter.")?;
1059                let low_freq_factor = low_freq_factor
1060                    .context("MLlama Llama3 RoPE needs `low_freq_factor` parameter.")?;
1061                let high_freq_factor = high_freq_factor
1062                    .context("MLlama Llama3 RoPE needs `high_freq_factor` parameter.")?;
1063
1064                let low_freq_wavelen = *original_max_position_embeddings as f32 / low_freq_factor;
1065                let high_freq_wavelen = *original_max_position_embeddings as f32 / high_freq_factor;
1066
1067                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1068
1069                let inv_freq = (0..head_dim)
1070                    .step_by(2)
1071                    .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
1072                    .map(|freq| {
1073                        let wavelen = 2. * PI / freq;
1074                        if wavelen < high_freq_wavelen {
1075                            freq
1076                        } else if wavelen > low_freq_wavelen {
1077                            freq / factor
1078                        } else {
1079                            let smooth = (*original_max_position_embeddings as f32 / wavelen
1080                                - low_freq_factor)
1081                                / (high_freq_factor - low_freq_factor);
1082                            (1. - smooth) * freq / factor + smooth * freq
1083                        }
1084                    })
1085                    .collect::<Vec<_>>();
1086                let inv_freq_len = inv_freq.len();
1087                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1088
1089                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1090                    .to_dtype(DType::F32)?
1091                    .reshape((cfg.max_position_embeddings, 1))?;
1092                let freqs = t.matmul(&inv_freq)?;
1093                let sin = freqs.sin()?.to_dtype(dtype)?;
1094                let cos = freqs.cos()?.to_dtype(dtype)?;
1095                Ok(Self(RotaryEmbedding {
1096                    sin,
1097                    cos,
1098                    is_gpt_neox,
1099                }))
1100            }
1101            Some(MLlamaRopeScaling {
1102                rope_type: other, ..
1103            }) => {
1104                hanzo_ml::bail!(
1105                    "MLlama doesn't support any other RoPE type than `llama3`, got {other:?}"
1106                )
1107            }
1108        }
1109    }
1110
1111    pub fn forward(
1112        &self,
1113        q: &Tensor,
1114        k: &Tensor,
1115        seqlen_offsets: &[usize],
1116    ) -> Result<(Tensor, Tensor)> {
1117        self.0.forward(q, k, seqlen_offsets)
1118    }
1119
1120    pub fn forward_q_norm(
1121        &self,
1122        q: &Tensor,
1123        q_weight: &Tensor,
1124        q_eps: f64,
1125        seqlen_offsets: &[usize],
1126    ) -> Result<Tensor> {
1127        self.0.forward_q_norm(q, q_weight, q_eps, seqlen_offsets)
1128    }
1129
1130    #[allow(clippy::too_many_arguments)]
1131    pub fn forward_qk_norm(
1132        &self,
1133        q: &Tensor,
1134        k: &Tensor,
1135        q_weight: &Tensor,
1136        k_weight: &Tensor,
1137        q_eps: f64,
1138        k_eps: f64,
1139        seqlen_offsets: &[usize],
1140    ) -> Result<(Tensor, Tensor)> {
1141        self.0
1142            .forward_qk_norm(q, k, q_weight, k_weight, q_eps, k_eps, seqlen_offsets)
1143    }
1144}
1145
1146/// RoPE for SmolLm3
1147#[derive(Debug, Clone)]
1148pub struct SmolLm3RotaryEmbedding(RotaryEmbedding);
1149
1150#[derive(Debug, Clone, Deserialize, Serialize, Default)]
1151pub enum SmolLm3RopeType {
1152    #[serde(rename = "llama3")]
1153    Llama3,
1154    #[serde(rename = "linear")]
1155    Linear,
1156    #[default]
1157    #[serde(rename = "default")]
1158    Default,
1159}
1160
1161#[derive(Debug, Clone, Deserialize, Serialize, Default)]
1162pub struct SmolLm3RopeConfig {
1163    pub factor: f32,
1164    pub low_freq_factor: Option<f32>,
1165    pub high_freq_factor: Option<f32>,
1166    pub original_max_position_embeddings: Option<usize>,
1167    pub rope_type: SmolLm3RopeType,
1168}
1169
1170fn calculate_default_inv_freq_smollm3(cfg: &smollm3::Config) -> Vec<f32> {
1171    let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1172    (0..head_dim)
1173        .step_by(2)
1174        .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
1175        .collect()
1176}
1177
1178impl SmolLm3RotaryEmbedding {
1179    pub fn new_llama3(
1180        dtype: DType,
1181        cfg: &smollm3::Config,
1182        dev: &Device,
1183        is_gpt_neox: bool,
1184    ) -> Result<Self> {
1185        match &cfg.rope_scaling {
1186            None
1187            | Some(SmolLm3RopeConfig {
1188                rope_type: SmolLm3RopeType::Default,
1189                ..
1190            }) => Ok(Self(RotaryEmbedding::new(
1191                cfg.rope_theta,
1192                cfg.hidden_size / cfg.num_attention_heads,
1193                cfg.max_position_embeddings,
1194                dev,
1195                is_gpt_neox,
1196                dtype,
1197            )?)),
1198            Some(SmolLm3RopeConfig {
1199                rope_type: SmolLm3RopeType::Llama3,
1200                factor,
1201                low_freq_factor,
1202                high_freq_factor,
1203                original_max_position_embeddings,
1204            }) => {
1205                let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
1206                let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
1207                let original_max_position_embeddings = original_max_position_embeddings
1208                    .context("original_max_position_embeddings is required")?;
1209
1210                let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
1211                let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
1212
1213                let inv_freq = calculate_default_inv_freq_smollm3(cfg)
1214                    .into_iter()
1215                    .map(|freq| {
1216                        let wavelen = 2. * PI / freq;
1217                        if wavelen < high_freq_wavelen {
1218                            freq
1219                        } else if wavelen > low_freq_wavelen {
1220                            freq / *factor
1221                        } else {
1222                            let smooth = (original_max_position_embeddings as f32 / wavelen
1223                                - low_freq_factor)
1224                                / (high_freq_factor - low_freq_factor);
1225                            (1. - smooth) * freq / *factor + smooth * freq
1226                        }
1227                    })
1228                    .collect::<Vec<_>>();
1229                let inv_freq_len = inv_freq.len();
1230                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1231                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1232                    .to_dtype(DType::F32)?
1233                    .reshape((cfg.max_position_embeddings, 1))?;
1234                let freqs = t.matmul(&inv_freq)?;
1235                let sin = freqs.sin()?.to_dtype(dtype)?;
1236                let cos = freqs.cos()?.to_dtype(dtype)?;
1237                Ok(Self(RotaryEmbedding {
1238                    sin,
1239                    cos,
1240                    is_gpt_neox,
1241                }))
1242            }
1243            Some(SmolLm3RopeConfig {
1244                rope_type: SmolLm3RopeType::Linear,
1245                factor,
1246                ..
1247            }) => {
1248                let inv_freq_vec = calculate_default_inv_freq_smollm3(cfg)
1249                    .into_iter()
1250                    .map(|freq| freq / *factor)
1251                    .collect::<Vec<_>>();
1252                let inv_freq_len = inv_freq_vec.len();
1253                let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
1254                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1255                    .to_dtype(DType::F32)?
1256                    .reshape((cfg.max_position_embeddings, 1))?;
1257                let freqs = t.matmul(&inv_freq)?;
1258                let sin = freqs.sin()?.to_dtype(dtype)?;
1259                let cos = freqs.cos()?.to_dtype(dtype)?;
1260                Ok(Self(RotaryEmbedding {
1261                    sin,
1262                    cos,
1263                    is_gpt_neox,
1264                }))
1265            }
1266        }
1267    }
1268
1269    pub fn forward(
1270        &self,
1271        q: &Tensor,
1272        k: &Tensor,
1273        seqlen_offsets: &[usize],
1274    ) -> Result<(Tensor, Tensor)> {
1275        self.0.forward(q, k, seqlen_offsets)
1276    }
1277
1278    pub fn forward_q_norm(
1279        &self,
1280        q: &Tensor,
1281        q_weight: &Tensor,
1282        q_eps: f64,
1283        seqlen_offsets: &[usize],
1284    ) -> Result<Tensor> {
1285        self.0.forward_q_norm(q, q_weight, q_eps, seqlen_offsets)
1286    }
1287
1288    #[allow(clippy::too_many_arguments)]
1289    pub fn forward_qk_norm(
1290        &self,
1291        q: &Tensor,
1292        k: &Tensor,
1293        q_weight: &Tensor,
1294        k_weight: &Tensor,
1295        q_eps: f64,
1296        k_eps: f64,
1297        seqlen_offsets: &[usize],
1298    ) -> Result<(Tensor, Tensor)> {
1299        self.0
1300            .forward_qk_norm(q, k, q_weight, k_weight, q_eps, k_eps, seqlen_offsets)
1301    }
1302}
1303
1304// https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L107
1305#[derive(Debug, Clone)]
1306pub struct Qwen2VLRotaryEmbedding {
1307    inv_freq: Tensor,
1308    mrope_section: Vec<usize>,
1309}
1310
1311impl Qwen2VLRotaryEmbedding {
1312    pub fn new(
1313        base: f32,
1314        head_dim: usize,
1315        device: &Device,
1316        mrope_section: Vec<usize>,
1317    ) -> Result<Self> {
1318        let inv_freq: Vec<_> = (0..head_dim)
1319            .step_by(2)
1320            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1321            .collect();
1322        let inv_freq_len = inv_freq.len();
1323        let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1324        Ok(Self {
1325            inv_freq,
1326            mrope_section,
1327        })
1328    }
1329
1330    /// (cos, sin)
1331    pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1332        let inv_freq_expanded =
1333            self.inv_freq
1334                .reshape((1, 1, (), 1))?
1335                .repeat((3, position_ids.dim(1)?, 1, 1))?;
1336        let position_ids_expanded = position_ids.unsqueeze(2)?;
1337        let freqs = inv_freq_expanded
1338            .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1339            .transpose(2, 3)?;
1340        let cos = freqs.cos()?;
1341        let sin = freqs.sin()?;
1342
1343        let cos = Tensor::cat(
1344            &cos.split(&self.mrope_section, D::Minus1)?
1345                .into_iter()
1346                .enumerate()
1347                .map(|(i, m)| m.i(i % 3))
1348                .collect::<Result<Vec<_>>>()?,
1349            D::Minus1,
1350        )?
1351        .squeeze(0)?
1352        .to_dtype(dtype)?
1353        .contiguous()?;
1354        let sin = Tensor::cat(
1355            &sin.split(&self.mrope_section, D::Minus1)?
1356                .into_iter()
1357                .enumerate()
1358                .map(|(i, m)| m.i(i % 3))
1359                .collect::<Result<Vec<_>>>()?,
1360            D::Minus1,
1361        )?
1362        .squeeze(0)?
1363        .to_dtype(dtype)?
1364        .contiguous()?;
1365
1366        Ok((cos, sin))
1367    }
1368
1369    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L203
1370    pub fn forward(
1371        &self,
1372        (cos, sin): &(Tensor, Tensor),
1373        q: &mut Tensor,
1374        k: &mut Tensor,
1375    ) -> Result<()> {
1376        *q = hanzo_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1377        *k = hanzo_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1378        Ok(())
1379    }
1380
1381    #[allow(clippy::too_many_arguments)]
1382    pub fn forward_qk_norm(
1383        &self,
1384        (cos, sin): &(Tensor, Tensor),
1385        q: &Tensor,
1386        k: &Tensor,
1387        q_weight: &Tensor,
1388        k_weight: &Tensor,
1389        q_eps: f64,
1390        k_eps: f64,
1391    ) -> Result<(Tensor, Tensor)> {
1392        qk_rms_norm_mrope(q, k, q_weight, k_weight, q_eps, k_eps, cos, sin, true)
1393    }
1394}
1395
1396/// Qwen3 VL uses **interleaved** MRoPE (not chunked like Qwen2 VL).
1397/// Frequencies are arranged as THW THW THW... TTTT pattern.
1398/// See `apply_interleaved_mrope` in modeling_qwen3_vl.py.
1399#[derive(Debug, Clone)]
1400pub struct Qwen3VLRotaryEmbedding {
1401    inv_freq: Tensor,
1402    /// Precomputed interleave indices for H (dim=1, offset=1) and W (dim=2, offset=2).
1403    /// Stored as (indices_1d, dim_idx) pairs. Created once at init to avoid CPU->GPU sync per step.
1404    interleave_indices: Vec<(Tensor, usize)>,
1405}
1406
1407impl Qwen3VLRotaryEmbedding {
1408    pub fn new(
1409        base: f32,
1410        head_dim: usize,
1411        device: &Device,
1412        mrope_section: Vec<usize>,
1413    ) -> Result<Self> {
1414        let inv_freq: Vec<_> = (0..head_dim)
1415            .step_by(2)
1416            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1417            .collect();
1418        let inv_freq_len = inv_freq.len();
1419        let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1420
1421        // Precompute interleave index tensors for H (dim=1, offset=1) and W (dim=2, offset=2)
1422        // to avoid CPU->GPU sync from Tensor::from_vec on every decode step.
1423        let half_dim = head_dim / 2;
1424        let mut interleave_indices = Vec::new();
1425        for (dim_idx, offset) in [(1usize, 1usize), (2usize, 2usize)] {
1426            let indices: Vec<u32> = (offset..)
1427                .step_by(3)
1428                .take(mrope_section[dim_idx])
1429                .filter(|&i| i < half_dim)
1430                .map(|i| i as u32)
1431                .collect();
1432            if !indices.is_empty() {
1433                let num = indices.len();
1434                let idx_tensor = Tensor::from_vec(indices, (num,), device)?;
1435                interleave_indices.push((idx_tensor, dim_idx));
1436            }
1437        }
1438
1439        Ok(Self {
1440            inv_freq,
1441            interleave_indices,
1442        })
1443    }
1444
1445    /// Compute (cos, sin) from 3D position_ids of shape (3, batch, seq_len).
1446    /// Applies interleaved MRoPE: starts with temporal freqs, then overwrites
1447    /// H positions (slice 1::3) and W positions (slice 2::3) within their sections.
1448    pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1449        // inv_freq: (head_dim/2,) -> (1, 1, head_dim/2, 1) -> expand to (3, batch, head_dim/2, 1)
1450        let inv_freq_expanded =
1451            self.inv_freq
1452                .reshape((1, 1, (), 1))?
1453                .repeat((3, position_ids.dim(1)?, 1, 1))?;
1454        // position_ids: (3, batch, seq_len) -> (3, batch, 1, seq_len)
1455        let position_ids_expanded = position_ids.unsqueeze(2)?;
1456        // freqs: (3, batch, head_dim/2, 1) @ (3, batch, 1, seq_len) -> (3, batch, head_dim/2, seq_len)
1457        // -> transpose -> (3, batch, seq_len, head_dim/2)
1458        let freqs = inv_freq_expanded
1459            .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1460            .transpose(2, 3)?;
1461
1462        // Apply interleaved MRoPE: start with temporal, overwrite H and W at interleaved positions
1463        // freqs_t = freqs[0] as base (all temporal)
1464        let mut freqs_t = freqs.i(0)?.contiguous()?;
1465        let (batch, seq_len, _) = freqs_t.dims3()?;
1466
1467        // For H (dim=1) and W (dim=2), overwrite interleaved positions using precomputed indices
1468        for (idx_tensor, dim_idx) in &self.interleave_indices {
1469            let freqs_dim = freqs.i(*dim_idx)?.contiguous()?;
1470            let num_indices = idx_tensor.dim(0)?;
1471            let idx_expanded = idx_tensor
1472                .reshape((1, 1, num_indices))?
1473                .repeat((batch, seq_len, 1))?;
1474            let src_vals = freqs_dim.gather(&idx_expanded, D::Minus1)?;
1475            freqs_t = freqs_t.scatter(&idx_expanded, &src_vals, D::Minus1)?;
1476        }
1477
1478        // cos/sin from freqs_t -> (batch, seq_len, head_dim/2)
1479        // hanzo-ml's rope() expects half-dim cos/sin and handles both halves internally
1480        let cos = freqs_t.cos()?.to_dtype(dtype)?.contiguous()?;
1481        let sin = freqs_t.sin()?.to_dtype(dtype)?.contiguous()?;
1482        Ok((cos, sin))
1483    }
1484
1485    pub fn forward(
1486        &self,
1487        (cos, sin): &(Tensor, Tensor),
1488        q: &mut Tensor,
1489        k: &mut Tensor,
1490    ) -> Result<()> {
1491        *q = hanzo_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1492        *k = hanzo_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1493        Ok(())
1494    }
1495
1496    #[allow(clippy::too_many_arguments)]
1497    pub fn forward_qk_norm(
1498        &self,
1499        (cos, sin): &(Tensor, Tensor),
1500        q: &Tensor,
1501        k: &Tensor,
1502        q_weight: &Tensor,
1503        k_weight: &Tensor,
1504        q_eps: f64,
1505        k_eps: f64,
1506    ) -> Result<(Tensor, Tensor)> {
1507        qk_rms_norm_mrope(q, k, q_weight, k_weight, q_eps, k_eps, cos, sin, true)
1508    }
1509}
1510
1511#[derive(Debug, Clone)]
1512pub struct Qwen2_5VLRotaryEmbedding {
1513    inv_freq: Tensor,
1514    mrope_section: Vec<usize>,
1515}
1516
1517impl Qwen2_5VLRotaryEmbedding {
1518    pub fn new(
1519        base: f32,
1520        head_dim: usize,
1521        device: &Device,
1522        mrope_section: Vec<usize>,
1523    ) -> Result<Self> {
1524        let inv_freq: Vec<_> = (0..head_dim)
1525            .step_by(2)
1526            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1527            .collect();
1528        let inv_freq_len = inv_freq.len();
1529        let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1530        Ok(Self {
1531            inv_freq,
1532            mrope_section,
1533        })
1534    }
1535
1536    /// (cos, sin)
1537    pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1538        let inv_freq_expanded =
1539            self.inv_freq
1540                .reshape((1, 1, (), 1))?
1541                .repeat((3, position_ids.dim(1)?, 1, 1))?;
1542        let position_ids_expanded = position_ids.unsqueeze(2)?;
1543        let freqs = inv_freq_expanded
1544            .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1545            .transpose(2, 3)?;
1546        let cos = freqs.cos()?;
1547        let sin = freqs.sin()?;
1548
1549        let cos = Tensor::cat(
1550            &cos.split(&self.mrope_section, D::Minus1)?
1551                .into_iter()
1552                .enumerate()
1553                .map(|(i, m)| m.i(i % 3))
1554                .collect::<Result<Vec<_>>>()?,
1555            D::Minus1,
1556        )?
1557        .squeeze(0)?
1558        .to_dtype(dtype)?
1559        .contiguous()?;
1560        let sin = Tensor::cat(
1561            &sin.split(&self.mrope_section, D::Minus1)?
1562                .into_iter()
1563                .enumerate()
1564                .map(|(i, m)| m.i(i % 3))
1565                .collect::<Result<Vec<_>>>()?,
1566            D::Minus1,
1567        )?
1568        .squeeze(0)?
1569        .to_dtype(dtype)?
1570        .contiguous()?;
1571
1572        Ok((cos, sin))
1573    }
1574
1575    pub fn forward(
1576        &self,
1577        (cos, sin): &(Tensor, Tensor),
1578        q: &mut Tensor,
1579        k: &mut Tensor,
1580    ) -> Result<()> {
1581        *q = hanzo_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1582        *k = hanzo_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1583        Ok(())
1584    }
1585}
1586
1587#[derive(Debug, Clone)]
1588pub struct DeepSeekV2RotaryEmbedding {
1589    sin: Tensor,
1590    cos: Tensor,
1591}
1592
1593#[derive(Debug, Clone, Deserialize, Serialize)]
1594#[serde(untagged)]
1595pub enum DeepSeekV2RopeScaling {
1596    Yarn {
1597        original_max_position_embeddings: usize,
1598        beta_fast: f32,
1599        beta_slow: f32,
1600        mscale: f32,
1601        mscale_all_dim: f32,
1602        factor: f32,
1603        #[serde(rename = "type")]
1604        scaling_type: ScaledRopeType,
1605    },
1606    LinearOrDynamic {
1607        #[serde(rename = "type")]
1608        scaling_type: ScaledRopeType,
1609        factor: f64,
1610    },
1611}
1612
1613pub struct DeepSeekV2RopeConfig {
1614    pub rope_scaling: Option<DeepSeekV2RopeScaling>,
1615    pub max_position_embeddings: usize,
1616    pub rope_theta: f32,
1617    pub qk_rope_head_dim: usize,
1618}
1619
1620impl DeepSeekV2RotaryEmbedding {
1621    fn new_unscaled(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1622        let max_seq_len = cfg.max_position_embeddings;
1623        let dim = cfg.qk_rope_head_dim;
1624
1625        let inv_freq: Vec<_> = (0..dim)
1626            .step_by(2)
1627            .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
1628            .collect();
1629        let inv_freq_len = inv_freq.len();
1630        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1631        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1632            .to_dtype(DType::F32)?
1633            .reshape((max_seq_len, 1))?;
1634        let freqs = t.matmul(&inv_freq)?;
1635
1636        let sin = freqs.sin()?.to_dtype(dtype)?;
1637        let cos = freqs.cos()?.to_dtype(dtype)?;
1638
1639        Ok(Self { sin, cos })
1640    }
1641
1642    fn yarn_find_correction_dim(
1643        num_rot: f32,
1644        dim: usize,
1645        base: f32,
1646        max_position_embeddings: usize,
1647    ) -> f32 {
1648        (dim as f32 * (max_position_embeddings as f32 / (num_rot * 2. * PI)).ln())
1649            / (2. * base.ln())
1650    }
1651
1652    fn yarn_find_correction_range(
1653        low_rot: f32,
1654        high_rot: f32,
1655        dim: usize,
1656        base: f32,
1657        max_position_embeddings: usize,
1658    ) -> (f32, f32) {
1659        let low =
1660            Self::yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings).floor();
1661        let high =
1662            Self::yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings).ceil();
1663        (low.max(0.), high.min(dim as f32 - 1.))
1664    }
1665
1666    fn yarn_linear_ramp_mask(min: f32, mut max: f32, dim: usize, dev: &Device) -> Result<Tensor> {
1667        if min == max {
1668            // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/604d5664dddd88a0433dbae533b7fe9472482de0/modeling_deepseek.py#L255
1669            max += 0.001;
1670        }
1671        let linear_func =
1672            ((Tensor::arange(0f32, dim as f32, dev)? - min as f64)? / (max as f64 - min as f64))?;
1673        linear_func.clamp(0., 1)
1674    }
1675
1676    pub(crate) fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 {
1677        if scale <= 1. {
1678            return 1.;
1679        }
1680        0.1 * mscale * scale.ln() + 1.
1681    }
1682
1683    #[allow(clippy::too_many_arguments)]
1684    fn new_yarn(
1685        cfg: &DeepSeekV2RopeConfig,
1686        dtype: DType,
1687        dev: &Device,
1688        original_max_position_embeddings: usize,
1689        beta_fast: f32,
1690        beta_slow: f32,
1691        factor: f32,
1692        mscale: f32,
1693        mscale_all_dim: f32,
1694    ) -> Result<Self> {
1695        let freq_extra: Vec<_> = (0..cfg.qk_rope_head_dim)
1696            .step_by(2)
1697            .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32))
1698            .collect();
1699        let freq_extra_len = freq_extra.len();
1700        let freq_extra = Tensor::from_vec(freq_extra, freq_extra_len, dev)?;
1701        let freq_inter: Vec<_> = (0..cfg.qk_rope_head_dim)
1702            .step_by(2)
1703            .map(|i| 1f32 / (factor * cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32)))
1704            .collect();
1705        let freq_inter_len = freq_inter.len();
1706        let freq_inter = Tensor::from_vec(freq_inter, (1, freq_inter_len), dev)?;
1707
1708        let (low, high) = Self::yarn_find_correction_range(
1709            beta_fast,
1710            beta_slow,
1711            cfg.qk_rope_head_dim,
1712            cfg.rope_theta,
1713            original_max_position_embeddings,
1714        );
1715        let inv_freq_mask =
1716            (1. - Self::yarn_linear_ramp_mask(low, high, cfg.qk_rope_head_dim / 2, dev)?)?;
1717        let inv_freq = freq_inter
1718            .broadcast_mul(&(1. - &inv_freq_mask)?)?
1719            .broadcast_add(&freq_extra.broadcast_mul(&inv_freq_mask)?)?;
1720
1721        let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1722            .to_dtype(DType::F32)?
1723            .reshape((cfg.max_position_embeddings, 1))?;
1724        let freqs = t.matmul(&inv_freq)?;
1725
1726        let mscale =
1727            Self::yarn_get_mscale(factor, mscale) / Self::yarn_get_mscale(factor, mscale_all_dim);
1728        let sin = (freqs.sin()? * mscale as f64)?.to_dtype(dtype)?;
1729        let cos = (freqs.cos()? * mscale as f64)?.to_dtype(dtype)?;
1730
1731        Ok(Self { sin, cos })
1732    }
1733
1734    pub fn new(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1735        match &cfg.rope_scaling {
1736            Some(DeepSeekV2RopeScaling::LinearOrDynamic {
1737                scaling_type: _,
1738                factor: _,
1739            }) => hanzo_ml::bail!("linear and dynamic rope are not implemented yet!"),
1740            Some(DeepSeekV2RopeScaling::Yarn {
1741                original_max_position_embeddings,
1742                beta_fast,
1743                beta_slow,
1744                factor,
1745                mscale,
1746                mscale_all_dim,
1747                scaling_type: _,
1748            }) => Self::new_yarn(
1749                cfg,
1750                dtype,
1751                dev,
1752                *original_max_position_embeddings,
1753                *beta_fast,
1754                *beta_slow,
1755                *factor,
1756                *mscale,
1757                *mscale_all_dim,
1758            ),
1759            None => Self::new_unscaled(cfg, dtype, dev),
1760        }
1761    }
1762
1763    pub fn forward(
1764        &self,
1765        q: &Tensor,
1766        k: &Tensor,
1767        seqlen_offsets: &[usize],
1768    ) -> Result<(Tensor, Tensor)> {
1769        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1770
1771        if seqlen_offsets.len() == 1 {
1772            let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1773            let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1774            let q_embed = hanzo_nn::rotary_emb::rope_i(&q.contiguous()?, &cos, &sin)?;
1775            let k_embed = hanzo_nn::rotary_emb::rope_i(&k.contiguous()?, &cos, &sin)?;
1776            Ok((q_embed, k_embed))
1777        } else {
1778            let mut q_embeds = Vec::new();
1779            let mut k_embeds = Vec::new();
1780            for (i, offset) in seqlen_offsets.iter().enumerate() {
1781                let cos = self.cos.narrow(0, *offset, seq_len)?;
1782                let sin = self.sin.narrow(0, *offset, seq_len)?;
1783                let q_embed =
1784                    hanzo_nn::rotary_emb::rope_i(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
1785                let k_embed =
1786                    hanzo_nn::rotary_emb::rope_i(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
1787                q_embeds.push(q_embed);
1788                k_embeds.push(k_embed);
1789            }
1790            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1791        }
1792    }
1793}
1794
1795#[derive(Debug, Clone)]
1796pub struct Phi4MMRotaryEmbedding {
1797    short_sin: Tensor,
1798    short_cos: Tensor,
1799    long_cos: Option<Tensor>,
1800    long_sin: Option<Tensor>,
1801    original_max_position_embeddings: usize,
1802}
1803
1804#[derive(Debug, Clone, Default, Deserialize, Serialize)]
1805#[serde(rename_all = "lowercase")]
1806pub enum Phi4MMScaledRopeType {
1807    #[serde(alias = "longrope")]
1808    LongRope,
1809    #[default]
1810    Default,
1811}
1812
1813#[derive(Debug, Clone, Deserialize, Serialize)]
1814pub struct Phi4MMRopeScalingConfig {
1815    short_factor: Option<Vec<f64>>,
1816    long_factor: Option<Vec<f64>>,
1817    #[serde(rename = "type")]
1818    scaling_type: Phi4MMScaledRopeType,
1819}
1820
1821impl Phi4MMRotaryEmbedding {
1822    fn new_unscaled(cfg: &Phi4MMConfig, dtype: DType, dev: &Device) -> Result<Self> {
1823        let max_seq_len = cfg.max_position_embeddings;
1824        let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1825
1826        let inv_freq: Vec<_> = (0..dim)
1827            .step_by(2)
1828            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1829            .collect();
1830        let inv_freq_len = inv_freq.len();
1831        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1832        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1833            .to_dtype(DType::F32)?
1834            .reshape((max_seq_len, 1))?;
1835        let freqs = t.matmul(&inv_freq)?;
1836        let sin = freqs.sin()?.to_dtype(dtype)?;
1837        let cos = freqs.cos()?.to_dtype(dtype)?;
1838        Ok(Self {
1839            short_cos: cos,
1840            short_sin: sin,
1841            long_cos: None,
1842            long_sin: None,
1843            original_max_position_embeddings: cfg.original_max_position_embeddings,
1844        })
1845    }
1846
1847    #[allow(clippy::too_many_arguments)]
1848    fn new_longrope(
1849        short_factor: &[f64],
1850        long_factor: &[f64],
1851        cfg: &Phi4MMConfig,
1852        dtype: DType,
1853        dev: &Device,
1854    ) -> Result<Self> {
1855        let max_seq_len = cfg.max_position_embeddings;
1856        let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1857
1858        // Calculate scale
1859        let scale =
1860            cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
1861        let scaling_factor = if scale <= 1.0 {
1862            1.0
1863        } else {
1864            (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
1865        };
1866
1867        // Short cos/sin
1868        let inv_freq_short: Vec<_> = (0..dim)
1869            .step_by(2)
1870            .enumerate()
1871            .map(|(k, i)| {
1872                1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1873            })
1874            .collect();
1875        let inv_freq_len_short = inv_freq_short.len();
1876        let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
1877        let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
1878            .to_dtype(DType::F32)?
1879            .reshape((max_seq_len, 1))?;
1880        let freqs_short = t_short.matmul(&inv_freq_short)?;
1881        let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * scaling_factor)?;
1882        let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * scaling_factor)?;
1883
1884        // Long cos/sin
1885        let inv_freq_long: Vec<_> = (0..dim)
1886            .step_by(2)
1887            .enumerate()
1888            .map(|(k, i)| {
1889                1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1890            })
1891            .collect();
1892        let inv_freq_len_long = inv_freq_long.len();
1893        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
1894        let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
1895            .to_dtype(DType::F32)?
1896            .reshape((max_seq_len, 1))?;
1897        let freqs_long = t_long.matmul(&inv_freq_long)?;
1898        let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * scaling_factor)?;
1899        let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * scaling_factor)?;
1900
1901        Ok(Self {
1902            short_cos: cos_short,
1903            short_sin: sin_short,
1904            long_cos: Some(cos_long),
1905            long_sin: Some(sin_long),
1906            original_max_position_embeddings: cfg.original_max_position_embeddings,
1907        })
1908    }
1909
1910    pub fn new(dtype: DType, cfg: &Phi4MMConfig, dev: &Device) -> Result<Self> {
1911        match &cfg.rope_scaling {
1912            Some(Phi4MMRopeScalingConfig {
1913                scaling_type: Phi4MMScaledRopeType::LongRope,
1914                short_factor: Some(short_factor),
1915                long_factor: Some(long_factor),
1916            }) => Self::new_longrope(short_factor, long_factor, cfg, dtype, dev),
1917
1918            _ => Self::new_unscaled(cfg, dtype, dev),
1919        }
1920    }
1921
1922    /// Returns (sin, cos) taking into account LongRope
1923    fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
1924        if self.long_cos.is_none() {
1925            return (&self.short_sin, &self.short_cos);
1926        }
1927        let seq_len = position_ids.iter().max().unwrap() + 1;
1928        if seq_len > self.original_max_position_embeddings {
1929            (
1930                self.long_sin.as_ref().unwrap(),
1931                self.long_cos.as_ref().unwrap(),
1932            )
1933        } else {
1934            (&self.short_sin, &self.short_cos)
1935        }
1936    }
1937
1938    pub fn forward(
1939        &self,
1940        q: &Tensor,
1941        k: &Tensor,
1942        seqlen_offsets: &[usize],
1943        position_ids: &[usize],
1944    ) -> Result<(Tensor, Tensor)> {
1945        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1946        let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
1947
1948        let rot_dim = cos.dim(D::Minus1)? * 2;
1949        let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
1950        let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
1951        let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
1952        let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
1953
1954        let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
1955            let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
1956            let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
1957            let q_embed = hanzo_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
1958            let k_embed = hanzo_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
1959            (q_embed, k_embed)
1960        } else {
1961            let mut q_embeds = Vec::new();
1962            let mut k_embeds = Vec::new();
1963            for (i, offset) in seqlen_offsets.iter().enumerate() {
1964                let cos = cos.narrow(0, *offset, seq_len)?;
1965                let sin = sin.narrow(0, *offset, seq_len)?;
1966                let q_embed = hanzo_nn::rotary_emb::rope(
1967                    &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1968                    &cos,
1969                    &sin,
1970                )?;
1971                let k_embed = hanzo_nn::rotary_emb::rope(
1972                    &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1973                    &cos,
1974                    &sin,
1975                )?;
1976                q_embeds.push(q_embed);
1977                k_embeds.push(k_embed);
1978            }
1979            let q_rot = Tensor::cat(&q_embeds, 0)?;
1980            let k_rot = Tensor::cat(&k_embeds, 0)?;
1981            (q_rot, k_rot)
1982        };
1983
1984        Ok((
1985            Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
1986            Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
1987        ))
1988    }
1989}
1990
1991#[derive(Debug, Clone)]
1992pub struct Gemma3nRotaryEmbedding(RotaryEmbedding);
1993
1994#[derive(Debug, Clone, Deserialize, Serialize)]
1995#[serde(rename_all = "lowercase")]
1996pub enum Gemma3nScaledRopeType {
1997    #[serde(alias = "linear")]
1998    Linear,
1999}
2000
2001#[derive(Debug, Clone, Deserialize, Serialize)]
2002pub struct Gemma3nRopeScalingConfig {
2003    factor: f64,
2004    rope_type: Gemma3nScaledRopeType,
2005}
2006
2007impl Gemma3nRotaryEmbedding {
2008    fn new_linear(
2009        cfg: &Gemma3nTextConfig,
2010        factor: f64,
2011        is_gpt_neox: bool,
2012        dtype: DType,
2013        dev: &Device,
2014    ) -> Result<Self> {
2015        let max_seq_len = cfg.max_position_embeddings;
2016        let dim = cfg.head_dim;
2017
2018        let inv_freq: Vec<_> = (0..dim)
2019            .step_by(2)
2020            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
2021            .collect();
2022        let inv_freq_len = inv_freq.len();
2023        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
2024        let inv_freq = (inv_freq / factor)?;
2025
2026        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
2027            .to_dtype(DType::F32)?
2028            .reshape((max_seq_len, 1))?;
2029        let freqs = t.matmul(&inv_freq)?;
2030        let sin = freqs.sin()?.to_dtype(dtype)?;
2031        let cos = freqs.cos()?.to_dtype(dtype)?;
2032        Ok(Self(RotaryEmbedding {
2033            cos,
2034            sin,
2035            is_gpt_neox,
2036        }))
2037    }
2038
2039    pub fn new(
2040        is_gpt_neox: bool,
2041        dtype: DType,
2042        cfg: &Gemma3nTextConfig,
2043        dev: &Device,
2044    ) -> Result<Self> {
2045        match &cfg.rope_scaling {
2046            Some(Gemma3RopeScalingConfig {
2047                rope_type: Gemma3ScaledRopeType::Linear,
2048                factor,
2049            }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
2050
2051            _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
2052        }
2053    }
2054
2055    pub fn get_cos_sin(&self) -> Result<(Tensor, Tensor)> {
2056        self.0.get_cos_sin()
2057    }
2058
2059    pub fn forward(
2060        &self,
2061        q: &Tensor,
2062        k: &Tensor,
2063        seqlen_offsets: &[usize],
2064    ) -> Result<(Tensor, Tensor)> {
2065        self.0.forward(q, k, seqlen_offsets)
2066    }
2067
2068    #[allow(clippy::too_many_arguments)]
2069    pub fn forward_qk_norm(
2070        &self,
2071        q: &Tensor,
2072        k: &Tensor,
2073        q_weight: &Tensor,
2074        k_weight: &Tensor,
2075        q_eps: f64,
2076        k_eps: f64,
2077        seqlen_offsets: &[usize],
2078    ) -> Result<(Tensor, Tensor)> {
2079        self.0
2080            .forward_qk_norm(q, k, q_weight, k_weight, q_eps, k_eps, seqlen_offsets)
2081    }
2082}
2083
2084#[derive(Debug, Clone)]
2085pub struct Gemma3RotaryEmbedding(RotaryEmbedding);
2086
2087#[derive(Debug, Clone, Deserialize, Serialize)]
2088#[serde(rename_all = "lowercase")]
2089pub enum Gemma3ScaledRopeType {
2090    #[serde(alias = "linear")]
2091    Linear,
2092}
2093
2094#[derive(Debug, Clone, Deserialize, Serialize)]
2095pub struct Gemma3RopeScalingConfig {
2096    factor: f64,
2097    rope_type: Gemma3ScaledRopeType,
2098}
2099
2100impl Gemma3RotaryEmbedding {
2101    fn new_linear(
2102        cfg: &Gemma3TextConfig,
2103        factor: f64,
2104        is_gpt_neox: bool,
2105        dtype: DType,
2106        dev: &Device,
2107    ) -> Result<Self> {
2108        let max_seq_len = cfg.max_position_embeddings;
2109        let dim = cfg.head_dim;
2110
2111        let inv_freq: Vec<_> = (0..dim)
2112            .step_by(2)
2113            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
2114            .collect();
2115        let inv_freq_len = inv_freq.len();
2116        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
2117        let inv_freq = (inv_freq / factor)?;
2118
2119        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
2120            .to_dtype(DType::F32)?
2121            .reshape((max_seq_len, 1))?;
2122        let freqs = t.matmul(&inv_freq)?;
2123        let sin = freqs.sin()?.to_dtype(dtype)?;
2124        let cos = freqs.cos()?.to_dtype(dtype)?;
2125        Ok(Self(RotaryEmbedding {
2126            cos,
2127            sin,
2128            is_gpt_neox,
2129        }))
2130    }
2131
2132    pub fn new(
2133        is_gpt_neox: bool,
2134        dtype: DType,
2135        cfg: &Gemma3TextConfig,
2136        dev: &Device,
2137    ) -> Result<Self> {
2138        match &cfg.rope_scaling {
2139            Some(Gemma3RopeScalingConfig {
2140                rope_type: Gemma3ScaledRopeType::Linear,
2141                factor,
2142            }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
2143
2144            _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
2145        }
2146    }
2147
2148    fn new_linear_embedding_gemma(
2149        cfg: &EmbeddingGemmaConfig,
2150        factor: f64,
2151        is_gpt_neox: bool,
2152        dtype: DType,
2153        dev: &Device,
2154    ) -> Result<Self> {
2155        let max_seq_len = cfg.max_position_embeddings;
2156        let dim = cfg.head_dim;
2157
2158        let inv_freq: Vec<_> = (0..dim)
2159            .step_by(2)
2160            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
2161            .collect();
2162        let inv_freq_len = inv_freq.len();
2163        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
2164        let inv_freq = (inv_freq / factor)?;
2165
2166        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
2167            .to_dtype(DType::F32)?
2168            .reshape((max_seq_len, 1))?;
2169        let freqs = t.matmul(&inv_freq)?;
2170        let sin = freqs.sin()?.to_dtype(dtype)?;
2171        let cos = freqs.cos()?.to_dtype(dtype)?;
2172        Ok(Self(RotaryEmbedding {
2173            cos,
2174            sin,
2175            is_gpt_neox,
2176        }))
2177    }
2178
2179    pub fn new_embedding_gemma(
2180        is_gpt_neox: bool,
2181        dtype: DType,
2182        cfg: &EmbeddingGemmaConfig,
2183        dev: &Device,
2184    ) -> Result<Self> {
2185        match &cfg.rope_scaling {
2186            Some(Gemma3RopeScalingConfig {
2187                rope_type: Gemma3ScaledRopeType::Linear,
2188                factor,
2189            }) => Self::new_linear_embedding_gemma(cfg, *factor, is_gpt_neox, dtype, dev),
2190
2191            _ => Self::new_linear_embedding_gemma(cfg, 1.0, is_gpt_neox, dtype, dev),
2192        }
2193    }
2194
2195    pub fn forward(
2196        &self,
2197        q: &Tensor,
2198        k: &Tensor,
2199        seqlen_offsets: &[usize],
2200    ) -> Result<(Tensor, Tensor)> {
2201        self.0.forward(q, k, seqlen_offsets)
2202    }
2203
2204    pub fn forward_q_norm(
2205        &self,
2206        q: &Tensor,
2207        q_weight: &Tensor,
2208        q_eps: f64,
2209        seqlen_offsets: &[usize],
2210    ) -> Result<Tensor> {
2211        self.0.forward_q_norm(q, q_weight, q_eps, seqlen_offsets)
2212    }
2213
2214    #[allow(clippy::too_many_arguments)]
2215    pub fn forward_qk_norm(
2216        &self,
2217        q: &Tensor,
2218        k: &Tensor,
2219        q_weight: &Tensor,
2220        k_weight: &Tensor,
2221        q_eps: f64,
2222        k_eps: f64,
2223        seqlen_offsets: &[usize],
2224    ) -> Result<(Tensor, Tensor)> {
2225        self.0
2226            .forward_qk_norm(q, k, q_weight, k_weight, q_eps, k_eps, seqlen_offsets)
2227    }
2228}
2229
2230pub struct DiaRotaryEmbedding {
2231    timescale: Tensor,
2232    dtype: DType,
2233}
2234
2235impl DiaRotaryEmbedding {
2236    pub fn new(
2237        min_timescale: f32,
2238        max_timescale: f32,
2239        head_dim: usize,
2240        device: &Device,
2241        dtype: DType,
2242    ) -> Result<Self> {
2243        assert_eq!(head_dim % 2, 0);
2244        let half_embedding_dim = head_dim / 2;
2245
2246        let fraction = (0..half_embedding_dim).map(|i| 2f32 * i as f32 / head_dim as f32);
2247        let timescale = fraction
2248            .into_iter()
2249            .map(|x| min_timescale * (max_timescale / min_timescale).powf(x))
2250            .collect::<Vec<_>>();
2251
2252        let timescale_len = timescale.len();
2253        let timescale = Tensor::from_vec(timescale, timescale_len, device)?;
2254
2255        Ok(Self { timescale, dtype })
2256    }
2257
2258    pub fn forward(&self, xs: &Tensor, positions: &Tensor) -> Result<Tensor> {
2259        let freqs = positions
2260            .unsqueeze(D::Minus1)?
2261            .unsqueeze(D::Minus1)?
2262            .broadcast_div(&self.timescale)?;
2263
2264        let sin = freqs.sin()?.to_dtype(self.dtype)?;
2265        let cos = freqs.cos()?.to_dtype(self.dtype)?;
2266
2267        let split = xs.chunk(2, D::Minus1)?;
2268        let first_half = &split[0];
2269        let second_half = &split[1];
2270
2271        let first_part = (first_half.broadcast_mul(&cos)? - second_half.broadcast_mul(&sin)?)?;
2272        let second_part = (second_half.broadcast_mul(&cos)? + first_half.broadcast_mul(&sin)?)?;
2273
2274        Tensor::cat(&[first_part, second_part], D::Minus1)
2275    }
2276}
2277#[derive(Debug, Clone)]
2278pub struct QLinear {
2279    inner: QMatMul,
2280    bias: Option<Tensor>,
2281    dtype: DType,
2282}
2283
2284impl QLinear {
2285    pub fn new<R: std::io::Read + std::io::Seek>(
2286        ct: &mut Content<'_, R>,
2287        name: &str,
2288        device: &Device,
2289    ) -> Result<Self> {
2290        let w = ct.tensor(&format!("{name}.weight"), device)?;
2291        let b = ct.tensor(&format!("{name}.bias"), device)?;
2292        let inner = QMatMul::from_qtensor(w)?;
2293        let bias = b.dequantize(device)?;
2294        Ok(Self {
2295            inner,
2296            bias: Some(bias),
2297            dtype: DType::F32,
2298        })
2299    }
2300
2301    pub fn from_linear(linear: Linear) -> Self {
2302        Self {
2303            inner: QMatMul::Tensor(linear.weight().clone()),
2304            bias: linear.bias().cloned(),
2305            dtype: linear.weight().dtype(),
2306        }
2307    }
2308
2309    pub fn from_parts(w: Tensor, b: Option<Tensor>) -> Self {
2310        let dtype = w.dtype();
2311        Self {
2312            inner: QMatMul::Tensor(w),
2313            bias: b,
2314            dtype,
2315        }
2316    }
2317
2318    pub fn from_qparts(w: QTensor, b: Option<Tensor>) -> Self {
2319        if let Some(ref b) = b {
2320            assert_eq!(b.dtype(), DType::F32);
2321        }
2322        Self {
2323            inner: QMatMul::QTensor(Arc::new(w)),
2324            bias: b,
2325            dtype: DType::F32,
2326        }
2327    }
2328
2329    pub fn from_old_and_qmatmul(inner: QMatMul, old: &Self) -> Self {
2330        Self {
2331            inner,
2332            bias: old.bias.clone(),
2333            dtype: old.dtype,
2334        }
2335    }
2336
2337    pub fn inner(&mut self) -> &mut QMatMul {
2338        &mut self.inner
2339    }
2340
2341    pub fn inner_ref(&self) -> &QMatMul {
2342        &self.inner
2343    }
2344
2345    pub fn is_quant(&self) -> bool {
2346        matches!(self.inner, QMatMul::QTensor(_))
2347    }
2348
2349    pub fn bias(&self) -> Option<&Tensor> {
2350        self.bias.as_ref()
2351    }
2352
2353    pub fn bias_mut(&mut self) -> Option<&mut Tensor> {
2354        self.bias.as_mut()
2355    }
2356}
2357
2358impl Module for QLinear {
2359    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2360        let xs = if self.is_quant() {
2361            xs.to_dtype(DType::F32)?
2362        } else {
2363            xs.clone()
2364        };
2365        if let Some(bias) = &self.bias {
2366            self.inner
2367                .forward(&xs)?
2368                .broadcast_add(bias)?
2369                .to_dtype(self.dtype)
2370        } else {
2371            self.inner.forward(&xs)?.to_dtype(self.dtype)
2372        }
2373    }
2374}
2375
2376#[derive(Debug, Clone)]
2377pub struct RotaryEmbedding {
2378    cos: Tensor,
2379    sin: Tensor,
2380    is_gpt_neox: bool,
2381}
2382
2383fn selected_rope_cache(
2384    cos: &Tensor,
2385    sin: &Tensor,
2386    batch: usize,
2387    seq_len: usize,
2388    seqlen_offsets: &[usize],
2389) -> Result<(Tensor, Tensor)> {
2390    if seqlen_offsets.len() == 1 {
2391        Ok((
2392            cos.narrow(0, seqlen_offsets[0], seq_len)?,
2393            sin.narrow(0, seqlen_offsets[0], seq_len)?,
2394        ))
2395    } else {
2396        if seqlen_offsets.len() != batch {
2397            hanzo_ml::bail!(
2398                "RoPE offset count {} does not match batch size {batch}",
2399                seqlen_offsets.len()
2400            );
2401        }
2402        let mut cos_s = Vec::with_capacity(batch);
2403        let mut sin_s = Vec::with_capacity(batch);
2404        for offset in seqlen_offsets {
2405            cos_s.push(cos.narrow(0, *offset, seq_len)?);
2406            sin_s.push(sin.narrow(0, *offset, seq_len)?);
2407        }
2408        Ok((Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?))
2409    }
2410}
2411
2412#[allow(clippy::too_many_arguments)]
2413pub fn qk_rms_norm_rope(
2414    q: &Tensor,
2415    k: &Tensor,
2416    q_weight: &Tensor,
2417    k_weight: &Tensor,
2418    q_eps: f64,
2419    k_eps: f64,
2420    cos_cache: &Tensor,
2421    sin_cache: &Tensor,
2422    is_gpt_neox: bool,
2423    seqlen_offsets: &[usize],
2424) -> Result<(Tensor, Tensor)> {
2425    let (batch, _, seq_len, _) = q.dims4()?;
2426    let (cos, sin) = selected_rope_cache(cos_cache, sin_cache, batch, seq_len, seqlen_offsets)?;
2427
2428    #[cfg(feature = "cuda")]
2429    if let Some((q, Some(k))) = crate::ops::try_cuda_qk_rms_norm_rope(
2430        q,
2431        Some(k),
2432        q_weight,
2433        Some(k_weight),
2434        q_eps as f32,
2435        k_eps as f32,
2436        &cos,
2437        &sin,
2438        is_gpt_neox,
2439    )? {
2440        return Ok((q, k));
2441    }
2442
2443    let rope = if is_gpt_neox {
2444        hanzo_nn::rotary_emb::rope
2445    } else {
2446        hanzo_nn::rotary_emb::rope_i
2447    };
2448    let q = hanzo_nn::ops::rms_norm(&q.contiguous()?, q_weight, q_eps as f32)?;
2449    let k = hanzo_nn::ops::rms_norm(&k.contiguous()?, k_weight, k_eps as f32)?;
2450
2451    #[cfg(feature = "cuda")]
2452    if q.device().is_cuda() && q.dim(1)? == k.dim(1)? && cos.dim(0)? == batch * seq_len {
2453        let qh = q.dim(1)?;
2454        let n_embd = q.dim(D::Minus1)?;
2455        let q_embed = q.transpose(1, 2)?.flatten(0, 1)?;
2456        let k_embed = k.transpose(1, 2)?.flatten(0, 1)?;
2457        hanzo_quant::rotary::apply_rotary_inplace(&q_embed, &k_embed, &cos, &sin, is_gpt_neox)?;
2458        let mut q = q_embed
2459            .reshape((batch, seq_len, qh, n_embd))?
2460            .transpose(1, 2)?;
2461        let mut k = k_embed
2462            .reshape((batch, seq_len, k.dim(1)?, n_embd))?
2463            .transpose(1, 2)?;
2464        if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) {
2465            q = q.contiguous()?;
2466            k = k.contiguous()?;
2467        }
2468        return Ok((q, k));
2469    }
2470
2471    if seqlen_offsets.len() == 1 {
2472        Ok((
2473            rope(&q.contiguous()?, &cos, &sin)?,
2474            rope(&k.contiguous()?, &cos, &sin)?,
2475        ))
2476    } else {
2477        let mut q_embeds = Vec::with_capacity(batch);
2478        let mut k_embeds = Vec::with_capacity(batch);
2479        for seq_idx in 0..batch {
2480            let cos = cos.narrow(0, seq_idx * seq_len, seq_len)?;
2481            let sin = sin.narrow(0, seq_idx * seq_len, seq_len)?;
2482            q_embeds.push(rope(
2483                &q.i(seq_idx)?.unsqueeze(0)?.contiguous()?,
2484                &cos,
2485                &sin,
2486            )?);
2487            k_embeds.push(rope(
2488                &k.i(seq_idx)?.unsqueeze(0)?.contiguous()?,
2489                &cos,
2490                &sin,
2491            )?);
2492        }
2493        Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
2494    }
2495}
2496
2497#[allow(clippy::too_many_arguments)]
2498pub fn q_rms_norm_rope(
2499    q: &Tensor,
2500    q_weight: &Tensor,
2501    q_eps: f64,
2502    cos_cache: &Tensor,
2503    sin_cache: &Tensor,
2504    is_gpt_neox: bool,
2505    seqlen_offsets: &[usize],
2506) -> Result<Tensor> {
2507    let (batch, _qh, seq_len, _head_dim) = q.dims4()?;
2508    let (cos, sin) = selected_rope_cache(cos_cache, sin_cache, batch, seq_len, seqlen_offsets)?;
2509
2510    #[cfg(feature = "cuda")]
2511    if let Some((q, None)) = crate::ops::try_cuda_qk_rms_norm_rope(
2512        q,
2513        None,
2514        q_weight,
2515        None,
2516        q_eps as f32,
2517        q_eps as f32,
2518        &cos,
2519        &sin,
2520        is_gpt_neox,
2521    )? {
2522        return Ok(q);
2523    }
2524
2525    let rope = if is_gpt_neox {
2526        hanzo_nn::rotary_emb::rope
2527    } else {
2528        hanzo_nn::rotary_emb::rope_i
2529    };
2530    let q = hanzo_nn::ops::rms_norm(&q.contiguous()?, q_weight, q_eps as f32)?;
2531    if seqlen_offsets.len() == 1 {
2532        rope(&q.contiguous()?, &cos, &sin)
2533    } else {
2534        let mut q_embeds = Vec::with_capacity(batch);
2535        for seq_idx in 0..batch {
2536            let cos = cos.narrow(0, seq_idx * seq_len, seq_len)?;
2537            let sin = sin.narrow(0, seq_idx * seq_len, seq_len)?;
2538            q_embeds.push(rope(
2539                &q.i(seq_idx)?.unsqueeze(0)?.contiguous()?,
2540                &cos,
2541                &sin,
2542            )?);
2543        }
2544        Tensor::cat(&q_embeds, 0)
2545    }
2546}
2547
2548#[allow(clippy::too_many_arguments)]
2549pub fn qk_rms_norm_mrope(
2550    q: &Tensor,
2551    k: &Tensor,
2552    q_weight: &Tensor,
2553    k_weight: &Tensor,
2554    q_eps: f64,
2555    k_eps: f64,
2556    cos: &Tensor,
2557    sin: &Tensor,
2558    is_gpt_neox: bool,
2559) -> Result<(Tensor, Tensor)> {
2560    let (_, _q_heads, _, head_dim) = q.dims4()?;
2561    let rot_width = cos.dim(D::Minus1)? * 2;
2562
2563    #[cfg(feature = "cuda")]
2564    {
2565        let (batch, _, seq_len, _) = q.dims4()?;
2566        let cos_flat = match cos.dims() {
2567            [cos_batch, cos_seq, _] if *cos_batch == batch && *cos_seq == seq_len => {
2568                cos.reshape((batch * seq_len, ()))?
2569            }
2570            [cos_rows, _] if *cos_rows == seq_len || *cos_rows == batch * seq_len => cos.clone(),
2571            _ => hanzo_ml::bail!(
2572                "MRoPE cos shape {:?} is incompatible with q shape {:?}",
2573                cos.shape(),
2574                q.shape()
2575            ),
2576        };
2577        let sin_flat = match sin.dims() {
2578            [sin_batch, sin_seq, _] if *sin_batch == batch && *sin_seq == seq_len => {
2579                sin.reshape((batch * seq_len, ()))?
2580            }
2581            [sin_rows, _] if *sin_rows == seq_len || *sin_rows == batch * seq_len => sin.clone(),
2582            _ => hanzo_ml::bail!(
2583                "MRoPE sin shape {:?} is incompatible with q shape {:?}",
2584                sin.shape(),
2585                q.shape()
2586            ),
2587        };
2588        if let Some((q, Some(k))) = crate::ops::try_cuda_qk_rms_norm_rope(
2589            q,
2590            Some(k),
2591            q_weight,
2592            Some(k_weight),
2593            q_eps as f32,
2594            k_eps as f32,
2595            &cos_flat,
2596            &sin_flat,
2597            is_gpt_neox,
2598        )? {
2599            return Ok((q, k));
2600        }
2601    }
2602
2603    let rope = if is_gpt_neox {
2604        hanzo_nn::rotary_emb::rope
2605    } else {
2606        hanzo_nn::rotary_emb::rope_i
2607    };
2608    let q = hanzo_nn::ops::rms_norm(&q.contiguous()?, q_weight, q_eps as f32)?;
2609    let k = hanzo_nn::ops::rms_norm(&k.contiguous()?, k_weight, k_eps as f32)?;
2610    if rot_width < head_dim {
2611        let q_rot = q.narrow(D::Minus1, 0, rot_width)?;
2612        let q_pass = q.narrow(D::Minus1, rot_width, head_dim - rot_width)?;
2613        let k_rot = k.narrow(D::Minus1, 0, rot_width)?;
2614        let k_pass = k.narrow(D::Minus1, rot_width, head_dim - rot_width)?;
2615        let q_rot = rope(&q_rot.contiguous()?, cos, sin)?;
2616        let k_rot = rope(&k_rot.contiguous()?, cos, sin)?;
2617        Ok((
2618            Tensor::cat(&[q_rot, q_pass], D::Minus1)?,
2619            Tensor::cat(&[k_rot, k_pass], D::Minus1)?,
2620        ))
2621    } else {
2622        Ok((
2623            rope(&q.contiguous()?, cos, sin)?,
2624            rope(&k.contiguous()?, cos, sin)?,
2625        ))
2626    }
2627}
2628
2629impl RotaryEmbedding {
2630    pub fn new(
2631        base: f32,
2632        head_dim: usize,
2633        max_position_embeddings: usize,
2634        device: &Device,
2635        is_gpt_neox: bool,
2636        dtype: DType,
2637    ) -> Result<Self> {
2638        let inv_freq: Vec<_> = (0..head_dim)
2639            .step_by(2)
2640            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
2641            .collect();
2642        let inv_freq_len = inv_freq.len();
2643        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
2644        let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
2645            .to_dtype(DType::F32)?
2646            .reshape((max_position_embeddings, 1))?;
2647        let freqs = t.matmul(&inv_freq)?;
2648        let sin = freqs.sin()?.to_dtype(dtype)?;
2649        let cos = freqs.cos()?.to_dtype(dtype)?;
2650
2651        Ok(Self {
2652            cos,
2653            sin,
2654            is_gpt_neox,
2655        })
2656    }
2657
2658    pub fn get_cos_sin(&self) -> Result<(Tensor, Tensor)> {
2659        Ok((self.cos.clone(), self.sin.clone()))
2660    }
2661
2662    #[allow(clippy::too_many_arguments)]
2663    pub fn forward_qk_norm(
2664        &self,
2665        q: &Tensor,
2666        k: &Tensor,
2667        q_weight: &Tensor,
2668        k_weight: &Tensor,
2669        q_eps: f64,
2670        k_eps: f64,
2671        seqlen_offsets: &[usize],
2672    ) -> Result<(Tensor, Tensor)> {
2673        qk_rms_norm_rope(
2674            q,
2675            k,
2676            q_weight,
2677            k_weight,
2678            q_eps,
2679            k_eps,
2680            &self.cos,
2681            &self.sin,
2682            self.is_gpt_neox,
2683            seqlen_offsets,
2684        )
2685    }
2686
2687    pub fn forward_q_norm(
2688        &self,
2689        q: &Tensor,
2690        q_weight: &Tensor,
2691        q_eps: f64,
2692        seqlen_offsets: &[usize],
2693    ) -> Result<Tensor> {
2694        q_rms_norm_rope(
2695            q,
2696            q_weight,
2697            q_eps,
2698            &self.cos,
2699            &self.sin,
2700            self.is_gpt_neox,
2701            seqlen_offsets,
2702        )
2703    }
2704
2705    pub fn new_partial(
2706        base: f32,
2707        rot_dim: usize,
2708        max_position_embeddings: usize,
2709        device: &Device,
2710        is_gpt_neox: bool,
2711        dtype: DType,
2712    ) -> Result<Self> {
2713        let inv_freq: Vec<_> = (0..rot_dim)
2714            .step_by(2)
2715            .map(|i| 1f32 / base.powf(i as f32 / rot_dim as f32))
2716            .collect();
2717        let inv_freq_len = inv_freq.len();
2718        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
2719        let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
2720            .to_dtype(DType::F32)?
2721            .reshape((max_position_embeddings, 1))?;
2722        let freqs = t.matmul(&inv_freq)?;
2723        let sin = freqs.sin()?.to_dtype(dtype)?;
2724        let cos = freqs.cos()?.to_dtype(dtype)?;
2725
2726        Ok(Self {
2727            cos,
2728            sin,
2729            is_gpt_neox,
2730        })
2731    }
2732
2733    pub fn forward(
2734        &self,
2735        q: &Tensor,
2736        k: &Tensor,
2737        seqlen_offsets: &[usize],
2738    ) -> Result<(Tensor, Tensor)> {
2739        let (b_sz, qh, seq_len, n_embd) = q.dims4()?;
2740        let (_b_sz, kh, _seq_len, __n_embd) = k.dims4()?;
2741
2742        let rope = if self.is_gpt_neox {
2743            hanzo_nn::rotary_emb::rope
2744        } else {
2745            hanzo_nn::rotary_emb::rope_i
2746        };
2747
2748        if cfg!(feature = "cuda") && qh == kh {
2749            let (cos, sin) = if seqlen_offsets.len() == 1 {
2750                (
2751                    self.cos.narrow(0, seqlen_offsets[0], seq_len)?,
2752                    self.sin.narrow(0, seqlen_offsets[0], seq_len)?,
2753                )
2754            } else {
2755                let mut cos_s = Vec::new();
2756                let mut sin_s = Vec::new();
2757                for offset in seqlen_offsets {
2758                    cos_s.push(self.cos.narrow(0, *offset, seq_len)?);
2759                    sin_s.push(self.sin.narrow(0, *offset, seq_len)?);
2760                }
2761                (Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?)
2762            };
2763
2764            let q_embed = q.transpose(1, 2)?.flatten(0, 1)?;
2765            let k_embed = k.transpose(1, 2)?.flatten(0, 1)?;
2766            hanzo_quant::rotary::apply_rotary_inplace(
2767                &q_embed,
2768                &k_embed,
2769                &cos,
2770                &sin,
2771                self.is_gpt_neox,
2772            )?;
2773            let mut q = q_embed
2774                .reshape((b_sz, seq_len, qh, n_embd))?
2775                .transpose(1, 2)?;
2776            let mut k = k_embed
2777                .reshape((b_sz, seq_len, kh, n_embd))?
2778                .transpose(1, 2)?;
2779            if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) {
2780                q = q.contiguous()?;
2781                k = k.contiguous()?;
2782            }
2783            Ok((q, k))
2784        } else if seqlen_offsets.len() == 1 {
2785            let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
2786            let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
2787            let q_embed = rope(&q.contiguous()?, &cos, &sin)?;
2788            let k_embed = rope(&k.contiguous()?, &cos, &sin)?;
2789            Ok((q_embed, k_embed))
2790        } else {
2791            let mut q_embeds = Vec::new();
2792            let mut k_embeds = Vec::new();
2793            for (i, offset) in seqlen_offsets.iter().enumerate() {
2794                let cos = self.cos.narrow(0, *offset, seq_len)?;
2795                let sin = self.sin.narrow(0, *offset, seq_len)?;
2796                let q_embed = rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
2797                let k_embed = rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
2798                q_embeds.push(q_embed);
2799                k_embeds.push(k_embed);
2800            }
2801            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
2802        }
2803    }
2804
2805    /// Apply RoPE to Q only (skip K rotation for shared KV layers).
2806    pub fn forward_q(&self, q: &Tensor, seqlen_offsets: &[usize]) -> Result<Tensor> {
2807        let (_b_sz, _qh, seq_len, _n_embd) = q.dims4()?;
2808        let rope = if self.is_gpt_neox {
2809            hanzo_nn::rotary_emb::rope
2810        } else {
2811            hanzo_nn::rotary_emb::rope_i
2812        };
2813        if seqlen_offsets.len() == 1 {
2814            let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
2815            let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
2816            rope(&q.contiguous()?, &cos, &sin)
2817        } else {
2818            let mut q_embeds = Vec::new();
2819            for (i, offset) in seqlen_offsets.iter().enumerate() {
2820                let cos = self.cos.narrow(0, *offset, seq_len)?;
2821                let sin = self.sin.narrow(0, *offset, seq_len)?;
2822                q_embeds.push(rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?);
2823            }
2824            Tensor::cat(&q_embeds, 0)
2825        }
2826    }
2827}
2828
2829/// GPT-OSS style rotary embedding with YARN scaling support.
2830/// Uses chunked/GPT-NeoX style rotation and applies attention scaling.
2831#[derive(Debug, Clone)]
2832pub struct GptOssRotaryEmbedding {
2833    cos: Tensor,
2834    sin: Tensor,
2835    #[allow(dead_code)]
2836    attention_scale: f32,
2837}
2838
2839impl GptOssRotaryEmbedding {
2840    /// Create a new GPT-OSS rotary embedding with YARN scaling.
2841    ///
2842    /// # Arguments
2843    /// * `base` - Base frequency for RoPE
2844    /// * `head_dim` - Dimension of each attention head
2845    /// * `max_position_embeddings` - Maximum sequence length
2846    /// * `factor` - YARN scaling factor
2847    /// * `original_max_position_embeddings` - Original max positions before scaling
2848    /// * `beta_fast` - YARN beta_fast parameter
2849    /// * `beta_slow` - YARN beta_slow parameter
2850    /// * `truncate` - Whether to truncate correction dimensions
2851    /// * `device` - Device to create tensors on
2852    /// * `dtype` - Data type for the embeddings
2853    #[allow(clippy::too_many_arguments)]
2854    pub fn new(
2855        base: f64,
2856        head_dim: usize,
2857        max_position_embeddings: usize,
2858        factor: f64,
2859        original_max_position_embeddings: usize,
2860        beta_fast: f64,
2861        beta_slow: f64,
2862        truncate: bool,
2863        device: &Device,
2864        dtype: DType,
2865    ) -> Result<Self> {
2866        let dim = head_dim;
2867
2868        // Compute attention scale: 0.1 * ln(factor) + 1.0 for YARN
2869        let attention_scale = (0.1 * factor.ln() + 1.0) as f32;
2870
2871        // Helper: find correction dimension based on number of rotations
2872        // HF: (dim * log(max_pos / (num_rotations * 2 * pi))) / (2 * log(base))
2873        let find_correction_dim = |num_rotations: f64| -> f64 {
2874            (dim as f64
2875                * (original_max_position_embeddings as f64
2876                    / (num_rotations * 2.0 * std::f64::consts::PI))
2877                    .ln())
2878                / (2.0 * base.ln())
2879        };
2880
2881        // Find correction range based on beta_fast and beta_slow
2882        let mut low = find_correction_dim(beta_fast);
2883        let mut high = find_correction_dim(beta_slow);
2884        if truncate {
2885            low = low.floor();
2886            high = high.ceil();
2887        }
2888        low = low.max(0.0);
2889        high = high.min((dim - 1) as f64);
2890
2891        // Compute base inverse frequencies
2892        let half_dim = dim / 2;
2893        let inv_freq_extrapolation: Vec<f64> = (0..dim)
2894            .step_by(2)
2895            .map(|i| 1.0 / base.powf(i as f64 / dim as f64))
2896            .collect();
2897        let inv_freq_interpolation: Vec<f64> =
2898            inv_freq_extrapolation.iter().map(|f| f / factor).collect();
2899
2900        // Linear ramp factor over dimension indices
2901        let inv_freq: Vec<f64> = (0..half_dim)
2902            .map(|i| {
2903                let range = if (high - low).abs() < 0.001 {
2904                    0.001
2905                } else {
2906                    high - low
2907                };
2908                let linear = (i as f64 - low) / range;
2909                let ramp = linear.clamp(0.0, 1.0);
2910                inv_freq_interpolation[i] * ramp + inv_freq_extrapolation[i] * (1.0 - ramp)
2911            })
2912            .collect();
2913
2914        let inv_freq_len = inv_freq.len();
2915        let inv_freq_tensor = Tensor::from_vec(
2916            inv_freq.iter().map(|&x| x as f32).collect::<Vec<_>>(),
2917            (1, inv_freq_len),
2918            device,
2919        )?;
2920
2921        let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
2922            .to_dtype(DType::F32)?
2923            .reshape((max_position_embeddings, 1))?;
2924
2925        let freqs = t.matmul(&inv_freq_tensor)?;
2926
2927        // Apply attention scale to sin/cos (matches HF transformers behavior)
2928        let sin = (freqs.sin()? * attention_scale as f64)?.to_dtype(dtype)?;
2929        let cos = (freqs.cos()? * attention_scale as f64)?.to_dtype(dtype)?;
2930
2931        Ok(Self {
2932            cos,
2933            sin,
2934            attention_scale,
2935        })
2936    }
2937
2938    pub fn forward(
2939        &self,
2940        q: &Tensor,
2941        k: &Tensor,
2942        seqlen_offsets: &[usize],
2943    ) -> Result<(Tensor, Tensor)> {
2944        #[allow(unused_variables)]
2945        let (b_sz, qh, seq_len, n_embd) = q.dims4()?;
2946        #[allow(unused_variables)]
2947        let (_b_sz, kh, _seq_len, _n_embd) = k.dims4()?;
2948
2949        // Use CUDA optimized kernel when available and q/k have same number of heads
2950        // The CUDA kernel uses is_neox=true for chunked/GPT-NeoX style rotary
2951        #[cfg(feature = "cuda")]
2952        if q.device().is_cuda() && qh == k.dim(1)? {
2953            let (cos, sin) = if seqlen_offsets.len() == 1 {
2954                (
2955                    self.cos.narrow(0, seqlen_offsets[0], seq_len)?,
2956                    self.sin.narrow(0, seqlen_offsets[0], seq_len)?,
2957                )
2958            } else {
2959                let mut cos_s = Vec::new();
2960                let mut sin_s = Vec::new();
2961                for offset in seqlen_offsets {
2962                    cos_s.push(self.cos.narrow(0, *offset, seq_len)?);
2963                    sin_s.push(self.sin.narrow(0, *offset, seq_len)?);
2964                }
2965                (Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?)
2966            };
2967
2968            // Reshape for CUDA kernel: [b, h, seq, dim] -> [b*seq, h, dim]
2969            let q_embed = q.transpose(1, 2)?.flatten(0, 1)?;
2970            let k_embed = k.transpose(1, 2)?.flatten(0, 1)?;
2971
2972            // Apply rotary with is_neox=true for chunked style
2973            hanzo_quant::rotary::apply_rotary_inplace(&q_embed, &k_embed, &cos, &sin, true)?;
2974
2975            // Reshape back: [b*seq, h, dim] -> [b, h, seq, dim]
2976            let mut q = q_embed
2977                .reshape((b_sz, seq_len, qh, n_embd))?
2978                .transpose(1, 2)?;
2979            let mut k = k_embed
2980                .reshape((b_sz, seq_len, kh, n_embd))?
2981                .transpose(1, 2)?;
2982
2983            if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) {
2984                q = q.contiguous()?;
2985                k = k.contiguous()?;
2986            }
2987            return Ok((q, k));
2988        }
2989
2990        // CPU fallback using hanzo_nn's rope (GPT-NeoX/chunked style)
2991        if seqlen_offsets.len() == 1 {
2992            let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
2993            let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
2994            let q_embed = hanzo_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
2995            let k_embed = hanzo_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
2996            Ok((q_embed, k_embed))
2997        } else {
2998            let mut q_embeds = Vec::new();
2999            let mut k_embeds = Vec::new();
3000            for (i, offset) in seqlen_offsets.iter().enumerate() {
3001                let cos = self.cos.narrow(0, *offset, seq_len)?;
3002                let sin = self.sin.narrow(0, *offset, seq_len)?;
3003                let q_embed =
3004                    hanzo_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
3005                let k_embed =
3006                    hanzo_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
3007                q_embeds.push(q_embed);
3008                k_embeds.push(k_embed);
3009            }
3010            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
3011        }
3012    }
3013}
3014
3015#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize, Default)]
3016#[serde(rename_all = "lowercase")]
3017pub enum Activation {
3018    #[default]
3019    #[serde(alias = "gelu")]
3020    Gelu,
3021    #[serde(alias = "gelu_new")]
3022    NewGelu,
3023    Relu,
3024    Relu2,
3025    Relu6,
3026    Silu,
3027    Sigmoid,
3028    HardSigmoid,
3029    Swiglu,
3030    Swish,
3031    HardSwish,
3032    Elu(f64),
3033    LeakyRelu(f64),
3034    #[serde(alias = "gelu_pytorch_tanh")]
3035    GeluPytorchTanh,
3036    QuickGelu,
3037}
3038
3039impl Module for Activation {
3040    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
3041        match self {
3042            Self::Gelu => xs.gelu_erf(),
3043            // https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
3044            Self::NewGelu => xs.gelu(),
3045            Self::Relu => xs.relu(),
3046            Self::Relu2 => xs.relu()?.sqr(),
3047            Self::Relu6 => xs.clamp(0f32, 6f32),
3048            Self::Silu => xs.silu(),
3049            Self::Sigmoid => hanzo_nn::ops::sigmoid(xs),
3050            Self::HardSigmoid => hanzo_nn::ops::hard_sigmoid(xs),
3051            Self::Swiglu => hanzo_nn::ops::swiglu(xs),
3052            Self::Swish => xs * hanzo_nn::ops::sigmoid(xs)?,
3053            Self::HardSwish => xs * hanzo_nn::ops::hard_sigmoid(xs)?,
3054            &Self::Elu(alpha) => xs.elu(alpha),
3055            &Self::LeakyRelu(negative_slope) => hanzo_nn::ops::leaky_relu(xs, negative_slope),
3056            Self::GeluPytorchTanh => xs.gelu(),
3057            Self::QuickGelu => xs * hanzo_nn::ops::sigmoid(&(xs * 1.702f64)?),
3058        }
3059    }
3060}
3061
3062impl TryInto<hanzo_nn::Activation> for Activation {
3063    type Error = hanzo_ml::Error;
3064
3065    fn try_into(self) -> Result<hanzo_nn::Activation> {
3066        match self {
3067            Self::Gelu => Ok(hanzo_nn::Activation::Gelu),
3068            Self::Relu => Ok(hanzo_nn::Activation::Relu),
3069            Self::Silu => Ok(hanzo_nn::Activation::Silu),
3070            Self::NewGelu => Ok(hanzo_nn::Activation::NewGelu),
3071            Self::Relu2 => Ok(hanzo_nn::Activation::Relu2),
3072            Self::Relu6 => Ok(hanzo_nn::Activation::Relu6),
3073            Self::Sigmoid => Ok(hanzo_nn::Activation::Sigmoid),
3074            Self::HardSigmoid => Ok(hanzo_nn::Activation::HardSigmoid),
3075            Self::Swiglu => Ok(hanzo_nn::Activation::Swiglu),
3076            Self::Swish => Ok(hanzo_nn::Activation::Swish),
3077            Self::HardSwish => Ok(hanzo_nn::Activation::HardSwish),
3078            Self::Elu(x) => Ok(hanzo_nn::Activation::Elu(x)),
3079            Self::LeakyRelu(x) => Ok(hanzo_nn::Activation::LeakyRelu(x)),
3080            Self::GeluPytorchTanh => Ok(hanzo_nn::Activation::GeluPytorchTanh),
3081            Self::QuickGelu => hanzo_ml::bail!("No mapping to hanzo_nn for QuickGelu"),
3082        }
3083    }
3084}
3085
3086#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3087pub struct Conv3dConfig {
3088    pub padding: usize,
3089    pub stride: usize,
3090    pub dilation: usize,
3091    pub groups: usize,
3092}
3093
3094impl Default for Conv3dConfig {
3095    fn default() -> Self {
3096        Self {
3097            padding: 0,
3098            stride: 1,
3099            dilation: 1,
3100            groups: 1,
3101        }
3102    }
3103}
3104
3105pub struct Conv3dNoBias {
3106    conv2d_1: Conv2d,
3107    conv2d_2: Conv2d,
3108}
3109
3110impl Conv3dNoBias {
3111    pub fn new(
3112        in_channels: usize,
3113        out_channels: usize,
3114        kernel_sizes: [usize; 3],
3115        cfg: Conv3dConfig,
3116        vb: ShardedVarBuilder,
3117    ) -> Result<Self> {
3118        let expected_shape = (
3119            out_channels,
3120            in_channels / cfg.groups,
3121            kernel_sizes[0],
3122            kernel_sizes[1],
3123            kernel_sizes[2],
3124        );
3125        // MLX format has channels-last: (out, temporal, h, w, in)
3126        // PyTorch format has channels-first: (out, in, temporal, h, w)
3127        let mlx_shape = (
3128            out_channels,
3129            kernel_sizes[0],
3130            kernel_sizes[1],
3131            kernel_sizes[2],
3132            in_channels / cfg.groups,
3133        );
3134        let ws = if vb.contains_tensor("weight") {
3135            // Try to load with expected shape first, if it fails try MLX shape and permute
3136            match vb.get(expected_shape, "weight") {
3137                Ok(ws) => ws,
3138                Err(_) => {
3139                    // Try MLX format and permute from (out, t, h, w, in) to (out, in, t, h, w)
3140                    let ws = vb.get(mlx_shape, "weight")?;
3141                    ws.permute((0, 4, 1, 2, 3))?
3142                }
3143            }
3144        } else {
3145            vb.get(expected_shape, "weight")?
3146        };
3147
3148        // Split on temporal dimension
3149        // https://github.com/pytorch/pytorch/issues/139066
3150
3151        let w1 = ws.i((.., .., 0, .., ..))?;
3152        let w2 = ws.i((.., .., 1, .., ..))?;
3153
3154        let cfg = Conv2dConfig {
3155            padding: cfg.padding,
3156            stride: cfg.stride,
3157            dilation: cfg.dilation,
3158            groups: cfg.groups,
3159            cudnn_fwd_algo: None,
3160        };
3161
3162        Ok(Self {
3163            conv2d_1: Conv2d::new(w1.contiguous()?, None, cfg),
3164            conv2d_2: Conv2d::new(w2.contiguous()?, None, cfg),
3165        })
3166    }
3167
3168    pub fn weight(&self) -> Result<Tensor> {
3169        let w1 = self.conv2d_1.weight().clone().unsqueeze(2)?;
3170        let w2 = self.conv2d_2.weight().clone().unsqueeze(2)?;
3171        Tensor::cat(&[w1, w2], 2)
3172    }
3173}
3174
3175impl Module for Conv3dNoBias {
3176    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
3177        let xs1 = xs.i((.., .., 0, .., ..))?;
3178        let xs2 = xs.i((.., .., 1, .., ..))?;
3179
3180        (Convolution.forward_2d(&self.conv2d_1, &xs1)?
3181            + Convolution.forward_2d(&self.conv2d_2, &xs2)?)?
3182        .unsqueeze(2)
3183    }
3184}
3185
3186pub trait TensorInfExtend {
3187    fn is_inf(&self) -> Result<Self>
3188    where
3189        Self: Sized;
3190    fn any(&self) -> Result<bool>;
3191}
3192
3193impl TensorInfExtend for Tensor {
3194    fn is_inf(&self) -> Result<Self> {
3195        self.broadcast_eq(&Tensor::new(f32::INFINITY, self.device())?.to_dtype(self.dtype())?)
3196    }
3197
3198    fn any(&self) -> Result<bool> {
3199        let sum = self.sum_all()?;
3200        match self.dtype() {
3201            DType::U8 => Ok(sum.to_scalar::<u8>()? == 0),
3202            DType::U32 => Ok(sum.to_scalar::<u32>()? == 0),
3203            DType::I16 => Ok(sum.to_scalar::<i16>()? == 0),
3204            DType::I32 => Ok(sum.to_scalar::<i32>()? == 0),
3205            DType::I64 => Ok(sum.to_scalar::<i64>()? == 0),
3206            DType::F16 => Ok(sum.to_scalar::<half::f16>()? == half::f16::from_f32_const(0.)),
3207            DType::BF16 => Ok(sum.to_scalar::<half::bf16>()? == half::bf16::from_f32_const(0.)),
3208            DType::F32 => Ok(sum.to_scalar::<f32>()? == 0.),
3209            DType::F64 => Ok(sum.to_scalar::<f64>()? == 0.),
3210            DType::F8E4M3 => Ok(sum.to_scalar::<F8E4M3>()? == F8E4M3::ZERO),
3211            _ => {
3212                hanzo_ml::bail!("dtype {:?} is not supported with .any", self.dtype())
3213            }
3214        }
3215    }
3216}
3217
3218pub fn clamp_for_f16(xs: &Tensor) -> Result<Tensor> {
3219    let mut max = match xs.dtype() {
3220        DType::U8 => u8::MAX as f32 - 1000.,
3221        DType::U32 => u32::MAX as f32 - 1000.,
3222        DType::I16 => i16::MAX as f32 - 1000.,
3223        DType::I32 => i32::MAX as f32 - 1000.,
3224        DType::I64 => i64::MAX as f32 - 1000.,
3225        DType::F16 => half::f16::MAX.to_f32_const() - 1000.,
3226        DType::BF16 => half::bf16::MAX.to_f32_const() - 1000.,
3227        DType::F32 => f32::MAX - 1000.,
3228        DType::F64 => f64::MAX as f32 - 1000.,
3229        DType::F8E4M3 => F8E4M3::MAX.to_f32() - 1000.,
3230        _ => {
3231            hanzo_ml::bail!("dtype {:?} is not supported with clamp_for_f16", xs.dtype())
3232        }
3233    };
3234    if xs.is_inf()?.any()? {
3235        max -= 1000.;
3236    }
3237    xs.clamp(-max, max)
3238}
3239
3240pub struct FloatInfo {
3241    /// Minimum representable value.
3242    pub min: f64,
3243    /// Maximum representable value.
3244    pub max: f64,
3245    /// The difference between 1.0 and the next smallest representable float larger than 1.0.
3246    pub eps: f64,
3247    pub dtype: DType,
3248}
3249
3250pub trait GetFloatInfo {
3251    fn finfo(&self) -> Result<FloatInfo>;
3252}
3253
3254impl GetFloatInfo for DType {
3255    fn finfo(&self) -> Result<FloatInfo> {
3256        let finfo = match self {
3257            Self::BF16 => FloatInfo {
3258                min: bf16::MIN.to_f64(),
3259                max: bf16::MAX.to_f64(),
3260                eps: bf16::EPSILON.to_f64(),
3261                dtype: DType::BF16,
3262            },
3263            Self::F16 => FloatInfo {
3264                min: f16::MIN.to_f64(),
3265                max: f16::MAX.to_f64(),
3266                eps: f16::EPSILON.to_f64(),
3267                dtype: DType::F16,
3268            },
3269            Self::F32 => FloatInfo {
3270                min: f32::MIN as f64,
3271                max: f32::MAX as f64,
3272                eps: f32::EPSILON as f64,
3273                dtype: DType::F32,
3274            },
3275            Self::F64 => FloatInfo {
3276                min: f64::MIN,
3277                max: f64::MAX,
3278                eps: f64::EPSILON,
3279                dtype: DType::F64,
3280            },
3281            Self::F8E4M3 => FloatInfo {
3282                min: F8E4M3::MIN.to_f64(),
3283                max: F8E4M3::MAX.to_f64(),
3284                eps: F8E4M3::EPSILON.to_f64(),
3285                dtype: DType::F8E4M3,
3286            },
3287            other => {
3288                hanzo_ml::bail!("Expected a float type for `GetFloatInfo`, got {other:?}");
3289            }
3290        };
3291        Ok(finfo)
3292    }
3293}
3294
3295#[derive(Clone)]
3296pub struct Mlp {
3297    pub gate: Arc<dyn QuantMethod>,
3298    pub up: Arc<dyn QuantMethod>,
3299    pub down: Arc<dyn QuantMethod>,
3300    act: Activation,
3301    params: Vec<usize>,
3302}
3303
3304impl Mlp {
3305    pub fn new(
3306        vb: ShardedVarBuilder,
3307        hidden_size: usize,
3308        intermediate_size: usize,
3309        quantization_config: &Option<QuantizedConfig>,
3310        hidden_act: Activation,
3311        comm: &Arc<hanzo_quant::Comm>,
3312    ) -> Result<Self> {
3313        Ok(Self {
3314            gate: ColumnParallelLayer::new(
3315                hidden_size,
3316                intermediate_size,
3317                quantization_config,
3318                false,
3319                comm,
3320                vb.pp("gate_proj"),
3321            )?,
3322            up: ColumnParallelLayer::new(
3323                hidden_size,
3324                intermediate_size,
3325                quantization_config,
3326                false,
3327                comm,
3328                vb.pp("up_proj"),
3329            )?,
3330            down: RowParallelLayer::new(
3331                intermediate_size,
3332                hidden_size,
3333                quantization_config,
3334                false,
3335                comm,
3336                vb.pp("down_proj"),
3337            )?,
3338            act: hidden_act,
3339            params: vec![hidden_size, intermediate_size],
3340        })
3341    }
3342
3343    pub fn new_merged(
3344        vb: ShardedVarBuilder,
3345        hidden_size: usize,
3346        intermediate_size: usize,
3347        chunks: usize,
3348        quantization_config: &Option<QuantizedConfig>,
3349        hidden_act: Activation,
3350        comm: &Arc<hanzo_quant::Comm>,
3351    ) -> Result<Self> {
3352        assert!(chunks == 2, "Only gate_up_proj merge is supported!");
3353        let gate_up_projs = ColumnParallelLayer::new_merged(
3354            hidden_size,
3355            intermediate_size * 2,
3356            2,
3357            quantization_config,
3358            false,
3359            comm,
3360            vb.pp("gate_up_proj"),
3361        )?;
3362
3363        Ok(Self {
3364            gate: gate_up_projs[0].to_owned(),
3365            up: gate_up_projs[1].to_owned(),
3366            down: RowParallelLayer::new(
3367                intermediate_size,
3368                hidden_size,
3369                quantization_config,
3370                false,
3371                comm,
3372                vb.pp("down_proj"),
3373            )?,
3374            act: hidden_act,
3375            params: vec![hidden_size, intermediate_size],
3376        })
3377    }
3378
3379    pub fn replicate(
3380        params: &[usize],
3381        vb: ShardedVarBuilder,
3382        act: Activation,
3383        comm: &Arc<hanzo_quant::Comm>,
3384    ) -> Result<Self> {
3385        Self::new(vb, params[0], params[1], &None, act, comm)
3386    }
3387
3388    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
3389        let res = crate::ops::quantized_ffn(xs, &*self.gate, &*self.up, &*self.down, self.act)?;
3390        Ok(res)
3391    }
3392}
3393
3394impl AnyMoeTrainableLayer for Mlp {}
3395
3396impl MlpLayer for Mlp {
3397    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
3398        let res = crate::ops::quantized_ffn(xs, &*self.gate, &*self.up, &*self.down, self.act)?;
3399        Ok(res)
3400    }
3401    fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
3402        vec![&mut self.gate, &mut self.up, &mut self.down]
3403    }
3404    fn clone(&self) -> Box<dyn MlpLayer> {
3405        Box::new(Clone::clone(self))
3406    }
3407    fn get_params(&self) -> &[usize] {
3408        &self.params
3409    }
3410    fn hidden_act(&self) -> Activation {
3411        self.act
3412    }
3413    // gate, up, down
3414    fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
3415        let gate = if let Some(ref delta) = deltas[0] {
3416            self.gate.add_delta_w(delta)?
3417        } else {
3418            self.gate.clone()
3419        };
3420        let up = if let Some(ref delta) = deltas[1] {
3421            self.up.add_delta_w(delta)?
3422        } else {
3423            self.up.clone()
3424        };
3425        let down = if let Some(ref delta) = deltas[2] {
3426            self.down.add_delta_w(delta)?
3427        } else {
3428            self.down.clone()
3429        };
3430
3431        Ok(Box::new(Self {
3432            gate,
3433            up,
3434            down,
3435            act: self.act,
3436            params: self.params.clone(),
3437        }))
3438    }
3439
3440    fn dtype_device(&self) -> (DType, Device) {
3441        self.gate.dtype_and_device()
3442    }
3443}
3444
3445pub struct AvgPool2d {
3446    kernel_size: usize,
3447    stride: usize,
3448}
3449
3450impl AvgPool2d {
3451    pub fn new(kernel_size: usize, stride: usize) -> Self {
3452        Self {
3453            kernel_size,
3454            stride,
3455        }
3456    }
3457
3458    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
3459        xs.avg_pool2d_with_stride(self.kernel_size, self.stride)
3460    }
3461}
3462
3463/// Applies 2D reflection padding to a tensor of shape (N, C, H, W).
3464///
3465/// The `padding` argument is a 4-tuple (pad_left, pad_right, pad_top, pad_bottom).
3466/// For left padding, it reflects the values from column 1 up to pad_left (in reverse order);
3467/// for right padding, it reflects from the second-to-last column backwards, and similarly for
3468/// vertical (height) padding.
3469pub struct ReflectionPad2d {
3470    padding: (usize, usize, usize, usize),
3471}
3472
3473impl ReflectionPad2d {
3474    pub fn new(padding: (usize, usize, usize, usize)) -> Self {
3475        Self { padding }
3476    }
3477}
3478
3479impl Module for ReflectionPad2d {
3480    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
3481        let (pad_left, pad_right, pad_top, pad_bottom) = self.padding;
3482
3483        let (_n, _c, h, w) = xs.dims4()?;
3484
3485        // --- Horizontal Padding (along width, axis = 3) ---
3486        // For left padding, we reflect columns 1..=pad_left (in reverse order).
3487        let left_pad = if pad_left > 0 {
3488            // Create indices: [pad_left, pad_left-1, ..., 1]
3489            let indices: Vec<i64> = (1..=pad_left as i64).rev().collect();
3490            Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
3491        } else {
3492            None
3493        };
3494
3495        // For right padding, we reflect from the right side (excluding the last column).
3496        let right_pad = if pad_right > 0 {
3497            // For pad_right == 2, generate indices: [w-2, w-3, ... , w-1-pad_right]
3498            let start = w as i64 - 2;
3499            let indices: Vec<i64> = (0..pad_right as i64).map(|i| start - i).collect();
3500            Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
3501        } else {
3502            None
3503        };
3504
3505        // Concatenate horizontally (along width, dim=3)
3506        let x_padded_width = match (left_pad, right_pad) {
3507            (Some(l), Some(r)) => Tensor::cat(&[l, xs.clone(), r], 3)?,
3508            (Some(l), None) => Tensor::cat(&[l, xs.clone()], 3)?,
3509            (None, Some(r)) => Tensor::cat(&[xs.clone(), r], 3)?,
3510            (None, None) => xs.clone(),
3511        };
3512
3513        // --- Vertical Padding (along height, axis = 2) ---
3514        // For top padding, reflect rows 1..=pad_top (in reverse order)
3515        let top_pad = if pad_top > 0 {
3516            let indices: Vec<i64> = (1..=pad_top as i64).rev().collect();
3517            Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
3518        } else {
3519            None
3520        };
3521
3522        // For bottom padding, reflect from the bottom (excluding the last row)
3523        let bottom_pad = if pad_bottom > 0 {
3524            let start = h as i64 - 2;
3525            let indices: Vec<i64> = (0..pad_bottom as i64).map(|i| start - i).collect();
3526            Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
3527        } else {
3528            None
3529        };
3530
3531        // Concatenate vertically (along height, dim=2)
3532        let x_padded = match (top_pad, bottom_pad) {
3533            (Some(t), Some(b)) => Tensor::cat(&[t, x_padded_width, b], 2)?,
3534            (Some(t), None) => Tensor::cat(&[t, x_padded_width], 2)?,
3535            (None, Some(b)) => Tensor::cat(&[x_padded_width, b], 2)?,
3536            (None, None) => x_padded_width,
3537        };
3538
3539        Ok(x_padded)
3540    }
3541}
3542
3543pub struct ScaledEmbedding {
3544    scale: f64,
3545    pub embedding: Tensor,
3546}
3547
3548impl ScaledEmbedding {
3549    pub fn new(scale: f64, embedding: Embedding) -> Self {
3550        Self {
3551            scale,
3552            embedding: embedding.embeddings().clone(),
3553        }
3554    }
3555
3556    pub fn embeddings(&self) -> &Tensor {
3557        &self.embedding
3558    }
3559}
3560
3561impl Module for ScaledEmbedding {
3562    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
3563        let embedding = Embedding::new(self.embedding.clone(), self.embedding.dim(D::Minus1)?);
3564        xs.apply(&embedding)? * self.scale
3565    }
3566}