use mockito::Server;
use openrouter_rust::{
OpenRouterClient,
ChatCompletionBuilder,
ChatCompletionChunk,
collect_stream,
};
use futures::StreamExt;
use serde_json::json;
#[tokio::test]
async fn test_streaming_chat_completion() {
let mut server = Server::new_async().await;
let chunk1 = json!({
"id": "gen-stream-1",
"object": "chat.completion.chunk",
"created": 1704067200,
"model": "openai/gpt-3.5-turbo",
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": "Hello"
},
"finish_reason": null
}
]
});
let chunk2 = json!({
"id": "gen-stream-1",
"object": "chat.completion.chunk",
"created": 1704067200,
"model": "openai/gpt-3.5-turbo",
"choices": [
{
"index": 0,
"delta": {
"content": " world"
},
"finish_reason": null
}
]
});
let chunk3 = json!({
"id": "gen-stream-1",
"object": "chat.completion.chunk",
"created": 1704067200,
"model": "openai/gpt-3.5-turbo",
"choices": [
{
"index": 0,
"delta": {
"content": "!"
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15
}
});
let sse_body = format!(
"data: {}\n\ndata: {}\n\ndata: {}\n\ndata: [DONE]\n\n",
chunk1.to_string(),
chunk2.to_string(),
chunk3.to_string()
);
let _m = server.mock("POST", "/chat/completions")
.match_header("authorization", "Bearer test-key")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(sse_body)
.create_async()
.await;
let client = OpenRouterClient::builder()
.api_key("test-key")
.base_url(&server.url())
.build()
.unwrap();
let request = ChatCompletionBuilder::new("openai/gpt-3.5-turbo")
.user_message("Say hello")
.build();
let stream = client.chat_completion_stream(request).await.unwrap();
let mut stream = stream;
let mut contents = Vec::new();
let mut usage_received = false;
while let Some(result) = stream.next().await {
match result {
Ok(chunk) => {
for choice in &chunk.choices {
if let Some(ref content) = choice.delta.content {
contents.push(content.clone());
}
if choice.finish_reason == Some("stop".to_string()) {
if chunk.usage.is_some() {
usage_received = true;
}
}
}
}
Err(e) => panic!("Stream error: {}", e),
}
}
assert_eq!(contents, vec!["Hello", " world", "!"]);
assert!(usage_received);
}
#[tokio::test]
async fn test_streaming_collect_function() {
let mut server = Server::new_async().await;
let chunk1 = json!({
"id": "gen-collect",
"object": "chat.completion.chunk",
"created": 1704067200,
"model": "openai/gpt-3.5-turbo",
"choices": [{"index": 0, "delta": {"role": "assistant", "content": "The"}, "finish_reason": null}]
});
let chunk2 = json!({
"id": "gen-collect",
"object": "chat.completion.chunk",
"created": 1704067200,
"model": "openai/gpt-3.5-turbo",
"choices": [{"index": 0, "delta": {"content": " answer"}, "finish_reason": null}]
});
let chunk3 = json!({
"id": "gen-collect",
"object": "chat.completion.chunk",
"created": 1704067200,
"model": "openai/gpt-3.5-turbo",
"choices": [{"index": 0, "delta": {"content": " is 42."}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
});
let sse_body = format!(
"data: {}\n\ndata: {}\n\ndata: {}\n\ndata: [DONE]\n\n",
chunk1.to_string(),
chunk2.to_string(),
chunk3.to_string()
);
let _m = server.mock("POST", "/chat/completions")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(sse_body)
.create_async()
.await;
let client = OpenRouterClient::builder()
.api_key("test-key")
.base_url(&server.url())
.build()
.unwrap();
let request = ChatCompletionBuilder::new("openai/gpt-3.5-turbo")
.user_message("What is the answer?")
.build();
let stream = client.chat_completion_stream(request).await.unwrap();
let response = collect_stream(stream).await.unwrap();
assert_eq!(response.choices.len(), 1);
assert_eq!(response.choices[0].message.content, Some("The answer is 42.".to_string()));
assert_eq!(response.choices[0].finish_reason, Some("stop".to_string()));
let usage = response.usage.unwrap();
assert_eq!(usage.prompt_tokens, 5);
assert_eq!(usage.completion_tokens, 5);
assert_eq!(usage.total_tokens, 10);
}
#[tokio::test]
async fn test_streaming_error_midstream() {
let mut server = Server::new_async().await;
let chunk1 = json!({
"id": "gen-error",
"object": "chat.completion.chunk",
"created": 1704067200,
"model": "openai/gpt-3.5-turbo",
"choices": [{"index": 0, "delta": {"content": "Hello"}, "finish_reason": null}]
});
let error_chunk = json!({
"id": "gen-error",
"object": "chat.completion.chunk",
"created": 1704067200,
"model": "openai/gpt-3.5-turbo",
"error": {
"code": 500,
"message": "Internal server error"
},
"choices": [{"index": 0, "delta": {"content": ""}, "finish_reason": "error"}]
});
let sse_body = format!(
"data: {}\n\ndata: {}\n\n",
chunk1.to_string(),
error_chunk.to_string()
);
let _m = server.mock("POST", "/chat/completions")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(sse_body)
.create_async()
.await;
let client = OpenRouterClient::builder()
.api_key("test-key")
.base_url(&server.url())
.build()
.unwrap();
let request = ChatCompletionBuilder::new("openai/gpt-3.5-turbo")
.user_message("Test")
.build();
let stream = client.chat_completion_stream(request).await.unwrap();
let mut stream = stream;
let mut received_error = false;
while let Some(result) = stream.next().await {
match result {
Ok(_) => {}
Err(_) => {
received_error = true;
break;
}
}
}
assert!(received_error);
}
#[tokio::test]
async fn test_streaming_sse_comments() {
let mut server = Server::new_async().await;
let chunk1 = json!({
"id": "gen-comment",
"object": "chat.completion.chunk",
"created": 1704067200,
"model": "openai/gpt-3.5-turbo",
"choices": [{"index": 0, "delta": {"content": "Test"}, "finish_reason": null}]
});
let sse_body = format!(
": OPENROUTER PROCESSING\n\ndata: {}\n\ndata: [DONE]\n\n",
chunk1.to_string()
);
let _m = server.mock("POST", "/chat/completions")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(sse_body)
.create_async()
.await;
let client = OpenRouterClient::builder()
.api_key("test-key")
.base_url(&server.url())
.build()
.unwrap();
let request = ChatCompletionBuilder::new("openai/gpt-3.5-turbo")
.user_message("Test")
.build();
let stream = client.chat_completion_stream(request).await.unwrap();
let mut stream = stream;
let mut content_received = false;
while let Some(result) = stream.next().await {
match result {
Ok(chunk) => {
for choice in &chunk.choices {
if let Some(ref content) = choice.delta.content {
if content == "Test" {
content_received = true;
}
}
}
}
Err(_) => {}
}
}
assert!(content_received);
}