1use crate::config::{ExtractionConfig, LlmProvider};
2use crate::error::ExtractionError;
3
4pub trait ExtractionProvider: Send + Sync {
6 fn extract(
9 &self,
10 conversation: &str,
11 system_prompt: &str,
12 ) -> impl std::future::Future<Output = Result<String, ExtractionError>> + Send;
13}
14
15pub struct HttpExtractionProvider {
17 client: reqwest::Client,
18 config: ExtractionConfig,
19}
20
21impl HttpExtractionProvider {
22 pub fn new(config: ExtractionConfig) -> Result<Self, ExtractionError> {
23 if config.provider != LlmProvider::Ollama && config.api_key.is_none() {
24 return Err(ExtractionError::ConfigError(
25 "API key is required for this provider".to_string(),
26 ));
27 }
28 let client = reqwest::Client::new();
29 Ok(Self { client, config })
30 }
31
32 async fn call_openai(
33 &self,
34 conversation: &str,
35 system_prompt: &str,
36 ) -> Result<String, ExtractionError> {
37 let body = serde_json::json!({
38 "model": self.config.model,
39 "response_format": { "type": "json_object" },
40 "messages": [
41 { "role": "system", "content": system_prompt },
42 { "role": "user", "content": conversation }
43 ]
44 });
45
46 let api_key = self.config.api_key.as_deref().unwrap_or_default();
47
48 let resp = self
49 .client
50 .post(&self.config.api_url)
51 .header("Authorization", format!("Bearer {api_key}"))
52 .header("Content-Type", "application/json")
53 .json(&body)
54 .send()
55 .await?;
56
57 let status = resp.status();
58 let text = resp.text().await?;
59
60 if !status.is_success() {
61 return Err(ExtractionError::ProviderError(format!(
62 "OpenAI API returned {status}: {text}"
63 )));
64 }
65
66 let parsed: serde_json::Value = serde_json::from_str(&text)?;
67 parsed["choices"][0]["message"]["content"]
68 .as_str()
69 .map(|s| s.to_string())
70 .ok_or_else(|| {
71 ExtractionError::ParseError("Missing content in OpenAI response".to_string())
72 })
73 }
74
75 async fn call_anthropic(
76 &self,
77 conversation: &str,
78 system_prompt: &str,
79 ) -> Result<String, ExtractionError> {
80 let body = serde_json::json!({
81 "model": self.config.model,
82 "max_tokens": 4096,
83 "system": system_prompt,
84 "messages": [
85 { "role": "user", "content": conversation }
86 ]
87 });
88
89 let api_key = self.config.api_key.as_deref().unwrap_or_default();
90
91 let resp = self
92 .client
93 .post(&self.config.api_url)
94 .header("x-api-key", api_key)
95 .header("anthropic-version", "2023-06-01")
96 .header("Content-Type", "application/json")
97 .json(&body)
98 .send()
99 .await?;
100
101 let status = resp.status();
102 let text = resp.text().await?;
103
104 if !status.is_success() {
105 return Err(ExtractionError::ProviderError(format!(
106 "Anthropic API returned {status}: {text}"
107 )));
108 }
109
110 let parsed: serde_json::Value = serde_json::from_str(&text)?;
111 parsed["content"][0]["text"]
112 .as_str()
113 .map(|s| s.to_string())
114 .ok_or_else(|| {
115 ExtractionError::ParseError("Missing text in Anthropic response".to_string())
116 })
117 }
118
119 async fn call_ollama(
120 &self,
121 conversation: &str,
122 system_prompt: &str,
123 ) -> Result<String, ExtractionError> {
124 let body = serde_json::json!({
125 "model": self.config.model,
126 "stream": false,
127 "format": "json",
128 "messages": [
129 { "role": "system", "content": system_prompt },
130 { "role": "user", "content": conversation }
131 ]
132 });
133
134 let resp = self
135 .client
136 .post(&self.config.api_url)
137 .header("Content-Type", "application/json")
138 .json(&body)
139 .send()
140 .await?;
141
142 let status = resp.status();
143 let text = resp.text().await?;
144
145 if !status.is_success() {
146 return Err(ExtractionError::ProviderError(format!(
147 "Ollama API returned {status}: {text}"
148 )));
149 }
150
151 let parsed: serde_json::Value = serde_json::from_str(&text)?;
152 parsed["message"]["content"]
153 .as_str()
154 .map(|s| s.to_string())
155 .ok_or_else(|| {
156 ExtractionError::ParseError("Missing content in Ollama response".to_string())
157 })
158 }
159
160 async fn call_with_retry(
163 &self,
164 conversation: &str,
165 system_prompt: &str,
166 ) -> Result<String, ExtractionError> {
167 let max_attempts = 3;
168 let mut last_err = None;
169
170 for attempt in 0..max_attempts {
171 if attempt > 0 {
172 let delay = std::time::Duration::from_secs(1 << attempt);
173 tracing::warn!(
174 attempt,
175 delay_secs = delay.as_secs(),
176 "retrying after rate limit"
177 );
178 tokio::time::sleep(delay).await;
179 }
180
181 tracing::info!(
182 provider = ?self.config.provider,
183 model = %self.config.model,
184 attempt = attempt + 1,
185 "calling LLM extraction API"
186 );
187
188 let result = match self.config.provider {
189 LlmProvider::OpenAI | LlmProvider::Custom => {
190 self.call_openai(conversation, system_prompt).await
191 }
192 LlmProvider::Anthropic => self.call_anthropic(conversation, system_prompt).await,
193 LlmProvider::Ollama => self.call_ollama(conversation, system_prompt).await,
194 };
195
196 match result {
197 Ok(text) => {
198 tracing::info!(response_len = text.len(), "LLM extraction complete");
199 return Ok(text);
200 }
201 Err(ExtractionError::ProviderError(ref msg)) if msg.contains("429") => {
202 tracing::warn!(attempt = attempt + 1, "rate limited by provider");
203 last_err = Some(result.unwrap_err());
204 continue;
205 }
206 Err(e) => {
207 tracing::error!(error = %e, "LLM extraction failed");
208 return Err(e);
209 }
210 }
211 }
212
213 match last_err {
214 Some(e) => Err(e),
215 None => Err(ExtractionError::RateLimitExceeded {
216 attempts: max_attempts,
217 }),
218 }
219 }
220}
221
222impl ExtractionProvider for HttpExtractionProvider {
223 async fn extract(
224 &self,
225 conversation: &str,
226 system_prompt: &str,
227 ) -> Result<String, ExtractionError> {
228 self.call_with_retry(conversation, system_prompt).await
229 }
230}
231
232pub struct MockExtractionProvider {
234 response: String,
235}
236
237impl MockExtractionProvider {
238 pub fn new(response: impl Into<String>) -> Self {
240 Self {
241 response: response.into(),
242 }
243 }
244
245 pub fn with_realistic_response() -> Self {
247 let response = serde_json::json!({
248 "memories": [
249 {
250 "content": "The team decided to use PostgreSQL 15 as the primary database for the REST API project",
251 "memory_type": "decision",
252 "confidence": 0.95,
253 "entities": ["PostgreSQL", "REST API"],
254 "tags": ["database", "architecture"],
255 "reasoning": "Explicitly decided after comparing options"
256 },
257 {
258 "content": "REST endpoints should follow the /api/v1/ prefix convention",
259 "memory_type": "decision",
260 "confidence": 0.9,
261 "entities": ["REST API"],
262 "tags": ["api-design", "conventions"],
263 "reasoning": "Team agreed on URL structure"
264 },
265 {
266 "content": "User prefers Rust over Go for backend services due to memory safety guarantees",
267 "memory_type": "preference",
268 "confidence": 0.85,
269 "entities": ["Rust", "Go"],
270 "tags": ["language", "backend"],
271 "reasoning": "Explicitly stated preference with clear reasoning"
272 },
273 {
274 "content": "The initial plan to use MongoDB was incorrect; PostgreSQL is the right choice for relational data",
275 "memory_type": "correction",
276 "confidence": 0.9,
277 "entities": ["MongoDB", "PostgreSQL"],
278 "tags": ["database", "correction"],
279 "reasoning": "Corrected an earlier wrong assumption"
280 },
281 {
282 "content": "The project deadline is March 15, 2025",
283 "memory_type": "fact",
284 "confidence": 0.8,
285 "entities": ["REST API project"],
286 "tags": ["timeline"],
287 "reasoning": "Confirmed date mentioned in discussion"
288 },
289 {
290 "content": "Using global mutable state for database connections caused race conditions in testing",
291 "memory_type": "anti_pattern",
292 "confidence": 0.85,
293 "entities": [],
294 "tags": ["testing", "concurrency"],
295 "reasoning": "Documented failure pattern to avoid repeating"
296 },
297 {
298 "content": "Low confidence speculation about maybe using Redis",
299 "memory_type": "fact",
300 "confidence": 0.3,
301 "entities": ["Redis"],
302 "tags": ["cache"],
303 "reasoning": "Mentioned but not confirmed"
304 }
305 ]
306 });
307 Self::new(response.to_string())
308 }
309}
310
311impl ExtractionProvider for MockExtractionProvider {
312 async fn extract(
313 &self,
314 _conversation: &str,
315 _system_prompt: &str,
316 ) -> Result<String, ExtractionError> {
317 Ok(self.response.clone())
318 }
319}