Skip to main content

anno/backends/
demonstration.rs

1//! Demonstration selection for few-shot NER.
2//!
3//! Implements CMAS-inspired demonstration selection (arXiv:2502.18702) that
4//! selects helpful demonstrations based on similarity and quality.
5//!
6//! # Key Concepts
7//!
8//! From the CMAS paper:
9//! 1. **Self-annotation**: Initial entity labeling
10//! 2. **TRF (Type-Related Features)**: Context features around entities
11//! 3. **Demonstration discriminator**: Evaluates helpfulness
12//! 4. **Overall predictor**: Final ensemble
13//!
14//! This module provides the demonstration discriminator component.
15//!
16//! # Example
17//!
18//! ```rust
19//! use anno::backends::demonstration::{DemonstrationBank, HelpfulnessConfig};
20//!
21//! let mut bank = DemonstrationBank::new();
22//!
23//! // Add demonstrations
24//! bank.add("Steve Jobs founded Apple in 1976.", vec![
25//!     ("Steve Jobs", "PER", 0, 10),
26//!     ("Apple", "ORG", 19, 24),
27//! ]);
28//!
29//! bank.add("Microsoft was founded by Bill Gates.", vec![
30//!     ("Microsoft", "ORG", 0, 9),
31//!     ("Bill Gates", "PER", 26, 36),
32//! ]);
33//!
34//! // Select helpful demonstrations for a query
35//! let demos = bank.select("Lynn Conway worked at IBM and Xerox PARC.", 2);
36//! assert_eq!(demos.len(), 2);
37//! ```
38
39use std::collections::HashMap;
40
41/// Entity annotation: (text, entity_type, start, end).
42pub type EntityAnnotation<'a> = (&'a str, &'a str, usize, usize);
43
44/// Batch of demonstrations: (text, entities) pairs.
45pub type DemoBatch<'a> = Vec<(&'a str, Vec<EntityAnnotation<'a>>)>;
46
47/// Configuration for helpfulness scoring.
48#[derive(Debug, Clone)]
49pub struct HelpfulnessConfig {
50    /// Weight for text similarity
51    pub similarity_weight: f64,
52    /// Weight for entity type overlap
53    pub type_overlap_weight: f64,
54    /// Weight for entity density similarity
55    pub density_weight: f64,
56    /// Minimum helpfulness score to include
57    pub min_score: f64,
58}
59
60impl Default for HelpfulnessConfig {
61    fn default() -> Self {
62        Self {
63            similarity_weight: 0.4,
64            type_overlap_weight: 0.4,
65            density_weight: 0.2,
66            min_score: 0.1,
67        }
68    }
69}
70
71/// A single demonstration example.
72#[derive(Debug, Clone)]
73pub struct DemonstrationExample {
74    /// Input text
75    pub text: String,
76    /// Annotated entities: (text, type, start, end)
77    pub entities: Vec<(String, String, usize, usize)>,
78    /// Precomputed features
79    features: ExampleFeatures,
80}
81
82/// Precomputed features for efficient matching.
83#[derive(Debug, Clone, Default)]
84struct ExampleFeatures {
85    /// Token set (lowercase words)
86    tokens: Vec<String>,
87    /// Entity types present
88    entity_types: Vec<String>,
89    /// Entity density (entities per 100 tokens)
90    entity_density: f64,
91}
92
93impl DemonstrationExample {
94    /// Create a new demonstration example.
95    pub fn new(text: &str, entities: Vec<(&str, &str, usize, usize)>) -> Self {
96        let entities: Vec<_> = entities
97            .into_iter()
98            .map(|(t, ty, s, e)| (t.to_string(), ty.to_string(), s, e))
99            .collect();
100
101        let features = Self::compute_features(text, &entities);
102
103        Self {
104            text: text.to_string(),
105            entities,
106            features,
107        }
108    }
109
110    fn compute_features(
111        text: &str,
112        entities: &[(String, String, usize, usize)],
113    ) -> ExampleFeatures {
114        let tokens: Vec<String> = text.split_whitespace().map(|w| w.to_lowercase()).collect();
115
116        let entity_types: Vec<String> = entities.iter().map(|(_, ty, _, _)| ty.clone()).collect();
117
118        let entity_density = if tokens.is_empty() {
119            0.0
120        } else {
121            (entities.len() as f64 / tokens.len() as f64) * 100.0
122        };
123
124        ExampleFeatures {
125            tokens,
126            entity_types,
127            entity_density,
128        }
129    }
130}
131
132/// Bank of demonstrations for few-shot NER.
133///
134/// Stores demonstrations and selects the most helpful ones for a query.
135#[derive(Debug, Clone, Default)]
136pub struct DemonstrationBank {
137    examples: Vec<DemonstrationExample>,
138    config: HelpfulnessConfig,
139}
140
141impl DemonstrationBank {
142    /// Create a new empty demonstration bank.
143    #[must_use]
144    pub fn new() -> Self {
145        Self::default()
146    }
147
148    /// Create with custom helpfulness config.
149    #[must_use]
150    pub fn with_config(config: HelpfulnessConfig) -> Self {
151        Self {
152            examples: vec![],
153            config,
154        }
155    }
156
157    /// Add a demonstration to the bank.
158    pub fn add(&mut self, text: &str, entities: Vec<(&str, &str, usize, usize)>) {
159        self.examples
160            .push(DemonstrationExample::new(text, entities));
161    }
162
163    /// Add multiple demonstrations at once.
164    pub fn add_all(&mut self, demos: DemoBatch<'_>) {
165        for (text, entities) in demos {
166            self.add(text, entities);
167        }
168    }
169
170    /// Number of demonstrations in the bank.
171    #[must_use]
172    pub fn len(&self) -> usize {
173        self.examples.len()
174    }
175
176    /// Check if the bank is empty.
177    #[must_use]
178    pub fn is_empty(&self) -> bool {
179        self.examples.is_empty()
180    }
181
182    /// Select the most helpful demonstrations for a query.
183    ///
184    /// # Arguments
185    ///
186    /// * `query` - The input text to find demonstrations for
187    /// * `k` - Maximum number of demonstrations to return
188    ///
189    /// # Returns
190    ///
191    /// Up to `k` demonstrations, sorted by helpfulness score (descending).
192    #[must_use]
193    pub fn select(&self, query: &str, k: usize) -> Vec<&DemonstrationExample> {
194        if self.examples.is_empty() || k == 0 {
195            return vec![];
196        }
197
198        let query_features = DemonstrationExample::compute_features(query, &[]);
199
200        // Performance: Pre-allocate scored vec with estimated capacity
201        // Score all demonstrations
202        let mut scored: Vec<_> = Vec::with_capacity(self.examples.len().min(k * 2));
203        scored.extend(
204            self.examples
205                .iter()
206                .map(|ex| {
207                    let score = self.helpfulness_score(&query_features, ex);
208                    (ex, score)
209                })
210                .filter(|(_, score)| *score >= self.config.min_score),
211        );
212
213        // Performance: Use unstable sort (we don't need stable sort here)
214        // Sort by score descending
215        scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
216
217        // Take top k
218        scored.into_iter().take(k).map(|(ex, _)| ex).collect()
219    }
220
221    /// Select demonstrations with their helpfulness scores.
222    #[must_use]
223    pub fn select_with_scores(&self, query: &str, k: usize) -> Vec<(&DemonstrationExample, f64)> {
224        if self.examples.is_empty() || k == 0 {
225            return vec![];
226        }
227
228        let query_features = DemonstrationExample::compute_features(query, &[]);
229
230        // Performance: Pre-allocate scored vec with estimated capacity
231        let mut scored: Vec<_> = Vec::with_capacity(self.examples.len().min(k * 2));
232        scored.extend(
233            self.examples
234                .iter()
235                .map(|ex| {
236                    let score = self.helpfulness_score(&query_features, ex);
237                    (ex, score)
238                })
239                .filter(|(_, score)| *score >= self.config.min_score),
240        );
241
242        // Performance: Use unstable sort (we don't need stable sort here)
243        scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
244
245        scored.into_iter().take(k).collect()
246    }
247
248    /// Compute helpfulness score for a demonstration.
249    ///
250    /// Based on CMAS demonstration discriminator, combines:
251    /// 1. Text similarity (token overlap)
252    /// 2. Entity type overlap
253    /// 3. Entity density similarity
254    fn helpfulness_score(&self, query: &ExampleFeatures, demo: &DemonstrationExample) -> f64 {
255        let sim = self.token_similarity(&query.tokens, &demo.features.tokens);
256        let type_overlap = self.type_overlap(&query.entity_types, &demo.features.entity_types);
257        let density_sim =
258            self.density_similarity(query.entity_density, demo.features.entity_density);
259
260        self.config.similarity_weight * sim
261            + self.config.type_overlap_weight * type_overlap
262            + self.config.density_weight * density_sim
263    }
264
265    /// Jaccard similarity between token sets.
266    fn token_similarity(&self, a: &[String], b: &[String]) -> f64 {
267        if a.is_empty() && b.is_empty() {
268            return 1.0;
269        }
270        if a.is_empty() || b.is_empty() {
271            return 0.0;
272        }
273
274        let set_a: std::collections::HashSet<_> = a.iter().collect();
275        let set_b: std::collections::HashSet<_> = b.iter().collect();
276
277        let intersection = set_a.intersection(&set_b).count();
278        let union = set_a.union(&set_b).count();
279
280        if union == 0 {
281            0.0
282        } else {
283            intersection as f64 / union as f64
284        }
285    }
286
287    /// Overlap ratio for entity types.
288    fn type_overlap(&self, query_types: &[String], demo_types: &[String]) -> f64 {
289        // For queries without known types, all demonstrations are equally good
290        if query_types.is_empty() {
291            return 1.0;
292        }
293        if demo_types.is_empty() {
294            return 0.0;
295        }
296
297        let query_set: std::collections::HashSet<_> = query_types.iter().collect();
298        let demo_set: std::collections::HashSet<_> = demo_types.iter().collect();
299
300        let overlap = query_set.intersection(&demo_set).count();
301        overlap as f64 / query_set.len() as f64
302    }
303
304    /// Similarity based on entity density.
305    fn density_similarity(&self, query_density: f64, demo_density: f64) -> f64 {
306        // Exponential decay based on density difference
307        let diff = (query_density - demo_density).abs();
308        (-diff / 5.0).exp() // Scale factor of 5 entities per 100 tokens
309    }
310}
311
312/// Type-Related Feature (TRF) extractor.
313///
314/// Extracts context features around entity mentions, as described in CMAS.
315#[derive(Debug, Clone, Default)]
316pub struct TRFExtractor {
317    window_size: usize,
318}
319
320impl TRFExtractor {
321    /// Create a new TRF extractor with default window size.
322    #[must_use]
323    pub fn new() -> Self {
324        Self { window_size: 3 }
325    }
326
327    /// Create with custom window size.
328    #[must_use]
329    pub fn with_window(size: usize) -> Self {
330        Self { window_size: size }
331    }
332
333    /// Extract type-related features from text.
334    ///
335    /// Returns context words around potential entity spans.
336    #[must_use]
337    pub fn extract(
338        &self,
339        text: &str,
340        entities: &[(String, String, usize, usize)],
341    ) -> HashMap<String, Vec<String>> {
342        let mut features: HashMap<String, Vec<String>> = HashMap::new();
343        let tokens: Vec<&str> = text.split_whitespace().collect();
344
345        for (entity_text, entity_type, start, _end) in entities {
346            // Find token index for entity start
347            let mut char_pos = 0;
348            let mut token_idx = None;
349
350            for (i, token) in tokens.iter().enumerate() {
351                if char_pos == *start || (char_pos <= *start && char_pos + token.len() > *start) {
352                    token_idx = Some(i);
353                    break;
354                }
355                char_pos += token.len() + 1; // +1 for space
356            }
357
358            if let Some(idx) = token_idx {
359                // Extract window around entity
360                let start_idx = idx.saturating_sub(self.window_size);
361                let end_idx = (idx + self.window_size + 1).min(tokens.len());
362
363                let context: Vec<String> = tokens[start_idx..end_idx]
364                    .iter()
365                    .enumerate()
366                    .filter(|(i, _)| *i + start_idx != idx) // Exclude entity itself
367                    .map(|(_, &t)| t.to_lowercase())
368                    .collect();
369
370                features
371                    .entry(entity_type.clone())
372                    .or_default()
373                    .extend(context);
374            }
375
376            // Also add the entity text as a feature (useful for learning patterns)
377            features
378                .entry(format!("{}_text", entity_type))
379                .or_default()
380                .push(entity_text.to_lowercase());
381        }
382
383        features
384    }
385}
386
387// =============================================================================
388// Tests
389// =============================================================================
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    #[test]
396    fn test_demonstration_example_creation() {
397        let demo = DemonstrationExample::new(
398            "Steve Jobs founded Apple.",
399            vec![("Steve Jobs", "PER", 0, 10), ("Apple", "ORG", 19, 24)],
400        );
401
402        assert_eq!(demo.entities.len(), 2);
403        assert!(demo.features.entity_types.contains(&"PER".to_string()));
404        assert!(demo.features.entity_types.contains(&"ORG".to_string()));
405    }
406
407    #[test]
408    fn test_bank_add_and_len() {
409        let mut bank = DemonstrationBank::new();
410        assert!(bank.is_empty());
411
412        bank.add("Test text.", vec![("Test", "MISC", 0, 4)]);
413        assert_eq!(bank.len(), 1);
414    }
415
416    #[test]
417    fn test_select_demonstrations() {
418        let mut bank = DemonstrationBank::new();
419
420        bank.add(
421            "Steve Jobs founded Apple in California.",
422            vec![
423                ("Steve Jobs", "PER", 0, 10),
424                ("Apple", "ORG", 19, 24),
425                ("California", "LOC", 28, 38),
426            ],
427        );
428
429        bank.add(
430            "The weather in New York is nice today.",
431            vec![("New York", "LOC", 15, 23)],
432        );
433
434        bank.add(
435            "Bill Gates started Microsoft in Seattle.",
436            vec![
437                ("Bill Gates", "PER", 0, 10),
438                ("Microsoft", "ORG", 19, 28),
439                ("Seattle", "LOC", 32, 39),
440            ],
441        );
442
443        // Query about companies (same domain as demos)
444        let demos = bank.select("Steve Jobs founded Apple in Silicon Valley.", 3);
445
446        // Should return all 3 demos
447        assert_eq!(demos.len(), 3);
448
449        // All demos should be returned - verify we have all three
450        let demo_texts: Vec<_> = demos.iter().map(|d| d.text.as_str()).collect();
451        assert!(demo_texts.contains(&"Steve Jobs founded Apple in California."));
452        assert!(demo_texts.contains(&"Bill Gates started Microsoft in Seattle."));
453        assert!(demo_texts.contains(&"The weather in New York is nice today."));
454    }
455
456    #[test]
457    fn test_select_with_scores() {
458        let mut bank = DemonstrationBank::new();
459
460        bank.add("Apple is in Cupertino.", vec![("Apple", "ORG", 0, 5)]);
461        bank.add("Google is in Mountain View.", vec![("Google", "ORG", 0, 6)]);
462
463        let demos = bank.select_with_scores("Microsoft is in Redmond.", 2);
464
465        assert_eq!(demos.len(), 2);
466        // Both should have positive scores
467        for (_, score) in &demos {
468            assert!(*score > 0.0);
469        }
470    }
471
472    #[test]
473    fn test_select_empty_bank() {
474        let bank = DemonstrationBank::new();
475        let demos = bank.select("Test query.", 5);
476        assert!(demos.is_empty());
477    }
478
479    #[test]
480    fn test_trf_extractor() {
481        let extractor = TRFExtractor::new();
482
483        let features = extractor.extract(
484            "The CEO Steve Jobs announced the new iPhone.",
485            &[("Steve Jobs".to_string(), "PER".to_string(), 8, 18)],
486        );
487
488        assert!(features.contains_key("PER"));
489        let per_context = features.get("PER").unwrap();
490        // Should contain context words around "Steve Jobs"
491        assert!(per_context.iter().any(|w| w == "ceo" || w == "announced"));
492    }
493
494    #[test]
495    fn test_helpfulness_config() {
496        let config = HelpfulnessConfig {
497            similarity_weight: 0.5,
498            type_overlap_weight: 0.3,
499            density_weight: 0.2,
500            min_score: 0.2,
501        };
502
503        let bank = DemonstrationBank::with_config(config);
504        assert!(!bank.config.min_score.is_nan());
505    }
506}