1use std::collections::hash_map::DefaultHasher;
9use std::hash::{Hash, Hasher};
10use std::sync::Arc;
11
12use crate::stage::FlowStage;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum LayerComposition {
17 Single { name: String },
19 Homogeneous { layer_name: String, count: usize },
21 Pair {
23 head: Box<LayerComposition>,
24 tail: Box<LayerComposition>,
25 },
26}
27
28impl LayerComposition {
29 pub fn single(name: impl Into<String>) -> Self {
30 Self::Single { name: name.into() }
31 }
32
33 pub fn homogeneous(layer_name: impl Into<String>, count: usize) -> Self {
34 Self::Homogeneous {
35 layer_name: layer_name.into(),
36 count,
37 }
38 }
39
40 pub fn pair(head: LayerComposition, tail: LayerComposition) -> Self {
41 Self::Pair {
42 head: Box::new(head),
43 tail: Box::new(tail),
44 }
45 }
46
47 pub fn cache_key(&self) -> u64 {
49 let mut h = DefaultHasher::new();
50 self.hash_fragment(&mut h);
51 h.finish()
52 }
53
54 fn hash_fragment(&self, h: &mut DefaultHasher) {
55 match self {
56 Self::Single { name } => {
57 0u8.hash(h);
58 name.hash(h);
59 }
60 Self::Homogeneous { layer_name, count } => {
61 1u8.hash(h);
62 layer_name.hash(h);
63 count.hash(h);
64 }
65 Self::Pair { head, tail } => {
66 2u8.hash(h);
67 head.hash_fragment(h);
68 tail.hash_fragment(h);
69 }
70 }
71 }
72
73 pub fn to_flow_stage(&self, build_layer: &dyn Fn(&str, usize) -> FlowStage) -> FlowStage {
75 match self {
76 Self::Single { name } => build_layer(name, 0),
77 Self::Homogeneous { layer_name, count } => {
78 let stages: Vec<_> = (0..*count)
79 .map(|i| FlowStage::Named {
80 name: format!("{layer_name}{i}"),
81 inner: Arc::new(build_layer(layer_name, i)),
82 })
83 .collect();
84 FlowStage::Sequence(stages)
85 }
86 Self::Pair { head, tail } => FlowStage::Sequence(vec![
87 head.to_flow_stage(build_layer),
88 tail.to_flow_stage(build_layer),
89 ]),
90 }
91 }
92
93 pub fn depth_hint(&self) -> usize {
94 match self {
95 Self::Single { .. } => 1,
96 Self::Homogeneous { count, .. } => *count,
97 Self::Pair { head, tail } => head.depth_hint() + tail.depth_hint(),
98 }
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use crate::stage::FlowStage;
106
107 #[test]
108 fn homogeneous_cache_key_scales_with_count() {
109 let a = LayerComposition::homogeneous("layer", 8).cache_key();
110 let b = LayerComposition::homogeneous("layer", 32).cache_key();
111 assert_ne!(a, b);
112 }
113
114 #[test]
115 fn pair_expands_two_stages() {
116 let comp =
117 LayerComposition::pair(LayerComposition::single("a"), LayerComposition::single("b"));
118 let stage = comp.to_flow_stage(&|name, _| FlowStage::Named {
119 name: name.into(),
120 inner: Arc::new(FlowStage::Sequence(vec![])),
121 });
122 assert!(matches!(stage, FlowStage::Sequence(s) if s.len() == 2));
123 }
124}