Skip to main content

byokey_proxy/handler/amp/
threads.rs

1//! Amp thread reading — parses local Amp CLI thread JSON files.
2//!
3//! Reads thread JSON files from `~/.local/share/amp/threads/` into
4//! in-memory domain types. The management `ConnectRPC` service (see
5//! `handler::management`) converts these to proto types at its boundary.
6
7use arc_swap::ArcSwap;
8use serde::Deserialize;
9use serde_json::Value;
10use std::{
11    fs::{self, File},
12    io::BufReader,
13    path::PathBuf,
14    sync::Arc,
15};
16
17// ── Threads directory resolution ─────────────────────────────────────
18
19/// Resolve the Amp threads directory.
20///
21/// Amp CLI uses `~/.local/share/amp/threads/` on both macOS and Linux
22/// (XDG data dir, not `~/Library`).
23pub(crate) fn threads_dir() -> PathBuf {
24    let home = std::env::var("HOME").unwrap_or_else(|_| String::from("/tmp"));
25    PathBuf::from(home).join(".local/share/amp/threads")
26}
27
28/// Validate a thread ID to prevent path traversal.
29/// Must match `T-` followed by hex digits and hyphens (UUID format).
30pub(crate) fn is_valid_thread_id(id: &str) -> bool {
31    id.starts_with("T-")
32        && id.len() > 2
33        && id[2..].chars().all(|c| c.is_ascii_hexdigit() || c == '-')
34}
35
36// ── Internal deserialization types (camelCase, matching Amp JSON) ─────
37
38#[derive(Deserialize)]
39#[serde(rename_all = "camelCase")]
40struct RawThreadSummary {
41    id: String,
42    created: u64,
43    #[serde(default)]
44    title: Option<String>,
45    #[serde(default)]
46    messages: Vec<RawMessageStub>,
47    #[serde(default)]
48    agent_mode: Option<String>,
49}
50
51#[derive(Deserialize)]
52struct RawMessageStub {
53    role: String,
54    #[serde(default)]
55    usage: Option<RawUsageStub>,
56}
57
58#[derive(Deserialize)]
59#[serde(rename_all = "camelCase")]
60struct RawUsageStub {
61    model: Option<String>,
62    input_tokens: Option<u64>,
63    output_tokens: Option<u64>,
64}
65
66#[derive(Deserialize)]
67#[serde(rename_all = "camelCase")]
68struct RawThread {
69    v: u64,
70    id: String,
71    created: u64,
72    #[serde(default)]
73    title: Option<String>,
74    #[serde(default)]
75    messages: Vec<RawMessage>,
76    #[serde(default)]
77    agent_mode: Option<String>,
78    #[serde(default)]
79    relationships: Vec<RawRelationship>,
80    #[serde(default)]
81    env: Option<Value>,
82}
83
84#[derive(Deserialize)]
85#[serde(rename_all = "camelCase")]
86struct RawMessage {
87    role: String,
88    message_id: u64,
89    #[serde(default)]
90    content: Vec<Value>,
91    #[serde(default)]
92    usage: Option<RawUsage>,
93    #[serde(default)]
94    state: Option<RawMessageState>,
95}
96
97#[derive(Deserialize)]
98#[serde(rename_all = "camelCase")]
99struct RawUsage {
100    #[serde(default)]
101    model: Option<String>,
102    #[serde(default)]
103    input_tokens: Option<u64>,
104    #[serde(default)]
105    output_tokens: Option<u64>,
106    #[serde(default)]
107    cache_creation_input_tokens: Option<u64>,
108    #[serde(default)]
109    cache_read_input_tokens: Option<u64>,
110    #[serde(default)]
111    total_input_tokens: Option<u64>,
112}
113
114#[derive(Deserialize)]
115#[serde(rename_all = "camelCase")]
116struct RawMessageState {
117    #[serde(rename = "type")]
118    state_type: String,
119    #[serde(default)]
120    stop_reason: Option<String>,
121}
122
123#[derive(Deserialize)]
124#[serde(rename_all = "camelCase")]
125struct RawRelationship {
126    #[serde(rename = "threadID")]
127    thread_id: String,
128    #[serde(rename = "type")]
129    rel_type: String,
130    #[serde(default)]
131    role: Option<String>,
132}
133
134// ── Internal domain types (snake_case) ────────────────────────────────
135
136/// Summary of a single Amp thread (excludes message bodies).
137#[derive(Clone)]
138pub struct AmpThreadSummary {
139    pub id: String,
140    /// Creation timestamp (Unix epoch milliseconds).
141    pub created: u64,
142    pub title: Option<String>,
143    /// Number of messages in the thread.
144    pub message_count: usize,
145    pub agent_mode: Option<String>,
146    /// Model used in the last assistant response.
147    pub last_model: Option<String>,
148    /// Sum of input tokens across all assistant turns.
149    pub total_input_tokens: Option<u64>,
150    /// Sum of output tokens across all assistant turns.
151    pub total_output_tokens: Option<u64>,
152    /// File size on disk (bytes).
153    pub file_size_bytes: u64,
154}
155
156/// Full Amp thread with all messages.
157pub struct AmpThreadDetail {
158    pub id: String,
159    /// Mutation counter (incremented on every thread change).
160    pub v: u64,
161    /// Creation timestamp (Unix epoch milliseconds).
162    pub created: u64,
163    pub title: Option<String>,
164    pub agent_mode: Option<String>,
165    pub messages: Vec<AmpMessage>,
166    pub relationships: Vec<AmpRelationship>,
167    /// Thread environment context (opaque JSON).
168    pub env: Option<Value>,
169}
170
171/// A single message within an Amp thread.
172pub struct AmpMessage {
173    /// `"user"`, `"assistant"`, or `"info"`.
174    pub role: String,
175    pub message_id: u64,
176    pub content: Vec<AmpContentBlock>,
177    pub usage: Option<AmpUsage>,
178    pub state: Option<AmpMessageState>,
179}
180
181/// A content block within a message.
182pub enum AmpContentBlock {
183    Text {
184        text: String,
185    },
186    Thinking {
187        thinking: String,
188    },
189    ToolUse {
190        id: String,
191        name: String,
192        input: Value,
193    },
194    ToolResult {
195        tool_use_id: String,
196        run: AmpToolRun,
197    },
198    /// Content block type not recognized by this parser.
199    Unknown {
200        original_type: Option<String>,
201    },
202}
203
204/// Tool execution result.
205pub struct AmpToolRun {
206    /// `"done"`, `"error"`, `"cancelled"`, `"rejected-by-user"`, or `"blocked-on-user"`.
207    pub status: String,
208    pub result: Option<Value>,
209    pub error: Option<Value>,
210}
211
212/// Token usage for an assistant turn.
213pub struct AmpUsage {
214    pub model: String,
215    pub input_tokens: Option<u64>,
216    pub output_tokens: Option<u64>,
217    pub cache_creation_input_tokens: Option<u64>,
218    pub cache_read_input_tokens: Option<u64>,
219    pub total_input_tokens: Option<u64>,
220}
221
222/// Assistant message state.
223pub struct AmpMessageState {
224    pub state_type: String,
225    pub stop_reason: Option<String>,
226}
227
228/// Relationship to another thread (handoff, fork, or mention).
229pub struct AmpRelationship {
230    pub thread_id: String,
231    /// `"handoff"`, `"fork"`, or `"mention"`.
232    pub rel_type: String,
233    /// `"parent"` or `"child"`.
234    pub role: Option<String>,
235}
236
237// ── Parsing logic ────────────────────────────────────────────────────
238
239fn parse_summary(path: &std::path::Path) -> Option<AmpThreadSummary> {
240    let file = File::open(path).ok()?;
241    let file_size = file.metadata().ok()?.len();
242    let raw: RawThreadSummary = serde_json::from_reader(BufReader::new(file)).ok()?;
243
244    let mut last_model: Option<String> = None;
245    let mut sum_input: u64 = 0;
246    let mut sum_output: u64 = 0;
247    let mut has_usage = false;
248
249    for msg in &raw.messages {
250        if msg.role == "assistant"
251            && let Some(u) = &msg.usage
252        {
253            if let Some(m) = &u.model {
254                last_model = Some(m.clone());
255            }
256            sum_input += u.input_tokens.unwrap_or(0);
257            sum_output += u.output_tokens.unwrap_or(0);
258            has_usage = true;
259        }
260    }
261
262    Some(AmpThreadSummary {
263        message_count: raw.messages.len(),
264        id: raw.id,
265        created: raw.created,
266        title: raw.title,
267        agent_mode: raw.agent_mode,
268        last_model,
269        total_input_tokens: has_usage.then_some(sum_input),
270        total_output_tokens: has_usage.then_some(sum_output),
271        file_size_bytes: file_size,
272    })
273}
274
275/// Convert a raw JSON `Value` content block into a typed `AmpContentBlock`.
276fn convert_content_block(v: &Value) -> AmpContentBlock {
277    let block_type = v.get("type").and_then(Value::as_str).unwrap_or("");
278    match block_type {
279        "text" => AmpContentBlock::Text {
280            text: v
281                .get("text")
282                .and_then(Value::as_str)
283                .unwrap_or("")
284                .to_string(),
285        },
286        "thinking" | "redacted_thinking" => AmpContentBlock::Thinking {
287            thinking: v
288                .get("thinking")
289                .or_else(|| v.get("data"))
290                .and_then(Value::as_str)
291                .unwrap_or("")
292                .to_string(),
293        },
294        "tool_use" => AmpContentBlock::ToolUse {
295            id: v
296                .get("id")
297                .and_then(Value::as_str)
298                .unwrap_or("")
299                .to_string(),
300            name: v
301                .get("name")
302                .and_then(Value::as_str)
303                .unwrap_or("")
304                .to_string(),
305            input: v.get("input").cloned().unwrap_or(Value::Null),
306        },
307        "tool_result" => {
308            let run_val = v.get("run");
309            AmpContentBlock::ToolResult {
310                tool_use_id: v
311                    .get("toolUseID")
312                    .and_then(Value::as_str)
313                    .unwrap_or("")
314                    .to_string(),
315                run: AmpToolRun {
316                    status: run_val
317                        .and_then(|r| r.get("status"))
318                        .and_then(Value::as_str)
319                        .unwrap_or("unknown")
320                        .to_string(),
321                    result: run_val.and_then(|r| r.get("result")).cloned(),
322                    error: run_val.and_then(|r| r.get("error")).cloned(),
323                },
324            }
325        }
326        _ => AmpContentBlock::Unknown {
327            original_type: Some(block_type.to_string()),
328        },
329    }
330}
331
332fn convert_message(raw: RawMessage) -> AmpMessage {
333    AmpMessage {
334        role: raw.role,
335        message_id: raw.message_id,
336        content: raw.content.iter().map(convert_content_block).collect(),
337        usage: raw.usage.map(|u| AmpUsage {
338            model: u.model.unwrap_or_default(),
339            input_tokens: u.input_tokens,
340            output_tokens: u.output_tokens,
341            cache_creation_input_tokens: u.cache_creation_input_tokens,
342            cache_read_input_tokens: u.cache_read_input_tokens,
343            total_input_tokens: u.total_input_tokens,
344        }),
345        state: raw.state.map(|s| AmpMessageState {
346            state_type: s.state_type,
347            stop_reason: s.stop_reason,
348        }),
349    }
350}
351
352pub(crate) fn parse_detail(path: &std::path::Path) -> Result<AmpThreadDetail, String> {
353    let file = File::open(path).map_err(|e| e.to_string())?;
354    let raw: RawThread =
355        serde_json::from_reader(BufReader::new(file)).map_err(|e| e.to_string())?;
356
357    Ok(AmpThreadDetail {
358        id: raw.id,
359        v: raw.v,
360        created: raw.created,
361        title: raw.title,
362        agent_mode: raw.agent_mode,
363        messages: raw.messages.into_iter().map(convert_message).collect(),
364        relationships: raw
365            .relationships
366            .into_iter()
367            .map(|r| AmpRelationship {
368                thread_id: r.thread_id,
369                rel_type: r.rel_type,
370                role: r.role,
371            })
372            .collect(),
373        env: raw.env,
374    })
375}
376
377// ── In-memory thread index with file watching ───────────────────────
378
379/// Pre-sorted, in-memory index of all Amp thread summaries.
380///
381/// Built once at startup by scanning `~/.local/share/amp/threads/`, then
382/// kept up-to-date via `notify` file-system events.  The inner `ArcSwap`
383/// allows lock-free reads from handlers while the watcher task atomically
384/// swaps in a new snapshot on every change.
385pub struct AmpThreadIndex {
386    summaries: ArcSwap<Vec<AmpThreadSummary>>,
387}
388
389impl AmpThreadIndex {
390    /// Build the initial index by scanning the threads directory.
391    ///
392    /// This performs synchronous filesystem I/O and should be called from
393    /// within `spawn_blocking` or at startup before the server binds.
394    #[must_use]
395    pub fn build() -> Self {
396        let summaries = scan_all_summaries();
397        Self {
398            summaries: ArcSwap::from_pointee(summaries),
399        }
400    }
401
402    /// Create an empty index (for tests or when the directory is absent).
403    #[must_use]
404    pub fn empty() -> Self {
405        Self {
406            summaries: ArcSwap::from_pointee(Vec::new()),
407        }
408    }
409
410    /// Return a snapshot of all cached summaries (sorted by `created` desc).
411    pub fn list(&self) -> arc_swap::Guard<Arc<Vec<AmpThreadSummary>>> {
412        self.summaries.load()
413    }
414
415    /// Start background file watching.
416    ///
417    /// Watches `~/.local/share/amp/threads/` for create / modify / remove
418    /// events and rebuilds the index on each change.  Events are debounced
419    /// (500 ms) so rapid writes from Amp don't cause redundant re-scans.
420    ///
421    /// # Panics
422    ///
423    /// Panics if the OS file watcher cannot be created or the directory
424    /// cannot be registered for watching.
425    pub fn watch(self: &Arc<Self>) {
426        use notify::{RecursiveMode, Watcher as _};
427
428        let index = Arc::clone(self);
429        let dir = threads_dir();
430
431        tokio::task::spawn_blocking(move || {
432            if !dir.is_dir() {
433                tracing::debug!(path = %dir.display(), "amp threads dir not found, skipping watch");
434                return;
435            }
436
437            let (tx, rx) = std::sync::mpsc::channel();
438
439            let mut watcher =
440                notify::recommended_watcher(move |res: notify::Result<notify::Event>| {
441                    if let Ok(ev) = res {
442                        // Only react to JSON file changes.
443                        let dominated_by_json = ev.paths.iter().any(|p| {
444                            p.extension()
445                                .is_some_and(|e| e.eq_ignore_ascii_case("json"))
446                        });
447                        if dominated_by_json {
448                            let _ = tx.send(());
449                        }
450                    }
451                })
452                .expect("failed to create file watcher");
453
454            watcher
455                .watch(&dir, RecursiveMode::NonRecursive)
456                .expect("failed to watch amp threads directory");
457
458            tracing::info!(path = %dir.display(), "watching amp threads directory");
459
460            // Debounce: drain all pending signals, then rebuild once.
461            while rx.recv().is_ok() {
462                // Drain any events that arrived while we were scanning.
463                while rx.try_recv().is_ok() {}
464
465                // Small delay to let Amp finish writing.
466                std::thread::sleep(std::time::Duration::from_millis(500));
467
468                // Drain again after the delay.
469                while rx.try_recv().is_ok() {}
470
471                let new = scan_all_summaries();
472                tracing::debug!(count = new.len(), "amp thread index rebuilt");
473                index.summaries.store(Arc::new(new));
474            }
475        });
476    }
477}
478
479/// Scan the threads directory and return all parseable summaries, sorted
480/// by `created` descending (newest first).
481fn scan_all_summaries() -> Vec<AmpThreadSummary> {
482    let dir = threads_dir();
483    let Ok(entries) = fs::read_dir(&dir) else {
484        return Vec::new();
485    };
486
487    let mut summaries: Vec<AmpThreadSummary> = entries
488        .filter_map(|entry| {
489            let entry = entry.ok()?;
490            let name = entry.file_name().to_string_lossy().to_string();
491            if !name.starts_with("T-")
492                || !std::path::Path::new(&name)
493                    .extension()
494                    .is_some_and(|ext| ext.eq_ignore_ascii_case("json"))
495            {
496                return None;
497            }
498            parse_summary(&entry.path())
499        })
500        .collect();
501
502    summaries.sort_unstable_by_key(|s| std::cmp::Reverse(s.created));
503    summaries
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509    use serde_json::json;
510
511    #[test]
512    fn valid_thread_ids() {
513        assert!(is_valid_thread_id("T-019d38dd-45f9-7617-8e7f-03b730ba197a"));
514        assert!(is_valid_thread_id("T-fc68e9f5-9621-4ee2-b8d9-d954ba656de4"));
515        assert!(is_valid_thread_id("T-abcdef0123456789"));
516    }
517
518    #[test]
519    fn invalid_thread_ids() {
520        assert!(!is_valid_thread_id(""));
521        assert!(!is_valid_thread_id("T-"));
522        assert!(!is_valid_thread_id("../etc/passwd"));
523        assert!(!is_valid_thread_id("T-../../foo"));
524        assert!(!is_valid_thread_id("T-abc def"));
525        assert!(!is_valid_thread_id("not-a-thread"));
526    }
527
528    #[test]
529    fn parse_empty_thread_json() {
530        let json_str =
531            r#"{"v":0,"id":"T-test-1234","created":1711728000000,"messages":[],"nextMessageId":0}"#;
532        let raw: RawThreadSummary = serde_json::from_str(json_str).unwrap();
533        assert_eq!(raw.id, "T-test-1234");
534        assert!(raw.messages.is_empty());
535        assert!(raw.title.is_none());
536    }
537
538    #[test]
539    fn parse_thread_with_messages() {
540        let json_str = json!({
541            "v": 5,
542            "id": "T-test-5678",
543            "created": 1_711_728_000_000_u64,
544            "messages": [
545                {
546                    "role": "user",
547                    "messageId": 0,
548                    "content": [{"type": "text", "text": "hello"}]
549                },
550                {
551                    "role": "assistant",
552                    "messageId": 1,
553                    "content": [
554                        {"type": "thinking", "thinking": "hmm", "signature": "sig"},
555                        {"type": "tool_use", "id": "toolu_01", "name": "Bash", "input": {"cmd": "ls"}, "complete": true},
556                    ],
557                    "usage": {
558                        "model": "claude-opus-4-6",
559                        "inputTokens": 100,
560                        "outputTokens": 50,
561                        "cacheCreationInputTokens": 10,
562                        "cacheReadInputTokens": 5,
563                        "totalInputTokens": 115
564                    },
565                    "state": {"type": "complete", "stopReason": "tool_use"}
566                },
567                {
568                    "role": "user",
569                    "messageId": 2,
570                    "content": [{
571                        "type": "tool_result",
572                        "toolUseID": "toolu_01",
573                        "run": {"status": "done", "result": {"output": "file.txt", "exitCode": 0}}
574                    }]
575                }
576            ],
577            "agentMode": "smart",
578            "title": "Test thread",
579            "nextMessageId": 3
580        });
581
582        let raw: RawThread = serde_json::from_value(json_str).unwrap();
583        assert_eq!(raw.messages.len(), 3);
584        assert_eq!(raw.agent_mode.as_deref(), Some("smart"));
585
586        // Test full conversion.
587        let detail = AmpThreadDetail {
588            id: raw.id.clone(),
589            v: raw.v,
590            created: raw.created,
591            title: raw.title.clone(),
592            agent_mode: raw.agent_mode.clone(),
593            messages: raw.messages.into_iter().map(convert_message).collect(),
594            relationships: Vec::new(),
595            env: None,
596        };
597
598        assert_eq!(detail.messages.len(), 3);
599        assert_eq!(detail.messages[0].role, "user");
600        assert_eq!(detail.messages[1].role, "assistant");
601        assert!(detail.messages[1].usage.is_some());
602
603        let usage = detail.messages[1].usage.as_ref().unwrap();
604        assert_eq!(usage.model, "claude-opus-4-6");
605        assert_eq!(usage.input_tokens, Some(100));
606        assert_eq!(usage.output_tokens, Some(50));
607
608        // Verify content blocks.
609        assert!(matches!(
610            &detail.messages[1].content[0],
611            AmpContentBlock::Thinking { .. }
612        ));
613        assert!(
614            matches!(&detail.messages[1].content[1], AmpContentBlock::ToolUse { name, .. } if name == "Bash")
615        );
616        assert!(matches!(
617            &detail.messages[2].content[0],
618            AmpContentBlock::ToolResult { .. }
619        ));
620    }
621
622    #[test]
623    fn convert_unknown_content_block() {
624        let block = json!({"type": "some_future_type", "data": 42});
625        let result = convert_content_block(&block);
626        assert!(
627            matches!(result, AmpContentBlock::Unknown { original_type: Some(t) } if t == "some_future_type")
628        );
629    }
630
631    #[test]
632    fn summary_deserialization_skips_heavy_fields() {
633        // Ensure RawThreadSummary doesn't fail on extra fields (content, env, etc.)
634        let json_str = json!({
635            "v": 100,
636            "id": "T-skip-test",
637            "created": 1_711_728_000_000_u64,
638            "messages": [{
639                "role": "user",
640                "messageId": 0,
641                "content": [{"type": "text", "text": "this should be skipped by summary parser"}],
642                "userState": {"activeEditor": "foo.rs"},
643                "fileMentions": {"files": []}
644            }],
645            "nextMessageId": 1,
646            "env": {"initial": {"platform": {"os": "darwin"}}},
647            "meta": {"traces": []},
648            "~debug": {"something": true}
649        });
650
651        let raw: RawThreadSummary = serde_json::from_value(json_str).unwrap();
652        assert_eq!(raw.id, "T-skip-test");
653        assert_eq!(raw.messages.len(), 1);
654    }
655}