1use async_trait::async_trait;
26use serde::{Deserialize, Serialize};
27use std::sync::Arc;
28use tokio::sync::RwLock;
29use tracing::warn;
30
31use crate::error::{InferenceError, Result};
32use crate::ner::{rule_based_extract, ExtractedEntity, NerEngine};
33
34#[derive(Debug, Clone, Serialize, Deserialize, Default)]
40pub struct ExtractionResult {
41 pub entities: Vec<ExtractedEntity>,
42 pub topics: Vec<String>,
43 pub key_phrases: Vec<String>,
44 pub summary: Option<String>,
45 pub provider: String,
47}
48
49#[derive(Debug, Clone, Default)]
51pub struct ExtractionOpts {
52 pub entity_types: Vec<String>,
54}
55
56#[derive(Clone, Serialize, Deserialize)]
61pub struct ExtractorConfig {
62 pub provider: String,
65 #[serde(skip_serializing_if = "Option::is_none")]
67 pub model: Option<String>,
68 #[serde(skip_serializing_if = "Option::is_none")]
70 pub base_url: Option<String>,
71 #[serde(skip)]
74 pub api_key: Option<String>,
75}
76
77impl std::fmt::Debug for ExtractorConfig {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct("ExtractorConfig")
80 .field("provider", &self.provider)
81 .field("model", &self.model)
82 .field("base_url", &self.base_url)
83 .field("api_key", &self.api_key.as_ref().map(|_| "[REDACTED]"))
84 .finish()
85 }
86}
87
88impl ExtractorConfig {
89 pub fn none() -> Self {
90 Self {
91 provider: "none".to_string(),
92 model: None,
93 base_url: None,
94 api_key: None,
95 }
96 }
97
98 pub fn gliner() -> Self {
99 Self {
100 provider: "gliner".to_string(),
101 model: None,
102 base_url: None,
103 api_key: None,
104 }
105 }
106}
107
108#[async_trait]
113pub trait ExtractionProvider: Send + Sync {
114 async fn extract(&self, text: &str, opts: &ExtractionOpts) -> Result<ExtractionResult>;
115 fn provider_name(&self) -> &'static str;
116}
117
118pub struct NoneExtractor;
123
124#[async_trait]
125impl ExtractionProvider for NoneExtractor {
126 async fn extract(&self, _text: &str, _opts: &ExtractionOpts) -> Result<ExtractionResult> {
127 Ok(ExtractionResult {
128 provider: "none".to_string(),
129 ..Default::default()
130 })
131 }
132 fn provider_name(&self) -> &'static str {
133 "none"
134 }
135}
136
137pub struct GlinerExtractor {
142 ner: Arc<RwLock<Option<NerEngine>>>,
143}
144
145impl GlinerExtractor {
146 pub fn new(ner: Arc<RwLock<Option<NerEngine>>>) -> Self {
147 Self { ner }
148 }
149}
150
151#[async_trait]
152impl ExtractionProvider for GlinerExtractor {
153 async fn extract(&self, text: &str, opts: &ExtractionOpts) -> Result<ExtractionResult> {
154 let guard = self.ner.read().await;
155 let type_refs: Vec<&str> = opts.entity_types.iter().map(|s| s.as_str()).collect();
156 let entities = if let Some(ref engine) = *guard {
157 engine.extract(text, &type_refs).await
158 } else {
159 rule_based_extract(text)
160 };
161 Ok(ExtractionResult {
162 entities,
163 provider: "gliner".to_string(),
164 ..Default::default()
165 })
166 }
167 fn provider_name(&self) -> &'static str {
168 "gliner"
169 }
170}
171
172const EXTRACT_SYSTEM: &str =
177 "You are a precise information extractor. Extract structured data from the given text. \
178 Respond with valid JSON only — no markdown, no explanation.";
179
180const EXTRACT_PROMPT_TMPL: &str =
181 "Extract entities, topics, key phrases, and a brief summary from the text below.\n\
182 Respond ONLY with this JSON structure:\n\
183 {\"entities\":[{\"entity_type\":\"person|org|location|date|url|email|uuid|ip\",\
184 \"value\":\"...\",\"score\":0.9,\"start\":0,\"end\":5}],\
185 \"topics\":[\"...\"],\"key_phrases\":[\"...\"],\"summary\":\"...\"}\n\n\
186 Text:\n";
187
188fn build_extraction_prompt(text: &str) -> String {
189 format!("{}{}", EXTRACT_PROMPT_TMPL, text)
190}
191
192fn parse_llm_json(content: &str, provider: &str) -> Result<ExtractionResult> {
193 let raw = content
195 .trim()
196 .trim_start_matches("```json")
197 .trim_start_matches("```")
198 .trim_end_matches("```")
199 .trim();
200
201 let v: serde_json::Value = serde_json::from_str(raw).map_err(|e| {
202 InferenceError::ExtractionFailed(format!("JSON parse error from {provider}: {e}"))
203 })?;
204
205 let entities: Vec<ExtractedEntity> = v["entities"]
206 .as_array()
207 .map(|arr| {
208 arr.iter()
209 .filter_map(|e| serde_json::from_value(e.clone()).ok())
210 .collect()
211 })
212 .unwrap_or_default();
213
214 let topics: Vec<String> = v["topics"]
215 .as_array()
216 .map(|arr| {
217 arr.iter()
218 .filter_map(|t| t.as_str().map(|s| s.to_string()))
219 .collect()
220 })
221 .unwrap_or_default();
222
223 let key_phrases: Vec<String> = v["key_phrases"]
224 .as_array()
225 .map(|arr| {
226 arr.iter()
227 .filter_map(|t| t.as_str().map(|s| s.to_string()))
228 .collect()
229 })
230 .unwrap_or_default();
231
232 let summary = v["summary"].as_str().map(|s| s.to_string());
233
234 Ok(ExtractionResult {
235 entities,
236 topics,
237 key_phrases,
238 summary,
239 provider: provider.to_string(),
240 })
241}
242
243pub struct OpenAIExtractor {
248 api_key: String,
250 base_url: String,
251 model: String,
252 provider_id: &'static str,
253 client: reqwest::Client,
254}
255
256impl std::fmt::Debug for OpenAIExtractor {
257 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258 f.debug_struct("OpenAIExtractor")
259 .field("base_url", &self.base_url)
260 .field("model", &self.model)
261 .field("api_key", &"[REDACTED]")
262 .finish()
263 }
264}
265
266impl OpenAIExtractor {
267 pub fn openai(api_key: String, model: Option<String>) -> Self {
268 Self::with_base_url(
269 api_key,
270 "https://api.openai.com/v1".to_string(),
271 model.unwrap_or_else(|| "gpt-4o-mini".to_string()),
272 "openai",
273 )
274 }
275
276 pub fn openrouter(api_key: String, model: Option<String>) -> Self {
277 Self::with_base_url(
278 api_key,
279 "https://openrouter.ai/api/v1".to_string(),
280 model.unwrap_or_else(|| "anthropic/claude-3-haiku".to_string()),
281 "openrouter",
282 )
283 }
284
285 pub fn ollama(base_url: Option<String>, model: Option<String>) -> Self {
287 Self::with_base_url(
288 "ollama".to_string(),
289 base_url.unwrap_or_else(|| "http://localhost:11434/v1".to_string()),
290 model.unwrap_or_else(|| "llama3.1:8b".to_string()),
291 "ollama",
292 )
293 }
294
295 fn with_base_url(
296 api_key: String,
297 base_url: String,
298 model: String,
299 provider_id: &'static str,
300 ) -> Self {
301 Self {
302 api_key,
303 base_url,
304 model,
305 provider_id,
306 client: reqwest::Client::new(),
307 }
308 }
309}
310
311#[async_trait]
312impl ExtractionProvider for OpenAIExtractor {
313 async fn extract(&self, text: &str, _opts: &ExtractionOpts) -> Result<ExtractionResult> {
314 let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
315 let prompt = build_extraction_prompt(text);
316
317 let body = serde_json::json!({
318 "model": self.model,
319 "messages": [
320 {"role": "system", "content": EXTRACT_SYSTEM},
321 {"role": "user", "content": prompt}
322 ],
323 "temperature": 0,
324 "response_format": {"type": "json_object"}
325 });
326
327 let resp = self
328 .client
329 .post(&url)
330 .header("Authorization", format!("Bearer {}", self.api_key))
331 .header("Content-Type", "application/json")
332 .json(&body)
333 .send()
334 .await
335 .map_err(|e| InferenceError::ExtractionFailed(e.to_string()))?;
336
337 if !resp.status().is_success() {
338 let status = resp.status().as_u16();
339 return Err(InferenceError::ExtractionFailed(format!(
340 "{} returned HTTP {status}",
341 self.provider_id
342 )));
343 }
344
345 let json: serde_json::Value = resp
346 .json()
347 .await
348 .map_err(|e| InferenceError::ExtractionFailed(e.to_string()))?;
349
350 let content = json["choices"][0]["message"]["content"]
351 .as_str()
352 .unwrap_or("{}");
353
354 parse_llm_json(content, self.provider_id)
355 }
356
357 fn provider_name(&self) -> &'static str {
358 self.provider_id
359 }
360}
361
362pub struct AnthropicExtractor {
367 api_key: String,
368 model: String,
369 client: reqwest::Client,
370}
371
372impl std::fmt::Debug for AnthropicExtractor {
373 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374 f.debug_struct("AnthropicExtractor")
375 .field("model", &self.model)
376 .field("api_key", &"[REDACTED]")
377 .finish()
378 }
379}
380
381impl AnthropicExtractor {
382 pub fn new(api_key: String, model: Option<String>) -> Self {
383 Self {
384 api_key,
385 model: model.unwrap_or_else(|| "claude-3-haiku-20240307".to_string()),
386 client: reqwest::Client::new(),
387 }
388 }
389}
390
391#[async_trait]
392impl ExtractionProvider for AnthropicExtractor {
393 async fn extract(&self, text: &str, _opts: &ExtractionOpts) -> Result<ExtractionResult> {
394 let prompt = build_extraction_prompt(text);
395
396 let body = serde_json::json!({
397 "model": self.model,
398 "max_tokens": 1024,
399 "system": EXTRACT_SYSTEM,
400 "messages": [{"role": "user", "content": prompt}]
401 });
402
403 let resp = self
404 .client
405 .post("https://api.anthropic.com/v1/messages")
406 .header("x-api-key", &self.api_key)
407 .header("anthropic-version", "2023-06-01")
408 .header("Content-Type", "application/json")
409 .json(&body)
410 .send()
411 .await
412 .map_err(|e| InferenceError::ExtractionFailed(e.to_string()))?;
413
414 if !resp.status().is_success() {
415 let status = resp.status().as_u16();
416 return Err(InferenceError::ExtractionFailed(format!(
417 "anthropic returned HTTP {status}"
418 )));
419 }
420
421 let json: serde_json::Value = resp
422 .json()
423 .await
424 .map_err(|e| InferenceError::ExtractionFailed(e.to_string()))?;
425
426 let content = json["content"][0]["text"].as_str().unwrap_or("{}");
427
428 parse_llm_json(content, "anthropic")
429 }
430
431 fn provider_name(&self) -> &'static str {
432 "anthropic"
433 }
434}
435
436pub fn build_provider(
445 config: &ExtractorConfig,
446 ner_engine: Option<Arc<RwLock<Option<NerEngine>>>>,
447) -> Box<dyn ExtractionProvider> {
448 match config.provider.as_str() {
449 "gliner" => {
450 if let Some(ner) = ner_engine {
451 Box::new(GlinerExtractor::new(ner))
452 } else {
453 warn!("gliner provider requested but NER engine not available — using rule-based");
455 Box::new(GlinerExtractor::new(Arc::new(RwLock::new(None))))
457 }
458 }
459 "openai" => {
460 let key = config
461 .api_key
462 .clone()
463 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
464 .unwrap_or_default();
465 Box::new(OpenAIExtractor::openai(key, config.model.clone()))
466 }
467 "openrouter" => {
468 let key = config
469 .api_key
470 .clone()
471 .or_else(|| std::env::var("OPENROUTER_API_KEY").ok())
472 .unwrap_or_default();
473 Box::new(OpenAIExtractor::openrouter(key, config.model.clone()))
474 }
475 "ollama" => Box::new(OpenAIExtractor::ollama(
476 config.base_url.clone(),
477 config.model.clone(),
478 )),
479 "anthropic" => {
480 let key = config
481 .api_key
482 .clone()
483 .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
484 .unwrap_or_default();
485 Box::new(AnthropicExtractor::new(key, config.model.clone()))
486 }
487 _ => Box::new(NoneExtractor),
488 }
489}