Skip to main content

rlx_flow/
composite.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Composite layer environments (Slang `LightPair` / `LightArray`).
5//!
6//! Describe static layer stacks for specialization keys and flow assembly.
7
8use std::collections::hash_map::DefaultHasher;
9use std::hash::{Hash, Hasher};
10use std::sync::Arc;
11
12use crate::stage::FlowStage;
13
14/// Static description of a layer stack for cache keys and recipes.
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum LayerComposition {
17    /// Single named stage.
18    Single { name: String },
19    /// `count` identical layers (homogeneous array).
20    Homogeneous { layer_name: String, count: usize },
21    /// Heterogeneous head + tail (pair).
22    Pair {
23        head: Box<LayerComposition>,
24        tail: Box<LayerComposition>,
25    },
26}
27
28impl LayerComposition {
29    pub fn single(name: impl Into<String>) -> Self {
30        Self::Single { name: name.into() }
31    }
32
33    pub fn homogeneous(layer_name: impl Into<String>, count: usize) -> Self {
34        Self::Homogeneous {
35            layer_name: layer_name.into(),
36            count,
37        }
38    }
39
40    pub fn pair(head: LayerComposition, tail: LayerComposition) -> Self {
41        Self::Pair {
42            head: Box::new(head),
43            tail: Box::new(tail),
44        }
45    }
46
47    /// Fingerprint for [`rlx_ir::ModelComponent::layer_composition_key`].
48    pub fn cache_key(&self) -> u64 {
49        let mut h = DefaultHasher::new();
50        self.hash_fragment(&mut h);
51        h.finish()
52    }
53
54    fn hash_fragment(&self, h: &mut DefaultHasher) {
55        match self {
56            Self::Single { name } => {
57                0u8.hash(h);
58                name.hash(h);
59            }
60            Self::Homogeneous { layer_name, count } => {
61                1u8.hash(h);
62                layer_name.hash(h);
63                count.hash(h);
64            }
65            Self::Pair { head, tail } => {
66                2u8.hash(h);
67                head.hash_fragment(h);
68                tail.hash_fragment(h);
69            }
70        }
71    }
72
73    /// Expand into a [`FlowStage::Sequence`] by repeating `build_layer`.
74    pub fn to_flow_stage(&self, build_layer: &dyn Fn(&str, usize) -> FlowStage) -> FlowStage {
75        match self {
76            Self::Single { name } => build_layer(name, 0),
77            Self::Homogeneous { layer_name, count } => {
78                let stages: Vec<_> = (0..*count)
79                    .map(|i| FlowStage::Named {
80                        name: format!("{layer_name}{i}"),
81                        inner: Arc::new(build_layer(layer_name, i)),
82                    })
83                    .collect();
84                FlowStage::Sequence(stages)
85            }
86            Self::Pair { head, tail } => FlowStage::Sequence(vec![
87                head.to_flow_stage(build_layer),
88                tail.to_flow_stage(build_layer),
89            ]),
90        }
91    }
92
93    pub fn depth_hint(&self) -> usize {
94        match self {
95            Self::Single { .. } => 1,
96            Self::Homogeneous { count, .. } => *count,
97            Self::Pair { head, tail } => head.depth_hint() + tail.depth_hint(),
98        }
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use crate::stage::FlowStage;
106
107    #[test]
108    fn homogeneous_cache_key_scales_with_count() {
109        let a = LayerComposition::homogeneous("layer", 8).cache_key();
110        let b = LayerComposition::homogeneous("layer", 32).cache_key();
111        assert_ne!(a, b);
112    }
113
114    #[test]
115    fn pair_expands_two_stages() {
116        let comp =
117            LayerComposition::pair(LayerComposition::single("a"), LayerComposition::single("b"));
118        let stage = comp.to_flow_stage(&|name, _| FlowStage::Named {
119            name: name.into(),
120            inner: Arc::new(FlowStage::Sequence(vec![])),
121        });
122        assert!(matches!(stage, FlowStage::Sequence(s) if s.len() == 2));
123    }
124}