assay_core/trace/
upgrader.rs

1use crate::trace::schema::{
2    EpisodeEnd, EpisodeStart, StepEntry, TraceEntry, TraceEntryV1, TraceEvent,
3};
4use std::io::BufRead;
5
6pub struct StreamUpgrader<R> {
7    reader: R,
8    current_line_events: std::vec::IntoIter<TraceEvent>,
9}
10
11impl<R: BufRead> StreamUpgrader<R> {
12    pub fn new(reader: R) -> Self {
13        Self {
14            reader,
15            current_line_events: vec![].into_iter(),
16        }
17    }
18}
19
20impl<R: BufRead> Iterator for StreamUpgrader<R> {
21    type Item = serde_json::Result<TraceEvent>;
22
23    fn next(&mut self) -> Option<Self::Item> {
24        // If we have buffered events from a V1 upgrade, verify/return them
25        if let Some(event) = self.current_line_events.next() {
26            return Some(Ok(event));
27        }
28
29        // Read next line
30        let mut line = String::new();
31        match self.reader.read_line(&mut line) {
32            Ok(0) => return None, // EOF
33            Ok(_) => {}
34            Err(_) => return None, // Or handle error? Iterator usually expects Option<T>
35        }
36
37        let line = line.trim();
38        if line.is_empty() {
39            return self.next();
40        }
41
42        match serde_json::from_str::<TraceEntry>(line) {
43            Ok(TraceEntry::V2(mut event)) => {
44                apply_truncation(&mut event);
45                Some(Ok(event))
46            }
47            Ok(TraceEntry::V1(v1)) => {
48                let mut events = upgrade_v1_to_v2(v1);
49                for e in &mut events {
50                    apply_truncation(e);
51                }
52                self.current_line_events = events.into_iter();
53                self.next()
54            }
55            Err(e) => Some(Err(e)),
56        }
57    }
58}
59
60fn apply_truncation(event: &mut TraceEvent) {
61    use super::truncation::{
62        compute_sha256, compute_sha256_str, truncate_string, truncate_value_with_provenance,
63    };
64    match event {
65        TraceEvent::EpisodeStart(e) => {
66            truncate_value_with_provenance(&mut e.input, "input");
67            truncate_value_with_provenance(&mut e.meta, "meta");
68        }
69        TraceEvent::Step(e) => {
70            if let Some(c) = &mut e.content {
71                // Compute hash before truncation
72                e.content_sha256 = Some(compute_sha256_str(c));
73                if let Some(meta) = truncate_string(c, "content") {
74                    e.truncations.push(meta);
75                }
76            }
77            e.truncations
78                .extend(truncate_value_with_provenance(&mut e.meta, "meta"));
79        }
80        TraceEvent::ToolCall(e) => {
81            // Compute hashes
82            e.args_sha256 = Some(compute_sha256(&e.args));
83            if let Some(res) = &e.result {
84                e.result_sha256 = Some(compute_sha256(res));
85            }
86
87            e.truncations
88                .extend(truncate_value_with_provenance(&mut e.args, "args"));
89
90            if let Some(mut result_val) = e.result.take() {
91                e.truncations
92                    .extend(truncate_value_with_provenance(&mut result_val, "result"));
93                e.result = Some(result_val);
94            }
95        }
96        TraceEvent::EpisodeEnd(_) => {}
97    }
98}
99
100fn upgrade_v1_to_v2(v1: TraceEntryV1) -> Vec<TraceEvent> {
101    let ts = 0;
102    // Ideally extract from meta if possible, but keep deterministic.
103
104    let episode_id = v1.request_id.clone();
105
106    let start = TraceEvent::EpisodeStart(EpisodeStart {
107        episode_id: episode_id.clone(),
108        timestamp: ts,
109        input: serde_json::json!({ "prompt": v1.prompt }),
110        meta: v1.meta.clone(),
111    });
112
113    let step = TraceEvent::Step(StepEntry {
114        episode_id: episode_id.clone(),
115        step_id: format!("{}-step-0", episode_id),
116        idx: 0,
117        timestamp: ts + 1,
118        kind: "llm_completion".to_string(),
119        name: Some("model".to_string()),
120        content: Some(v1.response),
121        meta: serde_json::Value::Null,
122        content_sha256: None, // Filled later
123        truncations: Vec::new(),
124    });
125
126    let end = TraceEvent::EpisodeEnd(EpisodeEnd {
127        episode_id,
128        timestamp: ts + 2,
129        outcome: Some("pass".to_string()), // V1 usually implies successful run?
130        final_output: None,
131    });
132
133    vec![start, step, end]
134}