datasynth_core/llm/
provider.rs1use serde::{Deserialize, Serialize};
2
3use crate::error::SynthError;
4
5pub trait LlmProvider: Send + Sync {
7 fn name(&self) -> &str;
9 fn complete(&self, request: &LlmRequest) -> Result<LlmResponse, SynthError>;
11 fn complete_batch(&self, requests: &[LlmRequest]) -> Result<Vec<LlmResponse>, SynthError> {
13 requests.iter().map(|r| self.complete(r)).collect()
14 }
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct LlmRequest {
20 pub prompt: String,
22 pub system: Option<String>,
24 pub max_tokens: u32,
26 pub temperature: f64,
28 pub seed: Option<u64>,
30}
31
32impl LlmRequest {
33 pub fn new(prompt: impl Into<String>) -> Self {
35 Self {
36 prompt: prompt.into(),
37 system: None,
38 max_tokens: 1024,
39 temperature: 0.7,
40 seed: None,
41 }
42 }
43
44 pub fn with_system(mut self, system: impl Into<String>) -> Self {
46 self.system = Some(system.into());
47 self
48 }
49
50 pub fn with_seed(mut self, seed: u64) -> Self {
52 self.seed = Some(seed);
53 self
54 }
55
56 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
58 self.max_tokens = max_tokens;
59 self
60 }
61
62 pub fn with_temperature(mut self, temperature: f64) -> Self {
64 self.temperature = temperature;
65 self
66 }
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct LlmResponse {
72 pub content: String,
74 pub usage: TokenUsage,
76 pub cached: bool,
78}
79
80#[derive(Debug, Clone, Default, Serialize, Deserialize)]
82pub struct TokenUsage {
83 pub input_tokens: u32,
85 pub output_tokens: u32,
87}
88
89#[derive(Debug, Clone, Default, Serialize, Deserialize)]
91#[serde(rename_all = "snake_case")]
92pub enum LlmProviderType {
93 #[default]
95 Mock,
96 OpenAi,
98 Anthropic,
100 Custom,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct LlmConfig {
107 pub provider: LlmProviderType,
109 #[serde(default = "default_llm_model")]
111 pub model: String,
112 #[serde(default)]
114 pub api_key_env: String,
115 #[serde(default)]
117 pub base_url: Option<String>,
118 #[serde(default = "default_max_retries")]
120 pub max_retries: u8,
121 #[serde(default = "default_timeout_secs")]
123 pub timeout_secs: u64,
124 #[serde(default = "default_true_val")]
126 pub cache_enabled: bool,
127}
128
129fn default_llm_model() -> String {
130 "gpt-4o-mini".to_string()
131}
132fn default_max_retries() -> u8 {
133 3
134}
135fn default_timeout_secs() -> u64 {
136 30
137}
138fn default_true_val() -> bool {
139 true
140}
141
142impl Default for LlmConfig {
143 fn default() -> Self {
144 Self {
145 provider: LlmProviderType::default(),
146 model: default_llm_model(),
147 api_key_env: String::new(),
148 base_url: None,
149 max_retries: default_max_retries(),
150 timeout_secs: default_timeout_secs(),
151 cache_enabled: true,
152 }
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159
160 #[test]
161 fn test_llm_request_builder() {
162 let req = LlmRequest::new("test prompt")
163 .with_system("system prompt")
164 .with_seed(42)
165 .with_max_tokens(512)
166 .with_temperature(0.5);
167 assert_eq!(req.prompt, "test prompt");
168 assert_eq!(req.system, Some("system prompt".to_string()));
169 assert_eq!(req.seed, Some(42));
170 assert_eq!(req.max_tokens, 512);
171 assert!((req.temperature - 0.5).abs() < f64::EPSILON);
172 }
173
174 #[test]
175 fn test_llm_request_serde_roundtrip() {
176 let req = LlmRequest::new("test").with_seed(42);
177 let json = serde_json::to_string(&req).unwrap();
178 let deserialized: LlmRequest = serde_json::from_str(&json).unwrap();
179 assert_eq!(deserialized.prompt, "test");
180 assert_eq!(deserialized.seed, Some(42));
181 }
182
183 #[test]
184 fn test_llm_response_serde_roundtrip() {
185 let resp = LlmResponse {
186 content: "output".to_string(),
187 usage: TokenUsage {
188 input_tokens: 10,
189 output_tokens: 20,
190 },
191 cached: false,
192 };
193 let json = serde_json::to_string(&resp).unwrap();
194 let deserialized: LlmResponse = serde_json::from_str(&json).unwrap();
195 assert_eq!(deserialized.content, "output");
196 assert_eq!(deserialized.usage.input_tokens, 10);
197 }
198
199 #[test]
200 fn test_llm_config_default() {
201 let config = LlmConfig::default();
202 assert!(matches!(config.provider, LlmProviderType::Mock));
203 assert!(config.cache_enabled);
204 assert_eq!(config.max_retries, 3);
205 }
206}