Skip to main content

anno/backends/
universal_ner.rs

1//! UniversalNER: LLM-based Zero-Shot NER
2//!
3//! UniversalNER uses instruction-tuned LLMs (LLaMA-based) for open NER,
4//! supporting 45+ entity types without retraining.
5//!
6//! # Architecture
7//!
8//! UniversalNER is fundamentally different from transformer-based NER:
9//! - **LLM-based**: Uses large language models (LLaMA) with instruction tuning
10//! - **Prompt-based**: Extracts entities via natural language prompts
11//! - **Very flexible**: Supports any entity type via prompt engineering
12//! - **Expensive**: Slower and more costly than transformer models
13//!
14//! # Research
15//!
16//! - **Paper**: [UniversalNER](https://universal-ner.github.io)
17//! - **Performance**: Competitive with ChatGPT on NER tasks
18//! - **Capabilities**: 45 entity types, unlimited via prompts
19//!
20//! # Usage
21//!
22//! ```rust,ignore
23//! use anno::backends::universal_ner::UniversalNER;
24//!
25//! let model = UniversalNER::new()?;
26//! let entities = model.extract_entities(
27//!     "Steve Jobs founded Apple in 1976.",
28//!     &["person", "organization", "date"]
29//! )?;
30//! ```
31//!
32//! # Implementation Status
33//!
34//! This backend is LLM-backed and requires:
35//! - A supported API provider (OpenAI / Anthropic / OpenRouter)
36//! - An API key in the environment (loaded from `.env` if present)
37//! - The `llm` feature for HTTP calls (`ureq`)
38//!
39//! Behavior is **explicit**:
40//! - If unavailable, `extract_*` returns `FeatureNotAvailable` (no silent empty fallback).
41//!
42//! # Environment Variables
43//!
44//! Automatically loads from `.env` if present. Supported keys:
45//! - `OPENAI_API_KEY` - OpenAI API
46//! - `OPENROUTER_API_KEY` - OpenRouter API  
47//! - `GEMINI_API_KEY` - Google Gemini API
48//! - `ANTHROPIC_API_KEY` - Anthropic API
49//! - `UNIVERSAL_NER_API_KEY` - Dedicated UniversalNER key
50
51use crate::backends::inference::ZeroShotNER;
52use crate::offset::TextSpan;
53use crate::{Entity, EntityType, Model, Result};
54
55/// UniversalNER backend for LLM-based zero-shot NER.
56///
57/// Automatically loads API keys from `.env` if present.
58/// Returns explicit errors when unavailable - use `is_available()` to check.
59pub struct UniversalNER {
60    /// Whether LLM backend is available
61    llm_available: bool,
62}
63
64impl UniversalNER {
65    /// Create a new UniversalNER instance.
66    ///
67    /// Opportunistically loads `.env` file to check for API keys.
68    /// Check `is_available()` before use. Returns an explicit error when unavailable.
69    pub fn new() -> Result<Self> {
70        // Load .env if present (idempotent)
71        crate::env::load_dotenv();
72
73        // LLM availability depends on:
74        // - compile-time feature (`llm`) for HTTP support
75        // - runtime configuration (API key)
76        let universal_key = std::env::var("UNIVERSAL_NER_API_KEY")
77            .ok()
78            .is_some_and(|v| !v.trim().is_empty());
79        let llm_available =
80            cfg!(feature = "llm") && (crate::env::has_llm_api_key() || universal_key);
81
82        Ok(Self { llm_available })
83    }
84
85    /// Extract entities using LLM-based prompt engineering.
86    ///
87    /// Calls OpenAI-compatible API with structured NER prompt.
88    /// Requires `llm` feature for HTTP client (ureq).
89    #[cfg(feature = "llm")]
90    fn extract_with_llm(&self, text: &str, entity_types: &[&str]) -> Result<Vec<Entity>> {
91        let (api_key, provider) = crate::env::llm_api_key().ok_or_else(|| {
92            crate::Error::FeatureNotAvailable(
93                "No LLM API key found. Set OPENAI_API_KEY, ANTHROPIC_API_KEY, or similar.".into(),
94            )
95        })?;
96
97        let types_str = entity_types.join(", ");
98        let prompt = format!(
99            r#"Extract named entities from the following text. Return ONLY a JSON array of objects with "text", "type", "start", "end" fields.
100
101Entity types to extract: {types_str}
102
103Text: "{text}"
104
105Example output: [{{"text": "John Smith", "type": "person", "start": 0, "end": 10}}]
106
107Return ONLY the JSON array, no other text:"#
108        );
109
110        let (url, model, auth_header) = match provider {
111            "openai" => (
112                "https://api.openai.com/v1/chat/completions",
113                "gpt-4o-mini",
114                format!("Bearer {}", api_key),
115            ),
116            "anthropic" => (
117                "https://api.anthropic.com/v1/messages",
118                "claude-3-haiku-20240307",
119                api_key.clone(),
120            ),
121            "openrouter" => (
122                "https://openrouter.ai/api/v1/chat/completions",
123                "openai/gpt-4o-mini",
124                format!("Bearer {}", api_key),
125            ),
126            other => {
127                return Err(crate::Error::FeatureNotAvailable(format!(
128                    "UniversalNER provider '{}' is not supported by this build. Supported: openai, anthropic, openrouter.",
129                    other
130                )));
131            }
132        };
133
134        let response = if provider == "anthropic" {
135            // Anthropic uses different API format
136            let body = serde_json::json!({
137                "model": model,
138                "max_tokens": 1024,
139                "messages": [{"role": "user", "content": prompt}]
140            });
141            ureq::post(url)
142                .set("x-api-key", &auth_header)
143                .set("anthropic-version", "2023-06-01")
144                .set("content-type", "application/json")
145                .send_json(body)
146        } else {
147            // OpenAI-compatible format
148            let body = serde_json::json!({
149                "model": model,
150                "messages": [{"role": "user", "content": prompt}],
151                "temperature": 0.0
152            });
153            ureq::post(url)
154                .set("Authorization", &auth_header)
155                .set("content-type", "application/json")
156                .send_json(body)
157        };
158
159        let response =
160            response.map_err(|e| crate::Error::Inference(format!("LLM API error: {}", e)))?;
161        let json: serde_json::Value = response
162            .into_json()
163            .map_err(|e| crate::Error::Parse(format!("LLM response parse error: {}", e)))?;
164
165        // Extract content from response
166        let content = if provider == "anthropic" {
167            json["content"][0]["text"].as_str().unwrap_or("[]")
168        } else {
169            json["choices"][0]["message"]["content"]
170                .as_str()
171                .unwrap_or("[]")
172        };
173
174        // Parse JSON array of entities
175        self.parse_llm_response(content, text)
176    }
177
178    /// Fallback when `llm` feature is not enabled.
179    #[cfg(not(feature = "llm"))]
180    fn extract_with_llm(&self, _text: &str, _entity_types: &[&str]) -> Result<Vec<Entity>> {
181        Err(crate::Error::FeatureNotAvailable(
182            "UniversalNER requires the 'llm' feature to make HTTP requests (ureq). Rebuild with --features llm and provide an API key via .env."
183                .into(),
184        ))
185    }
186
187    /// Parse LLM response into entities.
188    ///
189    /// This is **pure** (no HTTP) and therefore always compiled so we can unit test it
190    /// without network access.
191    #[allow(dead_code)] // Used by `extract_with_llm` (when enabled) and unit tests.
192    fn parse_llm_response(&self, content: &str, original_text: &str) -> Result<Vec<Entity>> {
193        // Try to extract JSON array from response. Some providers wrap responses in
194        // markdown/code fences or include extra explanation text.
195        let json_str = content.trim();
196        let json_str = json_str
197            .strip_prefix("```json")
198            .or_else(|| json_str.strip_prefix("```JSON"))
199            .or_else(|| json_str.strip_prefix("```"))
200            .unwrap_or(json_str)
201            .trim();
202        let json_str = json_str.strip_suffix("```").unwrap_or(json_str).trim();
203
204        let json_str = if json_str.starts_with('[') {
205            json_str.to_string()
206        } else if let Some(start) = json_str.find('[') {
207            if let Some(end) = json_str.rfind(']') {
208                json_str[start..=end].to_string()
209            } else {
210                return Err(crate::Error::Parse(format!(
211                    "UniversalNER LLM response did not contain a complete JSON array. Response begins: {:?}",
212                    json_str.chars().take(200).collect::<String>()
213                )));
214            }
215        } else {
216            return Err(crate::Error::Parse(format!(
217                "UniversalNER LLM response did not contain a JSON array. Response begins: {:?}",
218                json_str.chars().take(200).collect::<String>()
219            )));
220        };
221
222        let items: Vec<serde_json::Value> = serde_json::from_str(&json_str).map_err(|e| {
223            crate::Error::Parse(format!(
224                "UniversalNER failed to parse JSON array from LLM response: {}. Extracted JSON begins: {:?}",
225                e,
226                json_str.chars().take(200).collect::<String>()
227            ))
228        })?;
229
230        let mut entities = Vec::new();
231        for item in items {
232            let text = item["text"].as_str().unwrap_or("");
233            let type_str = item["type"].as_str().unwrap_or("misc");
234            // Treat provided offsets as **character offsets** hints (LLMs are often wrong).
235            let hint_start = item["start"].as_u64().unwrap_or(0) as usize;
236            let hint_end = item["end"].as_u64().unwrap_or(0) as usize;
237
238            if text.is_empty() || hint_end <= hint_start {
239                continue;
240            }
241
242            // Prefer exact substring matches in the original text; choose the occurrence that
243            // best matches the hint offsets. This avoids the "first occurrence" bug when the
244            // same surface form appears multiple times.
245            let mut occurrences: Vec<(usize, usize)> = Vec::new();
246            for (start_byte, _) in original_text.match_indices(text) {
247                let span = TextSpan::from_bytes(original_text, start_byte, start_byte + text.len());
248                occurrences.push((span.char_start, span.char_end));
249            }
250
251            let (actual_start, actual_end) = if !occurrences.is_empty() {
252                *occurrences
253                    .iter()
254                    .min_by_key(|(s, e)| {
255                        let ds = (*s as isize - hint_start as isize).unsigned_abs();
256                        let de = (*e as isize - hint_end as isize).unsigned_abs();
257                        (ds + de, *s, *e)
258                    })
259                    .expect("non-empty occurrences")
260            } else {
261                // Fallback: accept hint offsets only if they round-trip to the claimed text.
262                let char_count = original_text.chars().count();
263                if hint_end <= char_count {
264                    let extracted = TextSpan::from_chars(original_text, hint_start, hint_end)
265                        .extract(original_text);
266                    if extracted == text {
267                        (hint_start, hint_end)
268                    } else {
269                        continue;
270                    }
271                } else {
272                    continue;
273                }
274            };
275
276            let entity_type = match type_str.to_lowercase().as_str() {
277                "person" | "per" => EntityType::Person,
278                "organization" | "org" => EntityType::Organization,
279                "location" | "loc" | "gpe" => EntityType::Location,
280                "date" | "time" => EntityType::Date,
281                "money" | "currency" => EntityType::Money,
282                _ => EntityType::Other(type_str.to_string()),
283            };
284
285            let mut entity = Entity::new(
286                text.to_string(),
287                entity_type,
288                actual_start,
289                actual_end,
290                0.9, // LLM-based, high confidence
291            );
292            entity.provenance = Some(crate::Provenance::ml("universal_ner", entity.confidence));
293            entities.push(entity);
294        }
295
296        Ok(entities)
297    }
298}
299
300impl Model for UniversalNER {
301    fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
302        if !self.llm_available {
303            return Err(crate::Error::FeatureNotAvailable(
304                "UniversalNER requires an LLM API key. Set one of: OPENAI_API_KEY, ANTHROPIC_API_KEY, OPENROUTER_API_KEY, GEMINI_API_KEY, or UNIVERSAL_NER_API_KEY (loaded from .env if present)."
305                    .into(),
306            ));
307        }
308
309        self.extract_with_llm(text, &["person", "organization", "location"])
310    }
311
312    fn supported_types(&self) -> Vec<EntityType> {
313        vec![
314            EntityType::Person,
315            EntityType::Organization,
316            EntityType::Location,
317        ]
318    }
319
320    fn is_available(&self) -> bool {
321        self.llm_available
322    }
323
324    fn name(&self) -> &'static str {
325        "universal_ner"
326    }
327
328    fn description(&self) -> &'static str {
329        "UniversalNER: LLM-based zero-shot NER (requires `llm` feature + API key)"
330    }
331
332    fn capabilities(&self) -> crate::ModelCapabilities {
333        crate::ModelCapabilities {
334            dynamic_labels: true,
335            ..Default::default()
336        }
337    }
338}
339
340impl crate::NamedEntityCapable for UniversalNER {}
341
342impl crate::DynamicLabels for UniversalNER {
343    fn extract_with_labels(
344        &self,
345        text: &str,
346        labels: &[&str],
347        _language: Option<&str>,
348    ) -> crate::Result<Vec<Entity>> {
349        <Self as ZeroShotNER>::extract_with_types(self, text, labels, 0.3)
350    }
351}
352
353impl ZeroShotNER for UniversalNER {
354    fn default_types(&self) -> &[&'static str] {
355        &["person", "organization", "location"]
356    }
357
358    fn extract_with_types(
359        &self,
360        text: &str,
361        entity_types: &[&str],
362        _threshold: f32,
363    ) -> Result<Vec<Entity>> {
364        if !self.llm_available {
365            return Err(crate::Error::FeatureNotAvailable(
366                "UniversalNER requires an LLM API key. Set one of: OPENAI_API_KEY, ANTHROPIC_API_KEY, OPENROUTER_API_KEY, GEMINI_API_KEY, or UNIVERSAL_NER_API_KEY (loaded from .env if present)."
367                    .into(),
368            ));
369        }
370        self.extract_with_llm(text, entity_types)
371    }
372
373    fn extract_with_descriptions(
374        &self,
375        text: &str,
376        descriptions: &[&str],
377        threshold: f32,
378    ) -> Result<Vec<Entity>> {
379        // For UniversalNER, descriptions are treated as entity types
380        self.extract_with_types(text, descriptions, threshold)
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387    use std::sync::{Mutex, OnceLock};
388
389    #[test]
390    fn test_universal_ner_creation() {
391        let model = UniversalNER::new().unwrap();
392        assert_eq!(model.name(), "universal_ner");
393    }
394
395    #[test]
396    fn test_universal_ner_availability_reflects_api_key() {
397        // Env vars are global; serialize to avoid interference with other tests.
398        static ENV_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
399        let _guard = ENV_LOCK
400            .get_or_init(|| Mutex::new(()))
401            .lock()
402            .unwrap_or_else(|e| e.into_inner());
403
404        // Override any `.env` values (dotenv only sets if unset).
405        for k in [
406            "OPENAI_API_KEY",
407            "ANTHROPIC_API_KEY",
408            "OPENROUTER_API_KEY",
409            "GEMINI_API_KEY",
410            "UNIVERSAL_NER_API_KEY",
411        ] {
412            std::env::set_var(k, "");
413        }
414
415        let model = UniversalNER::new().unwrap();
416        assert!(
417            !model.is_available(),
418            "Empty keys must not count as available"
419        );
420
421        std::env::set_var("UNIVERSAL_NER_API_KEY", "dummy");
422        let model2 = UniversalNER::new().unwrap();
423        assert_eq!(model2.is_available(), cfg!(feature = "llm"));
424    }
425
426    #[test]
427    fn test_universal_ner_errors_without_llm() {
428        let model = UniversalNER::new().unwrap();
429        if !model.is_available() {
430            // Without LLM, should return explicit error (not silent empty).
431            let result = model.extract_entities("Steve Jobs founded Apple", None);
432            assert!(result.is_err());
433        }
434    }
435
436    #[test]
437    fn test_parse_llm_response_handles_code_fences_and_multiscript() {
438        let model = UniversalNER::new().unwrap();
439        let text = "李明 met Müller in الرياض. 😀";
440        let response = r#"```json
441[
442  {"text":"李明","type":"person","start":0,"end":2},
443  {"text":"Müller","type":"person","start":7,"end":13},
444  {"text":"الرياض","type":"location","start":17,"end":23},
445  {"text":"😀","type":"misc","start":25,"end":26}
446]
447```"#;
448        let ents = model.parse_llm_response(response, text).expect("parse");
449        assert!(!ents.is_empty());
450
451        for e in ents {
452            let extracted = TextSpan::from_chars(text, e.start, e.end).extract(text);
453            assert_eq!(extracted, e.text, "entity span should round-trip");
454        }
455    }
456
457    #[test]
458    fn test_parse_llm_response_repeated_surface_form_uses_hint_offsets() {
459        let model = UniversalNER::new().unwrap();
460        let text = "Apple met Apple in Apple Park.";
461        // Intentionally provide multiple occurrences with different hint offsets.
462        let response = r#"[{"text":"Apple","type":"org","start":0,"end":5},{"text":"Apple","type":"org","start":10,"end":15},{"text":"Apple","type":"org","start":19,"end":24}]"#;
463        let ents = model.parse_llm_response(response, text).expect("parse");
464
465        let apples: Vec<_> = ents.into_iter().filter(|e| e.text == "Apple").collect();
466        assert_eq!(apples.len(), 3);
467        let mut starts: Vec<usize> = apples.iter().map(|e| e.start).collect();
468        starts.sort_unstable();
469        starts.dedup();
470        assert_eq!(
471            starts.len(),
472            3,
473            "each Apple should map to a distinct occurrence"
474        );
475    }
476}