use crate::cancel_token::CancellationFlag;
use crate::tool_executor::ToolResult;
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use std::pin::Pin;
pub type ToolStream = Pin<Box<dyn Stream<Item = ToolChunk> + Send + 'static>>;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "kind")]
pub enum ToolFinishReason {
Stop,
Error {
message: String,
},
Cancelled,
}
impl ToolFinishReason {
pub fn all_variants() -> &'static [&'static str] {
&["stop", "error", "cancelled"]
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ToolChunk {
pub delta: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<ToolFinishReason>,
pub timestamp_ms: u64,
}
impl ToolChunk {
pub fn now_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
pub fn intermediate(delta: impl Into<String>) -> Self {
ToolChunk {
delta: delta.into(),
finish_reason: None,
timestamp_ms: Self::now_ms(),
}
}
pub fn terminator(
delta: impl Into<String>,
finish_reason: ToolFinishReason,
) -> Self {
ToolChunk {
delta: delta.into(),
finish_reason: Some(finish_reason),
timestamp_ms: Self::now_ms(),
}
}
pub fn from_result(result: &ToolResult) -> Self {
let finish_reason = Some(if result.success {
ToolFinishReason::Stop
} else {
ToolFinishReason::Error {
message: format!("tool '{}' failed", result.tool_name),
}
});
ToolChunk {
delta: result.output.clone(),
finish_reason,
timestamp_ms: Self::now_ms(),
}
}
pub fn is_terminator(&self) -> bool {
self.finish_reason.is_some()
}
}
#[derive(Debug, Clone)]
pub struct ToolContext {
pub cancel: CancellationFlag,
pub trace_id: u64,
}
impl ToolContext {
pub fn new(cancel: CancellationFlag, trace_id: u64) -> Self {
ToolContext { cancel, trace_id }
}
pub fn is_cancelled(&self) -> bool {
self.cancel.is_cancelled()
}
}
#[async_trait::async_trait]
pub trait Tool: Send + Sync {
async fn execute(&self, args: String, ctx: ToolContext) -> ToolResult;
async fn stream(&self, args: String, ctx: ToolContext) -> ToolStream {
let result = self.execute(args, ctx).await;
Box::pin(futures::stream::once(async move {
ToolChunk::from_result(&result)
}))
}
fn is_streaming(&self) -> bool {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;
#[test]
fn cell_01_toolchunk_serde_round_trip_intermediate() {
let chunk = ToolChunk::intermediate("Hola");
let json = serde_json::to_string(&chunk).unwrap();
assert!(
!json.contains("finish_reason"),
"intermediate chunk's None finish_reason MUST be elided. \
Got JSON: {json}"
);
let back: ToolChunk = serde_json::from_str(&json).unwrap();
assert_eq!(back.delta, "Hola");
assert_eq!(back.finish_reason, None);
assert_eq!(back.timestamp_ms, chunk.timestamp_ms);
}
#[test]
fn cell_02_toolchunk_serde_round_trip_terminator_stop() {
let chunk = ToolChunk::terminator("", ToolFinishReason::Stop);
let json = serde_json::to_string(&chunk).unwrap();
assert!(json.contains("\"finish_reason\""));
assert!(json.contains("\"kind\":\"stop\""));
let back: ToolChunk = serde_json::from_str(&json).unwrap();
assert_eq!(back.finish_reason, Some(ToolFinishReason::Stop));
}
#[test]
fn cell_03_toolchunk_serde_round_trip_terminator_error() {
let chunk = ToolChunk::terminator(
"",
ToolFinishReason::Error {
message: "upstream timeout".to_string(),
},
);
let json = serde_json::to_string(&chunk).unwrap();
assert!(json.contains("\"kind\":\"error\""));
assert!(json.contains("upstream timeout"));
let back: ToolChunk = serde_json::from_str(&json).unwrap();
assert_eq!(
back.finish_reason,
Some(ToolFinishReason::Error {
message: "upstream timeout".to_string()
})
);
}
#[test]
fn cell_04_toolchunk_serde_round_trip_terminator_cancelled() {
let chunk = ToolChunk::terminator("", ToolFinishReason::Cancelled);
let json = serde_json::to_string(&chunk).unwrap();
assert!(json.contains("\"kind\":\"cancelled\""));
let back: ToolChunk = serde_json::from_str(&json).unwrap();
assert_eq!(back.finish_reason, Some(ToolFinishReason::Cancelled));
}
#[test]
fn cell_05_toolchunk_intermediate_constructor() {
let chunk = ToolChunk::intermediate("partial");
assert_eq!(chunk.delta, "partial");
assert_eq!(chunk.finish_reason, None);
assert!(chunk.timestamp_ms > 0);
assert!(!chunk.is_terminator());
}
#[test]
fn cell_06_toolchunk_terminator_constructor() {
let chunk = ToolChunk::terminator("final", ToolFinishReason::Stop);
assert_eq!(chunk.delta, "final");
assert_eq!(chunk.finish_reason, Some(ToolFinishReason::Stop));
assert!(chunk.is_terminator());
}
#[test]
fn cell_07_toolchunk_from_result_success_maps_to_stop() {
let result = ToolResult {
success: true,
output: "42".to_string(),
tool_name: "Calculator".to_string(),
};
let chunk = ToolChunk::from_result(&result);
assert_eq!(chunk.delta, "42");
assert_eq!(chunk.finish_reason, Some(ToolFinishReason::Stop));
assert!(chunk.is_terminator());
}
#[test]
fn cell_08_toolchunk_from_result_failure_maps_to_error() {
let result = ToolResult {
success: false,
output: "division by zero".to_string(),
tool_name: "Calculator".to_string(),
};
let chunk = ToolChunk::from_result(&result);
assert_eq!(chunk.delta, "division by zero");
match chunk.finish_reason {
Some(ToolFinishReason::Error { ref message }) => {
assert!(message.contains("Calculator"));
}
other => panic!("expected Error finish_reason, got {other:?}"),
}
}
#[test]
fn cell_09_toolfinishreason_all_variants_pinned() {
let variants = ToolFinishReason::all_variants();
assert_eq!(
variants,
&["stop", "error", "cancelled"],
"33.b D1 closed-catalog: ToolFinishReason has EXACTLY 3 \
reachable states (stop / error / cancelled). Adding a 4th \
requires a deliberate sub-fase + cross-stack drift gate \
update + adopter docs update."
);
match ToolFinishReason::Stop {
ToolFinishReason::Stop => {}
ToolFinishReason::Error { .. } => unreachable!(),
ToolFinishReason::Cancelled => unreachable!(),
}
}
struct SyncTool;
#[async_trait::async_trait]
impl Tool for SyncTool {
async fn execute(&self, args: String, _ctx: ToolContext) -> ToolResult {
ToolResult {
success: true,
output: format!("sync({args})"),
tool_name: "SyncTool".to_string(),
}
}
}
#[test]
fn cell_10_tool_trait_default_is_streaming_is_false() {
let tool = SyncTool;
assert_eq!(
tool.is_streaming(),
false,
"33.b D1 + D9 default: tools that don't declare a stream \
effect have is_streaming() == false. Backwards-compat: \
every existing tool stays out of the streaming dispatch \
path."
);
}
#[tokio::test]
async fn cell_11_tool_trait_default_stream_wraps_execute_one_chunk() {
let tool = SyncTool;
let cancel = CancellationFlag::new();
let ctx = ToolContext::new(cancel, 0xDEAD_BEEF);
let mut stream = tool.stream("hello".to_string(), ctx).await;
let first = stream.next().await.expect("at least one chunk");
assert_eq!(first.delta, "sync(hello)");
assert_eq!(first.finish_reason, Some(ToolFinishReason::Stop));
assert!(first.is_terminator());
let second = stream.next().await;
assert!(
second.is_none(),
"33.b D9: default stream() emits EXACTLY 1 chunk. \
Got a second chunk: {second:?}"
);
}
struct StreamingTool;
#[async_trait::async_trait]
impl Tool for StreamingTool {
async fn execute(&self, _args: String, _ctx: ToolContext) -> ToolResult {
ToolResult {
success: true,
output: "materialized fallback".to_string(),
tool_name: "StreamingTool".to_string(),
}
}
async fn stream(&self, _args: String, _ctx: ToolContext) -> ToolStream {
let chunks = vec![
ToolChunk::intermediate("alpha "),
ToolChunk::intermediate("beta "),
ToolChunk::intermediate("gamma"),
ToolChunk::terminator("", ToolFinishReason::Stop),
];
Box::pin(futures::stream::iter(chunks))
}
fn is_streaming(&self) -> bool {
true
}
}
#[test]
fn cell_12_tool_trait_override_is_streaming_returns_true() {
let tool = StreamingTool;
assert!(
tool.is_streaming(),
"33.b D1: tools that override stream() to emit multiple \
chunks SHOULD override is_streaming() to return true"
);
}
#[tokio::test]
async fn cell_13_tool_trait_override_stream_emits_multiple_chunks() {
let tool = StreamingTool;
let cancel = CancellationFlag::new();
let ctx = ToolContext::new(cancel, 0xDEAD_BEEF);
let mut stream = tool.stream("".to_string(), ctx).await;
let mut collected: Vec<ToolChunk> = Vec::new();
while let Some(chunk) = stream.next().await {
collected.push(chunk);
}
assert_eq!(
collected.len(),
4,
"33.b: StreamingTool override emits exactly 4 chunks \
(3 intermediate + 1 terminator). Got {} chunks.",
collected.len()
);
assert_eq!(collected[0].delta, "alpha ");
assert_eq!(collected[1].delta, "beta ");
assert_eq!(collected[2].delta, "gamma");
assert_eq!(collected[3].delta, "");
assert!(!collected[0].is_terminator());
assert!(!collected[1].is_terminator());
assert!(!collected[2].is_terminator());
assert!(collected[3].is_terminator());
assert_eq!(collected[3].finish_reason, Some(ToolFinishReason::Stop));
}
#[test]
fn cell_14_toolchunk_d4_byte_compat_finish_reason_elided() {
let chunk = ToolChunk {
delta: "x".to_string(),
finish_reason: None,
timestamp_ms: 42,
};
let json = serde_json::to_string(&chunk).unwrap();
assert_eq!(
json, r#"{"delta":"x","timestamp_ms":42}"#,
"33.b D4: serialized JSON MUST elide None finish_reason. \
Got: {json}"
);
}
#[test]
fn cell_15_toolcontext_constructor_and_field_access() {
let cancel = CancellationFlag::new();
let ctx = ToolContext::new(cancel.clone(), 0xCAFE_BABE);
assert_eq!(ctx.trace_id, 0xCAFE_BABE);
assert!(!ctx.is_cancelled());
cancel.cancel();
assert!(ctx.is_cancelled());
}
}