Skip to main content

datasynth_core/llm/
provider.rs

1use serde::{Deserialize, Serialize};
2
3use crate::error::SynthError;
4
5/// LLM provider trait for AI-augmented generation.
6pub trait LlmProvider: Send + Sync {
7    /// Provider name.
8    fn name(&self) -> &str;
9    /// Complete a single request.
10    fn complete(&self, request: &LlmRequest) -> Result<LlmResponse, SynthError>;
11    /// Complete a batch of requests.
12    fn complete_batch(&self, requests: &[LlmRequest]) -> Result<Vec<LlmResponse>, SynthError> {
13        requests.iter().map(|r| self.complete(r)).collect()
14    }
15}
16
17/// A request to an LLM provider.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct LlmRequest {
20    /// The user prompt.
21    pub prompt: String,
22    /// Optional system prompt.
23    pub system: Option<String>,
24    /// Maximum tokens in the response.
25    pub max_tokens: u32,
26    /// Sampling temperature.
27    pub temperature: f64,
28    /// Optional seed for deterministic output.
29    pub seed: Option<u64>,
30}
31
32impl LlmRequest {
33    /// Create a new request with the given prompt.
34    pub fn new(prompt: impl Into<String>) -> Self {
35        Self {
36            prompt: prompt.into(),
37            system: None,
38            max_tokens: 1024,
39            temperature: 0.7,
40            seed: None,
41        }
42    }
43
44    /// Set the system prompt.
45    pub fn with_system(mut self, system: impl Into<String>) -> Self {
46        self.system = Some(system.into());
47        self
48    }
49
50    /// Set the seed for deterministic output.
51    pub fn with_seed(mut self, seed: u64) -> Self {
52        self.seed = Some(seed);
53        self
54    }
55
56    /// Set the maximum number of tokens.
57    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
58        self.max_tokens = max_tokens;
59        self
60    }
61
62    /// Set the sampling temperature.
63    pub fn with_temperature(mut self, temperature: f64) -> Self {
64        self.temperature = temperature;
65        self
66    }
67}
68
69/// A response from an LLM provider.
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct LlmResponse {
72    /// The generated content.
73    pub content: String,
74    /// Token usage statistics.
75    pub usage: TokenUsage,
76    /// Whether this response was served from cache.
77    pub cached: bool,
78}
79
80/// Token usage statistics for an LLM request.
81#[derive(Debug, Clone, Default, Serialize, Deserialize)]
82pub struct TokenUsage {
83    /// Number of input (prompt) tokens.
84    pub input_tokens: u32,
85    /// Number of output (completion) tokens.
86    pub output_tokens: u32,
87}
88
89/// LLM provider type selection.
90#[derive(Debug, Clone, Default, Serialize, Deserialize)]
91#[serde(rename_all = "snake_case")]
92pub enum LlmProviderType {
93    /// Deterministic mock provider (no network calls).
94    #[default]
95    Mock,
96    /// OpenAI-compatible API provider.
97    OpenAi,
98    /// Anthropic API provider.
99    Anthropic,
100    /// Custom provider with user-specified base URL.
101    Custom,
102}
103
104/// LLM configuration.
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct LlmConfig {
107    /// Which provider type to use.
108    pub provider: LlmProviderType,
109    /// Model name/ID.
110    #[serde(default = "default_llm_model")]
111    pub model: String,
112    /// Environment variable containing the API key.
113    #[serde(default)]
114    pub api_key_env: String,
115    /// Custom API base URL (overrides provider default).
116    #[serde(default)]
117    pub base_url: Option<String>,
118    /// Maximum retry attempts for failed requests.
119    #[serde(default = "default_max_retries")]
120    pub max_retries: u8,
121    /// Request timeout in seconds.
122    #[serde(default = "default_timeout_secs")]
123    pub timeout_secs: u64,
124    /// Whether to cache responses.
125    #[serde(default = "default_true_val")]
126    pub cache_enabled: bool,
127}
128
129fn default_llm_model() -> String {
130    "gpt-4o-mini".to_string()
131}
132fn default_max_retries() -> u8 {
133    3
134}
135fn default_timeout_secs() -> u64 {
136    30
137}
138fn default_true_val() -> bool {
139    true
140}
141
142impl Default for LlmConfig {
143    fn default() -> Self {
144        Self {
145            provider: LlmProviderType::default(),
146            model: default_llm_model(),
147            api_key_env: String::new(),
148            base_url: None,
149            max_retries: default_max_retries(),
150            timeout_secs: default_timeout_secs(),
151            cache_enabled: true,
152        }
153    }
154}
155
156#[cfg(test)]
157#[allow(clippy::unwrap_used)]
158mod tests {
159    use super::*;
160
161    #[test]
162    fn test_llm_request_builder() {
163        let req = LlmRequest::new("test prompt")
164            .with_system("system prompt")
165            .with_seed(42)
166            .with_max_tokens(512)
167            .with_temperature(0.5);
168        assert_eq!(req.prompt, "test prompt");
169        assert_eq!(req.system, Some("system prompt".to_string()));
170        assert_eq!(req.seed, Some(42));
171        assert_eq!(req.max_tokens, 512);
172        assert!((req.temperature - 0.5).abs() < f64::EPSILON);
173    }
174
175    #[test]
176    fn test_llm_request_serde_roundtrip() {
177        let req = LlmRequest::new("test").with_seed(42);
178        let json = serde_json::to_string(&req).unwrap();
179        let deserialized: LlmRequest = serde_json::from_str(&json).unwrap();
180        assert_eq!(deserialized.prompt, "test");
181        assert_eq!(deserialized.seed, Some(42));
182    }
183
184    #[test]
185    fn test_llm_response_serde_roundtrip() {
186        let resp = LlmResponse {
187            content: "output".to_string(),
188            usage: TokenUsage {
189                input_tokens: 10,
190                output_tokens: 20,
191            },
192            cached: false,
193        };
194        let json = serde_json::to_string(&resp).unwrap();
195        let deserialized: LlmResponse = serde_json::from_str(&json).unwrap();
196        assert_eq!(deserialized.content, "output");
197        assert_eq!(deserialized.usage.input_tokens, 10);
198    }
199
200    #[test]
201    fn test_llm_config_default() {
202        let config = LlmConfig::default();
203        assert!(matches!(config.provider, LlmProviderType::Mock));
204        assert!(config.cache_enabled);
205        assert_eq!(config.max_retries, 3);
206    }
207}