datasynth_core/llm/
provider.rs1use serde::{Deserialize, Serialize};
2
3use crate::error::SynthError;
4
5pub trait LlmProvider: Send + Sync {
7 fn name(&self) -> &str;
9 fn complete(&self, request: &LlmRequest) -> Result<LlmResponse, SynthError>;
11 fn complete_batch(&self, requests: &[LlmRequest]) -> Result<Vec<LlmResponse>, SynthError> {
13 requests.iter().map(|r| self.complete(r)).collect()
14 }
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct LlmRequest {
20 pub prompt: String,
22 pub system: Option<String>,
24 pub max_tokens: u32,
26 pub temperature: f64,
28 pub seed: Option<u64>,
30}
31
32impl LlmRequest {
33 pub fn new(prompt: impl Into<String>) -> Self {
35 Self {
36 prompt: prompt.into(),
37 system: None,
38 max_tokens: 1024,
39 temperature: 0.7,
40 seed: None,
41 }
42 }
43
44 pub fn with_system(mut self, system: impl Into<String>) -> Self {
46 self.system = Some(system.into());
47 self
48 }
49
50 pub fn with_seed(mut self, seed: u64) -> Self {
52 self.seed = Some(seed);
53 self
54 }
55
56 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
58 self.max_tokens = max_tokens;
59 self
60 }
61
62 pub fn with_temperature(mut self, temperature: f64) -> Self {
64 self.temperature = temperature;
65 self
66 }
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct LlmResponse {
72 pub content: String,
74 pub usage: TokenUsage,
76 pub cached: bool,
78}
79
80#[derive(Debug, Clone, Default, Serialize, Deserialize)]
82pub struct TokenUsage {
83 pub input_tokens: u32,
85 pub output_tokens: u32,
87}
88
89#[derive(Debug, Clone, Default, Serialize, Deserialize)]
91#[serde(rename_all = "snake_case")]
92pub enum LlmProviderType {
93 #[default]
95 Mock,
96 OpenAi,
98 Anthropic,
100 Custom,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct LlmConfig {
107 pub provider: LlmProviderType,
109 #[serde(default = "default_llm_model")]
111 pub model: String,
112 #[serde(default)]
114 pub api_key_env: String,
115 #[serde(default)]
117 pub base_url: Option<String>,
118 #[serde(default = "default_max_retries")]
120 pub max_retries: u8,
121 #[serde(default = "default_timeout_secs")]
123 pub timeout_secs: u64,
124 #[serde(default = "default_true_val")]
126 pub cache_enabled: bool,
127}
128
129fn default_llm_model() -> String {
130 "gpt-4o-mini".to_string()
131}
132fn default_max_retries() -> u8 {
133 3
134}
135fn default_timeout_secs() -> u64 {
136 30
137}
138fn default_true_val() -> bool {
139 true
140}
141
142impl Default for LlmConfig {
143 fn default() -> Self {
144 Self {
145 provider: LlmProviderType::default(),
146 model: default_llm_model(),
147 api_key_env: String::new(),
148 base_url: None,
149 max_retries: default_max_retries(),
150 timeout_secs: default_timeout_secs(),
151 cache_enabled: true,
152 }
153 }
154}
155
156#[cfg(test)]
157#[allow(clippy::unwrap_used)]
158mod tests {
159 use super::*;
160
161 #[test]
162 fn test_llm_request_builder() {
163 let req = LlmRequest::new("test prompt")
164 .with_system("system prompt")
165 .with_seed(42)
166 .with_max_tokens(512)
167 .with_temperature(0.5);
168 assert_eq!(req.prompt, "test prompt");
169 assert_eq!(req.system, Some("system prompt".to_string()));
170 assert_eq!(req.seed, Some(42));
171 assert_eq!(req.max_tokens, 512);
172 assert!((req.temperature - 0.5).abs() < f64::EPSILON);
173 }
174
175 #[test]
176 fn test_llm_request_serde_roundtrip() {
177 let req = LlmRequest::new("test").with_seed(42);
178 let json = serde_json::to_string(&req).unwrap();
179 let deserialized: LlmRequest = serde_json::from_str(&json).unwrap();
180 assert_eq!(deserialized.prompt, "test");
181 assert_eq!(deserialized.seed, Some(42));
182 }
183
184 #[test]
185 fn test_llm_response_serde_roundtrip() {
186 let resp = LlmResponse {
187 content: "output".to_string(),
188 usage: TokenUsage {
189 input_tokens: 10,
190 output_tokens: 20,
191 },
192 cached: false,
193 };
194 let json = serde_json::to_string(&resp).unwrap();
195 let deserialized: LlmResponse = serde_json::from_str(&json).unwrap();
196 assert_eq!(deserialized.content, "output");
197 assert_eq!(deserialized.usage.input_tokens, 10);
198 }
199
200 #[test]
201 fn test_llm_config_default() {
202 let config = LlmConfig::default();
203 assert!(matches!(config.provider, LlmProviderType::Mock));
204 assert!(config.cache_enabled);
205 assert_eq!(config.max_retries, 3);
206 }
207}