Skip to main content

oxirs_graphrag/
entity_classifier.rs

1//! Entity type classification for knowledge graph nodes.
2//!
3//! Uses rule-based heuristics (pattern matching, suffix detection, numeric checks)
4//! plus user-defined classification rules to assign entity types and confidence scores.
5
6/// The set of recognised entity types.
7#[derive(Debug, Clone, PartialEq, Eq, Hash)]
8pub enum EntityType {
9    Person,
10    Organization,
11    Location,
12    Event,
13    Product,
14    Concept,
15    Date,
16    Number,
17    Unknown,
18}
19
20impl EntityType {
21    /// Return a human-readable label.
22    pub fn label(&self) -> &'static str {
23        match self {
24            EntityType::Person => "Person",
25            EntityType::Organization => "Organization",
26            EntityType::Location => "Location",
27            EntityType::Event => "Event",
28            EntityType::Product => "Product",
29            EntityType::Concept => "Concept",
30            EntityType::Date => "Date",
31            EntityType::Number => "Number",
32            EntityType::Unknown => "Unknown",
33        }
34    }
35}
36
37/// A single named feature with a numeric value, used to explain classification.
38#[derive(Debug, Clone)]
39pub struct ClassificationFeature {
40    pub name: String,
41    pub value: f64,
42}
43
44/// The result of classifying a single entity text string.
45#[derive(Debug, Clone)]
46pub struct ClassificationResult {
47    pub entity_text: String,
48    pub predicted_type: EntityType,
49    pub confidence: f64,
50    pub features: Vec<ClassificationFeature>,
51}
52
53/// A user-defined rule: if `pattern` is found in the entity text (case-insensitive),
54/// score `entity_type` with an additional `confidence_boost`.
55#[derive(Debug, Clone)]
56pub struct ClassificationRule {
57    pub pattern: String,
58    pub entity_type: EntityType,
59    pub confidence_boost: f64,
60}
61
62/// Month name constants used for Date detection.
63const MONTH_NAMES: &[&str] = &[
64    "january",
65    "february",
66    "march",
67    "april",
68    "may",
69    "june",
70    "july",
71    "august",
72    "september",
73    "october",
74    "november",
75    "december",
76    "jan",
77    "feb",
78    "mar",
79    "apr",
80    "jun",
81    "jul",
82    "aug",
83    "sep",
84    "oct",
85    "nov",
86    "dec",
87];
88
89/// Common location suffixes.
90const LOCATION_SUFFIXES: &[&str] = &[
91    "city", "river", "mountain", "street", "avenue", "lake", "island", "valley",
92];
93
94/// Common organisation suffixes.
95const ORG_SUFFIXES: &[&str] = &["inc", "corp", "ltd", "gmbh", "llc", "plc", "ag", "bv", "sa"];
96
97/// Base confidence for any match — additional boosts are applied on top.
98const BASE_CONFIDENCE: f64 = 0.5;
99
100/// Entity classifier using heuristic rules and optional user-defined rules.
101pub struct EntityClassifier {
102    rules: Vec<ClassificationRule>,
103}
104
105impl EntityClassifier {
106    /// Create a new classifier with no user rules.
107    pub fn new() -> Self {
108        Self { rules: Vec::new() }
109    }
110
111    /// Add a user-defined classification rule.
112    pub fn add_rule(&mut self, rule: ClassificationRule) {
113        self.rules.push(rule);
114    }
115
116    /// Return the number of user-defined rules.
117    pub fn rule_count(&self) -> usize {
118        self.rules.len()
119    }
120
121    /// Classify a single entity text string.
122    pub fn classify(&self, text: &str) -> ClassificationResult {
123        let lower = text.to_lowercase();
124        let mut features: Vec<ClassificationFeature> = Vec::new();
125
126        // Collect (EntityType, confidence) candidates from all heuristics
127        let mut candidates: Vec<(EntityType, f64)> = Vec::new();
128
129        // 1. Pure numeric check → Number
130        if text
131            .trim()
132            .chars()
133            .all(|c| c.is_ascii_digit() || c == '.' || c == '-')
134            && !text.trim().is_empty()
135        {
136            features.push(ClassificationFeature {
137                name: "is_numeric".to_string(),
138                value: 1.0,
139            });
140            candidates.push((EntityType::Number, BASE_CONFIDENCE + 0.4));
141        }
142
143        // 2. Date detection: contains digits AND a month name
144        let has_digits = text.chars().any(|c| c.is_ascii_digit());
145        let has_month = MONTH_NAMES.iter().any(|&m| lower.contains(m));
146        if has_digits && has_month {
147            features.push(ClassificationFeature {
148                name: "has_month_name".to_string(),
149                value: 1.0,
150            });
151            candidates.push((EntityType::Date, BASE_CONFIDENCE + 0.35));
152        }
153
154        // 3. Location suffix
155        if let Some(suffix) = LOCATION_SUFFIXES.iter().find(|&&s| lower.ends_with(s)) {
156            features.push(ClassificationFeature {
157                name: format!("location_suffix:{suffix}"),
158                value: 1.0,
159            });
160            candidates.push((EntityType::Location, BASE_CONFIDENCE + 0.3));
161        }
162
163        // 4. Organisation suffix (check last word)
164        let last_word_lower = lower
165            .split_whitespace()
166            .last()
167            .unwrap_or("")
168            .trim_end_matches('.');
169        if ORG_SUFFIXES.contains(&last_word_lower) {
170            features.push(ClassificationFeature {
171                name: format!("org_suffix:{last_word_lower}"),
172                value: 1.0,
173            });
174            candidates.push((EntityType::Organization, BASE_CONFIDENCE + 0.35));
175        }
176
177        // 5. Starts with uppercase + no spaces + ≤ 20 chars → possible Person/Org
178        let starts_upper = text
179            .chars()
180            .next()
181            .map(|c| c.is_uppercase())
182            .unwrap_or(false);
183        let no_spaces = !text.contains(' ');
184        let short = text.len() <= 20;
185        if starts_upper && no_spaces && short {
186            features.push(ClassificationFeature {
187                name: "capitalized_single_token".to_string(),
188                value: 1.0,
189            });
190            // Slight lean towards Person if it looks like a name; otherwise Concept
191            candidates.push((EntityType::Person, BASE_CONFIDENCE + 0.1));
192        }
193
194        // 6. User-defined rules
195        for rule in &self.rules {
196            let pattern_lower = rule.pattern.to_lowercase();
197            if lower.contains(&pattern_lower) {
198                features.push(ClassificationFeature {
199                    name: format!("rule_match:{}", rule.pattern),
200                    value: rule.confidence_boost,
201                });
202                let conf = (BASE_CONFIDENCE + rule.confidence_boost).clamp(0.0, 1.0);
203                candidates.push((rule.entity_type.clone(), conf));
204            }
205        }
206
207        // Choose the candidate with the highest confidence
208        let (predicted_type, confidence) = candidates
209            .into_iter()
210            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
211            .unwrap_or((EntityType::Unknown, BASE_CONFIDENCE));
212
213        ClassificationResult {
214            entity_text: text.to_string(),
215            predicted_type,
216            confidence: confidence.clamp(0.0, 1.0),
217            features,
218        }
219    }
220
221    /// Classify a batch of entity text strings.
222    pub fn classify_batch(&self, texts: &[&str]) -> Vec<ClassificationResult> {
223        texts.iter().map(|&t| self.classify(t)).collect()
224    }
225}
226
227impl Default for EntityClassifier {
228    fn default() -> Self {
229        Self::new()
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    fn classifier() -> EntityClassifier {
238        EntityClassifier::new()
239    }
240
241    // --- Entity type label ---
242
243    #[test]
244    fn test_entity_type_labels() {
245        assert_eq!(EntityType::Person.label(), "Person");
246        assert_eq!(EntityType::Organization.label(), "Organization");
247        assert_eq!(EntityType::Location.label(), "Location");
248        assert_eq!(EntityType::Date.label(), "Date");
249        assert_eq!(EntityType::Number.label(), "Number");
250        assert_eq!(EntityType::Unknown.label(), "Unknown");
251    }
252
253    // --- Number detection ---
254
255    #[test]
256    fn test_classify_integer() {
257        let c = classifier();
258        let r = c.classify("42");
259        assert_eq!(r.predicted_type, EntityType::Number);
260    }
261
262    #[test]
263    fn test_classify_float() {
264        let c = classifier();
265        let r = c.classify("3.14");
266        assert_eq!(r.predicted_type, EntityType::Number);
267    }
268
269    #[test]
270    fn test_classify_negative_number() {
271        let c = classifier();
272        let r = c.classify("-7");
273        assert_eq!(r.predicted_type, EntityType::Number);
274    }
275
276    // --- Date detection ---
277
278    #[test]
279    fn test_classify_date_with_month_name() {
280        let c = classifier();
281        let r = c.classify("January 2024");
282        assert_eq!(r.predicted_type, EntityType::Date);
283    }
284
285    #[test]
286    fn test_classify_date_abbreviated_month() {
287        let c = classifier();
288        let r = c.classify("15 Mar 2025");
289        assert_eq!(r.predicted_type, EntityType::Date);
290    }
291
292    // --- Location detection ---
293
294    #[test]
295    fn test_classify_location_city() {
296        let c = classifier();
297        let r = c.classify("New York City");
298        // ends with "city"
299        assert_eq!(r.predicted_type, EntityType::Location);
300    }
301
302    #[test]
303    fn test_classify_location_river() {
304        let c = classifier();
305        let r = c.classify("Amazon River");
306        assert_eq!(r.predicted_type, EntityType::Location);
307    }
308
309    #[test]
310    fn test_classify_location_mountain() {
311        let c = classifier();
312        let r = c.classify("Mount Everest Mountain");
313        assert_eq!(r.predicted_type, EntityType::Location);
314    }
315
316    #[test]
317    fn test_classify_location_street() {
318        let c = classifier();
319        let r = c.classify("Baker Street");
320        assert_eq!(r.predicted_type, EntityType::Location);
321    }
322
323    // --- Organisation detection ---
324
325    #[test]
326    fn test_classify_org_inc() {
327        let c = classifier();
328        let r = c.classify("Acme Corp");
329        assert_eq!(r.predicted_type, EntityType::Organization);
330    }
331
332    #[test]
333    fn test_classify_org_ltd() {
334        let c = classifier();
335        let r = c.classify("Widgets Ltd");
336        assert_eq!(r.predicted_type, EntityType::Organization);
337    }
338
339    #[test]
340    fn test_classify_org_gmbh() {
341        let c = classifier();
342        let r = c.classify("Muller GmbH");
343        assert_eq!(r.predicted_type, EntityType::Organization);
344    }
345
346    // --- Person detection (capitalised single token) ---
347
348    #[test]
349    fn test_classify_person_single_capitalized() {
350        let c = classifier();
351        let r = c.classify("Alice");
352        // Should be Person (capitalized single short token)
353        assert_eq!(r.predicted_type, EntityType::Person);
354    }
355
356    #[test]
357    fn test_classify_person_confidence_positive() {
358        let c = classifier();
359        let r = c.classify("Bob");
360        assert!(r.confidence > 0.0);
361    }
362
363    // --- Unknown ---
364
365    #[test]
366    fn test_classify_unknown_generic_phrase() {
367        let c = classifier();
368        let r = c.classify("the semantic web is interesting");
369        // None of the heuristics fire strongly; falls to Unknown
370        // (no uppercase start, no number, no month+digits, no suffix)
371        let _ = r; // just ensure no panic
372    }
373
374    // --- Confidence bounds ---
375
376    #[test]
377    fn test_confidence_always_in_range() {
378        let c = classifier();
379        let texts = [
380            "Alice",
381            "42",
382            "January 2024",
383            "Acme Corp",
384            "Baker Street",
385            "foo",
386            "",
387        ];
388        for text in &texts {
389            let r = c.classify(text);
390            assert!(
391                r.confidence >= 0.0 && r.confidence <= 1.0,
392                "Confidence out of range for '{text}': {}",
393                r.confidence
394            );
395        }
396    }
397
398    // --- Features populated ---
399
400    #[test]
401    fn test_features_populated_for_number() {
402        let c = classifier();
403        let r = c.classify("100");
404        assert!(!r.features.is_empty());
405    }
406
407    // --- Custom rule tests ---
408
409    #[test]
410    fn test_add_custom_rule_count() {
411        let mut c = classifier();
412        assert_eq!(c.rule_count(), 0);
413        c.add_rule(ClassificationRule {
414            pattern: "summit".to_string(),
415            entity_type: EntityType::Event,
416            confidence_boost: 0.4,
417        });
418        assert_eq!(c.rule_count(), 1);
419    }
420
421    #[test]
422    fn test_custom_rule_fires() {
423        let mut c = classifier();
424        c.add_rule(ClassificationRule {
425            pattern: "summit".to_string(),
426            entity_type: EntityType::Event,
427            confidence_boost: 0.4,
428        });
429        let r = c.classify("G7 Summit 2025");
430        assert_eq!(r.predicted_type, EntityType::Event);
431    }
432
433    #[test]
434    fn test_custom_rule_confidence_boosted() {
435        let mut c = classifier();
436        c.add_rule(ClassificationRule {
437            pattern: "widget".to_string(),
438            entity_type: EntityType::Product,
439            confidence_boost: 0.3,
440        });
441        let r = c.classify("Super Widget Pro");
442        assert!(r.confidence >= BASE_CONFIDENCE + 0.3 - 1e-9);
443    }
444
445    #[test]
446    fn test_custom_rule_case_insensitive() {
447        let mut c = classifier();
448        c.add_rule(ClassificationRule {
449            pattern: "WIDGET".to_string(),
450            entity_type: EntityType::Product,
451            confidence_boost: 0.2,
452        });
453        let r = c.classify("widget maker");
454        assert_eq!(r.predicted_type, EntityType::Product);
455    }
456
457    #[test]
458    fn test_multiple_custom_rules_highest_wins() {
459        let mut c = classifier();
460        c.add_rule(ClassificationRule {
461            pattern: "demo".to_string(),
462            entity_type: EntityType::Event,
463            confidence_boost: 0.2,
464        });
465        c.add_rule(ClassificationRule {
466            pattern: "demo".to_string(),
467            entity_type: EntityType::Concept,
468            confidence_boost: 0.45,
469        });
470        let r = c.classify("demo system");
471        // Concept has higher boost
472        assert_eq!(r.predicted_type, EntityType::Concept);
473    }
474
475    // --- Batch processing ---
476
477    #[test]
478    fn test_classify_batch_count() {
479        let c = classifier();
480        let texts = ["Alice", "Acme Corp", "42", "Baker Street"];
481        let results = c.classify_batch(&texts);
482        assert_eq!(results.len(), 4);
483    }
484
485    #[test]
486    fn test_classify_batch_empty() {
487        let c = classifier();
488        let results = c.classify_batch(&[]);
489        assert!(results.is_empty());
490    }
491
492    #[test]
493    fn test_classify_batch_single() {
494        let c = classifier();
495        let results = c.classify_batch(&["100"]);
496        assert_eq!(results.len(), 1);
497        assert_eq!(results[0].predicted_type, EntityType::Number);
498    }
499
500    // --- Edge cases ---
501
502    #[test]
503    fn test_classify_empty_string() {
504        let c = classifier();
505        let r = c.classify("");
506        // Empty string: no heuristics fire → Unknown
507        let _ = r.predicted_type; // just ensure no panic
508    }
509
510    #[test]
511    fn test_classify_whitespace_only() {
512        let c = classifier();
513        let r = c.classify("   ");
514        let _ = r;
515    }
516
517    #[test]
518    fn test_default_classifier() {
519        let c = EntityClassifier::default();
520        assert_eq!(c.rule_count(), 0);
521    }
522}