Skip to main content

edda_transcript/
filter.rs

1use serde_json::Value;
2use std::collections::HashMap;
3
4/// Return the largest byte index `<= i` that is a valid char boundary.
5/// Equivalent to `str::floor_char_boundary` (unstable nightly API).
6fn floor_char_boundary(s: &str, i: usize) -> usize {
7    if i >= s.len() {
8        return s.len();
9    }
10    let mut pos = i;
11    while pos > 0 && !s.is_char_boundary(pos) {
12        pos -= 1;
13    }
14    pos
15}
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum FilterAction {
19    Keep,
20    Progress,
21    Drop,
22}
23
24/// Classify a transcript JSONL record.
25pub fn classify_record(json: &Value) -> FilterAction {
26    let record_type = json.get("type").and_then(|v| v.as_str()).unwrap_or("");
27
28    match record_type {
29        "user" | "assistant" => FilterAction::Keep,
30        "system" => {
31            let subtype = json.get("subtype").and_then(|v| v.as_str()).unwrap_or("");
32            if subtype == "turn_duration" {
33                FilterAction::Drop
34            } else {
35                FilterAction::Keep
36            }
37        }
38        "progress" => FilterAction::Progress,
39        "file-history-snapshot" | "queue-operation" => FilterAction::Keep,
40        _ => FilterAction::Drop,
41    }
42}
43
44/// Progress Strategy 3: per-toolUseID, keep only the latest record.
45/// Truncate data.output to max chars and limit total entries.
46pub fn update_progress_last(progress_map: &mut HashMap<String, Value>, record: &Value) {
47    let tool_use_id = record
48        .get("toolUseID")
49        .or_else(|| record.get("tool_use_id"))
50        .and_then(|v| v.as_str())
51        .unwrap_or("")
52        .to_string();
53
54    if tool_use_id.is_empty() {
55        return;
56    }
57
58    let max_output_chars: usize = std::env::var("EDDA_PROGRESS_OUTPUT_CHARS")
59        .ok()
60        .and_then(|v| v.parse().ok())
61        .unwrap_or(600);
62
63    let max_tools: usize = std::env::var("EDDA_PROGRESS_MAX_TOOLS")
64        .ok()
65        .and_then(|v| v.parse().ok())
66        .unwrap_or(200);
67
68    // Truncate data.output if present
69    let mut record = record.clone();
70    if let Some(data) = record.get_mut("data") {
71        let needs_truncate = data
72            .get("output")
73            .and_then(|v| v.as_str())
74            .map(|s| s.len() > max_output_chars)
75            .unwrap_or(false);
76        if needs_truncate {
77            let output = data["output"].as_str().unwrap();
78            // Find a valid char boundary at or before max_output_chars
79            let end = floor_char_boundary(output, max_output_chars);
80            let truncated = output[..end].to_string();
81            data.as_object_mut()
82                .unwrap()
83                .insert("output".into(), Value::String(truncated));
84        }
85    }
86
87    progress_map.insert(tool_use_id, record);
88
89    // Enforce map size limit
90    while progress_map.len() > max_tools {
91        // Remove oldest (arbitrary key since HashMap is unordered)
92        if let Some(key) = progress_map.keys().next().cloned() {
93            progress_map.remove(&key);
94        }
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101
102    #[test]
103    fn classify_user_keep() {
104        let v = serde_json::json!({"type": "user", "message": "hello"});
105        assert_eq!(classify_record(&v), FilterAction::Keep);
106    }
107
108    #[test]
109    fn classify_assistant_keep() {
110        let v = serde_json::json!({"type": "assistant"});
111        assert_eq!(classify_record(&v), FilterAction::Keep);
112    }
113
114    #[test]
115    fn classify_system_keep() {
116        let v = serde_json::json!({"type": "system", "subtype": "init"});
117        assert_eq!(classify_record(&v), FilterAction::Keep);
118    }
119
120    #[test]
121    fn classify_system_turn_duration_drop() {
122        let v = serde_json::json!({"type": "system", "subtype": "turn_duration"});
123        assert_eq!(classify_record(&v), FilterAction::Drop);
124    }
125
126    #[test]
127    fn classify_progress() {
128        let v = serde_json::json!({"type": "progress"});
129        assert_eq!(classify_record(&v), FilterAction::Progress);
130    }
131
132    #[test]
133    fn classify_unknown_drop() {
134        let v = serde_json::json!({"type": "unknown_type"});
135        assert_eq!(classify_record(&v), FilterAction::Drop);
136    }
137
138    #[test]
139    fn truncate_respects_char_boundary() {
140        let mut map = HashMap::new();
141        // Build a string that has multi-byte chars near the truncation point.
142        // '後' is 3 bytes (E5 BE 8C). Place it so byte index 600 lands mid-char.
143        let prefix = "x".repeat(598); // 598 ASCII bytes
144        let output = format!("{prefix}後後後 tail"); // byte 598..601 = '後'
145        assert!(!output.is_char_boundary(600)); // confirm the setup
146
147        let r = serde_json::json!({
148            "toolUseID": "t_utf8",
149            "data": { "output": output }
150        });
151        // Default max_output_chars = 600 → should NOT panic
152        update_progress_last(&mut map, &r);
153        let stored = map["t_utf8"]["data"]["output"].as_str().unwrap();
154        // floor_char_boundary(600) → 598 (before '後' at bytes 598..601)
155        assert_eq!(stored.len(), 598);
156        assert!(stored.chars().all(|c| c == 'x'));
157    }
158
159    #[test]
160    fn floor_char_boundary_basic() {
161        assert_eq!(super::floor_char_boundary("hello", 3), 3);
162        assert_eq!(super::floor_char_boundary("hello", 100), 5);
163        // '後' = 3 bytes
164        let s = "ab後cd"; // b'a'=0, b'b'=1, '後'=2..5, b'c'=5, b'd'=6
165        assert_eq!(super::floor_char_boundary(s, 3), 2); // mid-'後' → back to 2
166        assert_eq!(super::floor_char_boundary(s, 4), 2); // still mid-'後'
167        assert_eq!(super::floor_char_boundary(s, 5), 5); // at 'c', valid boundary
168    }
169
170    #[test]
171    fn progress_last_keeps_latest() {
172        let mut map = HashMap::new();
173        let r1 = serde_json::json!({"toolUseID": "t1", "data": {"output": "old"}});
174        let r2 = serde_json::json!({"toolUseID": "t1", "data": {"output": "new"}});
175        update_progress_last(&mut map, &r1);
176        update_progress_last(&mut map, &r2);
177        assert_eq!(map.len(), 1);
178        assert_eq!(map["t1"]["data"]["output"], "new");
179    }
180}