Skip to main content

rlx_ir/
variant.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Model execution variants — one object drives cache keys and [`DimBinding`].
5//!
6//! Mirrors the “shader components” idea from extensible shading systems: the same
7//! granularity selects **what to specialize** and **which symbolic dims to bind**.
8
9use std::collections::hash_map::DefaultHasher;
10use std::hash::{Hash, Hasher};
11
12use crate::dynamic::sym;
13use crate::shape::DimBinding;
14
15/// Coarse execution phase (prefill vs decode vs encoder).
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub enum ModelPhase {
18    Prefill,
19    Decode,
20    Encoder,
21    Inference,
22}
23
24/// Concrete shape bucket for compile-once / specialize-at-runtime workflows.
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct ModelVariant {
27    pub batch: usize,
28    pub seq: usize,
29    pub past_seq: Option<usize>,
30    pub phase: ModelPhase,
31    /// Extra dynamic symbols beyond batch/seq/past (e.g. custom ragged axes).
32    pub extra: Vec<(u32, usize)>,
33}
34
35impl ModelVariant {
36    pub fn prefill(batch: usize, seq: usize) -> Self {
37        Self {
38            batch,
39            seq,
40            past_seq: None,
41            phase: ModelPhase::Prefill,
42            extra: Vec::new(),
43        }
44    }
45
46    /// Single-step decode: `seq` is the new token count (often 1); `past_seq` is KV length.
47    pub fn decode(batch: usize, past_seq: usize, new_tokens: usize) -> Self {
48        Self {
49            batch,
50            seq: new_tokens,
51            past_seq: Some(past_seq),
52            phase: ModelPhase::Decode,
53            extra: Vec::new(),
54        }
55    }
56
57    pub fn encoder(batch: usize, seq: usize) -> Self {
58        Self {
59            batch,
60            seq,
61            past_seq: None,
62            phase: ModelPhase::Encoder,
63            extra: Vec::new(),
64        }
65    }
66
67    pub fn with_extra(mut self, symbol: u32, size: usize) -> Self {
68        self.extra.push((symbol, size));
69        self
70    }
71
72    /// Stable cache key: phase + bound leading dims + extra symbols.
73    pub fn cache_key(&self) -> u64 {
74        let mut h = DefaultHasher::new();
75        self.phase.hash(&mut h);
76        self.batch.hash(&mut h);
77        self.seq.hash(&mut h);
78        self.past_seq.hash(&mut h);
79        for (sym, size) in &self.extra {
80            sym.hash(&mut h);
81            size.hash(&mut h);
82        }
83        h.finish()
84    }
85
86    /// Symbol bindings used by [`crate::dynamic::bind_graph`] / compile specialization.
87    pub fn dim_binding(&self) -> DimBinding {
88        let mut b = match (self.phase, self.past_seq) {
89            (ModelPhase::Decode, Some(past)) => DimBinding::batch_past_seq(self.batch, past),
90            _ => DimBinding::batch_seq(self.batch, self.seq),
91        };
92        if self.phase == ModelPhase::Decode {
93            b.set(sym::SEQ, self.seq);
94        }
95        for (sym, size) in &self.extra {
96            b.set(*sym, *size);
97        }
98        b
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    #[test]
107    fn prefill_binding_sets_batch_seq() {
108        let v = ModelVariant::prefill(2, 128);
109        let b = v.dim_binding();
110        assert_eq!(b.get(sym::BATCH), Some(2));
111        assert_eq!(b.get(sym::SEQ), Some(128));
112    }
113
114    #[test]
115    fn decode_binding_sets_past_and_new_seq() {
116        let v = ModelVariant::decode(1, 64, 1);
117        let b = v.dim_binding();
118        assert_eq!(b.get(sym::BATCH), Some(1));
119        assert_eq!(b.get(sym::PAST_SEQ), Some(64));
120        assert_eq!(b.get(sym::SEQ), Some(1));
121    }
122
123    #[test]
124    fn cache_key_differs_by_phase() {
125        let a = ModelVariant::prefill(1, 8).cache_key();
126        let b = ModelVariant::decode(1, 7, 1).cache_key();
127        assert_ne!(a, b);
128    }
129}