1use async_trait::async_trait;
26use serde::{Deserialize, Serialize};
27use std::sync::Arc;
28use tokio::sync::RwLock;
29use tracing::warn;
30
31use crate::error::{InferenceError, Result};
32use crate::ner::{rule_based_extract, ExtractedEntity, NerEngine};
33
34#[derive(Debug, Clone, Serialize, Deserialize, Default)]
40pub struct ExtractionResult {
41 pub entities: Vec<ExtractedEntity>,
42 pub topics: Vec<String>,
43 pub key_phrases: Vec<String>,
44 pub summary: Option<String>,
45 pub provider: String,
47}
48
49#[derive(Debug, Clone, Default)]
51pub struct ExtractionOpts {
52 pub entity_types: Vec<String>,
54}
55
56#[derive(Clone, Serialize, Deserialize)]
61pub struct ExtractorConfig {
62 pub provider: String,
65 #[serde(skip_serializing_if = "Option::is_none")]
67 pub model: Option<String>,
68 #[serde(skip_serializing_if = "Option::is_none")]
70 pub base_url: Option<String>,
71 #[serde(skip)]
74 pub api_key: Option<String>,
75}
76
77impl std::fmt::Debug for ExtractorConfig {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct("ExtractorConfig")
80 .field("provider", &self.provider)
81 .field("model", &self.model)
82 .field("base_url", &self.base_url)
83 .field("api_key", &self.api_key.as_ref().map(|_| "[REDACTED]"))
84 .finish()
85 }
86}
87
88impl ExtractorConfig {
89 pub fn none() -> Self {
90 Self {
91 provider: "none".to_string(),
92 model: None,
93 base_url: None,
94 api_key: None,
95 }
96 }
97
98 pub fn gliner() -> Self {
99 Self {
100 provider: "gliner".to_string(),
101 model: None,
102 base_url: None,
103 api_key: None,
104 }
105 }
106}
107
108#[async_trait]
113pub trait ExtractionProvider: Send + Sync {
114 async fn extract(&self, text: &str, opts: &ExtractionOpts) -> Result<ExtractionResult>;
115 fn provider_name(&self) -> &'static str;
116}
117
118pub struct NoneExtractor;
123
124#[async_trait]
125impl ExtractionProvider for NoneExtractor {
126 async fn extract(&self, _text: &str, _opts: &ExtractionOpts) -> Result<ExtractionResult> {
127 Ok(ExtractionResult {
128 provider: "none".to_string(),
129 ..Default::default()
130 })
131 }
132 fn provider_name(&self) -> &'static str {
133 "none"
134 }
135}
136
137pub struct GlinerExtractor {
142 ner: Arc<RwLock<Option<NerEngine>>>,
143}
144
145impl GlinerExtractor {
146 pub fn new(ner: Arc<RwLock<Option<NerEngine>>>) -> Self {
147 Self { ner }
148 }
149}
150
151#[async_trait]
152impl ExtractionProvider for GlinerExtractor {
153 async fn extract(&self, text: &str, opts: &ExtractionOpts) -> Result<ExtractionResult> {
154 let guard = self.ner.read().await;
155 let type_refs: Vec<&str> = opts.entity_types.iter().map(|s| s.as_str()).collect();
156 let entities = if let Some(ref engine) = *guard {
157 engine.extract(text, &type_refs).await
158 } else {
159 rule_based_extract(text)
160 };
161 Ok(ExtractionResult {
162 entities,
163 provider: "gliner".to_string(),
164 ..Default::default()
165 })
166 }
167 fn provider_name(&self) -> &'static str {
168 "gliner"
169 }
170}
171
172const EXTRACT_SYSTEM: &str =
177 "You are a precise information extractor. Extract structured data from the given text. \
178 Respond with valid JSON only — no markdown, no explanation.";
179
180const DEFAULT_ENTITY_TYPES: &[&str] = &[
182 "person",
183 "organization",
184 "location",
185 "date",
186 "url",
187 "email",
188 "uuid",
189 "ip",
190];
191
192fn build_extraction_prompt(text: &str, entity_types: &[String]) -> String {
197 let types_list = if entity_types.is_empty() {
198 DEFAULT_ENTITY_TYPES
199 .iter()
200 .map(|s| s.to_string())
201 .collect::<Vec<_>>()
202 } else {
203 entity_types.to_vec()
204 };
205
206 let type_spec = types_list.join(", ");
207
208 format!(
209 "Extract entities of the following types: {type_spec}.\n\
210 Also extract topics, key phrases, and a brief summary.\n\
211 Respond ONLY with this JSON structure (no markdown):\n\
212 {{\"entities\":[{{\"entity_type\":\"<one of the requested types>\",\
213 \"value\":\"...\",\"score\":0.9,\"start\":0,\"end\":5}}],\
214 \"topics\":[\"...\"],\"key_phrases\":[\"...\"],\"summary\":\"...\"}}\n\n\
215 Text:\n{text}"
216 )
217}
218
219fn parse_llm_json(content: &str, provider: &str) -> Result<ExtractionResult> {
220 let raw = content
222 .trim()
223 .trim_start_matches("```json")
224 .trim_start_matches("```")
225 .trim_end_matches("```")
226 .trim();
227
228 let v: serde_json::Value = serde_json::from_str(raw).map_err(|e| {
229 InferenceError::ExtractionFailed(format!("JSON parse error from {provider}: {e}"))
230 })?;
231
232 let entities: Vec<ExtractedEntity> = v["entities"]
233 .as_array()
234 .map(|arr| {
235 arr.iter()
236 .filter_map(|e| serde_json::from_value(e.clone()).ok())
237 .collect()
238 })
239 .unwrap_or_default();
240
241 let topics: Vec<String> = v["topics"]
242 .as_array()
243 .map(|arr| {
244 arr.iter()
245 .filter_map(|t| t.as_str().map(|s| s.to_string()))
246 .collect()
247 })
248 .unwrap_or_default();
249
250 let key_phrases: Vec<String> = v["key_phrases"]
251 .as_array()
252 .map(|arr| {
253 arr.iter()
254 .filter_map(|t| t.as_str().map(|s| s.to_string()))
255 .collect()
256 })
257 .unwrap_or_default();
258
259 let summary = v["summary"].as_str().map(|s| s.to_string());
260
261 Ok(ExtractionResult {
262 entities,
263 topics,
264 key_phrases,
265 summary,
266 provider: provider.to_string(),
267 })
268}
269
270pub struct OpenAIExtractor {
275 api_key: String,
277 base_url: String,
278 model: String,
279 provider_id: &'static str,
280 client: reqwest::Client,
281}
282
283impl std::fmt::Debug for OpenAIExtractor {
284 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285 f.debug_struct("OpenAIExtractor")
286 .field("base_url", &self.base_url)
287 .field("model", &self.model)
288 .field("api_key", &"[REDACTED]")
289 .finish()
290 }
291}
292
293impl OpenAIExtractor {
294 pub fn openai(api_key: String, model: Option<String>) -> Self {
295 Self::with_base_url(
296 api_key,
297 "https://api.openai.com/v1".to_string(),
298 model.unwrap_or_else(|| "gpt-4o-mini".to_string()),
299 "openai",
300 )
301 }
302
303 pub fn openrouter(api_key: String, model: Option<String>) -> Self {
304 Self::with_base_url(
305 api_key,
306 "https://openrouter.ai/api/v1".to_string(),
307 model.unwrap_or_else(|| "anthropic/claude-3-haiku".to_string()),
308 "openrouter",
309 )
310 }
311
312 pub fn ollama(base_url: Option<String>, model: Option<String>) -> Self {
314 Self::with_base_url(
315 "ollama".to_string(),
316 base_url.unwrap_or_else(|| "http://localhost:11434/v1".to_string()),
317 model.unwrap_or_else(|| "llama3.1:8b".to_string()),
318 "ollama",
319 )
320 }
321
322 fn with_base_url(
323 api_key: String,
324 base_url: String,
325 model: String,
326 provider_id: &'static str,
327 ) -> Self {
328 Self {
329 api_key,
330 base_url,
331 model,
332 provider_id,
333 client: reqwest::Client::new(),
334 }
335 }
336}
337
338#[async_trait]
339impl ExtractionProvider for OpenAIExtractor {
340 async fn extract(&self, text: &str, opts: &ExtractionOpts) -> Result<ExtractionResult> {
341 let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
342 let prompt = build_extraction_prompt(text, &opts.entity_types);
343
344 let body = serde_json::json!({
345 "model": self.model,
346 "messages": [
347 {"role": "system", "content": EXTRACT_SYSTEM},
348 {"role": "user", "content": prompt}
349 ],
350 "temperature": 0,
351 "response_format": {"type": "json_object"}
352 });
353
354 let resp = self
355 .client
356 .post(&url)
357 .header("Authorization", format!("Bearer {}", self.api_key))
358 .header("Content-Type", "application/json")
359 .json(&body)
360 .send()
361 .await
362 .map_err(|e| InferenceError::ExtractionFailed(e.to_string()))?;
363
364 if !resp.status().is_success() {
365 let status = resp.status().as_u16();
366 return Err(InferenceError::ExtractionFailed(format!(
367 "{} returned HTTP {status}",
368 self.provider_id
369 )));
370 }
371
372 let json: serde_json::Value = resp
373 .json()
374 .await
375 .map_err(|e| InferenceError::ExtractionFailed(e.to_string()))?;
376
377 let content = json["choices"][0]["message"]["content"]
378 .as_str()
379 .unwrap_or("{}");
380
381 parse_llm_json(content, self.provider_id)
382 }
383
384 fn provider_name(&self) -> &'static str {
385 self.provider_id
386 }
387}
388
389pub struct AnthropicExtractor {
394 api_key: String,
395 model: String,
396 client: reqwest::Client,
397}
398
399impl std::fmt::Debug for AnthropicExtractor {
400 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
401 f.debug_struct("AnthropicExtractor")
402 .field("model", &self.model)
403 .field("api_key", &"[REDACTED]")
404 .finish()
405 }
406}
407
408impl AnthropicExtractor {
409 pub fn new(api_key: String, model: Option<String>) -> Self {
410 Self {
411 api_key,
412 model: model.unwrap_or_else(|| "claude-3-haiku-20240307".to_string()),
413 client: reqwest::Client::new(),
414 }
415 }
416}
417
418#[async_trait]
419impl ExtractionProvider for AnthropicExtractor {
420 async fn extract(&self, text: &str, opts: &ExtractionOpts) -> Result<ExtractionResult> {
421 let prompt = build_extraction_prompt(text, &opts.entity_types);
422
423 let body = serde_json::json!({
424 "model": self.model,
425 "max_tokens": 1024,
426 "system": EXTRACT_SYSTEM,
427 "messages": [{"role": "user", "content": prompt}]
428 });
429
430 let resp = self
431 .client
432 .post("https://api.anthropic.com/v1/messages")
433 .header("x-api-key", &self.api_key)
434 .header("anthropic-version", "2023-06-01")
435 .header("Content-Type", "application/json")
436 .json(&body)
437 .send()
438 .await
439 .map_err(|e| InferenceError::ExtractionFailed(e.to_string()))?;
440
441 if !resp.status().is_success() {
442 let status = resp.status().as_u16();
443 return Err(InferenceError::ExtractionFailed(format!(
444 "anthropic returned HTTP {status}"
445 )));
446 }
447
448 let json: serde_json::Value = resp
449 .json()
450 .await
451 .map_err(|e| InferenceError::ExtractionFailed(e.to_string()))?;
452
453 let content = json["content"][0]["text"].as_str().unwrap_or("{}");
454
455 parse_llm_json(content, "anthropic")
456 }
457
458 fn provider_name(&self) -> &'static str {
459 "anthropic"
460 }
461}
462
463pub fn build_provider(
472 config: &ExtractorConfig,
473 ner_engine: Option<Arc<RwLock<Option<NerEngine>>>>,
474) -> Box<dyn ExtractionProvider> {
475 match config.provider.as_str() {
476 "gliner" => {
477 if let Some(ner) = ner_engine {
478 Box::new(GlinerExtractor::new(ner))
479 } else {
480 warn!("gliner provider requested but NER engine not available — using rule-based");
482 Box::new(GlinerExtractor::new(Arc::new(RwLock::new(None))))
484 }
485 }
486 "openai" => {
487 let key = config
488 .api_key
489 .clone()
490 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
491 .unwrap_or_default();
492 Box::new(OpenAIExtractor::openai(key, config.model.clone()))
493 }
494 "openrouter" => {
495 let key = config
496 .api_key
497 .clone()
498 .or_else(|| std::env::var("OPENROUTER_API_KEY").ok())
499 .unwrap_or_default();
500 Box::new(OpenAIExtractor::openrouter(key, config.model.clone()))
501 }
502 "ollama" => Box::new(OpenAIExtractor::ollama(
503 config.base_url.clone(),
504 config.model.clone(),
505 )),
506 "anthropic" => {
507 let key = config
508 .api_key
509 .clone()
510 .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
511 .unwrap_or_default();
512 Box::new(AnthropicExtractor::new(key, config.model.clone()))
513 }
514 _ => Box::new(NoneExtractor),
515 }
516}