use crate::builtin::{BuiltinTool, BuiltinToolError, ToolOutput};
use async_trait::async_trait;
use meerkat_core::ToolDef;
use meerkat_core::time_compat::{Duration, Instant};
use meerkat_core::wait_interrupt::WaitInterruptReceiver;
use serde::Deserialize;
use serde_json::{Value, json};
pub use meerkat_core::wait_interrupt::WaitInterrupt;
const MAX_WAIT_SECONDS: f64 = 60.0;
#[derive(Debug, Clone)]
pub struct WaitTool {
interrupt_rx: Option<WaitInterruptReceiver>,
}
impl WaitTool {
pub fn new() -> Self {
Self { interrupt_rx: None }
}
pub fn with_interrupt(rx: WaitInterruptReceiver) -> Self {
Self {
interrupt_rx: Some(rx),
}
}
pub fn interrupt_receiver(&self) -> Option<WaitInterruptReceiver> {
self.interrupt_rx.clone()
}
}
impl Default for WaitTool {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Deserialize, schemars::JsonSchema)]
struct WaitArgs {
#[schemars(
description = "Number of seconds to wait (0.1 to 60)",
range(min = 0.1, max = 60.0)
)]
seconds: f64,
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl BuiltinTool for WaitTool {
fn name(&self) -> &'static str {
"wait"
}
fn def(&self) -> ToolDef {
ToolDef {
name: "wait".into(),
description: "Pause execution for the specified number of seconds. Use this to wait between status checks on async operations like delegated work or long-running background tasks. Wait is interrupted early when peer messages arrive. Maximum wait time is 60 seconds (1 minute).".into(),
input_schema: crate::schema::schema_for::<WaitArgs>(),
}
}
fn default_enabled(&self) -> bool {
true }
async fn call(&self, args: Value) -> Result<ToolOutput, BuiltinToolError> {
#[cfg(target_arch = "wasm32")]
use crate::tokio::time::sleep;
#[cfg(not(target_arch = "wasm32"))]
use tokio::time::sleep;
let args: WaitArgs = serde_json::from_value(args)
.map_err(|e| BuiltinToolError::invalid_args(format!("Invalid arguments: {e}")))?;
let seconds = args.seconds.clamp(0.0, MAX_WAIT_SECONDS);
let duration = Duration::from_secs_f64(seconds);
let start = Instant::now();
if let Some(ref rx) = self.interrupt_rx {
let mut rx = rx.clone();
rx.borrow_and_update();
let interrupted = {
let sleep_fut = sleep(duration);
futures::pin_mut!(sleep_fut);
let changed_fut = rx.changed();
futures::pin_mut!(changed_fut);
match futures::future::select(sleep_fut, changed_fut).await {
futures::future::Either::Left(_) => false,
futures::future::Either::Right((result, _)) => result.is_ok(),
}
};
if interrupted && let Some(interrupt) = rx.borrow().as_ref() {
let waited = start.elapsed().as_secs_f64();
return Ok(ToolOutput::Json(json!({
"waited_seconds": waited,
"requested_seconds": seconds,
"status": "interrupted",
"reason": format!("Wait interrupted after {:.1}s: {}", waited, interrupt.reason)
})));
}
Ok(ToolOutput::Json(json!({
"waited_seconds": seconds,
"status": "complete"
})))
} else {
sleep(duration).await;
Ok(ToolOutput::Json(json!({
"waited_seconds": seconds,
"status": "complete"
})))
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use std::time::Instant;
#[test]
fn test_wait_tool_name() {
let tool = WaitTool::new();
assert_eq!(tool.name(), "wait");
}
#[tokio::test]
async fn test_wait_tool_interrupted_by_message() {
let (tx, rx) = tokio::sync::watch::channel(None::<WaitInterrupt>);
let tool = WaitTool::with_interrupt(rx);
let start = Instant::now();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(100)).await;
let _ = tx.send(Some(WaitInterrupt {
reason: "Received message from delegated worker: Task completed".to_string(),
}));
});
let result = tool
.call(json!({"seconds": 10.0}))
.await
.unwrap()
.into_json()
.unwrap();
let elapsed = start.elapsed();
assert!(
elapsed < Duration::from_secs(1),
"Should be interrupted quickly"
);
assert_eq!(result["status"], "interrupted");
assert!(result["waited_seconds"].as_f64().unwrap() < 1.0);
assert!(
result["reason"]
.as_str()
.unwrap()
.contains("delegated worker")
);
}
#[tokio::test]
async fn test_wait_tool_completes_without_interrupt() {
let (_tx, rx) = tokio::sync::watch::channel(None::<WaitInterrupt>);
let tool = WaitTool::with_interrupt(rx);
let result = tool
.call(json!({"seconds": 0.1}))
.await
.unwrap()
.into_json()
.unwrap();
assert_eq!(result["status"], "complete");
assert_eq!(result["waited_seconds"], 0.1);
}
#[tokio::test]
async fn test_wait_tool_without_interrupt_receiver() {
let tool = WaitTool::new();
let result = tool
.call(json!({"seconds": 0.1}))
.await
.unwrap()
.into_json()
.unwrap();
assert_eq!(result["status"], "complete");
}
#[test]
fn test_wait_tool_default_enabled() {
let tool = WaitTool::new();
assert!(tool.default_enabled());
}
#[test]
fn test_wait_tool_def() {
let tool = WaitTool::new();
let def = tool.def();
assert_eq!(def.name, "wait");
assert!(def.description.contains("Pause execution"));
assert!(def.input_schema.get("properties").is_some());
}
#[tokio::test]
async fn test_wait_tool_short_wait() {
let tool = WaitTool::new();
let start = Instant::now();
let result = tool
.call(json!({"seconds": 0.1}))
.await
.unwrap()
.into_json()
.unwrap();
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(100));
assert!(elapsed < Duration::from_millis(200));
assert_eq!(result["status"], "complete");
assert_eq!(result["waited_seconds"], 0.1);
}
#[test]
fn test_wait_tool_clamps_max_value() {
let seconds = 9999.0_f64;
let clamped = seconds.clamp(0.0, MAX_WAIT_SECONDS);
assert_eq!(clamped, MAX_WAIT_SECONDS);
}
#[test]
fn test_wait_tool_clamps_negative_value() {
let seconds = -5.0_f64;
let clamped = seconds.clamp(0.0, MAX_WAIT_SECONDS);
assert_eq!(clamped, 0.0);
}
#[tokio::test]
async fn test_wait_tool_invalid_args() {
let tool = WaitTool::new();
let result = tool.call(json!({"invalid": "args"})).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, BuiltinToolError::InvalidArgs(_)));
}
#[tokio::test]
async fn test_wait_tool_missing_seconds() {
let tool = WaitTool::new();
let result = tool.call(json!({})).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_wait_tool_stale_interrupt_does_not_affect_subsequent_waits() {
let (tx, rx) = tokio::sync::watch::channel(None::<WaitInterrupt>);
let tool = WaitTool::with_interrupt(rx);
tx.send(Some(WaitInterrupt {
reason: "stale interrupt".to_string(),
}))
.unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
let start = Instant::now();
let result = tool
.call(json!({"seconds": 0.2}))
.await
.unwrap()
.into_json()
.unwrap();
let elapsed = start.elapsed();
assert_eq!(result["status"], "complete");
assert!(
elapsed >= Duration::from_millis(180),
"Should wait full duration, got {elapsed:?}"
);
}
#[tokio::test]
async fn test_wait_tool_interrupt_returns_interrupted_status_with_reason() {
let (tx, rx) = tokio::sync::watch::channel(None::<WaitInterrupt>);
let tool = WaitTool::with_interrupt(rx);
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
let _ = tx.send(Some(WaitInterrupt {
reason: "Incoming peer message".to_string(),
}));
});
let result = tool
.call(json!({"seconds": 30.0}))
.await
.unwrap()
.into_json()
.unwrap();
assert_eq!(result["status"], "interrupted");
assert!(
result["reason"]
.as_str()
.unwrap()
.contains("Incoming peer message"),
"reason must include the interrupt source"
);
assert!(
result["waited_seconds"].as_f64().unwrap() < 1.0,
"should have been interrupted well before the full wait"
);
assert_eq!(result["requested_seconds"], 30.0);
}
#[tokio::test]
async fn test_wait_tool_multiple_sequential_interrupts() {
let (tx, rx) = tokio::sync::watch::channel(None::<WaitInterrupt>);
let tool = WaitTool::with_interrupt(rx);
let tx_clone = tx.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
let _ = tx_clone.send(Some(WaitInterrupt {
reason: "First interrupt".to_string(),
}));
});
let result1 = tool
.call(json!({"seconds": 10.0}))
.await
.unwrap()
.into_json()
.unwrap();
assert_eq!(result1["status"], "interrupted");
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
let _ = tx.send(Some(WaitInterrupt {
reason: "Second interrupt".to_string(),
}));
});
let result2 = tool
.call(json!({"seconds": 10.0}))
.await
.unwrap()
.into_json()
.unwrap();
assert_eq!(result2["status"], "interrupted");
assert!(
result2["reason"]
.as_str()
.unwrap()
.contains("Second interrupt")
);
}
#[test]
fn test_max_wait_seconds_is_60() {
assert_eq!(MAX_WAIT_SECONDS, 60.0);
let seconds = 120.0_f64;
let clamped = seconds.clamp(0.0, MAX_WAIT_SECONDS);
assert_eq!(clamped, 60.0);
}
}