Skip to main content

rlx_ir/
component.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Unified model component — one object drives specialization, compile cache, and binding.
5//!
6//! Mirrors Slang “shader components”: the same granularity selects **what to specialize**
7//! (dims, dispatch, compilation mode) and **how host code binds** (via [`BindingManifest`]
8//! after specialize). Works across eager, lazy, and AOT pipelines while keeping HIR/MIR/LIR.
9
10use std::collections::hash_map::DefaultHasher;
11use std::hash::{Hash, Hasher};
12
13use crate::logical_kernel::KernelDispatchConfig;
14use crate::quant::QuantScheme;
15use crate::variant::ModelVariant;
16
17/// When the backend executable is produced relative to the host loop.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
19pub enum CompilationMode {
20    /// Compile before first `run` (default inference).
21    #[default]
22    Eager,
23    /// Build template at load; specialize/compile on first use per variant.
24    Lazy,
25    /// Serialize LIR / executable to disk; load without re-running fusion.
26    Aot,
27}
28
29/// Full specialization + binding bundle (Slang shader-component analogue).
30#[derive(Debug, Clone)]
31pub struct ModelComponent {
32    pub variant: ModelVariant,
33    pub kernel_dispatch: KernelDispatchConfig,
34    pub compilation_mode: CompilationMode,
35    /// Hash of tier-1 [`CompileProfile`] or arch preset (see `rlx-flow` presets).
36    pub profile_key: u64,
37    /// Optional quant scheme affecting lowers and weight layout.
38    pub quant: Option<QuantScheme>,
39    /// Composite layer-stack fingerprint (homogeneous depth, pair nesting).
40    pub layer_composition_key: u64,
41}
42
43impl ModelComponent {
44    pub fn new(variant: ModelVariant) -> Self {
45        Self {
46            variant,
47            kernel_dispatch: KernelDispatchConfig::default(),
48            compilation_mode: CompilationMode::Eager,
49            profile_key: 0,
50            quant: None,
51            layer_composition_key: 0,
52        }
53    }
54
55    pub fn with_kernel_dispatch(mut self, config: KernelDispatchConfig) -> Self {
56        self.kernel_dispatch = config;
57        self
58    }
59
60    pub fn with_compilation_mode(mut self, mode: CompilationMode) -> Self {
61        self.compilation_mode = mode;
62        self
63    }
64
65    pub fn with_profile_key(mut self, key: u64) -> Self {
66        self.profile_key = key;
67        self
68    }
69
70    pub fn with_quant(mut self, scheme: QuantScheme) -> Self {
71        self.quant = Some(scheme);
72        self
73    }
74
75    pub fn with_layer_composition_key(mut self, key: u64) -> Self {
76        self.layer_composition_key = key;
77        self
78    }
79
80    /// Stable key for compile caches (variant + dispatch + profile + composition).
81    pub fn cache_key(&self) -> u64 {
82        let mut h = DefaultHasher::new();
83        self.variant.cache_key().hash(&mut h);
84        (self.kernel_dispatch.policy as u8).hash(&mut h);
85        for k in self.kernel_dispatch.force_common_kinds.iter() {
86            k.hash(&mut h);
87        }
88        for k in self.kernel_dispatch.force_native_kinds.iter() {
89            k.hash(&mut h);
90        }
91        self.compilation_mode.hash(&mut h);
92        self.profile_key.hash(&mut h);
93        if let Some(q) = &self.quant {
94            format!("{q:?}").hash(&mut h);
95        }
96        self.layer_composition_key.hash(&mut h);
97        h.finish()
98    }
99
100    pub fn dim_binding(&self) -> crate::DimBinding {
101        self.variant.dim_binding()
102    }
103
104    /// Stable on-disk prefix for [`rlx_runtime::AotCache`] (`{base}__{binding_hash}` per variant).
105    pub fn aot_disk_base(&self) -> String {
106        format!("rlx_{:016x}", self.cache_key())
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113    use crate::ModelVariant;
114    use crate::logical_kernel::KernelDispatchPolicy;
115
116    #[test]
117    fn cache_key_changes_with_mode_and_profile() {
118        let v = ModelVariant::prefill(1, 8);
119        let a = ModelComponent::new(v.clone()).cache_key();
120        let b = ModelComponent::new(v.clone())
121            .with_compilation_mode(CompilationMode::Lazy)
122            .cache_key();
123        let c = ModelComponent::new(v)
124            .with_profile_key(42)
125            .with_kernel_dispatch(KernelDispatchConfig::new(KernelDispatchPolicy::ForceCommon))
126            .cache_key();
127        assert_ne!(a, b);
128        assert_ne!(a, c);
129    }
130}