mod sim;
#[cfg(feature = "anthropic")]
mod anthropic;
#[cfg(feature = "openai")]
mod openai;
pub use sim::SimLLMProvider;
#[cfg(feature = "anthropic")]
pub use anthropic::AnthropicProvider;
#[cfg(feature = "openai")]
pub use openai::OpenAIProvider;
use async_trait::async_trait;
use serde::de::DeserializeOwned;
use crate::constants::{LLM_PROMPT_BYTES_MAX, LLM_RESPONSE_BYTES_MAX};
#[derive(Debug, Clone, thiserror::Error)]
pub enum ProviderError {
#[error("Request timed out")]
Timeout,
#[error("Rate limit exceeded, retry after {retry_after_secs:?}s")]
RateLimit {
retry_after_secs: Option<u64>,
},
#[error("Context length exceeded: {tokens} tokens")]
ContextOverflow {
tokens: usize,
},
#[error("Invalid response: {message}")]
InvalidResponse {
message: String,
},
#[error("Service unavailable: {message}")]
ServiceUnavailable {
message: String,
},
#[error("Authentication failed")]
AuthenticationFailed,
#[error("JSON error: {message}")]
JsonError {
message: String,
},
#[error("Network error: {message}")]
NetworkError {
message: String,
},
#[error("Invalid request: {message}")]
InvalidRequest {
message: String,
},
}
impl ProviderError {
#[must_use]
pub fn timeout() -> Self {
Self::Timeout
}
#[must_use]
pub fn rate_limit(retry_after_secs: Option<u64>) -> Self {
Self::RateLimit { retry_after_secs }
}
#[must_use]
pub fn context_overflow(tokens: usize) -> Self {
Self::ContextOverflow { tokens }
}
#[must_use]
pub fn invalid_response(message: impl Into<String>) -> Self {
Self::InvalidResponse {
message: message.into(),
}
}
#[must_use]
pub fn service_unavailable(message: impl Into<String>) -> Self {
Self::ServiceUnavailable {
message: message.into(),
}
}
#[must_use]
pub fn json_error(message: impl Into<String>) -> Self {
Self::JsonError {
message: message.into(),
}
}
#[must_use]
pub fn network_error(message: impl Into<String>) -> Self {
Self::NetworkError {
message: message.into(),
}
}
#[must_use]
pub fn invalid_request(message: impl Into<String>) -> Self {
Self::InvalidRequest {
message: message.into(),
}
}
#[must_use]
pub fn is_retryable(&self) -> bool {
matches!(
self,
Self::Timeout | Self::RateLimit { .. } | Self::ServiceUnavailable { .. }
)
}
}
#[derive(Debug, Clone)]
pub struct CompletionRequest {
pub prompt: String,
pub system: Option<String>,
pub max_tokens: Option<usize>,
pub temperature: Option<f32>,
pub json_mode: bool,
}
impl CompletionRequest {
#[must_use]
pub fn new(prompt: impl Into<String>) -> Self {
let prompt = prompt.into();
assert!(!prompt.is_empty(), "prompt must not be empty");
assert!(
prompt.len() <= LLM_PROMPT_BYTES_MAX,
"prompt exceeds {} bytes",
LLM_PROMPT_BYTES_MAX
);
Self {
prompt,
system: None,
max_tokens: None,
temperature: None,
json_mode: false,
}
}
#[must_use]
pub fn with_system(mut self, system: impl Into<String>) -> Self {
self.system = Some(system.into());
self
}
#[must_use]
pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
self.max_tokens = Some(max_tokens);
self
}
#[must_use]
pub fn with_temperature(mut self, temperature: f32) -> Self {
assert!(
(0.0..=1.0).contains(&temperature),
"temperature must be in [0.0, 1.0]"
);
self.temperature = Some(temperature);
self
}
#[must_use]
pub fn with_json_mode(mut self) -> Self {
self.json_mode = true;
self
}
}
#[async_trait]
pub trait LLMProvider: Send + Sync {
async fn complete(&self, request: &CompletionRequest) -> Result<String, ProviderError>;
async fn complete_json<T: DeserializeOwned + Send>(
&self,
request: &CompletionRequest,
) -> Result<T, ProviderError> {
let response = self.complete(request).await?;
debug_assert!(
response.len() <= LLM_RESPONSE_BYTES_MAX,
"response exceeds limit"
);
serde_json::from_str(&response).map_err(|e| ProviderError::json_error(e.to_string()))
}
fn name(&self) -> &'static str;
fn is_simulation(&self) -> bool;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_completion_request_new() {
let request = CompletionRequest::new("Hello, world!");
assert_eq!(request.prompt, "Hello, world!");
assert!(request.system.is_none());
assert!(request.max_tokens.is_none());
assert!(request.temperature.is_none());
assert!(!request.json_mode);
}
#[test]
fn test_completion_request_builder() {
let request = CompletionRequest::new("Hello")
.with_system("You are a helpful assistant")
.with_max_tokens(100)
.with_temperature(0.7)
.with_json_mode();
assert_eq!(request.prompt, "Hello");
assert_eq!(request.system, Some("You are a helpful assistant".into()));
assert_eq!(request.max_tokens, Some(100));
assert_eq!(request.temperature, Some(0.7));
assert!(request.json_mode);
}
#[test]
#[should_panic(expected = "prompt must not be empty")]
fn test_completion_request_empty_prompt() {
let _ = CompletionRequest::new("");
}
#[test]
#[should_panic(expected = "temperature must be in")]
fn test_completion_request_invalid_temperature() {
let _ = CompletionRequest::new("Hello").with_temperature(1.5);
}
#[test]
fn test_provider_error_is_retryable() {
assert!(ProviderError::timeout().is_retryable());
assert!(ProviderError::rate_limit(Some(60)).is_retryable());
assert!(ProviderError::service_unavailable("down").is_retryable());
assert!(!ProviderError::AuthenticationFailed.is_retryable());
assert!(!ProviderError::json_error("parse failed").is_retryable());
}
#[test]
fn test_provider_error_constructors() {
let err = ProviderError::context_overflow(10000);
assert!(matches!(
err,
ProviderError::ContextOverflow { tokens: 10000 }
));
let err = ProviderError::invalid_response("bad format");
assert!(matches!(err, ProviderError::InvalidResponse { .. }));
let err = ProviderError::network_error("connection refused");
assert!(matches!(err, ProviderError::NetworkError { .. }));
}
}