1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
use anyhow::Result;
use rlx_ir::{DType, Shape};
use crate::context::FlowCtx;
#[derive(Debug, Clone)]
pub struct RopeTablesStage {
pub cos_key: String,
pub sin_key: String,
pub max_positions: usize,
pub half_dim: usize,
pub cos_data: Vec<f32>,
pub sin_data: Vec<f32>,
/// When `Some(slot)`, push the (cos, sin) HIR ids into
/// `state.named["{slot}_cos"]` / `"{slot}_sin"` instead of the
/// default `state.rope_cos`/`state.rope_sin`. Self-attention
/// blocks opt into the named slot via
/// `SelfAttnPrefillSpec::rope_table`. Used by Gemma 4 which
/// ships split sliding/full RoPE thetas.
pub named_slot: Option<String>,
}
impl RopeTablesStage {
pub fn param(
max_positions: usize,
half_dim: usize,
cos_data: Vec<f32>,
sin_data: Vec<f32>,
) -> Self {
Self {
cos_key: "rope.cos".into(),
sin_key: "rope.sin".into(),
max_positions,
half_dim,
cos_data,
sin_data,
named_slot: None,
}
}
/// Variant that publishes the tables under a named slot (for
/// per-layer RoPE) rather than the default flow-state handles.
pub fn param_named(
slot: impl Into<String>,
max_positions: usize,
half_dim: usize,
cos_data: Vec<f32>,
sin_data: Vec<f32>,
) -> Self {
let slot = slot.into();
Self {
cos_key: format!("rope.{slot}.cos"),
sin_key: format!("rope.{slot}.sin"),
max_positions,
half_dim,
cos_data,
sin_data,
named_slot: Some(slot),
}
}
pub fn emit(&self, ctx: &mut FlowCtx<'_>) -> Result<()> {
let f = DType::F32;
let cos_shape = Shape::new(&[self.max_positions, self.half_dim], f);
let sin_shape = Shape::new(&[self.max_positions, self.half_dim], f);
let cos_id = ctx.synth_param(&self.cos_key, self.cos_data.clone(), cos_shape);
let sin_id = ctx.synth_param(&self.sin_key, self.sin_data.clone(), sin_shape);
match &self.named_slot {
Some(slot) => {
ctx.state.named.insert(format!("{slot}_cos"), cos_id);
ctx.state.named.insert(format!("{slot}_sin"), sin_id);
}
None => {
ctx.state.rope_cos = Some(cos_id);
ctx.state.rope_sin = Some(sin_id);
}
}
Ok(())
}
}