use async_trait::async_trait;
use futures::Stream;
use std::pin::Pin;
use talk::context::Message;
use talk::error::AgentError;
use talk::provider::{LLMProvider, ProviderConfig, StreamChunk};
#[derive(Debug, Clone)]
struct MockProvider {
config: ProviderConfig,
response: String,
should_fail: bool,
}
impl MockProvider {
fn new(response: impl Into<String>) -> Self {
Self {
config: ProviderConfig::new("mock-model"),
response: response.into(),
should_fail: false,
}
}
fn with_failure(mut self) -> Self {
self.should_fail = true;
self
}
}
#[async_trait]
impl LLMProvider for MockProvider {
async fn complete(&self, _messages: Vec<Message>) -> Result<String, AgentError> {
if self.should_fail {
Err(AgentError::LLMProvider("Mock provider error".into()))
} else {
Ok(self.response.clone())
}
}
async fn stream(
&self,
_messages: Vec<Message>,
) -> Result<Pin<Box<dyn Stream<Item = StreamChunk> + Send>>, AgentError> {
if self.should_fail {
return Err(AgentError::LLMProvider("Mock provider error".into()));
}
let chunks: Vec<String> = self.response.chars().map(|c| c.to_string()).collect();
Ok(Box::pin(futures::stream::iter(chunks.into_iter().map(Ok))))
}
fn name(&self) -> &str {
"MockProvider"
}
fn config(&self) -> &ProviderConfig {
&self.config
}
}
#[tokio::test]
async fn test_llm_provider_complete_contract() {
let provider = MockProvider::new("This is a test response");
let messages = vec![
Message::system("You are a helpful assistant"),
Message::user("Hello"),
];
let result = provider.complete(messages).await;
assert!(
result.is_ok(),
"LLMProvider::complete should succeed with valid messages"
);
let response = result.unwrap();
assert!(
!response.is_empty(),
"LLMProvider::complete should return a non-empty response"
);
assert_eq!(
response, "This is a test response",
"LLMProvider::complete should return the expected response"
);
}
#[tokio::test]
async fn test_llm_provider_complete_error_contract() {
let provider = MockProvider::new("Response").with_failure();
let messages = vec![Message::user("Hello")];
let result = provider.complete(messages).await;
assert!(
result.is_err(),
"LLMProvider::complete should return error when provider fails"
);
assert!(
matches!(result.unwrap_err(), AgentError::LLMProvider(_)),
"Error should be AgentError::LLMProvider"
);
}
#[tokio::test]
async fn test_llm_provider_stream_contract() {
let provider = MockProvider::new("Hello");
let messages = vec![Message::user("Hi")];
let result = provider.stream(messages).await;
assert!(
result.is_ok(),
"LLMProvider::stream should succeed with valid messages"
);
let mut stream = result.unwrap();
use futures::StreamExt;
let mut chunks = Vec::new();
while let Some(chunk_result) = stream.next().await {
assert!(
chunk_result.is_ok(),
"Stream chunks should not contain errors"
);
chunks.push(chunk_result.unwrap());
}
assert!(
!chunks.is_empty(),
"LLMProvider::stream should produce at least one chunk"
);
let combined: String = chunks.join("");
assert_eq!(
combined, "Hello",
"Stream chunks should combine to form the complete response"
);
}
#[tokio::test]
async fn test_llm_provider_stream_error_contract() {
let provider = MockProvider::new("Response").with_failure();
let messages = vec![Message::user("Hello")];
let result = provider.stream(messages).await;
assert!(
result.is_err(),
"LLMProvider::stream should return error when provider fails"
);
match result {
Err(AgentError::LLMProvider(_)) => {
}
Err(e) => {
panic!("Expected AgentError::LLMProvider, got {:?}", e);
}
Ok(_) => {
panic!("Expected error, got Ok");
}
}
}
#[tokio::test]
async fn test_llm_provider_name_contract() {
let provider = MockProvider::new("Response");
let name = provider.name();
assert!(
!name.is_empty(),
"LLMProvider::name should return a non-empty string"
);
assert_eq!(
provider.name(),
name,
"LLMProvider::name should return consistent values"
);
}
#[tokio::test]
async fn test_llm_provider_config_contract() {
let provider = MockProvider::new("Response");
let config = provider.config();
assert!(
!config.model.is_empty(),
"ProviderConfig should have a non-empty model name"
);
assert!(
config.temperature >= 0.0 && config.temperature <= 2.0,
"ProviderConfig temperature should be in valid range"
);
}
#[tokio::test]
async fn test_llm_provider_message_types_contract() {
let provider = MockProvider::new("Response");
let messages = vec![
Message::system("You are a helpful assistant"),
Message::user("Hello"),
Message::assistant("Hi there!"),
Message::user("How are you?"),
];
let result = provider.complete(messages).await;
assert!(
result.is_ok(),
"LLMProvider should handle all message types"
);
}
#[tokio::test]
async fn test_llm_provider_empty_messages_contract() {
let provider = MockProvider::new("Response");
let messages = vec![];
let result = provider.complete(messages).await;
assert!(
result.is_ok() || result.is_err(),
"LLMProvider should handle empty messages without panicking"
);
}
#[tokio::test]
async fn test_llm_provider_thread_safety_contract() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<MockProvider>();
}
#[tokio::test]
async fn test_llm_provider_concurrent_calls_contract() {
let provider = MockProvider::new("Response");
let handles: Vec<_> = (0..10)
.map(|_| {
let provider_clone = provider.clone();
tokio::spawn(async move {
let messages = vec![Message::user("Test")];
provider_clone.complete(messages).await
})
})
.collect();
for handle in handles {
let result = handle.await.unwrap();
assert!(
result.is_ok(),
"Concurrent LLMProvider calls should succeed"
);
}
}