use std::sync::Arc;
use std::time::Instant;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::error::Result;
use crate::identity::validate_request_identifier;
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum ToolProgressStatus {
Started,
Running,
Completed,
Failed,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct ToolProgress {
pub run_id: String,
pub tool_use_id: String,
pub tool_name: String,
pub phase: String,
pub status: ToolProgressStatus,
pub dispatch_elapsed_ms: u64,
pub metadata: Value,
}
#[derive(Clone, Debug)]
pub struct CurrentToolInvocation {
tool_use_id: String,
tool_name: String,
started_at: Instant,
}
impl CurrentToolInvocation {
pub fn new(tool_use_id: impl Into<String>, tool_name: impl Into<String>) -> Result<Self> {
let tool_use_id =
validated_identifier("CurrentToolInvocation::new", "tool_use_id", tool_use_id)?;
let tool_name = validated_identifier("CurrentToolInvocation::new", "tool_name", tool_name)?;
Ok(Self {
tool_use_id,
tool_name,
started_at: Instant::now(),
})
}
#[must_use]
pub fn tool_use_id(&self) -> &str {
&self.tool_use_id
}
#[must_use]
pub fn tool_name(&self) -> &str {
&self.tool_name
}
#[must_use]
pub fn dispatch_elapsed_ms(&self) -> u64 {
u64::try_from(self.started_at.elapsed().as_millis()).unwrap_or(u64::MAX)
}
}
fn validated_identifier(surface: &str, field: &str, value: impl Into<String>) -> Result<String> {
let value = value.into();
validate_request_identifier(&format!("{surface}: {field}"), &value)?;
Ok(value)
}
#[async_trait]
pub trait ToolProgressSink: Send + Sync + 'static {
async fn record_progress(&self, progress: ToolProgress);
}
#[derive(Clone)]
pub struct ToolProgressSinkHandle {
sink: Arc<dyn ToolProgressSink>,
}
impl ToolProgressSinkHandle {
#[must_use]
pub fn new<S>(sink: S) -> Self
where
S: ToolProgressSink,
{
Self {
sink: Arc::new(sink),
}
}
#[must_use]
pub fn from_arc(sink: Arc<dyn ToolProgressSink>) -> Self {
Self { sink }
}
#[must_use]
pub fn inner(&self) -> &Arc<dyn ToolProgressSink> {
&self.sink
}
}
impl std::fmt::Debug for ToolProgressSinkHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolProgressSinkHandle")
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn current_tool_invocation_rejects_invalid_identity() {
let Err(err) = CurrentToolInvocation::new(" ", "echo") else {
panic!("expected invalid tool_use_id to fail");
};
assert!(format!("{err}").contains("tool_use_id must not be empty"));
let Err(err) = CurrentToolInvocation::new("tu-1", "echo\nnext") else {
panic!("expected invalid tool_name to fail");
};
assert!(format!("{err}").contains("tool_name must not contain control characters"));
}
#[test]
fn current_tool_invocation_accepts_valid_identity() -> Result<()> {
let current = CurrentToolInvocation::new("tu-1", "echo")?;
assert_eq!(current.tool_use_id(), "tu-1");
assert_eq!(current.tool_name(), "echo");
Ok(())
}
}