use crate::messages::StopReason;
use crate::types::CorrelationId;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum StreamAction {
#[default]
Continue,
Complete,
Stop,
}
impl StreamAction {
#[must_use]
pub fn is_continue(&self) -> bool {
matches!(self, Self::Continue)
}
#[must_use]
pub fn is_complete(&self) -> bool {
matches!(self, Self::Complete)
}
#[must_use]
pub fn is_stop(&self) -> bool {
matches!(self, Self::Stop)
}
}
pub trait StreamHandler: Send + 'static {
fn on_start(&mut self, _correlation_id: &CorrelationId) {}
fn on_token(&mut self, token: &str);
fn on_end(&mut self, _stop_reason: StopReason) -> StreamAction {
StreamAction::Continue
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ExecutedToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
pub result: Result<serde_json::Value, String>,
}
impl ExecutedToolCall {
#[must_use]
pub fn success(
id: impl Into<String>,
name: impl Into<String>,
arguments: serde_json::Value,
result: serde_json::Value,
) -> Self {
Self {
id: id.into(),
name: name.into(),
arguments,
result: Ok(result),
}
}
#[must_use]
pub fn error(
id: impl Into<String>,
name: impl Into<String>,
arguments: serde_json::Value,
error: impl Into<String>,
) -> Self {
Self {
id: id.into(),
name: name.into(),
arguments,
result: Err(error.into()),
}
}
#[must_use]
pub fn is_success(&self) -> bool {
self.result.is_ok()
}
#[must_use]
pub fn is_error(&self) -> bool {
self.result.is_err()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CollectedResponse {
pub text: String,
pub stop_reason: StopReason,
pub token_count: usize,
pub tool_calls: Vec<ExecutedToolCall>,
}
impl CollectedResponse {
#[must_use]
pub fn new(text: String, stop_reason: StopReason, token_count: usize) -> Self {
Self {
text,
stop_reason,
token_count,
tool_calls: Vec::new(),
}
}
#[must_use]
pub fn with_tool_calls(
text: String,
stop_reason: StopReason,
token_count: usize,
tool_calls: Vec<ExecutedToolCall>,
) -> Self {
Self {
text,
stop_reason,
token_count,
tool_calls,
}
}
#[must_use]
pub fn has_tool_calls(&self) -> bool {
!self.tool_calls.is_empty()
}
#[must_use]
pub fn is_complete(&self) -> bool {
matches!(self.stop_reason, StopReason::EndTurn)
}
#[must_use]
pub fn is_truncated(&self) -> bool {
matches!(self.stop_reason, StopReason::MaxTokens)
}
#[must_use]
pub fn needs_tool_call(&self) -> bool {
matches!(self.stop_reason, StopReason::ToolUse)
}
}
impl Default for CollectedResponse {
fn default() -> Self {
Self {
text: String::new(),
stop_reason: StopReason::EndTurn,
token_count: 0,
tool_calls: Vec::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn stream_action_default_is_continue() {
assert_eq!(StreamAction::default(), StreamAction::Continue);
}
#[test]
fn stream_action_is_continue() {
assert!(StreamAction::Continue.is_continue());
assert!(!StreamAction::Complete.is_continue());
assert!(!StreamAction::Stop.is_continue());
}
#[test]
fn stream_action_is_complete() {
assert!(!StreamAction::Continue.is_complete());
assert!(StreamAction::Complete.is_complete());
assert!(!StreamAction::Stop.is_complete());
}
#[test]
fn stream_action_is_stop() {
assert!(!StreamAction::Continue.is_stop());
assert!(!StreamAction::Complete.is_stop());
assert!(StreamAction::Stop.is_stop());
}
#[test]
fn collected_response_new() {
let response = CollectedResponse::new("Hello".to_string(), StopReason::EndTurn, 1);
assert_eq!(response.text, "Hello");
assert_eq!(response.stop_reason, StopReason::EndTurn);
assert_eq!(response.token_count, 1);
}
#[test]
fn collected_response_is_complete() {
let complete = CollectedResponse::new("test".to_string(), StopReason::EndTurn, 1);
assert!(complete.is_complete());
let truncated = CollectedResponse::new("test".to_string(), StopReason::MaxTokens, 1);
assert!(!truncated.is_complete());
}
#[test]
fn collected_response_is_truncated() {
let truncated = CollectedResponse::new("test".to_string(), StopReason::MaxTokens, 1);
assert!(truncated.is_truncated());
let complete = CollectedResponse::new("test".to_string(), StopReason::EndTurn, 1);
assert!(!complete.is_truncated());
}
#[test]
fn collected_response_needs_tool_call() {
let tool_use = CollectedResponse::new("".to_string(), StopReason::ToolUse, 0);
assert!(tool_use.needs_tool_call());
let complete = CollectedResponse::new("test".to_string(), StopReason::EndTurn, 1);
assert!(!complete.needs_tool_call());
}
#[test]
fn collected_response_default() {
let response = CollectedResponse::default();
assert!(response.text.is_empty());
assert_eq!(response.stop_reason, StopReason::EndTurn);
assert_eq!(response.token_count, 0);
assert!(response.tool_calls.is_empty());
}
#[test]
fn collected_response_new_has_empty_tool_calls() {
let response = CollectedResponse::new("test".to_string(), StopReason::EndTurn, 1);
assert!(response.tool_calls.is_empty());
assert!(!response.has_tool_calls());
}
#[test]
fn collected_response_with_tool_calls() {
let tool_call = ExecutedToolCall::success(
"tc_1",
"calculator",
serde_json::json!({"expr": "2+2"}),
serde_json::json!({"result": 4}),
);
let response = CollectedResponse::with_tool_calls(
"The result is 4".to_string(),
StopReason::EndTurn,
5,
vec![tool_call],
);
assert!(response.has_tool_calls());
assert_eq!(response.tool_calls.len(), 1);
assert_eq!(response.tool_calls[0].name, "calculator");
}
#[test]
fn executed_tool_call_success() {
let call = ExecutedToolCall::success(
"tc_1",
"my_tool",
serde_json::json!({"arg": "value"}),
serde_json::json!({"result": "ok"}),
);
assert!(call.is_success());
assert!(!call.is_error());
assert_eq!(call.id, "tc_1");
assert_eq!(call.name, "my_tool");
}
#[test]
fn executed_tool_call_error() {
let call = ExecutedToolCall::error(
"tc_2",
"failing_tool",
serde_json::json!({}),
"Tool execution failed",
);
assert!(!call.is_success());
assert!(call.is_error());
assert_eq!(call.result.unwrap_err(), "Tool execution failed");
}
}