Skip to main content

assay_core/providers/
trace.rs

1use crate::errors::{diagnostic::codes, similarity::closest_prompt, Diagnostic};
2use crate::model::LlmResponse;
3use crate::providers::llm::LlmClient;
4use async_trait::async_trait;
5use sha2::Digest;
6use std::collections::{HashMap, HashSet};
7use std::fs::File;
8use std::io::BufRead;
9use std::path::Path;
10use std::sync::Arc;
11
12#[derive(Clone)]
13pub struct TraceClient {
14    // prompts -> response
15    traces: Arc<HashMap<String, LlmResponse>>,
16    fingerprint: String,
17}
18impl TraceClient {
19    pub fn from_path<P: AsRef<Path>>(path: P) -> anyhow::Result<Self> {
20        let file = File::open(path.as_ref()).map_err(|e| {
21            anyhow::anyhow!(
22                "failed to open trace file '{}': {}",
23                path.as_ref().display(),
24                e
25            )
26        })?;
27        let reader = std::io::BufReader::new(file);
28
29        let mut traces = HashMap::new();
30        let mut request_ids = HashSet::new();
31
32        // State for accumulating V2 episodes
33        struct EpisodeState {
34            input: Option<String>,
35            output: Option<String>,
36            model: Option<String>,
37            meta: serde_json::Value,
38            input_is_model: bool,
39            tool_calls: Vec<crate::model::ToolCallRecord>,
40        }
41        let mut active_episodes: HashMap<String, EpisodeState> = HashMap::new();
42
43        for (i, line_res) in reader.lines().enumerate() {
44            let line = line_res?;
45            if line.trim().is_empty() {
46                continue;
47            }
48
49            // Attempt V2 Parse first (TraceEntry enum)
50            // If it fails, fallback to legacy V1 (TraceEntryV1/TraceEntry struct local def)
51            // Actually, we can use `TraceEntry` enum from schema if we have it?
52            // But schema might not be strictly followed in loose JSON files.
53            // Let's use serde_json::Value to sniff.
54
55            let v: serde_json::Value = serde_json::from_str(&line).map_err(|e| {
56                anyhow::anyhow!(
57                    "line {}: Invalid trace format. Expected JSONL object.\n  Error: {}\n  Content: {}",
58                    i + 1,
59                    e,
60                    line.chars().take(50).collect::<String>()
61                )
62            })?;
63
64            // Heuristic detection
65            let mut prompt_opt = None;
66            let mut response_opt = None;
67            let mut model = "trace".to_string();
68            let mut meta = serde_json::json!({});
69            let mut request_id_check = None;
70
71            if let Some(t) = v.get("type").and_then(|t| t.as_str()) {
72                match t {
73                    "assay.trace" => {
74                        // V1
75                        prompt_opt = v.get("prompt").and_then(|s| s.as_str()).map(String::from);
76                        response_opt = v
77                            .get("response")
78                            .or(v.get("text"))
79                            .and_then(|s| s.as_str())
80                            .map(String::from);
81                        if let Some(m) = v.get("model").and_then(|s| s.as_str()) {
82                            model = m.to_string();
83                        }
84                        if let Some(m) = v.get("meta") {
85                            meta = m.clone();
86                        }
87                        if let Some(r) = v.get("request_id").and_then(|s| s.as_str()) {
88                            request_id_check = Some(r.to_string());
89                        }
90                    }
91                    "episode_start" => {
92                        // START V2
93                        if let Ok(ev) =
94                            serde_json::from_value::<crate::trace::schema::EpisodeStart>(v.clone())
95                        {
96                            let input_prompt = ev
97                                .input
98                                .get("prompt")
99                                .and_then(|s| s.as_str())
100                                .map(String::from);
101                            let has_input = input_prompt.is_some();
102                            let state = EpisodeState {
103                                input: input_prompt,
104                                output: None, // accum later
105                                model: None,  // extract from steps?
106                                meta: ev.meta,
107                                input_is_model: has_input, // authoritative only if present
108                                tool_calls: Vec::new(),
109                            };
110                            active_episodes.insert(ev.episode_id, state);
111                            continue; // Wait for end
112                        }
113                    }
114                    "tool_call" => {
115                        if let Ok(ev) =
116                            serde_json::from_value::<crate::trace::schema::ToolCallEntry>(v.clone())
117                        {
118                            if let Some(state) = active_episodes.get_mut(&ev.episode_id) {
119                                state.tool_calls.push(crate::model::ToolCallRecord {
120                                    id: format!("{}-{}", ev.step_id, ev.call_index.unwrap_or(0)),
121                                    tool_name: ev.tool_name,
122                                    args: ev.args,
123                                    result: ev.result,
124                                    error: ev.error.map(serde_json::Value::String),
125                                    index: state.tool_calls.len(), // Global index for sequence validation
126                                    ts_ms: ev.timestamp,
127                                });
128                            }
129                        }
130                    }
131                    "episode_end" => {
132                        // END V2
133                        if let Ok(ev) =
134                            serde_json::from_value::<crate::trace::schema::EpisodeEnd>(v.clone())
135                        {
136                            if let Some(mut state) = active_episodes.remove(&ev.episode_id) {
137                                // Finalize
138                                if let Some(out) = ev.final_output {
139                                    state.output = Some(out);
140                                }
141
142                                if let Some(p) = state.input {
143                                    prompt_opt = Some(p);
144                                    response_opt = state.output;
145
146                                    // Inject tool calls into meta
147                                    if !state.tool_calls.is_empty() {
148                                        state.meta["tool_calls"] =
149                                            serde_json::to_value(&state.tool_calls)
150                                                .unwrap_or_default();
151                                    }
152
153                                    meta = state.meta;
154                                    // model?
155                                }
156                            }
157                        }
158                    }
159
160                    "step" => {
161                        if let Ok(ev) =
162                            serde_json::from_value::<crate::trace::schema::StepEntry>(v.clone())
163                        {
164                            if let Some(state) = active_episodes.get_mut(&ev.episode_id) {
165                                // PROMPT EXTRACTION
166                                // Logic:
167                                // 1. If step is MODEL: Prefer this prompt over any previous (unless locked? No, "First Wins" for model steps).
168                                //    Actually standard "First Wins" means first MODEL step.
169                                // 2. If step is NOT model: Use as fallback only if we have NO input yet.
170
171                                let is_model = ev.kind == "model";
172                                let can_extract = if is_model {
173                                    // If we are model, we overwrite if current input is NOT model (fallback) OR if input is None.
174                                    // If we already have a model input, we skip (First Model Wins).
175                                    !state.input_is_model
176                                } else {
177                                    // If not model, only extract if we have absolutely nothing.
178                                    state.input.is_none()
179                                };
180
181                                if can_extract {
182                                    let mut found_prompt = None;
183
184                                    if let Some(c) = &ev.content {
185                                        if let Ok(c_json) =
186                                            serde_json::from_str::<serde_json::Value>(c)
187                                        {
188                                            if let Some(p) =
189                                                c_json.get("prompt").and_then(|s| s.as_str())
190                                            {
191                                                found_prompt = Some(p.to_string());
192                                            }
193                                        }
194                                    }
195                                    if found_prompt.is_none() {
196                                        if let Some(p) =
197                                            ev.meta.get("gen_ai.prompt").and_then(|s| s.as_str())
198                                        {
199                                            found_prompt = Some(p.to_string());
200                                        }
201                                    }
202
203                                    if let Some(p) = found_prompt {
204                                        state.input = Some(p);
205                                        if is_model {
206                                            state.input_is_model = true;
207                                        }
208                                        // DEBUG: remove me
209                                        /*
210                                        eprintln!("DEBUG: TraceClient extracted prompt: '{}' is_model={}", state.input.as_ref().unwrap(), is_model);
211                                        */
212                                    }
213                                }
214
215                                // --- OUTPUT EXTRACTION (Last Wins) ---
216                                // Rule 4: Step Content "completion"
217                                if let Some(c) = &ev.content {
218                                    let mut extracted = None;
219                                    if let Ok(c_json) = serde_json::from_str::<serde_json::Value>(c)
220                                    {
221                                        if let Some(resp) =
222                                            c_json.get("completion").and_then(|s| s.as_str())
223                                        {
224                                            extracted = Some(resp.to_string());
225                                            // Capture model if present
226                                            if let Some(m) =
227                                                c_json.get("model").and_then(|s| s.as_str())
228                                            {
229                                                state.model = Some(m.to_string());
230                                            }
231                                        }
232                                    }
233
234                                    if let Some(out) = extracted {
235                                        state.output = Some(out);
236                                    } else {
237                                        // Fallback: use raw content as output if structured extraction failed
238                                        state.output = Some(c.clone());
239                                    }
240                                }
241                                // Rule 5: Step Meta "gen_ai.completion"
242                                if let Some(resp) =
243                                    ev.meta.get("gen_ai.completion").and_then(|s| s.as_str())
244                                {
245                                    state.output = Some(resp.to_string());
246                                }
247                                if let Some(m) = ev
248                                    .meta
249                                    .get("gen_ai.request.model")
250                                    .or(ev.meta.get("gen_ai.response.model"))
251                                    .and_then(|s| s.as_str())
252                                {
253                                    state.model = Some(m.to_string());
254                                }
255                            }
256                        }
257                        continue;
258                    }
259                    _ => {
260                        continue;
261                    }
262                }
263            } else {
264                // Legacy loose JSON (no type)
265                prompt_opt = v.get("prompt").and_then(|s| s.as_str()).map(String::from);
266                response_opt = v
267                    .get("response")
268                    .or(v.get("text"))
269                    .and_then(|s| s.as_str())
270                    .map(String::from);
271                // Fix: Extract other fields too
272                if let Some(m) = v.get("model").and_then(|s| s.as_str()) {
273                    model = m.to_string();
274                }
275                if let Some(r) = v.get("request_id").and_then(|s| s.as_str()) {
276                    request_id_check = Some(r.to_string());
277                }
278
279                // Fix: Extract tool calls for V1/Legacy trace validation
280                let tool_name = v.get("tool").and_then(|s| s.as_str()).map(String::from);
281                let tool_args = v.get("args").cloned();
282
283                if let Some(tool) = tool_name {
284                    let record = crate::model::ToolCallRecord {
285                        id: "legacy-v1".to_string(),
286                        tool_name: tool,
287                        args: tool_args.unwrap_or(serde_json::json!({})),
288                        result: None,
289                        error: None,
290                        index: 0,
291                        ts_ms: 0,
292                    };
293                    meta["tool_calls"] = serde_json::json!([record]);
294                } else if let Some(calls) = v.get("tool_calls").and_then(|v| v.as_array()) {
295                    // Propagate full list if present in V1
296                    meta["tool_calls"] = serde_json::Value::Array(calls.clone());
297                }
298            }
299
300            if let (Some(p), Some(r)) = (prompt_opt, response_opt) {
301                // Finalize Entry
302                // Uniqueness Check
303                if let Some(rid) = &request_id_check {
304                    if request_ids.contains(rid) {
305                        return Err(anyhow::anyhow!(
306                            "line {}: Duplicate request_id {}",
307                            i + 1,
308                            rid
309                        ));
310                    }
311                    request_ids.insert(rid.clone());
312                }
313
314                if traces.contains_key(&p) {
315                    // Duplicate prompt handling? Overwrite or Error?
316                    // Existing code errors.
317                    return Err(anyhow::anyhow!(
318                        "Duplicate prompt found in trace file: {}",
319                        p
320                    ));
321                }
322
323                traces.insert(
324                    p,
325                    LlmResponse {
326                        text: r,
327                        meta,
328                        model,
329                        provider: "trace".to_string(),
330                        ..Default::default()
331                    },
332                );
333            }
334        }
335
336        // Flush active episodes at EOF
337        for (id, state) in active_episodes {
338            if let (Some(p), Some(r)) = (state.input.clone(), state.output.clone()) {
339                // ... reuse insertion logic (refactor to helper?) ...
340                // Duplicate check
341                if traces.contains_key(&p) {
342                    eprintln!("Warning: Duplicate prompt skipped at EOF for id {}", id);
343                    continue;
344                }
345                traces.insert(
346                    p,
347                    LlmResponse {
348                        text: r,
349                        meta: state.meta,
350                        model: state.model.unwrap_or_else(|| "trace".to_string()),
351                        provider: "trace".to_string(),
352                        ..Default::default()
353                    },
354                );
355            }
356        }
357
358        // Compute deterministic fingerprint of traces
359        let mut keys: Vec<&String> = traces.keys().collect();
360        keys.sort();
361        let mut hasher = sha2::Sha256::new();
362        for k in keys {
363            use sha2::Digest;
364            hasher.update(k.as_bytes());
365            if let Some(v) = traces.get(k) {
366                // hash validation relevant parts of response
367                hasher.update(v.text.as_bytes());
368                // include meta/model? yes for completeness
369                hasher.update(v.model.as_bytes());
370            }
371        }
372        let fingerprint = hex::encode(hasher.finalize());
373
374        Ok(Self {
375            traces: Arc::new(traces),
376            fingerprint,
377        })
378    }
379}
380
381#[async_trait]
382impl LlmClient for TraceClient {
383    async fn complete(
384        &self,
385        prompt: &str,
386        _context: Option<&[String]>,
387    ) -> anyhow::Result<LlmResponse> {
388        if let Some(resp) = self.traces.get(prompt) {
389            Ok(resp.clone())
390        } else {
391            // Find closest match for hint
392            let closest = closest_prompt(prompt, self.traces.keys());
393
394            let mut diag = Diagnostic::new(
395                codes::E_TRACE_MISS,
396                "Trace miss: prompt not found in loaded traces".to_string(),
397            )
398            .with_source("trace")
399            .with_context(serde_json::json!({
400                "prompt": prompt,
401                "closest_match": closest
402            }));
403
404            if let Some(match_) = closest {
405                diag = diag.with_fix_step(format!(
406                    "Did you mean '{}'? (similarity: {:.2})",
407                    match_.prompt, match_.similarity
408                ));
409                diag = diag.with_fix_step("Update your input prompt to match the trace exactly");
410            } else {
411                diag = diag.with_fix_step("No similar prompts found in trace file");
412            }
413
414            diag = diag.with_fix_step("Regenerate the trace file: assay trace ingest ...");
415
416            Err(anyhow::Error::new(diag))
417        }
418    }
419
420    fn provider_name(&self) -> &'static str {
421        "trace"
422    }
423
424    fn fingerprint(&self) -> Option<String> {
425        Some(self.fingerprint.clone())
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432    use std::io::Write;
433    use tempfile::NamedTempFile;
434
435    #[tokio::test]
436    async fn test_trace_client_happy_path() -> anyhow::Result<()> {
437        let mut tmp = NamedTempFile::new()?;
438        writeln!(
439            tmp,
440            r#"{{"prompt": "hello", "response": "world", "model": "gpt-4"}}"#
441        )?;
442        writeln!(tmp, r#"{{"prompt": "foo", "response": "bar"}}"#)?;
443
444        let client = TraceClient::from_path(tmp.path())?;
445
446        let resp1 = client.complete("hello", None).await?;
447        assert_eq!(resp1.text, "world");
448        assert_eq!(resp1.model, "gpt-4");
449
450        let resp2 = client.complete("foo", None).await?;
451        assert_eq!(resp2.text, "bar");
452        assert_eq!(resp2.provider, "trace"); // default
453
454        Ok(())
455    }
456
457    #[tokio::test]
458    async fn test_trace_client_miss() -> anyhow::Result<()> {
459        let mut tmp = NamedTempFile::new()?;
460        writeln!(tmp, r#"{{"prompt": "exists", "response": "yes"}}"#)?;
461
462        let client = TraceClient::from_path(tmp.path())?;
463        let result = client.complete("does not exist", None).await;
464        assert!(result.is_err());
465        Ok(())
466    }
467
468    #[tokio::test]
469    async fn test_trace_client_duplicate_prompt() -> anyhow::Result<()> {
470        let mut tmp = NamedTempFile::new()?;
471        writeln!(tmp, r#"{{"prompt": "dup", "response": "1"}}"#)?;
472        writeln!(tmp, r#"{{"prompt": "dup", "response": "2"}}"#)?;
473
474        let result = TraceClient::from_path(tmp.path());
475        assert!(result.is_err());
476        Ok(())
477    }
478
479    #[tokio::test]
480    async fn test_trace_client_duplicate_request_id() -> anyhow::Result<()> {
481        let mut tmp = NamedTempFile::new()?;
482        // different prompts, same ID
483        writeln!(
484            tmp,
485            r#"{{"request_id": "id1", "prompt": "p1", "response": "1"}}"#
486        )?;
487        writeln!(
488            tmp,
489            r#"{{"request_id": "id1", "prompt": "p2", "response": "2"}}"#
490        )?;
491
492        let result = TraceClient::from_path(tmp.path());
493        assert!(result.is_err());
494        assert!(result
495            .err()
496            .unwrap()
497            .to_string()
498            .contains("Duplicate request_id"));
499        Ok(())
500    }
501
502    #[tokio::test]
503    async fn test_trace_schema_validation() -> anyhow::Result<()> {
504        let mut tmp = NamedTempFile::new()?;
505        // Bad version (Legacy JSON with version but missing response should be skipped)
506        writeln!(tmp, r#"{{"schema_version": 2, "prompt": "p"}}"#)?;
507        let client = TraceClient::from_path(tmp.path())?;
508        assert!(client.complete("p", None).await.is_err()); // Trace miss
509
510        let mut tmp2 = NamedTempFile::new()?;
511        // Bad type - should be ignored (Ok, empty) or Err depending on policy.
512        // Current implementation ignores unknown types (forward compat).
513        writeln!(
514            tmp2,
515            r#"{{"type": "wrong", "prompt": "p", "response": "r"}}"#
516        )?;
517        let client = TraceClient::from_path(tmp2.path())?;
518        assert!(client.complete("p", None).await.is_err()); // "p" not found because line ignored
519
520        let mut tmp3 = NamedTempFile::new()?;
521        // Missing text/response
522        writeln!(tmp3, r#"{{"prompt": "p"}}"#)?;
523        // Valid legacy line but missing required response -> TraceClient skips it.
524        // So client is empty, returns Ok.
525        let client = TraceClient::from_path(tmp3.path())?;
526        assert!(client.complete("p", None).await.is_err()); // Trace miss expected
527
528        Ok(())
529    }
530
531    #[tokio::test]
532    async fn test_trace_meta_preservation() -> anyhow::Result<()> {
533        let mut tmp = NamedTempFile::new()?;
534        // Using verbatim JSON from trace.jsonl (simplified)
535        let json = r#"{"schema_version":1,"type":"assay.trace","request_id":"test-1","prompt":"Say hello","response":"Hello world","meta":{"assay":{"embeddings":{"model":"text-embedding-3-small","response":[0.1],"reference":[0.1]}}}}"#;
536        writeln!(tmp, "{}", json)?;
537
538        let client = TraceClient::from_path(tmp.path())?;
539        let resp = client.complete("Say hello", None).await?;
540
541        println!("Meta from test: {}", resp.meta);
542        assert!(
543            resp.meta.pointer("/assay/embeddings/response").is_some(),
544            "Meta embeddings missing!"
545        );
546        Ok(())
547    }
548
549    #[tokio::test]
550    async fn test_v2_replay_precedence() -> anyhow::Result<()> {
551        let mut tmp = NamedTempFile::new()?;
552        // Scenario: Input in Step Content should override nothing (it's first),
553        // Output in 2nd Step should override 1st Step.
554
555        let ep_start = r#"{"type":"episode_start","episode_id":"e1","timestamp":100,"input":null}"#;
556        let step1 = r#"{"type":"step","episode_id":"e1","step_id":"s1","kind":"model","timestamp":101,"content":"{\"prompt\":\"original_prompt\",\"completion\":\"output_1\"}"}"#;
557        // Step 2 has same prompt (ignored if input set) but new completion (should override)
558        let step2 = r#"{"type":"step","episode_id":"e1","step_id":"s2","kind":"model","timestamp":102,"content":"{\"prompt\":\"ignored\",\"completion\":\"final_output\"}"}"#;
559        // Step 3 has meta completion (should override content?) per our rule "last wins" for output
560        let step3 = r#"{"type":"step","episode_id":"e1","step_id":"s3","kind":"model","timestamp":103,"content":null,"meta":{"gen_ai.completion":"meta_final"}}"#;
561
562        let ep_end = r#"{"type":"episode_end","episode_id":"e1","timestamp":104}"#;
563
564        writeln!(tmp, "{}", ep_start)?;
565        writeln!(tmp, "{}", step1)?;
566        writeln!(tmp, "{}", step2)?;
567        writeln!(tmp, "{}", step3)?;
568        writeln!(tmp, "{}", ep_end)?;
569
570        let client = TraceClient::from_path(tmp.path())?;
571        let resp = client.complete("original_prompt", None).await?; // Should find via Step 1
572
573        // Output should be from Step 3 (last one)
574        assert_eq!(resp.text, "meta_final");
575
576        Ok(())
577    }
578
579    #[tokio::test]
580    async fn test_eof_flush_partial_episode() -> anyhow::Result<()> {
581        let mut tmp = NamedTempFile::new()?;
582        // No episode_end
583        let ep_start = r#"{"type":"episode_start","episode_id":"e_flush","timestamp":100,"input":{"prompt":"flush_me"}}"#;
584        let step1 = r#"{"type":"step","episode_id":"e_flush","step_id":"s1","kind":"model","timestamp":101,"content":"{\"completion\":\"flushed_output\"}"}"#;
585
586        writeln!(tmp, "{}", ep_start)?;
587        writeln!(tmp, "{}", step1)?;
588
589        let client = TraceClient::from_path(tmp.path())?;
590        let resp = client.complete("flush_me", None).await?;
591        assert_eq!(resp.text, "flushed_output");
592
593        Ok(())
594    }
595}