Skip to main content

rlx_flow/
composite.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Composite layer environments (Slang `LightPair` / `LightArray`).
17//!
18//! Describe static layer stacks for specialization keys and flow assembly.
19
20use std::collections::hash_map::DefaultHasher;
21use std::hash::{Hash, Hasher};
22use std::sync::Arc;
23
24use crate::stage::FlowStage;
25
26/// Static description of a layer stack for cache keys and recipes.
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub enum LayerComposition {
29    /// Single named stage.
30    Single { name: String },
31    /// `count` identical layers (homogeneous array).
32    Homogeneous { layer_name: String, count: usize },
33    /// Heterogeneous head + tail (pair).
34    Pair {
35        head: Box<LayerComposition>,
36        tail: Box<LayerComposition>,
37    },
38}
39
40impl LayerComposition {
41    pub fn single(name: impl Into<String>) -> Self {
42        Self::Single { name: name.into() }
43    }
44
45    pub fn homogeneous(layer_name: impl Into<String>, count: usize) -> Self {
46        Self::Homogeneous {
47            layer_name: layer_name.into(),
48            count,
49        }
50    }
51
52    pub fn pair(head: LayerComposition, tail: LayerComposition) -> Self {
53        Self::Pair {
54            head: Box::new(head),
55            tail: Box::new(tail),
56        }
57    }
58
59    /// Fingerprint for [`rlx_ir::ModelComponent::layer_composition_key`].
60    pub fn cache_key(&self) -> u64 {
61        let mut h = DefaultHasher::new();
62        self.hash_fragment(&mut h);
63        h.finish()
64    }
65
66    fn hash_fragment(&self, h: &mut DefaultHasher) {
67        match self {
68            Self::Single { name } => {
69                0u8.hash(h);
70                name.hash(h);
71            }
72            Self::Homogeneous { layer_name, count } => {
73                1u8.hash(h);
74                layer_name.hash(h);
75                count.hash(h);
76            }
77            Self::Pair { head, tail } => {
78                2u8.hash(h);
79                head.hash_fragment(h);
80                tail.hash_fragment(h);
81            }
82        }
83    }
84
85    /// Expand into a [`FlowStage::Sequence`] by repeating `build_layer`.
86    pub fn to_flow_stage(&self, build_layer: &dyn Fn(&str, usize) -> FlowStage) -> FlowStage {
87        match self {
88            Self::Single { name } => build_layer(name, 0),
89            Self::Homogeneous { layer_name, count } => {
90                let stages: Vec<_> = (0..*count)
91                    .map(|i| FlowStage::Named {
92                        name: format!("{layer_name}{i}"),
93                        inner: Arc::new(build_layer(layer_name, i)),
94                    })
95                    .collect();
96                FlowStage::Sequence(stages)
97            }
98            Self::Pair { head, tail } => FlowStage::Sequence(vec![
99                head.to_flow_stage(build_layer),
100                tail.to_flow_stage(build_layer),
101            ]),
102        }
103    }
104
105    pub fn depth_hint(&self) -> usize {
106        match self {
107            Self::Single { .. } => 1,
108            Self::Homogeneous { count, .. } => *count,
109            Self::Pair { head, tail } => head.depth_hint() + tail.depth_hint(),
110        }
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use crate::stage::FlowStage;
118
119    #[test]
120    fn homogeneous_cache_key_scales_with_count() {
121        let a = LayerComposition::homogeneous("layer", 8).cache_key();
122        let b = LayerComposition::homogeneous("layer", 32).cache_key();
123        assert_ne!(a, b);
124    }
125
126    #[test]
127    fn pair_expands_two_stages() {
128        let comp =
129            LayerComposition::pair(LayerComposition::single("a"), LayerComposition::single("b"));
130        let stage = comp.to_flow_stage(&|name, _| FlowStage::Named {
131            name: name.into(),
132            inner: Arc::new(FlowStage::Sequence(vec![])),
133        });
134        assert!(matches!(stage, FlowStage::Sequence(s) if s.len() == 2));
135    }
136}