Skip to main content

mentedb_extraction/
provider.rs

1use crate::config::{ExtractionConfig, LlmProvider};
2use crate::error::ExtractionError;
3
4/// Classify an HTTP error response into a specific ExtractionError variant.
5fn classify_api_error(
6    status: reqwest::StatusCode,
7    body: &str,
8    provider: &str,
9    model: &str,
10) -> ExtractionError {
11    let code = status.as_u16();
12    match code {
13        401 => ExtractionError::AuthError(format!(
14            "{provider} returned 401 Unauthorized. Check your API key (MENTEDB_LLM_API_KEY). \
15             Current provider: {provider}, model: {model}"
16        )),
17        403 => ExtractionError::AuthError(format!(
18            "{provider} returned 403 Forbidden. Your API key may lack permissions for model '{model}'."
19        )),
20        404 => ExtractionError::ModelNotFound(format!(
21            "{provider} returned 404. Model '{model}' may not exist or is not available on your account."
22        )),
23        _ => ExtractionError::ProviderError(format!("{provider} API returned {status}: {body}")),
24    }
25}
26
27/// Trait for LLM providers that can extract memories from conversation text.
28pub trait ExtractionProvider: Send + Sync {
29    /// Send a conversation to the LLM with the given system prompt and return
30    /// the raw response text (expected to be JSON).
31    fn extract(
32        &self,
33        conversation: &str,
34        system_prompt: &str,
35    ) -> impl std::future::Future<Output = Result<String, ExtractionError>> + Send;
36}
37
38/// HTTP-based extraction provider that calls OpenAI, Anthropic, or Ollama APIs.
39pub struct HttpExtractionProvider {
40    client: reqwest::Client,
41    config: ExtractionConfig,
42}
43
44impl HttpExtractionProvider {
45    pub fn new(config: ExtractionConfig) -> Result<Self, ExtractionError> {
46        if config.provider != LlmProvider::Ollama && config.api_key.is_none() {
47            return Err(ExtractionError::ConfigError(
48                "API key is required for this provider".to_string(),
49            ));
50        }
51        let client = reqwest::Client::builder()
52            .timeout(std::time::Duration::from_secs(120))
53            .connect_timeout(std::time::Duration::from_secs(30))
54            .build()
55            .map_err(|e| ExtractionError::ConfigError(format!("HTTP client error: {}", e)))?;
56        Ok(Self { client, config })
57    }
58
59    /// Expand a search query into multiple sub-queries via LLM.
60    ///
61    /// Given a natural language question, identifies the expected answer type
62    /// and extracts 2-3 targeted search queries. The first line of the response
63    /// is the answer type (PLACE, DATE, NUMBER, NAME, PERSON, BRAND, etc.),
64    /// followed by the search queries.
65    ///
66    /// For counting/aggregation/comparison queries, also generates comprehensive
67    /// category synonyms for exhaustive BM25 sweep.
68    pub async fn expand_query(&self, query: &str) -> Result<Vec<String>, ExtractionError> {
69        let system_prompt = "You help search a memory database. Given a question, return a JSON object with:\n\
70            - \"answer_type\": one of PLACE, DATE, TIME, NUMBER, NAME, PERSON, BRAND, ITEM, ACTIVITY, COUNTING, OTHER\n\
71            - \"queries\": array of 2-3 short search queries\n\
72            - For COUNTING only, also include:\n\
73              - \"item_keywords\": comma-separated specific subtypes/instances that would be individually counted\n\
74              - \"broad_keywords\": comma-separated category terms, action verbs, and general synonyms\n\n\
75            Use COUNTING when the question requires COMPLETENESS — counting, listing, aggregating, totaling, \
76            or comparing to find a superlative (most, least, best, worst, first, last, biggest, highest, lowest).\n\n\
77            The distinction matters:\n\
78            - item_keywords: specific things you would COUNT (types of the thing being asked about)\n\
79            - broad_keywords: general terms that help FIND memories but aren't counted themselves\n\n\
80            Examples:\n\
81            Q: \"Where do I take yoga classes?\"\n\
82            {\"answer_type\": \"PLACE\", \"queries\": [\"yoga studio name\", \"yoga class location\"]}\n\n\
83            Q: \"How many doctors did I visit?\"\n\
84            {\"answer_type\": \"COUNTING\", \"queries\": [\"doctor visits appointments\", \"medical specialist visits\"], \
85            \"item_keywords\": \"doctor, Dr., physician, specialist, dermatologist, cardiologist, dentist, surgeon, pediatrician, orthopedist, ophthalmologist\", \
86            \"broad_keywords\": \"medical, clinic, appointment, visit, diagnosed, prescribed, referred, checkup, exam\"}\n\n\
87            Q: \"Which platform did I gain the most followers on?\"\n\
88            {\"answer_type\": \"COUNTING\", \"queries\": [\"social media follower growth\", \"follower count increase\"], \
89            \"item_keywords\": \"TikTok, Instagram, Twitter, YouTube, Facebook, LinkedIn, Snapchat, Reddit, Twitch\", \
90            \"broad_keywords\": \"followers, follower count, gained, growth, platform, social media, increase, jumped, grew\"}";
91        let result = self.call_with_retry(query, system_prompt).await?;
92
93        // Parse JSON response (call_openai forces json_object response format)
94        let mut lines: Vec<String> = Vec::new();
95        let cleaned = result
96            .trim()
97            .trim_start_matches("```json")
98            .trim_end_matches("```")
99            .trim();
100        if let Ok(json) = serde_json::from_str::<serde_json::Value>(cleaned) {
101            if let Some(answer_type) = json.get("answer_type").and_then(|v| v.as_str()) {
102                lines.push(answer_type.to_string());
103            }
104            if let Some(queries) = json.get("queries").and_then(|v| v.as_array()) {
105                for q in queries {
106                    if let Some(s) = q.as_str() {
107                        lines.push(s.to_string());
108                    }
109                }
110            }
111            if let Some(item_kw) = json.get("item_keywords").and_then(|v| v.as_str()) {
112                lines.push(format!("ITEM_KEYWORDS: {}", item_kw));
113            }
114            if let Some(broad_kw) = json.get("broad_keywords").and_then(|v| v.as_str()) {
115                lines.push(format!("BROAD_KEYWORDS: {}", broad_kw));
116            }
117            // Fallback: old single "keywords" field → treat all as item keywords
118            if let Some(keywords) = json.get("keywords").and_then(|v| v.as_str())
119                && json.get("item_keywords").is_none()
120            {
121                lines.push(format!("ITEM_KEYWORDS: {}", keywords));
122            }
123        } else {
124            // Fallback: parse as plain text lines
125            lines = result
126                .lines()
127                .map(|l| l.trim().to_string())
128                .filter(|l| !l.is_empty())
129                .collect();
130        }
131        if std::env::var("MENTEDB_DEBUG").is_ok() {
132            eprintln!("[expand_query] input={:?} parsed={:?}", query, lines);
133        }
134        Ok(lines)
135    }
136
137    async fn call_openai(
138        &self,
139        conversation: &str,
140        system_prompt: &str,
141    ) -> Result<String, ExtractionError> {
142        let body = serde_json::json!({
143            "model": self.config.model,
144            "response_format": { "type": "json_object" },
145            "messages": [
146                { "role": "system", "content": system_prompt },
147                { "role": "user", "content": conversation }
148            ]
149        });
150
151        let api_key = self.config.api_key.as_deref().unwrap_or_default();
152
153        let resp = self
154            .client
155            .post(&self.config.api_url)
156            .header("Authorization", format!("Bearer {api_key}"))
157            .header("Content-Type", "application/json")
158            .json(&body)
159            .send()
160            .await?;
161
162        let status = resp.status();
163        let text = resp.text().await?;
164
165        if !status.is_success() {
166            return Err(classify_api_error(
167                status,
168                &text,
169                "OpenAI",
170                &self.config.model,
171            ));
172        }
173
174        let parsed: serde_json::Value = serde_json::from_str(&text)?;
175        parsed["choices"][0]["message"]["content"]
176            .as_str()
177            .map(|s| s.to_string())
178            .ok_or_else(|| {
179                ExtractionError::ParseError("Missing content in OpenAI response".to_string())
180            })
181    }
182
183    /// OpenAI call without forced JSON response format.
184    /// Used for plain text outputs (synthesis, re-ranking, key noun extraction).
185    async fn call_openai_text(
186        &self,
187        conversation: &str,
188        system_prompt: &str,
189    ) -> Result<String, ExtractionError> {
190        let body = serde_json::json!({
191            "model": self.config.model,
192            "messages": [
193                { "role": "system", "content": system_prompt },
194                { "role": "user", "content": conversation }
195            ]
196        });
197
198        let api_key = self.config.api_key.as_deref().unwrap_or_default();
199
200        let resp = self
201            .client
202            .post(&self.config.api_url)
203            .header("Authorization", format!("Bearer {api_key}"))
204            .header("Content-Type", "application/json")
205            .json(&body)
206            .send()
207            .await?;
208
209        let status = resp.status();
210        let text = resp.text().await?;
211
212        if !status.is_success() {
213            return Err(classify_api_error(
214                status,
215                &text,
216                "OpenAI",
217                &self.config.model,
218            ));
219        }
220
221        let parsed: serde_json::Value = serde_json::from_str(&text)?;
222        parsed["choices"][0]["message"]["content"]
223            .as_str()
224            .map(|s| s.to_string())
225            .ok_or_else(|| {
226                ExtractionError::ParseError("Missing content in OpenAI response".to_string())
227            })
228    }
229
230    async fn call_anthropic(
231        &self,
232        conversation: &str,
233        system_prompt: &str,
234    ) -> Result<String, ExtractionError> {
235        let body = serde_json::json!({
236            "model": self.config.model,
237            "max_tokens": 4096,
238            "system": system_prompt,
239            "messages": [
240                { "role": "user", "content": conversation }
241            ]
242        });
243
244        let api_key = self.config.api_key.as_deref().unwrap_or_default();
245
246        let resp = self
247            .client
248            .post(&self.config.api_url)
249            .header("x-api-key", api_key)
250            .header("anthropic-version", "2023-06-01")
251            .header("Content-Type", "application/json")
252            .json(&body)
253            .send()
254            .await?;
255
256        let status = resp.status();
257        let text = resp.text().await?;
258
259        if !status.is_success() {
260            return Err(classify_api_error(
261                status,
262                &text,
263                "Anthropic",
264                &self.config.model,
265            ));
266        }
267
268        let parsed: serde_json::Value = serde_json::from_str(&text)?;
269
270        // Anthropic may return multiple content blocks; find the first text block
271        let content_text = parsed["content"]
272            .as_array()
273            .and_then(|blocks| {
274                blocks.iter().find_map(|block| {
275                    if block["type"].as_str() == Some("text") {
276                        block["text"].as_str().map(|s| s.to_string())
277                    } else {
278                        None
279                    }
280                })
281            })
282            .or_else(|| {
283                // Fallback: try the old path for backwards compat
284                parsed["content"][0]["text"].as_str().map(|s| s.to_string())
285            });
286
287        match content_text {
288            Some(t) if !t.trim().is_empty() => Ok(t),
289            Some(_) => {
290                tracing::warn!(
291                    model = %self.config.model,
292                    "Anthropic returned empty text content"
293                );
294                Ok("{\"memories\": []}".to_string())
295            }
296            None => {
297                tracing::warn!(
298                    model = %self.config.model,
299                    response_preview = &text[..text.len().min(300)],
300                    "No text block found in Anthropic response"
301                );
302                Ok("{\"memories\": []}".to_string())
303            }
304        }
305    }
306
307    async fn call_ollama(
308        &self,
309        conversation: &str,
310        system_prompt: &str,
311    ) -> Result<String, ExtractionError> {
312        let body = serde_json::json!({
313            "model": self.config.model,
314            "stream": false,
315            "format": "json",
316            "messages": [
317                { "role": "system", "content": system_prompt },
318                { "role": "user", "content": conversation }
319            ]
320        });
321
322        let resp = self
323            .client
324            .post(&self.config.api_url)
325            .header("Content-Type", "application/json")
326            .json(&body)
327            .send()
328            .await?;
329
330        let status = resp.status();
331        let text = resp.text().await?;
332
333        if !status.is_success() {
334            return Err(classify_api_error(
335                status,
336                &text,
337                "Ollama",
338                &self.config.model,
339            ));
340        }
341
342        let parsed: serde_json::Value = serde_json::from_str(&text)?;
343        parsed["message"]["content"]
344            .as_str()
345            .map(|s| s.to_string())
346            .ok_or_else(|| {
347                ExtractionError::ParseError("Missing content in Ollama response".to_string())
348            })
349    }
350
351    /// Execute a request with retry logic for rate limits (HTTP 429).
352    /// Uses exponential backoff: 1s, 2s, 4s.
353    pub async fn call_with_retry(
354        &self,
355        conversation: &str,
356        system_prompt: &str,
357    ) -> Result<String, ExtractionError> {
358        self.call_with_retry_inner(conversation, system_prompt, true)
359            .await
360    }
361
362    /// Like call_with_retry but without forcing JSON response format.
363    /// Use for prompts that expect plain text output (synthesis, re-ranking, etc).
364    pub async fn call_text_with_retry(
365        &self,
366        conversation: &str,
367        system_prompt: &str,
368    ) -> Result<String, ExtractionError> {
369        self.call_with_retry_inner(conversation, system_prompt, false)
370            .await
371    }
372
373    async fn call_with_retry_inner(
374        &self,
375        conversation: &str,
376        system_prompt: &str,
377        force_json: bool,
378    ) -> Result<String, ExtractionError> {
379        let max_attempts = 3;
380        let mut last_err = None;
381
382        for attempt in 0..max_attempts {
383            if attempt > 0 {
384                let delay = std::time::Duration::from_secs(1 << attempt);
385                tracing::warn!(
386                    attempt,
387                    delay_secs = delay.as_secs(),
388                    "retrying after rate limit"
389                );
390                tokio::time::sleep(delay).await;
391            }
392
393            tracing::info!(
394                provider = ?self.config.provider,
395                model = %self.config.model,
396                attempt = attempt + 1,
397                "calling LLM extraction API"
398            );
399
400            let result = match self.config.provider {
401                LlmProvider::OpenAI | LlmProvider::Custom => {
402                    if force_json {
403                        self.call_openai(conversation, system_prompt).await
404                    } else {
405                        self.call_openai_text(conversation, system_prompt).await
406                    }
407                }
408                LlmProvider::Anthropic => self.call_anthropic(conversation, system_prompt).await,
409                LlmProvider::Ollama => self.call_ollama(conversation, system_prompt).await,
410            };
411
412            match result {
413                Ok(text) => {
414                    tracing::info!(response_len = text.len(), "LLM extraction complete");
415                    return Ok(text);
416                }
417                Err(ExtractionError::ProviderError(ref msg))
418                    if msg.contains("429")
419                        || msg.contains("500")
420                        || msg.contains("502")
421                        || msg.contains("503")
422                        || msg.contains("529")
423                        || msg.contains("timeout")
424                        || msg.contains("connection")
425                        || msg.contains("overloaded") =>
426                {
427                    tracing::warn!(attempt = attempt + 1, error = %msg, "retrying transient LLM error");
428                    last_err = Some(result.unwrap_err());
429                    continue;
430                }
431                Err(e) => {
432                    tracing::error!(error = %e, "LLM extraction failed (non-retryable)");
433                    return Err(e);
434                }
435            }
436        }
437
438        match last_err {
439            Some(e) => Err(e),
440            None => Err(ExtractionError::RateLimitExceeded {
441                attempts: max_attempts,
442            }),
443        }
444    }
445}
446
447impl ExtractionProvider for HttpExtractionProvider {
448    async fn extract(
449        &self,
450        conversation: &str,
451        system_prompt: &str,
452    ) -> Result<String, ExtractionError> {
453        self.call_with_retry(conversation, system_prompt).await
454    }
455}
456
457/// Mock extraction provider for testing. Returns a predefined JSON response.
458pub struct MockExtractionProvider {
459    response: String,
460}
461
462impl MockExtractionProvider {
463    /// Create a mock provider that always returns the given JSON string.
464    pub fn new(response: impl Into<String>) -> Self {
465        Self {
466            response: response.into(),
467        }
468    }
469
470    /// Create a mock provider with a realistic extraction response.
471    pub fn with_realistic_response() -> Self {
472        let response = serde_json::json!({
473            "memories": [
474                {
475                    "content": "The team decided to use PostgreSQL 15 as the primary database for the REST API project",
476                    "memory_type": "decision",
477                    "confidence": 0.95,
478                    "entities": ["PostgreSQL", "REST API"],
479                    "tags": ["database", "architecture"],
480                    "reasoning": "Explicitly decided after comparing options"
481                },
482                {
483                    "content": "REST endpoints should follow the /api/v1/ prefix convention",
484                    "memory_type": "decision",
485                    "confidence": 0.9,
486                    "entities": ["REST API"],
487                    "tags": ["api-design", "conventions"],
488                    "reasoning": "Team agreed on URL structure"
489                },
490                {
491                    "content": "User prefers Rust over Go for backend services due to memory safety guarantees",
492                    "memory_type": "preference",
493                    "confidence": 0.85,
494                    "entities": ["Rust", "Go"],
495                    "tags": ["language", "backend"],
496                    "reasoning": "Explicitly stated preference with clear reasoning"
497                },
498                {
499                    "content": "The initial plan to use MongoDB was incorrect; PostgreSQL is the right choice for relational data",
500                    "memory_type": "correction",
501                    "confidence": 0.9,
502                    "entities": ["MongoDB", "PostgreSQL"],
503                    "tags": ["database", "correction"],
504                    "reasoning": "Corrected an earlier wrong assumption"
505                },
506                {
507                    "content": "The project deadline is March 15, 2025",
508                    "memory_type": "fact",
509                    "confidence": 0.8,
510                    "entities": ["REST API project"],
511                    "tags": ["timeline"],
512                    "reasoning": "Confirmed date mentioned in discussion"
513                },
514                {
515                    "content": "Using global mutable state for database connections caused race conditions in testing",
516                    "memory_type": "anti_pattern",
517                    "confidence": 0.85,
518                    "entities": [],
519                    "tags": ["testing", "concurrency"],
520                    "reasoning": "Documented failure pattern to avoid repeating"
521                },
522                {
523                    "content": "Low confidence speculation about maybe using Redis",
524                    "memory_type": "fact",
525                    "confidence": 0.3,
526                    "entities": ["Redis"],
527                    "tags": ["cache"],
528                    "reasoning": "Mentioned but not confirmed"
529                }
530            ]
531        });
532        Self::new(response.to_string())
533    }
534}
535
536impl ExtractionProvider for MockExtractionProvider {
537    async fn extract(
538        &self,
539        _conversation: &str,
540        _system_prompt: &str,
541    ) -> Result<String, ExtractionError> {
542        Ok(self.response.clone())
543    }
544}