Skip to main content

inference/
extraction.rs

1//! EXT-1 — External Extraction Providers.
2//!
3//! Implements an `ExtractionProvider` trait with five concrete backends:
4//!
5//! | Provider | Transport | Auth |
6//! |----------|-----------|------|
7//! | `none` | — | — |
8//! | `gliner` | In-process ONNX | — |
9//! | `openai` | HTTPS | Bearer key |
10//! | `openrouter` | HTTPS (base_url override) | Bearer key |
11//! | `ollama` | HTTP (local) | — |
12//! | `anthropic` | HTTPS | x-api-key header |
13//!
14//! **Security contract:** `api_key` fields redact themselves in `Debug` output
15//! and are never serialized (only used at call time via per-request override or
16//! server env). They are NOT written to any storage layer.
17//!
18//! **Provider hierarchy** (highest priority wins):
19//! 1. Per-request `extractor_override` in request body
20//! 2. Per-namespace default (stored in `_dakera_namespace_configs`)
21//! 3. Server default (`[extractor]` in config.toml)
22//! 4. GLiNER local (if namespace has `extract_entities = true`)
23//! 5. `none` (default — backward-compatible)
24
25use async_trait::async_trait;
26use serde::{Deserialize, Serialize};
27use std::sync::Arc;
28use tokio::sync::RwLock;
29use tracing::warn;
30
31use crate::error::{InferenceError, Result};
32use crate::ner::{rule_based_extract, ExtractedEntity, NerEngine};
33
34// ─────────────────────────────────────────────────────────────
35// Public types
36// ─────────────────────────────────────────────────────────────
37
38/// Result returned by any extraction provider.
39#[derive(Debug, Clone, Serialize, Deserialize, Default)]
40pub struct ExtractionResult {
41    pub entities: Vec<ExtractedEntity>,
42    pub topics: Vec<String>,
43    pub key_phrases: Vec<String>,
44    pub summary: Option<String>,
45    /// Which provider produced this result.
46    pub provider: String,
47}
48
49/// Options passed into `ExtractionProvider::extract`.
50#[derive(Debug, Clone, Default)]
51pub struct ExtractionOpts {
52    /// GLiNER entity types (e.g. `["person", "org"]`). Ignored by LLM providers.
53    pub entity_types: Vec<String>,
54}
55
56/// Serialisable configuration stored per-namespace or sent per-request.
57///
58/// The `api_key` field has a custom `Debug` impl that redacts the value
59/// and is tagged `#[serde(skip)]` — it is never written to storage.
60#[derive(Clone, Serialize, Deserialize)]
61pub struct ExtractorConfig {
62    /// Provider identifier: `none`, `gliner`, `openai`, `anthropic`,
63    /// `openrouter`, `ollama`.
64    pub provider: String,
65    /// Model name (provider-specific). `None` → use the recommended default.
66    #[serde(skip_serializing_if = "Option::is_none")]
67    pub model: Option<String>,
68    /// Base URL override — used for `openrouter` and `ollama`.
69    #[serde(skip_serializing_if = "Option::is_none")]
70    pub base_url: Option<String>,
71    /// API key — NEVER persisted, NEVER logged. Present only in per-request
72    /// overrides or resolved from server env at call time.
73    #[serde(skip)]
74    pub api_key: Option<String>,
75}
76
77impl std::fmt::Debug for ExtractorConfig {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        f.debug_struct("ExtractorConfig")
80            .field("provider", &self.provider)
81            .field("model", &self.model)
82            .field("base_url", &self.base_url)
83            .field("api_key", &self.api_key.as_ref().map(|_| "[REDACTED]"))
84            .finish()
85    }
86}
87
88impl ExtractorConfig {
89    pub fn none() -> Self {
90        Self {
91            provider: "none".to_string(),
92            model: None,
93            base_url: None,
94            api_key: None,
95        }
96    }
97
98    pub fn gliner() -> Self {
99        Self {
100            provider: "gliner".to_string(),
101            model: None,
102            base_url: None,
103            api_key: None,
104        }
105    }
106}
107
108// ─────────────────────────────────────────────────────────────
109// Provider trait
110// ─────────────────────────────────────────────────────────────
111
112#[async_trait]
113pub trait ExtractionProvider: Send + Sync {
114    async fn extract(&self, text: &str, opts: &ExtractionOpts) -> Result<ExtractionResult>;
115    fn provider_name(&self) -> &'static str;
116}
117
118// ─────────────────────────────────────────────────────────────
119// NoneExtractor — no-op, backward-compatible default
120// ─────────────────────────────────────────────────────────────
121
122pub struct NoneExtractor;
123
124#[async_trait]
125impl ExtractionProvider for NoneExtractor {
126    async fn extract(&self, _text: &str, _opts: &ExtractionOpts) -> Result<ExtractionResult> {
127        Ok(ExtractionResult {
128            provider: "none".to_string(),
129            ..Default::default()
130        })
131    }
132    fn provider_name(&self) -> &'static str {
133        "none"
134    }
135}
136
137// ─────────────────────────────────────────────────────────────
138// GlinerExtractor — wraps CE-4 NerEngine
139// ─────────────────────────────────────────────────────────────
140
141pub struct GlinerExtractor {
142    ner: Arc<RwLock<Option<NerEngine>>>,
143}
144
145impl GlinerExtractor {
146    pub fn new(ner: Arc<RwLock<Option<NerEngine>>>) -> Self {
147        Self { ner }
148    }
149}
150
151#[async_trait]
152impl ExtractionProvider for GlinerExtractor {
153    async fn extract(&self, text: &str, opts: &ExtractionOpts) -> Result<ExtractionResult> {
154        let guard = self.ner.read().await;
155        let type_refs: Vec<&str> = opts.entity_types.iter().map(|s| s.as_str()).collect();
156        let entities = if let Some(ref engine) = *guard {
157            engine.extract(text, &type_refs).await
158        } else {
159            rule_based_extract(text)
160        };
161        Ok(ExtractionResult {
162            entities,
163            provider: "gliner".to_string(),
164            ..Default::default()
165        })
166    }
167    fn provider_name(&self) -> &'static str {
168        "gliner"
169    }
170}
171
172// ─────────────────────────────────────────────────────────────
173// LLM extraction prompt
174// ─────────────────────────────────────────────────────────────
175
176const EXTRACT_SYSTEM: &str =
177    "You are a precise information extractor. Extract structured data from the given text. \
178     Respond with valid JSON only — no markdown, no explanation.";
179
180const EXTRACT_PROMPT_TMPL: &str =
181    "Extract entities, topics, key phrases, and a brief summary from the text below.\n\
182     Respond ONLY with this JSON structure:\n\
183     {\"entities\":[{\"entity_type\":\"person|org|location|date|url|email|uuid|ip\",\
184     \"value\":\"...\",\"score\":0.9,\"start\":0,\"end\":5}],\
185     \"topics\":[\"...\"],\"key_phrases\":[\"...\"],\"summary\":\"...\"}\n\n\
186     Text:\n";
187
188fn build_extraction_prompt(text: &str) -> String {
189    format!("{}{}", EXTRACT_PROMPT_TMPL, text)
190}
191
192fn parse_llm_json(content: &str, provider: &str) -> Result<ExtractionResult> {
193    // Strip markdown code fences if present
194    let raw = content
195        .trim()
196        .trim_start_matches("```json")
197        .trim_start_matches("```")
198        .trim_end_matches("```")
199        .trim();
200
201    let v: serde_json::Value = serde_json::from_str(raw).map_err(|e| {
202        InferenceError::ExtractionFailed(format!("JSON parse error from {provider}: {e}"))
203    })?;
204
205    let entities: Vec<ExtractedEntity> = v["entities"]
206        .as_array()
207        .map(|arr| {
208            arr.iter()
209                .filter_map(|e| serde_json::from_value(e.clone()).ok())
210                .collect()
211        })
212        .unwrap_or_default();
213
214    let topics: Vec<String> = v["topics"]
215        .as_array()
216        .map(|arr| {
217            arr.iter()
218                .filter_map(|t| t.as_str().map(|s| s.to_string()))
219                .collect()
220        })
221        .unwrap_or_default();
222
223    let key_phrases: Vec<String> = v["key_phrases"]
224        .as_array()
225        .map(|arr| {
226            arr.iter()
227                .filter_map(|t| t.as_str().map(|s| s.to_string()))
228                .collect()
229        })
230        .unwrap_or_default();
231
232    let summary = v["summary"].as_str().map(|s| s.to_string());
233
234    Ok(ExtractionResult {
235        entities,
236        topics,
237        key_phrases,
238        summary,
239        provider: provider.to_string(),
240    })
241}
242
243// ─────────────────────────────────────────────────────────────
244// OpenAIExtractor — openai + openrouter + ollama (base_url override)
245// ─────────────────────────────────────────────────────────────
246
247pub struct OpenAIExtractor {
248    /// `api_key` is runtime-only — never stored, redacted in Debug.
249    api_key: String,
250    base_url: String,
251    model: String,
252    provider_id: &'static str,
253    client: reqwest::Client,
254}
255
256impl std::fmt::Debug for OpenAIExtractor {
257    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258        f.debug_struct("OpenAIExtractor")
259            .field("base_url", &self.base_url)
260            .field("model", &self.model)
261            .field("api_key", &"[REDACTED]")
262            .finish()
263    }
264}
265
266impl OpenAIExtractor {
267    pub fn openai(api_key: String, model: Option<String>) -> Self {
268        Self::with_base_url(
269            api_key,
270            "https://api.openai.com/v1".to_string(),
271            model.unwrap_or_else(|| "gpt-4o-mini".to_string()),
272            "openai",
273        )
274    }
275
276    pub fn openrouter(api_key: String, model: Option<String>) -> Self {
277        Self::with_base_url(
278            api_key,
279            "https://openrouter.ai/api/v1".to_string(),
280            model.unwrap_or_else(|| "anthropic/claude-3-haiku".to_string()),
281            "openrouter",
282        )
283    }
284
285    /// Ollama — local OpenAI-compatible server, no auth required.
286    pub fn ollama(base_url: Option<String>, model: Option<String>) -> Self {
287        Self::with_base_url(
288            "ollama".to_string(),
289            base_url.unwrap_or_else(|| "http://localhost:11434/v1".to_string()),
290            model.unwrap_or_else(|| "llama3.1:8b".to_string()),
291            "ollama",
292        )
293    }
294
295    fn with_base_url(
296        api_key: String,
297        base_url: String,
298        model: String,
299        provider_id: &'static str,
300    ) -> Self {
301        Self {
302            api_key,
303            base_url,
304            model,
305            provider_id,
306            client: reqwest::Client::new(),
307        }
308    }
309}
310
311#[async_trait]
312impl ExtractionProvider for OpenAIExtractor {
313    async fn extract(&self, text: &str, _opts: &ExtractionOpts) -> Result<ExtractionResult> {
314        let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
315        let prompt = build_extraction_prompt(text);
316
317        let body = serde_json::json!({
318            "model": self.model,
319            "messages": [
320                {"role": "system", "content": EXTRACT_SYSTEM},
321                {"role": "user", "content": prompt}
322            ],
323            "temperature": 0,
324            "response_format": {"type": "json_object"}
325        });
326
327        let resp = self
328            .client
329            .post(&url)
330            .header("Authorization", format!("Bearer {}", self.api_key))
331            .header("Content-Type", "application/json")
332            .json(&body)
333            .send()
334            .await
335            .map_err(|e| InferenceError::ExtractionFailed(e.to_string()))?;
336
337        if !resp.status().is_success() {
338            let status = resp.status().as_u16();
339            return Err(InferenceError::ExtractionFailed(format!(
340                "{} returned HTTP {status}",
341                self.provider_id
342            )));
343        }
344
345        let json: serde_json::Value = resp
346            .json()
347            .await
348            .map_err(|e| InferenceError::ExtractionFailed(e.to_string()))?;
349
350        let content = json["choices"][0]["message"]["content"]
351            .as_str()
352            .unwrap_or("{}");
353
354        parse_llm_json(content, self.provider_id)
355    }
356
357    fn provider_name(&self) -> &'static str {
358        self.provider_id
359    }
360}
361
362// ─────────────────────────────────────────────────────────────
363// AnthropicExtractor
364// ─────────────────────────────────────────────────────────────
365
366pub struct AnthropicExtractor {
367    api_key: String,
368    model: String,
369    client: reqwest::Client,
370}
371
372impl std::fmt::Debug for AnthropicExtractor {
373    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374        f.debug_struct("AnthropicExtractor")
375            .field("model", &self.model)
376            .field("api_key", &"[REDACTED]")
377            .finish()
378    }
379}
380
381impl AnthropicExtractor {
382    pub fn new(api_key: String, model: Option<String>) -> Self {
383        Self {
384            api_key,
385            model: model.unwrap_or_else(|| "claude-3-haiku-20240307".to_string()),
386            client: reqwest::Client::new(),
387        }
388    }
389}
390
391#[async_trait]
392impl ExtractionProvider for AnthropicExtractor {
393    async fn extract(&self, text: &str, _opts: &ExtractionOpts) -> Result<ExtractionResult> {
394        let prompt = build_extraction_prompt(text);
395
396        let body = serde_json::json!({
397            "model": self.model,
398            "max_tokens": 1024,
399            "system": EXTRACT_SYSTEM,
400            "messages": [{"role": "user", "content": prompt}]
401        });
402
403        let resp = self
404            .client
405            .post("https://api.anthropic.com/v1/messages")
406            .header("x-api-key", &self.api_key)
407            .header("anthropic-version", "2023-06-01")
408            .header("Content-Type", "application/json")
409            .json(&body)
410            .send()
411            .await
412            .map_err(|e| InferenceError::ExtractionFailed(e.to_string()))?;
413
414        if !resp.status().is_success() {
415            let status = resp.status().as_u16();
416            return Err(InferenceError::ExtractionFailed(format!(
417                "anthropic returned HTTP {status}"
418            )));
419        }
420
421        let json: serde_json::Value = resp
422            .json()
423            .await
424            .map_err(|e| InferenceError::ExtractionFailed(e.to_string()))?;
425
426        let content = json["content"][0]["text"].as_str().unwrap_or("{}");
427
428        parse_llm_json(content, "anthropic")
429    }
430
431    fn provider_name(&self) -> &'static str {
432        "anthropic"
433    }
434}
435
436// ─────────────────────────────────────────────────────────────
437// Factory — build a boxed provider from ExtractorConfig
438// ─────────────────────────────────────────────────────────────
439
440/// Build a `Box<dyn ExtractionProvider>` from a config + optional NER engine.
441///
442/// `api_key` in `config` takes precedence over env vars.
443/// For `gliner`, `ner_engine` must be `Some`; if not, falls back to rule-based.
444pub fn build_provider(
445    config: &ExtractorConfig,
446    ner_engine: Option<Arc<RwLock<Option<NerEngine>>>>,
447) -> Box<dyn ExtractionProvider> {
448    match config.provider.as_str() {
449        "gliner" => {
450            if let Some(ner) = ner_engine {
451                Box::new(GlinerExtractor::new(ner))
452            } else {
453                // No NER engine available — run rule-based only via None+fallback
454                warn!("gliner provider requested but NER engine not available — using rule-based");
455                // Synthesise a GlinerExtractor with an empty engine slot
456                Box::new(GlinerExtractor::new(Arc::new(RwLock::new(None))))
457            }
458        }
459        "openai" => {
460            let key = config
461                .api_key
462                .clone()
463                .or_else(|| std::env::var("OPENAI_API_KEY").ok())
464                .unwrap_or_default();
465            Box::new(OpenAIExtractor::openai(key, config.model.clone()))
466        }
467        "openrouter" => {
468            let key = config
469                .api_key
470                .clone()
471                .or_else(|| std::env::var("OPENROUTER_API_KEY").ok())
472                .unwrap_or_default();
473            Box::new(OpenAIExtractor::openrouter(key, config.model.clone()))
474        }
475        "ollama" => Box::new(OpenAIExtractor::ollama(
476            config.base_url.clone(),
477            config.model.clone(),
478        )),
479        "anthropic" => {
480            let key = config
481                .api_key
482                .clone()
483                .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
484                .unwrap_or_default();
485            Box::new(AnthropicExtractor::new(key, config.model.clone()))
486        }
487        _ => Box::new(NoneExtractor),
488    }
489}