1use std::collections::hash_map::DefaultHasher;
21use std::hash::{Hash, Hasher};
22use std::sync::Arc;
23
24use crate::stage::FlowStage;
25
26#[derive(Debug, Clone, PartialEq, Eq)]
28pub enum LayerComposition {
29 Single { name: String },
31 Homogeneous { layer_name: String, count: usize },
33 Pair {
35 head: Box<LayerComposition>,
36 tail: Box<LayerComposition>,
37 },
38}
39
40impl LayerComposition {
41 pub fn single(name: impl Into<String>) -> Self {
42 Self::Single { name: name.into() }
43 }
44
45 pub fn homogeneous(layer_name: impl Into<String>, count: usize) -> Self {
46 Self::Homogeneous {
47 layer_name: layer_name.into(),
48 count,
49 }
50 }
51
52 pub fn pair(head: LayerComposition, tail: LayerComposition) -> Self {
53 Self::Pair {
54 head: Box::new(head),
55 tail: Box::new(tail),
56 }
57 }
58
59 pub fn cache_key(&self) -> u64 {
61 let mut h = DefaultHasher::new();
62 self.hash_fragment(&mut h);
63 h.finish()
64 }
65
66 fn hash_fragment(&self, h: &mut DefaultHasher) {
67 match self {
68 Self::Single { name } => {
69 0u8.hash(h);
70 name.hash(h);
71 }
72 Self::Homogeneous { layer_name, count } => {
73 1u8.hash(h);
74 layer_name.hash(h);
75 count.hash(h);
76 }
77 Self::Pair { head, tail } => {
78 2u8.hash(h);
79 head.hash_fragment(h);
80 tail.hash_fragment(h);
81 }
82 }
83 }
84
85 pub fn to_flow_stage(&self, build_layer: &dyn Fn(&str, usize) -> FlowStage) -> FlowStage {
87 match self {
88 Self::Single { name } => build_layer(name, 0),
89 Self::Homogeneous { layer_name, count } => {
90 let stages: Vec<_> = (0..*count)
91 .map(|i| FlowStage::Named {
92 name: format!("{layer_name}{i}"),
93 inner: Arc::new(build_layer(layer_name, i)),
94 })
95 .collect();
96 FlowStage::Sequence(stages)
97 }
98 Self::Pair { head, tail } => FlowStage::Sequence(vec![
99 head.to_flow_stage(build_layer),
100 tail.to_flow_stage(build_layer),
101 ]),
102 }
103 }
104
105 pub fn depth_hint(&self) -> usize {
106 match self {
107 Self::Single { .. } => 1,
108 Self::Homogeneous { count, .. } => *count,
109 Self::Pair { head, tail } => head.depth_hint() + tail.depth_hint(),
110 }
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use crate::stage::FlowStage;
118
119 #[test]
120 fn homogeneous_cache_key_scales_with_count() {
121 let a = LayerComposition::homogeneous("layer", 8).cache_key();
122 let b = LayerComposition::homogeneous("layer", 32).cache_key();
123 assert_ne!(a, b);
124 }
125
126 #[test]
127 fn pair_expands_two_stages() {
128 let comp =
129 LayerComposition::pair(LayerComposition::single("a"), LayerComposition::single("b"));
130 let stage = comp.to_flow_stage(&|name, _| FlowStage::Named {
131 name: name.into(),
132 inner: Arc::new(FlowStage::Sequence(vec![])),
133 });
134 assert!(matches!(stage, FlowStage::Sequence(s) if s.len() == 2));
135 }
136}