use anyhow::Result;
use async_trait::async_trait;
use opencrabs::llm::provider::{
error::{ProviderError, Result as ProviderResult},
types::{
ContentBlock, ContentDelta, LLMRequest, LLMResponse, MessageDelta, Role, StopReason,
StreamEvent, StreamMessage, TokenUsage,
},
Provider, ProviderStream,
};
use futures::{stream, StreamExt};
struct StreamingMockProvider {
events: Vec<StreamEvent>,
}
impl StreamingMockProvider {
fn new(text_chunks: Vec<&str>) -> Self {
let mut events = vec![StreamEvent::MessageStart {
message: StreamMessage {
id: "msg-test".to_string(),
model: "mock-model".to_string(),
role: Role::Assistant,
usage: TokenUsage {
input_tokens: 10,
output_tokens: 0, ..Default::default() },
},
}];
events.push(StreamEvent::ContentBlockStart {
index: 0,
content_block: ContentBlock::Text {
text: String::new(),
},
});
for chunk in text_chunks {
events.push(StreamEvent::ContentBlockDelta {
index: 0,
delta: ContentDelta::TextDelta {
text: chunk.to_string(),
},
});
}
events.push(StreamEvent::ContentBlockStop { index: 0 });
events.push(StreamEvent::MessageDelta {
delta: MessageDelta {
stop_reason: Some(StopReason::EndTurn),
stop_sequence: None,
},
usage: TokenUsage {
input_tokens: 10,
output_tokens: 20, ..Default::default() },
});
events.push(StreamEvent::MessageStop);
Self { events }
}
fn with_error(error_message: &str) -> Self {
Self {
events: vec![StreamEvent::Error {
error: error_message.to_string(),
}],
}
}
}
#[async_trait]
impl Provider for StreamingMockProvider {
async fn complete(&self, _request: LLMRequest) -> ProviderResult<LLMResponse> {
Err(ProviderError::StreamingNotSupported)
}
async fn stream(&self, _request: LLMRequest) -> ProviderResult<ProviderStream> {
let events = self.events.clone();
let stream = stream::iter(events.into_iter().map(Ok));
Ok(Box::pin(stream))
}
fn name(&self) -> &str {
"streaming-mock"
}
fn default_model(&self) -> &str {
"mock-model"
}
fn supported_models(&self) -> Vec<String> {
vec!["mock-model".to_string()]
}
fn context_window(&self, _model: &str) -> Option<u32> {
Some(8192)
}
fn calculate_cost(&self, _model: &str, input_tokens: u32, output_tokens: u32) -> f64 {
((input_tokens + output_tokens) as f64 / 1000.0) * 0.001
}
fn supports_streaming(&self) -> bool {
true
}
}
#[tokio::test]
async fn test_streaming_basic() -> Result<()> {
let provider = StreamingMockProvider::new(vec!["Hello", " ", "world", "!"]);
let request = LLMRequest::new("mock-model", vec![]).with_streaming();
let mut stream = provider.stream(request).await?;
let mut events = vec![];
while let Some(event) = stream.next().await {
events.push(event?);
}
assert_eq!(events.len(), 9);
assert!(matches!(events[0], StreamEvent::MessageStart { .. }));
let mut text_chunks = vec![];
for event in &events {
if let StreamEvent::ContentBlockDelta {
delta: ContentDelta::TextDelta { text },
..
} = event
{
text_chunks.push(text.clone());
}
}
assert_eq!(text_chunks, vec!["Hello", " ", "world", "!"]);
assert!(matches!(events[events.len() - 1], StreamEvent::MessageStop));
Ok(())
}
#[tokio::test]
async fn test_streaming_single_chunk() -> Result<()> {
let provider = StreamingMockProvider::new(vec!["Complete response in one go"]);
let request = LLMRequest::new("mock-model", vec![]).with_streaming();
let mut stream = provider.stream(request).await?;
let mut event_count = 0;
let mut text_received = String::new();
while let Some(event) = stream.next().await {
event_count += 1;
if let StreamEvent::ContentBlockDelta {
delta: ContentDelta::TextDelta { text },
..
} = event?
{
text_received.push_str(&text);
}
}
assert_eq!(text_received, "Complete response in one go");
assert_eq!(event_count, 6); Ok(())
}
#[tokio::test]
async fn test_streaming_multiple_chunks() -> Result<()> {
let chunks = vec![
"This",
" is",
" a",
" longer",
" response",
" with",
" many",
" chunks",
];
let provider = StreamingMockProvider::new(chunks.clone());
let request = LLMRequest::new("mock-model", vec![]).with_streaming();
let mut stream = provider.stream(request).await?;
let mut received_chunks = vec![];
while let Some(event) = stream.next().await {
if let StreamEvent::ContentBlockDelta {
delta: ContentDelta::TextDelta { text },
..
} = event?
{
received_chunks.push(text);
}
}
assert_eq!(received_chunks, chunks);
Ok(())
}
#[tokio::test]
async fn test_streaming_token_counting() -> Result<()> {
let provider = StreamingMockProvider::new(vec!["Test"]);
let request = LLMRequest::new("mock-model", vec![]).with_streaming();
let mut stream = provider.stream(request).await?;
let mut start_tokens = None;
let mut end_tokens = None;
while let Some(event) = stream.next().await {
match event? {
StreamEvent::MessageStart { message } => {
start_tokens = Some(message.usage);
}
StreamEvent::MessageDelta { usage, .. } => {
end_tokens = Some(usage);
}
_ => {}
}
}
assert!(start_tokens.is_some());
assert!(end_tokens.is_some());
let end = end_tokens.unwrap();
assert_eq!(end.input_tokens, 10);
assert_eq!(end.output_tokens, 20);
assert_eq!(end.total(), 30);
Ok(())
}
#[tokio::test]
async fn test_streaming_stop_reason() -> Result<()> {
let provider = StreamingMockProvider::new(vec!["Test"]);
let request = LLMRequest::new("mock-model", vec![]).with_streaming();
let mut stream = provider.stream(request).await?;
let mut stop_reason = None;
while let Some(event) = stream.next().await {
if let StreamEvent::MessageDelta { delta, .. } = event? {
stop_reason = delta.stop_reason;
}
}
assert_eq!(stop_reason, Some(StopReason::EndTurn));
Ok(())
}
#[tokio::test]
async fn test_streaming_error_handling() -> Result<()> {
let provider = StreamingMockProvider::with_error("Test error message");
let request = LLMRequest::new("mock-model", vec![]).with_streaming();
let mut stream = provider.stream(request).await?;
let mut error_received = None;
while let Some(event) = stream.next().await {
if let StreamEvent::Error { error } = event? {
error_received = Some(error);
break;
}
}
assert!(error_received.is_some());
assert_eq!(error_received.unwrap(), "Test error message");
Ok(())
}
#[tokio::test]
async fn test_streaming_empty_response() -> Result<()> {
let provider = StreamingMockProvider::new(vec![]);
let request = LLMRequest::new("mock-model", vec![]).with_streaming();
let mut stream = provider.stream(request).await?;
let mut event_count = 0;
while let Some(event) = stream.next().await {
event?;
event_count += 1;
}
assert_eq!(event_count, 5);
Ok(())
}
#[tokio::test]
async fn test_streaming_content_accumulation() -> Result<()> {
let provider = StreamingMockProvider::new(vec!["Hello", " ", "world", "!"]);
let request = LLMRequest::new("mock-model", vec![]).with_streaming();
let mut stream = provider.stream(request).await?;
let mut accumulated_text = String::new();
while let Some(event) = stream.next().await {
if let StreamEvent::ContentBlockDelta {
delta: ContentDelta::TextDelta { text },
..
} = event?
{
accumulated_text.push_str(&text);
}
}
assert_eq!(accumulated_text, "Hello world!");
Ok(())
}
#[tokio::test]
async fn test_streaming_request_builder() {
let request = LLMRequest::new("test-model", vec![]).with_streaming();
assert!(request.stream);
assert_eq!(request.model, "test-model");
}
#[tokio::test]
async fn test_provider_supports_streaming() {
let provider = StreamingMockProvider::new(vec!["test"]);
assert!(provider.supports_streaming());
}