1use crate::task::{CrackleTask, TaskOutput};
2
3pub struct GlazeLayer<T: CrackleTask> {
35 inner: T,
36 derived_metrics: Vec<DerivedMetric<T::Output>>,
37 label_override: Option<String>,
38}
39
40type DerivedMetric<T> = (String, Box<dyn Fn(&TaskOutput<T>) -> f64>);
41
42impl<T: CrackleTask> GlazeLayer<T> {
43 pub fn new(task: T) -> Self {
45 GlazeLayer {
46 inner: task,
47 derived_metrics: Vec::new(),
48 label_override: None,
49 }
50 }
51
52 pub fn with_derived_metric(
57 mut self,
58 name: impl Into<String>,
59 compute: impl Fn(&TaskOutput<T::Output>) -> f64 + 'static,
60 ) -> Self {
61 self.derived_metrics
62 .push((name.into(), Box::new(compute)));
63 self
64 }
65
66 pub fn with_label(mut self, label: impl Into<String>) -> Self {
68 self.label_override = Some(label.into());
69 self
70 }
71}
72
73impl<T: CrackleTask> CrackleTask for GlazeLayer<T> {
74 type Output = T::Output;
75
76 fn fire(&self) -> TaskOutput<Self::Output> {
77 let mut output = self.inner.fire();
78 for (name, compute) in &self.derived_metrics {
79 let value = compute(&output);
80 output.metrics.push((name.clone(), value));
81 }
82 output
83 }
84
85 fn cool(
86 &self,
87 output: &TaskOutput<Self::Output>,
88 all_metrics: &[(String, Vec<(String, f64)>)],
89 ) -> Vec<(String, f64)> {
90 let mut results = self.inner.cool(output, all_metrics);
91 for (name, compute) in &self.derived_metrics {
93 let value = compute(output);
94 results.push((name.clone(), value));
95 }
96 results
97 }
98
99 fn label(&self) -> String {
100 self.label_override
101 .clone()
102 .unwrap_or_else(|| self.inner.label())
103 }
104}
105
106#[allow(dead_code)]
110pub struct GlazeBatch<T: CrackleTask> {
111 tasks: Vec<GlazeLayer<T>>,
112}
113
114impl<T: CrackleTask> GlazeBatch<T> {
115 #[allow(dead_code)]
116 pub fn new() -> Self {
118 GlazeBatch { tasks: Vec::new() }
119 }
120
121 #[allow(dead_code)]
123 pub fn add(mut self, task: T) -> Self {
124 self.tasks.push(GlazeLayer::new(task));
125 self
126 }
127
128 #[allow(dead_code)]
129 pub fn add_glazed(mut self, glazed: GlazeLayer<T>) -> Self {
130 self.tasks.push(glazed);
131 self
132 }
133
134 #[allow(dead_code)]
135 pub fn tasks(&self) -> &[GlazeLayer<T>] {
136 &self.tasks
137 }
138
139 #[allow(dead_code)]
140 pub fn into_tasks(self) -> Vec<GlazeLayer<T>> {
141 self.tasks
142 }
143}
144
145impl<T: CrackleTask> Default for GlazeBatch<T> {
146 fn default() -> Self {
147 Self::new()
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 struct NumTask { v: f64 }
156 impl CrackleTask for NumTask {
157 type Output = f64;
158 fn fire(&self) -> TaskOutput<Self::Output> {
159 TaskOutput::simple(self.v)
160 }
161 }
162
163 #[test]
164 fn glaze_adds_derived_metrics() {
165 let glazed = GlazeLayer::new(NumTask { v: 4.0 })
166 .with_derived_metric("squared", |o| o.value * o.value);
167 let output = glazed.fire();
168 assert_eq!(output.value, 4.0);
169 let sq = output.metrics.iter().find(|(n, _)| n == "squared").unwrap();
170 assert!((sq.1 - 16.0).abs() < 0.001);
171 }
172
173 #[test]
174 fn glaze_label_override() {
175 let glazed = GlazeLayer::new(NumTask { v: 1.0 }).with_label("special");
176 assert_eq!(glazed.label(), "special");
177 }
178
179 #[test]
180 fn glaze_batch_builder() {
181 let batch = GlazeBatch::new()
182 .add(NumTask { v: 1.0 })
183 .add(NumTask { v: 2.0 });
184 assert_eq!(batch.tasks().len(), 2);
185 }
186}