1use crate::backends::inference::ZeroShotNER;
52use crate::offset::TextSpan;
53use crate::{Entity, EntityType, Model, Result};
54
55pub struct UniversalNER {
60 llm_available: bool,
62}
63
64impl UniversalNER {
65 pub fn new() -> Result<Self> {
70 crate::env::load_dotenv();
72
73 let universal_key = std::env::var("UNIVERSAL_NER_API_KEY")
77 .ok()
78 .is_some_and(|v| !v.trim().is_empty());
79 let llm_available =
80 cfg!(feature = "llm") && (crate::env::has_llm_api_key() || universal_key);
81
82 Ok(Self { llm_available })
83 }
84
85 #[cfg(feature = "llm")]
90 fn extract_with_llm(&self, text: &str, entity_types: &[&str]) -> Result<Vec<Entity>> {
91 let (api_key, provider) = crate::env::llm_api_key().ok_or_else(|| {
92 crate::Error::FeatureNotAvailable(
93 "No LLM API key found. Set OPENAI_API_KEY, ANTHROPIC_API_KEY, or similar.".into(),
94 )
95 })?;
96
97 let types_str = entity_types.join(", ");
98 let prompt = format!(
99 r#"Extract named entities from the following text. Return ONLY a JSON array of objects with "text", "type", "start", "end" fields.
100
101Entity types to extract: {types_str}
102
103Text: "{text}"
104
105Example output: [{{"text": "John Smith", "type": "person", "start": 0, "end": 10}}]
106
107Return ONLY the JSON array, no other text:"#
108 );
109
110 let (url, model, auth_header) = match provider {
111 "openai" => (
112 "https://api.openai.com/v1/chat/completions",
113 "gpt-4o-mini",
114 format!("Bearer {}", api_key),
115 ),
116 "anthropic" => (
117 "https://api.anthropic.com/v1/messages",
118 "claude-3-haiku-20240307",
119 api_key.clone(),
120 ),
121 "openrouter" => (
122 "https://openrouter.ai/api/v1/chat/completions",
123 "openai/gpt-4o-mini",
124 format!("Bearer {}", api_key),
125 ),
126 other => {
127 return Err(crate::Error::FeatureNotAvailable(format!(
128 "UniversalNER provider '{}' is not supported by this build. Supported: openai, anthropic, openrouter.",
129 other
130 )));
131 }
132 };
133
134 let response = if provider == "anthropic" {
135 let body = serde_json::json!({
137 "model": model,
138 "max_tokens": 1024,
139 "messages": [{"role": "user", "content": prompt}]
140 });
141 ureq::post(url)
142 .set("x-api-key", &auth_header)
143 .set("anthropic-version", "2023-06-01")
144 .set("content-type", "application/json")
145 .send_json(body)
146 } else {
147 let body = serde_json::json!({
149 "model": model,
150 "messages": [{"role": "user", "content": prompt}],
151 "temperature": 0.0
152 });
153 ureq::post(url)
154 .set("Authorization", &auth_header)
155 .set("content-type", "application/json")
156 .send_json(body)
157 };
158
159 let response =
160 response.map_err(|e| crate::Error::Inference(format!("LLM API error: {}", e)))?;
161 let json: serde_json::Value = response
162 .into_json()
163 .map_err(|e| crate::Error::Parse(format!("LLM response parse error: {}", e)))?;
164
165 let content = if provider == "anthropic" {
167 json["content"][0]["text"].as_str().unwrap_or("[]")
168 } else {
169 json["choices"][0]["message"]["content"]
170 .as_str()
171 .unwrap_or("[]")
172 };
173
174 self.parse_llm_response(content, text)
176 }
177
178 #[cfg(not(feature = "llm"))]
180 fn extract_with_llm(&self, _text: &str, _entity_types: &[&str]) -> Result<Vec<Entity>> {
181 Err(crate::Error::FeatureNotAvailable(
182 "UniversalNER requires the 'llm' feature to make HTTP requests (ureq). Rebuild with --features llm and provide an API key via .env."
183 .into(),
184 ))
185 }
186
187 #[allow(dead_code)] fn parse_llm_response(&self, content: &str, original_text: &str) -> Result<Vec<Entity>> {
193 let json_str = content.trim();
196 let json_str = json_str
197 .strip_prefix("```json")
198 .or_else(|| json_str.strip_prefix("```JSON"))
199 .or_else(|| json_str.strip_prefix("```"))
200 .unwrap_or(json_str)
201 .trim();
202 let json_str = json_str.strip_suffix("```").unwrap_or(json_str).trim();
203
204 let json_str = if json_str.starts_with('[') {
205 json_str.to_string()
206 } else if let Some(start) = json_str.find('[') {
207 if let Some(end) = json_str.rfind(']') {
208 json_str[start..=end].to_string()
209 } else {
210 return Err(crate::Error::Parse(format!(
211 "UniversalNER LLM response did not contain a complete JSON array. Response begins: {:?}",
212 json_str.chars().take(200).collect::<String>()
213 )));
214 }
215 } else {
216 return Err(crate::Error::Parse(format!(
217 "UniversalNER LLM response did not contain a JSON array. Response begins: {:?}",
218 json_str.chars().take(200).collect::<String>()
219 )));
220 };
221
222 let items: Vec<serde_json::Value> = serde_json::from_str(&json_str).map_err(|e| {
223 crate::Error::Parse(format!(
224 "UniversalNER failed to parse JSON array from LLM response: {}. Extracted JSON begins: {:?}",
225 e,
226 json_str.chars().take(200).collect::<String>()
227 ))
228 })?;
229
230 let mut entities = Vec::new();
231 for item in items {
232 let text = item["text"].as_str().unwrap_or("");
233 let type_str = item["type"].as_str().unwrap_or("misc");
234 let hint_start = item["start"].as_u64().unwrap_or(0) as usize;
236 let hint_end = item["end"].as_u64().unwrap_or(0) as usize;
237
238 if text.is_empty() || hint_end <= hint_start {
239 continue;
240 }
241
242 let mut occurrences: Vec<(usize, usize)> = Vec::new();
246 for (start_byte, _) in original_text.match_indices(text) {
247 let span = TextSpan::from_bytes(original_text, start_byte, start_byte + text.len());
248 occurrences.push((span.char_start, span.char_end));
249 }
250
251 let (actual_start, actual_end) = if !occurrences.is_empty() {
252 *occurrences
253 .iter()
254 .min_by_key(|(s, e)| {
255 let ds = (*s as isize - hint_start as isize).unsigned_abs();
256 let de = (*e as isize - hint_end as isize).unsigned_abs();
257 (ds + de, *s, *e)
258 })
259 .expect("non-empty occurrences")
260 } else {
261 let char_count = original_text.chars().count();
263 if hint_end <= char_count {
264 let extracted = TextSpan::from_chars(original_text, hint_start, hint_end)
265 .extract(original_text);
266 if extracted == text {
267 (hint_start, hint_end)
268 } else {
269 continue;
270 }
271 } else {
272 continue;
273 }
274 };
275
276 let entity_type = match type_str.to_lowercase().as_str() {
277 "person" | "per" => EntityType::Person,
278 "organization" | "org" => EntityType::Organization,
279 "location" | "loc" | "gpe" => EntityType::Location,
280 "date" | "time" => EntityType::Date,
281 "money" | "currency" => EntityType::Money,
282 _ => EntityType::Other(type_str.to_string()),
283 };
284
285 let mut entity = Entity::new(
286 text.to_string(),
287 entity_type,
288 actual_start,
289 actual_end,
290 0.9, );
292 entity.provenance = Some(crate::Provenance::ml("universal_ner", entity.confidence));
293 entities.push(entity);
294 }
295
296 Ok(entities)
297 }
298}
299
300impl Model for UniversalNER {
301 fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
302 if !self.llm_available {
303 return Err(crate::Error::FeatureNotAvailable(
304 "UniversalNER requires an LLM API key. Set one of: OPENAI_API_KEY, ANTHROPIC_API_KEY, OPENROUTER_API_KEY, GEMINI_API_KEY, or UNIVERSAL_NER_API_KEY (loaded from .env if present)."
305 .into(),
306 ));
307 }
308
309 self.extract_with_llm(text, &["person", "organization", "location"])
310 }
311
312 fn supported_types(&self) -> Vec<EntityType> {
313 vec![
314 EntityType::Person,
315 EntityType::Organization,
316 EntityType::Location,
317 ]
318 }
319
320 fn is_available(&self) -> bool {
321 self.llm_available
322 }
323
324 fn name(&self) -> &'static str {
325 "universal_ner"
326 }
327
328 fn description(&self) -> &'static str {
329 "UniversalNER: LLM-based zero-shot NER (requires `llm` feature + API key)"
330 }
331}
332
333impl ZeroShotNER for UniversalNER {
334 fn default_types(&self) -> &[&'static str] {
335 &["person", "organization", "location"]
336 }
337
338 fn extract_with_types(
339 &self,
340 text: &str,
341 entity_types: &[&str],
342 _threshold: f32,
343 ) -> Result<Vec<Entity>> {
344 if !self.llm_available {
345 return Err(crate::Error::FeatureNotAvailable(
346 "UniversalNER requires an LLM API key. Set one of: OPENAI_API_KEY, ANTHROPIC_API_KEY, OPENROUTER_API_KEY, GEMINI_API_KEY, or UNIVERSAL_NER_API_KEY (loaded from .env if present)."
347 .into(),
348 ));
349 }
350 self.extract_with_llm(text, entity_types)
351 }
352
353 fn extract_with_descriptions(
354 &self,
355 text: &str,
356 descriptions: &[&str],
357 threshold: f32,
358 ) -> Result<Vec<Entity>> {
359 self.extract_with_types(text, descriptions, threshold)
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367 use std::sync::{Mutex, OnceLock};
368
369 #[test]
370 fn test_universal_ner_creation() {
371 let model = UniversalNER::new().unwrap();
372 assert_eq!(model.name(), "universal_ner");
373 }
374
375 #[test]
376 fn test_universal_ner_availability_reflects_api_key() {
377 static ENV_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
379 let _guard = ENV_LOCK
380 .get_or_init(|| Mutex::new(()))
381 .lock()
382 .unwrap_or_else(|e| e.into_inner());
383
384 for k in [
386 "OPENAI_API_KEY",
387 "ANTHROPIC_API_KEY",
388 "OPENROUTER_API_KEY",
389 "GEMINI_API_KEY",
390 "UNIVERSAL_NER_API_KEY",
391 ] {
392 std::env::set_var(k, "");
393 }
394
395 let model = UniversalNER::new().unwrap();
396 assert!(
397 !model.is_available(),
398 "Empty keys must not count as available"
399 );
400
401 std::env::set_var("UNIVERSAL_NER_API_KEY", "dummy");
402 let model2 = UniversalNER::new().unwrap();
403 assert_eq!(model2.is_available(), cfg!(feature = "llm"));
404 }
405
406 #[test]
407 fn test_universal_ner_errors_without_llm() {
408 let model = UniversalNER::new().unwrap();
409 if !model.is_available() {
410 let result = model.extract_entities("Steve Jobs founded Apple", None);
412 assert!(result.is_err());
413 }
414 }
415
416 #[test]
417 fn test_parse_llm_response_handles_code_fences_and_multiscript() {
418 let model = UniversalNER::new().unwrap();
419 let text = "李明 met Müller in الرياض. 😀";
420 let response = r#"```json
421[
422 {"text":"李明","type":"person","start":0,"end":2},
423 {"text":"Müller","type":"person","start":7,"end":13},
424 {"text":"الرياض","type":"location","start":17,"end":23},
425 {"text":"😀","type":"misc","start":25,"end":26}
426]
427```"#;
428 let ents = model.parse_llm_response(response, text).expect("parse");
429 assert!(!ents.is_empty());
430
431 for e in ents {
432 let extracted = TextSpan::from_chars(text, e.start, e.end).extract(text);
433 assert_eq!(extracted, e.text, "entity span should round-trip");
434 }
435 }
436
437 #[test]
438 fn test_parse_llm_response_repeated_surface_form_uses_hint_offsets() {
439 let model = UniversalNER::new().unwrap();
440 let text = "Apple met Apple in Apple Park.";
441 let response = r#"[{"text":"Apple","type":"org","start":0,"end":5},{"text":"Apple","type":"org","start":10,"end":15},{"text":"Apple","type":"org","start":19,"end":24}]"#;
443 let ents = model.parse_llm_response(response, text).expect("parse");
444
445 let apples: Vec<_> = ents.into_iter().filter(|e| e.text == "Apple").collect();
446 assert_eq!(apples.len(), 3);
447 let mut starts: Vec<usize> = apples.iter().map(|e| e.start).collect();
448 starts.sort_unstable();
449 starts.dedup();
450 assert_eq!(
451 starts.len(),
452 3,
453 "each Apple should map to a distinct occurrence"
454 );
455 }
456}