1use crate::types::{Concept, ExtractionResponse, Relationship};
4use async_trait::async_trait;
5use thiserror::Error;
6
7#[derive(Debug, Error)]
9pub enum LlmError {
10 #[error("API error: {0}")]
11 ApiError(String),
12
13 #[error("Connection failed: {0}")]
14 ConnectionFailed(String),
15
16 #[error("Rate limited: retry after {0} seconds")]
17 RateLimited(u32),
18
19 #[error("Invalid response: {0}")]
20 InvalidResponse(String),
21
22 #[error("Parsing failed: {0}")]
23 ParseError(String),
24
25 #[error("Model not found: {0}")]
26 ModelNotFound(String),
27
28 #[error("Context too long: {0} tokens (max: {1})")]
29 ContextTooLong(usize, usize),
30
31 #[error("Authentication failed")]
32 AuthenticationFailed,
33
34 #[error("Timeout after {0} seconds")]
35 Timeout(u32),
36
37 #[error("IO error: {0}")]
38 Io(#[from] std::io::Error),
39}
40
41pub type LlmResult<T> = Result<T, LlmError>;
43
44#[derive(Debug, Clone)]
46pub struct LlmConfig {
47 pub model: String,
49 pub max_tokens: u32,
51 pub temperature: f32,
53 pub timeout_secs: u32,
55 pub include_reasoning: bool,
57}
58
59impl Default for LlmConfig {
60 fn default() -> Self {
61 Self {
62 model: "default".to_string(),
63 max_tokens: 1024,
64 temperature: 0.0,
65 timeout_secs: 30,
66 include_reasoning: false,
67 }
68 }
69}
70
71impl LlmConfig {
72 pub fn claude() -> Self {
74 Self {
75 model: "claude-3-haiku-20240307".to_string(),
76 max_tokens: 1024,
77 temperature: 0.0,
78 timeout_secs: 30,
79 include_reasoning: false,
80 }
81 }
82
83 pub fn openai() -> Self {
85 Self {
86 model: "gpt-4o-mini".to_string(),
87 max_tokens: 1024,
88 temperature: 0.0,
89 timeout_secs: 30,
90 include_reasoning: false,
91 }
92 }
93
94 pub fn ollama() -> Self {
96 Self {
97 model: "llama3.2".to_string(),
98 max_tokens: 1024,
99 temperature: 0.0,
100 timeout_secs: 60, include_reasoning: false,
102 }
103 }
104
105 pub fn with_model(mut self, model: impl Into<String>) -> Self {
107 self.model = model.into();
108 self
109 }
110
111 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
113 self.max_tokens = max_tokens;
114 self
115 }
116
117 pub fn with_temperature(mut self, temperature: f32) -> Self {
119 self.temperature = temperature.clamp(0.0, 2.0);
120 self
121 }
122
123 pub fn with_timeout(mut self, timeout_secs: u32) -> Self {
125 self.timeout_secs = timeout_secs;
126 self
127 }
128}
129
130#[async_trait]
135pub trait LlmBackend: Send + Sync {
136 fn name(&self) -> &str;
138
139 fn config(&self) -> &LlmConfig;
141
142 async fn complete(&self, prompt: &str) -> LlmResult<String>;
144
145 async fn extract_concepts(&self, text: &str) -> LlmResult<Vec<Concept>>;
147
148 async fn identify_relationships(
150 &self,
151 text: &str,
152 concepts: &[Concept],
153 ) -> LlmResult<Vec<Relationship>>;
154
155 async fn extract(&self, text: &str) -> LlmResult<ExtractionResponse> {
157 let concepts = self.extract_concepts(text).await?;
158 let relationships = self.identify_relationships(text, &concepts).await?;
159 Ok(ExtractionResponse {
160 concepts,
161 relationships,
162 raw_response: None,
163 tokens_used: None,
164 })
165 }
166
167 async fn expand_query(&self, query: &str) -> LlmResult<Vec<String>> {
169 Ok(vec![query.to_string()])
171 }
172
173 async fn summarize_cluster(&self, concepts: &[&str]) -> LlmResult<String> {
175 Ok(concepts.join(", "))
177 }
178
179 async fn health_check(&self) -> LlmResult<bool> {
181 match self.complete("ping").await {
183 Ok(_) => Ok(true),
184 Err(e) => {
185 match e {
187 LlmError::ConnectionFailed(_) => Ok(false),
188 LlmError::AuthenticationFailed => Ok(false),
189 _ => Ok(true),
190 }
191 }
192 }
193 }
194}
195
196pub struct MockBackend {
198 config: LlmConfig,
199 responses: std::collections::HashMap<String, String>,
200}
201
202impl MockBackend {
203 pub fn new() -> Self {
205 Self {
206 config: LlmConfig::default(),
207 responses: std::collections::HashMap::new(),
208 }
209 }
210
211 pub fn with_response(mut self, pattern: &str, response: &str) -> Self {
213 self.responses.insert(pattern.to_string(), response.to_string());
214 self
215 }
216}
217
218impl Default for MockBackend {
219 fn default() -> Self {
220 Self::new()
221 }
222}
223
224#[async_trait]
225impl LlmBackend for MockBackend {
226 fn name(&self) -> &str {
227 "mock"
228 }
229
230 fn config(&self) -> &LlmConfig {
231 &self.config
232 }
233
234 async fn complete(&self, prompt: &str) -> LlmResult<String> {
235 for (pattern, response) in &self.responses {
237 if prompt.contains(pattern) {
238 return Ok(response.clone());
239 }
240 }
241 Ok("Mock response".to_string())
242 }
243
244 async fn extract_concepts(&self, text: &str) -> LlmResult<Vec<Concept>> {
245 let words: Vec<&str> = text
247 .split(|c: char| !c.is_alphanumeric())
248 .filter(|w| w.len() >= 4)
249 .collect();
250
251 let concepts: Vec<Concept> = words
252 .into_iter()
253 .take(5)
254 .map(|w| Concept::new(w.to_lowercase()))
255 .collect();
256
257 Ok(concepts)
258 }
259
260 async fn identify_relationships(
261 &self,
262 _text: &str,
263 concepts: &[Concept],
264 ) -> LlmResult<Vec<Relationship>> {
265 let relationships: Vec<Relationship> = concepts
267 .windows(2)
268 .map(|pair| {
269 Relationship::new(&pair[0].label, &pair[1].label, "related_to")
270 })
271 .collect();
272
273 Ok(relationships)
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280
281 #[tokio::test]
282 async fn test_mock_backend() {
283 let backend = MockBackend::new()
284 .with_response("test", "Test response");
285
286 let response = backend.complete("This is a test").await.unwrap();
287 assert_eq!(response, "Test response");
288 }
289
290 #[tokio::test]
291 async fn test_mock_extract_concepts() {
292 let backend = MockBackend::new();
293 let concepts = backend
294 .extract_concepts("The mitochondria produces ATP in the cell")
295 .await
296 .unwrap();
297
298 assert!(!concepts.is_empty());
299 assert!(concepts.iter().any(|c| c.label == "mitochondria"));
300 }
301
302 #[tokio::test]
303 async fn test_mock_relationships() {
304 let backend = MockBackend::new();
305 let concepts = vec![
306 Concept::new("mitochondria"),
307 Concept::new("ATP"),
308 Concept::new("cell"),
309 ];
310
311 let relationships = backend
312 .identify_relationships("", &concepts)
313 .await
314 .unwrap();
315
316 assert_eq!(relationships.len(), 2);
317 assert_eq!(relationships[0].source, "mitochondria");
318 assert_eq!(relationships[0].target, "ATP");
319 }
320
321 #[test]
322 fn test_config_builders() {
323 let claude = LlmConfig::claude();
324 assert!(claude.model.contains("claude"));
325
326 let openai = LlmConfig::openai();
327 assert!(openai.model.contains("gpt"));
328
329 let ollama = LlmConfig::ollama();
330 assert!(ollama.model.contains("llama"));
331 }
332}