use std::collections::HashSet;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex as StdMutex};
use tokio::sync::{mpsc, Mutex};
use crate::permissions::PermissionGate;
#[derive(Debug, Clone)]
pub enum ToolProgressChunk {
Stdout(Vec<u8>),
Stderr(Vec<u8>),
Status(String),
}
#[derive(Clone)]
pub struct SharedCancelToken {
inner: Arc<StdMutex<motosan_agent_loop::CancellationToken>>,
}
impl SharedCancelToken {
pub fn new() -> Self {
Self {
inner: Arc::new(StdMutex::new(motosan_agent_loop::CancellationToken::new())),
}
}
pub fn reset(&self) -> motosan_agent_loop::CancellationToken {
let mut guard = match self.inner.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
*guard = motosan_agent_loop::CancellationToken::new();
guard.clone()
}
pub fn cancel(&self) {
self.current().cancel();
}
pub async fn cancelled(&self) {
self.current().cancelled().await;
}
fn current(&self) -> motosan_agent_loop::CancellationToken {
match self.inner.lock() {
Ok(guard) => guard.clone(),
Err(poisoned) => poisoned.into_inner().clone(),
}
}
}
impl Default for SharedCancelToken {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct ToolCtx {
pub cwd: PathBuf,
pub read_files: Arc<Mutex<HashSet<PathBuf>>>,
pub permission_gate: Arc<dyn PermissionGate>,
pub progress_tx: mpsc::Sender<ToolProgressChunk>,
pub cancel_token: SharedCancelToken,
}
impl ToolCtx {
pub fn new(
cwd: impl AsRef<Path>,
permission_gate: Arc<dyn PermissionGate>,
progress_tx: mpsc::Sender<ToolProgressChunk>,
) -> Self {
Self::new_with_cancel_token(cwd, permission_gate, progress_tx, SharedCancelToken::new())
}
pub fn new_with_cancel_token(
cwd: impl AsRef<Path>,
permission_gate: Arc<dyn PermissionGate>,
progress_tx: mpsc::Sender<ToolProgressChunk>,
cancel_token: SharedCancelToken,
) -> Self {
Self {
cwd: cwd.as_ref().to_path_buf(),
read_files: Arc::new(Mutex::new(HashSet::new())),
permission_gate,
progress_tx,
cancel_token,
}
}
pub async fn mark_read(&self, path: &Path) {
self.read_files.lock().await.insert(path.to_path_buf());
}
pub async fn has_been_read(&self, path: &Path) -> bool {
self.read_files.lock().await.contains(path)
}
}