use crate::llm::config::SamplingParams;
use crate::llm::error::LLMError;
use crate::messages::{Message, StopReason, ToolCall, ToolDefinition};
use async_trait::async_trait;
use futures::Stream;
use std::pin::Pin;
#[derive(Debug, Clone)]
pub enum LLMStreamEvent {
Start {
id: String,
},
Token {
text: String,
},
ToolCall {
tool_call: ToolCall,
},
End {
stop_reason: StopReason,
},
Error {
error_type: String,
message: String,
},
}
#[derive(Debug, Clone)]
pub struct LLMClientResponse {
pub content: String,
pub tool_calls: Vec<ToolCall>,
pub stop_reason: StopReason,
}
pub type LLMEventStream = Pin<Box<dyn Stream<Item = Result<LLMStreamEvent, LLMError>> + Send>>;
#[async_trait]
pub trait LLMClient: Send + Sync + std::fmt::Debug {
async fn send_request(
&self,
messages: &[Message],
tools: Option<&[ToolDefinition]>,
sampling: Option<&SamplingParams>,
) -> Result<LLMClientResponse, LLMError>;
async fn send_streaming_request(
&self,
messages: &[Message],
tools: Option<&[ToolDefinition]>,
sampling: Option<&SamplingParams>,
) -> Result<LLMEventStream, LLMError>;
fn provider_name(&self) -> &'static str;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn llm_stream_event_is_debug() {
let event = LLMStreamEvent::Token {
text: "Hello".to_string(),
};
let debug_str = format!("{:?}", event);
assert!(debug_str.contains("Token"));
assert!(debug_str.contains("Hello"));
}
#[test]
fn llm_stream_event_is_clone() {
let event = LLMStreamEvent::Start {
id: "test-id".to_string(),
};
let cloned = event.clone();
assert!(matches!(cloned, LLMStreamEvent::Start { id } if id == "test-id"));
}
#[test]
fn llm_client_response_is_debug() {
let response = LLMClientResponse {
content: "Hello".to_string(),
tool_calls: vec![],
stop_reason: StopReason::EndTurn,
};
let debug_str = format!("{:?}", response);
assert!(debug_str.contains("Hello"));
assert!(debug_str.contains("EndTurn"));
}
#[test]
fn llm_client_response_is_clone() {
let response = LLMClientResponse {
content: "Hello".to_string(),
tool_calls: vec![],
stop_reason: StopReason::EndTurn,
};
let cloned = response.clone();
assert_eq!(cloned.content, "Hello");
assert_eq!(cloned.stop_reason, StopReason::EndTurn);
}
#[test]
fn llm_stream_event_error_variant() {
let event = LLMStreamEvent::Error {
error_type: "rate_limit".to_string(),
message: "Too many requests".to_string(),
};
assert!(matches!(
event,
LLMStreamEvent::Error { error_type, message }
if error_type == "rate_limit" && message == "Too many requests"
));
}
#[test]
fn llm_stream_event_end_variant() {
let event = LLMStreamEvent::End {
stop_reason: StopReason::ToolUse,
};
assert!(matches!(
event,
LLMStreamEvent::End { stop_reason }
if stop_reason == StopReason::ToolUse
));
}
#[test]
fn llm_stream_event_tool_call_variant() {
let tool_call = ToolCall {
id: "tc_123".to_string(),
name: "search".to_string(),
arguments: serde_json::json!({"query": "test"}),
};
let event = LLMStreamEvent::ToolCall {
tool_call: tool_call.clone(),
};
assert!(matches!(
event,
LLMStreamEvent::ToolCall { tool_call: tc }
if tc.id == "tc_123" && tc.name == "search"
));
}
}