Skip to main content

anno/backends/
llm_prompt.rs

1//! Code-based prompt generation for LLM NER.
2//!
3//! Implements CodeNER-style prompting (arXiv:2507.20423) that frames NER
4//! as a coding task, exploiting LLMs' superior code understanding.
5//!
6//! # Key Insight
7//!
8//! LLMs trained on code understand:
9//! - Structured scope boundaries (like entity spans)
10//! - Type annotations (like entity types)
11//! - Sequential processing (like BIO tagging)
12//!
13//! By embedding NER instructions as code, we get better results than
14//! natural language prompts.
15//!
16//! # Example
17//!
18//! ```rust
19//! use anno::backends::llm_prompt::{CodeNERPrompt, BIOSchema};
20//! use anno::EntityType;
21//!
22//! let schema = BIOSchema::new(&[
23//!     EntityType::Person,
24//!     EntityType::Organization,
25//!     EntityType::Location,
26//! ]);
27//!
28//! let prompt = CodeNERPrompt::new(schema)
29//!     .with_demonstrations(vec![
30//!         ("Steve Jobs founded Apple.", vec![
31//!             ("Steve Jobs", "PER", 0, 10),
32//!             ("Apple", "ORG", 19, 24),
33//!         ]),
34//!     ])
35//!     .with_chain_of_thought(true);
36//!
37//! let rendered = prompt.render("Lynn Conway worked at IBM.");
38//! // Send `rendered` to your LLM API
39//! ```
40//!
41//! # References
42//!
43//! - CodeNER: Code Prompting for Named Entity Recognition (arXiv:2507.20423)
44
45use anno_core::EntityType;
46use std::collections::HashMap;
47
48/// Entity annotation for demonstrations: (text, entity_type, start, end).
49pub type DemoEntity<'a> = (&'a str, &'a str, usize, usize);
50
51/// Full demonstration: (text, list of entity annotations).
52pub type DemoExample<'a> = (&'a str, Vec<DemoEntity<'a>>);
53
54/// BIO tagging schema for NER.
55///
56/// Defines the entity types and their descriptions for prompting.
57#[derive(Debug, Clone)]
58pub struct BIOSchema {
59    /// Entity types to extract
60    pub entity_types: Vec<EntityType>,
61    /// Human-readable descriptions for each type
62    pub descriptions: HashMap<EntityType, String>,
63}
64
65impl BIOSchema {
66    /// Create a new BIO schema with default descriptions.
67    #[must_use]
68    pub fn new(entity_types: &[EntityType]) -> Self {
69        let mut descriptions = HashMap::new();
70
71        for et in entity_types {
72            let desc = match et {
73                EntityType::Person => "Person names (individuals, fictional characters)",
74                EntityType::Organization => "Organizations (companies, institutions, groups)",
75                EntityType::Location => "Locations (cities, countries, addresses, landmarks)",
76                EntityType::Date => "Temporal expressions (dates, times, durations)",
77                EntityType::Time => "Time expressions (clock times, periods)",
78                EntityType::Money => "Monetary values (prices, amounts, currencies)",
79                EntityType::Percent => "Percentage values",
80                EntityType::Email => "Email addresses",
81                EntityType::Phone => "Phone numbers",
82                EntityType::Url => "Web URLs",
83                EntityType::Quantity => "Quantities (measurements, counts)",
84                EntityType::Cardinal => "Cardinal numbers",
85                EntityType::Ordinal => "Ordinal numbers (1st, 2nd, etc.)",
86                EntityType::Other(_) => "Miscellaneous named entities",
87                EntityType::Custom { name, .. } => name.as_str(),
88                // `EntityType` is non-exhaustive; keep prompts resilient to future variants.
89                _ => "Named entities",
90            };
91            descriptions.insert(et.clone(), desc.to_string());
92        }
93
94        Self {
95            entity_types: entity_types.to_vec(),
96            descriptions,
97        }
98    }
99
100    /// Set a custom description for an entity type.
101    #[must_use]
102    pub fn with_description(mut self, entity_type: EntityType, description: &str) -> Self {
103        self.descriptions
104            .insert(entity_type, description.to_string());
105        self
106    }
107
108    /// Render the schema as a code docstring.
109    fn render_docstring(&self) -> String {
110        let mut lines = vec![
111            "    \"\"\"".to_string(),
112            "    Extract named entities from text using BIO tagging.".to_string(),
113            "    ".to_string(),
114            "    BIO Schema:".to_string(),
115            "    - B-{TYPE}: Beginning of entity of TYPE".to_string(),
116            "    - I-{TYPE}: Inside (continuation) of entity".to_string(),
117            "    - O: Outside any entity".to_string(),
118            "    ".to_string(),
119            "    Entity Types:".to_string(),
120        ];
121
122        for et in &self.entity_types {
123            let label = et.as_label();
124            let desc = self.descriptions.get(et).map_or("", |s| s.as_str());
125            lines.push(format!("    - {}: {}", label, desc));
126        }
127
128        lines.push("    ".to_string());
129        lines.push(
130            "    Returns: List of entities with text, type, start, end positions.".to_string(),
131        );
132        lines.push("    \"\"\"".to_string());
133
134        lines.join("\n")
135    }
136}
137
138/// Demonstration example for few-shot prompting.
139#[derive(Debug, Clone)]
140pub struct Demonstration {
141    /// Input text
142    pub text: String,
143    /// Extracted entities: (text, type_label, start, end)
144    pub entities: Vec<(String, String, usize, usize)>,
145}
146
147impl Demonstration {
148    /// Create a new demonstration.
149    #[must_use]
150    pub fn new(text: &str, entities: Vec<(&str, &str, usize, usize)>) -> Self {
151        Self {
152            text: text.to_string(),
153            entities: entities
154                .into_iter()
155                .map(|(t, ty, s, e)| (t.to_string(), ty.to_string(), s, e))
156                .collect(),
157        }
158    }
159
160    /// Render as JSON output.
161    fn render_output(&self) -> String {
162        if self.entities.is_empty() {
163            return "[]".to_string();
164        }
165
166        let items: Vec<String> = self
167            .entities
168            .iter()
169            .map(|(text, ty, start, end)| {
170                format!(
171                    r#"    {{"text": "{}", "type": "{}", "start": {}, "end": {}}}"#,
172                    text, ty, start, end
173                )
174            })
175            .collect();
176
177        format!("[\n{}\n]", items.join(",\n"))
178    }
179}
180
181/// Code-based NER prompt generator.
182///
183/// Implements CodeNER-style prompting where NER is framed as a
184/// coding task with BIO schema instructions.
185#[derive(Debug, Clone)]
186pub struct CodeNERPrompt {
187    /// BIO schema definition
188    schema: BIOSchema,
189    /// Few-shot demonstrations
190    demonstrations: Vec<Demonstration>,
191    /// Enable chain-of-thought reasoning
192    use_cot: bool,
193    /// System message prefix
194    system_prefix: Option<String>,
195}
196
197impl CodeNERPrompt {
198    /// Create a new code NER prompt with the given schema.
199    #[must_use]
200    pub fn new(schema: BIOSchema) -> Self {
201        Self {
202            schema,
203            demonstrations: vec![],
204            use_cot: false,
205            system_prefix: None,
206        }
207    }
208
209    /// Add few-shot demonstrations.
210    #[must_use]
211    pub fn with_demonstrations(mut self, demos: Vec<DemoExample<'_>>) -> Self {
212        self.demonstrations = demos
213            .into_iter()
214            .map(|(text, entities)| Demonstration::new(text, entities))
215            .collect();
216        self
217    }
218
219    /// Enable chain-of-thought reasoning.
220    #[must_use]
221    pub fn with_chain_of_thought(mut self, enabled: bool) -> Self {
222        self.use_cot = enabled;
223        self
224    }
225
226    /// Set a custom system message prefix.
227    #[must_use]
228    pub fn with_system_prefix(mut self, prefix: &str) -> Self {
229        self.system_prefix = Some(prefix.to_string());
230        self
231    }
232
233    /// Render the system message.
234    #[must_use]
235    pub fn render_system(&self) -> String {
236        let prefix = self.system_prefix.as_deref().unwrap_or(
237            "You are an expert NER system. Extract entities precisely using BIO tagging.",
238        );
239
240        format!(
241            "{}\n\nRespond ONLY with valid JSON array of entities. No explanation.",
242            prefix
243        )
244    }
245
246    /// Render the user prompt for the given input text.
247    #[must_use]
248    pub fn render(&self, input_text: &str) -> String {
249        // Function signature with schema
250        let mut parts = vec![
251            "```python".to_string(),
252            "def extract_entities(text: str) -> list[dict]:".to_string(),
253            self.schema.render_docstring(),
254            "    pass".to_string(),
255            "```".to_string(),
256            String::new(),
257        ];
258
259        // Demonstrations
260        if !self.demonstrations.is_empty() {
261            parts.push("# Examples:".to_string());
262            for (i, demo) in self.demonstrations.iter().enumerate() {
263                parts.push(format!("\n## Example {}:", i + 1));
264                parts.push(format!("Input: \"{}\"", demo.text));
265                parts.push(format!("Output: {}", demo.render_output()));
266            }
267            parts.push("".to_string());
268        }
269
270        // Chain-of-thought instruction
271        if self.use_cot {
272            parts.push("# Instructions:".to_string());
273            parts.push("1. First, identify potential entity spans in the text".to_string());
274            parts.push("2. For each span, determine the most appropriate entity type".to_string());
275            parts.push("3. Verify the start and end positions are correct".to_string());
276            parts.push("4. Return the final JSON array".to_string());
277            parts.push("".to_string());
278        }
279
280        // Input
281        parts.push("# Task:".to_string());
282        parts.push(format!("Input: \"{}\"", input_text));
283        parts.push("Output:".to_string());
284
285        parts.join("\n")
286    }
287
288    /// Get the expected JSON output format description.
289    #[must_use]
290    pub fn output_format(&self) -> &'static str {
291        r#"[{"text": "entity_text", "type": "TYPE", "start": 0, "end": 10}, ...]"#
292    }
293}
294
295/// Parse LLM response into entities.
296///
297/// Attempts to extract a JSON array of entities from the LLM output,
298/// handling common formatting issues.
299pub fn parse_llm_response(response: &str) -> Result<Vec<ParsedEntity>, ParseError> {
300    // Try to find JSON array in response
301    let json_str = extract_json_array(response)?;
302
303    // Parse JSON
304    let parsed: Vec<serde_json::Value> =
305        serde_json::from_str(&json_str).map_err(|e| ParseError::InvalidJson(e.to_string()))?;
306
307    // Convert to entities
308    let mut entities = Vec::new();
309    for (i, item) in parsed.iter().enumerate() {
310        let text = item
311            .get("text")
312            .and_then(|v| v.as_str())
313            .ok_or(ParseError::MissingField(i, "text"))?
314            .to_string();
315
316        let entity_type = item
317            .get("type")
318            .and_then(|v| v.as_str())
319            .ok_or(ParseError::MissingField(i, "type"))?
320            .to_string();
321
322        let start = item
323            .get("start")
324            .and_then(|v| v.as_u64())
325            .ok_or(ParseError::MissingField(i, "start"))? as usize;
326
327        let end = item
328            .get("end")
329            .and_then(|v| v.as_u64())
330            .ok_or(ParseError::MissingField(i, "end"))? as usize;
331
332        let confidence = item.get("confidence").and_then(|v| v.as_f64());
333
334        entities.push(ParsedEntity {
335            text,
336            entity_type,
337            start,
338            end,
339            confidence,
340        });
341    }
342
343    Ok(entities)
344}
345
346/// Extract JSON array from potentially messy LLM output.
347fn extract_json_array(text: &str) -> Result<String, ParseError> {
348    // Try direct parse first
349    if let (Some(start), Some(end)) = (text.find('['), text.rfind(']')) {
350        if end > start {
351            return Ok(text[start..=end].to_string());
352        }
353    }
354
355    // Look for ```json block
356    if let Some(start) = text.find("```json") {
357        let start = start + 7;
358        if let Some(end) = text[start..].find("```") {
359            let json = text[start..start + end].trim();
360            if json.starts_with('[') {
361                return Ok(json.to_string());
362            }
363        }
364    }
365
366    // Look for any [ ] pair
367    if let Some(start) = text.find('[') {
368        if let Some(end) = text.rfind(']') {
369            if end > start {
370                return Ok(text[start..=end].to_string());
371            }
372        }
373    }
374
375    Err(ParseError::NoJsonFound)
376}
377
378/// Parsed entity from LLM response.
379#[derive(Debug, Clone)]
380pub struct ParsedEntity {
381    /// Entity text
382    pub text: String,
383    /// Entity type label
384    pub entity_type: String,
385    /// Start position in input
386    pub start: usize,
387    /// End position in input
388    pub end: usize,
389    /// Optional confidence score
390    pub confidence: Option<f64>,
391}
392
393impl ParsedEntity {
394    /// Convert to `Entity` with the given entity type mapping.
395    pub fn to_entity(&self, type_map: &HashMap<String, EntityType>) -> Option<anno_core::Entity> {
396        let entity_type = type_map
397            .get(&self.entity_type)
398            .or_else(|| type_map.get(&self.entity_type.to_uppercase()))
399            .cloned()?;
400
401        Some(anno_core::Entity::new(
402            &self.text,
403            entity_type,
404            self.start,
405            self.end,
406            self.confidence.unwrap_or(0.8),
407        ))
408    }
409}
410
411/// Error during LLM response parsing.
412#[derive(Debug)]
413pub enum ParseError {
414    /// No JSON array found in response
415    NoJsonFound,
416    /// Invalid JSON syntax
417    InvalidJson(String),
418    /// Missing required field
419    MissingField(usize, &'static str),
420}
421
422impl std::fmt::Display for ParseError {
423    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424        match self {
425            Self::NoJsonFound => write!(f, "No JSON array found in LLM response"),
426            Self::InvalidJson(e) => write!(f, "Invalid JSON: {}", e),
427            Self::MissingField(i, field) => {
428                write!(f, "Entity {} missing required field: {}", i, field)
429            }
430        }
431    }
432}
433
434impl std::error::Error for ParseError {}
435
436// =============================================================================
437// Tests
438// =============================================================================
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443
444    #[test]
445    fn test_bio_schema_creation() {
446        let schema = BIOSchema::new(&[EntityType::Person, EntityType::Organization]);
447
448        assert_eq!(schema.entity_types.len(), 2);
449        assert!(schema.descriptions.contains_key(&EntityType::Person));
450    }
451
452    #[test]
453    fn test_prompt_rendering() {
454        let schema = BIOSchema::new(&[EntityType::Person, EntityType::Location]);
455        let prompt = CodeNERPrompt::new(schema);
456
457        let rendered = prompt.render("John went to Paris.");
458
459        assert!(rendered.contains("extract_entities"));
460        assert!(rendered.contains("BIO Schema"));
461        assert!(rendered.contains("PER"));
462        assert!(rendered.contains("LOC"));
463        assert!(rendered.contains("John went to Paris"));
464    }
465
466    #[test]
467    fn test_prompt_with_demonstrations() {
468        let schema = BIOSchema::new(&[EntityType::Person]);
469        let prompt = CodeNERPrompt::new(schema).with_demonstrations(vec![(
470            "Steve Jobs worked at Apple.",
471            vec![("Steve Jobs", "PER", 0, 10)],
472        )]);
473
474        let rendered = prompt.render("Test input.");
475
476        assert!(rendered.contains("Example 1"));
477        assert!(rendered.contains("Steve Jobs"));
478    }
479
480    #[test]
481    fn test_parse_clean_json() {
482        let response = r#"[{"text": "John", "type": "PER", "start": 0, "end": 4}]"#;
483        let entities = parse_llm_response(response).unwrap();
484
485        assert_eq!(entities.len(), 1);
486        assert_eq!(entities[0].text, "John");
487        assert_eq!(entities[0].entity_type, "PER");
488    }
489
490    #[test]
491    fn test_parse_json_with_markdown() {
492        let response = r#"
493Here are the entities:
494
495```json
496[{"text": "Paris", "type": "LOC", "start": 10, "end": 15}]
497```
498
499That's all!
500"#;
501        let entities = parse_llm_response(response).unwrap();
502
503        assert_eq!(entities.len(), 1);
504        assert_eq!(entities[0].text, "Paris");
505    }
506
507    #[test]
508    fn test_parse_empty_response() {
509        let response = "[]";
510        let entities = parse_llm_response(response).unwrap();
511
512        assert!(entities.is_empty());
513    }
514
515    #[test]
516    fn test_parse_no_json() {
517        let response = "I couldn't find any entities.";
518        let result = parse_llm_response(response);
519
520        assert!(matches!(result, Err(ParseError::NoJsonFound)));
521    }
522
523    #[test]
524    fn test_chain_of_thought() {
525        let schema = BIOSchema::new(&[EntityType::Person]);
526        let prompt = CodeNERPrompt::new(schema).with_chain_of_thought(true);
527
528        let rendered = prompt.render("Test.");
529
530        assert!(rendered.contains("Instructions"));
531        assert!(rendered.contains("identify potential entity spans"));
532    }
533}