Skip to main content

inference/
extraction.rs

1//! EXT-1 — External Extraction Providers.
2//!
3//! Implements an `ExtractionProvider` trait with five concrete backends:
4//!
5//! | Provider | Transport | Auth |
6//! |----------|-----------|------|
7//! | `none` | — | — |
8//! | `gliner` | In-process ONNX | — |
9//! | `openai` | HTTPS | Bearer key |
10//! | `openrouter` | HTTPS (base_url override) | Bearer key |
11//! | `ollama` | HTTP (local) | — |
12//! | `anthropic` | HTTPS | x-api-key header |
13//!
14//! **Security contract:** `api_key` fields redact themselves in `Debug` output
15//! and are never serialized (only used at call time via per-request override or
16//! server env). They are NOT written to any storage layer.
17//!
18//! **Provider hierarchy** (highest priority wins):
19//! 1. Per-request `extractor_override` in request body
20//! 2. Per-namespace default (stored in `_dakera_namespace_configs`)
21//! 3. Server default (`[extractor]` in config.toml)
22//! 4. GLiNER local (if namespace has `extract_entities = true`)
23//! 5. `none` (default — backward-compatible)
24
25use async_trait::async_trait;
26use serde::{Deserialize, Serialize};
27use std::sync::Arc;
28use tokio::sync::RwLock;
29use tracing::warn;
30
31use crate::error::{InferenceError, Result};
32use crate::ner::{rule_based_extract, ExtractedEntity, NerEngine};
33
34// ─────────────────────────────────────────────────────────────
35// Public types
36// ─────────────────────────────────────────────────────────────
37
38/// Result returned by any extraction provider.
39#[derive(Debug, Clone, Serialize, Deserialize, Default)]
40pub struct ExtractionResult {
41    pub entities: Vec<ExtractedEntity>,
42    pub topics: Vec<String>,
43    pub key_phrases: Vec<String>,
44    pub summary: Option<String>,
45    /// Which provider produced this result.
46    pub provider: String,
47}
48
49/// Options passed into `ExtractionProvider::extract`.
50#[derive(Debug, Clone, Default)]
51pub struct ExtractionOpts {
52    /// GLiNER entity types (e.g. `["person", "org"]`). Ignored by LLM providers.
53    pub entity_types: Vec<String>,
54}
55
56/// Serialisable configuration stored per-namespace or sent per-request.
57///
58/// The `api_key` field has a custom `Debug` impl that redacts the value
59/// and is tagged `#[serde(skip)]` — it is never written to storage.
60#[derive(Clone, Serialize, Deserialize)]
61pub struct ExtractorConfig {
62    /// Provider identifier: `none`, `gliner`, `openai`, `anthropic`,
63    /// `openrouter`, `ollama`.
64    pub provider: String,
65    /// Model name (provider-specific). `None` → use the recommended default.
66    #[serde(skip_serializing_if = "Option::is_none")]
67    pub model: Option<String>,
68    /// Base URL override — used for `openrouter` and `ollama`.
69    #[serde(skip_serializing_if = "Option::is_none")]
70    pub base_url: Option<String>,
71    /// API key — NEVER persisted, NEVER logged. Present only in per-request
72    /// overrides or resolved from server env at call time.
73    #[serde(skip)]
74    pub api_key: Option<String>,
75}
76
77impl std::fmt::Debug for ExtractorConfig {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        f.debug_struct("ExtractorConfig")
80            .field("provider", &self.provider)
81            .field("model", &self.model)
82            .field("base_url", &self.base_url)
83            .field("api_key", &self.api_key.as_ref().map(|_| "[REDACTED]"))
84            .finish()
85    }
86}
87
88impl ExtractorConfig {
89    pub fn none() -> Self {
90        Self {
91            provider: "none".to_string(),
92            model: None,
93            base_url: None,
94            api_key: None,
95        }
96    }
97
98    pub fn gliner() -> Self {
99        Self {
100            provider: "gliner".to_string(),
101            model: None,
102            base_url: None,
103            api_key: None,
104        }
105    }
106}
107
108// ─────────────────────────────────────────────────────────────
109// Provider trait
110// ─────────────────────────────────────────────────────────────
111
112#[async_trait]
113pub trait ExtractionProvider: Send + Sync {
114    async fn extract(&self, text: &str, opts: &ExtractionOpts) -> Result<ExtractionResult>;
115    fn provider_name(&self) -> &'static str;
116}
117
118// ─────────────────────────────────────────────────────────────
119// NoneExtractor — no-op, backward-compatible default
120// ─────────────────────────────────────────────────────────────
121
122pub struct NoneExtractor;
123
124#[async_trait]
125impl ExtractionProvider for NoneExtractor {
126    async fn extract(&self, _text: &str, _opts: &ExtractionOpts) -> Result<ExtractionResult> {
127        Ok(ExtractionResult {
128            provider: "none".to_string(),
129            ..Default::default()
130        })
131    }
132    fn provider_name(&self) -> &'static str {
133        "none"
134    }
135}
136
137// ─────────────────────────────────────────────────────────────
138// GlinerExtractor — wraps CE-4 NerEngine
139// ─────────────────────────────────────────────────────────────
140
141pub struct GlinerExtractor {
142    ner: Arc<RwLock<Option<NerEngine>>>,
143}
144
145impl GlinerExtractor {
146    pub fn new(ner: Arc<RwLock<Option<NerEngine>>>) -> Self {
147        Self { ner }
148    }
149}
150
151#[async_trait]
152impl ExtractionProvider for GlinerExtractor {
153    async fn extract(&self, text: &str, opts: &ExtractionOpts) -> Result<ExtractionResult> {
154        let guard = self.ner.read().await;
155        let type_refs: Vec<&str> = opts.entity_types.iter().map(|s| s.as_str()).collect();
156        let entities = if let Some(ref engine) = *guard {
157            engine.extract(text, &type_refs).await
158        } else {
159            rule_based_extract(text)
160        };
161        Ok(ExtractionResult {
162            entities,
163            provider: "gliner".to_string(),
164            ..Default::default()
165        })
166    }
167    fn provider_name(&self) -> &'static str {
168        "gliner"
169    }
170}
171
172// ─────────────────────────────────────────────────────────────
173// LLM extraction prompt
174// ─────────────────────────────────────────────────────────────
175
176const EXTRACT_SYSTEM: &str =
177    "You are a precise information extractor. Extract structured data from the given text. \
178     Respond with valid JSON only — no markdown, no explanation.";
179
180/// Default entity types used when the caller provides none.
181const DEFAULT_ENTITY_TYPES: &[&str] = &[
182    "person",
183    "organization",
184    "location",
185    "date",
186    "url",
187    "email",
188    "uuid",
189    "ip",
190];
191
192/// Build the LLM extraction prompt, incorporating the caller-requested entity types.
193///
194/// If `entity_types` is empty the default set is used so that the model always
195/// has a concrete list to work against rather than an open-ended instruction.
196fn build_extraction_prompt(text: &str, entity_types: &[String]) -> String {
197    let types_list = if entity_types.is_empty() {
198        DEFAULT_ENTITY_TYPES
199            .iter()
200            .map(|s| s.to_string())
201            .collect::<Vec<_>>()
202    } else {
203        entity_types.to_vec()
204    };
205
206    let type_spec = types_list.join(", ");
207
208    format!(
209        "Extract entities of the following types: {type_spec}.\n\
210         Also extract topics, key phrases, and a brief summary.\n\
211         Respond ONLY with this JSON structure (no markdown):\n\
212         {{\"entities\":[{{\"entity_type\":\"<one of the requested types>\",\
213         \"value\":\"...\",\"score\":0.9,\"start\":0,\"end\":5}}],\
214         \"topics\":[\"...\"],\"key_phrases\":[\"...\"],\"summary\":\"...\"}}\n\n\
215         Text:\n{text}"
216    )
217}
218
219fn parse_llm_json(content: &str, provider: &str) -> Result<ExtractionResult> {
220    // Strip markdown code fences if present
221    let raw = content
222        .trim()
223        .trim_start_matches("```json")
224        .trim_start_matches("```")
225        .trim_end_matches("```")
226        .trim();
227
228    let v: serde_json::Value = serde_json::from_str(raw).map_err(|e| {
229        InferenceError::ExtractionFailed(format!("JSON parse error from {provider}: {e}"))
230    })?;
231
232    let entities: Vec<ExtractedEntity> = v["entities"]
233        .as_array()
234        .map(|arr| {
235            arr.iter()
236                .filter_map(|e| serde_json::from_value(e.clone()).ok())
237                .collect()
238        })
239        .unwrap_or_default();
240
241    let topics: Vec<String> = v["topics"]
242        .as_array()
243        .map(|arr| {
244            arr.iter()
245                .filter_map(|t| t.as_str().map(|s| s.to_string()))
246                .collect()
247        })
248        .unwrap_or_default();
249
250    let key_phrases: Vec<String> = v["key_phrases"]
251        .as_array()
252        .map(|arr| {
253            arr.iter()
254                .filter_map(|t| t.as_str().map(|s| s.to_string()))
255                .collect()
256        })
257        .unwrap_or_default();
258
259    let summary = v["summary"].as_str().map(|s| s.to_string());
260
261    Ok(ExtractionResult {
262        entities,
263        topics,
264        key_phrases,
265        summary,
266        provider: provider.to_string(),
267    })
268}
269
270// ─────────────────────────────────────────────────────────────
271// OpenAIExtractor — openai + openrouter + ollama (base_url override)
272// ─────────────────────────────────────────────────────────────
273
274pub struct OpenAIExtractor {
275    /// `api_key` is runtime-only — never stored, redacted in Debug.
276    api_key: String,
277    base_url: String,
278    model: String,
279    provider_id: &'static str,
280    client: reqwest::Client,
281}
282
283impl std::fmt::Debug for OpenAIExtractor {
284    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285        f.debug_struct("OpenAIExtractor")
286            .field("base_url", &self.base_url)
287            .field("model", &self.model)
288            .field("api_key", &"[REDACTED]")
289            .finish()
290    }
291}
292
293impl OpenAIExtractor {
294    pub fn openai(api_key: String, model: Option<String>) -> Self {
295        Self::with_base_url(
296            api_key,
297            "https://api.openai.com/v1".to_string(),
298            model.unwrap_or_else(|| "gpt-4o-mini".to_string()),
299            "openai",
300        )
301    }
302
303    pub fn openrouter(api_key: String, model: Option<String>) -> Self {
304        Self::with_base_url(
305            api_key,
306            "https://openrouter.ai/api/v1".to_string(),
307            model.unwrap_or_else(|| "anthropic/claude-3-haiku".to_string()),
308            "openrouter",
309        )
310    }
311
312    /// Ollama — local OpenAI-compatible server, no auth required.
313    pub fn ollama(base_url: Option<String>, model: Option<String>) -> Self {
314        Self::with_base_url(
315            "ollama".to_string(),
316            base_url.unwrap_or_else(|| "http://localhost:11434/v1".to_string()),
317            model.unwrap_or_else(|| "llama3.1:8b".to_string()),
318            "ollama",
319        )
320    }
321
322    fn with_base_url(
323        api_key: String,
324        base_url: String,
325        model: String,
326        provider_id: &'static str,
327    ) -> Self {
328        Self {
329            api_key,
330            base_url,
331            model,
332            provider_id,
333            client: reqwest::Client::new(),
334        }
335    }
336}
337
338#[async_trait]
339impl ExtractionProvider for OpenAIExtractor {
340    async fn extract(&self, text: &str, opts: &ExtractionOpts) -> Result<ExtractionResult> {
341        let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
342        let prompt = build_extraction_prompt(text, &opts.entity_types);
343
344        let body = serde_json::json!({
345            "model": self.model,
346            "messages": [
347                {"role": "system", "content": EXTRACT_SYSTEM},
348                {"role": "user", "content": prompt}
349            ],
350            "temperature": 0,
351            "response_format": {"type": "json_object"}
352        });
353
354        let resp = self
355            .client
356            .post(&url)
357            .header("Authorization", format!("Bearer {}", self.api_key))
358            .header("Content-Type", "application/json")
359            .json(&body)
360            .send()
361            .await
362            .map_err(|e| InferenceError::ExtractionFailed(e.to_string()))?;
363
364        if !resp.status().is_success() {
365            let status = resp.status().as_u16();
366            return Err(InferenceError::ExtractionFailed(format!(
367                "{} returned HTTP {status}",
368                self.provider_id
369            )));
370        }
371
372        let json: serde_json::Value = resp
373            .json()
374            .await
375            .map_err(|e| InferenceError::ExtractionFailed(e.to_string()))?;
376
377        let content = json["choices"][0]["message"]["content"]
378            .as_str()
379            .unwrap_or("{}");
380
381        parse_llm_json(content, self.provider_id)
382    }
383
384    fn provider_name(&self) -> &'static str {
385        self.provider_id
386    }
387}
388
389// ─────────────────────────────────────────────────────────────
390// AnthropicExtractor
391// ─────────────────────────────────────────────────────────────
392
393pub struct AnthropicExtractor {
394    api_key: String,
395    model: String,
396    client: reqwest::Client,
397}
398
399impl std::fmt::Debug for AnthropicExtractor {
400    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
401        f.debug_struct("AnthropicExtractor")
402            .field("model", &self.model)
403            .field("api_key", &"[REDACTED]")
404            .finish()
405    }
406}
407
408impl AnthropicExtractor {
409    pub fn new(api_key: String, model: Option<String>) -> Self {
410        Self {
411            api_key,
412            model: model.unwrap_or_else(|| "claude-3-haiku-20240307".to_string()),
413            client: reqwest::Client::new(),
414        }
415    }
416}
417
418#[async_trait]
419impl ExtractionProvider for AnthropicExtractor {
420    async fn extract(&self, text: &str, opts: &ExtractionOpts) -> Result<ExtractionResult> {
421        let prompt = build_extraction_prompt(text, &opts.entity_types);
422
423        let body = serde_json::json!({
424            "model": self.model,
425            "max_tokens": 1024,
426            "system": EXTRACT_SYSTEM,
427            "messages": [{"role": "user", "content": prompt}]
428        });
429
430        let resp = self
431            .client
432            .post("https://api.anthropic.com/v1/messages")
433            .header("x-api-key", &self.api_key)
434            .header("anthropic-version", "2023-06-01")
435            .header("Content-Type", "application/json")
436            .json(&body)
437            .send()
438            .await
439            .map_err(|e| InferenceError::ExtractionFailed(e.to_string()))?;
440
441        if !resp.status().is_success() {
442            let status = resp.status().as_u16();
443            return Err(InferenceError::ExtractionFailed(format!(
444                "anthropic returned HTTP {status}"
445            )));
446        }
447
448        let json: serde_json::Value = resp
449            .json()
450            .await
451            .map_err(|e| InferenceError::ExtractionFailed(e.to_string()))?;
452
453        let content = json["content"][0]["text"].as_str().unwrap_or("{}");
454
455        parse_llm_json(content, "anthropic")
456    }
457
458    fn provider_name(&self) -> &'static str {
459        "anthropic"
460    }
461}
462
463// ─────────────────────────────────────────────────────────────
464// Factory — build a boxed provider from ExtractorConfig
465// ─────────────────────────────────────────────────────────────
466
467/// Build a `Box<dyn ExtractionProvider>` from a config + optional NER engine.
468///
469/// `api_key` in `config` takes precedence over env vars.
470/// For `gliner`, `ner_engine` must be `Some`; if not, falls back to rule-based.
471pub fn build_provider(
472    config: &ExtractorConfig,
473    ner_engine: Option<Arc<RwLock<Option<NerEngine>>>>,
474) -> Box<dyn ExtractionProvider> {
475    match config.provider.as_str() {
476        "gliner" => {
477            if let Some(ner) = ner_engine {
478                Box::new(GlinerExtractor::new(ner))
479            } else {
480                // No NER engine available — run rule-based only via None+fallback
481                warn!("gliner provider requested but NER engine not available — using rule-based");
482                // Synthesise a GlinerExtractor with an empty engine slot
483                Box::new(GlinerExtractor::new(Arc::new(RwLock::new(None))))
484            }
485        }
486        "openai" => {
487            let key = config
488                .api_key
489                .clone()
490                .or_else(|| std::env::var("OPENAI_API_KEY").ok())
491                .unwrap_or_default();
492            Box::new(OpenAIExtractor::openai(key, config.model.clone()))
493        }
494        "openrouter" => {
495            let key = config
496                .api_key
497                .clone()
498                .or_else(|| std::env::var("OPENROUTER_API_KEY").ok())
499                .unwrap_or_default();
500            Box::new(OpenAIExtractor::openrouter(key, config.model.clone()))
501        }
502        "ollama" => Box::new(OpenAIExtractor::ollama(
503            config.base_url.clone(),
504            config.model.clone(),
505        )),
506        "anthropic" => {
507            let key = config
508                .api_key
509                .clone()
510                .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
511                .unwrap_or_default();
512            Box::new(AnthropicExtractor::new(key, config.model.clone()))
513        }
514        _ => Box::new(NoneExtractor),
515    }
516}