capo_agent/tools/
context.rs1use std::collections::HashSet;
2use std::path::{Path, PathBuf};
3use std::sync::{Arc, Mutex as StdMutex};
4
5use tokio::sync::{mpsc, Mutex};
6
7use crate::permissions::PermissionGate;
8
9#[derive(Debug, Clone)]
10pub enum ToolProgressChunk {
11 Stdout(Vec<u8>),
12 Stderr(Vec<u8>),
13 Status(String),
14}
15
16#[derive(Clone)]
17pub struct SharedCancelToken {
18 inner: Arc<StdMutex<motosan_agent_loop::CancellationToken>>,
19}
20
21impl SharedCancelToken {
22 pub fn new() -> Self {
23 Self {
24 inner: Arc::new(StdMutex::new(motosan_agent_loop::CancellationToken::new())),
25 }
26 }
27
28 pub fn reset(&self) -> motosan_agent_loop::CancellationToken {
29 let mut guard = match self.inner.lock() {
30 Ok(guard) => guard,
31 Err(poisoned) => poisoned.into_inner(),
32 };
33 *guard = motosan_agent_loop::CancellationToken::new();
34 guard.clone()
35 }
36
37 pub fn cancel(&self) {
38 self.current().cancel();
39 }
40
41 pub async fn cancelled(&self) {
42 self.current().cancelled().await;
43 }
44
45 fn current(&self) -> motosan_agent_loop::CancellationToken {
46 match self.inner.lock() {
47 Ok(guard) => guard.clone(),
48 Err(poisoned) => poisoned.into_inner().clone(),
49 }
50 }
51}
52
53impl Default for SharedCancelToken {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59#[derive(Clone)]
60pub struct ToolCtx {
61 pub cwd: PathBuf,
62 pub read_files: Arc<Mutex<HashSet<PathBuf>>>,
63 pub permission_gate: Arc<dyn PermissionGate>,
64 pub progress_tx: mpsc::Sender<ToolProgressChunk>,
65 pub cancel_token: SharedCancelToken,
66}
67
68impl ToolCtx {
69 pub fn new(
70 cwd: impl AsRef<Path>,
71 permission_gate: Arc<dyn PermissionGate>,
72 progress_tx: mpsc::Sender<ToolProgressChunk>,
73 ) -> Self {
74 Self::new_with_cancel_token(cwd, permission_gate, progress_tx, SharedCancelToken::new())
75 }
76
77 pub fn new_with_cancel_token(
78 cwd: impl AsRef<Path>,
79 permission_gate: Arc<dyn PermissionGate>,
80 progress_tx: mpsc::Sender<ToolProgressChunk>,
81 cancel_token: SharedCancelToken,
82 ) -> Self {
83 Self {
84 cwd: cwd.as_ref().to_path_buf(),
85 read_files: Arc::new(Mutex::new(HashSet::new())),
86 permission_gate,
87 progress_tx,
88 cancel_token,
89 }
90 }
91
92 pub async fn mark_read(&self, path: &Path) {
98 self.read_files.lock().await.insert(path.to_path_buf());
99 }
100
101 pub async fn has_been_read(&self, path: &Path) -> bool {
102 self.read_files.lock().await.contains(path)
103 }
104}