use async_trait::async_trait;
use crate::types::completion::{CompletionRequest, CompletionResponse};
use crate::types::stream::CompletionStream;
use crate::Result;
pub use crate::types::model_info::{ModelInfo, ModelTier};
#[async_trait]
pub trait Provider: Send + Sync + 'static {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse>;
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream>;
fn model_info(&self) -> &ModelInfo;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::completion::{ResponseContent, Usage};
struct MockProvider {
info: ModelInfo,
response: String,
}
impl MockProvider {
fn new(response: impl Into<String>) -> Self {
Self {
info: ModelInfo {
name: "mock-model".to_string(),
tier: ModelTier::Medium,
context_window: 8_192,
supports_tools: true,
supports_vision: false,
supports_structured: true,
},
response: response.into(),
}
}
}
#[async_trait]
impl Provider for MockProvider {
async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse> {
Ok(CompletionResponse {
content: ResponseContent::Text(self.response.clone()),
usage: Usage {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
})
}
async fn stream(&self, _request: CompletionRequest) -> Result<CompletionStream> {
let text = self.response.clone();
let stream = tokio_stream::once(Ok(crate::types::stream::StreamEvent::TextDelta(text)));
Ok(Box::pin(stream))
}
fn model_info(&self) -> &ModelInfo {
&self.info
}
}
#[tokio::test]
async fn test_mock_provider_complete_ac1() {
use crate::types::message::Message;
let provider = MockProvider::new("Hello from LLM!");
let request = CompletionRequest {
model: "mock-model".to_string(),
messages: vec![Message::user("Hi")],
tools: vec![],
max_tokens: None,
temperature: None,
response_format: None,
stream: false,
};
let response = provider.complete(request).await.unwrap();
match response.content {
ResponseContent::Text(t) => assert_eq!(t, "Hello from LLM!"),
ResponseContent::ToolCalls(_) => panic!("expected Text"),
}
assert_eq!(response.usage.total_tokens, 15);
}
#[tokio::test]
async fn test_mock_provider_stream_ac2() {
use crate::types::message::Message;
use crate::types::stream::StreamEvent;
use tokio_stream::StreamExt;
let provider = MockProvider::new("streamed chunk");
let request = CompletionRequest {
model: "mock-model".to_string(),
messages: vec![Message::user("stream test")],
tools: vec![],
max_tokens: None,
temperature: None,
response_format: None,
stream: true,
};
let mut stream = provider.stream(request).await.unwrap();
let event = stream
.next()
.await
.expect("expected at least one event")
.unwrap();
match event {
StreamEvent::TextDelta(t) => assert_eq!(t, "streamed chunk"),
_ => panic!("expected TextDelta"),
}
}
#[test]
fn test_mock_provider_model_info_ac3() {
let provider = MockProvider::new("x");
let info = provider.model_info();
assert_eq!(info.name, "mock-model");
assert_eq!(info.tier, ModelTier::Medium);
assert_eq!(info.context_window, 8_192);
assert!(info.supports_tools);
assert!(!info.supports_vision);
}
#[test]
fn test_provider_send_sync_static_bounds_ac5() {
fn assert_bounds<T: Send + Sync + 'static>() {}
assert_bounds::<MockProvider>();
}
}