1use crate::backends::inference::ZeroShotNER;
52use crate::offset::TextSpan;
53use crate::{Entity, EntityType, Model, Result};
54
55pub struct UniversalNER {
60 llm_available: bool,
62}
63
64impl UniversalNER {
65 pub fn new() -> Result<Self> {
70 crate::env::load_dotenv();
72
73 let universal_key = std::env::var("UNIVERSAL_NER_API_KEY")
77 .ok()
78 .is_some_and(|v| !v.trim().is_empty());
79 let llm_available =
80 cfg!(feature = "llm") && (crate::env::has_llm_api_key() || universal_key);
81
82 Ok(Self { llm_available })
83 }
84
85 #[cfg(feature = "llm")]
90 fn extract_with_llm(&self, text: &str, entity_types: &[&str]) -> Result<Vec<Entity>> {
91 let (api_key, provider) = crate::env::llm_api_key().ok_or_else(|| {
92 crate::Error::FeatureNotAvailable(
93 "No LLM API key found. Set OPENAI_API_KEY, ANTHROPIC_API_KEY, or similar.".into(),
94 )
95 })?;
96
97 let types_str = entity_types.join(", ");
98 let prompt = format!(
99 r#"Extract named entities from the following text. Return ONLY a JSON array of objects with "text", "type", "start", "end" fields.
100
101Entity types to extract: {types_str}
102
103Text: "{text}"
104
105Example output: [{{"text": "John Smith", "type": "person", "start": 0, "end": 10}}]
106
107Return ONLY the JSON array, no other text:"#
108 );
109
110 let (url, model, auth_header) = match provider {
111 "openai" => (
112 "https://api.openai.com/v1/chat/completions",
113 "gpt-4o-mini",
114 format!("Bearer {}", api_key),
115 ),
116 "anthropic" => (
117 "https://api.anthropic.com/v1/messages",
118 "claude-3-haiku-20240307",
119 api_key.clone(),
120 ),
121 "openrouter" => (
122 "https://openrouter.ai/api/v1/chat/completions",
123 "openai/gpt-4o-mini",
124 format!("Bearer {}", api_key),
125 ),
126 other => {
127 return Err(crate::Error::FeatureNotAvailable(format!(
128 "UniversalNER provider '{}' is not supported by this build. Supported: openai, anthropic, openrouter.",
129 other
130 )));
131 }
132 };
133
134 let response = if provider == "anthropic" {
135 let body = serde_json::json!({
137 "model": model,
138 "max_tokens": 1024,
139 "messages": [{"role": "user", "content": prompt}]
140 });
141 ureq::post(url)
142 .set("x-api-key", &auth_header)
143 .set("anthropic-version", "2023-06-01")
144 .set("content-type", "application/json")
145 .send_json(body)
146 } else {
147 let body = serde_json::json!({
149 "model": model,
150 "messages": [{"role": "user", "content": prompt}],
151 "temperature": 0.0
152 });
153 ureq::post(url)
154 .set("Authorization", &auth_header)
155 .set("content-type", "application/json")
156 .send_json(body)
157 };
158
159 let response =
160 response.map_err(|e| crate::Error::Inference(format!("LLM API error: {}", e)))?;
161 let json: serde_json::Value = response
162 .into_json()
163 .map_err(|e| crate::Error::Parse(format!("LLM response parse error: {}", e)))?;
164
165 let content = if provider == "anthropic" {
167 json["content"][0]["text"].as_str().unwrap_or("[]")
168 } else {
169 json["choices"][0]["message"]["content"]
170 .as_str()
171 .unwrap_or("[]")
172 };
173
174 self.parse_llm_response(content, text)
176 }
177
178 #[cfg(not(feature = "llm"))]
180 fn extract_with_llm(&self, _text: &str, _entity_types: &[&str]) -> Result<Vec<Entity>> {
181 Err(crate::Error::FeatureNotAvailable(
182 "UniversalNER requires the 'llm' feature to make HTTP requests (ureq). Rebuild with --features llm and provide an API key via .env."
183 .into(),
184 ))
185 }
186
187 #[allow(dead_code)] fn parse_llm_response(&self, content: &str, original_text: &str) -> Result<Vec<Entity>> {
193 let json_str = content.trim();
196 let json_str = json_str
197 .strip_prefix("```json")
198 .or_else(|| json_str.strip_prefix("```JSON"))
199 .or_else(|| json_str.strip_prefix("```"))
200 .unwrap_or(json_str)
201 .trim();
202 let json_str = json_str.strip_suffix("```").unwrap_or(json_str).trim();
203
204 let json_str = if json_str.starts_with('[') {
205 json_str.to_string()
206 } else if let Some(start) = json_str.find('[') {
207 if let Some(end) = json_str.rfind(']') {
208 json_str[start..=end].to_string()
209 } else {
210 return Err(crate::Error::Parse(format!(
211 "UniversalNER LLM response did not contain a complete JSON array. Response begins: {:?}",
212 json_str.chars().take(200).collect::<String>()
213 )));
214 }
215 } else {
216 return Err(crate::Error::Parse(format!(
217 "UniversalNER LLM response did not contain a JSON array. Response begins: {:?}",
218 json_str.chars().take(200).collect::<String>()
219 )));
220 };
221
222 let items: Vec<serde_json::Value> = serde_json::from_str(&json_str).map_err(|e| {
223 crate::Error::Parse(format!(
224 "UniversalNER failed to parse JSON array from LLM response: {}. Extracted JSON begins: {:?}",
225 e,
226 json_str.chars().take(200).collect::<String>()
227 ))
228 })?;
229
230 let mut entities = Vec::new();
231 for item in items {
232 let text = item["text"].as_str().unwrap_or("");
233 let type_str = item["type"].as_str().unwrap_or("misc");
234 let hint_start = item["start"].as_u64().unwrap_or(0) as usize;
236 let hint_end = item["end"].as_u64().unwrap_or(0) as usize;
237
238 if text.is_empty() || hint_end <= hint_start {
239 continue;
240 }
241
242 let mut occurrences: Vec<(usize, usize)> = Vec::new();
246 for (start_byte, _) in original_text.match_indices(text) {
247 let span = TextSpan::from_bytes(original_text, start_byte, start_byte + text.len());
248 occurrences.push((span.char_start, span.char_end));
249 }
250
251 let (actual_start, actual_end) = if !occurrences.is_empty() {
252 *occurrences
253 .iter()
254 .min_by_key(|(s, e)| {
255 let ds = (*s as isize - hint_start as isize).unsigned_abs();
256 let de = (*e as isize - hint_end as isize).unsigned_abs();
257 (ds + de, *s, *e)
258 })
259 .expect("non-empty occurrences")
260 } else {
261 let char_count = original_text.chars().count();
263 if hint_end <= char_count {
264 let extracted = TextSpan::from_chars(original_text, hint_start, hint_end)
265 .extract(original_text);
266 if extracted == text {
267 (hint_start, hint_end)
268 } else {
269 continue;
270 }
271 } else {
272 continue;
273 }
274 };
275
276 let entity_type = match type_str.to_lowercase().as_str() {
277 "person" | "per" => EntityType::Person,
278 "organization" | "org" => EntityType::Organization,
279 "location" | "loc" | "gpe" => EntityType::Location,
280 "date" | "time" => EntityType::Date,
281 "money" | "currency" => EntityType::Money,
282 _ => EntityType::Other(type_str.to_string()),
283 };
284
285 let mut entity = Entity::new(
286 text.to_string(),
287 entity_type,
288 actual_start,
289 actual_end,
290 0.9, );
292 entity.provenance = Some(crate::Provenance::ml("universal_ner", entity.confidence));
293 entities.push(entity);
294 }
295
296 Ok(entities)
297 }
298}
299
300impl Model for UniversalNER {
301 fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
302 if !self.llm_available {
303 return Err(crate::Error::FeatureNotAvailable(
304 "UniversalNER requires an LLM API key. Set one of: OPENAI_API_KEY, ANTHROPIC_API_KEY, OPENROUTER_API_KEY, GEMINI_API_KEY, or UNIVERSAL_NER_API_KEY (loaded from .env if present)."
305 .into(),
306 ));
307 }
308
309 self.extract_with_llm(text, &["person", "organization", "location"])
310 }
311
312 fn supported_types(&self) -> Vec<EntityType> {
313 vec![
314 EntityType::Person,
315 EntityType::Organization,
316 EntityType::Location,
317 ]
318 }
319
320 fn is_available(&self) -> bool {
321 self.llm_available
322 }
323
324 fn name(&self) -> &'static str {
325 "universal_ner"
326 }
327
328 fn description(&self) -> &'static str {
329 "UniversalNER: LLM-based zero-shot NER (requires `llm` feature + API key)"
330 }
331
332 fn capabilities(&self) -> crate::ModelCapabilities {
333 crate::ModelCapabilities {
334 dynamic_labels: true,
335 ..Default::default()
336 }
337 }
338}
339
340impl crate::NamedEntityCapable for UniversalNER {}
341
342impl crate::DynamicLabels for UniversalNER {
343 fn extract_with_labels(
344 &self,
345 text: &str,
346 labels: &[&str],
347 _language: Option<&str>,
348 ) -> crate::Result<Vec<Entity>> {
349 <Self as ZeroShotNER>::extract_with_types(self, text, labels, 0.3)
350 }
351}
352
353impl ZeroShotNER for UniversalNER {
354 fn default_types(&self) -> &[&'static str] {
355 &["person", "organization", "location"]
356 }
357
358 fn extract_with_types(
359 &self,
360 text: &str,
361 entity_types: &[&str],
362 _threshold: f32,
363 ) -> Result<Vec<Entity>> {
364 if !self.llm_available {
365 return Err(crate::Error::FeatureNotAvailable(
366 "UniversalNER requires an LLM API key. Set one of: OPENAI_API_KEY, ANTHROPIC_API_KEY, OPENROUTER_API_KEY, GEMINI_API_KEY, or UNIVERSAL_NER_API_KEY (loaded from .env if present)."
367 .into(),
368 ));
369 }
370 self.extract_with_llm(text, entity_types)
371 }
372
373 fn extract_with_descriptions(
374 &self,
375 text: &str,
376 descriptions: &[&str],
377 threshold: f32,
378 ) -> Result<Vec<Entity>> {
379 self.extract_with_types(text, descriptions, threshold)
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387 use std::sync::{Mutex, OnceLock};
388
389 #[test]
390 fn test_universal_ner_creation() {
391 let model = UniversalNER::new().unwrap();
392 assert_eq!(model.name(), "universal_ner");
393 }
394
395 #[test]
396 fn test_universal_ner_availability_reflects_api_key() {
397 static ENV_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
399 let _guard = ENV_LOCK
400 .get_or_init(|| Mutex::new(()))
401 .lock()
402 .unwrap_or_else(|e| e.into_inner());
403
404 for k in [
406 "OPENAI_API_KEY",
407 "ANTHROPIC_API_KEY",
408 "OPENROUTER_API_KEY",
409 "GEMINI_API_KEY",
410 "UNIVERSAL_NER_API_KEY",
411 ] {
412 std::env::set_var(k, "");
413 }
414
415 let model = UniversalNER::new().unwrap();
416 assert!(
417 !model.is_available(),
418 "Empty keys must not count as available"
419 );
420
421 std::env::set_var("UNIVERSAL_NER_API_KEY", "dummy");
422 let model2 = UniversalNER::new().unwrap();
423 assert_eq!(model2.is_available(), cfg!(feature = "llm"));
424 }
425
426 #[test]
427 fn test_universal_ner_errors_without_llm() {
428 let model = UniversalNER::new().unwrap();
429 if !model.is_available() {
430 let result = model.extract_entities("Steve Jobs founded Apple", None);
432 assert!(result.is_err());
433 }
434 }
435
436 #[test]
437 fn test_parse_llm_response_handles_code_fences_and_multiscript() {
438 let model = UniversalNER::new().unwrap();
439 let text = "李明 met Müller in الرياض. 😀";
440 let response = r#"```json
441[
442 {"text":"李明","type":"person","start":0,"end":2},
443 {"text":"Müller","type":"person","start":7,"end":13},
444 {"text":"الرياض","type":"location","start":17,"end":23},
445 {"text":"😀","type":"misc","start":25,"end":26}
446]
447```"#;
448 let ents = model.parse_llm_response(response, text).expect("parse");
449 assert!(!ents.is_empty());
450
451 for e in ents {
452 let extracted = TextSpan::from_chars(text, e.start, e.end).extract(text);
453 assert_eq!(extracted, e.text, "entity span should round-trip");
454 }
455 }
456
457 #[test]
458 fn test_parse_llm_response_repeated_surface_form_uses_hint_offsets() {
459 let model = UniversalNER::new().unwrap();
460 let text = "Apple met Apple in Apple Park.";
461 let response = r#"[{"text":"Apple","type":"org","start":0,"end":5},{"text":"Apple","type":"org","start":10,"end":15},{"text":"Apple","type":"org","start":19,"end":24}]"#;
463 let ents = model.parse_llm_response(response, text).expect("parse");
464
465 let apples: Vec<_> = ents.into_iter().filter(|e| e.text == "Apple").collect();
466 assert_eq!(apples.len(), 3);
467 let mut starts: Vec<usize> = apples.iter().map(|e| e.start).collect();
468 starts.sort_unstable();
469 starts.dedup();
470 assert_eq!(
471 starts.len(),
472 3,
473 "each Apple should map to a distinct occurrence"
474 );
475 }
476}