use serde::{Deserialize, Serialize};
use crate::error::SynthError;
pub trait LlmProvider: Send + Sync {
fn name(&self) -> &str;
fn complete(&self, request: &LlmRequest) -> Result<LlmResponse, SynthError>;
fn complete_batch(&self, requests: &[LlmRequest]) -> Result<Vec<LlmResponse>, SynthError> {
requests.iter().map(|r| self.complete(r)).collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmRequest {
pub prompt: String,
pub system: Option<String>,
pub max_tokens: u32,
pub temperature: f64,
pub seed: Option<u64>,
}
impl LlmRequest {
pub fn new(prompt: impl Into<String>) -> Self {
Self {
prompt: prompt.into(),
system: None,
max_tokens: 1024,
temperature: 0.7,
seed: None,
}
}
pub fn with_system(mut self, system: impl Into<String>) -> Self {
self.system = Some(system.into());
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_temperature(mut self, temperature: f64) -> Self {
self.temperature = temperature;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmResponse {
pub content: String,
pub usage: TokenUsage,
pub cached: bool,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TokenUsage {
pub input_tokens: u32,
pub output_tokens: u32,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LlmProviderType {
#[default]
Mock,
OpenAi,
Anthropic,
Custom,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmConfig {
pub provider: LlmProviderType,
#[serde(default = "default_llm_model")]
pub model: String,
#[serde(default)]
pub api_key_env: String,
#[serde(default)]
pub base_url: Option<String>,
#[serde(default = "default_max_retries")]
pub max_retries: u8,
#[serde(default = "default_timeout_secs")]
pub timeout_secs: u64,
#[serde(default = "default_true_val")]
pub cache_enabled: bool,
}
fn default_llm_model() -> String {
"gpt-4o-mini".to_string()
}
fn default_max_retries() -> u8 {
3
}
fn default_timeout_secs() -> u64 {
30
}
fn default_true_val() -> bool {
true
}
impl Default for LlmConfig {
fn default() -> Self {
Self {
provider: LlmProviderType::default(),
model: default_llm_model(),
api_key_env: String::new(),
base_url: None,
max_retries: default_max_retries(),
timeout_secs: default_timeout_secs(),
cache_enabled: true,
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_llm_request_builder() {
let req = LlmRequest::new("test prompt")
.with_system("system prompt")
.with_seed(42)
.with_max_tokens(512)
.with_temperature(0.5);
assert_eq!(req.prompt, "test prompt");
assert_eq!(req.system, Some("system prompt".to_string()));
assert_eq!(req.seed, Some(42));
assert_eq!(req.max_tokens, 512);
assert!((req.temperature - 0.5).abs() < f64::EPSILON);
}
#[test]
fn test_llm_request_serde_roundtrip() {
let req = LlmRequest::new("test").with_seed(42);
let json = serde_json::to_string(&req).unwrap();
let deserialized: LlmRequest = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.prompt, "test");
assert_eq!(deserialized.seed, Some(42));
}
#[test]
fn test_llm_response_serde_roundtrip() {
let resp = LlmResponse {
content: "output".to_string(),
usage: TokenUsage {
input_tokens: 10,
output_tokens: 20,
},
cached: false,
};
let json = serde_json::to_string(&resp).unwrap();
let deserialized: LlmResponse = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.content, "output");
assert_eq!(deserialized.usage.input_tokens, 10);
}
#[test]
fn test_llm_config_default() {
let config = LlmConfig::default();
assert!(matches!(config.provider, LlmProviderType::Mock));
assert!(config.cache_enabled);
assert_eq!(config.max_retries, 3);
}
}