1use anno_core::EntityType;
46use std::collections::HashMap;
47
48pub type DemoEntity<'a> = (&'a str, &'a str, usize, usize);
50
51pub type DemoExample<'a> = (&'a str, Vec<DemoEntity<'a>>);
53
54#[derive(Debug, Clone)]
58pub struct BIOSchema {
59 pub entity_types: Vec<EntityType>,
61 pub descriptions: HashMap<EntityType, String>,
63}
64
65impl BIOSchema {
66 #[must_use]
68 pub fn new(entity_types: &[EntityType]) -> Self {
69 let mut descriptions = HashMap::new();
70
71 for et in entity_types {
72 let desc = match et {
73 EntityType::Person => "Person names (individuals, fictional characters)",
74 EntityType::Organization => "Organizations (companies, institutions, groups)",
75 EntityType::Location => "Locations (cities, countries, addresses, landmarks)",
76 EntityType::Date => "Temporal expressions (dates, times, durations)",
77 EntityType::Time => "Time expressions (clock times, periods)",
78 EntityType::Money => "Monetary values (prices, amounts, currencies)",
79 EntityType::Percent => "Percentage values",
80 EntityType::Email => "Email addresses",
81 EntityType::Phone => "Phone numbers",
82 EntityType::Url => "Web URLs",
83 EntityType::Quantity => "Quantities (measurements, counts)",
84 EntityType::Cardinal => "Cardinal numbers",
85 EntityType::Ordinal => "Ordinal numbers (1st, 2nd, etc.)",
86 EntityType::Other(_) => "Miscellaneous named entities",
87 EntityType::Custom { name, .. } => name.as_str(),
88 _ => "Named entities",
90 };
91 descriptions.insert(et.clone(), desc.to_string());
92 }
93
94 Self {
95 entity_types: entity_types.to_vec(),
96 descriptions,
97 }
98 }
99
100 #[must_use]
102 pub fn with_description(mut self, entity_type: EntityType, description: &str) -> Self {
103 self.descriptions
104 .insert(entity_type, description.to_string());
105 self
106 }
107
108 fn render_docstring(&self) -> String {
110 let mut lines = vec![
111 " \"\"\"".to_string(),
112 " Extract named entities from text using BIO tagging.".to_string(),
113 " ".to_string(),
114 " BIO Schema:".to_string(),
115 " - B-{TYPE}: Beginning of entity of TYPE".to_string(),
116 " - I-{TYPE}: Inside (continuation) of entity".to_string(),
117 " - O: Outside any entity".to_string(),
118 " ".to_string(),
119 " Entity Types:".to_string(),
120 ];
121
122 for et in &self.entity_types {
123 let label = et.as_label();
124 let desc = self.descriptions.get(et).map_or("", |s| s.as_str());
125 lines.push(format!(" - {}: {}", label, desc));
126 }
127
128 lines.push(" ".to_string());
129 lines.push(
130 " Returns: List of entities with text, type, start, end positions.".to_string(),
131 );
132 lines.push(" \"\"\"".to_string());
133
134 lines.join("\n")
135 }
136}
137
138#[derive(Debug, Clone)]
140pub struct Demonstration {
141 pub text: String,
143 pub entities: Vec<(String, String, usize, usize)>,
145}
146
147impl Demonstration {
148 #[must_use]
150 pub fn new(text: &str, entities: Vec<(&str, &str, usize, usize)>) -> Self {
151 Self {
152 text: text.to_string(),
153 entities: entities
154 .into_iter()
155 .map(|(t, ty, s, e)| (t.to_string(), ty.to_string(), s, e))
156 .collect(),
157 }
158 }
159
160 fn render_output(&self) -> String {
162 if self.entities.is_empty() {
163 return "[]".to_string();
164 }
165
166 let items: Vec<String> = self
167 .entities
168 .iter()
169 .map(|(text, ty, start, end)| {
170 format!(
171 r#" {{"text": "{}", "type": "{}", "start": {}, "end": {}}}"#,
172 text, ty, start, end
173 )
174 })
175 .collect();
176
177 format!("[\n{}\n]", items.join(",\n"))
178 }
179}
180
181#[derive(Debug, Clone)]
186pub struct CodeNERPrompt {
187 schema: BIOSchema,
189 demonstrations: Vec<Demonstration>,
191 use_cot: bool,
193 system_prefix: Option<String>,
195}
196
197impl CodeNERPrompt {
198 #[must_use]
200 pub fn new(schema: BIOSchema) -> Self {
201 Self {
202 schema,
203 demonstrations: vec![],
204 use_cot: false,
205 system_prefix: None,
206 }
207 }
208
209 #[must_use]
211 pub fn with_demonstrations(mut self, demos: Vec<DemoExample<'_>>) -> Self {
212 self.demonstrations = demos
213 .into_iter()
214 .map(|(text, entities)| Demonstration::new(text, entities))
215 .collect();
216 self
217 }
218
219 #[must_use]
221 pub fn with_chain_of_thought(mut self, enabled: bool) -> Self {
222 self.use_cot = enabled;
223 self
224 }
225
226 #[must_use]
228 pub fn with_system_prefix(mut self, prefix: &str) -> Self {
229 self.system_prefix = Some(prefix.to_string());
230 self
231 }
232
233 #[must_use]
235 pub fn render_system(&self) -> String {
236 let prefix = self.system_prefix.as_deref().unwrap_or(
237 "You are an expert NER system. Extract entities precisely using BIO tagging.",
238 );
239
240 format!(
241 "{}\n\nRespond ONLY with valid JSON array of entities. No explanation.",
242 prefix
243 )
244 }
245
246 #[must_use]
248 pub fn render(&self, input_text: &str) -> String {
249 let mut parts = vec![
251 "```python".to_string(),
252 "def extract_entities(text: str) -> list[dict]:".to_string(),
253 self.schema.render_docstring(),
254 " pass".to_string(),
255 "```".to_string(),
256 String::new(),
257 ];
258
259 if !self.demonstrations.is_empty() {
261 parts.push("# Examples:".to_string());
262 for (i, demo) in self.demonstrations.iter().enumerate() {
263 parts.push(format!("\n## Example {}:", i + 1));
264 parts.push(format!("Input: \"{}\"", demo.text));
265 parts.push(format!("Output: {}", demo.render_output()));
266 }
267 parts.push("".to_string());
268 }
269
270 if self.use_cot {
272 parts.push("# Instructions:".to_string());
273 parts.push("1. First, identify potential entity spans in the text".to_string());
274 parts.push("2. For each span, determine the most appropriate entity type".to_string());
275 parts.push("3. Verify the start and end positions are correct".to_string());
276 parts.push("4. Return the final JSON array".to_string());
277 parts.push("".to_string());
278 }
279
280 parts.push("# Task:".to_string());
282 parts.push(format!("Input: \"{}\"", input_text));
283 parts.push("Output:".to_string());
284
285 parts.join("\n")
286 }
287
288 #[must_use]
290 pub fn output_format(&self) -> &'static str {
291 r#"[{"text": "entity_text", "type": "TYPE", "start": 0, "end": 10}, ...]"#
292 }
293}
294
295pub fn parse_llm_response(response: &str) -> Result<Vec<ParsedEntity>, ParseError> {
300 let json_str = extract_json_array(response)?;
302
303 let parsed: Vec<serde_json::Value> =
305 serde_json::from_str(&json_str).map_err(|e| ParseError::InvalidJson(e.to_string()))?;
306
307 let mut entities = Vec::new();
309 for (i, item) in parsed.iter().enumerate() {
310 let text = item
311 .get("text")
312 .and_then(|v| v.as_str())
313 .ok_or(ParseError::MissingField(i, "text"))?
314 .to_string();
315
316 let entity_type = item
317 .get("type")
318 .and_then(|v| v.as_str())
319 .ok_or(ParseError::MissingField(i, "type"))?
320 .to_string();
321
322 let start = item
323 .get("start")
324 .and_then(|v| v.as_u64())
325 .ok_or(ParseError::MissingField(i, "start"))? as usize;
326
327 let end = item
328 .get("end")
329 .and_then(|v| v.as_u64())
330 .ok_or(ParseError::MissingField(i, "end"))? as usize;
331
332 let confidence = item.get("confidence").and_then(|v| v.as_f64());
333
334 entities.push(ParsedEntity {
335 text,
336 entity_type,
337 start,
338 end,
339 confidence,
340 });
341 }
342
343 Ok(entities)
344}
345
346fn extract_json_array(text: &str) -> Result<String, ParseError> {
348 if let (Some(start), Some(end)) = (text.find('['), text.rfind(']')) {
350 if end > start {
351 return Ok(text[start..=end].to_string());
352 }
353 }
354
355 if let Some(start) = text.find("```json") {
357 let start = start + 7;
358 if let Some(end) = text[start..].find("```") {
359 let json = text[start..start + end].trim();
360 if json.starts_with('[') {
361 return Ok(json.to_string());
362 }
363 }
364 }
365
366 if let Some(start) = text.find('[') {
368 if let Some(end) = text.rfind(']') {
369 if end > start {
370 return Ok(text[start..=end].to_string());
371 }
372 }
373 }
374
375 Err(ParseError::NoJsonFound)
376}
377
378#[derive(Debug, Clone)]
380pub struct ParsedEntity {
381 pub text: String,
383 pub entity_type: String,
385 pub start: usize,
387 pub end: usize,
389 pub confidence: Option<f64>,
391}
392
393impl ParsedEntity {
394 pub fn to_entity(&self, type_map: &HashMap<String, EntityType>) -> Option<anno_core::Entity> {
396 let entity_type = type_map
397 .get(&self.entity_type)
398 .or_else(|| type_map.get(&self.entity_type.to_uppercase()))
399 .cloned()?;
400
401 Some(anno_core::Entity::new(
402 &self.text,
403 entity_type,
404 self.start,
405 self.end,
406 self.confidence.unwrap_or(0.8),
407 ))
408 }
409}
410
411#[derive(Debug)]
413pub enum ParseError {
414 NoJsonFound,
416 InvalidJson(String),
418 MissingField(usize, &'static str),
420}
421
422impl std::fmt::Display for ParseError {
423 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424 match self {
425 Self::NoJsonFound => write!(f, "No JSON array found in LLM response"),
426 Self::InvalidJson(e) => write!(f, "Invalid JSON: {}", e),
427 Self::MissingField(i, field) => {
428 write!(f, "Entity {} missing required field: {}", i, field)
429 }
430 }
431 }
432}
433
434impl std::error::Error for ParseError {}
435
436#[cfg(test)]
441mod tests {
442 use super::*;
443
444 #[test]
445 fn test_bio_schema_creation() {
446 let schema = BIOSchema::new(&[EntityType::Person, EntityType::Organization]);
447
448 assert_eq!(schema.entity_types.len(), 2);
449 assert!(schema.descriptions.contains_key(&EntityType::Person));
450 }
451
452 #[test]
453 fn test_prompt_rendering() {
454 let schema = BIOSchema::new(&[EntityType::Person, EntityType::Location]);
455 let prompt = CodeNERPrompt::new(schema);
456
457 let rendered = prompt.render("John went to Paris.");
458
459 assert!(rendered.contains("extract_entities"));
460 assert!(rendered.contains("BIO Schema"));
461 assert!(rendered.contains("PER"));
462 assert!(rendered.contains("LOC"));
463 assert!(rendered.contains("John went to Paris"));
464 }
465
466 #[test]
467 fn test_prompt_with_demonstrations() {
468 let schema = BIOSchema::new(&[EntityType::Person]);
469 let prompt = CodeNERPrompt::new(schema).with_demonstrations(vec![(
470 "Steve Jobs worked at Apple.",
471 vec![("Steve Jobs", "PER", 0, 10)],
472 )]);
473
474 let rendered = prompt.render("Test input.");
475
476 assert!(rendered.contains("Example 1"));
477 assert!(rendered.contains("Steve Jobs"));
478 }
479
480 #[test]
481 fn test_parse_clean_json() {
482 let response = r#"[{"text": "John", "type": "PER", "start": 0, "end": 4}]"#;
483 let entities = parse_llm_response(response).unwrap();
484
485 assert_eq!(entities.len(), 1);
486 assert_eq!(entities[0].text, "John");
487 assert_eq!(entities[0].entity_type, "PER");
488 }
489
490 #[test]
491 fn test_parse_json_with_markdown() {
492 let response = r#"
493Here are the entities:
494
495```json
496[{"text": "Paris", "type": "LOC", "start": 10, "end": 15}]
497```
498
499That's all!
500"#;
501 let entities = parse_llm_response(response).unwrap();
502
503 assert_eq!(entities.len(), 1);
504 assert_eq!(entities[0].text, "Paris");
505 }
506
507 #[test]
508 fn test_parse_empty_response() {
509 let response = "[]";
510 let entities = parse_llm_response(response).unwrap();
511
512 assert!(entities.is_empty());
513 }
514
515 #[test]
516 fn test_parse_no_json() {
517 let response = "I couldn't find any entities.";
518 let result = parse_llm_response(response);
519
520 assert!(matches!(result, Err(ParseError::NoJsonFound)));
521 }
522
523 #[test]
524 fn test_chain_of_thought() {
525 let schema = BIOSchema::new(&[EntityType::Person]);
526 let prompt = CodeNERPrompt::new(schema).with_chain_of_thought(true);
527
528 let rendered = prompt.render("Test.");
529
530 assert!(rendered.contains("Instructions"));
531 assert!(rendered.contains("identify potential entity spans"));
532 }
533}