use std::pin::Pin;
use std::time::Duration;
use async_trait::async_trait;
use futures::Stream;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::tool::Tool;
use crate::tool_error::ToolError;
use crate::tool_result::ToolResult;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ToolChunk {
Progress {
percent: f32,
message: String,
},
Partial {
data: Value,
index: usize,
},
Complete {
result: ToolResult,
},
Error {
tool: String,
message: String,
retryable: bool,
},
}
impl ToolChunk {
pub fn progress(percent: f32, message: impl Into<String>) -> Self {
Self::Progress {
percent: percent.clamp(0.0, 100.0),
message: message.into(),
}
}
pub fn partial(data: Value, index: usize) -> Self {
Self::Partial { data, index }
}
pub fn complete(result: ToolResult) -> Self {
Self::Complete { result }
}
pub fn from_error(error: &ToolError) -> Self {
Self::Error {
tool: match error {
ToolError::NotFound { name } => name.clone(),
ToolError::InvalidArguments { tool, .. } => tool.clone(),
ToolError::ExecutionFailed { tool, .. } => tool.clone(),
ToolError::Timeout { tool, .. } => tool.clone(),
ToolError::Unavailable { name, .. } => name.clone(),
ToolError::Serialization(_) => "serialization".to_string(),
ToolError::AuditFailed(_) => "audit".to_string(),
},
message: error.to_string(),
retryable: error.is_retryable(),
}
}
pub fn error(tool: impl Into<String>, message: impl Into<String>) -> Self {
Self::Error {
tool: tool.into(),
message: message.into(),
retryable: false,
}
}
pub fn is_terminal(&self) -> bool {
matches!(self, Self::Complete { .. } | Self::Error { .. })
}
}
pub type ToolStream = Pin<Box<dyn Stream<Item = ToolChunk> + Send>>;
#[derive(Debug, Clone)]
pub struct StreamConfig {
pub max_chunks: usize,
pub chunk_timeout: Duration,
pub max_duration: Duration,
pub min_progress_interval: Duration,
}
impl Default for StreamConfig {
fn default() -> Self {
Self {
max_chunks: 1000,
chunk_timeout: Duration::from_secs(30),
max_duration: Duration::from_secs(300), min_progress_interval: Duration::from_millis(100),
}
}
}
impl StreamConfig {
pub fn short() -> Self {
Self {
max_chunks: 100,
chunk_timeout: Duration::from_secs(5),
max_duration: Duration::from_secs(30),
min_progress_interval: Duration::from_millis(50),
}
}
pub fn long() -> Self {
Self {
max_chunks: 10000,
chunk_timeout: Duration::from_secs(60),
max_duration: Duration::from_secs(3600), min_progress_interval: Duration::from_millis(500),
}
}
}
#[async_trait]
pub trait StreamingTool: Tool {
fn execute_stream(&self, args: Value, config: StreamConfig) -> ToolStream;
fn stream_config(&self) -> StreamConfig {
StreamConfig::default()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_tool_chunk_progress() {
let chunk = ToolChunk::progress(50.0, "Halfway there");
match chunk {
ToolChunk::Progress { percent, message } => {
assert_eq!(percent, 50.0);
assert_eq!(message, "Halfway there");
}
_ => panic!("Expected Progress chunk"),
}
}
#[test]
fn test_tool_chunk_progress_clamped() {
let chunk = ToolChunk::progress(150.0, "Over 100");
match chunk {
ToolChunk::Progress { percent, .. } => {
assert_eq!(percent, 100.0);
}
_ => panic!("Expected Progress chunk"),
}
}
#[test]
fn test_tool_chunk_is_terminal() {
assert!(!ToolChunk::progress(50.0, "").is_terminal());
assert!(!ToolChunk::partial(serde_json::json!({}), 0).is_terminal());
let result = ToolResult::new(
"test",
&serde_json::json!({}),
serde_json::json!({}),
Duration::from_secs(1),
);
assert!(ToolChunk::complete(result).is_terminal());
assert!(ToolChunk::error("test", "not found").is_terminal());
}
#[test]
fn test_stream_config_default() {
let config = StreamConfig::default();
assert_eq!(config.max_chunks, 1000);
assert_eq!(config.max_duration, Duration::from_secs(300));
}
#[test]
fn test_stream_config_short() {
let config = StreamConfig::short();
assert_eq!(config.max_chunks, 100);
assert!(config.max_duration < Duration::from_secs(60));
}
#[test]
fn test_stream_config_long() {
let config = StreamConfig::long();
assert_eq!(config.max_chunks, 10000);
assert_eq!(config.max_duration, Duration::from_secs(3600));
}
}