use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::warn;
use crate::error::{InferenceError, Result};
use crate::ner::{rule_based_extract, ExtractedEntity, NerEngine};
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ExtractionResult {
pub entities: Vec<ExtractedEntity>,
pub topics: Vec<String>,
pub key_phrases: Vec<String>,
pub summary: Option<String>,
pub provider: String,
}
#[derive(Debug, Clone, Default)]
pub struct ExtractionOpts {
pub entity_types: Vec<String>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct ExtractorConfig {
pub provider: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub base_url: Option<String>,
#[serde(skip)]
pub api_key: Option<String>,
}
impl std::fmt::Debug for ExtractorConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ExtractorConfig")
.field("provider", &self.provider)
.field("model", &self.model)
.field("base_url", &self.base_url)
.field("api_key", &self.api_key.as_ref().map(|_| "[REDACTED]"))
.finish()
}
}
impl ExtractorConfig {
pub fn none() -> Self {
Self {
provider: "none".to_string(),
model: None,
base_url: None,
api_key: None,
}
}
pub fn gliner() -> Self {
Self {
provider: "gliner".to_string(),
model: None,
base_url: None,
api_key: None,
}
}
}
#[async_trait]
pub trait ExtractionProvider: Send + Sync {
async fn extract(&self, text: &str, opts: &ExtractionOpts) -> Result<ExtractionResult>;
fn provider_name(&self) -> &'static str;
}
pub struct NoneExtractor;
#[async_trait]
impl ExtractionProvider for NoneExtractor {
async fn extract(&self, _text: &str, _opts: &ExtractionOpts) -> Result<ExtractionResult> {
Ok(ExtractionResult {
provider: "none".to_string(),
..Default::default()
})
}
fn provider_name(&self) -> &'static str {
"none"
}
}
pub struct GlinerExtractor {
ner: Arc<RwLock<Option<NerEngine>>>,
}
impl GlinerExtractor {
pub fn new(ner: Arc<RwLock<Option<NerEngine>>>) -> Self {
Self { ner }
}
}
#[async_trait]
impl ExtractionProvider for GlinerExtractor {
async fn extract(&self, text: &str, opts: &ExtractionOpts) -> Result<ExtractionResult> {
let guard = self.ner.read().await;
let type_refs: Vec<&str> = opts.entity_types.iter().map(|s| s.as_str()).collect();
let entities = if let Some(ref engine) = *guard {
engine.extract(text, &type_refs).await
} else {
rule_based_extract(text)
};
Ok(ExtractionResult {
entities,
provider: "gliner".to_string(),
..Default::default()
})
}
fn provider_name(&self) -> &'static str {
"gliner"
}
}
const EXTRACT_SYSTEM: &str =
"You are a precise information extractor. Extract structured data from the given text. \
Respond with valid JSON only — no markdown, no explanation.";
const EXTRACT_PROMPT_TMPL: &str =
"Extract entities, topics, key phrases, and a brief summary from the text below.\n\
Respond ONLY with this JSON structure:\n\
{\"entities\":[{\"entity_type\":\"person|org|location|date|url|email|uuid|ip\",\
\"value\":\"...\",\"score\":0.9,\"start\":0,\"end\":5}],\
\"topics\":[\"...\"],\"key_phrases\":[\"...\"],\"summary\":\"...\"}\n\n\
Text:\n";
fn build_extraction_prompt(text: &str) -> String {
format!("{}{}", EXTRACT_PROMPT_TMPL, text)
}
fn parse_llm_json(content: &str, provider: &str) -> Result<ExtractionResult> {
let raw = content
.trim()
.trim_start_matches("```json")
.trim_start_matches("```")
.trim_end_matches("```")
.trim();
let v: serde_json::Value = serde_json::from_str(raw).map_err(|e| {
InferenceError::ExtractionFailed(format!("JSON parse error from {provider}: {e}"))
})?;
let entities: Vec<ExtractedEntity> = v["entities"]
.as_array()
.map(|arr| {
arr.iter()
.filter_map(|e| serde_json::from_value(e.clone()).ok())
.collect()
})
.unwrap_or_default();
let topics: Vec<String> = v["topics"]
.as_array()
.map(|arr| {
arr.iter()
.filter_map(|t| t.as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
let key_phrases: Vec<String> = v["key_phrases"]
.as_array()
.map(|arr| {
arr.iter()
.filter_map(|t| t.as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
let summary = v["summary"].as_str().map(|s| s.to_string());
Ok(ExtractionResult {
entities,
topics,
key_phrases,
summary,
provider: provider.to_string(),
})
}
pub struct OpenAIExtractor {
api_key: String,
base_url: String,
model: String,
provider_id: &'static str,
client: reqwest::Client,
}
impl std::fmt::Debug for OpenAIExtractor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OpenAIExtractor")
.field("base_url", &self.base_url)
.field("model", &self.model)
.field("api_key", &"[REDACTED]")
.finish()
}
}
impl OpenAIExtractor {
pub fn openai(api_key: String, model: Option<String>) -> Self {
Self::with_base_url(
api_key,
"https://api.openai.com/v1".to_string(),
model.unwrap_or_else(|| "gpt-4o-mini".to_string()),
"openai",
)
}
pub fn openrouter(api_key: String, model: Option<String>) -> Self {
Self::with_base_url(
api_key,
"https://openrouter.ai/api/v1".to_string(),
model.unwrap_or_else(|| "anthropic/claude-3-haiku".to_string()),
"openrouter",
)
}
pub fn ollama(base_url: Option<String>, model: Option<String>) -> Self {
Self::with_base_url(
"ollama".to_string(),
base_url.unwrap_or_else(|| "http://localhost:11434/v1".to_string()),
model.unwrap_or_else(|| "llama3.1:8b".to_string()),
"ollama",
)
}
fn with_base_url(
api_key: String,
base_url: String,
model: String,
provider_id: &'static str,
) -> Self {
Self {
api_key,
base_url,
model,
provider_id,
client: reqwest::Client::new(),
}
}
}
#[async_trait]
impl ExtractionProvider for OpenAIExtractor {
async fn extract(&self, text: &str, _opts: &ExtractionOpts) -> Result<ExtractionResult> {
let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
let prompt = build_extraction_prompt(text);
let body = serde_json::json!({
"model": self.model,
"messages": [
{"role": "system", "content": EXTRACT_SYSTEM},
{"role": "user", "content": prompt}
],
"temperature": 0,
"response_format": {"type": "json_object"}
});
let resp = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| InferenceError::ExtractionFailed(e.to_string()))?;
if !resp.status().is_success() {
let status = resp.status().as_u16();
return Err(InferenceError::ExtractionFailed(format!(
"{} returned HTTP {status}",
self.provider_id
)));
}
let json: serde_json::Value = resp
.json()
.await
.map_err(|e| InferenceError::ExtractionFailed(e.to_string()))?;
let content = json["choices"][0]["message"]["content"]
.as_str()
.unwrap_or("{}");
parse_llm_json(content, self.provider_id)
}
fn provider_name(&self) -> &'static str {
self.provider_id
}
}
pub struct AnthropicExtractor {
api_key: String,
model: String,
client: reqwest::Client,
}
impl std::fmt::Debug for AnthropicExtractor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AnthropicExtractor")
.field("model", &self.model)
.field("api_key", &"[REDACTED]")
.finish()
}
}
impl AnthropicExtractor {
pub fn new(api_key: String, model: Option<String>) -> Self {
Self {
api_key,
model: model.unwrap_or_else(|| "claude-3-haiku-20240307".to_string()),
client: reqwest::Client::new(),
}
}
}
#[async_trait]
impl ExtractionProvider for AnthropicExtractor {
async fn extract(&self, text: &str, _opts: &ExtractionOpts) -> Result<ExtractionResult> {
let prompt = build_extraction_prompt(text);
let body = serde_json::json!({
"model": self.model,
"max_tokens": 1024,
"system": EXTRACT_SYSTEM,
"messages": [{"role": "user", "content": prompt}]
});
let resp = self
.client
.post("https://api.anthropic.com/v1/messages")
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| InferenceError::ExtractionFailed(e.to_string()))?;
if !resp.status().is_success() {
let status = resp.status().as_u16();
return Err(InferenceError::ExtractionFailed(format!(
"anthropic returned HTTP {status}"
)));
}
let json: serde_json::Value = resp
.json()
.await
.map_err(|e| InferenceError::ExtractionFailed(e.to_string()))?;
let content = json["content"][0]["text"].as_str().unwrap_or("{}");
parse_llm_json(content, "anthropic")
}
fn provider_name(&self) -> &'static str {
"anthropic"
}
}
pub fn build_provider(
config: &ExtractorConfig,
ner_engine: Option<Arc<RwLock<Option<NerEngine>>>>,
) -> Box<dyn ExtractionProvider> {
match config.provider.as_str() {
"gliner" => {
if let Some(ner) = ner_engine {
Box::new(GlinerExtractor::new(ner))
} else {
warn!("gliner provider requested but NER engine not available — using rule-based");
Box::new(GlinerExtractor::new(Arc::new(RwLock::new(None))))
}
}
"openai" => {
let key = config
.api_key
.clone()
.or_else(|| std::env::var("OPENAI_API_KEY").ok())
.unwrap_or_default();
Box::new(OpenAIExtractor::openai(key, config.model.clone()))
}
"openrouter" => {
let key = config
.api_key
.clone()
.or_else(|| std::env::var("OPENROUTER_API_KEY").ok())
.unwrap_or_default();
Box::new(OpenAIExtractor::openrouter(key, config.model.clone()))
}
"ollama" => Box::new(OpenAIExtractor::ollama(
config.base_url.clone(),
config.model.clone(),
)),
"anthropic" => {
let key = config
.api_key
.clone()
.or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
.unwrap_or_default();
Box::new(AnthropicExtractor::new(key, config.model.clone()))
}
_ => Box::new(NoneExtractor),
}
}