Skip to main content

assay_core/providers/
trace.rs

1use crate::errors::{diagnostic::codes, similarity::closest_prompt, Diagnostic};
2use crate::model::LlmResponse;
3use crate::providers::llm::LlmClient;
4use async_trait::async_trait;
5use serde_json as sj;
6use std::collections::HashMap;
7use std::sync::Arc;
8
9#[path = "trace_next/mod.rs"]
10mod trace_next;
11
12#[derive(Clone)]
13pub struct TraceClient {
14    traces: Arc<HashMap<String, LlmResponse>>,
15    fingerprint: String,
16}
17
18impl TraceClient {
19    pub fn from_path<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<Self> {
20        trace_next::from_path_impl(path)
21    }
22}
23
24#[async_trait]
25impl LlmClient for TraceClient {
26    async fn complete(
27        &self,
28        prompt: &str,
29        _context: Option<&[String]>,
30    ) -> anyhow::Result<LlmResponse> {
31        if let Some(resp) = self.traces.get(prompt) {
32            Ok(resp.clone())
33        } else {
34            let closest = closest_prompt(prompt, self.traces.keys());
35
36            let mut diag = Diagnostic::new(
37                codes::E_TRACE_MISS,
38                "Trace miss: prompt not found in loaded traces".to_string(),
39            )
40            .with_source("trace")
41            .with_context(sj::json!({
42                "prompt": prompt,
43                "closest_match": closest
44            }));
45
46            if let Some(match_) = closest {
47                diag = diag.with_fix_step(format!(
48                    "Did you mean '{}'? (similarity: {:.2})",
49                    match_.prompt, match_.similarity
50                ));
51                diag = diag.with_fix_step("Update your input prompt to match the trace exactly");
52            } else {
53                diag = diag.with_fix_step("No similar prompts found in trace file");
54            }
55
56            diag = diag.with_fix_step("Regenerate the trace file: assay trace ingest ...");
57
58            Err(anyhow::Error::new(diag))
59        }
60    }
61
62    fn provider_name(&self) -> &'static str {
63        "trace"
64    }
65
66    fn fingerprint(&self) -> Option<String> {
67        Some(self.fingerprint.clone())
68    }
69}
70#[cfg(test)]
71mod tests {
72    use super::*;
73    use std::io::Write;
74    use tempfile::NamedTempFile;
75
76    #[tokio::test]
77    async fn test_trace_client_happy_path() -> anyhow::Result<()> {
78        let mut tmp = NamedTempFile::new()?;
79        writeln!(
80            tmp,
81            r#"{{"prompt": "hello", "response": "world", "model": "gpt-4"}}"#
82        )?;
83        writeln!(tmp, r#"{{"prompt": "foo", "response": "bar"}}"#)?;
84
85        let client = TraceClient::from_path(tmp.path())?;
86
87        let resp1 = client.complete("hello", None).await?;
88        assert_eq!(resp1.text, "world");
89        assert_eq!(resp1.model, "gpt-4");
90
91        let resp2 = client.complete("foo", None).await?;
92        assert_eq!(resp2.text, "bar");
93        assert_eq!(resp2.provider, "trace"); // default
94
95        Ok(())
96    }
97
98    #[tokio::test]
99    async fn test_trace_client_miss() -> anyhow::Result<()> {
100        let mut tmp = NamedTempFile::new()?;
101        writeln!(tmp, r#"{{"prompt": "exists", "response": "yes"}}"#)?;
102
103        let client = TraceClient::from_path(tmp.path())?;
104        let result = client.complete("does not exist", None).await;
105        assert!(result.is_err());
106        Ok(())
107    }
108
109    #[tokio::test]
110    async fn test_trace_client_duplicate_prompt() -> anyhow::Result<()> {
111        let mut tmp = NamedTempFile::new()?;
112        writeln!(tmp, r#"{{"prompt": "dup", "response": "1"}}"#)?;
113        writeln!(tmp, r#"{{"prompt": "dup", "response": "2"}}"#)?;
114
115        let result = TraceClient::from_path(tmp.path());
116        assert!(result.is_err());
117        Ok(())
118    }
119
120    #[tokio::test]
121    async fn test_trace_client_duplicate_request_id() -> anyhow::Result<()> {
122        let mut tmp = NamedTempFile::new()?;
123        // different prompts, same ID
124        writeln!(
125            tmp,
126            r#"{{"request_id": "id1", "prompt": "p1", "response": "1"}}"#
127        )?;
128        writeln!(
129            tmp,
130            r#"{{"request_id": "id1", "prompt": "p2", "response": "2"}}"#
131        )?;
132
133        let result = TraceClient::from_path(tmp.path());
134        assert!(result.is_err());
135        assert!(result
136            .err()
137            .unwrap()
138            .to_string()
139            .contains("Duplicate request_id"));
140        Ok(())
141    }
142
143    #[tokio::test]
144    async fn test_trace_schema_validation() -> anyhow::Result<()> {
145        let mut tmp = NamedTempFile::new()?;
146        // Bad version (Legacy JSON with version but missing response should be skipped)
147        writeln!(tmp, r#"{{"schema_version": 2, "prompt": "p"}}"#)?;
148        let client = TraceClient::from_path(tmp.path())?;
149        assert!(client.complete("p", None).await.is_err()); // Trace miss
150
151        let mut tmp2 = NamedTempFile::new()?;
152        // Bad type - should be ignored (Ok, empty) or Err depending on policy.
153        // Current implementation ignores unknown types (forward compat).
154        writeln!(
155            tmp2,
156            r#"{{"type": "wrong", "prompt": "p", "response": "r"}}"#
157        )?;
158        let client = TraceClient::from_path(tmp2.path())?;
159        assert!(client.complete("p", None).await.is_err()); // "p" not found because line ignored
160
161        let mut tmp3 = NamedTempFile::new()?;
162        // Missing text/response
163        writeln!(tmp3, r#"{{"prompt": "p"}}"#)?;
164        // Valid legacy line but missing required response -> TraceClient skips it.
165        // So client is empty, returns Ok.
166        let client = TraceClient::from_path(tmp3.path())?;
167        assert!(client.complete("p", None).await.is_err()); // Trace miss expected
168
169        Ok(())
170    }
171
172    #[tokio::test]
173    async fn test_trace_meta_preservation() -> anyhow::Result<()> {
174        let mut tmp = NamedTempFile::new()?;
175        // Using verbatim JSON from trace.jsonl (simplified)
176        let json = r#"{"schema_version":1,"type":"assay.trace","request_id":"test-1","prompt":"Say hello","response":"Hello world","meta":{"assay":{"embeddings":{"model":"text-embedding-3-small","response":[0.1],"reference":[0.1]}}}}"#;
177        writeln!(tmp, "{}", json)?;
178
179        let client = TraceClient::from_path(tmp.path())?;
180        let resp = client.complete("Say hello", None).await?;
181
182        println!("Meta from test: {}", resp.meta);
183        assert!(
184            resp.meta.pointer("/assay/embeddings/response").is_some(),
185            "Meta embeddings missing!"
186        );
187        Ok(())
188    }
189
190    #[tokio::test]
191    async fn test_v2_replay_precedence() -> anyhow::Result<()> {
192        let mut tmp = NamedTempFile::new()?;
193        // Scenario: Input in Step Content should override nothing (it's first),
194        // Output in 2nd Step should override 1st Step.
195
196        let ep_start = r#"{"type":"episode_start","episode_id":"e1","timestamp":100,"input":null}"#;
197        let step1 = r#"{"type":"step","episode_id":"e1","step_id":"s1","kind":"model","timestamp":101,"content":"{\"prompt\":\"original_prompt\",\"completion\":\"output_1\"}"}"#;
198        // Step 2 has same prompt (ignored if input set) but new completion (should override)
199        let step2 = r#"{"type":"step","episode_id":"e1","step_id":"s2","kind":"model","timestamp":102,"content":"{\"prompt\":\"ignored\",\"completion\":\"final_output\"}"}"#;
200        // Step 3 has meta completion (should override content?) per our rule "last wins" for output
201        let step3 = r#"{"type":"step","episode_id":"e1","step_id":"s3","kind":"model","timestamp":103,"content":null,"meta":{"gen_ai.completion":"meta_final"}}"#;
202
203        let ep_end = r#"{"type":"episode_end","episode_id":"e1","timestamp":104}"#;
204
205        writeln!(tmp, "{}", ep_start)?;
206        writeln!(tmp, "{}", step1)?;
207        writeln!(tmp, "{}", step2)?;
208        writeln!(tmp, "{}", step3)?;
209        writeln!(tmp, "{}", ep_end)?;
210
211        let client = TraceClient::from_path(tmp.path())?;
212        let resp = client.complete("original_prompt", None).await?; // Should find via Step 1
213
214        // Output should be from Step 3 (last one)
215        assert_eq!(resp.text, "meta_final");
216
217        Ok(())
218    }
219
220    #[tokio::test]
221    async fn test_eof_flush_partial_episode() -> anyhow::Result<()> {
222        let mut tmp = NamedTempFile::new()?;
223        // No episode_end
224        let ep_start = r#"{"type":"episode_start","episode_id":"e_flush","timestamp":100,"input":{"prompt":"flush_me"}}"#;
225        let step1 = r#"{"type":"step","episode_id":"e_flush","step_id":"s1","kind":"model","timestamp":101,"content":"{\"completion\":\"flushed_output\"}"}"#;
226
227        writeln!(tmp, "{}", ep_start)?;
228        writeln!(tmp, "{}", step1)?;
229
230        let client = TraceClient::from_path(tmp.path())?;
231        let resp = client.complete("flush_me", None).await?;
232        assert_eq!(resp.text, "flushed_output");
233
234        Ok(())
235    }
236
237    #[tokio::test]
238    async fn test_episode_end_with_null_meta_preserves_tool_calls() -> anyhow::Result<()> {
239        let mut tmp = NamedTempFile::new()?;
240        let ep_start = r#"{"type":"episode_start","episode_id":"e_meta_null","timestamp":100,"input":{"prompt":"meta_null_prompt"},"meta":null}"#;
241        let tool_call = r#"{"type":"tool_call","episode_id":"e_meta_null","step_id":"s1","timestamp":101,"tool_name":"fs.read","call_index":0,"args":{"path":"/tmp/demo.txt"}}"#;
242        let ep_end = r#"{"type":"episode_end","episode_id":"e_meta_null","timestamp":102,"final_output":"done"}"#;
243
244        writeln!(tmp, "{}", ep_start)?;
245        writeln!(tmp, "{}", tool_call)?;
246        writeln!(tmp, "{}", ep_end)?;
247
248        let client = TraceClient::from_path(tmp.path())?;
249        let resp = client.complete("meta_null_prompt", None).await?;
250        assert_eq!(resp.text, "done");
251        assert_eq!(
252            resp.meta
253                .pointer("/tool_calls")
254                .and_then(|v| v.as_array())
255                .map(|a| a.len()),
256            Some(1)
257        );
258        assert_eq!(
259            resp.meta
260                .pointer("/tool_calls/0/tool_name")
261                .and_then(|v| v.as_str()),
262            Some("fs.read")
263        );
264
265        Ok(())
266    }
267
268    #[tokio::test]
269    async fn test_episode_end_propagates_step_model_to_response() -> anyhow::Result<()> {
270        let mut tmp = NamedTempFile::new()?;
271        let ep_start = r#"{"type":"episode_start","episode_id":"e_model","timestamp":100,"input":{"prompt":"model_prompt"}}"#;
272        let step1 = r#"{"type":"step","episode_id":"e_model","step_id":"s1","kind":"model","timestamp":101,"content":"{\"completion\":\"model_output\",\"model\":\"gpt-4o-mini\"}"}"#;
273        let ep_end = r#"{"type":"episode_end","episode_id":"e_model","timestamp":102}"#;
274
275        writeln!(tmp, "{}", ep_start)?;
276        writeln!(tmp, "{}", step1)?;
277        writeln!(tmp, "{}", ep_end)?;
278
279        let client = TraceClient::from_path(tmp.path())?;
280        let resp = client.complete("model_prompt", None).await?;
281        assert_eq!(resp.text, "model_output");
282        assert_eq!(resp.model, "gpt-4o-mini");
283
284        Ok(())
285    }
286
287    #[tokio::test]
288    async fn test_eof_flush_preserves_tool_calls_in_meta() -> anyhow::Result<()> {
289        let mut tmp = NamedTempFile::new()?;
290        let ep_start = r#"{"type":"episode_start","episode_id":"e_eof_tools","timestamp":100,"input":{"prompt":"eof_tools_prompt"}}"#;
291        let tool_call = r#"{"type":"tool_call","episode_id":"e_eof_tools","step_id":"s1","timestamp":101,"tool_name":"fs.write","call_index":0,"args":{"path":"/tmp/out.txt"}}"#;
292        let step1 = r#"{"type":"step","episode_id":"e_eof_tools","step_id":"s2","kind":"model","timestamp":102,"content":"{\"completion\":\"eof_output\"}"}"#;
293        // Intentionally no episode_end: exercises EOF flush path.
294
295        writeln!(tmp, "{}", ep_start)?;
296        writeln!(tmp, "{}", tool_call)?;
297        writeln!(tmp, "{}", step1)?;
298
299        let client = TraceClient::from_path(tmp.path())?;
300        let resp = client.complete("eof_tools_prompt", None).await?;
301        assert_eq!(resp.text, "eof_output");
302        assert_eq!(
303            resp.meta
304                .pointer("/tool_calls")
305                .and_then(|v| v.as_array())
306                .map(|a| a.len()),
307            Some(1)
308        );
309        assert_eq!(
310            resp.meta
311                .pointer("/tool_calls/0/tool_name")
312                .and_then(|v| v.as_str()),
313            Some("fs.write")
314        );
315
316        Ok(())
317    }
318
319    #[tokio::test]
320    async fn test_from_path_invalid_json_has_line_context() -> anyhow::Result<()> {
321        let mut tmp = NamedTempFile::new()?;
322        writeln!(tmp, "not-json")?;
323
324        let err = match TraceClient::from_path(tmp.path()) {
325            Ok(_) => panic!("invalid JSON must fail"),
326            Err(e) => e.to_string(),
327        };
328        assert!(err.contains("Invalid trace format"));
329        assert!(err.contains("line 1"));
330        assert!(err.contains("Content: not-json"));
331
332        Ok(())
333    }
334
335    #[tokio::test]
336    async fn test_legacy_tool_fields_promote_to_tool_calls_meta() -> anyhow::Result<()> {
337        let mut tmp = NamedTempFile::new()?;
338        writeln!(
339            tmp,
340            r#"{{"prompt":"legacy_tool","response":"ok","tool":"fs.read","args":{{"path":"/tmp/demo.txt"}}}}"#
341        )?;
342
343        let client = TraceClient::from_path(tmp.path())?;
344        let resp = client.complete("legacy_tool", None).await?;
345        assert_eq!(resp.text, "ok");
346        assert_eq!(
347            resp.meta
348                .pointer("/tool_calls")
349                .and_then(|v| v.as_array())
350                .map(|a| a.len()),
351            Some(1)
352        );
353        assert_eq!(
354            resp.meta
355                .pointer("/tool_calls/0/tool_name")
356                .and_then(|v| v.as_str()),
357            Some("fs.read")
358        );
359        assert_eq!(
360            resp.meta
361                .pointer("/tool_calls/0/args/path")
362                .and_then(|v| v.as_str()),
363            Some("/tmp/demo.txt")
364        );
365
366        Ok(())
367    }
368
369    #[tokio::test]
370    async fn test_legacy_preexisting_tool_calls_are_preserved_without_duplication(
371    ) -> anyhow::Result<()> {
372        let mut tmp = NamedTempFile::new()?;
373        writeln!(
374            tmp,
375            r#"{{"prompt":"legacy_with_calls","response":"ok","tool_calls":[{{"tool_name":"fs.read","args":{{"path":"/tmp/a"}}}},{{"tool_name":"fs.write","args":{{"path":"/tmp/b"}}}}]}}"#
376        )?;
377
378        let client = TraceClient::from_path(tmp.path())?;
379        let resp = client.complete("legacy_with_calls", None).await?;
380        assert_eq!(resp.text, "ok");
381        assert_eq!(
382            resp.meta
383                .pointer("/tool_calls")
384                .and_then(|v| v.as_array())
385                .map(|a| a.len()),
386            Some(2)
387        );
388        assert_eq!(
389            resp.meta
390                .pointer("/tool_calls/1/tool_name")
391                .and_then(|v| v.as_str()),
392            Some("fs.write")
393        );
394
395        Ok(())
396    }
397
398    #[tokio::test]
399    async fn test_legacy_tool_only_record_uses_ignore_fallback_prompt() -> anyhow::Result<()> {
400        let mut tmp = NamedTempFile::new()?;
401        writeln!(
402            tmp,
403            r#"{{"tool":"fs.read","args":{{"path":"/tmp/input.txt"}},"result":"ok"}}"#
404        )?;
405
406        let client = TraceClient::from_path(tmp.path())?;
407        let resp = client.complete("ignore", None).await?;
408        assert_eq!(resp.text, "ok");
409
410        assert_eq!(
411            resp.meta
412                .pointer("/tool_calls/0/tool_name")
413                .and_then(|v| v.as_str()),
414            Some("fs.read")
415        );
416        assert_eq!(
417            resp.meta
418                .pointer("/tool_calls/0/args/path")
419                .and_then(|v| v.as_str()),
420            Some("/tmp/input.txt")
421        );
422        Ok(())
423    }
424
425    #[tokio::test]
426    async fn test_v2_non_model_prompt_is_only_fallback() -> anyhow::Result<()> {
427        let mut tmp = NamedTempFile::new()?;
428        let ep_start =
429            r#"{"type":"episode_start","episode_id":"e_prio","timestamp":100,"input":null}"#;
430        let step_tool = r#"{"type":"step","episode_id":"e_prio","step_id":"s_tool","kind":"tool","timestamp":101,"content":"{\"prompt\":\"fallback_prompt\",\"completion\":\"tool_out\"}","meta":{}}"#;
431        let step_model = r#"{"type":"step","episode_id":"e_prio","step_id":"s_model","kind":"model","timestamp":102,"content":"{\"prompt\":\"authoritative_prompt\",\"completion\":\"model_out\"}","meta":{}}"#;
432        let ep_end = r#"{"type":"episode_end","episode_id":"e_prio","timestamp":103}"#;
433
434        writeln!(tmp, "{}", ep_start)?;
435        writeln!(tmp, "{}", step_tool)?;
436        writeln!(tmp, "{}", step_model)?;
437        writeln!(tmp, "{}", ep_end)?;
438
439        let client = TraceClient::from_path(tmp.path())?;
440        let resp = client.complete("authoritative_prompt", None).await?;
441        assert_eq!(resp.text, "model_out");
442        assert!(
443            client.complete("fallback_prompt", None).await.is_err(),
444            "fallback prompt must not remain addressable after model prompt extraction"
445        );
446
447        Ok(())
448    }
449
450    #[tokio::test]
451    async fn test_eof_flush_duplicate_prompt_key_keeps_first_entry() -> anyhow::Result<()> {
452        let mut tmp = NamedTempFile::new()?;
453        // Duplicate key definition for TraceClient insertion is prompt string.
454        // request_id differences do not allow overwriting an existing prompt key.
455        writeln!(
456            tmp,
457            r#"{{"request_id":"r1","prompt":"dup_prompt","response":"first_response"}}"#
458        )?;
459        let ep_start = r#"{"type":"episode_start","episode_id":"e_dup","timestamp":100,"input":{"prompt":"dup_prompt"}}"#;
460        let step1 = r#"{"type":"step","episode_id":"e_dup","step_id":"s1","kind":"model","timestamp":101,"content":"{\"completion\":\"second_response\"}"}"#;
461        // No episode_end on purpose; this exercises EOF flush path.
462        writeln!(tmp, "{}", ep_start)?;
463        writeln!(tmp, "{}", step1)?;
464
465        let client = TraceClient::from_path(tmp.path())?;
466        let resp = client.complete("dup_prompt", None).await?;
467        assert_eq!(resp.text, "first_response");
468
469        Ok(())
470    }
471
472    #[tokio::test]
473    async fn test_from_path_accepts_crlf_jsonl_lines() -> anyhow::Result<()> {
474        let mut tmp = NamedTempFile::new()?;
475        use std::io::Write as _;
476        tmp.as_file_mut().write_all(
477            b"{\"prompt\":\"crlf_prompt_1\",\"response\":\"ok1\"}\r\n{\"prompt\":\"crlf_prompt_2\",\"response\":\"ok2\"}\r\n",
478        )?;
479
480        let client = TraceClient::from_path(tmp.path())?;
481        let resp1 = client.complete("crlf_prompt_1", None).await?;
482        let resp2 = client.complete("crlf_prompt_2", None).await?;
483        assert_eq!(resp1.text, "ok1");
484        assert_eq!(resp2.text, "ok2");
485
486        Ok(())
487    }
488}