1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{f32::consts::PI, ops::Mul, str::FromStr, sync::Arc};
4
5use float8::F8E4M3;
6use half::{bf16, f16};
7use hanzo_ml::{
8 quantized::{QMatMul, QTensor},
9 Context, DType, Device, IndexOp, Result, Tensor, D,
10};
11use hanzo_nn::{
12 BatchNorm, BatchNormConfig, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, Embedding, GroupNorm,
13 LayerNorm, LayerNormConfig, Linear, Module,
14};
15use hanzo_quant::{
16 AfqLayer, ColumnParallelLayer, Convolution, QuantMethod, QuantizedConfig, RowParallelLayer,
17 ShardedVarBuilder,
18};
19use serde::{Deserialize, Serialize};
20
21pub use crate::attention::Sdpa;
22pub use crate::layers_masker::{CausalMaskConfig, CausalMasker};
23pub use crate::layers_utils::repeat_kv;
24use crate::{
25 amoe::{AnyMoeTrainableLayer, MlpLayer},
26 embedding_models::embedding_gemma::EmbeddingGemmaConfig,
27 gguf::Content,
28 models::{llama, smollm3},
29 ops::SplitOp,
30 vision_models::{
31 gemma3::config::Gemma3TextConfig,
32 gemma3n::config::Gemma3nTextConfig,
33 llama4,
34 mllama::{MLlamaRopeScaling, MLlamaRopeType, MLlamaTextConfig},
35 phi4::Phi4MMConfig,
36 },
37};
38
39pub use hanzo_quant::MatMul;
40
41pub fn embedding(
42 in_size: usize,
43 out_size: usize,
44 vb: ShardedVarBuilder,
45 config: &Option<QuantizedConfig>,
46) -> Result<Embedding> {
47 let embeddings = if let Some(QuantizedConfig::Afq { .. }) = config {
49 let afq_layer =
50 AfqLayer::afq_linear_b(out_size, in_size, config.as_ref().unwrap(), false, vb)?;
51 afq_layer.dequantize_w()?
52 } else {
53 vb.get_with_hints((in_size, out_size), "weight", Default::default())?
54 };
55 Ok(Embedding::new(embeddings, out_size))
56}
57
58pub fn layer_norm<C: Into<LayerNormConfig>>(
59 size: usize,
60 config: C,
61 vb: ShardedVarBuilder,
62) -> Result<LayerNorm> {
63 let config = config.into();
64 let weight = vb.get(size, "weight")?;
65 if config.affine {
66 let bias = vb.get(size, "bias")?;
67 Ok(LayerNorm::new(weight, bias, config.eps))
68 } else {
69 Ok(LayerNorm::new_no_bias(weight, config.eps))
70 }
71}
72
73pub fn batch_norm<C: Into<BatchNormConfig>>(
74 num_features: usize,
75 config: C,
76 vb: ShardedVarBuilder,
77) -> Result<BatchNorm> {
78 let config = config.into();
79 if config.eps < 0. {
80 hanzo_ml::bail!("batch-norm eps cannot be negative {}", config.eps)
81 }
82 let running_mean = vb.get(num_features, "running_mean")?;
83 let running_var = vb.get(num_features, "running_var")?;
84
85 if config.affine {
86 let weight = vb.get(num_features, "weight")?;
87 let bias = vb.get(num_features, "bias")?;
88 BatchNorm::new(
89 num_features,
90 running_mean,
91 running_var,
92 weight,
93 bias,
94 config.eps,
95 )
96 } else {
97 BatchNorm::new_no_bias(num_features, running_mean, running_var, config.eps)
98 }
99}
100
101pub fn group_norm(
102 num_groups: usize,
103 num_channels: usize,
104 eps: f64,
105 vb: ShardedVarBuilder,
106) -> Result<GroupNorm> {
107 let weight = vb.get(num_channels, "weight")?;
108 let bias = vb.get(num_channels, "bias")?;
109 GroupNorm::new(weight, bias, num_channels, num_groups, eps)
110}
111
112pub fn conv2d(
113 in_channels: usize,
114 out_channels: usize,
115 kernel_size: usize,
116 cfg: Conv2dConfig,
117 vb: ShardedVarBuilder,
118) -> Result<Conv2d> {
119 let ws = vb.get(
120 (
121 out_channels,
122 in_channels / cfg.groups,
123 kernel_size,
124 kernel_size,
125 ),
126 "weight",
127 )?;
128 let bs = vb.get(out_channels, "bias")?;
129 Ok(Conv2d::new(ws, Some(bs), cfg))
130}
131
132pub fn conv2d_no_bias(
133 in_channels: usize,
134 out_channels: usize,
135 kernel_size: usize,
136 cfg: Conv2dConfig,
137 vb: ShardedVarBuilder,
138) -> Result<Conv2d> {
139 let ws = vb.get(
140 (
141 out_channels,
142 in_channels / cfg.groups,
143 kernel_size,
144 kernel_size,
145 ),
146 "weight",
147 )?;
148 Ok(Conv2d::new(ws, None, cfg))
149}
150
151pub fn conv1d(
152 in_channels: usize,
153 out_channels: usize,
154 kernel_size: usize,
155 cfg: Conv1dConfig,
156 vb: ShardedVarBuilder,
157) -> Result<Conv1d> {
158 let ws = vb.get(
159 (out_channels, in_channels / cfg.groups, kernel_size),
160 "weight",
161 )?;
162 let bs = vb.get(out_channels, "bias")?;
163 Ok(Conv1d::new(ws, Some(bs), cfg))
164}
165
166pub fn conv1d_no_bias(
167 in_channels: usize,
168 out_channels: usize,
169 kernel_size: usize,
170 cfg: Conv1dConfig,
171 vb: ShardedVarBuilder,
172) -> Result<Conv1d> {
173 let ws = vb.get(
174 (out_channels, in_channels / cfg.groups, kernel_size),
175 "weight",
176 )?;
177 Ok(Conv1d::new(ws, None, cfg))
178}
179
180pub fn linear(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
181 let ws = vb.get((out_dim, in_dim), "weight")?;
182 let bs = vb.get(out_dim, "bias")?;
183 Ok(Linear::new(ws, Some(bs)))
184}
185
186pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
187 let ws = vb.get((out_dim, in_dim), "weight")?;
188 Ok(Linear::new(ws, None))
189}
190
191pub fn linear_b(
192 in_dim: usize,
193 out_dim: usize,
194 bias: bool,
195 vb: ShardedVarBuilder,
196) -> Result<Linear> {
197 if bias {
198 linear(in_dim, out_dim, vb)
199 } else {
200 linear_no_bias(in_dim, out_dim, vb)
201 }
202}
203
204#[derive(Debug, Clone)]
205pub struct RmsNorm {
206 eps: f64,
207 weight: Tensor,
208}
209
210impl RmsNorm {
211 pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
212 let w = vb.get(size, "weight")?;
213 Ok(Self { eps, weight: w })
214 }
215
216 #[deprecated(
218 note = "Use GemmaRmsNorm::new() instead, which handles UQFF serialization correctly"
219 )]
220 pub fn new_gemma(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
221 let w = vb.get(size, "weight")?;
222 let w = (w + 1.0)?;
223 Ok(Self { eps, weight: w })
224 }
225
226 pub fn new_gemma_3n(
228 size: usize,
229 eps: f64,
230 with_scale: bool,
231 vb: ShardedVarBuilder,
232 ) -> Result<Self> {
233 let w = if with_scale {
234 vb.get(size, "weight")?
235 } else {
236 Tensor::ones(size, vb.dtype(), vb.device())?
237 };
238 Ok(Self { eps, weight: w })
239 }
240
241 #[deprecated(note = "Use GemmaRmsNorm instead, which handles UQFF serialization automatically")]
243 pub fn undo_gemma(&self) -> Result<Self> {
244 Ok(Self {
245 eps: self.eps,
246 weight: (&self.weight - 1.0)?,
247 })
248 }
249
250 pub fn from_w(w: Tensor, eps: f64) -> Result<Self> {
251 Ok(Self { eps, weight: w })
252 }
253
254 pub fn weight(&self) -> &Tensor {
255 &self.weight
256 }
257
258 pub fn eps(&self) -> f64 {
259 self.eps
260 }
261
262 pub fn forward_residual(&self, x: &Tensor, residual: &Tensor) -> Result<Tensor> {
263 rms_norm_forward_residual(x, residual, &self.weight, self.eps, None)
264 }
265
266 pub fn forward_residual_scaled(
267 &self,
268 x: &Tensor,
269 residual: &Tensor,
270 scale: &Tensor,
271 ) -> Result<Tensor> {
272 rms_norm_forward_residual(x, residual, &self.weight, self.eps, Some(scale))
273 }
274}
275
276impl Module for RmsNorm {
277 fn forward(&self, x: &Tensor) -> Result<Tensor> {
278 hanzo_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
279 }
280}
281
282fn rms_norm_forward_residual(
283 x: &Tensor,
284 residual: &Tensor,
285 weight: &Tensor,
286 eps: f64,
287 scale: Option<&Tensor>,
288) -> Result<Tensor> {
289 #[cfg(feature = "cuda")]
290 if x.device().is_cuda()
291 && residual.device().same_device(x.device())
292 && weight.device().same_device(x.device())
293 && scale.is_none_or(|scale| scale.device().same_device(x.device()))
294 && x.dtype() == residual.dtype()
295 && x.dtype() == weight.dtype()
296 && scale.is_none_or(|scale| scale.dtype() == x.dtype())
297 && matches!(x.dtype(), DType::BF16 | DType::F16 | DType::F32)
298 {
299 return crate::ops::cuda_rms_norm_residual(x, residual, weight, scale, eps as f32);
300 }
301
302 #[cfg(feature = "metal")]
303 if x.device().is_metal()
304 && residual.device().same_device(x.device())
305 && weight.device().same_device(x.device())
306 && scale.is_none_or(|scale| scale.device().same_device(x.device()))
307 && x.dtype() == residual.dtype()
308 && x.dtype() == weight.dtype()
309 && scale.is_none_or(|scale| scale.dtype() == x.dtype())
310 && matches!(x.dtype(), DType::BF16 | DType::F16 | DType::F32)
311 {
312 if let Some(out) =
313 crate::ops::metal_rms_norm_residual(x, residual, weight, scale, eps as f32)?
314 {
315 return Ok(out);
316 }
317 }
318
319 let normed = hanzo_nn::ops::rms_norm(&x.contiguous()?, weight, eps as f32)?;
320 let out = (residual + normed)?;
321 if let Some(scale) = scale {
322 out.broadcast_mul(scale)
323 } else {
324 Ok(out)
325 }
326}
327
328#[derive(Debug, Clone)]
335pub struct GemmaRmsNorm {
336 eps: f64,
337 original_weight: Tensor,
338 weight: Tensor,
339}
340
341impl GemmaRmsNorm {
342 pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
343 let original_weight = vb.get(size, "weight")?;
344 let weight = (&original_weight + 1.0)?;
345 Ok(Self {
346 eps,
347 original_weight,
348 weight,
349 })
350 }
351
352 pub fn weight(&self) -> &Tensor {
353 &self.weight
354 }
355
356 pub fn original_weight(&self) -> &Tensor {
357 &self.original_weight
358 }
359
360 pub fn eps(&self) -> f64 {
361 self.eps
362 }
363
364 pub fn forward_residual(&self, x: &Tensor, residual: &Tensor) -> Result<Tensor> {
365 rms_norm_forward_residual(x, residual, &self.weight, self.eps, None)
366 }
367
368 pub fn forward_residual_scaled(
369 &self,
370 x: &Tensor,
371 residual: &Tensor,
372 scale: &Tensor,
373 ) -> Result<Tensor> {
374 rms_norm_forward_residual(x, residual, &self.weight, self.eps, Some(scale))
375 }
376}
377
378impl Module for GemmaRmsNorm {
379 fn forward(&self, x: &Tensor) -> Result<Tensor> {
380 hanzo_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
381 }
382}
383
384#[derive(Debug, Clone)]
385pub struct F32RmsNorm {
386 w: Tensor,
387 eps: f64,
388}
389
390impl F32RmsNorm {
391 pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
392 Ok(Self {
393 w: vb.get((size,), "weight")?,
394 eps,
395 })
396 }
397
398 pub fn weight(&self) -> &Tensor {
399 &self.w
400 }
401}
402
403impl Module for F32RmsNorm {
404 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
405 let initial_type = xs.dtype();
406 let mut xs = xs.to_dtype(DType::F32)?;
407 let var = xs.powf(2.)?.mean_keepdim(D::Minus1)?;
408 xs = xs.broadcast_mul(&(&var + self.eps)?.recip()?.sqrt()?)?;
409 xs.to_dtype(initial_type)?.broadcast_mul(&self.w)
410 }
411}
412
413#[derive(Debug, Clone)]
414pub struct QRmsNorm {
415 eps: f64,
416 weight: Tensor,
417}
418
419impl QRmsNorm {
420 pub fn new(scale: QTensor, eps: f32) -> Result<Self> {
421 let scale = scale.dequantize(&scale.device())?;
422 Ok(Self {
423 eps: eps as f64,
424 weight: scale,
425 })
426 }
427
428 pub fn weight(&self) -> &Tensor {
429 &self.weight
430 }
431
432 pub fn eps(&self) -> f64 {
433 self.eps
434 }
435
436 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
437 hanzo_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
438 }
439}
440
441#[derive(Debug, Clone)]
443pub struct PhiRotaryEmbedding {
444 short_sin: Tensor,
445 short_cos: Tensor,
446 long_cos: Option<Tensor>,
447 long_sin: Option<Tensor>,
448 original_max_position_embeddings: usize,
449}
450
451#[derive(Debug, Clone, Deserialize, Serialize)]
452#[serde(rename_all = "lowercase")]
453pub enum ScaledRopeType {
454 #[serde(alias = "su")]
455 #[serde(alias = "longrope")]
456 Su,
457 #[serde(alias = "yarn")]
458 Yarn,
459 #[serde(alias = "dynamic")]
460 Dynamic,
461 #[serde(alias = "linear")]
462 Linear,
463}
464
465impl FromStr for ScaledRopeType {
466 type Err = hanzo_ml::Error;
467 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
468 match s {
469 "su" | "longrope" => Ok(Self::Su),
470 "yarn" => Ok(Self::Yarn),
471 "linear" => Ok(Self::Linear),
472 "dynamic" => Ok(Self::Dynamic),
473 _ => Err(hanzo_ml::Error::Msg(
474 "Expected either `su` or `yarn` scaled RoPE type.".to_string(),
475 )),
476 }
477 }
478}
479
480#[derive(Debug, Clone, Deserialize, Serialize)]
481#[serde(untagged)]
482pub enum PhiRopeScalingConfig {
483 Classic {
484 short_factor: Vec<f64>,
485 long_factor: Vec<f64>,
486 #[serde(rename = "type")]
487 scaling_type: ScaledRopeType,
488 },
489 Scaled {
490 short_factor: Vec<f64>,
491 long_factor: Vec<f64>,
492 #[serde(rename = "type")]
493 scaling_type: ScaledRopeType,
494 long_mscale: f64,
495 short_mscale: f64,
496 },
497}
498
499pub struct PhiRopeConfig {
500 pub rope_scaling: Option<PhiRopeScalingConfig>,
501 pub max_position_embeddings: usize,
502 pub original_max_position_embeddings: usize,
503 pub rope_theta: f64,
504 pub head_dim: usize,
505 pub partial_rotary_factor: Option<f64>,
506}
507
508impl PhiRotaryEmbedding {
509 fn new_classic_scaled(
510 short_factor: &[f64],
511 long_factor: &[f64],
512 scaling_type: &ScaledRopeType,
513 cfg: &PhiRopeConfig,
514 dtype: DType,
515 dev: &Device,
516 ) -> Result<Self> {
517 let max_seq_len = cfg.max_position_embeddings;
518 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
519
520 let scale =
522 cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
523 let scaling_factor = if scale <= 1.0 {
524 1.0
525 } else {
526 match scaling_type {
527 ScaledRopeType::Su => {
528 (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
529 }
530 ScaledRopeType::Yarn => 0.1 * scale.ln() + 1.0,
531 _ => hanzo_ml::bail!("Expected either `su` or `yarn` RoPE"),
532 }
533 };
534
535 let inv_freq_long = (0..dim)
537 .step_by(2)
538 .enumerate()
539 .map(|(k, i)| {
540 (1f64 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
541 })
542 .collect::<Vec<_>>();
543 let inv_freq_short = (0..dim)
544 .step_by(2)
545 .enumerate()
546 .map(|(k, i)| {
547 (1f64 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
548 })
549 .collect::<Vec<_>>();
550 let inv_freq_len = inv_freq_long.len();
551
552 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
553 .to_dtype(DType::F32)?
554 .reshape((max_seq_len, 1))?;
555
556 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?;
558 let freqs_long = t.matmul(&inv_freq_long)?;
559 let long_sin = freqs_long.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
560 let long_cos = freqs_long.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
561
562 let inv_freq_short =
564 Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
565 let freqs_short = t.matmul(&inv_freq_short)?;
566 let short_sin = freqs_short.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
567 let short_cos = freqs_short.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
568
569 Ok(Self {
570 short_cos,
571 short_sin,
572 long_cos: Some(long_cos),
573 long_sin: Some(long_sin),
574 original_max_position_embeddings: cfg.original_max_position_embeddings,
575 })
576 }
577
578 fn new_unscaled(cfg: &PhiRopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
579 let max_seq_len = cfg.max_position_embeddings;
580 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
581
582 let inv_freq: Vec<_> = (0..dim)
583 .step_by(2)
584 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
585 .collect();
586 let inv_freq_len = inv_freq.len();
587 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
588 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
589 .to_dtype(DType::F32)?
590 .reshape((max_seq_len, 1))?;
591 let freqs = t.matmul(&inv_freq)?;
592 let sin = freqs.sin()?.to_dtype(dtype)?;
593 let cos = freqs.cos()?.to_dtype(dtype)?;
594 Ok(Self {
595 short_cos: cos,
596 short_sin: sin,
597 long_cos: None,
598 long_sin: None,
599 original_max_position_embeddings: cfg.original_max_position_embeddings,
600 })
601 }
602
603 #[allow(clippy::too_many_arguments)]
604 fn new_scaled(
605 short_factor: &[f64],
606 long_factor: &[f64],
607 scaling_type: &ScaledRopeType,
608 long_mscale: f64,
609 short_mscale: f64,
610 cfg: &PhiRopeConfig,
611 dtype: DType,
612 dev: &Device,
613 ) -> Result<Self> {
614 let max_seq_len = cfg.max_position_embeddings;
615 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
616
617 if !matches!(scaling_type, ScaledRopeType::Su) {
618 hanzo_ml::bail!("Scaled Phi3 RoPE (non-classic scaled, with mscales) must have type `su`/`longrope`.");
619 }
620
621 if short_factor.len() != dim / 2 {
622 hanzo_ml::bail!(
623 "Misaligned length {}, expected {} for `su`/`longrope` short rescale factors",
624 short_factor.len(),
625 dim / 2
626 );
627 }
628 if long_factor.len() != dim / 2 {
629 hanzo_ml::bail!(
630 "Misaligned length {}, expected {} for `su`/`longrope` long rescale factors",
631 long_factor.len(),
632 dim / 2
633 );
634 }
635
636 let inv_freq_short: Vec<_> = (0..dim)
638 .step_by(2)
639 .enumerate()
640 .map(|(k, i)| {
641 1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
642 })
643 .collect();
644 let inv_freq_len_short = inv_freq_short.len();
645 let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
646 let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
647 .to_dtype(DType::F32)?
648 .reshape((max_seq_len, 1))?;
649 let freqs_short = t_short.matmul(&inv_freq_short)?;
650 let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * short_mscale)?;
651 let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * short_mscale)?;
652
653 let inv_freq_long: Vec<_> = (0..dim)
655 .step_by(2)
656 .enumerate()
657 .map(|(k, i)| {
658 1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
659 })
660 .collect();
661 let inv_freq_len_long = inv_freq_long.len();
662 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
663 let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
664 .to_dtype(DType::F32)?
665 .reshape((max_seq_len, 1))?;
666 let freqs_long = t_long.matmul(&inv_freq_long)?;
667 let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * long_mscale)?;
668 let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * long_mscale)?;
669 Ok(Self {
670 short_cos: cos_short,
671 short_sin: sin_short,
672 long_cos: Some(cos_long),
673 long_sin: Some(sin_long),
674 original_max_position_embeddings: cfg.original_max_position_embeddings,
675 })
676 }
677
678 pub fn new(dtype: DType, cfg: impl Into<PhiRopeConfig>, dev: &Device) -> Result<Self> {
679 let cfg: PhiRopeConfig = cfg.into();
680
681 match &cfg.rope_scaling {
682 Some(PhiRopeScalingConfig::Classic {
683 short_factor,
684 long_factor,
685 scaling_type,
686 }) => {
687 Self::new_classic_scaled(short_factor, long_factor, scaling_type, &cfg, dtype, dev)
688 }
689
690 Some(PhiRopeScalingConfig::Scaled {
691 short_factor,
692 long_factor,
693 scaling_type,
694 long_mscale,
695 short_mscale,
696 }) => Self::new_scaled(
697 short_factor,
698 long_factor,
699 scaling_type,
700 *long_mscale,
701 *short_mscale,
702 &cfg,
703 dtype,
704 dev,
705 ),
706
707 None => Self::new_unscaled(&cfg, dtype, dev),
708 }
709 }
710
711 fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
713 if self.long_cos.is_none() {
714 return (&self.short_sin, &self.short_cos);
715 }
716 let seq_len = position_ids.iter().max().unwrap() + 1;
717 if seq_len > self.original_max_position_embeddings {
718 (
719 self.long_sin.as_ref().unwrap(),
720 self.long_cos.as_ref().unwrap(),
721 )
722 } else {
723 (&self.short_sin, &self.short_cos)
724 }
725 }
726
727 pub fn forward(
728 &self,
729 q: &Tensor,
730 k: &Tensor,
731 seqlen_offsets: &[usize],
732 position_ids: &[usize],
733 ) -> Result<(Tensor, Tensor)> {
734 let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
735 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
736
737 let rot_dim = cos.dim(D::Minus1)? * 2;
738
739 if rot_dim != q.dim(D::Minus1)? {
741 let rot_dim = cos.dim(D::Minus1)? * 2;
742 let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
743 let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
744 let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
745 let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
746
747 let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
748 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
749 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
750 let q_embed = hanzo_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
751 let k_embed = hanzo_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
752 (q_embed, k_embed)
753 } else {
754 let mut q_embeds = Vec::new();
755 let mut k_embeds = Vec::new();
756 for (i, offset) in seqlen_offsets.iter().enumerate() {
757 let cos = cos.narrow(0, *offset, seq_len)?;
758 let sin = sin.narrow(0, *offset, seq_len)?;
759 let q_embed = hanzo_nn::rotary_emb::rope(
760 &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
761 &cos,
762 &sin,
763 )?;
764 let k_embed = hanzo_nn::rotary_emb::rope(
765 &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
766 &cos,
767 &sin,
768 )?;
769 q_embeds.push(q_embed);
770 k_embeds.push(k_embed);
771 }
772 let q_rot = Tensor::cat(&q_embeds, 0)?;
773 let k_rot = Tensor::cat(&k_embeds, 0)?;
774 (q_rot, k_rot)
775 };
776
777 Ok((
778 Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
779 Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
780 ))
781 } else if seqlen_offsets.len() == 1 {
782 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
783 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
784 let q_embed = hanzo_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
785 let k_embed = hanzo_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
786 Ok((q_embed, k_embed))
787 } else {
788 let mut q_embeds = Vec::new();
789 let mut k_embeds = Vec::new();
790 for (i, offset) in seqlen_offsets.iter().enumerate() {
791 let cos = cos.narrow(0, *offset, seq_len)?;
792 let sin = sin.narrow(0, *offset, seq_len)?;
793 let q_embed =
794 hanzo_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
795 let k_embed =
796 hanzo_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
797 q_embeds.push(q_embed);
798 k_embeds.push(k_embed);
799 }
800 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
801 }
802 }
803}
804
805#[derive(Debug, Clone)]
807pub struct Llama3RotaryEmbedding(RotaryEmbedding);
808
809#[derive(Debug, Clone, Deserialize, Serialize, Default)]
810pub enum Llama3RopeType {
811 #[serde(rename = "llama3")]
812 Llama3,
813 #[serde(rename = "linear")]
814 Linear,
815 #[default]
816 #[serde(rename = "default")]
817 Default,
818}
819
820#[derive(Debug, Clone, Deserialize, Serialize, Default)]
821pub struct Llama3RopeConfig {
822 pub factor: f32,
823 pub low_freq_factor: Option<f32>,
824 pub high_freq_factor: Option<f32>,
825 pub original_max_position_embeddings: Option<usize>,
826 pub rope_type: Llama3RopeType,
827}
828
829fn calculate_default_inv_freq(cfg: &llama::Config) -> Vec<f32> {
830 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
831 (0..head_dim)
832 .step_by(2)
833 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
834 .collect()
835}
836
837fn calculate_default_inv_freq_llama4(cfg: &llama4::TextConfig) -> Vec<f32> {
838 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
839 (0..head_dim)
840 .step_by(2)
841 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
842 .collect()
843}
844
845impl Llama3RotaryEmbedding {
847 pub fn new_llama3(
848 dtype: DType,
849 cfg: &llama::Config,
850 dev: &Device,
851 is_gpt_neox: bool,
852 ) -> Result<Self> {
853 match &cfg.rope_scaling {
854 None
855 | Some(Llama3RopeConfig {
856 rope_type: Llama3RopeType::Default,
857 ..
858 }) => Ok(Self(RotaryEmbedding::new(
859 cfg.rope_theta,
860 cfg.hidden_size / cfg.num_attention_heads,
861 cfg.max_position_embeddings,
862 dev,
863 is_gpt_neox,
864 dtype,
865 )?)),
866 Some(Llama3RopeConfig {
867 rope_type: Llama3RopeType::Llama3,
868 factor,
869 low_freq_factor,
870 high_freq_factor,
871 original_max_position_embeddings,
872 }) => {
873 let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
874 let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
875 let original_max_position_embeddings = original_max_position_embeddings
876 .context("original_max_position_embeddings is required")?;
877
878 let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
879 let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
880
881 let inv_freq = calculate_default_inv_freq(cfg)
882 .into_iter()
883 .map(|freq| {
884 let wavelen = 2. * PI / freq;
885 if wavelen < high_freq_wavelen {
886 freq
887 } else if wavelen > low_freq_wavelen {
888 freq / *factor
889 } else {
890 let smooth = (original_max_position_embeddings as f32 / wavelen
891 - low_freq_factor)
892 / (high_freq_factor - low_freq_factor);
893 (1. - smooth) * freq / *factor + smooth * freq
894 }
895 })
896 .collect::<Vec<_>>();
897 let inv_freq_len = inv_freq.len();
898 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
899 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
900 .to_dtype(DType::F32)?
901 .reshape((cfg.max_position_embeddings, 1))?;
902 let freqs = t.matmul(&inv_freq)?;
903 let sin = freqs.sin()?.to_dtype(dtype)?;
904 let cos = freqs.cos()?.to_dtype(dtype)?;
905 Ok(Self(RotaryEmbedding {
906 sin,
907 cos,
908 is_gpt_neox,
909 }))
910 }
911 Some(Llama3RopeConfig {
912 rope_type: Llama3RopeType::Linear,
913 factor,
914 ..
915 }) => {
916 let inv_freq_vec = calculate_default_inv_freq(cfg)
917 .into_iter()
918 .map(|freq| freq / *factor)
919 .collect::<Vec<_>>();
920 let inv_freq_len = inv_freq_vec.len();
921 let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
922 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
923 .to_dtype(DType::F32)?
924 .reshape((cfg.max_position_embeddings, 1))?;
925 let freqs = t.matmul(&inv_freq)?;
926 let sin = freqs.sin()?.to_dtype(dtype)?;
927 let cos = freqs.cos()?.to_dtype(dtype)?;
928 Ok(Self(RotaryEmbedding {
929 sin,
930 cos,
931 is_gpt_neox,
932 }))
933 }
934 }
935 }
936
937 pub fn new_llama4(
938 dtype: DType,
939 cfg: &llama4::TextConfig,
940 dev: &Device,
941 is_gpt_neox: bool,
942 ) -> Result<Self> {
943 match &cfg.rope_scaling {
944 None
945 | Some(Llama3RopeConfig {
946 rope_type: Llama3RopeType::Default,
947 ..
948 }) => Ok(Self(RotaryEmbedding::new(
949 cfg.rope_theta,
950 cfg.hidden_size / cfg.num_attention_heads,
951 cfg.max_position_embeddings,
952 dev,
953 is_gpt_neox,
954 dtype,
955 )?)),
956 Some(Llama3RopeConfig {
957 rope_type: Llama3RopeType::Llama3,
958 factor,
959 low_freq_factor,
960 high_freq_factor,
961 original_max_position_embeddings,
962 }) => {
963 let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
964 let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
965 let original_max_position_embeddings = original_max_position_embeddings
966 .context("original_max_position_embeddings is required")?;
967
968 let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
969 let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
970
971 let inv_freq = calculate_default_inv_freq_llama4(cfg)
972 .into_iter()
973 .map(|freq| {
974 let wavelen = 2. * PI / freq;
975 if wavelen < high_freq_wavelen {
976 freq
977 } else if wavelen > low_freq_wavelen {
978 freq / *factor
979 } else {
980 let smooth = (original_max_position_embeddings as f32 / wavelen
981 - low_freq_factor)
982 / (high_freq_factor - low_freq_factor);
983 (1. - smooth) * freq / *factor + smooth * freq
984 }
985 })
986 .collect::<Vec<_>>();
987 let inv_freq_len = inv_freq.len();
988 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
989 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
990 .to_dtype(DType::F32)?
991 .reshape((cfg.max_position_embeddings, 1))?;
992 let freqs = t.matmul(&inv_freq)?;
993 let sin = freqs.sin()?.to_dtype(dtype)?;
994 let cos = freqs.cos()?.to_dtype(dtype)?;
995 Ok(Self(RotaryEmbedding {
996 sin,
997 cos,
998 is_gpt_neox,
999 }))
1000 }
1001 Some(Llama3RopeConfig {
1002 rope_type: Llama3RopeType::Linear,
1003 factor,
1004 ..
1005 }) => {
1006 let inv_freq_vec = calculate_default_inv_freq_llama4(cfg)
1007 .into_iter()
1008 .map(|freq| freq / *factor)
1009 .collect::<Vec<_>>();
1010 let inv_freq_len = inv_freq_vec.len();
1011 let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
1012 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1013 .to_dtype(DType::F32)?
1014 .reshape((cfg.max_position_embeddings, 1))?;
1015 let freqs = t.matmul(&inv_freq)?;
1016 let sin = freqs.sin()?.to_dtype(dtype)?;
1017 let cos = freqs.cos()?.to_dtype(dtype)?;
1018 Ok(Self(RotaryEmbedding {
1019 sin,
1020 cos,
1021 is_gpt_neox,
1022 }))
1023 }
1024 }
1025 }
1026
1027 pub fn new_mllama3(
1028 dtype: DType,
1029 cfg: &MLlamaTextConfig,
1030 dev: &Device,
1031 is_gpt_neox: bool,
1032 ) -> Result<Self> {
1033 match &cfg.rope_scaling {
1034 None
1035 | Some(MLlamaRopeScaling {
1036 rope_type: MLlamaRopeType::Default,
1037 ..
1038 }) => Ok(Self(RotaryEmbedding::new(
1039 cfg.rope_theta,
1040 cfg.hidden_size / cfg.num_attention_heads,
1041 cfg.max_position_embeddings,
1042 dev,
1043 is_gpt_neox,
1044 dtype,
1045 )?)),
1046 Some(MLlamaRopeScaling {
1047 rope_type: MLlamaRopeType::Llama3,
1048 original_max_position_embeddings,
1049 factor,
1050 attention_factor: _,
1051 beta_fast: _,
1052 beta_slow: _,
1053 short_factor: _,
1054 long_factor: _,
1055 low_freq_factor,
1056 high_freq_factor,
1057 }) => {
1058 let factor = factor.context("MLlama Llama3 RoPE needs `factor` parameter.")?;
1059 let low_freq_factor = low_freq_factor
1060 .context("MLlama Llama3 RoPE needs `low_freq_factor` parameter.")?;
1061 let high_freq_factor = high_freq_factor
1062 .context("MLlama Llama3 RoPE needs `high_freq_factor` parameter.")?;
1063
1064 let low_freq_wavelen = *original_max_position_embeddings as f32 / low_freq_factor;
1065 let high_freq_wavelen = *original_max_position_embeddings as f32 / high_freq_factor;
1066
1067 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1068
1069 let inv_freq = (0..head_dim)
1070 .step_by(2)
1071 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
1072 .map(|freq| {
1073 let wavelen = 2. * PI / freq;
1074 if wavelen < high_freq_wavelen {
1075 freq
1076 } else if wavelen > low_freq_wavelen {
1077 freq / factor
1078 } else {
1079 let smooth = (*original_max_position_embeddings as f32 / wavelen
1080 - low_freq_factor)
1081 / (high_freq_factor - low_freq_factor);
1082 (1. - smooth) * freq / factor + smooth * freq
1083 }
1084 })
1085 .collect::<Vec<_>>();
1086 let inv_freq_len = inv_freq.len();
1087 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1088
1089 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1090 .to_dtype(DType::F32)?
1091 .reshape((cfg.max_position_embeddings, 1))?;
1092 let freqs = t.matmul(&inv_freq)?;
1093 let sin = freqs.sin()?.to_dtype(dtype)?;
1094 let cos = freqs.cos()?.to_dtype(dtype)?;
1095 Ok(Self(RotaryEmbedding {
1096 sin,
1097 cos,
1098 is_gpt_neox,
1099 }))
1100 }
1101 Some(MLlamaRopeScaling {
1102 rope_type: other, ..
1103 }) => {
1104 hanzo_ml::bail!(
1105 "MLlama doesn't support any other RoPE type than `llama3`, got {other:?}"
1106 )
1107 }
1108 }
1109 }
1110
1111 pub fn forward(
1112 &self,
1113 q: &Tensor,
1114 k: &Tensor,
1115 seqlen_offsets: &[usize],
1116 ) -> Result<(Tensor, Tensor)> {
1117 self.0.forward(q, k, seqlen_offsets)
1118 }
1119
1120 pub fn forward_q_norm(
1121 &self,
1122 q: &Tensor,
1123 q_weight: &Tensor,
1124 q_eps: f64,
1125 seqlen_offsets: &[usize],
1126 ) -> Result<Tensor> {
1127 self.0.forward_q_norm(q, q_weight, q_eps, seqlen_offsets)
1128 }
1129
1130 #[allow(clippy::too_many_arguments)]
1131 pub fn forward_qk_norm(
1132 &self,
1133 q: &Tensor,
1134 k: &Tensor,
1135 q_weight: &Tensor,
1136 k_weight: &Tensor,
1137 q_eps: f64,
1138 k_eps: f64,
1139 seqlen_offsets: &[usize],
1140 ) -> Result<(Tensor, Tensor)> {
1141 self.0
1142 .forward_qk_norm(q, k, q_weight, k_weight, q_eps, k_eps, seqlen_offsets)
1143 }
1144}
1145
1146#[derive(Debug, Clone)]
1148pub struct SmolLm3RotaryEmbedding(RotaryEmbedding);
1149
1150#[derive(Debug, Clone, Deserialize, Serialize, Default)]
1151pub enum SmolLm3RopeType {
1152 #[serde(rename = "llama3")]
1153 Llama3,
1154 #[serde(rename = "linear")]
1155 Linear,
1156 #[default]
1157 #[serde(rename = "default")]
1158 Default,
1159}
1160
1161#[derive(Debug, Clone, Deserialize, Serialize, Default)]
1162pub struct SmolLm3RopeConfig {
1163 pub factor: f32,
1164 pub low_freq_factor: Option<f32>,
1165 pub high_freq_factor: Option<f32>,
1166 pub original_max_position_embeddings: Option<usize>,
1167 pub rope_type: SmolLm3RopeType,
1168}
1169
1170fn calculate_default_inv_freq_smollm3(cfg: &smollm3::Config) -> Vec<f32> {
1171 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1172 (0..head_dim)
1173 .step_by(2)
1174 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
1175 .collect()
1176}
1177
1178impl SmolLm3RotaryEmbedding {
1179 pub fn new_llama3(
1180 dtype: DType,
1181 cfg: &smollm3::Config,
1182 dev: &Device,
1183 is_gpt_neox: bool,
1184 ) -> Result<Self> {
1185 match &cfg.rope_scaling {
1186 None
1187 | Some(SmolLm3RopeConfig {
1188 rope_type: SmolLm3RopeType::Default,
1189 ..
1190 }) => Ok(Self(RotaryEmbedding::new(
1191 cfg.rope_theta,
1192 cfg.hidden_size / cfg.num_attention_heads,
1193 cfg.max_position_embeddings,
1194 dev,
1195 is_gpt_neox,
1196 dtype,
1197 )?)),
1198 Some(SmolLm3RopeConfig {
1199 rope_type: SmolLm3RopeType::Llama3,
1200 factor,
1201 low_freq_factor,
1202 high_freq_factor,
1203 original_max_position_embeddings,
1204 }) => {
1205 let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
1206 let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
1207 let original_max_position_embeddings = original_max_position_embeddings
1208 .context("original_max_position_embeddings is required")?;
1209
1210 let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
1211 let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
1212
1213 let inv_freq = calculate_default_inv_freq_smollm3(cfg)
1214 .into_iter()
1215 .map(|freq| {
1216 let wavelen = 2. * PI / freq;
1217 if wavelen < high_freq_wavelen {
1218 freq
1219 } else if wavelen > low_freq_wavelen {
1220 freq / *factor
1221 } else {
1222 let smooth = (original_max_position_embeddings as f32 / wavelen
1223 - low_freq_factor)
1224 / (high_freq_factor - low_freq_factor);
1225 (1. - smooth) * freq / *factor + smooth * freq
1226 }
1227 })
1228 .collect::<Vec<_>>();
1229 let inv_freq_len = inv_freq.len();
1230 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1231 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1232 .to_dtype(DType::F32)?
1233 .reshape((cfg.max_position_embeddings, 1))?;
1234 let freqs = t.matmul(&inv_freq)?;
1235 let sin = freqs.sin()?.to_dtype(dtype)?;
1236 let cos = freqs.cos()?.to_dtype(dtype)?;
1237 Ok(Self(RotaryEmbedding {
1238 sin,
1239 cos,
1240 is_gpt_neox,
1241 }))
1242 }
1243 Some(SmolLm3RopeConfig {
1244 rope_type: SmolLm3RopeType::Linear,
1245 factor,
1246 ..
1247 }) => {
1248 let inv_freq_vec = calculate_default_inv_freq_smollm3(cfg)
1249 .into_iter()
1250 .map(|freq| freq / *factor)
1251 .collect::<Vec<_>>();
1252 let inv_freq_len = inv_freq_vec.len();
1253 let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
1254 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1255 .to_dtype(DType::F32)?
1256 .reshape((cfg.max_position_embeddings, 1))?;
1257 let freqs = t.matmul(&inv_freq)?;
1258 let sin = freqs.sin()?.to_dtype(dtype)?;
1259 let cos = freqs.cos()?.to_dtype(dtype)?;
1260 Ok(Self(RotaryEmbedding {
1261 sin,
1262 cos,
1263 is_gpt_neox,
1264 }))
1265 }
1266 }
1267 }
1268
1269 pub fn forward(
1270 &self,
1271 q: &Tensor,
1272 k: &Tensor,
1273 seqlen_offsets: &[usize],
1274 ) -> Result<(Tensor, Tensor)> {
1275 self.0.forward(q, k, seqlen_offsets)
1276 }
1277
1278 pub fn forward_q_norm(
1279 &self,
1280 q: &Tensor,
1281 q_weight: &Tensor,
1282 q_eps: f64,
1283 seqlen_offsets: &[usize],
1284 ) -> Result<Tensor> {
1285 self.0.forward_q_norm(q, q_weight, q_eps, seqlen_offsets)
1286 }
1287
1288 #[allow(clippy::too_many_arguments)]
1289 pub fn forward_qk_norm(
1290 &self,
1291 q: &Tensor,
1292 k: &Tensor,
1293 q_weight: &Tensor,
1294 k_weight: &Tensor,
1295 q_eps: f64,
1296 k_eps: f64,
1297 seqlen_offsets: &[usize],
1298 ) -> Result<(Tensor, Tensor)> {
1299 self.0
1300 .forward_qk_norm(q, k, q_weight, k_weight, q_eps, k_eps, seqlen_offsets)
1301 }
1302}
1303
1304#[derive(Debug, Clone)]
1306pub struct Qwen2VLRotaryEmbedding {
1307 inv_freq: Tensor,
1308 mrope_section: Vec<usize>,
1309}
1310
1311impl Qwen2VLRotaryEmbedding {
1312 pub fn new(
1313 base: f32,
1314 head_dim: usize,
1315 device: &Device,
1316 mrope_section: Vec<usize>,
1317 ) -> Result<Self> {
1318 let inv_freq: Vec<_> = (0..head_dim)
1319 .step_by(2)
1320 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1321 .collect();
1322 let inv_freq_len = inv_freq.len();
1323 let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1324 Ok(Self {
1325 inv_freq,
1326 mrope_section,
1327 })
1328 }
1329
1330 pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1332 let inv_freq_expanded =
1333 self.inv_freq
1334 .reshape((1, 1, (), 1))?
1335 .repeat((3, position_ids.dim(1)?, 1, 1))?;
1336 let position_ids_expanded = position_ids.unsqueeze(2)?;
1337 let freqs = inv_freq_expanded
1338 .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1339 .transpose(2, 3)?;
1340 let cos = freqs.cos()?;
1341 let sin = freqs.sin()?;
1342
1343 let cos = Tensor::cat(
1344 &cos.split(&self.mrope_section, D::Minus1)?
1345 .into_iter()
1346 .enumerate()
1347 .map(|(i, m)| m.i(i % 3))
1348 .collect::<Result<Vec<_>>>()?,
1349 D::Minus1,
1350 )?
1351 .squeeze(0)?
1352 .to_dtype(dtype)?
1353 .contiguous()?;
1354 let sin = Tensor::cat(
1355 &sin.split(&self.mrope_section, D::Minus1)?
1356 .into_iter()
1357 .enumerate()
1358 .map(|(i, m)| m.i(i % 3))
1359 .collect::<Result<Vec<_>>>()?,
1360 D::Minus1,
1361 )?
1362 .squeeze(0)?
1363 .to_dtype(dtype)?
1364 .contiguous()?;
1365
1366 Ok((cos, sin))
1367 }
1368
1369 pub fn forward(
1371 &self,
1372 (cos, sin): &(Tensor, Tensor),
1373 q: &mut Tensor,
1374 k: &mut Tensor,
1375 ) -> Result<()> {
1376 *q = hanzo_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1377 *k = hanzo_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1378 Ok(())
1379 }
1380
1381 #[allow(clippy::too_many_arguments)]
1382 pub fn forward_qk_norm(
1383 &self,
1384 (cos, sin): &(Tensor, Tensor),
1385 q: &Tensor,
1386 k: &Tensor,
1387 q_weight: &Tensor,
1388 k_weight: &Tensor,
1389 q_eps: f64,
1390 k_eps: f64,
1391 ) -> Result<(Tensor, Tensor)> {
1392 qk_rms_norm_mrope(q, k, q_weight, k_weight, q_eps, k_eps, cos, sin, true)
1393 }
1394}
1395
1396#[derive(Debug, Clone)]
1400pub struct Qwen3VLRotaryEmbedding {
1401 inv_freq: Tensor,
1402 interleave_indices: Vec<(Tensor, usize)>,
1405}
1406
1407impl Qwen3VLRotaryEmbedding {
1408 pub fn new(
1409 base: f32,
1410 head_dim: usize,
1411 device: &Device,
1412 mrope_section: Vec<usize>,
1413 ) -> Result<Self> {
1414 let inv_freq: Vec<_> = (0..head_dim)
1415 .step_by(2)
1416 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1417 .collect();
1418 let inv_freq_len = inv_freq.len();
1419 let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1420
1421 let half_dim = head_dim / 2;
1424 let mut interleave_indices = Vec::new();
1425 for (dim_idx, offset) in [(1usize, 1usize), (2usize, 2usize)] {
1426 let indices: Vec<u32> = (offset..)
1427 .step_by(3)
1428 .take(mrope_section[dim_idx])
1429 .filter(|&i| i < half_dim)
1430 .map(|i| i as u32)
1431 .collect();
1432 if !indices.is_empty() {
1433 let num = indices.len();
1434 let idx_tensor = Tensor::from_vec(indices, (num,), device)?;
1435 interleave_indices.push((idx_tensor, dim_idx));
1436 }
1437 }
1438
1439 Ok(Self {
1440 inv_freq,
1441 interleave_indices,
1442 })
1443 }
1444
1445 pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1449 let inv_freq_expanded =
1451 self.inv_freq
1452 .reshape((1, 1, (), 1))?
1453 .repeat((3, position_ids.dim(1)?, 1, 1))?;
1454 let position_ids_expanded = position_ids.unsqueeze(2)?;
1456 let freqs = inv_freq_expanded
1459 .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1460 .transpose(2, 3)?;
1461
1462 let mut freqs_t = freqs.i(0)?.contiguous()?;
1465 let (batch, seq_len, _) = freqs_t.dims3()?;
1466
1467 for (idx_tensor, dim_idx) in &self.interleave_indices {
1469 let freqs_dim = freqs.i(*dim_idx)?.contiguous()?;
1470 let num_indices = idx_tensor.dim(0)?;
1471 let idx_expanded = idx_tensor
1472 .reshape((1, 1, num_indices))?
1473 .repeat((batch, seq_len, 1))?;
1474 let src_vals = freqs_dim.gather(&idx_expanded, D::Minus1)?;
1475 freqs_t = freqs_t.scatter(&idx_expanded, &src_vals, D::Minus1)?;
1476 }
1477
1478 let cos = freqs_t.cos()?.to_dtype(dtype)?.contiguous()?;
1481 let sin = freqs_t.sin()?.to_dtype(dtype)?.contiguous()?;
1482 Ok((cos, sin))
1483 }
1484
1485 pub fn forward(
1486 &self,
1487 (cos, sin): &(Tensor, Tensor),
1488 q: &mut Tensor,
1489 k: &mut Tensor,
1490 ) -> Result<()> {
1491 *q = hanzo_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1492 *k = hanzo_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1493 Ok(())
1494 }
1495
1496 #[allow(clippy::too_many_arguments)]
1497 pub fn forward_qk_norm(
1498 &self,
1499 (cos, sin): &(Tensor, Tensor),
1500 q: &Tensor,
1501 k: &Tensor,
1502 q_weight: &Tensor,
1503 k_weight: &Tensor,
1504 q_eps: f64,
1505 k_eps: f64,
1506 ) -> Result<(Tensor, Tensor)> {
1507 qk_rms_norm_mrope(q, k, q_weight, k_weight, q_eps, k_eps, cos, sin, true)
1508 }
1509}
1510
1511#[derive(Debug, Clone)]
1512pub struct Qwen2_5VLRotaryEmbedding {
1513 inv_freq: Tensor,
1514 mrope_section: Vec<usize>,
1515}
1516
1517impl Qwen2_5VLRotaryEmbedding {
1518 pub fn new(
1519 base: f32,
1520 head_dim: usize,
1521 device: &Device,
1522 mrope_section: Vec<usize>,
1523 ) -> Result<Self> {
1524 let inv_freq: Vec<_> = (0..head_dim)
1525 .step_by(2)
1526 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1527 .collect();
1528 let inv_freq_len = inv_freq.len();
1529 let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1530 Ok(Self {
1531 inv_freq,
1532 mrope_section,
1533 })
1534 }
1535
1536 pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1538 let inv_freq_expanded =
1539 self.inv_freq
1540 .reshape((1, 1, (), 1))?
1541 .repeat((3, position_ids.dim(1)?, 1, 1))?;
1542 let position_ids_expanded = position_ids.unsqueeze(2)?;
1543 let freqs = inv_freq_expanded
1544 .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1545 .transpose(2, 3)?;
1546 let cos = freqs.cos()?;
1547 let sin = freqs.sin()?;
1548
1549 let cos = Tensor::cat(
1550 &cos.split(&self.mrope_section, D::Minus1)?
1551 .into_iter()
1552 .enumerate()
1553 .map(|(i, m)| m.i(i % 3))
1554 .collect::<Result<Vec<_>>>()?,
1555 D::Minus1,
1556 )?
1557 .squeeze(0)?
1558 .to_dtype(dtype)?
1559 .contiguous()?;
1560 let sin = Tensor::cat(
1561 &sin.split(&self.mrope_section, D::Minus1)?
1562 .into_iter()
1563 .enumerate()
1564 .map(|(i, m)| m.i(i % 3))
1565 .collect::<Result<Vec<_>>>()?,
1566 D::Minus1,
1567 )?
1568 .squeeze(0)?
1569 .to_dtype(dtype)?
1570 .contiguous()?;
1571
1572 Ok((cos, sin))
1573 }
1574
1575 pub fn forward(
1576 &self,
1577 (cos, sin): &(Tensor, Tensor),
1578 q: &mut Tensor,
1579 k: &mut Tensor,
1580 ) -> Result<()> {
1581 *q = hanzo_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1582 *k = hanzo_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1583 Ok(())
1584 }
1585}
1586
1587#[derive(Debug, Clone)]
1588pub struct DeepSeekV2RotaryEmbedding {
1589 sin: Tensor,
1590 cos: Tensor,
1591}
1592
1593#[derive(Debug, Clone, Deserialize, Serialize)]
1594#[serde(untagged)]
1595pub enum DeepSeekV2RopeScaling {
1596 Yarn {
1597 original_max_position_embeddings: usize,
1598 beta_fast: f32,
1599 beta_slow: f32,
1600 mscale: f32,
1601 mscale_all_dim: f32,
1602 factor: f32,
1603 #[serde(rename = "type")]
1604 scaling_type: ScaledRopeType,
1605 },
1606 LinearOrDynamic {
1607 #[serde(rename = "type")]
1608 scaling_type: ScaledRopeType,
1609 factor: f64,
1610 },
1611}
1612
1613pub struct DeepSeekV2RopeConfig {
1614 pub rope_scaling: Option<DeepSeekV2RopeScaling>,
1615 pub max_position_embeddings: usize,
1616 pub rope_theta: f32,
1617 pub qk_rope_head_dim: usize,
1618}
1619
1620impl DeepSeekV2RotaryEmbedding {
1621 fn new_unscaled(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1622 let max_seq_len = cfg.max_position_embeddings;
1623 let dim = cfg.qk_rope_head_dim;
1624
1625 let inv_freq: Vec<_> = (0..dim)
1626 .step_by(2)
1627 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
1628 .collect();
1629 let inv_freq_len = inv_freq.len();
1630 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1631 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1632 .to_dtype(DType::F32)?
1633 .reshape((max_seq_len, 1))?;
1634 let freqs = t.matmul(&inv_freq)?;
1635
1636 let sin = freqs.sin()?.to_dtype(dtype)?;
1637 let cos = freqs.cos()?.to_dtype(dtype)?;
1638
1639 Ok(Self { sin, cos })
1640 }
1641
1642 fn yarn_find_correction_dim(
1643 num_rot: f32,
1644 dim: usize,
1645 base: f32,
1646 max_position_embeddings: usize,
1647 ) -> f32 {
1648 (dim as f32 * (max_position_embeddings as f32 / (num_rot * 2. * PI)).ln())
1649 / (2. * base.ln())
1650 }
1651
1652 fn yarn_find_correction_range(
1653 low_rot: f32,
1654 high_rot: f32,
1655 dim: usize,
1656 base: f32,
1657 max_position_embeddings: usize,
1658 ) -> (f32, f32) {
1659 let low =
1660 Self::yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings).floor();
1661 let high =
1662 Self::yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings).ceil();
1663 (low.max(0.), high.min(dim as f32 - 1.))
1664 }
1665
1666 fn yarn_linear_ramp_mask(min: f32, mut max: f32, dim: usize, dev: &Device) -> Result<Tensor> {
1667 if min == max {
1668 max += 0.001;
1670 }
1671 let linear_func =
1672 ((Tensor::arange(0f32, dim as f32, dev)? - min as f64)? / (max as f64 - min as f64))?;
1673 linear_func.clamp(0., 1)
1674 }
1675
1676 pub(crate) fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 {
1677 if scale <= 1. {
1678 return 1.;
1679 }
1680 0.1 * mscale * scale.ln() + 1.
1681 }
1682
1683 #[allow(clippy::too_many_arguments)]
1684 fn new_yarn(
1685 cfg: &DeepSeekV2RopeConfig,
1686 dtype: DType,
1687 dev: &Device,
1688 original_max_position_embeddings: usize,
1689 beta_fast: f32,
1690 beta_slow: f32,
1691 factor: f32,
1692 mscale: f32,
1693 mscale_all_dim: f32,
1694 ) -> Result<Self> {
1695 let freq_extra: Vec<_> = (0..cfg.qk_rope_head_dim)
1696 .step_by(2)
1697 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32))
1698 .collect();
1699 let freq_extra_len = freq_extra.len();
1700 let freq_extra = Tensor::from_vec(freq_extra, freq_extra_len, dev)?;
1701 let freq_inter: Vec<_> = (0..cfg.qk_rope_head_dim)
1702 .step_by(2)
1703 .map(|i| 1f32 / (factor * cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32)))
1704 .collect();
1705 let freq_inter_len = freq_inter.len();
1706 let freq_inter = Tensor::from_vec(freq_inter, (1, freq_inter_len), dev)?;
1707
1708 let (low, high) = Self::yarn_find_correction_range(
1709 beta_fast,
1710 beta_slow,
1711 cfg.qk_rope_head_dim,
1712 cfg.rope_theta,
1713 original_max_position_embeddings,
1714 );
1715 let inv_freq_mask =
1716 (1. - Self::yarn_linear_ramp_mask(low, high, cfg.qk_rope_head_dim / 2, dev)?)?;
1717 let inv_freq = freq_inter
1718 .broadcast_mul(&(1. - &inv_freq_mask)?)?
1719 .broadcast_add(&freq_extra.broadcast_mul(&inv_freq_mask)?)?;
1720
1721 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1722 .to_dtype(DType::F32)?
1723 .reshape((cfg.max_position_embeddings, 1))?;
1724 let freqs = t.matmul(&inv_freq)?;
1725
1726 let mscale =
1727 Self::yarn_get_mscale(factor, mscale) / Self::yarn_get_mscale(factor, mscale_all_dim);
1728 let sin = (freqs.sin()? * mscale as f64)?.to_dtype(dtype)?;
1729 let cos = (freqs.cos()? * mscale as f64)?.to_dtype(dtype)?;
1730
1731 Ok(Self { sin, cos })
1732 }
1733
1734 pub fn new(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1735 match &cfg.rope_scaling {
1736 Some(DeepSeekV2RopeScaling::LinearOrDynamic {
1737 scaling_type: _,
1738 factor: _,
1739 }) => hanzo_ml::bail!("linear and dynamic rope are not implemented yet!"),
1740 Some(DeepSeekV2RopeScaling::Yarn {
1741 original_max_position_embeddings,
1742 beta_fast,
1743 beta_slow,
1744 factor,
1745 mscale,
1746 mscale_all_dim,
1747 scaling_type: _,
1748 }) => Self::new_yarn(
1749 cfg,
1750 dtype,
1751 dev,
1752 *original_max_position_embeddings,
1753 *beta_fast,
1754 *beta_slow,
1755 *factor,
1756 *mscale,
1757 *mscale_all_dim,
1758 ),
1759 None => Self::new_unscaled(cfg, dtype, dev),
1760 }
1761 }
1762
1763 pub fn forward(
1764 &self,
1765 q: &Tensor,
1766 k: &Tensor,
1767 seqlen_offsets: &[usize],
1768 ) -> Result<(Tensor, Tensor)> {
1769 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1770
1771 if seqlen_offsets.len() == 1 {
1772 let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1773 let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1774 let q_embed = hanzo_nn::rotary_emb::rope_i(&q.contiguous()?, &cos, &sin)?;
1775 let k_embed = hanzo_nn::rotary_emb::rope_i(&k.contiguous()?, &cos, &sin)?;
1776 Ok((q_embed, k_embed))
1777 } else {
1778 let mut q_embeds = Vec::new();
1779 let mut k_embeds = Vec::new();
1780 for (i, offset) in seqlen_offsets.iter().enumerate() {
1781 let cos = self.cos.narrow(0, *offset, seq_len)?;
1782 let sin = self.sin.narrow(0, *offset, seq_len)?;
1783 let q_embed =
1784 hanzo_nn::rotary_emb::rope_i(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
1785 let k_embed =
1786 hanzo_nn::rotary_emb::rope_i(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
1787 q_embeds.push(q_embed);
1788 k_embeds.push(k_embed);
1789 }
1790 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1791 }
1792 }
1793}
1794
1795#[derive(Debug, Clone)]
1796pub struct Phi4MMRotaryEmbedding {
1797 short_sin: Tensor,
1798 short_cos: Tensor,
1799 long_cos: Option<Tensor>,
1800 long_sin: Option<Tensor>,
1801 original_max_position_embeddings: usize,
1802}
1803
1804#[derive(Debug, Clone, Default, Deserialize, Serialize)]
1805#[serde(rename_all = "lowercase")]
1806pub enum Phi4MMScaledRopeType {
1807 #[serde(alias = "longrope")]
1808 LongRope,
1809 #[default]
1810 Default,
1811}
1812
1813#[derive(Debug, Clone, Deserialize, Serialize)]
1814pub struct Phi4MMRopeScalingConfig {
1815 short_factor: Option<Vec<f64>>,
1816 long_factor: Option<Vec<f64>>,
1817 #[serde(rename = "type")]
1818 scaling_type: Phi4MMScaledRopeType,
1819}
1820
1821impl Phi4MMRotaryEmbedding {
1822 fn new_unscaled(cfg: &Phi4MMConfig, dtype: DType, dev: &Device) -> Result<Self> {
1823 let max_seq_len = cfg.max_position_embeddings;
1824 let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1825
1826 let inv_freq: Vec<_> = (0..dim)
1827 .step_by(2)
1828 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1829 .collect();
1830 let inv_freq_len = inv_freq.len();
1831 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1832 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1833 .to_dtype(DType::F32)?
1834 .reshape((max_seq_len, 1))?;
1835 let freqs = t.matmul(&inv_freq)?;
1836 let sin = freqs.sin()?.to_dtype(dtype)?;
1837 let cos = freqs.cos()?.to_dtype(dtype)?;
1838 Ok(Self {
1839 short_cos: cos,
1840 short_sin: sin,
1841 long_cos: None,
1842 long_sin: None,
1843 original_max_position_embeddings: cfg.original_max_position_embeddings,
1844 })
1845 }
1846
1847 #[allow(clippy::too_many_arguments)]
1848 fn new_longrope(
1849 short_factor: &[f64],
1850 long_factor: &[f64],
1851 cfg: &Phi4MMConfig,
1852 dtype: DType,
1853 dev: &Device,
1854 ) -> Result<Self> {
1855 let max_seq_len = cfg.max_position_embeddings;
1856 let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1857
1858 let scale =
1860 cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
1861 let scaling_factor = if scale <= 1.0 {
1862 1.0
1863 } else {
1864 (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
1865 };
1866
1867 let inv_freq_short: Vec<_> = (0..dim)
1869 .step_by(2)
1870 .enumerate()
1871 .map(|(k, i)| {
1872 1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1873 })
1874 .collect();
1875 let inv_freq_len_short = inv_freq_short.len();
1876 let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
1877 let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
1878 .to_dtype(DType::F32)?
1879 .reshape((max_seq_len, 1))?;
1880 let freqs_short = t_short.matmul(&inv_freq_short)?;
1881 let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * scaling_factor)?;
1882 let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * scaling_factor)?;
1883
1884 let inv_freq_long: Vec<_> = (0..dim)
1886 .step_by(2)
1887 .enumerate()
1888 .map(|(k, i)| {
1889 1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1890 })
1891 .collect();
1892 let inv_freq_len_long = inv_freq_long.len();
1893 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
1894 let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
1895 .to_dtype(DType::F32)?
1896 .reshape((max_seq_len, 1))?;
1897 let freqs_long = t_long.matmul(&inv_freq_long)?;
1898 let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * scaling_factor)?;
1899 let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * scaling_factor)?;
1900
1901 Ok(Self {
1902 short_cos: cos_short,
1903 short_sin: sin_short,
1904 long_cos: Some(cos_long),
1905 long_sin: Some(sin_long),
1906 original_max_position_embeddings: cfg.original_max_position_embeddings,
1907 })
1908 }
1909
1910 pub fn new(dtype: DType, cfg: &Phi4MMConfig, dev: &Device) -> Result<Self> {
1911 match &cfg.rope_scaling {
1912 Some(Phi4MMRopeScalingConfig {
1913 scaling_type: Phi4MMScaledRopeType::LongRope,
1914 short_factor: Some(short_factor),
1915 long_factor: Some(long_factor),
1916 }) => Self::new_longrope(short_factor, long_factor, cfg, dtype, dev),
1917
1918 _ => Self::new_unscaled(cfg, dtype, dev),
1919 }
1920 }
1921
1922 fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
1924 if self.long_cos.is_none() {
1925 return (&self.short_sin, &self.short_cos);
1926 }
1927 let seq_len = position_ids.iter().max().unwrap() + 1;
1928 if seq_len > self.original_max_position_embeddings {
1929 (
1930 self.long_sin.as_ref().unwrap(),
1931 self.long_cos.as_ref().unwrap(),
1932 )
1933 } else {
1934 (&self.short_sin, &self.short_cos)
1935 }
1936 }
1937
1938 pub fn forward(
1939 &self,
1940 q: &Tensor,
1941 k: &Tensor,
1942 seqlen_offsets: &[usize],
1943 position_ids: &[usize],
1944 ) -> Result<(Tensor, Tensor)> {
1945 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1946 let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
1947
1948 let rot_dim = cos.dim(D::Minus1)? * 2;
1949 let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
1950 let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
1951 let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
1952 let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
1953
1954 let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
1955 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
1956 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
1957 let q_embed = hanzo_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
1958 let k_embed = hanzo_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
1959 (q_embed, k_embed)
1960 } else {
1961 let mut q_embeds = Vec::new();
1962 let mut k_embeds = Vec::new();
1963 for (i, offset) in seqlen_offsets.iter().enumerate() {
1964 let cos = cos.narrow(0, *offset, seq_len)?;
1965 let sin = sin.narrow(0, *offset, seq_len)?;
1966 let q_embed = hanzo_nn::rotary_emb::rope(
1967 &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1968 &cos,
1969 &sin,
1970 )?;
1971 let k_embed = hanzo_nn::rotary_emb::rope(
1972 &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1973 &cos,
1974 &sin,
1975 )?;
1976 q_embeds.push(q_embed);
1977 k_embeds.push(k_embed);
1978 }
1979 let q_rot = Tensor::cat(&q_embeds, 0)?;
1980 let k_rot = Tensor::cat(&k_embeds, 0)?;
1981 (q_rot, k_rot)
1982 };
1983
1984 Ok((
1985 Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
1986 Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
1987 ))
1988 }
1989}
1990
1991#[derive(Debug, Clone)]
1992pub struct Gemma3nRotaryEmbedding(RotaryEmbedding);
1993
1994#[derive(Debug, Clone, Deserialize, Serialize)]
1995#[serde(rename_all = "lowercase")]
1996pub enum Gemma3nScaledRopeType {
1997 #[serde(alias = "linear")]
1998 Linear,
1999}
2000
2001#[derive(Debug, Clone, Deserialize, Serialize)]
2002pub struct Gemma3nRopeScalingConfig {
2003 factor: f64,
2004 rope_type: Gemma3nScaledRopeType,
2005}
2006
2007impl Gemma3nRotaryEmbedding {
2008 fn new_linear(
2009 cfg: &Gemma3nTextConfig,
2010 factor: f64,
2011 is_gpt_neox: bool,
2012 dtype: DType,
2013 dev: &Device,
2014 ) -> Result<Self> {
2015 let max_seq_len = cfg.max_position_embeddings;
2016 let dim = cfg.head_dim;
2017
2018 let inv_freq: Vec<_> = (0..dim)
2019 .step_by(2)
2020 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
2021 .collect();
2022 let inv_freq_len = inv_freq.len();
2023 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
2024 let inv_freq = (inv_freq / factor)?;
2025
2026 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
2027 .to_dtype(DType::F32)?
2028 .reshape((max_seq_len, 1))?;
2029 let freqs = t.matmul(&inv_freq)?;
2030 let sin = freqs.sin()?.to_dtype(dtype)?;
2031 let cos = freqs.cos()?.to_dtype(dtype)?;
2032 Ok(Self(RotaryEmbedding {
2033 cos,
2034 sin,
2035 is_gpt_neox,
2036 }))
2037 }
2038
2039 pub fn new(
2040 is_gpt_neox: bool,
2041 dtype: DType,
2042 cfg: &Gemma3nTextConfig,
2043 dev: &Device,
2044 ) -> Result<Self> {
2045 match &cfg.rope_scaling {
2046 Some(Gemma3RopeScalingConfig {
2047 rope_type: Gemma3ScaledRopeType::Linear,
2048 factor,
2049 }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
2050
2051 _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
2052 }
2053 }
2054
2055 pub fn get_cos_sin(&self) -> Result<(Tensor, Tensor)> {
2056 self.0.get_cos_sin()
2057 }
2058
2059 pub fn forward(
2060 &self,
2061 q: &Tensor,
2062 k: &Tensor,
2063 seqlen_offsets: &[usize],
2064 ) -> Result<(Tensor, Tensor)> {
2065 self.0.forward(q, k, seqlen_offsets)
2066 }
2067
2068 #[allow(clippy::too_many_arguments)]
2069 pub fn forward_qk_norm(
2070 &self,
2071 q: &Tensor,
2072 k: &Tensor,
2073 q_weight: &Tensor,
2074 k_weight: &Tensor,
2075 q_eps: f64,
2076 k_eps: f64,
2077 seqlen_offsets: &[usize],
2078 ) -> Result<(Tensor, Tensor)> {
2079 self.0
2080 .forward_qk_norm(q, k, q_weight, k_weight, q_eps, k_eps, seqlen_offsets)
2081 }
2082}
2083
2084#[derive(Debug, Clone)]
2085pub struct Gemma3RotaryEmbedding(RotaryEmbedding);
2086
2087#[derive(Debug, Clone, Deserialize, Serialize)]
2088#[serde(rename_all = "lowercase")]
2089pub enum Gemma3ScaledRopeType {
2090 #[serde(alias = "linear")]
2091 Linear,
2092}
2093
2094#[derive(Debug, Clone, Deserialize, Serialize)]
2095pub struct Gemma3RopeScalingConfig {
2096 factor: f64,
2097 rope_type: Gemma3ScaledRopeType,
2098}
2099
2100impl Gemma3RotaryEmbedding {
2101 fn new_linear(
2102 cfg: &Gemma3TextConfig,
2103 factor: f64,
2104 is_gpt_neox: bool,
2105 dtype: DType,
2106 dev: &Device,
2107 ) -> Result<Self> {
2108 let max_seq_len = cfg.max_position_embeddings;
2109 let dim = cfg.head_dim;
2110
2111 let inv_freq: Vec<_> = (0..dim)
2112 .step_by(2)
2113 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
2114 .collect();
2115 let inv_freq_len = inv_freq.len();
2116 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
2117 let inv_freq = (inv_freq / factor)?;
2118
2119 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
2120 .to_dtype(DType::F32)?
2121 .reshape((max_seq_len, 1))?;
2122 let freqs = t.matmul(&inv_freq)?;
2123 let sin = freqs.sin()?.to_dtype(dtype)?;
2124 let cos = freqs.cos()?.to_dtype(dtype)?;
2125 Ok(Self(RotaryEmbedding {
2126 cos,
2127 sin,
2128 is_gpt_neox,
2129 }))
2130 }
2131
2132 pub fn new(
2133 is_gpt_neox: bool,
2134 dtype: DType,
2135 cfg: &Gemma3TextConfig,
2136 dev: &Device,
2137 ) -> Result<Self> {
2138 match &cfg.rope_scaling {
2139 Some(Gemma3RopeScalingConfig {
2140 rope_type: Gemma3ScaledRopeType::Linear,
2141 factor,
2142 }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
2143
2144 _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
2145 }
2146 }
2147
2148 fn new_linear_embedding_gemma(
2149 cfg: &EmbeddingGemmaConfig,
2150 factor: f64,
2151 is_gpt_neox: bool,
2152 dtype: DType,
2153 dev: &Device,
2154 ) -> Result<Self> {
2155 let max_seq_len = cfg.max_position_embeddings;
2156 let dim = cfg.head_dim;
2157
2158 let inv_freq: Vec<_> = (0..dim)
2159 .step_by(2)
2160 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
2161 .collect();
2162 let inv_freq_len = inv_freq.len();
2163 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
2164 let inv_freq = (inv_freq / factor)?;
2165
2166 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
2167 .to_dtype(DType::F32)?
2168 .reshape((max_seq_len, 1))?;
2169 let freqs = t.matmul(&inv_freq)?;
2170 let sin = freqs.sin()?.to_dtype(dtype)?;
2171 let cos = freqs.cos()?.to_dtype(dtype)?;
2172 Ok(Self(RotaryEmbedding {
2173 cos,
2174 sin,
2175 is_gpt_neox,
2176 }))
2177 }
2178
2179 pub fn new_embedding_gemma(
2180 is_gpt_neox: bool,
2181 dtype: DType,
2182 cfg: &EmbeddingGemmaConfig,
2183 dev: &Device,
2184 ) -> Result<Self> {
2185 match &cfg.rope_scaling {
2186 Some(Gemma3RopeScalingConfig {
2187 rope_type: Gemma3ScaledRopeType::Linear,
2188 factor,
2189 }) => Self::new_linear_embedding_gemma(cfg, *factor, is_gpt_neox, dtype, dev),
2190
2191 _ => Self::new_linear_embedding_gemma(cfg, 1.0, is_gpt_neox, dtype, dev),
2192 }
2193 }
2194
2195 pub fn forward(
2196 &self,
2197 q: &Tensor,
2198 k: &Tensor,
2199 seqlen_offsets: &[usize],
2200 ) -> Result<(Tensor, Tensor)> {
2201 self.0.forward(q, k, seqlen_offsets)
2202 }
2203
2204 pub fn forward_q_norm(
2205 &self,
2206 q: &Tensor,
2207 q_weight: &Tensor,
2208 q_eps: f64,
2209 seqlen_offsets: &[usize],
2210 ) -> Result<Tensor> {
2211 self.0.forward_q_norm(q, q_weight, q_eps, seqlen_offsets)
2212 }
2213
2214 #[allow(clippy::too_many_arguments)]
2215 pub fn forward_qk_norm(
2216 &self,
2217 q: &Tensor,
2218 k: &Tensor,
2219 q_weight: &Tensor,
2220 k_weight: &Tensor,
2221 q_eps: f64,
2222 k_eps: f64,
2223 seqlen_offsets: &[usize],
2224 ) -> Result<(Tensor, Tensor)> {
2225 self.0
2226 .forward_qk_norm(q, k, q_weight, k_weight, q_eps, k_eps, seqlen_offsets)
2227 }
2228}
2229
2230pub struct DiaRotaryEmbedding {
2231 timescale: Tensor,
2232 dtype: DType,
2233}
2234
2235impl DiaRotaryEmbedding {
2236 pub fn new(
2237 min_timescale: f32,
2238 max_timescale: f32,
2239 head_dim: usize,
2240 device: &Device,
2241 dtype: DType,
2242 ) -> Result<Self> {
2243 assert_eq!(head_dim % 2, 0);
2244 let half_embedding_dim = head_dim / 2;
2245
2246 let fraction = (0..half_embedding_dim).map(|i| 2f32 * i as f32 / head_dim as f32);
2247 let timescale = fraction
2248 .into_iter()
2249 .map(|x| min_timescale * (max_timescale / min_timescale).powf(x))
2250 .collect::<Vec<_>>();
2251
2252 let timescale_len = timescale.len();
2253 let timescale = Tensor::from_vec(timescale, timescale_len, device)?;
2254
2255 Ok(Self { timescale, dtype })
2256 }
2257
2258 pub fn forward(&self, xs: &Tensor, positions: &Tensor) -> Result<Tensor> {
2259 let freqs = positions
2260 .unsqueeze(D::Minus1)?
2261 .unsqueeze(D::Minus1)?
2262 .broadcast_div(&self.timescale)?;
2263
2264 let sin = freqs.sin()?.to_dtype(self.dtype)?;
2265 let cos = freqs.cos()?.to_dtype(self.dtype)?;
2266
2267 let split = xs.chunk(2, D::Minus1)?;
2268 let first_half = &split[0];
2269 let second_half = &split[1];
2270
2271 let first_part = (first_half.broadcast_mul(&cos)? - second_half.broadcast_mul(&sin)?)?;
2272 let second_part = (second_half.broadcast_mul(&cos)? + first_half.broadcast_mul(&sin)?)?;
2273
2274 Tensor::cat(&[first_part, second_part], D::Minus1)
2275 }
2276}
2277#[derive(Debug, Clone)]
2278pub struct QLinear {
2279 inner: QMatMul,
2280 bias: Option<Tensor>,
2281 dtype: DType,
2282}
2283
2284impl QLinear {
2285 pub fn new<R: std::io::Read + std::io::Seek>(
2286 ct: &mut Content<'_, R>,
2287 name: &str,
2288 device: &Device,
2289 ) -> Result<Self> {
2290 let w = ct.tensor(&format!("{name}.weight"), device)?;
2291 let b = ct.tensor(&format!("{name}.bias"), device)?;
2292 let inner = QMatMul::from_qtensor(w)?;
2293 let bias = b.dequantize(device)?;
2294 Ok(Self {
2295 inner,
2296 bias: Some(bias),
2297 dtype: DType::F32,
2298 })
2299 }
2300
2301 pub fn from_linear(linear: Linear) -> Self {
2302 Self {
2303 inner: QMatMul::Tensor(linear.weight().clone()),
2304 bias: linear.bias().cloned(),
2305 dtype: linear.weight().dtype(),
2306 }
2307 }
2308
2309 pub fn from_parts(w: Tensor, b: Option<Tensor>) -> Self {
2310 let dtype = w.dtype();
2311 Self {
2312 inner: QMatMul::Tensor(w),
2313 bias: b,
2314 dtype,
2315 }
2316 }
2317
2318 pub fn from_qparts(w: QTensor, b: Option<Tensor>) -> Self {
2319 if let Some(ref b) = b {
2320 assert_eq!(b.dtype(), DType::F32);
2321 }
2322 Self {
2323 inner: QMatMul::QTensor(Arc::new(w)),
2324 bias: b,
2325 dtype: DType::F32,
2326 }
2327 }
2328
2329 pub fn from_old_and_qmatmul(inner: QMatMul, old: &Self) -> Self {
2330 Self {
2331 inner,
2332 bias: old.bias.clone(),
2333 dtype: old.dtype,
2334 }
2335 }
2336
2337 pub fn inner(&mut self) -> &mut QMatMul {
2338 &mut self.inner
2339 }
2340
2341 pub fn inner_ref(&self) -> &QMatMul {
2342 &self.inner
2343 }
2344
2345 pub fn is_quant(&self) -> bool {
2346 matches!(self.inner, QMatMul::QTensor(_))
2347 }
2348
2349 pub fn bias(&self) -> Option<&Tensor> {
2350 self.bias.as_ref()
2351 }
2352
2353 pub fn bias_mut(&mut self) -> Option<&mut Tensor> {
2354 self.bias.as_mut()
2355 }
2356}
2357
2358impl Module for QLinear {
2359 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2360 let xs = if self.is_quant() {
2361 xs.to_dtype(DType::F32)?
2362 } else {
2363 xs.clone()
2364 };
2365 if let Some(bias) = &self.bias {
2366 self.inner
2367 .forward(&xs)?
2368 .broadcast_add(bias)?
2369 .to_dtype(self.dtype)
2370 } else {
2371 self.inner.forward(&xs)?.to_dtype(self.dtype)
2372 }
2373 }
2374}
2375
2376#[derive(Debug, Clone)]
2377pub struct RotaryEmbedding {
2378 cos: Tensor,
2379 sin: Tensor,
2380 is_gpt_neox: bool,
2381}
2382
2383fn selected_rope_cache(
2384 cos: &Tensor,
2385 sin: &Tensor,
2386 batch: usize,
2387 seq_len: usize,
2388 seqlen_offsets: &[usize],
2389) -> Result<(Tensor, Tensor)> {
2390 if seqlen_offsets.len() == 1 {
2391 Ok((
2392 cos.narrow(0, seqlen_offsets[0], seq_len)?,
2393 sin.narrow(0, seqlen_offsets[0], seq_len)?,
2394 ))
2395 } else {
2396 if seqlen_offsets.len() != batch {
2397 hanzo_ml::bail!(
2398 "RoPE offset count {} does not match batch size {batch}",
2399 seqlen_offsets.len()
2400 );
2401 }
2402 let mut cos_s = Vec::with_capacity(batch);
2403 let mut sin_s = Vec::with_capacity(batch);
2404 for offset in seqlen_offsets {
2405 cos_s.push(cos.narrow(0, *offset, seq_len)?);
2406 sin_s.push(sin.narrow(0, *offset, seq_len)?);
2407 }
2408 Ok((Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?))
2409 }
2410}
2411
2412#[allow(clippy::too_many_arguments)]
2413pub fn qk_rms_norm_rope(
2414 q: &Tensor,
2415 k: &Tensor,
2416 q_weight: &Tensor,
2417 k_weight: &Tensor,
2418 q_eps: f64,
2419 k_eps: f64,
2420 cos_cache: &Tensor,
2421 sin_cache: &Tensor,
2422 is_gpt_neox: bool,
2423 seqlen_offsets: &[usize],
2424) -> Result<(Tensor, Tensor)> {
2425 let (batch, _, seq_len, _) = q.dims4()?;
2426 let (cos, sin) = selected_rope_cache(cos_cache, sin_cache, batch, seq_len, seqlen_offsets)?;
2427
2428 #[cfg(feature = "cuda")]
2429 if let Some((q, Some(k))) = crate::ops::try_cuda_qk_rms_norm_rope(
2430 q,
2431 Some(k),
2432 q_weight,
2433 Some(k_weight),
2434 q_eps as f32,
2435 k_eps as f32,
2436 &cos,
2437 &sin,
2438 is_gpt_neox,
2439 )? {
2440 return Ok((q, k));
2441 }
2442
2443 let rope = if is_gpt_neox {
2444 hanzo_nn::rotary_emb::rope
2445 } else {
2446 hanzo_nn::rotary_emb::rope_i
2447 };
2448 let q = hanzo_nn::ops::rms_norm(&q.contiguous()?, q_weight, q_eps as f32)?;
2449 let k = hanzo_nn::ops::rms_norm(&k.contiguous()?, k_weight, k_eps as f32)?;
2450
2451 #[cfg(feature = "cuda")]
2452 if q.device().is_cuda() && q.dim(1)? == k.dim(1)? && cos.dim(0)? == batch * seq_len {
2453 let qh = q.dim(1)?;
2454 let n_embd = q.dim(D::Minus1)?;
2455 let q_embed = q.transpose(1, 2)?.flatten(0, 1)?;
2456 let k_embed = k.transpose(1, 2)?.flatten(0, 1)?;
2457 hanzo_quant::rotary::apply_rotary_inplace(&q_embed, &k_embed, &cos, &sin, is_gpt_neox)?;
2458 let mut q = q_embed
2459 .reshape((batch, seq_len, qh, n_embd))?
2460 .transpose(1, 2)?;
2461 let mut k = k_embed
2462 .reshape((batch, seq_len, k.dim(1)?, n_embd))?
2463 .transpose(1, 2)?;
2464 if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) {
2465 q = q.contiguous()?;
2466 k = k.contiguous()?;
2467 }
2468 return Ok((q, k));
2469 }
2470
2471 if seqlen_offsets.len() == 1 {
2472 Ok((
2473 rope(&q.contiguous()?, &cos, &sin)?,
2474 rope(&k.contiguous()?, &cos, &sin)?,
2475 ))
2476 } else {
2477 let mut q_embeds = Vec::with_capacity(batch);
2478 let mut k_embeds = Vec::with_capacity(batch);
2479 for seq_idx in 0..batch {
2480 let cos = cos.narrow(0, seq_idx * seq_len, seq_len)?;
2481 let sin = sin.narrow(0, seq_idx * seq_len, seq_len)?;
2482 q_embeds.push(rope(
2483 &q.i(seq_idx)?.unsqueeze(0)?.contiguous()?,
2484 &cos,
2485 &sin,
2486 )?);
2487 k_embeds.push(rope(
2488 &k.i(seq_idx)?.unsqueeze(0)?.contiguous()?,
2489 &cos,
2490 &sin,
2491 )?);
2492 }
2493 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
2494 }
2495}
2496
2497#[allow(clippy::too_many_arguments)]
2498pub fn q_rms_norm_rope(
2499 q: &Tensor,
2500 q_weight: &Tensor,
2501 q_eps: f64,
2502 cos_cache: &Tensor,
2503 sin_cache: &Tensor,
2504 is_gpt_neox: bool,
2505 seqlen_offsets: &[usize],
2506) -> Result<Tensor> {
2507 let (batch, _qh, seq_len, _head_dim) = q.dims4()?;
2508 let (cos, sin) = selected_rope_cache(cos_cache, sin_cache, batch, seq_len, seqlen_offsets)?;
2509
2510 #[cfg(feature = "cuda")]
2511 if let Some((q, None)) = crate::ops::try_cuda_qk_rms_norm_rope(
2512 q,
2513 None,
2514 q_weight,
2515 None,
2516 q_eps as f32,
2517 q_eps as f32,
2518 &cos,
2519 &sin,
2520 is_gpt_neox,
2521 )? {
2522 return Ok(q);
2523 }
2524
2525 let rope = if is_gpt_neox {
2526 hanzo_nn::rotary_emb::rope
2527 } else {
2528 hanzo_nn::rotary_emb::rope_i
2529 };
2530 let q = hanzo_nn::ops::rms_norm(&q.contiguous()?, q_weight, q_eps as f32)?;
2531 if seqlen_offsets.len() == 1 {
2532 rope(&q.contiguous()?, &cos, &sin)
2533 } else {
2534 let mut q_embeds = Vec::with_capacity(batch);
2535 for seq_idx in 0..batch {
2536 let cos = cos.narrow(0, seq_idx * seq_len, seq_len)?;
2537 let sin = sin.narrow(0, seq_idx * seq_len, seq_len)?;
2538 q_embeds.push(rope(
2539 &q.i(seq_idx)?.unsqueeze(0)?.contiguous()?,
2540 &cos,
2541 &sin,
2542 )?);
2543 }
2544 Tensor::cat(&q_embeds, 0)
2545 }
2546}
2547
2548#[allow(clippy::too_many_arguments)]
2549pub fn qk_rms_norm_mrope(
2550 q: &Tensor,
2551 k: &Tensor,
2552 q_weight: &Tensor,
2553 k_weight: &Tensor,
2554 q_eps: f64,
2555 k_eps: f64,
2556 cos: &Tensor,
2557 sin: &Tensor,
2558 is_gpt_neox: bool,
2559) -> Result<(Tensor, Tensor)> {
2560 let (_, _q_heads, _, head_dim) = q.dims4()?;
2561 let rot_width = cos.dim(D::Minus1)? * 2;
2562
2563 #[cfg(feature = "cuda")]
2564 {
2565 let (batch, _, seq_len, _) = q.dims4()?;
2566 let cos_flat = match cos.dims() {
2567 [cos_batch, cos_seq, _] if *cos_batch == batch && *cos_seq == seq_len => {
2568 cos.reshape((batch * seq_len, ()))?
2569 }
2570 [cos_rows, _] if *cos_rows == seq_len || *cos_rows == batch * seq_len => cos.clone(),
2571 _ => hanzo_ml::bail!(
2572 "MRoPE cos shape {:?} is incompatible with q shape {:?}",
2573 cos.shape(),
2574 q.shape()
2575 ),
2576 };
2577 let sin_flat = match sin.dims() {
2578 [sin_batch, sin_seq, _] if *sin_batch == batch && *sin_seq == seq_len => {
2579 sin.reshape((batch * seq_len, ()))?
2580 }
2581 [sin_rows, _] if *sin_rows == seq_len || *sin_rows == batch * seq_len => sin.clone(),
2582 _ => hanzo_ml::bail!(
2583 "MRoPE sin shape {:?} is incompatible with q shape {:?}",
2584 sin.shape(),
2585 q.shape()
2586 ),
2587 };
2588 if let Some((q, Some(k))) = crate::ops::try_cuda_qk_rms_norm_rope(
2589 q,
2590 Some(k),
2591 q_weight,
2592 Some(k_weight),
2593 q_eps as f32,
2594 k_eps as f32,
2595 &cos_flat,
2596 &sin_flat,
2597 is_gpt_neox,
2598 )? {
2599 return Ok((q, k));
2600 }
2601 }
2602
2603 let rope = if is_gpt_neox {
2604 hanzo_nn::rotary_emb::rope
2605 } else {
2606 hanzo_nn::rotary_emb::rope_i
2607 };
2608 let q = hanzo_nn::ops::rms_norm(&q.contiguous()?, q_weight, q_eps as f32)?;
2609 let k = hanzo_nn::ops::rms_norm(&k.contiguous()?, k_weight, k_eps as f32)?;
2610 if rot_width < head_dim {
2611 let q_rot = q.narrow(D::Minus1, 0, rot_width)?;
2612 let q_pass = q.narrow(D::Minus1, rot_width, head_dim - rot_width)?;
2613 let k_rot = k.narrow(D::Minus1, 0, rot_width)?;
2614 let k_pass = k.narrow(D::Minus1, rot_width, head_dim - rot_width)?;
2615 let q_rot = rope(&q_rot.contiguous()?, cos, sin)?;
2616 let k_rot = rope(&k_rot.contiguous()?, cos, sin)?;
2617 Ok((
2618 Tensor::cat(&[q_rot, q_pass], D::Minus1)?,
2619 Tensor::cat(&[k_rot, k_pass], D::Minus1)?,
2620 ))
2621 } else {
2622 Ok((
2623 rope(&q.contiguous()?, cos, sin)?,
2624 rope(&k.contiguous()?, cos, sin)?,
2625 ))
2626 }
2627}
2628
2629impl RotaryEmbedding {
2630 pub fn new(
2631 base: f32,
2632 head_dim: usize,
2633 max_position_embeddings: usize,
2634 device: &Device,
2635 is_gpt_neox: bool,
2636 dtype: DType,
2637 ) -> Result<Self> {
2638 let inv_freq: Vec<_> = (0..head_dim)
2639 .step_by(2)
2640 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
2641 .collect();
2642 let inv_freq_len = inv_freq.len();
2643 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
2644 let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
2645 .to_dtype(DType::F32)?
2646 .reshape((max_position_embeddings, 1))?;
2647 let freqs = t.matmul(&inv_freq)?;
2648 let sin = freqs.sin()?.to_dtype(dtype)?;
2649 let cos = freqs.cos()?.to_dtype(dtype)?;
2650
2651 Ok(Self {
2652 cos,
2653 sin,
2654 is_gpt_neox,
2655 })
2656 }
2657
2658 pub fn get_cos_sin(&self) -> Result<(Tensor, Tensor)> {
2659 Ok((self.cos.clone(), self.sin.clone()))
2660 }
2661
2662 #[allow(clippy::too_many_arguments)]
2663 pub fn forward_qk_norm(
2664 &self,
2665 q: &Tensor,
2666 k: &Tensor,
2667 q_weight: &Tensor,
2668 k_weight: &Tensor,
2669 q_eps: f64,
2670 k_eps: f64,
2671 seqlen_offsets: &[usize],
2672 ) -> Result<(Tensor, Tensor)> {
2673 qk_rms_norm_rope(
2674 q,
2675 k,
2676 q_weight,
2677 k_weight,
2678 q_eps,
2679 k_eps,
2680 &self.cos,
2681 &self.sin,
2682 self.is_gpt_neox,
2683 seqlen_offsets,
2684 )
2685 }
2686
2687 pub fn forward_q_norm(
2688 &self,
2689 q: &Tensor,
2690 q_weight: &Tensor,
2691 q_eps: f64,
2692 seqlen_offsets: &[usize],
2693 ) -> Result<Tensor> {
2694 q_rms_norm_rope(
2695 q,
2696 q_weight,
2697 q_eps,
2698 &self.cos,
2699 &self.sin,
2700 self.is_gpt_neox,
2701 seqlen_offsets,
2702 )
2703 }
2704
2705 pub fn new_partial(
2706 base: f32,
2707 rot_dim: usize,
2708 max_position_embeddings: usize,
2709 device: &Device,
2710 is_gpt_neox: bool,
2711 dtype: DType,
2712 ) -> Result<Self> {
2713 let inv_freq: Vec<_> = (0..rot_dim)
2714 .step_by(2)
2715 .map(|i| 1f32 / base.powf(i as f32 / rot_dim as f32))
2716 .collect();
2717 let inv_freq_len = inv_freq.len();
2718 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
2719 let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
2720 .to_dtype(DType::F32)?
2721 .reshape((max_position_embeddings, 1))?;
2722 let freqs = t.matmul(&inv_freq)?;
2723 let sin = freqs.sin()?.to_dtype(dtype)?;
2724 let cos = freqs.cos()?.to_dtype(dtype)?;
2725
2726 Ok(Self {
2727 cos,
2728 sin,
2729 is_gpt_neox,
2730 })
2731 }
2732
2733 pub fn forward(
2734 &self,
2735 q: &Tensor,
2736 k: &Tensor,
2737 seqlen_offsets: &[usize],
2738 ) -> Result<(Tensor, Tensor)> {
2739 let (b_sz, qh, seq_len, n_embd) = q.dims4()?;
2740 let (_b_sz, kh, _seq_len, __n_embd) = k.dims4()?;
2741
2742 let rope = if self.is_gpt_neox {
2743 hanzo_nn::rotary_emb::rope
2744 } else {
2745 hanzo_nn::rotary_emb::rope_i
2746 };
2747
2748 if cfg!(feature = "cuda") && qh == kh {
2749 let (cos, sin) = if seqlen_offsets.len() == 1 {
2750 (
2751 self.cos.narrow(0, seqlen_offsets[0], seq_len)?,
2752 self.sin.narrow(0, seqlen_offsets[0], seq_len)?,
2753 )
2754 } else {
2755 let mut cos_s = Vec::new();
2756 let mut sin_s = Vec::new();
2757 for offset in seqlen_offsets {
2758 cos_s.push(self.cos.narrow(0, *offset, seq_len)?);
2759 sin_s.push(self.sin.narrow(0, *offset, seq_len)?);
2760 }
2761 (Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?)
2762 };
2763
2764 let q_embed = q.transpose(1, 2)?.flatten(0, 1)?;
2765 let k_embed = k.transpose(1, 2)?.flatten(0, 1)?;
2766 hanzo_quant::rotary::apply_rotary_inplace(
2767 &q_embed,
2768 &k_embed,
2769 &cos,
2770 &sin,
2771 self.is_gpt_neox,
2772 )?;
2773 let mut q = q_embed
2774 .reshape((b_sz, seq_len, qh, n_embd))?
2775 .transpose(1, 2)?;
2776 let mut k = k_embed
2777 .reshape((b_sz, seq_len, kh, n_embd))?
2778 .transpose(1, 2)?;
2779 if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) {
2780 q = q.contiguous()?;
2781 k = k.contiguous()?;
2782 }
2783 Ok((q, k))
2784 } else if seqlen_offsets.len() == 1 {
2785 let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
2786 let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
2787 let q_embed = rope(&q.contiguous()?, &cos, &sin)?;
2788 let k_embed = rope(&k.contiguous()?, &cos, &sin)?;
2789 Ok((q_embed, k_embed))
2790 } else {
2791 let mut q_embeds = Vec::new();
2792 let mut k_embeds = Vec::new();
2793 for (i, offset) in seqlen_offsets.iter().enumerate() {
2794 let cos = self.cos.narrow(0, *offset, seq_len)?;
2795 let sin = self.sin.narrow(0, *offset, seq_len)?;
2796 let q_embed = rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
2797 let k_embed = rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
2798 q_embeds.push(q_embed);
2799 k_embeds.push(k_embed);
2800 }
2801 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
2802 }
2803 }
2804
2805 pub fn forward_q(&self, q: &Tensor, seqlen_offsets: &[usize]) -> Result<Tensor> {
2807 let (_b_sz, _qh, seq_len, _n_embd) = q.dims4()?;
2808 let rope = if self.is_gpt_neox {
2809 hanzo_nn::rotary_emb::rope
2810 } else {
2811 hanzo_nn::rotary_emb::rope_i
2812 };
2813 if seqlen_offsets.len() == 1 {
2814 let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
2815 let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
2816 rope(&q.contiguous()?, &cos, &sin)
2817 } else {
2818 let mut q_embeds = Vec::new();
2819 for (i, offset) in seqlen_offsets.iter().enumerate() {
2820 let cos = self.cos.narrow(0, *offset, seq_len)?;
2821 let sin = self.sin.narrow(0, *offset, seq_len)?;
2822 q_embeds.push(rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?);
2823 }
2824 Tensor::cat(&q_embeds, 0)
2825 }
2826 }
2827}
2828
2829#[derive(Debug, Clone)]
2832pub struct GptOssRotaryEmbedding {
2833 cos: Tensor,
2834 sin: Tensor,
2835 #[allow(dead_code)]
2836 attention_scale: f32,
2837}
2838
2839impl GptOssRotaryEmbedding {
2840 #[allow(clippy::too_many_arguments)]
2854 pub fn new(
2855 base: f64,
2856 head_dim: usize,
2857 max_position_embeddings: usize,
2858 factor: f64,
2859 original_max_position_embeddings: usize,
2860 beta_fast: f64,
2861 beta_slow: f64,
2862 truncate: bool,
2863 device: &Device,
2864 dtype: DType,
2865 ) -> Result<Self> {
2866 let dim = head_dim;
2867
2868 let attention_scale = (0.1 * factor.ln() + 1.0) as f32;
2870
2871 let find_correction_dim = |num_rotations: f64| -> f64 {
2874 (dim as f64
2875 * (original_max_position_embeddings as f64
2876 / (num_rotations * 2.0 * std::f64::consts::PI))
2877 .ln())
2878 / (2.0 * base.ln())
2879 };
2880
2881 let mut low = find_correction_dim(beta_fast);
2883 let mut high = find_correction_dim(beta_slow);
2884 if truncate {
2885 low = low.floor();
2886 high = high.ceil();
2887 }
2888 low = low.max(0.0);
2889 high = high.min((dim - 1) as f64);
2890
2891 let half_dim = dim / 2;
2893 let inv_freq_extrapolation: Vec<f64> = (0..dim)
2894 .step_by(2)
2895 .map(|i| 1.0 / base.powf(i as f64 / dim as f64))
2896 .collect();
2897 let inv_freq_interpolation: Vec<f64> =
2898 inv_freq_extrapolation.iter().map(|f| f / factor).collect();
2899
2900 let inv_freq: Vec<f64> = (0..half_dim)
2902 .map(|i| {
2903 let range = if (high - low).abs() < 0.001 {
2904 0.001
2905 } else {
2906 high - low
2907 };
2908 let linear = (i as f64 - low) / range;
2909 let ramp = linear.clamp(0.0, 1.0);
2910 inv_freq_interpolation[i] * ramp + inv_freq_extrapolation[i] * (1.0 - ramp)
2911 })
2912 .collect();
2913
2914 let inv_freq_len = inv_freq.len();
2915 let inv_freq_tensor = Tensor::from_vec(
2916 inv_freq.iter().map(|&x| x as f32).collect::<Vec<_>>(),
2917 (1, inv_freq_len),
2918 device,
2919 )?;
2920
2921 let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
2922 .to_dtype(DType::F32)?
2923 .reshape((max_position_embeddings, 1))?;
2924
2925 let freqs = t.matmul(&inv_freq_tensor)?;
2926
2927 let sin = (freqs.sin()? * attention_scale as f64)?.to_dtype(dtype)?;
2929 let cos = (freqs.cos()? * attention_scale as f64)?.to_dtype(dtype)?;
2930
2931 Ok(Self {
2932 cos,
2933 sin,
2934 attention_scale,
2935 })
2936 }
2937
2938 pub fn forward(
2939 &self,
2940 q: &Tensor,
2941 k: &Tensor,
2942 seqlen_offsets: &[usize],
2943 ) -> Result<(Tensor, Tensor)> {
2944 #[allow(unused_variables)]
2945 let (b_sz, qh, seq_len, n_embd) = q.dims4()?;
2946 #[allow(unused_variables)]
2947 let (_b_sz, kh, _seq_len, _n_embd) = k.dims4()?;
2948
2949 #[cfg(feature = "cuda")]
2952 if q.device().is_cuda() && qh == k.dim(1)? {
2953 let (cos, sin) = if seqlen_offsets.len() == 1 {
2954 (
2955 self.cos.narrow(0, seqlen_offsets[0], seq_len)?,
2956 self.sin.narrow(0, seqlen_offsets[0], seq_len)?,
2957 )
2958 } else {
2959 let mut cos_s = Vec::new();
2960 let mut sin_s = Vec::new();
2961 for offset in seqlen_offsets {
2962 cos_s.push(self.cos.narrow(0, *offset, seq_len)?);
2963 sin_s.push(self.sin.narrow(0, *offset, seq_len)?);
2964 }
2965 (Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?)
2966 };
2967
2968 let q_embed = q.transpose(1, 2)?.flatten(0, 1)?;
2970 let k_embed = k.transpose(1, 2)?.flatten(0, 1)?;
2971
2972 hanzo_quant::rotary::apply_rotary_inplace(&q_embed, &k_embed, &cos, &sin, true)?;
2974
2975 let mut q = q_embed
2977 .reshape((b_sz, seq_len, qh, n_embd))?
2978 .transpose(1, 2)?;
2979 let mut k = k_embed
2980 .reshape((b_sz, seq_len, kh, n_embd))?
2981 .transpose(1, 2)?;
2982
2983 if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) {
2984 q = q.contiguous()?;
2985 k = k.contiguous()?;
2986 }
2987 return Ok((q, k));
2988 }
2989
2990 if seqlen_offsets.len() == 1 {
2992 let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
2993 let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
2994 let q_embed = hanzo_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
2995 let k_embed = hanzo_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
2996 Ok((q_embed, k_embed))
2997 } else {
2998 let mut q_embeds = Vec::new();
2999 let mut k_embeds = Vec::new();
3000 for (i, offset) in seqlen_offsets.iter().enumerate() {
3001 let cos = self.cos.narrow(0, *offset, seq_len)?;
3002 let sin = self.sin.narrow(0, *offset, seq_len)?;
3003 let q_embed =
3004 hanzo_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
3005 let k_embed =
3006 hanzo_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
3007 q_embeds.push(q_embed);
3008 k_embeds.push(k_embed);
3009 }
3010 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
3011 }
3012 }
3013}
3014
3015#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize, Default)]
3016#[serde(rename_all = "lowercase")]
3017pub enum Activation {
3018 #[default]
3019 #[serde(alias = "gelu")]
3020 Gelu,
3021 #[serde(alias = "gelu_new")]
3022 NewGelu,
3023 Relu,
3024 Relu2,
3025 Relu6,
3026 Silu,
3027 Sigmoid,
3028 HardSigmoid,
3029 Swiglu,
3030 Swish,
3031 HardSwish,
3032 Elu(f64),
3033 LeakyRelu(f64),
3034 #[serde(alias = "gelu_pytorch_tanh")]
3035 GeluPytorchTanh,
3036 QuickGelu,
3037}
3038
3039impl Module for Activation {
3040 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
3041 match self {
3042 Self::Gelu => xs.gelu_erf(),
3043 Self::NewGelu => xs.gelu(),
3045 Self::Relu => xs.relu(),
3046 Self::Relu2 => xs.relu()?.sqr(),
3047 Self::Relu6 => xs.clamp(0f32, 6f32),
3048 Self::Silu => xs.silu(),
3049 Self::Sigmoid => hanzo_nn::ops::sigmoid(xs),
3050 Self::HardSigmoid => hanzo_nn::ops::hard_sigmoid(xs),
3051 Self::Swiglu => hanzo_nn::ops::swiglu(xs),
3052 Self::Swish => xs * hanzo_nn::ops::sigmoid(xs)?,
3053 Self::HardSwish => xs * hanzo_nn::ops::hard_sigmoid(xs)?,
3054 &Self::Elu(alpha) => xs.elu(alpha),
3055 &Self::LeakyRelu(negative_slope) => hanzo_nn::ops::leaky_relu(xs, negative_slope),
3056 Self::GeluPytorchTanh => xs.gelu(),
3057 Self::QuickGelu => xs * hanzo_nn::ops::sigmoid(&(xs * 1.702f64)?),
3058 }
3059 }
3060}
3061
3062impl TryInto<hanzo_nn::Activation> for Activation {
3063 type Error = hanzo_ml::Error;
3064
3065 fn try_into(self) -> Result<hanzo_nn::Activation> {
3066 match self {
3067 Self::Gelu => Ok(hanzo_nn::Activation::Gelu),
3068 Self::Relu => Ok(hanzo_nn::Activation::Relu),
3069 Self::Silu => Ok(hanzo_nn::Activation::Silu),
3070 Self::NewGelu => Ok(hanzo_nn::Activation::NewGelu),
3071 Self::Relu2 => Ok(hanzo_nn::Activation::Relu2),
3072 Self::Relu6 => Ok(hanzo_nn::Activation::Relu6),
3073 Self::Sigmoid => Ok(hanzo_nn::Activation::Sigmoid),
3074 Self::HardSigmoid => Ok(hanzo_nn::Activation::HardSigmoid),
3075 Self::Swiglu => Ok(hanzo_nn::Activation::Swiglu),
3076 Self::Swish => Ok(hanzo_nn::Activation::Swish),
3077 Self::HardSwish => Ok(hanzo_nn::Activation::HardSwish),
3078 Self::Elu(x) => Ok(hanzo_nn::Activation::Elu(x)),
3079 Self::LeakyRelu(x) => Ok(hanzo_nn::Activation::LeakyRelu(x)),
3080 Self::GeluPytorchTanh => Ok(hanzo_nn::Activation::GeluPytorchTanh),
3081 Self::QuickGelu => hanzo_ml::bail!("No mapping to hanzo_nn for QuickGelu"),
3082 }
3083 }
3084}
3085
3086#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3087pub struct Conv3dConfig {
3088 pub padding: usize,
3089 pub stride: usize,
3090 pub dilation: usize,
3091 pub groups: usize,
3092}
3093
3094impl Default for Conv3dConfig {
3095 fn default() -> Self {
3096 Self {
3097 padding: 0,
3098 stride: 1,
3099 dilation: 1,
3100 groups: 1,
3101 }
3102 }
3103}
3104
3105pub struct Conv3dNoBias {
3106 conv2d_1: Conv2d,
3107 conv2d_2: Conv2d,
3108}
3109
3110impl Conv3dNoBias {
3111 pub fn new(
3112 in_channels: usize,
3113 out_channels: usize,
3114 kernel_sizes: [usize; 3],
3115 cfg: Conv3dConfig,
3116 vb: ShardedVarBuilder,
3117 ) -> Result<Self> {
3118 let expected_shape = (
3119 out_channels,
3120 in_channels / cfg.groups,
3121 kernel_sizes[0],
3122 kernel_sizes[1],
3123 kernel_sizes[2],
3124 );
3125 let mlx_shape = (
3128 out_channels,
3129 kernel_sizes[0],
3130 kernel_sizes[1],
3131 kernel_sizes[2],
3132 in_channels / cfg.groups,
3133 );
3134 let ws = if vb.contains_tensor("weight") {
3135 match vb.get(expected_shape, "weight") {
3137 Ok(ws) => ws,
3138 Err(_) => {
3139 let ws = vb.get(mlx_shape, "weight")?;
3141 ws.permute((0, 4, 1, 2, 3))?
3142 }
3143 }
3144 } else {
3145 vb.get(expected_shape, "weight")?
3146 };
3147
3148 let w1 = ws.i((.., .., 0, .., ..))?;
3152 let w2 = ws.i((.., .., 1, .., ..))?;
3153
3154 let cfg = Conv2dConfig {
3155 padding: cfg.padding,
3156 stride: cfg.stride,
3157 dilation: cfg.dilation,
3158 groups: cfg.groups,
3159 cudnn_fwd_algo: None,
3160 };
3161
3162 Ok(Self {
3163 conv2d_1: Conv2d::new(w1.contiguous()?, None, cfg),
3164 conv2d_2: Conv2d::new(w2.contiguous()?, None, cfg),
3165 })
3166 }
3167
3168 pub fn weight(&self) -> Result<Tensor> {
3169 let w1 = self.conv2d_1.weight().clone().unsqueeze(2)?;
3170 let w2 = self.conv2d_2.weight().clone().unsqueeze(2)?;
3171 Tensor::cat(&[w1, w2], 2)
3172 }
3173}
3174
3175impl Module for Conv3dNoBias {
3176 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
3177 let xs1 = xs.i((.., .., 0, .., ..))?;
3178 let xs2 = xs.i((.., .., 1, .., ..))?;
3179
3180 (Convolution.forward_2d(&self.conv2d_1, &xs1)?
3181 + Convolution.forward_2d(&self.conv2d_2, &xs2)?)?
3182 .unsqueeze(2)
3183 }
3184}
3185
3186pub trait TensorInfExtend {
3187 fn is_inf(&self) -> Result<Self>
3188 where
3189 Self: Sized;
3190 fn any(&self) -> Result<bool>;
3191}
3192
3193impl TensorInfExtend for Tensor {
3194 fn is_inf(&self) -> Result<Self> {
3195 self.broadcast_eq(&Tensor::new(f32::INFINITY, self.device())?.to_dtype(self.dtype())?)
3196 }
3197
3198 fn any(&self) -> Result<bool> {
3199 let sum = self.sum_all()?;
3200 match self.dtype() {
3201 DType::U8 => Ok(sum.to_scalar::<u8>()? == 0),
3202 DType::U32 => Ok(sum.to_scalar::<u32>()? == 0),
3203 DType::I16 => Ok(sum.to_scalar::<i16>()? == 0),
3204 DType::I32 => Ok(sum.to_scalar::<i32>()? == 0),
3205 DType::I64 => Ok(sum.to_scalar::<i64>()? == 0),
3206 DType::F16 => Ok(sum.to_scalar::<half::f16>()? == half::f16::from_f32_const(0.)),
3207 DType::BF16 => Ok(sum.to_scalar::<half::bf16>()? == half::bf16::from_f32_const(0.)),
3208 DType::F32 => Ok(sum.to_scalar::<f32>()? == 0.),
3209 DType::F64 => Ok(sum.to_scalar::<f64>()? == 0.),
3210 DType::F8E4M3 => Ok(sum.to_scalar::<F8E4M3>()? == F8E4M3::ZERO),
3211 _ => {
3212 hanzo_ml::bail!("dtype {:?} is not supported with .any", self.dtype())
3213 }
3214 }
3215 }
3216}
3217
3218pub fn clamp_for_f16(xs: &Tensor) -> Result<Tensor> {
3219 let mut max = match xs.dtype() {
3220 DType::U8 => u8::MAX as f32 - 1000.,
3221 DType::U32 => u32::MAX as f32 - 1000.,
3222 DType::I16 => i16::MAX as f32 - 1000.,
3223 DType::I32 => i32::MAX as f32 - 1000.,
3224 DType::I64 => i64::MAX as f32 - 1000.,
3225 DType::F16 => half::f16::MAX.to_f32_const() - 1000.,
3226 DType::BF16 => half::bf16::MAX.to_f32_const() - 1000.,
3227 DType::F32 => f32::MAX - 1000.,
3228 DType::F64 => f64::MAX as f32 - 1000.,
3229 DType::F8E4M3 => F8E4M3::MAX.to_f32() - 1000.,
3230 _ => {
3231 hanzo_ml::bail!("dtype {:?} is not supported with clamp_for_f16", xs.dtype())
3232 }
3233 };
3234 if xs.is_inf()?.any()? {
3235 max -= 1000.;
3236 }
3237 xs.clamp(-max, max)
3238}
3239
3240pub struct FloatInfo {
3241 pub min: f64,
3243 pub max: f64,
3245 pub eps: f64,
3247 pub dtype: DType,
3248}
3249
3250pub trait GetFloatInfo {
3251 fn finfo(&self) -> Result<FloatInfo>;
3252}
3253
3254impl GetFloatInfo for DType {
3255 fn finfo(&self) -> Result<FloatInfo> {
3256 let finfo = match self {
3257 Self::BF16 => FloatInfo {
3258 min: bf16::MIN.to_f64(),
3259 max: bf16::MAX.to_f64(),
3260 eps: bf16::EPSILON.to_f64(),
3261 dtype: DType::BF16,
3262 },
3263 Self::F16 => FloatInfo {
3264 min: f16::MIN.to_f64(),
3265 max: f16::MAX.to_f64(),
3266 eps: f16::EPSILON.to_f64(),
3267 dtype: DType::F16,
3268 },
3269 Self::F32 => FloatInfo {
3270 min: f32::MIN as f64,
3271 max: f32::MAX as f64,
3272 eps: f32::EPSILON as f64,
3273 dtype: DType::F32,
3274 },
3275 Self::F64 => FloatInfo {
3276 min: f64::MIN,
3277 max: f64::MAX,
3278 eps: f64::EPSILON,
3279 dtype: DType::F64,
3280 },
3281 Self::F8E4M3 => FloatInfo {
3282 min: F8E4M3::MIN.to_f64(),
3283 max: F8E4M3::MAX.to_f64(),
3284 eps: F8E4M3::EPSILON.to_f64(),
3285 dtype: DType::F8E4M3,
3286 },
3287 other => {
3288 hanzo_ml::bail!("Expected a float type for `GetFloatInfo`, got {other:?}");
3289 }
3290 };
3291 Ok(finfo)
3292 }
3293}
3294
3295#[derive(Clone)]
3296pub struct Mlp {
3297 pub gate: Arc<dyn QuantMethod>,
3298 pub up: Arc<dyn QuantMethod>,
3299 pub down: Arc<dyn QuantMethod>,
3300 act: Activation,
3301 params: Vec<usize>,
3302}
3303
3304impl Mlp {
3305 pub fn new(
3306 vb: ShardedVarBuilder,
3307 hidden_size: usize,
3308 intermediate_size: usize,
3309 quantization_config: &Option<QuantizedConfig>,
3310 hidden_act: Activation,
3311 comm: &Arc<hanzo_quant::Comm>,
3312 ) -> Result<Self> {
3313 Ok(Self {
3314 gate: ColumnParallelLayer::new(
3315 hidden_size,
3316 intermediate_size,
3317 quantization_config,
3318 false,
3319 comm,
3320 vb.pp("gate_proj"),
3321 )?,
3322 up: ColumnParallelLayer::new(
3323 hidden_size,
3324 intermediate_size,
3325 quantization_config,
3326 false,
3327 comm,
3328 vb.pp("up_proj"),
3329 )?,
3330 down: RowParallelLayer::new(
3331 intermediate_size,
3332 hidden_size,
3333 quantization_config,
3334 false,
3335 comm,
3336 vb.pp("down_proj"),
3337 )?,
3338 act: hidden_act,
3339 params: vec![hidden_size, intermediate_size],
3340 })
3341 }
3342
3343 pub fn new_merged(
3344 vb: ShardedVarBuilder,
3345 hidden_size: usize,
3346 intermediate_size: usize,
3347 chunks: usize,
3348 quantization_config: &Option<QuantizedConfig>,
3349 hidden_act: Activation,
3350 comm: &Arc<hanzo_quant::Comm>,
3351 ) -> Result<Self> {
3352 assert!(chunks == 2, "Only gate_up_proj merge is supported!");
3353 let gate_up_projs = ColumnParallelLayer::new_merged(
3354 hidden_size,
3355 intermediate_size * 2,
3356 2,
3357 quantization_config,
3358 false,
3359 comm,
3360 vb.pp("gate_up_proj"),
3361 )?;
3362
3363 Ok(Self {
3364 gate: gate_up_projs[0].to_owned(),
3365 up: gate_up_projs[1].to_owned(),
3366 down: RowParallelLayer::new(
3367 intermediate_size,
3368 hidden_size,
3369 quantization_config,
3370 false,
3371 comm,
3372 vb.pp("down_proj"),
3373 )?,
3374 act: hidden_act,
3375 params: vec![hidden_size, intermediate_size],
3376 })
3377 }
3378
3379 pub fn replicate(
3380 params: &[usize],
3381 vb: ShardedVarBuilder,
3382 act: Activation,
3383 comm: &Arc<hanzo_quant::Comm>,
3384 ) -> Result<Self> {
3385 Self::new(vb, params[0], params[1], &None, act, comm)
3386 }
3387
3388 pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
3389 let res = crate::ops::quantized_ffn(xs, &*self.gate, &*self.up, &*self.down, self.act)?;
3390 Ok(res)
3391 }
3392}
3393
3394impl AnyMoeTrainableLayer for Mlp {}
3395
3396impl MlpLayer for Mlp {
3397 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
3398 let res = crate::ops::quantized_ffn(xs, &*self.gate, &*self.up, &*self.down, self.act)?;
3399 Ok(res)
3400 }
3401 fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
3402 vec![&mut self.gate, &mut self.up, &mut self.down]
3403 }
3404 fn clone(&self) -> Box<dyn MlpLayer> {
3405 Box::new(Clone::clone(self))
3406 }
3407 fn get_params(&self) -> &[usize] {
3408 &self.params
3409 }
3410 fn hidden_act(&self) -> Activation {
3411 self.act
3412 }
3413 fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
3415 let gate = if let Some(ref delta) = deltas[0] {
3416 self.gate.add_delta_w(delta)?
3417 } else {
3418 self.gate.clone()
3419 };
3420 let up = if let Some(ref delta) = deltas[1] {
3421 self.up.add_delta_w(delta)?
3422 } else {
3423 self.up.clone()
3424 };
3425 let down = if let Some(ref delta) = deltas[2] {
3426 self.down.add_delta_w(delta)?
3427 } else {
3428 self.down.clone()
3429 };
3430
3431 Ok(Box::new(Self {
3432 gate,
3433 up,
3434 down,
3435 act: self.act,
3436 params: self.params.clone(),
3437 }))
3438 }
3439
3440 fn dtype_device(&self) -> (DType, Device) {
3441 self.gate.dtype_and_device()
3442 }
3443}
3444
3445pub struct AvgPool2d {
3446 kernel_size: usize,
3447 stride: usize,
3448}
3449
3450impl AvgPool2d {
3451 pub fn new(kernel_size: usize, stride: usize) -> Self {
3452 Self {
3453 kernel_size,
3454 stride,
3455 }
3456 }
3457
3458 pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
3459 xs.avg_pool2d_with_stride(self.kernel_size, self.stride)
3460 }
3461}
3462
3463pub struct ReflectionPad2d {
3470 padding: (usize, usize, usize, usize),
3471}
3472
3473impl ReflectionPad2d {
3474 pub fn new(padding: (usize, usize, usize, usize)) -> Self {
3475 Self { padding }
3476 }
3477}
3478
3479impl Module for ReflectionPad2d {
3480 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
3481 let (pad_left, pad_right, pad_top, pad_bottom) = self.padding;
3482
3483 let (_n, _c, h, w) = xs.dims4()?;
3484
3485 let left_pad = if pad_left > 0 {
3488 let indices: Vec<i64> = (1..=pad_left as i64).rev().collect();
3490 Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
3491 } else {
3492 None
3493 };
3494
3495 let right_pad = if pad_right > 0 {
3497 let start = w as i64 - 2;
3499 let indices: Vec<i64> = (0..pad_right as i64).map(|i| start - i).collect();
3500 Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
3501 } else {
3502 None
3503 };
3504
3505 let x_padded_width = match (left_pad, right_pad) {
3507 (Some(l), Some(r)) => Tensor::cat(&[l, xs.clone(), r], 3)?,
3508 (Some(l), None) => Tensor::cat(&[l, xs.clone()], 3)?,
3509 (None, Some(r)) => Tensor::cat(&[xs.clone(), r], 3)?,
3510 (None, None) => xs.clone(),
3511 };
3512
3513 let top_pad = if pad_top > 0 {
3516 let indices: Vec<i64> = (1..=pad_top as i64).rev().collect();
3517 Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
3518 } else {
3519 None
3520 };
3521
3522 let bottom_pad = if pad_bottom > 0 {
3524 let start = h as i64 - 2;
3525 let indices: Vec<i64> = (0..pad_bottom as i64).map(|i| start - i).collect();
3526 Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
3527 } else {
3528 None
3529 };
3530
3531 let x_padded = match (top_pad, bottom_pad) {
3533 (Some(t), Some(b)) => Tensor::cat(&[t, x_padded_width, b], 2)?,
3534 (Some(t), None) => Tensor::cat(&[t, x_padded_width], 2)?,
3535 (None, Some(b)) => Tensor::cat(&[x_padded_width, b], 2)?,
3536 (None, None) => x_padded_width,
3537 };
3538
3539 Ok(x_padded)
3540 }
3541}
3542
3543pub struct ScaledEmbedding {
3544 scale: f64,
3545 pub embedding: Tensor,
3546}
3547
3548impl ScaledEmbedding {
3549 pub fn new(scale: f64, embedding: Embedding) -> Self {
3550 Self {
3551 scale,
3552 embedding: embedding.embeddings().clone(),
3553 }
3554 }
3555
3556 pub fn embeddings(&self) -> &Tensor {
3557 &self.embedding
3558 }
3559}
3560
3561impl Module for ScaledEmbedding {
3562 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
3563 let embedding = Embedding::new(self.embedding.clone(), self.embedding.dim(D::Minus1)?);
3564 xs.apply(&embedding)? * self.scale
3565 }
3566}