use async_trait::async_trait;
use crate::message::Message;
use crate::tool::ToolDefinition;
#[derive(Debug, Clone)]
pub struct CompletionRequest {
pub messages: Vec<Message>,
pub tools: Vec<ToolDefinition>,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub model: Option<String>,
}
#[derive(Debug, Clone)]
pub struct CompletionResponse {
pub message: Message,
pub usage: Usage,
pub finish_reason: FinishReason,
}
#[derive(Debug, Clone, Default)]
pub struct Usage {
pub input_tokens: u32,
pub output_tokens: u32,
}
impl Usage {
pub fn total_tokens(&self) -> u32 {
self.input_tokens + self.output_tokens
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum FinishReason {
Stop,
ToolUse,
MaxTokens,
Other(String),
}
#[async_trait]
pub trait Provider: Send + Sync {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, ProviderError>;
}
#[derive(Debug, Clone)]
pub enum StreamEvent {
TextDelta(String),
ToolCallStart { id: String, name: String },
ToolCallDelta { id: String, arguments_delta: String },
Done {
usage: Usage,
finish_reason: FinishReason,
},
Error(String),
}
#[async_trait]
pub trait StreamingProvider: Provider {
async fn stream(
&self,
request: CompletionRequest,
) -> Result<tokio::sync::mpsc::Receiver<StreamEvent>, ProviderError>;
}
#[derive(Debug, thiserror::Error)]
pub enum ProviderError {
#[error("authentication failed: {0}")]
Auth(String),
#[error("rate limited, retry after {retry_after_ms:?}ms")]
RateLimited { retry_after_ms: Option<u64> },
#[error("context length exceeded: {0}")]
ContextLengthExceeded(String),
#[error("provider error: {0}")]
Other(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn usage_total_tokens() {
let usage = Usage {
input_tokens: 100,
output_tokens: 50,
};
assert_eq!(usage.total_tokens(), 150);
}
#[test]
fn default_usage() {
let usage = Usage::default();
assert_eq!(usage.input_tokens, 0);
assert_eq!(usage.output_tokens, 0);
assert_eq!(usage.total_tokens(), 0);
}
#[test]
fn finish_reason_equality() {
assert_eq!(FinishReason::Stop, FinishReason::Stop);
assert_eq!(FinishReason::ToolUse, FinishReason::ToolUse);
assert_ne!(FinishReason::Stop, FinishReason::ToolUse);
assert_eq!(
FinishReason::Other("foo".into()),
FinishReason::Other("foo".into())
);
}
struct MockProvider {
response: CompletionResponse,
}
#[async_trait]
impl Provider for MockProvider {
async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse, ProviderError> {
Ok(self.response.clone())
}
}
#[tokio::test]
async fn mock_provider_returns_response() {
let provider = MockProvider {
response: CompletionResponse {
message: Message::assistant("Hello!"),
usage: Usage {
input_tokens: 10,
output_tokens: 5,
},
finish_reason: FinishReason::Stop,
},
};
let request = CompletionRequest {
messages: vec![Message::user("Hi")],
tools: vec![],
max_tokens: None,
temperature: None,
model: None,
};
let response = provider.complete(request).await.unwrap();
assert_eq!(response.message.content, "Hello!");
assert_eq!(response.usage.total_tokens(), 15);
assert_eq!(response.finish_reason, FinishReason::Stop);
}
struct ErrorProvider;
#[async_trait]
impl Provider for ErrorProvider {
async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse, ProviderError> {
Err(ProviderError::RateLimited {
retry_after_ms: Some(1000),
})
}
}
#[tokio::test]
async fn provider_error_handling() {
let provider = ErrorProvider;
let request = CompletionRequest {
messages: vec![Message::user("Hi")],
tools: vec![],
max_tokens: None,
temperature: None,
model: None,
};
let err = provider.complete(request).await.unwrap_err();
match err {
ProviderError::RateLimited { retry_after_ms } => {
assert_eq!(retry_after_ms, Some(1000));
}
_ => panic!("expected RateLimited error"),
}
}
#[test]
fn stream_event_text_delta() {
let event = StreamEvent::TextDelta("hello".into());
match event {
StreamEvent::TextDelta(s) => assert_eq!(s, "hello"),
_ => panic!("expected TextDelta"),
}
}
#[test]
fn stream_event_tool_call_lifecycle() {
let start = StreamEvent::ToolCallStart {
id: "c1".into(),
name: "search".into(),
};
let delta = StreamEvent::ToolCallDelta {
id: "c1".into(),
arguments_delta: "{\"q\":".into(),
};
let done = StreamEvent::Done {
usage: Usage { input_tokens: 10, output_tokens: 5 },
finish_reason: FinishReason::ToolUse,
};
match start {
StreamEvent::ToolCallStart { id, name } => {
assert_eq!(id, "c1");
assert_eq!(name, "search");
}
_ => panic!("expected ToolCallStart"),
}
match delta {
StreamEvent::ToolCallDelta { id, arguments_delta } => {
assert_eq!(id, "c1");
assert_eq!(arguments_delta, "{\"q\":");
}
_ => panic!("expected ToolCallDelta"),
}
match done {
StreamEvent::Done { usage, finish_reason } => {
assert_eq!(usage.total_tokens(), 15);
assert_eq!(finish_reason, FinishReason::ToolUse);
}
_ => panic!("expected Done"),
}
}
}