1#[derive(Debug, Clone, PartialEq, Eq, Hash)]
8pub enum EntityType {
9 Person,
10 Organization,
11 Location,
12 Event,
13 Product,
14 Concept,
15 Date,
16 Number,
17 Unknown,
18}
19
20impl EntityType {
21 pub fn label(&self) -> &'static str {
23 match self {
24 EntityType::Person => "Person",
25 EntityType::Organization => "Organization",
26 EntityType::Location => "Location",
27 EntityType::Event => "Event",
28 EntityType::Product => "Product",
29 EntityType::Concept => "Concept",
30 EntityType::Date => "Date",
31 EntityType::Number => "Number",
32 EntityType::Unknown => "Unknown",
33 }
34 }
35}
36
37#[derive(Debug, Clone)]
39pub struct ClassificationFeature {
40 pub name: String,
41 pub value: f64,
42}
43
44#[derive(Debug, Clone)]
46pub struct ClassificationResult {
47 pub entity_text: String,
48 pub predicted_type: EntityType,
49 pub confidence: f64,
50 pub features: Vec<ClassificationFeature>,
51}
52
53#[derive(Debug, Clone)]
56pub struct ClassificationRule {
57 pub pattern: String,
58 pub entity_type: EntityType,
59 pub confidence_boost: f64,
60}
61
62const MONTH_NAMES: &[&str] = &[
64 "january",
65 "february",
66 "march",
67 "april",
68 "may",
69 "june",
70 "july",
71 "august",
72 "september",
73 "october",
74 "november",
75 "december",
76 "jan",
77 "feb",
78 "mar",
79 "apr",
80 "jun",
81 "jul",
82 "aug",
83 "sep",
84 "oct",
85 "nov",
86 "dec",
87];
88
89const LOCATION_SUFFIXES: &[&str] = &[
91 "city", "river", "mountain", "street", "avenue", "lake", "island", "valley",
92];
93
94const ORG_SUFFIXES: &[&str] = &["inc", "corp", "ltd", "gmbh", "llc", "plc", "ag", "bv", "sa"];
96
97const BASE_CONFIDENCE: f64 = 0.5;
99
100pub struct EntityClassifier {
102 rules: Vec<ClassificationRule>,
103}
104
105impl EntityClassifier {
106 pub fn new() -> Self {
108 Self { rules: Vec::new() }
109 }
110
111 pub fn add_rule(&mut self, rule: ClassificationRule) {
113 self.rules.push(rule);
114 }
115
116 pub fn rule_count(&self) -> usize {
118 self.rules.len()
119 }
120
121 pub fn classify(&self, text: &str) -> ClassificationResult {
123 let lower = text.to_lowercase();
124 let mut features: Vec<ClassificationFeature> = Vec::new();
125
126 let mut candidates: Vec<(EntityType, f64)> = Vec::new();
128
129 if text
131 .trim()
132 .chars()
133 .all(|c| c.is_ascii_digit() || c == '.' || c == '-')
134 && !text.trim().is_empty()
135 {
136 features.push(ClassificationFeature {
137 name: "is_numeric".to_string(),
138 value: 1.0,
139 });
140 candidates.push((EntityType::Number, BASE_CONFIDENCE + 0.4));
141 }
142
143 let has_digits = text.chars().any(|c| c.is_ascii_digit());
145 let has_month = MONTH_NAMES.iter().any(|&m| lower.contains(m));
146 if has_digits && has_month {
147 features.push(ClassificationFeature {
148 name: "has_month_name".to_string(),
149 value: 1.0,
150 });
151 candidates.push((EntityType::Date, BASE_CONFIDENCE + 0.35));
152 }
153
154 if let Some(suffix) = LOCATION_SUFFIXES.iter().find(|&&s| lower.ends_with(s)) {
156 features.push(ClassificationFeature {
157 name: format!("location_suffix:{suffix}"),
158 value: 1.0,
159 });
160 candidates.push((EntityType::Location, BASE_CONFIDENCE + 0.3));
161 }
162
163 let last_word_lower = lower
165 .split_whitespace()
166 .last()
167 .unwrap_or("")
168 .trim_end_matches('.');
169 if ORG_SUFFIXES.contains(&last_word_lower) {
170 features.push(ClassificationFeature {
171 name: format!("org_suffix:{last_word_lower}"),
172 value: 1.0,
173 });
174 candidates.push((EntityType::Organization, BASE_CONFIDENCE + 0.35));
175 }
176
177 let starts_upper = text
179 .chars()
180 .next()
181 .map(|c| c.is_uppercase())
182 .unwrap_or(false);
183 let no_spaces = !text.contains(' ');
184 let short = text.len() <= 20;
185 if starts_upper && no_spaces && short {
186 features.push(ClassificationFeature {
187 name: "capitalized_single_token".to_string(),
188 value: 1.0,
189 });
190 candidates.push((EntityType::Person, BASE_CONFIDENCE + 0.1));
192 }
193
194 for rule in &self.rules {
196 let pattern_lower = rule.pattern.to_lowercase();
197 if lower.contains(&pattern_lower) {
198 features.push(ClassificationFeature {
199 name: format!("rule_match:{}", rule.pattern),
200 value: rule.confidence_boost,
201 });
202 let conf = (BASE_CONFIDENCE + rule.confidence_boost).clamp(0.0, 1.0);
203 candidates.push((rule.entity_type.clone(), conf));
204 }
205 }
206
207 let (predicted_type, confidence) = candidates
209 .into_iter()
210 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
211 .unwrap_or((EntityType::Unknown, BASE_CONFIDENCE));
212
213 ClassificationResult {
214 entity_text: text.to_string(),
215 predicted_type,
216 confidence: confidence.clamp(0.0, 1.0),
217 features,
218 }
219 }
220
221 pub fn classify_batch(&self, texts: &[&str]) -> Vec<ClassificationResult> {
223 texts.iter().map(|&t| self.classify(t)).collect()
224 }
225}
226
227impl Default for EntityClassifier {
228 fn default() -> Self {
229 Self::new()
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 fn classifier() -> EntityClassifier {
238 EntityClassifier::new()
239 }
240
241 #[test]
244 fn test_entity_type_labels() {
245 assert_eq!(EntityType::Person.label(), "Person");
246 assert_eq!(EntityType::Organization.label(), "Organization");
247 assert_eq!(EntityType::Location.label(), "Location");
248 assert_eq!(EntityType::Date.label(), "Date");
249 assert_eq!(EntityType::Number.label(), "Number");
250 assert_eq!(EntityType::Unknown.label(), "Unknown");
251 }
252
253 #[test]
256 fn test_classify_integer() {
257 let c = classifier();
258 let r = c.classify("42");
259 assert_eq!(r.predicted_type, EntityType::Number);
260 }
261
262 #[test]
263 fn test_classify_float() {
264 let c = classifier();
265 let r = c.classify("3.14");
266 assert_eq!(r.predicted_type, EntityType::Number);
267 }
268
269 #[test]
270 fn test_classify_negative_number() {
271 let c = classifier();
272 let r = c.classify("-7");
273 assert_eq!(r.predicted_type, EntityType::Number);
274 }
275
276 #[test]
279 fn test_classify_date_with_month_name() {
280 let c = classifier();
281 let r = c.classify("January 2024");
282 assert_eq!(r.predicted_type, EntityType::Date);
283 }
284
285 #[test]
286 fn test_classify_date_abbreviated_month() {
287 let c = classifier();
288 let r = c.classify("15 Mar 2025");
289 assert_eq!(r.predicted_type, EntityType::Date);
290 }
291
292 #[test]
295 fn test_classify_location_city() {
296 let c = classifier();
297 let r = c.classify("New York City");
298 assert_eq!(r.predicted_type, EntityType::Location);
300 }
301
302 #[test]
303 fn test_classify_location_river() {
304 let c = classifier();
305 let r = c.classify("Amazon River");
306 assert_eq!(r.predicted_type, EntityType::Location);
307 }
308
309 #[test]
310 fn test_classify_location_mountain() {
311 let c = classifier();
312 let r = c.classify("Mount Everest Mountain");
313 assert_eq!(r.predicted_type, EntityType::Location);
314 }
315
316 #[test]
317 fn test_classify_location_street() {
318 let c = classifier();
319 let r = c.classify("Baker Street");
320 assert_eq!(r.predicted_type, EntityType::Location);
321 }
322
323 #[test]
326 fn test_classify_org_inc() {
327 let c = classifier();
328 let r = c.classify("Acme Corp");
329 assert_eq!(r.predicted_type, EntityType::Organization);
330 }
331
332 #[test]
333 fn test_classify_org_ltd() {
334 let c = classifier();
335 let r = c.classify("Widgets Ltd");
336 assert_eq!(r.predicted_type, EntityType::Organization);
337 }
338
339 #[test]
340 fn test_classify_org_gmbh() {
341 let c = classifier();
342 let r = c.classify("Muller GmbH");
343 assert_eq!(r.predicted_type, EntityType::Organization);
344 }
345
346 #[test]
349 fn test_classify_person_single_capitalized() {
350 let c = classifier();
351 let r = c.classify("Alice");
352 assert_eq!(r.predicted_type, EntityType::Person);
354 }
355
356 #[test]
357 fn test_classify_person_confidence_positive() {
358 let c = classifier();
359 let r = c.classify("Bob");
360 assert!(r.confidence > 0.0);
361 }
362
363 #[test]
366 fn test_classify_unknown_generic_phrase() {
367 let c = classifier();
368 let r = c.classify("the semantic web is interesting");
369 let _ = r; }
373
374 #[test]
377 fn test_confidence_always_in_range() {
378 let c = classifier();
379 let texts = [
380 "Alice",
381 "42",
382 "January 2024",
383 "Acme Corp",
384 "Baker Street",
385 "foo",
386 "",
387 ];
388 for text in &texts {
389 let r = c.classify(text);
390 assert!(
391 r.confidence >= 0.0 && r.confidence <= 1.0,
392 "Confidence out of range for '{text}': {}",
393 r.confidence
394 );
395 }
396 }
397
398 #[test]
401 fn test_features_populated_for_number() {
402 let c = classifier();
403 let r = c.classify("100");
404 assert!(!r.features.is_empty());
405 }
406
407 #[test]
410 fn test_add_custom_rule_count() {
411 let mut c = classifier();
412 assert_eq!(c.rule_count(), 0);
413 c.add_rule(ClassificationRule {
414 pattern: "summit".to_string(),
415 entity_type: EntityType::Event,
416 confidence_boost: 0.4,
417 });
418 assert_eq!(c.rule_count(), 1);
419 }
420
421 #[test]
422 fn test_custom_rule_fires() {
423 let mut c = classifier();
424 c.add_rule(ClassificationRule {
425 pattern: "summit".to_string(),
426 entity_type: EntityType::Event,
427 confidence_boost: 0.4,
428 });
429 let r = c.classify("G7 Summit 2025");
430 assert_eq!(r.predicted_type, EntityType::Event);
431 }
432
433 #[test]
434 fn test_custom_rule_confidence_boosted() {
435 let mut c = classifier();
436 c.add_rule(ClassificationRule {
437 pattern: "widget".to_string(),
438 entity_type: EntityType::Product,
439 confidence_boost: 0.3,
440 });
441 let r = c.classify("Super Widget Pro");
442 assert!(r.confidence >= BASE_CONFIDENCE + 0.3 - 1e-9);
443 }
444
445 #[test]
446 fn test_custom_rule_case_insensitive() {
447 let mut c = classifier();
448 c.add_rule(ClassificationRule {
449 pattern: "WIDGET".to_string(),
450 entity_type: EntityType::Product,
451 confidence_boost: 0.2,
452 });
453 let r = c.classify("widget maker");
454 assert_eq!(r.predicted_type, EntityType::Product);
455 }
456
457 #[test]
458 fn test_multiple_custom_rules_highest_wins() {
459 let mut c = classifier();
460 c.add_rule(ClassificationRule {
461 pattern: "demo".to_string(),
462 entity_type: EntityType::Event,
463 confidence_boost: 0.2,
464 });
465 c.add_rule(ClassificationRule {
466 pattern: "demo".to_string(),
467 entity_type: EntityType::Concept,
468 confidence_boost: 0.45,
469 });
470 let r = c.classify("demo system");
471 assert_eq!(r.predicted_type, EntityType::Concept);
473 }
474
475 #[test]
478 fn test_classify_batch_count() {
479 let c = classifier();
480 let texts = ["Alice", "Acme Corp", "42", "Baker Street"];
481 let results = c.classify_batch(&texts);
482 assert_eq!(results.len(), 4);
483 }
484
485 #[test]
486 fn test_classify_batch_empty() {
487 let c = classifier();
488 let results = c.classify_batch(&[]);
489 assert!(results.is_empty());
490 }
491
492 #[test]
493 fn test_classify_batch_single() {
494 let c = classifier();
495 let results = c.classify_batch(&["100"]);
496 assert_eq!(results.len(), 1);
497 assert_eq!(results[0].predicted_type, EntityType::Number);
498 }
499
500 #[test]
503 fn test_classify_empty_string() {
504 let c = classifier();
505 let r = c.classify("");
506 let _ = r.predicted_type; }
509
510 #[test]
511 fn test_classify_whitespace_only() {
512 let c = classifier();
513 let r = c.classify(" ");
514 let _ = r;
515 }
516
517 #[test]
518 fn test_default_classifier() {
519 let c = EntityClassifier::default();
520 assert_eq!(c.rule_count(), 0);
521 }
522}