use async_trait::async_trait;
#[cfg(test)]
use super::error::LlmError;
use super::error::LlmResult;
#[async_trait]
pub trait LlmClient: Send + Sync {
async fn complete(&self, prompt: &str) -> LlmResult<String>;
fn model_name(&self) -> &str;
fn max_tokens(&self) -> usize;
async fn is_ready(&self) -> bool;
}
#[derive(Debug, Clone)]
pub struct CompletionResponse {
pub text: String,
pub prompt_tokens: Option<usize>,
pub completion_tokens: Option<usize>,
pub duration_ms: Option<u64>,
}
impl CompletionResponse {
pub fn new(text: impl Into<String>) -> Self {
Self {
text: text.into(),
prompt_tokens: None,
completion_tokens: None,
duration_ms: None,
}
}
pub fn with_tokens(mut self, prompt: usize, completion: usize) -> Self {
self.prompt_tokens = Some(prompt);
self.completion_tokens = Some(completion);
self
}
pub fn with_duration(mut self, ms: u64) -> Self {
self.duration_ms = Some(ms);
self
}
}
#[cfg(test)]
pub struct MockLlmClient {
response: String,
model: String,
max_tokens: usize,
should_fail: bool,
}
#[cfg(test)]
impl MockLlmClient {
pub fn new(response: impl Into<String>) -> Self {
Self {
response: response.into(),
model: "mock-model".to_string(),
max_tokens: 4096,
should_fail: false,
}
}
pub fn failing() -> Self {
Self {
response: String::new(),
model: "mock-model".to_string(),
max_tokens: 4096,
should_fail: true,
}
}
}
#[cfg(test)]
#[async_trait]
impl LlmClient for MockLlmClient {
async fn complete(&self, _prompt: &str) -> LlmResult<String> {
if self.should_fail {
Err(LlmError::ConnectionError("Mock failure".to_string()))
} else {
Ok(self.response.clone())
}
}
fn model_name(&self) -> &str {
&self.model
}
fn max_tokens(&self) -> usize {
self.max_tokens
}
async fn is_ready(&self) -> bool {
!self.should_fail
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mock_client_success() {
let client = MockLlmClient::new("Test response");
assert!(client.is_ready().await);
assert_eq!(client.model_name(), "mock-model");
assert_eq!(client.max_tokens(), 4096);
let response = client.complete("Test prompt").await.unwrap();
assert_eq!(response, "Test response");
}
#[tokio::test]
async fn test_mock_client_failure() {
let client = MockLlmClient::failing();
assert!(!client.is_ready().await);
let result = client.complete("Test prompt").await;
assert!(result.is_err());
}
#[test]
fn test_completion_response() {
let response = CompletionResponse::new("Generated text")
.with_tokens(100, 50)
.with_duration(1500);
assert_eq!(response.text, "Generated text");
assert_eq!(response.prompt_tokens, Some(100));
assert_eq!(response.completion_tokens, Some(50));
assert_eq!(response.duration_ms, Some(1500));
}
}