Skip to main content

lean_ctx/core/
pipeline.rs

1use std::collections::HashMap;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
4pub enum LayerKind {
5    Input,
6    Intent,
7    Relevance,
8    Compression,
9    Translation,
10    Delivery,
11}
12
13impl LayerKind {
14    pub fn as_str(&self) -> &'static str {
15        match self {
16            Self::Input => "input",
17            Self::Intent => "intent",
18            Self::Relevance => "relevance",
19            Self::Compression => "compression",
20            Self::Translation => "translation",
21            Self::Delivery => "delivery",
22        }
23    }
24
25    pub fn all() -> &'static [LayerKind] {
26        &[
27            Self::Input,
28            Self::Intent,
29            Self::Relevance,
30            Self::Compression,
31            Self::Translation,
32            Self::Delivery,
33        ]
34    }
35}
36
37impl std::fmt::Display for LayerKind {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        write!(f, "{}", self.as_str())
40    }
41}
42
43#[derive(Debug, Clone)]
44pub struct LayerInput {
45    pub content: String,
46    pub tokens: usize,
47    pub metadata: HashMap<String, String>,
48}
49
50#[derive(Debug, Clone)]
51pub struct LayerOutput {
52    pub content: String,
53    pub tokens: usize,
54    pub metadata: HashMap<String, String>,
55}
56
57#[derive(Debug, Clone)]
58pub struct LayerMetrics {
59    pub layer: LayerKind,
60    pub input_tokens: usize,
61    pub output_tokens: usize,
62    pub duration_us: u64,
63    pub compression_ratio: f64,
64}
65
66impl LayerMetrics {
67    pub fn new(
68        layer: LayerKind,
69        input_tokens: usize,
70        output_tokens: usize,
71        duration_us: u64,
72    ) -> Self {
73        let ratio = if input_tokens == 0 {
74            1.0
75        } else {
76            output_tokens as f64 / input_tokens as f64
77        };
78        Self {
79            layer,
80            input_tokens,
81            output_tokens,
82            duration_us,
83            compression_ratio: ratio,
84        }
85    }
86}
87
88pub trait Layer {
89    fn kind(&self) -> LayerKind;
90    fn process(&self, input: LayerInput) -> LayerOutput;
91}
92
93pub struct Pipeline {
94    layers: Vec<Box<dyn Layer>>,
95}
96
97impl Pipeline {
98    pub fn new() -> Self {
99        Self { layers: Vec::new() }
100    }
101
102    pub fn add_layer(mut self, layer: Box<dyn Layer>) -> Self {
103        self.layers.push(layer);
104        self
105    }
106
107    pub fn execute(&self, input: LayerInput) -> (LayerOutput, Vec<LayerMetrics>) {
108        let mut current = input;
109        let mut metrics = Vec::new();
110
111        for layer in &self.layers {
112            let start = std::time::Instant::now();
113            let input_tokens = current.tokens;
114            let output = layer.process(current);
115            let duration = start.elapsed().as_micros() as u64;
116
117            metrics.push(LayerMetrics::new(
118                layer.kind(),
119                input_tokens,
120                output.tokens,
121                duration,
122            ));
123
124            current = LayerInput {
125                content: output.content,
126                tokens: output.tokens,
127                metadata: output.metadata,
128            };
129        }
130
131        let final_output = LayerOutput {
132            content: current.content,
133            tokens: current.tokens,
134            metadata: current.metadata,
135        };
136
137        (final_output, metrics)
138    }
139
140    pub fn format_metrics(metrics: &[LayerMetrics]) -> String {
141        let mut out = String::from("Pipeline Metrics:\n");
142        let mut total_saved = 0usize;
143        for m in metrics {
144            let saved = m.input_tokens.saturating_sub(m.output_tokens);
145            total_saved += saved;
146            out.push_str(&format!(
147                "  {} : {} -> {} tok ({:.0}%, {:.1}ms)\n",
148                m.layer,
149                m.input_tokens,
150                m.output_tokens,
151                m.compression_ratio * 100.0,
152                m.duration_us as f64 / 1000.0,
153            ));
154        }
155        if let (Some(first), Some(last)) = (metrics.first(), metrics.last()) {
156            let total_ratio = if first.input_tokens == 0 {
157                1.0
158            } else {
159                last.output_tokens as f64 / first.input_tokens as f64
160            };
161            out.push_str(&format!(
162                "  TOTAL: {} -> {} tok ({:.0}%, saved {})\n",
163                first.input_tokens,
164                last.output_tokens,
165                total_ratio * 100.0,
166                total_saved,
167            ));
168        }
169        out
170    }
171}
172
173#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
174pub struct PipelineStats {
175    pub runs: usize,
176    pub per_layer: HashMap<LayerKind, AggregatedMetrics>,
177}
178
179#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
180pub struct AggregatedMetrics {
181    pub total_input_tokens: usize,
182    pub total_output_tokens: usize,
183    pub total_duration_us: u64,
184    pub count: usize,
185}
186
187impl AggregatedMetrics {
188    pub fn avg_ratio(&self) -> f64 {
189        if self.total_input_tokens == 0 {
190            return 1.0;
191        }
192        self.total_output_tokens as f64 / self.total_input_tokens as f64
193    }
194
195    pub fn avg_duration_ms(&self) -> f64 {
196        if self.count == 0 {
197            return 0.0;
198        }
199        self.total_duration_us as f64 / self.count as f64 / 1000.0
200    }
201}
202
203impl PipelineStats {
204    pub fn new() -> Self {
205        Self {
206            runs: 0,
207            per_layer: HashMap::new(),
208        }
209    }
210
211    pub fn record(&mut self, metrics: &[LayerMetrics]) {
212        self.runs += 1;
213        for m in metrics {
214            let agg = self.per_layer.entry(m.layer).or_default();
215            agg.total_input_tokens += m.input_tokens;
216            agg.total_output_tokens += m.output_tokens;
217            agg.total_duration_us += m.duration_us;
218            agg.count += 1;
219        }
220    }
221
222    pub fn record_single(
223        &mut self,
224        layer: LayerKind,
225        input_tokens: usize,
226        output_tokens: usize,
227        duration: std::time::Duration,
228    ) {
229        self.runs += 1;
230        let agg = self.per_layer.entry(layer).or_default();
231        agg.total_input_tokens += input_tokens;
232        agg.total_output_tokens += output_tokens;
233        agg.total_duration_us += duration.as_micros() as u64;
234        agg.count += 1;
235    }
236
237    pub fn total_tokens_saved(&self) -> usize {
238        self.per_layer
239            .values()
240            .map(|a| a.total_input_tokens.saturating_sub(a.total_output_tokens))
241            .sum()
242    }
243
244    pub fn save(&self) {
245        if let Ok(dir) = crate::core::data_dir::lean_ctx_data_dir() {
246            let path = dir.join("pipeline_stats.json");
247            if let Ok(json) = serde_json::to_string(self) {
248                let _ = std::fs::write(path, json);
249            }
250        }
251    }
252
253    pub fn load() -> Self {
254        crate::core::data_dir::lean_ctx_data_dir()
255            .ok()
256            .map(|d| d.join("pipeline_stats.json"))
257            .and_then(|p| std::fs::read_to_string(p).ok())
258            .and_then(|s| serde_json::from_str(&s).ok())
259            .unwrap_or_default()
260    }
261
262    pub fn format_summary(&self) -> String {
263        let mut out = format!("Pipeline Stats ({} runs):\n", self.runs);
264        for kind in LayerKind::all() {
265            if let Some(agg) = self.per_layer.get(kind) {
266                out.push_str(&format!(
267                    "  {}: avg {:.0}% ratio, {:.1}ms, {} invocations\n",
268                    kind,
269                    agg.avg_ratio() * 100.0,
270                    agg.avg_duration_ms(),
271                    agg.count,
272                ));
273            }
274        }
275        out.push_str(&format!("  SAVED: {} tokens\n", self.total_tokens_saved()));
276        out
277    }
278}
279
280impl Default for Pipeline {
281    fn default() -> Self {
282        Self::new()
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    struct PassthroughLayer {
291        kind: LayerKind,
292    }
293
294    impl Layer for PassthroughLayer {
295        fn kind(&self) -> LayerKind {
296            self.kind
297        }
298
299        fn process(&self, input: LayerInput) -> LayerOutput {
300            LayerOutput {
301                content: input.content,
302                tokens: input.tokens,
303                metadata: input.metadata,
304            }
305        }
306    }
307
308    struct CompressionLayer {
309        ratio: f64,
310    }
311
312    impl Layer for CompressionLayer {
313        fn kind(&self) -> LayerKind {
314            LayerKind::Compression
315        }
316
317        fn process(&self, input: LayerInput) -> LayerOutput {
318            let new_tokens = (input.tokens as f64 * self.ratio) as usize;
319            let truncated = if input.content.len() > new_tokens * 4 {
320                input.content[..new_tokens * 4].to_string()
321            } else {
322                input.content
323            };
324            LayerOutput {
325                content: truncated,
326                tokens: new_tokens,
327                metadata: input.metadata,
328            }
329        }
330    }
331
332    #[test]
333    fn layer_kind_all_ordered() {
334        let all = LayerKind::all();
335        assert_eq!(all.len(), 6);
336        assert_eq!(all[0], LayerKind::Input);
337        assert_eq!(all[5], LayerKind::Delivery);
338    }
339
340    #[test]
341    fn passthrough_preserves_content() {
342        let layer = PassthroughLayer {
343            kind: LayerKind::Input,
344        };
345        let input = LayerInput {
346            content: "hello world".to_string(),
347            tokens: 2,
348            metadata: HashMap::new(),
349        };
350        let output = layer.process(input);
351        assert_eq!(output.content, "hello world");
352        assert_eq!(output.tokens, 2);
353    }
354
355    #[test]
356    fn compression_layer_reduces() {
357        let layer = CompressionLayer { ratio: 0.5 };
358        let input = LayerInput {
359            content: "a ".repeat(100),
360            tokens: 100,
361            metadata: HashMap::new(),
362        };
363        let output = layer.process(input);
364        assert_eq!(output.tokens, 50);
365    }
366
367    #[test]
368    fn pipeline_chains_layers() {
369        let pipeline = Pipeline::new()
370            .add_layer(Box::new(PassthroughLayer {
371                kind: LayerKind::Input,
372            }))
373            .add_layer(Box::new(CompressionLayer { ratio: 0.5 }))
374            .add_layer(Box::new(PassthroughLayer {
375                kind: LayerKind::Delivery,
376            }));
377
378        let input = LayerInput {
379            content: "a ".repeat(100),
380            tokens: 100,
381            metadata: HashMap::new(),
382        };
383        let (output, metrics) = pipeline.execute(input);
384        assert_eq!(output.tokens, 50);
385        assert_eq!(metrics.len(), 3);
386        assert_eq!(metrics[0].layer, LayerKind::Input);
387        assert_eq!(metrics[1].layer, LayerKind::Compression);
388        assert_eq!(metrics[2].layer, LayerKind::Delivery);
389    }
390
391    #[test]
392    fn metrics_new_calculates_ratio() {
393        let m = LayerMetrics::new(LayerKind::Compression, 100, 50, 1000);
394        assert!((m.compression_ratio - 0.5).abs() < f64::EPSILON);
395    }
396
397    #[test]
398    fn metrics_format_readable() {
399        let metrics = vec![
400            LayerMetrics::new(LayerKind::Input, 1000, 1000, 100),
401            LayerMetrics::new(LayerKind::Compression, 1000, 300, 5000),
402            LayerMetrics::new(LayerKind::Delivery, 300, 300, 50),
403        ];
404        let formatted = Pipeline::format_metrics(&metrics);
405        assert!(formatted.contains("input"));
406        assert!(formatted.contains("compression"));
407        assert!(formatted.contains("delivery"));
408        assert!(formatted.contains("TOTAL"));
409    }
410
411    #[test]
412    fn empty_pipeline_passes_through() {
413        let pipeline = Pipeline::new();
414        let input = LayerInput {
415            content: "test".to_string(),
416            tokens: 1,
417            metadata: HashMap::new(),
418        };
419        let (output, metrics) = pipeline.execute(input);
420        assert_eq!(output.content, "test");
421        assert!(metrics.is_empty());
422    }
423
424    #[test]
425    fn pipeline_stats_record_and_summarize() {
426        let mut stats = PipelineStats::default();
427        let metrics = vec![
428            LayerMetrics::new(LayerKind::Input, 1000, 1000, 100),
429            LayerMetrics::new(LayerKind::Compression, 1000, 300, 5000),
430            LayerMetrics::new(LayerKind::Delivery, 300, 300, 50),
431        ];
432        stats.record(&metrics);
433        stats.record(&metrics);
434
435        assert_eq!(stats.runs, 2);
436        assert_eq!(stats.total_tokens_saved(), 1400);
437
438        let agg = stats.per_layer.get(&LayerKind::Compression).unwrap();
439        assert_eq!(agg.count, 2);
440        assert_eq!(agg.total_input_tokens, 2000);
441        assert_eq!(agg.total_output_tokens, 600);
442
443        let summary = stats.format_summary();
444        assert!(summary.contains("2 runs"));
445        assert!(summary.contains("SAVED: 1400"));
446    }
447
448    #[test]
449    fn aggregated_metrics_avg() {
450        let agg = AggregatedMetrics {
451            total_input_tokens: 1000,
452            total_output_tokens: 500,
453            total_duration_us: 10000,
454            count: 2,
455        };
456        assert!((agg.avg_ratio() - 0.5).abs() < f64::EPSILON);
457        assert!((agg.avg_duration_ms() - 5.0).abs() < f64::EPSILON);
458    }
459}