Skip to main content

rlx_ir/
variant.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Model execution variants — one object drives cache keys and [`DimBinding`].
17//!
18//! Mirrors the “shader components” idea from extensible shading systems: the same
19//! granularity selects **what to specialize** and **which symbolic dims to bind**.
20
21use std::collections::hash_map::DefaultHasher;
22use std::hash::{Hash, Hasher};
23
24use crate::dynamic::sym;
25use crate::shape::DimBinding;
26
27/// Coarse execution phase (prefill vs decode vs encoder).
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub enum ModelPhase {
30    Prefill,
31    Decode,
32    Encoder,
33    Inference,
34}
35
36/// Concrete shape bucket for compile-once / specialize-at-runtime workflows.
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub struct ModelVariant {
39    pub batch: usize,
40    pub seq: usize,
41    pub past_seq: Option<usize>,
42    pub phase: ModelPhase,
43    /// Extra dynamic symbols beyond batch/seq/past (e.g. custom ragged axes).
44    pub extra: Vec<(u32, usize)>,
45}
46
47impl ModelVariant {
48    pub fn prefill(batch: usize, seq: usize) -> Self {
49        Self {
50            batch,
51            seq,
52            past_seq: None,
53            phase: ModelPhase::Prefill,
54            extra: Vec::new(),
55        }
56    }
57
58    /// Single-step decode: `seq` is the new token count (often 1); `past_seq` is KV length.
59    pub fn decode(batch: usize, past_seq: usize, new_tokens: usize) -> Self {
60        Self {
61            batch,
62            seq: new_tokens,
63            past_seq: Some(past_seq),
64            phase: ModelPhase::Decode,
65            extra: Vec::new(),
66        }
67    }
68
69    pub fn encoder(batch: usize, seq: usize) -> Self {
70        Self {
71            batch,
72            seq,
73            past_seq: None,
74            phase: ModelPhase::Encoder,
75            extra: Vec::new(),
76        }
77    }
78
79    pub fn with_extra(mut self, symbol: u32, size: usize) -> Self {
80        self.extra.push((symbol, size));
81        self
82    }
83
84    /// Stable cache key: phase + bound leading dims + extra symbols.
85    pub fn cache_key(&self) -> u64 {
86        let mut h = DefaultHasher::new();
87        self.phase.hash(&mut h);
88        self.batch.hash(&mut h);
89        self.seq.hash(&mut h);
90        self.past_seq.hash(&mut h);
91        for (sym, size) in &self.extra {
92            sym.hash(&mut h);
93            size.hash(&mut h);
94        }
95        h.finish()
96    }
97
98    /// Symbol bindings used by [`crate::dynamic::bind_graph`] / compile specialization.
99    pub fn dim_binding(&self) -> DimBinding {
100        let mut b = match (self.phase, self.past_seq) {
101            (ModelPhase::Decode, Some(past)) => DimBinding::batch_past_seq(self.batch, past),
102            _ => DimBinding::batch_seq(self.batch, self.seq),
103        };
104        if self.phase == ModelPhase::Decode {
105            b.set(sym::SEQ, self.seq);
106        }
107        for (sym, size) in &self.extra {
108            b.set(*sym, *size);
109        }
110        b
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn prefill_binding_sets_batch_seq() {
120        let v = ModelVariant::prefill(2, 128);
121        let b = v.dim_binding();
122        assert_eq!(b.get(sym::BATCH), Some(2));
123        assert_eq!(b.get(sym::SEQ), Some(128));
124    }
125
126    #[test]
127    fn decode_binding_sets_past_and_new_seq() {
128        let v = ModelVariant::decode(1, 64, 1);
129        let b = v.dim_binding();
130        assert_eq!(b.get(sym::BATCH), Some(1));
131        assert_eq!(b.get(sym::PAST_SEQ), Some(64));
132        assert_eq!(b.get(sym::SEQ), Some(1));
133    }
134
135    #[test]
136    fn cache_key_differs_by_phase() {
137        let a = ModelVariant::prefill(1, 8).cache_key();
138        let b = ModelVariant::decode(1, 7, 1).cache_key();
139        assert_ne!(a, b);
140    }
141}