1use crate::errors::{diagnostic::codes, similarity::closest_prompt, Diagnostic};
2use crate::model::LlmResponse;
3use crate::providers::llm::LlmClient;
4use async_trait::async_trait;
5use sha2::Digest;
6use std::collections::{HashMap, HashSet};
7use std::fs::File;
8use std::io::BufRead;
9use std::path::Path;
10use std::sync::Arc;
11
12#[derive(Clone)]
13pub struct TraceClient {
14 traces: Arc<HashMap<String, LlmResponse>>,
16 fingerprint: String,
17}
18impl TraceClient {
19 pub fn from_path<P: AsRef<Path>>(path: P) -> anyhow::Result<Self> {
20 let file = File::open(path.as_ref()).map_err(|e| {
21 anyhow::anyhow!(
22 "failed to open trace file '{}': {}",
23 path.as_ref().display(),
24 e
25 )
26 })?;
27 let reader = std::io::BufReader::new(file);
28
29 let mut traces = HashMap::new();
30 let mut request_ids = HashSet::new();
31
32 struct EpisodeState {
34 input: Option<String>,
35 output: Option<String>,
36 model: Option<String>,
37 meta: serde_json::Value,
38 input_is_model: bool,
39 tool_calls: Vec<crate::model::ToolCallRecord>,
40 }
41 let mut active_episodes: HashMap<String, EpisodeState> = HashMap::new();
42
43 for (i, line_res) in reader.lines().enumerate() {
44 let line = line_res?;
45 if line.trim().is_empty() {
46 continue;
47 }
48
49 let v: serde_json::Value = serde_json::from_str(&line).map_err(|e| {
56 anyhow::anyhow!(
57 "line {}: Invalid trace format. Expected JSONL object.\n Error: {}\n Content: {}",
58 i + 1,
59 e,
60 line.chars().take(50).collect::<String>()
61 )
62 })?;
63
64 let mut prompt_opt = None;
66 let mut response_opt = None;
67 let mut model = "trace".to_string();
68 let mut meta = serde_json::json!({});
69 let mut request_id_check = None;
70
71 if let Some(t) = v.get("type").and_then(|t| t.as_str()) {
72 match t {
73 "assay.trace" => {
74 prompt_opt = v.get("prompt").and_then(|s| s.as_str()).map(String::from);
76 response_opt = v
77 .get("response")
78 .or(v.get("text"))
79 .and_then(|s| s.as_str())
80 .map(String::from);
81 if let Some(m) = v.get("model").and_then(|s| s.as_str()) {
82 model = m.to_string();
83 }
84 if let Some(m) = v.get("meta") {
85 meta = m.clone();
86 }
87 if let Some(r) = v.get("request_id").and_then(|s| s.as_str()) {
88 request_id_check = Some(r.to_string());
89 }
90 }
91 "episode_start" => {
92 if let Ok(ev) =
94 serde_json::from_value::<crate::trace::schema::EpisodeStart>(v.clone())
95 {
96 let input_prompt = ev
97 .input
98 .get("prompt")
99 .and_then(|s| s.as_str())
100 .map(String::from);
101 let has_input = input_prompt.is_some();
102 let state = EpisodeState {
103 input: input_prompt,
104 output: None, model: None, meta: ev.meta,
107 input_is_model: has_input, tool_calls: Vec::new(),
109 };
110 active_episodes.insert(ev.episode_id, state);
111 continue; }
113 }
114 "tool_call" => {
115 if let Ok(ev) =
116 serde_json::from_value::<crate::trace::schema::ToolCallEntry>(v.clone())
117 {
118 if let Some(state) = active_episodes.get_mut(&ev.episode_id) {
119 state.tool_calls.push(crate::model::ToolCallRecord {
120 id: format!("{}-{}", ev.step_id, ev.call_index.unwrap_or(0)),
121 tool_name: ev.tool_name,
122 args: ev.args,
123 result: ev.result,
124 error: ev.error.map(serde_json::Value::String),
125 index: state.tool_calls.len(), ts_ms: ev.timestamp,
127 });
128 }
129 }
130 }
131 "episode_end" => {
132 if let Ok(ev) =
134 serde_json::from_value::<crate::trace::schema::EpisodeEnd>(v.clone())
135 {
136 if let Some(mut state) = active_episodes.remove(&ev.episode_id) {
137 if let Some(out) = ev.final_output {
139 state.output = Some(out);
140 }
141
142 if let Some(p) = state.input {
143 prompt_opt = Some(p);
144 response_opt = state.output;
145
146 if !state.tool_calls.is_empty() {
148 state.meta["tool_calls"] =
149 serde_json::to_value(&state.tool_calls)
150 .unwrap_or_default();
151 }
152
153 meta = state.meta;
154 }
156 }
157 }
158 }
159
160 "step" => {
161 if let Ok(ev) =
162 serde_json::from_value::<crate::trace::schema::StepEntry>(v.clone())
163 {
164 if let Some(state) = active_episodes.get_mut(&ev.episode_id) {
165 let is_model = ev.kind == "model";
172 let can_extract = if is_model {
173 !state.input_is_model
176 } else {
177 state.input.is_none()
179 };
180
181 if can_extract {
182 let mut found_prompt = None;
183
184 if let Some(c) = &ev.content {
185 if let Ok(c_json) =
186 serde_json::from_str::<serde_json::Value>(c)
187 {
188 if let Some(p) =
189 c_json.get("prompt").and_then(|s| s.as_str())
190 {
191 found_prompt = Some(p.to_string());
192 }
193 }
194 }
195 if found_prompt.is_none() {
196 if let Some(p) =
197 ev.meta.get("gen_ai.prompt").and_then(|s| s.as_str())
198 {
199 found_prompt = Some(p.to_string());
200 }
201 }
202
203 if let Some(p) = found_prompt {
204 state.input = Some(p);
205 if is_model {
206 state.input_is_model = true;
207 }
208 }
213 }
214
215 if let Some(c) = &ev.content {
218 let mut extracted = None;
219 if let Ok(c_json) = serde_json::from_str::<serde_json::Value>(c)
220 {
221 if let Some(resp) =
222 c_json.get("completion").and_then(|s| s.as_str())
223 {
224 extracted = Some(resp.to_string());
225 if let Some(m) =
227 c_json.get("model").and_then(|s| s.as_str())
228 {
229 state.model = Some(m.to_string());
230 }
231 }
232 }
233
234 if let Some(out) = extracted {
235 state.output = Some(out);
236 } else {
237 state.output = Some(c.clone());
239 }
240 }
241 if let Some(resp) =
243 ev.meta.get("gen_ai.completion").and_then(|s| s.as_str())
244 {
245 state.output = Some(resp.to_string());
246 }
247 if let Some(m) = ev
248 .meta
249 .get("gen_ai.request.model")
250 .or(ev.meta.get("gen_ai.response.model"))
251 .and_then(|s| s.as_str())
252 {
253 state.model = Some(m.to_string());
254 }
255 }
256 }
257 continue;
258 }
259 _ => {
260 continue;
261 }
262 }
263 } else {
264 prompt_opt = v.get("prompt").and_then(|s| s.as_str()).map(String::from);
266 response_opt = v
267 .get("response")
268 .or(v.get("text"))
269 .and_then(|s| s.as_str())
270 .map(String::from);
271 if let Some(m) = v.get("model").and_then(|s| s.as_str()) {
273 model = m.to_string();
274 }
275 if let Some(r) = v.get("request_id").and_then(|s| s.as_str()) {
276 request_id_check = Some(r.to_string());
277 }
278
279 let tool_name = v.get("tool").and_then(|s| s.as_str()).map(String::from);
281 let tool_args = v.get("args").cloned();
282
283 if let Some(tool) = tool_name {
284 let record = crate::model::ToolCallRecord {
285 id: "legacy-v1".to_string(),
286 tool_name: tool,
287 args: tool_args.unwrap_or(serde_json::json!({})),
288 result: None,
289 error: None,
290 index: 0,
291 ts_ms: 0,
292 };
293 meta["tool_calls"] = serde_json::json!([record]);
294 } else if let Some(calls) = v.get("tool_calls").and_then(|v| v.as_array()) {
295 meta["tool_calls"] = serde_json::Value::Array(calls.clone());
297 }
298 }
299
300 if let (Some(p), Some(r)) = (prompt_opt, response_opt) {
301 if let Some(rid) = &request_id_check {
304 if request_ids.contains(rid) {
305 return Err(anyhow::anyhow!(
306 "line {}: Duplicate request_id {}",
307 i + 1,
308 rid
309 ));
310 }
311 request_ids.insert(rid.clone());
312 }
313
314 if traces.contains_key(&p) {
315 return Err(anyhow::anyhow!(
318 "Duplicate prompt found in trace file: {}",
319 p
320 ));
321 }
322
323 traces.insert(
324 p,
325 LlmResponse {
326 text: r,
327 meta,
328 model,
329 provider: "trace".to_string(),
330 ..Default::default()
331 },
332 );
333 }
334 }
335
336 for (id, state) in active_episodes {
338 if let (Some(p), Some(r)) = (state.input.clone(), state.output.clone()) {
339 if traces.contains_key(&p) {
342 eprintln!("Warning: Duplicate prompt skipped at EOF for id {}", id);
343 continue;
344 }
345 traces.insert(
346 p,
347 LlmResponse {
348 text: r,
349 meta: state.meta,
350 model: state.model.unwrap_or_else(|| "trace".to_string()),
351 provider: "trace".to_string(),
352 ..Default::default()
353 },
354 );
355 }
356 }
357
358 let mut keys: Vec<&String> = traces.keys().collect();
360 keys.sort();
361 let mut hasher = sha2::Sha256::new();
362 for k in keys {
363 use sha2::Digest;
364 hasher.update(k.as_bytes());
365 if let Some(v) = traces.get(k) {
366 hasher.update(v.text.as_bytes());
368 hasher.update(v.model.as_bytes());
370 }
371 }
372 let fingerprint = hex::encode(hasher.finalize());
373
374 Ok(Self {
375 traces: Arc::new(traces),
376 fingerprint,
377 })
378 }
379}
380
381#[async_trait]
382impl LlmClient for TraceClient {
383 async fn complete(
384 &self,
385 prompt: &str,
386 _context: Option<&[String]>,
387 ) -> anyhow::Result<LlmResponse> {
388 if let Some(resp) = self.traces.get(prompt) {
389 Ok(resp.clone())
390 } else {
391 let closest = closest_prompt(prompt, self.traces.keys());
393
394 let mut diag = Diagnostic::new(
395 codes::E_TRACE_MISS,
396 "Trace miss: prompt not found in loaded traces".to_string(),
397 )
398 .with_source("trace")
399 .with_context(serde_json::json!({
400 "prompt": prompt,
401 "closest_match": closest
402 }));
403
404 if let Some(match_) = closest {
405 diag = diag.with_fix_step(format!(
406 "Did you mean '{}'? (similarity: {:.2})",
407 match_.prompt, match_.similarity
408 ));
409 diag = diag.with_fix_step("Update your input prompt to match the trace exactly");
410 } else {
411 diag = diag.with_fix_step("No similar prompts found in trace file");
412 }
413
414 diag = diag.with_fix_step("Regenerate the trace file: assay trace ingest ...");
415
416 Err(anyhow::Error::new(diag))
417 }
418 }
419
420 fn provider_name(&self) -> &'static str {
421 "trace"
422 }
423
424 fn fingerprint(&self) -> Option<String> {
425 Some(self.fingerprint.clone())
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432 use std::io::Write;
433 use tempfile::NamedTempFile;
434
435 #[tokio::test]
436 async fn test_trace_client_happy_path() -> anyhow::Result<()> {
437 let mut tmp = NamedTempFile::new()?;
438 writeln!(
439 tmp,
440 r#"{{"prompt": "hello", "response": "world", "model": "gpt-4"}}"#
441 )?;
442 writeln!(tmp, r#"{{"prompt": "foo", "response": "bar"}}"#)?;
443
444 let client = TraceClient::from_path(tmp.path())?;
445
446 let resp1 = client.complete("hello", None).await?;
447 assert_eq!(resp1.text, "world");
448 assert_eq!(resp1.model, "gpt-4");
449
450 let resp2 = client.complete("foo", None).await?;
451 assert_eq!(resp2.text, "bar");
452 assert_eq!(resp2.provider, "trace"); Ok(())
455 }
456
457 #[tokio::test]
458 async fn test_trace_client_miss() -> anyhow::Result<()> {
459 let mut tmp = NamedTempFile::new()?;
460 writeln!(tmp, r#"{{"prompt": "exists", "response": "yes"}}"#)?;
461
462 let client = TraceClient::from_path(tmp.path())?;
463 let result = client.complete("does not exist", None).await;
464 assert!(result.is_err());
465 Ok(())
466 }
467
468 #[tokio::test]
469 async fn test_trace_client_duplicate_prompt() -> anyhow::Result<()> {
470 let mut tmp = NamedTempFile::new()?;
471 writeln!(tmp, r#"{{"prompt": "dup", "response": "1"}}"#)?;
472 writeln!(tmp, r#"{{"prompt": "dup", "response": "2"}}"#)?;
473
474 let result = TraceClient::from_path(tmp.path());
475 assert!(result.is_err());
476 Ok(())
477 }
478
479 #[tokio::test]
480 async fn test_trace_client_duplicate_request_id() -> anyhow::Result<()> {
481 let mut tmp = NamedTempFile::new()?;
482 writeln!(
484 tmp,
485 r#"{{"request_id": "id1", "prompt": "p1", "response": "1"}}"#
486 )?;
487 writeln!(
488 tmp,
489 r#"{{"request_id": "id1", "prompt": "p2", "response": "2"}}"#
490 )?;
491
492 let result = TraceClient::from_path(tmp.path());
493 assert!(result.is_err());
494 assert!(result
495 .err()
496 .unwrap()
497 .to_string()
498 .contains("Duplicate request_id"));
499 Ok(())
500 }
501
502 #[tokio::test]
503 async fn test_trace_schema_validation() -> anyhow::Result<()> {
504 let mut tmp = NamedTempFile::new()?;
505 writeln!(tmp, r#"{{"schema_version": 2, "prompt": "p"}}"#)?;
507 let client = TraceClient::from_path(tmp.path())?;
508 assert!(client.complete("p", None).await.is_err()); let mut tmp2 = NamedTempFile::new()?;
511 writeln!(
514 tmp2,
515 r#"{{"type": "wrong", "prompt": "p", "response": "r"}}"#
516 )?;
517 let client = TraceClient::from_path(tmp2.path())?;
518 assert!(client.complete("p", None).await.is_err()); let mut tmp3 = NamedTempFile::new()?;
521 writeln!(tmp3, r#"{{"prompt": "p"}}"#)?;
523 let client = TraceClient::from_path(tmp3.path())?;
526 assert!(client.complete("p", None).await.is_err()); Ok(())
529 }
530
531 #[tokio::test]
532 async fn test_trace_meta_preservation() -> anyhow::Result<()> {
533 let mut tmp = NamedTempFile::new()?;
534 let json = r#"{"schema_version":1,"type":"assay.trace","request_id":"test-1","prompt":"Say hello","response":"Hello world","meta":{"assay":{"embeddings":{"model":"text-embedding-3-small","response":[0.1],"reference":[0.1]}}}}"#;
536 writeln!(tmp, "{}", json)?;
537
538 let client = TraceClient::from_path(tmp.path())?;
539 let resp = client.complete("Say hello", None).await?;
540
541 println!("Meta from test: {}", resp.meta);
542 assert!(
543 resp.meta.pointer("/assay/embeddings/response").is_some(),
544 "Meta embeddings missing!"
545 );
546 Ok(())
547 }
548
549 #[tokio::test]
550 async fn test_v2_replay_precedence() -> anyhow::Result<()> {
551 let mut tmp = NamedTempFile::new()?;
552 let ep_start = r#"{"type":"episode_start","episode_id":"e1","timestamp":100,"input":null}"#;
556 let step1 = r#"{"type":"step","episode_id":"e1","step_id":"s1","kind":"model","timestamp":101,"content":"{\"prompt\":\"original_prompt\",\"completion\":\"output_1\"}"}"#;
557 let step2 = r#"{"type":"step","episode_id":"e1","step_id":"s2","kind":"model","timestamp":102,"content":"{\"prompt\":\"ignored\",\"completion\":\"final_output\"}"}"#;
559 let step3 = r#"{"type":"step","episode_id":"e1","step_id":"s3","kind":"model","timestamp":103,"content":null,"meta":{"gen_ai.completion":"meta_final"}}"#;
561
562 let ep_end = r#"{"type":"episode_end","episode_id":"e1","timestamp":104}"#;
563
564 writeln!(tmp, "{}", ep_start)?;
565 writeln!(tmp, "{}", step1)?;
566 writeln!(tmp, "{}", step2)?;
567 writeln!(tmp, "{}", step3)?;
568 writeln!(tmp, "{}", ep_end)?;
569
570 let client = TraceClient::from_path(tmp.path())?;
571 let resp = client.complete("original_prompt", None).await?; assert_eq!(resp.text, "meta_final");
575
576 Ok(())
577 }
578
579 #[tokio::test]
580 async fn test_eof_flush_partial_episode() -> anyhow::Result<()> {
581 let mut tmp = NamedTempFile::new()?;
582 let ep_start = r#"{"type":"episode_start","episode_id":"e_flush","timestamp":100,"input":{"prompt":"flush_me"}}"#;
584 let step1 = r#"{"type":"step","episode_id":"e_flush","step_id":"s1","kind":"model","timestamp":101,"content":"{\"completion\":\"flushed_output\"}"}"#;
585
586 writeln!(tmp, "{}", ep_start)?;
587 writeln!(tmp, "{}", step1)?;
588
589 let client = TraceClient::from_path(tmp.path())?;
590 let resp = client.complete("flush_me", None).await?;
591 assert_eq!(resp.text, "flushed_output");
592
593 Ok(())
594 }
595}