Skip to main content

rlx_ir/
component.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Unified model component — one object drives specialization, compile cache, and binding.
17//!
18//! Mirrors Slang “shader components”: the same granularity selects **what to specialize**
19//! (dims, dispatch, compilation mode) and **how host code binds** (via [`BindingManifest`]
20//! after specialize). Works across eager, lazy, and AOT pipelines while keeping HIR/MIR/LIR.
21
22use std::collections::hash_map::DefaultHasher;
23use std::hash::{Hash, Hasher};
24
25use crate::logical_kernel::KernelDispatchConfig;
26use crate::quant::QuantScheme;
27use crate::variant::ModelVariant;
28
29/// When the backend executable is produced relative to the host loop.
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
31pub enum CompilationMode {
32    /// Compile before first `run` (default inference).
33    #[default]
34    Eager,
35    /// Build template at load; specialize/compile on first use per variant.
36    Lazy,
37    /// Serialize LIR / executable to disk; load without re-running fusion.
38    Aot,
39}
40
41/// Full specialization + binding bundle (Slang shader-component analogue).
42#[derive(Debug, Clone)]
43pub struct ModelComponent {
44    pub variant: ModelVariant,
45    pub kernel_dispatch: KernelDispatchConfig,
46    pub compilation_mode: CompilationMode,
47    /// Hash of tier-1 [`CompileProfile`] or arch preset (see `rlx-flow` presets).
48    pub profile_key: u64,
49    /// Optional quant scheme affecting lowers and weight layout.
50    pub quant: Option<QuantScheme>,
51    /// Composite layer-stack fingerprint (homogeneous depth, pair nesting).
52    pub layer_composition_key: u64,
53}
54
55impl ModelComponent {
56    pub fn new(variant: ModelVariant) -> Self {
57        Self {
58            variant,
59            kernel_dispatch: KernelDispatchConfig::default(),
60            compilation_mode: CompilationMode::Eager,
61            profile_key: 0,
62            quant: None,
63            layer_composition_key: 0,
64        }
65    }
66
67    pub fn with_kernel_dispatch(mut self, config: KernelDispatchConfig) -> Self {
68        self.kernel_dispatch = config;
69        self
70    }
71
72    pub fn with_compilation_mode(mut self, mode: CompilationMode) -> Self {
73        self.compilation_mode = mode;
74        self
75    }
76
77    pub fn with_profile_key(mut self, key: u64) -> Self {
78        self.profile_key = key;
79        self
80    }
81
82    pub fn with_quant(mut self, scheme: QuantScheme) -> Self {
83        self.quant = Some(scheme);
84        self
85    }
86
87    pub fn with_layer_composition_key(mut self, key: u64) -> Self {
88        self.layer_composition_key = key;
89        self
90    }
91
92    /// Stable key for compile caches (variant + dispatch + profile + composition).
93    pub fn cache_key(&self) -> u64 {
94        let mut h = DefaultHasher::new();
95        self.variant.cache_key().hash(&mut h);
96        (self.kernel_dispatch.policy as u8).hash(&mut h);
97        for k in self.kernel_dispatch.force_common_kinds.iter() {
98            k.hash(&mut h);
99        }
100        for k in self.kernel_dispatch.force_native_kinds.iter() {
101            k.hash(&mut h);
102        }
103        self.compilation_mode.hash(&mut h);
104        self.profile_key.hash(&mut h);
105        if let Some(q) = &self.quant {
106            format!("{q:?}").hash(&mut h);
107        }
108        self.layer_composition_key.hash(&mut h);
109        h.finish()
110    }
111
112    pub fn dim_binding(&self) -> crate::DimBinding {
113        self.variant.dim_binding()
114    }
115
116    /// Stable on-disk prefix for [`rlx_runtime::AotCache`] (`{base}__{binding_hash}` per variant).
117    pub fn aot_disk_base(&self) -> String {
118        format!("rlx_{:016x}", self.cache_key())
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use crate::ModelVariant;
126    use crate::logical_kernel::KernelDispatchPolicy;
127
128    #[test]
129    fn cache_key_changes_with_mode_and_profile() {
130        let v = ModelVariant::prefill(1, 8);
131        let a = ModelComponent::new(v.clone()).cache_key();
132        let b = ModelComponent::new(v.clone())
133            .with_compilation_mode(CompilationMode::Lazy)
134            .cache_key();
135        let c = ModelComponent::new(v)
136            .with_profile_key(42)
137            .with_kernel_dispatch(KernelDispatchConfig::new(KernelDispatchPolicy::ForceCommon))
138            .cache_key();
139        assert_ne!(a, b);
140        assert_ne!(a, c);
141    }
142}