Skip to main content

mentedb_extraction/
provider.rs

1use crate::config::{ExtractionConfig, LlmProvider};
2use crate::error::ExtractionError;
3
4/// Trait for LLM providers that can extract memories from conversation text.
5pub trait ExtractionProvider: Send + Sync {
6    /// Send a conversation to the LLM with the given system prompt and return
7    /// the raw response text (expected to be JSON).
8    fn extract(
9        &self,
10        conversation: &str,
11        system_prompt: &str,
12    ) -> impl std::future::Future<Output = Result<String, ExtractionError>> + Send;
13}
14
15/// HTTP-based extraction provider that calls OpenAI, Anthropic, or Ollama APIs.
16pub struct HttpExtractionProvider {
17    client: reqwest::Client,
18    config: ExtractionConfig,
19}
20
21impl HttpExtractionProvider {
22    pub fn new(config: ExtractionConfig) -> Result<Self, ExtractionError> {
23        if config.provider != LlmProvider::Ollama && config.api_key.is_none() {
24            return Err(ExtractionError::ConfigError(
25                "API key is required for this provider".to_string(),
26            ));
27        }
28        let client = reqwest::Client::new();
29        Ok(Self { client, config })
30    }
31
32    async fn call_openai(
33        &self,
34        conversation: &str,
35        system_prompt: &str,
36    ) -> Result<String, ExtractionError> {
37        let body = serde_json::json!({
38            "model": self.config.model,
39            "response_format": { "type": "json_object" },
40            "messages": [
41                { "role": "system", "content": system_prompt },
42                { "role": "user", "content": conversation }
43            ]
44        });
45
46        let api_key = self.config.api_key.as_deref().unwrap_or_default();
47
48        let resp = self
49            .client
50            .post(&self.config.api_url)
51            .header("Authorization", format!("Bearer {api_key}"))
52            .header("Content-Type", "application/json")
53            .json(&body)
54            .send()
55            .await?;
56
57        let status = resp.status();
58        let text = resp.text().await?;
59
60        if !status.is_success() {
61            return Err(ExtractionError::ProviderError(format!(
62                "OpenAI API returned {status}: {text}"
63            )));
64        }
65
66        let parsed: serde_json::Value = serde_json::from_str(&text)?;
67        parsed["choices"][0]["message"]["content"]
68            .as_str()
69            .map(|s| s.to_string())
70            .ok_or_else(|| {
71                ExtractionError::ParseError("Missing content in OpenAI response".to_string())
72            })
73    }
74
75    async fn call_anthropic(
76        &self,
77        conversation: &str,
78        system_prompt: &str,
79    ) -> Result<String, ExtractionError> {
80        let body = serde_json::json!({
81            "model": self.config.model,
82            "max_tokens": 4096,
83            "system": system_prompt,
84            "messages": [
85                { "role": "user", "content": conversation }
86            ]
87        });
88
89        let api_key = self.config.api_key.as_deref().unwrap_or_default();
90
91        let resp = self
92            .client
93            .post(&self.config.api_url)
94            .header("x-api-key", api_key)
95            .header("anthropic-version", "2023-06-01")
96            .header("Content-Type", "application/json")
97            .json(&body)
98            .send()
99            .await?;
100
101        let status = resp.status();
102        let text = resp.text().await?;
103
104        if !status.is_success() {
105            return Err(ExtractionError::ProviderError(format!(
106                "Anthropic API returned {status}: {text}"
107            )));
108        }
109
110        let parsed: serde_json::Value = serde_json::from_str(&text)?;
111        parsed["content"][0]["text"]
112            .as_str()
113            .map(|s| s.to_string())
114            .ok_or_else(|| {
115                ExtractionError::ParseError("Missing text in Anthropic response".to_string())
116            })
117    }
118
119    async fn call_ollama(
120        &self,
121        conversation: &str,
122        system_prompt: &str,
123    ) -> Result<String, ExtractionError> {
124        let body = serde_json::json!({
125            "model": self.config.model,
126            "stream": false,
127            "format": "json",
128            "messages": [
129                { "role": "system", "content": system_prompt },
130                { "role": "user", "content": conversation }
131            ]
132        });
133
134        let resp = self
135            .client
136            .post(&self.config.api_url)
137            .header("Content-Type", "application/json")
138            .json(&body)
139            .send()
140            .await?;
141
142        let status = resp.status();
143        let text = resp.text().await?;
144
145        if !status.is_success() {
146            return Err(ExtractionError::ProviderError(format!(
147                "Ollama API returned {status}: {text}"
148            )));
149        }
150
151        let parsed: serde_json::Value = serde_json::from_str(&text)?;
152        parsed["message"]["content"]
153            .as_str()
154            .map(|s| s.to_string())
155            .ok_or_else(|| {
156                ExtractionError::ParseError("Missing content in Ollama response".to_string())
157            })
158    }
159
160    /// Execute a request with retry logic for rate limits (HTTP 429).
161    /// Uses exponential backoff: 1s, 2s, 4s.
162    async fn call_with_retry(
163        &self,
164        conversation: &str,
165        system_prompt: &str,
166    ) -> Result<String, ExtractionError> {
167        let max_attempts = 3;
168        let mut last_err = None;
169
170        for attempt in 0..max_attempts {
171            if attempt > 0 {
172                let delay = std::time::Duration::from_secs(1 << attempt);
173                tracing::warn!(
174                    attempt,
175                    delay_secs = delay.as_secs(),
176                    "retrying after rate limit"
177                );
178                tokio::time::sleep(delay).await;
179            }
180
181            tracing::info!(
182                provider = ?self.config.provider,
183                model = %self.config.model,
184                attempt = attempt + 1,
185                "calling LLM extraction API"
186            );
187
188            let result = match self.config.provider {
189                LlmProvider::OpenAI | LlmProvider::Custom => {
190                    self.call_openai(conversation, system_prompt).await
191                }
192                LlmProvider::Anthropic => self.call_anthropic(conversation, system_prompt).await,
193                LlmProvider::Ollama => self.call_ollama(conversation, system_prompt).await,
194            };
195
196            match result {
197                Ok(text) => {
198                    tracing::info!(response_len = text.len(), "LLM extraction complete");
199                    return Ok(text);
200                }
201                Err(ExtractionError::ProviderError(ref msg)) if msg.contains("429") => {
202                    tracing::warn!(attempt = attempt + 1, "rate limited by provider");
203                    last_err = Some(result.unwrap_err());
204                    continue;
205                }
206                Err(e) => {
207                    tracing::error!(error = %e, "LLM extraction failed");
208                    return Err(e);
209                }
210            }
211        }
212
213        match last_err {
214            Some(e) => Err(e),
215            None => Err(ExtractionError::RateLimitExceeded {
216                attempts: max_attempts,
217            }),
218        }
219    }
220}
221
222impl ExtractionProvider for HttpExtractionProvider {
223    async fn extract(
224        &self,
225        conversation: &str,
226        system_prompt: &str,
227    ) -> Result<String, ExtractionError> {
228        self.call_with_retry(conversation, system_prompt).await
229    }
230}
231
232/// Mock extraction provider for testing. Returns a predefined JSON response.
233pub struct MockExtractionProvider {
234    response: String,
235}
236
237impl MockExtractionProvider {
238    /// Create a mock provider that always returns the given JSON string.
239    pub fn new(response: impl Into<String>) -> Self {
240        Self {
241            response: response.into(),
242        }
243    }
244
245    /// Create a mock provider with a realistic extraction response.
246    pub fn with_realistic_response() -> Self {
247        let response = serde_json::json!({
248            "memories": [
249                {
250                    "content": "The team decided to use PostgreSQL 15 as the primary database for the REST API project",
251                    "memory_type": "decision",
252                    "confidence": 0.95,
253                    "entities": ["PostgreSQL", "REST API"],
254                    "tags": ["database", "architecture"],
255                    "reasoning": "Explicitly decided after comparing options"
256                },
257                {
258                    "content": "REST endpoints should follow the /api/v1/ prefix convention",
259                    "memory_type": "decision",
260                    "confidence": 0.9,
261                    "entities": ["REST API"],
262                    "tags": ["api-design", "conventions"],
263                    "reasoning": "Team agreed on URL structure"
264                },
265                {
266                    "content": "User prefers Rust over Go for backend services due to memory safety guarantees",
267                    "memory_type": "preference",
268                    "confidence": 0.85,
269                    "entities": ["Rust", "Go"],
270                    "tags": ["language", "backend"],
271                    "reasoning": "Explicitly stated preference with clear reasoning"
272                },
273                {
274                    "content": "The initial plan to use MongoDB was incorrect; PostgreSQL is the right choice for relational data",
275                    "memory_type": "correction",
276                    "confidence": 0.9,
277                    "entities": ["MongoDB", "PostgreSQL"],
278                    "tags": ["database", "correction"],
279                    "reasoning": "Corrected an earlier wrong assumption"
280                },
281                {
282                    "content": "The project deadline is March 15, 2025",
283                    "memory_type": "fact",
284                    "confidence": 0.8,
285                    "entities": ["REST API project"],
286                    "tags": ["timeline"],
287                    "reasoning": "Confirmed date mentioned in discussion"
288                },
289                {
290                    "content": "Using global mutable state for database connections caused race conditions in testing",
291                    "memory_type": "anti_pattern",
292                    "confidence": 0.85,
293                    "entities": [],
294                    "tags": ["testing", "concurrency"],
295                    "reasoning": "Documented failure pattern to avoid repeating"
296                },
297                {
298                    "content": "Low confidence speculation about maybe using Redis",
299                    "memory_type": "fact",
300                    "confidence": 0.3,
301                    "entities": ["Redis"],
302                    "tags": ["cache"],
303                    "reasoning": "Mentioned but not confirmed"
304                }
305            ]
306        });
307        Self::new(response.to_string())
308    }
309}
310
311impl ExtractionProvider for MockExtractionProvider {
312    async fn extract(
313        &self,
314        _conversation: &str,
315        _system_prompt: &str,
316    ) -> Result<String, ExtractionError> {
317        Ok(self.response.clone())
318    }
319}