1use crate::config::{ExtractionConfig, LlmProvider};
2use crate::error::ExtractionError;
3
4fn classify_api_error(
6 status: reqwest::StatusCode,
7 body: &str,
8 provider: &str,
9 model: &str,
10) -> ExtractionError {
11 let code = status.as_u16();
12 match code {
13 401 => ExtractionError::AuthError(format!(
14 "{provider} returned 401 Unauthorized. Check your API key (MENTEDB_LLM_API_KEY). \
15 Current provider: {provider}, model: {model}"
16 )),
17 403 => ExtractionError::AuthError(format!(
18 "{provider} returned 403 Forbidden. Your API key may lack permissions for model '{model}'."
19 )),
20 404 => ExtractionError::ModelNotFound(format!(
21 "{provider} returned 404. Model '{model}' may not exist or is not available on your account."
22 )),
23 _ => ExtractionError::ProviderError(format!("{provider} API returned {status}: {body}")),
24 }
25}
26
27pub trait ExtractionProvider: Send + Sync {
29 fn extract(
32 &self,
33 conversation: &str,
34 system_prompt: &str,
35 ) -> impl std::future::Future<Output = Result<String, ExtractionError>> + Send;
36}
37
38pub struct HttpExtractionProvider {
40 client: reqwest::Client,
41 config: ExtractionConfig,
42}
43
44impl HttpExtractionProvider {
45 pub fn new(config: ExtractionConfig) -> Result<Self, ExtractionError> {
46 if config.provider != LlmProvider::Ollama && config.api_key.is_none() {
47 return Err(ExtractionError::ConfigError(
48 "API key is required for this provider".to_string(),
49 ));
50 }
51 let client = reqwest::Client::builder()
52 .timeout(std::time::Duration::from_secs(120))
53 .connect_timeout(std::time::Duration::from_secs(30))
54 .build()
55 .map_err(|e| ExtractionError::ConfigError(format!("HTTP client error: {}", e)))?;
56 Ok(Self { client, config })
57 }
58
59 pub async fn expand_query(&self, query: &str) -> Result<Vec<String>, ExtractionError> {
69 let system_prompt = "You help search a memory database. Given a question, return a JSON object with:\n\
70 - \"answer_type\": one of PLACE, DATE, TIME, NUMBER, NAME, PERSON, BRAND, ITEM, ACTIVITY, COUNTING, OTHER\n\
71 - \"queries\": array of 2-3 short search queries\n\
72 - For COUNTING only, also include:\n\
73 - \"item_keywords\": comma-separated specific subtypes/instances that would be individually counted\n\
74 - \"broad_keywords\": comma-separated category terms, action verbs, and general synonyms\n\n\
75 Use COUNTING when the question requires COMPLETENESS — counting, listing, aggregating, totaling, \
76 or comparing to find a superlative (most, least, best, worst, first, last, biggest, highest, lowest).\n\n\
77 The distinction matters:\n\
78 - item_keywords: specific things you would COUNT (types of the thing being asked about)\n\
79 - broad_keywords: general terms that help FIND memories but aren't counted themselves\n\n\
80 Examples:\n\
81 Q: \"Where do I take yoga classes?\"\n\
82 {\"answer_type\": \"PLACE\", \"queries\": [\"yoga studio name\", \"yoga class location\"]}\n\n\
83 Q: \"How many doctors did I visit?\"\n\
84 {\"answer_type\": \"COUNTING\", \"queries\": [\"doctor visits appointments\", \"medical specialist visits\"], \
85 \"item_keywords\": \"doctor, Dr., physician, specialist, dermatologist, cardiologist, dentist, surgeon, pediatrician, orthopedist, ophthalmologist\", \
86 \"broad_keywords\": \"medical, clinic, appointment, visit, diagnosed, prescribed, referred, checkup, exam\"}\n\n\
87 Q: \"Which platform did I gain the most followers on?\"\n\
88 {\"answer_type\": \"COUNTING\", \"queries\": [\"social media follower growth\", \"follower count increase\"], \
89 \"item_keywords\": \"TikTok, Instagram, Twitter, YouTube, Facebook, LinkedIn, Snapchat, Reddit, Twitch\", \
90 \"broad_keywords\": \"followers, follower count, gained, growth, platform, social media, increase, jumped, grew\"}";
91 let result = self.call_with_retry(query, system_prompt).await?;
92
93 let mut lines: Vec<String> = Vec::new();
95 let cleaned = result
96 .trim()
97 .trim_start_matches("```json")
98 .trim_end_matches("```")
99 .trim();
100 if let Ok(json) = serde_json::from_str::<serde_json::Value>(cleaned) {
101 if let Some(answer_type) = json.get("answer_type").and_then(|v| v.as_str()) {
102 lines.push(answer_type.to_string());
103 }
104 if let Some(queries) = json.get("queries").and_then(|v| v.as_array()) {
105 for q in queries {
106 if let Some(s) = q.as_str() {
107 lines.push(s.to_string());
108 }
109 }
110 }
111 if let Some(item_kw) = json.get("item_keywords").and_then(|v| v.as_str()) {
112 lines.push(format!("ITEM_KEYWORDS: {}", item_kw));
113 }
114 if let Some(broad_kw) = json.get("broad_keywords").and_then(|v| v.as_str()) {
115 lines.push(format!("BROAD_KEYWORDS: {}", broad_kw));
116 }
117 if let Some(keywords) = json.get("keywords").and_then(|v| v.as_str())
119 && json.get("item_keywords").is_none()
120 {
121 lines.push(format!("ITEM_KEYWORDS: {}", keywords));
122 }
123 } else {
124 lines = result
126 .lines()
127 .map(|l| l.trim().to_string())
128 .filter(|l| !l.is_empty())
129 .collect();
130 }
131 if std::env::var("MENTEDB_DEBUG").is_ok() {
132 eprintln!("[expand_query] input={:?} parsed={:?}", query, lines);
133 }
134 Ok(lines)
135 }
136
137 async fn call_openai(
138 &self,
139 conversation: &str,
140 system_prompt: &str,
141 ) -> Result<String, ExtractionError> {
142 let body = serde_json::json!({
143 "model": self.config.model,
144 "response_format": { "type": "json_object" },
145 "messages": [
146 { "role": "system", "content": system_prompt },
147 { "role": "user", "content": conversation }
148 ]
149 });
150
151 let api_key = self.config.api_key.as_deref().unwrap_or_default();
152
153 let resp = self
154 .client
155 .post(&self.config.api_url)
156 .header("Authorization", format!("Bearer {api_key}"))
157 .header("Content-Type", "application/json")
158 .json(&body)
159 .send()
160 .await?;
161
162 let status = resp.status();
163 let text = resp.text().await?;
164
165 if !status.is_success() {
166 return Err(classify_api_error(
167 status,
168 &text,
169 "OpenAI",
170 &self.config.model,
171 ));
172 }
173
174 let parsed: serde_json::Value = serde_json::from_str(&text)?;
175 parsed["choices"][0]["message"]["content"]
176 .as_str()
177 .map(|s| s.to_string())
178 .ok_or_else(|| {
179 ExtractionError::ParseError("Missing content in OpenAI response".to_string())
180 })
181 }
182
183 async fn call_openai_text(
186 &self,
187 conversation: &str,
188 system_prompt: &str,
189 ) -> Result<String, ExtractionError> {
190 let body = serde_json::json!({
191 "model": self.config.model,
192 "messages": [
193 { "role": "system", "content": system_prompt },
194 { "role": "user", "content": conversation }
195 ]
196 });
197
198 let api_key = self.config.api_key.as_deref().unwrap_or_default();
199
200 let resp = self
201 .client
202 .post(&self.config.api_url)
203 .header("Authorization", format!("Bearer {api_key}"))
204 .header("Content-Type", "application/json")
205 .json(&body)
206 .send()
207 .await?;
208
209 let status = resp.status();
210 let text = resp.text().await?;
211
212 if !status.is_success() {
213 return Err(classify_api_error(
214 status,
215 &text,
216 "OpenAI",
217 &self.config.model,
218 ));
219 }
220
221 let parsed: serde_json::Value = serde_json::from_str(&text)?;
222 parsed["choices"][0]["message"]["content"]
223 .as_str()
224 .map(|s| s.to_string())
225 .ok_or_else(|| {
226 ExtractionError::ParseError("Missing content in OpenAI response".to_string())
227 })
228 }
229
230 async fn call_anthropic(
231 &self,
232 conversation: &str,
233 system_prompt: &str,
234 ) -> Result<String, ExtractionError> {
235 let body = serde_json::json!({
236 "model": self.config.model,
237 "max_tokens": 4096,
238 "system": system_prompt,
239 "messages": [
240 { "role": "user", "content": conversation }
241 ]
242 });
243
244 let api_key = self.config.api_key.as_deref().unwrap_or_default();
245
246 let resp = self
247 .client
248 .post(&self.config.api_url)
249 .header("x-api-key", api_key)
250 .header("anthropic-version", "2023-06-01")
251 .header("Content-Type", "application/json")
252 .json(&body)
253 .send()
254 .await?;
255
256 let status = resp.status();
257 let text = resp.text().await?;
258
259 if !status.is_success() {
260 return Err(classify_api_error(
261 status,
262 &text,
263 "Anthropic",
264 &self.config.model,
265 ));
266 }
267
268 let parsed: serde_json::Value = serde_json::from_str(&text)?;
269
270 let content_text = parsed["content"]
272 .as_array()
273 .and_then(|blocks| {
274 blocks.iter().find_map(|block| {
275 if block["type"].as_str() == Some("text") {
276 block["text"].as_str().map(|s| s.to_string())
277 } else {
278 None
279 }
280 })
281 })
282 .or_else(|| {
283 parsed["content"][0]["text"].as_str().map(|s| s.to_string())
285 });
286
287 match content_text {
288 Some(t) if !t.trim().is_empty() => Ok(t),
289 Some(_) => {
290 tracing::warn!(
291 model = %self.config.model,
292 "Anthropic returned empty text content"
293 );
294 Ok("{\"memories\": []}".to_string())
295 }
296 None => {
297 tracing::warn!(
298 model = %self.config.model,
299 response_preview = &text[..text.len().min(300)],
300 "No text block found in Anthropic response"
301 );
302 Ok("{\"memories\": []}".to_string())
303 }
304 }
305 }
306
307 async fn call_ollama(
308 &self,
309 conversation: &str,
310 system_prompt: &str,
311 ) -> Result<String, ExtractionError> {
312 let body = serde_json::json!({
313 "model": self.config.model,
314 "stream": false,
315 "format": "json",
316 "messages": [
317 { "role": "system", "content": system_prompt },
318 { "role": "user", "content": conversation }
319 ]
320 });
321
322 let resp = self
323 .client
324 .post(&self.config.api_url)
325 .header("Content-Type", "application/json")
326 .json(&body)
327 .send()
328 .await?;
329
330 let status = resp.status();
331 let text = resp.text().await?;
332
333 if !status.is_success() {
334 return Err(classify_api_error(
335 status,
336 &text,
337 "Ollama",
338 &self.config.model,
339 ));
340 }
341
342 let parsed: serde_json::Value = serde_json::from_str(&text)?;
343 parsed["message"]["content"]
344 .as_str()
345 .map(|s| s.to_string())
346 .ok_or_else(|| {
347 ExtractionError::ParseError("Missing content in Ollama response".to_string())
348 })
349 }
350
351 pub async fn call_with_retry(
354 &self,
355 conversation: &str,
356 system_prompt: &str,
357 ) -> Result<String, ExtractionError> {
358 self.call_with_retry_inner(conversation, system_prompt, true)
359 .await
360 }
361
362 pub async fn call_text_with_retry(
365 &self,
366 conversation: &str,
367 system_prompt: &str,
368 ) -> Result<String, ExtractionError> {
369 self.call_with_retry_inner(conversation, system_prompt, false)
370 .await
371 }
372
373 async fn call_with_retry_inner(
374 &self,
375 conversation: &str,
376 system_prompt: &str,
377 force_json: bool,
378 ) -> Result<String, ExtractionError> {
379 let max_attempts = 3;
380 let mut last_err = None;
381
382 for attempt in 0..max_attempts {
383 if attempt > 0 {
384 let delay = std::time::Duration::from_secs(1 << attempt);
385 tracing::warn!(
386 attempt,
387 delay_secs = delay.as_secs(),
388 "retrying after rate limit"
389 );
390 tokio::time::sleep(delay).await;
391 }
392
393 tracing::info!(
394 provider = ?self.config.provider,
395 model = %self.config.model,
396 attempt = attempt + 1,
397 "calling LLM extraction API"
398 );
399
400 let result = match self.config.provider {
401 LlmProvider::OpenAI | LlmProvider::Custom => {
402 if force_json {
403 self.call_openai(conversation, system_prompt).await
404 } else {
405 self.call_openai_text(conversation, system_prompt).await
406 }
407 }
408 LlmProvider::Anthropic => self.call_anthropic(conversation, system_prompt).await,
409 LlmProvider::Ollama => self.call_ollama(conversation, system_prompt).await,
410 };
411
412 match result {
413 Ok(text) => {
414 tracing::info!(response_len = text.len(), "LLM extraction complete");
415 return Ok(text);
416 }
417 Err(ExtractionError::ProviderError(ref msg))
418 if msg.contains("429")
419 || msg.contains("500")
420 || msg.contains("502")
421 || msg.contains("503")
422 || msg.contains("529")
423 || msg.contains("timeout")
424 || msg.contains("connection")
425 || msg.contains("overloaded") =>
426 {
427 tracing::warn!(attempt = attempt + 1, error = %msg, "retrying transient LLM error");
428 last_err = Some(result.unwrap_err());
429 continue;
430 }
431 Err(e) => {
432 tracing::error!(error = %e, "LLM extraction failed (non-retryable)");
433 return Err(e);
434 }
435 }
436 }
437
438 match last_err {
439 Some(e) => Err(e),
440 None => Err(ExtractionError::RateLimitExceeded {
441 attempts: max_attempts,
442 }),
443 }
444 }
445}
446
447impl ExtractionProvider for HttpExtractionProvider {
448 async fn extract(
449 &self,
450 conversation: &str,
451 system_prompt: &str,
452 ) -> Result<String, ExtractionError> {
453 self.call_with_retry(conversation, system_prompt).await
454 }
455}
456
457pub struct MockExtractionProvider {
459 response: String,
460}
461
462impl MockExtractionProvider {
463 pub fn new(response: impl Into<String>) -> Self {
465 Self {
466 response: response.into(),
467 }
468 }
469
470 pub fn with_realistic_response() -> Self {
472 let response = serde_json::json!({
473 "memories": [
474 {
475 "content": "The team decided to use PostgreSQL 15 as the primary database for the REST API project",
476 "memory_type": "decision",
477 "confidence": 0.95,
478 "entities": ["PostgreSQL", "REST API"],
479 "tags": ["database", "architecture"],
480 "reasoning": "Explicitly decided after comparing options"
481 },
482 {
483 "content": "REST endpoints should follow the /api/v1/ prefix convention",
484 "memory_type": "decision",
485 "confidence": 0.9,
486 "entities": ["REST API"],
487 "tags": ["api-design", "conventions"],
488 "reasoning": "Team agreed on URL structure"
489 },
490 {
491 "content": "User prefers Rust over Go for backend services due to memory safety guarantees",
492 "memory_type": "preference",
493 "confidence": 0.85,
494 "entities": ["Rust", "Go"],
495 "tags": ["language", "backend"],
496 "reasoning": "Explicitly stated preference with clear reasoning"
497 },
498 {
499 "content": "The initial plan to use MongoDB was incorrect; PostgreSQL is the right choice for relational data",
500 "memory_type": "correction",
501 "confidence": 0.9,
502 "entities": ["MongoDB", "PostgreSQL"],
503 "tags": ["database", "correction"],
504 "reasoning": "Corrected an earlier wrong assumption"
505 },
506 {
507 "content": "The project deadline is March 15, 2025",
508 "memory_type": "fact",
509 "confidence": 0.8,
510 "entities": ["REST API project"],
511 "tags": ["timeline"],
512 "reasoning": "Confirmed date mentioned in discussion"
513 },
514 {
515 "content": "Using global mutable state for database connections caused race conditions in testing",
516 "memory_type": "anti_pattern",
517 "confidence": 0.85,
518 "entities": [],
519 "tags": ["testing", "concurrency"],
520 "reasoning": "Documented failure pattern to avoid repeating"
521 },
522 {
523 "content": "Low confidence speculation about maybe using Redis",
524 "memory_type": "fact",
525 "confidence": 0.3,
526 "entities": ["Redis"],
527 "tags": ["cache"],
528 "reasoning": "Mentioned but not confirmed"
529 }
530 ]
531 });
532 Self::new(response.to_string())
533 }
534}
535
536impl ExtractionProvider for MockExtractionProvider {
537 async fn extract(
538 &self,
539 _conversation: &str,
540 _system_prompt: &str,
541 ) -> Result<String, ExtractionError> {
542 Ok(self.response.clone())
543 }
544}