Skip to main content

phago_llm/
backend.rs

1//! Core LLM backend trait.
2
3use crate::types::{Concept, ExtractionResponse, Relationship};
4use async_trait::async_trait;
5use thiserror::Error;
6
7/// LLM-related errors.
8#[derive(Debug, Error)]
9pub enum LlmError {
10    #[error("API error: {0}")]
11    ApiError(String),
12
13    #[error("Connection failed: {0}")]
14    ConnectionFailed(String),
15
16    #[error("Rate limited: retry after {0} seconds")]
17    RateLimited(u32),
18
19    #[error("Invalid response: {0}")]
20    InvalidResponse(String),
21
22    #[error("Parsing failed: {0}")]
23    ParseError(String),
24
25    #[error("Model not found: {0}")]
26    ModelNotFound(String),
27
28    #[error("Context too long: {0} tokens (max: {1})")]
29    ContextTooLong(usize, usize),
30
31    #[error("Authentication failed")]
32    AuthenticationFailed,
33
34    #[error("Timeout after {0} seconds")]
35    Timeout(u32),
36
37    #[error("IO error: {0}")]
38    Io(#[from] std::io::Error),
39}
40
41/// Result type for LLM operations.
42pub type LlmResult<T> = Result<T, LlmError>;
43
44/// Configuration for LLM requests.
45#[derive(Debug, Clone)]
46pub struct LlmConfig {
47    /// Model name/identifier.
48    pub model: String,
49    /// Maximum tokens to generate.
50    pub max_tokens: u32,
51    /// Temperature (0.0 = deterministic, 1.0 = creative).
52    pub temperature: f32,
53    /// Request timeout in seconds.
54    pub timeout_secs: u32,
55    /// Whether to include reasoning/explanation.
56    pub include_reasoning: bool,
57}
58
59impl Default for LlmConfig {
60    fn default() -> Self {
61        Self {
62            model: "default".to_string(),
63            max_tokens: 1024,
64            temperature: 0.0,
65            timeout_secs: 30,
66            include_reasoning: false,
67        }
68    }
69}
70
71impl LlmConfig {
72    /// Create config for Claude.
73    pub fn claude() -> Self {
74        Self {
75            model: "claude-3-haiku-20240307".to_string(),
76            max_tokens: 1024,
77            temperature: 0.0,
78            timeout_secs: 30,
79            include_reasoning: false,
80        }
81    }
82
83    /// Create config for OpenAI.
84    pub fn openai() -> Self {
85        Self {
86            model: "gpt-4o-mini".to_string(),
87            max_tokens: 1024,
88            temperature: 0.0,
89            timeout_secs: 30,
90            include_reasoning: false,
91        }
92    }
93
94    /// Create config for Ollama.
95    pub fn ollama() -> Self {
96        Self {
97            model: "llama3.2".to_string(),
98            max_tokens: 1024,
99            temperature: 0.0,
100            timeout_secs: 60, // Local models can be slower
101            include_reasoning: false,
102        }
103    }
104
105    /// Set the model.
106    pub fn with_model(mut self, model: impl Into<String>) -> Self {
107        self.model = model.into();
108        self
109    }
110
111    /// Set max tokens.
112    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
113        self.max_tokens = max_tokens;
114        self
115    }
116
117    /// Set temperature.
118    pub fn with_temperature(mut self, temperature: f32) -> Self {
119        self.temperature = temperature.clamp(0.0, 2.0);
120        self
121    }
122
123    /// Set timeout.
124    pub fn with_timeout(mut self, timeout_secs: u32) -> Self {
125        self.timeout_secs = timeout_secs;
126        self
127    }
128}
129
130/// Core trait for LLM backends.
131///
132/// Implementors provide concept extraction and relationship identification
133/// using various LLM providers.
134#[async_trait]
135pub trait LlmBackend: Send + Sync {
136    /// Get the backend name.
137    fn name(&self) -> &str;
138
139    /// Get the current configuration.
140    fn config(&self) -> &LlmConfig;
141
142    /// Generate a completion for a prompt.
143    async fn complete(&self, prompt: &str) -> LlmResult<String>;
144
145    /// Extract concepts from text.
146    async fn extract_concepts(&self, text: &str) -> LlmResult<Vec<Concept>>;
147
148    /// Identify relationships between concepts.
149    async fn identify_relationships(
150        &self,
151        text: &str,
152        concepts: &[Concept],
153    ) -> LlmResult<Vec<Relationship>>;
154
155    /// Full extraction: concepts and relationships.
156    async fn extract(&self, text: &str) -> LlmResult<ExtractionResponse> {
157        let concepts = self.extract_concepts(text).await?;
158        let relationships = self.identify_relationships(text, &concepts).await?;
159        Ok(ExtractionResponse {
160            concepts,
161            relationships,
162            raw_response: None,
163            tokens_used: None,
164        })
165    }
166
167    /// Expand a query for better recall.
168    async fn expand_query(&self, query: &str) -> LlmResult<Vec<String>> {
169        // Default implementation: return the query as-is
170        Ok(vec![query.to_string()])
171    }
172
173    /// Summarize a cluster of concepts.
174    async fn summarize_cluster(&self, concepts: &[&str]) -> LlmResult<String> {
175        // Default implementation: join concepts
176        Ok(concepts.join(", "))
177    }
178
179    /// Check if the backend is available.
180    async fn health_check(&self) -> LlmResult<bool> {
181        // Default: try a simple completion
182        match self.complete("ping").await {
183            Ok(_) => Ok(true),
184            Err(e) => {
185                // Connection errors mean unavailable, other errors might be OK
186                match e {
187                    LlmError::ConnectionFailed(_) => Ok(false),
188                    LlmError::AuthenticationFailed => Ok(false),
189                    _ => Ok(true),
190                }
191            }
192        }
193    }
194}
195
196/// A mock backend for testing.
197pub struct MockBackend {
198    config: LlmConfig,
199    responses: std::collections::HashMap<String, String>,
200}
201
202impl MockBackend {
203    /// Create a new mock backend.
204    pub fn new() -> Self {
205        Self {
206            config: LlmConfig::default(),
207            responses: std::collections::HashMap::new(),
208        }
209    }
210
211    /// Add a canned response for a prompt pattern.
212    pub fn with_response(mut self, pattern: &str, response: &str) -> Self {
213        self.responses.insert(pattern.to_string(), response.to_string());
214        self
215    }
216}
217
218impl Default for MockBackend {
219    fn default() -> Self {
220        Self::new()
221    }
222}
223
224#[async_trait]
225impl LlmBackend for MockBackend {
226    fn name(&self) -> &str {
227        "mock"
228    }
229
230    fn config(&self) -> &LlmConfig {
231        &self.config
232    }
233
234    async fn complete(&self, prompt: &str) -> LlmResult<String> {
235        // Check for matching pattern
236        for (pattern, response) in &self.responses {
237            if prompt.contains(pattern) {
238                return Ok(response.clone());
239            }
240        }
241        Ok("Mock response".to_string())
242    }
243
244    async fn extract_concepts(&self, text: &str) -> LlmResult<Vec<Concept>> {
245        // Simple keyword extraction for testing
246        let words: Vec<&str> = text
247            .split(|c: char| !c.is_alphanumeric())
248            .filter(|w| w.len() >= 4)
249            .collect();
250
251        let concepts: Vec<Concept> = words
252            .into_iter()
253            .take(5)
254            .map(|w| Concept::new(w.to_lowercase()))
255            .collect();
256
257        Ok(concepts)
258    }
259
260    async fn identify_relationships(
261        &self,
262        _text: &str,
263        concepts: &[Concept],
264    ) -> LlmResult<Vec<Relationship>> {
265        // Create simple relationships between consecutive concepts
266        let relationships: Vec<Relationship> = concepts
267            .windows(2)
268            .map(|pair| {
269                Relationship::new(&pair[0].label, &pair[1].label, "related_to")
270            })
271            .collect();
272
273        Ok(relationships)
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    #[tokio::test]
282    async fn test_mock_backend() {
283        let backend = MockBackend::new()
284            .with_response("test", "Test response");
285
286        let response = backend.complete("This is a test").await.unwrap();
287        assert_eq!(response, "Test response");
288    }
289
290    #[tokio::test]
291    async fn test_mock_extract_concepts() {
292        let backend = MockBackend::new();
293        let concepts = backend
294            .extract_concepts("The mitochondria produces ATP in the cell")
295            .await
296            .unwrap();
297
298        assert!(!concepts.is_empty());
299        assert!(concepts.iter().any(|c| c.label == "mitochondria"));
300    }
301
302    #[tokio::test]
303    async fn test_mock_relationships() {
304        let backend = MockBackend::new();
305        let concepts = vec![
306            Concept::new("mitochondria"),
307            Concept::new("ATP"),
308            Concept::new("cell"),
309        ];
310
311        let relationships = backend
312            .identify_relationships("", &concepts)
313            .await
314            .unwrap();
315
316        assert_eq!(relationships.len(), 2);
317        assert_eq!(relationships[0].source, "mitochondria");
318        assert_eq!(relationships[0].target, "ATP");
319    }
320
321    #[test]
322    fn test_config_builders() {
323        let claude = LlmConfig::claude();
324        assert!(claude.model.contains("claude"));
325
326        let openai = LlmConfig::openai();
327        assert!(openai.model.contains("gpt"));
328
329        let ollama = LlmConfig::ollama();
330        assert!(ollama.model.contains("llama"));
331    }
332}