1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
//! NER evaluation trait and implementations.
//!
//! Provides trait-based evaluation matching the RetrievalEvaluator pattern
//! for consistency and extensibility.
use super::datasets::GoldEntity;
use super::types::{GoalCheckResult, MetricValue};
use super::TypeMetrics;
use anno::{Error, Model, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Per-test-case NER evaluation metrics.
///
/// Type-safe metrics using `MetricValue` for compile-time guarantees.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NERQueryMetrics {
/// Test case text
pub text: String,
/// Optional test case ID
pub test_case_id: Option<String>,
/// Precision (type-safe, bounded 0.0-1.0)
pub precision: MetricValue,
/// Recall (type-safe, bounded 0.0-1.0)
pub recall: MetricValue,
/// F1 score (type-safe, bounded 0.0-1.0)
pub f1: MetricValue,
/// Per-entity-type metrics
pub per_type: HashMap<String, TypeMetrics>,
/// Number of entities found
pub found: usize,
/// Number of entities expected
pub expected: usize,
/// Number of correct predictions
pub correct: usize,
/// Processing speed (tokens per second)
pub tokens_per_second: f64,
}
/// Averaging mode for NER metrics.
///
/// Following seqeval conventions:
/// - Micro: Calculate metrics globally by counting total TP, FP, FN (default, recommended)
/// - Macro: Calculate metrics per test case, then average (gives equal weight to each case)
/// - Weighted: Like macro, but weighted by support (number of expected entities per case)
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
pub enum AveragingMode {
/// Calculate globally: total_correct / total_predicted (default, standard for NER)
#[default]
Micro,
/// Average per-case metrics (equal weight to each test case, regardless of size)
Macro,
/// Average per-case metrics weighted by support (number of expected entities)
Weighted,
}
/// Aggregated NER evaluation metrics with statistical measures.
///
/// Provides mean, standard deviation, and confidence intervals
/// for comprehensive analysis.
///
/// # Micro vs Macro Averaging
///
/// By default, we compute **micro-averaged** metrics (total_correct / total_found),
/// which is the standard for NER evaluation and matches seqeval's default.
///
/// Macro-averaging (average of per-case metrics) can inflate scores when test
/// cases have different sizes. A test case with 1 entity getting 100% F1 shouldn't
/// boost overall metrics as much as a case with 100 entities getting 50% F1.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NERAggregateMetrics {
/// Micro-averaged precision (total_correct / total_found)
pub precision: MetricValue,
/// Micro-averaged recall (total_correct / total_expected)
pub recall: MetricValue,
/// Micro-averaged F1 score
pub f1: MetricValue,
/// Macro-averaged precision (for comparison)
pub macro_precision: MetricValue,
/// Macro-averaged recall (for comparison)
pub macro_recall: MetricValue,
/// Macro-averaged F1 (for comparison)
pub macro_f1: MetricValue,
/// Precision standard deviation (of per-case metrics)
pub precision_std: f64,
/// Recall standard deviation (of per-case metrics)
pub recall_std: f64,
/// F1 standard deviation (of per-case metrics)
pub f1_std: f64,
/// Precision 95% confidence interval (lower, upper)
pub precision_ci_95: Option<(f64, f64)>,
/// Recall 95% confidence interval (lower, upper)
pub recall_ci_95: Option<(f64, f64)>,
/// F1 95% confidence interval (lower, upper)
pub f1_ci_95: Option<(f64, f64)>,
/// Per-entity-type aggregated metrics (micro-averaged)
pub per_type: HashMap<String, TypeMetrics>,
/// Mean tokens per second
pub tokens_per_second: f64,
/// Number of test cases evaluated
pub num_test_cases: usize,
/// Total entities found
pub total_found: usize,
/// Total entities expected
pub total_expected: usize,
/// Total correct predictions
pub total_correct: usize,
}
/// Type-safe NER evaluation goals.
///
/// Allows setting minimum thresholds for metrics with compile-time guarantees.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NERMetricGoals {
/// Minimum precision threshold
pub min_precision: Option<MetricValue>,
/// Minimum recall threshold
pub min_recall: Option<MetricValue>,
/// Minimum F1 threshold
pub min_f1: Option<MetricValue>,
/// Per-entity-type goals
pub per_type_goals: HashMap<String, TypeMetricGoals>,
}
impl NERMetricGoals {
/// Create new empty goals.
#[must_use]
pub fn new() -> Self {
Self {
min_precision: None,
min_recall: None,
min_f1: None,
per_type_goals: HashMap::new(),
}
}
/// Set minimum precision goal.
pub fn with_min_precision(mut self, value: f64) -> Result<Self> {
self.min_precision = Some(MetricValue::try_new(value)?);
Ok(self)
}
/// Set minimum recall goal.
pub fn with_min_recall(mut self, value: f64) -> Result<Self> {
self.min_recall = Some(MetricValue::try_new(value)?);
Ok(self)
}
/// Set minimum F1 goal.
pub fn with_min_f1(mut self, value: f64) -> Result<Self> {
self.min_f1 = Some(MetricValue::try_new(value)?);
Ok(self)
}
/// Add per-type goal.
#[must_use]
pub fn with_type_goal(mut self, entity_type: String, goal: TypeMetricGoals) -> Self {
self.per_type_goals.insert(entity_type, goal);
self
}
}
impl Default for NERMetricGoals {
fn default() -> Self {
Self::new()
}
}
/// Per-entity-type metric goals.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TypeMetricGoals {
/// Minimum precision for this entity type
pub min_precision: Option<MetricValue>,
/// Minimum recall for this entity type
pub min_recall: Option<MetricValue>,
/// Minimum F1 for this entity type
pub min_f1: Option<MetricValue>,
}
impl TypeMetricGoals {
/// Create new type goals.
#[must_use]
pub fn new() -> Self {
Self {
min_precision: None,
min_recall: None,
min_f1: None,
}
}
/// Set minimum precision.
pub fn with_min_precision(mut self, value: f64) -> Result<Self> {
self.min_precision = Some(MetricValue::try_new(value)?);
Ok(self)
}
/// Set minimum recall.
pub fn with_min_recall(mut self, value: f64) -> Result<Self> {
self.min_recall = Some(MetricValue::try_new(value)?);
Ok(self)
}
/// Set minimum F1.
pub fn with_min_f1(mut self, value: f64) -> Result<Self> {
self.min_f1 = Some(MetricValue::try_new(value)?);
Ok(self)
}
}
impl Default for TypeMetricGoals {
fn default() -> Self {
Self::new()
}
}
/// Trait for NER evaluation strategies.
///
/// Allows plugging in different evaluation implementations:
/// - Standard evaluator (exact match)
/// - Partial match evaluator (overlap-based)
/// - Custom evaluators (for research, special metrics)
///
/// # Example
///
/// ```rust
/// use anno_eval::eval::{GoldEntity, StandardNEREvaluator, NEREvaluator};
/// use anno::{RegexNER, Model, EntityType};
///
/// let evaluator = StandardNEREvaluator::new();
/// let model = RegexNER::new();
/// let ground_truth = vec![
/// GoldEntity::new("$100", EntityType::Money, 6),
/// ];
///
/// let metrics = evaluator.evaluate_test_case(
/// &model,
/// "Cost: $100",
/// &ground_truth,
/// Some("test-1"),
/// ).unwrap();
///
/// assert!(metrics.precision.get() > 0.0);
/// ```
pub trait NEREvaluator: Send + Sync {
/// Evaluate a single test case.
///
/// # Arguments
/// * `model` - NER model to evaluate
/// * `text` - Text to extract entities from (must not be empty)
/// * `ground_truth` - Expected entities
/// * `test_case_id` - Optional test case identifier
///
/// # Returns
/// Per-test-case metrics with precision, recall, F1, and per-type breakdowns
///
/// # Errors
/// Returns `Error::InvalidInput` if:
/// - Text is empty
/// - Ground truth entities are invalid (overlapping, out of bounds)
/// - Metrics are invalid (NaN or Inf)
fn evaluate_test_case(
&self,
model: &dyn Model,
text: &str,
ground_truth: &[GoldEntity],
test_case_id: Option<&str>,
) -> Result<NERQueryMetrics>;
/// Aggregate metrics across multiple test cases.
///
/// # Arguments
/// * `query_metrics` - Per-test-case metrics
///
/// # Returns
/// Aggregate metrics with statistical measures
fn aggregate(&self, query_metrics: &[NERQueryMetrics]) -> Result<NERAggregateMetrics>;
/// Check if metrics meet goals.
///
/// # Arguments
/// * `metrics` - Aggregate metrics to check
/// * `goals` - Goals to check against
///
/// # Returns
/// Goal check result with pass/fail status
fn check_goals(
&self,
metrics: &NERAggregateMetrics,
goals: &NERMetricGoals,
) -> Result<GoalCheckResult>;
}
/// Standard NER evaluator implementation.
///
/// Computes standard NER metrics: Precision, Recall, F1 (exact match).
pub struct StandardNEREvaluator;
impl StandardNEREvaluator {
/// Create a new standard evaluator.
#[must_use]
pub fn new() -> Self {
Self
}
}
impl Default for StandardNEREvaluator {
fn default() -> Self {
Self::new()
}
}
impl NEREvaluator for StandardNEREvaluator {
fn evaluate_test_case(
&self,
model: &dyn Model,
text: &str,
ground_truth: &[GoldEntity],
test_case_id: Option<&str>,
) -> Result<NERQueryMetrics> {
// Validate input
if text.is_empty() {
return Err(Error::InvalidInput(
"Text cannot be empty for NER evaluation".to_string(),
));
}
// Validate ground truth entities
let validation = crate::eval::validation::validate_ground_truth_entities(
text,
ground_truth,
false, // Warnings for overlaps, not errors
);
if !validation.is_valid {
return Err(Error::InvalidInput(format!(
"Invalid ground truth entities: {}",
validation.errors.join("; ")
)));
}
// Log warnings if any (using eprintln! for now, can be upgraded to proper logging)
if !validation.warnings.is_empty() {
eprintln!(
"WARNING: Ground truth validation warnings: {}",
validation.warnings.join("; ")
);
}
let start_time = std::time::Instant::now();
// Extract entities using model
let predicted = model.extract_entities(text, None)?;
let elapsed = start_time.elapsed().as_secs_f64();
let tokens = text.split_whitespace().count();
let tokens_per_second = if elapsed > 0.0 {
tokens as f64 / elapsed
} else {
0.0
};
// Count correct predictions (exact match: same span and type)
// Track which gold entities have been matched to prevent double-counting
// This ensures each gold entity can only be matched once, even if multiple
// predictions match it (duplicate predictions should not inflate precision)
let mut gold_matched = vec![false; ground_truth.len()];
let mut correct = 0;
for pred in &predicted {
for (gt_idx, gt) in ground_truth.iter().enumerate() {
if gold_matched[gt_idx] {
continue; // This gold entity already matched
}
if pred.start() == gt.start
&& pred.end() == gt.end
&& super::entity_type_matches(&pred.entity_type, >.entity_type)
{
gold_matched[gt_idx] = true;
correct += 1;
break;
}
}
}
// Calculate per-type statistics
let mut per_type_stats: HashMap<String, (usize, usize, usize)> = HashMap::new(); // (found, expected, correct)
// Count expected per type and check matches
// Reuse gold_matched tracking to ensure each gold entity counted once per type
let mut gold_matched_per_type = vec![false; ground_truth.len()];
for (gt_idx, gt) in ground_truth.iter().enumerate() {
let type_key = super::entity_type_to_string(>.entity_type);
let stats = per_type_stats.entry(type_key.clone()).or_insert((0, 0, 0));
stats.1 += 1; // expected
// Check if this ground truth entity was found (only count once per gold entity)
if !gold_matched_per_type[gt_idx] {
for pred in &predicted {
if pred.start() == gt.start
&& pred.end() == gt.end
&& super::entity_type_matches(&pred.entity_type, >.entity_type)
{
gold_matched_per_type[gt_idx] = true;
stats.2 += 1; // correct
break;
}
}
}
}
// Count found per type
for pred in &predicted {
let type_key = super::entity_type_to_string(&pred.entity_type);
let stats = per_type_stats.entry(type_key).or_insert((0, 0, 0));
stats.0 += 1; // found
}
// Calculate overall metrics
let found = predicted.len();
let expected = ground_truth.len();
let precision = if found > 0 {
correct as f64 / found as f64
} else {
0.0
};
let recall = if expected > 0 {
correct as f64 / expected as f64
} else {
0.0
};
let f1 = if precision + recall > 0.0 {
2.0 * precision * recall / (precision + recall)
} else {
0.0
};
// Validate metrics are finite (not NaN or Inf)
if !precision.is_finite() || !recall.is_finite() || !f1.is_finite() {
return Err(Error::InvalidInput(format!(
"Invalid metric values: precision={}, recall={}, f1={}",
precision, recall, f1
)));
}
// Calculate per-type metrics
let mut per_type = HashMap::new();
for (type_name, (found_count, expected_count, correct_count)) in per_type_stats {
let type_precision = if found_count > 0 {
correct_count as f64 / found_count as f64
} else {
0.0
};
let type_recall = if expected_count > 0 {
correct_count as f64 / expected_count as f64
} else {
0.0
};
let type_f1 = if type_precision + type_recall > 0.0 {
2.0 * type_precision * type_recall / (type_precision + type_recall)
} else {
0.0
};
per_type.insert(
type_name,
TypeMetrics {
precision: type_precision,
recall: type_recall,
f1: type_f1,
found: found_count,
expected: expected_count,
correct: correct_count,
},
);
}
Ok(NERQueryMetrics {
text: text.to_string(),
test_case_id: test_case_id.map(|s| s.to_string()),
precision: MetricValue::new(precision),
recall: MetricValue::new(recall),
f1: MetricValue::new(f1),
per_type,
found,
expected,
correct,
tokens_per_second,
})
}
fn aggregate(&self, query_metrics: &[NERQueryMetrics]) -> Result<NERAggregateMetrics> {
if query_metrics.is_empty() {
return Err(Error::InvalidInput(
"Cannot aggregate empty metrics".to_string(),
));
}
// Calculate totals for micro-averaging
let total_found: usize = query_metrics.iter().map(|m| m.found).sum();
let total_expected: usize = query_metrics.iter().map(|m| m.expected).sum();
let total_correct: usize = query_metrics.iter().map(|m| m.correct).sum();
// MICRO-averaged metrics (standard for NER, matches seqeval default)
// precision = total_correct / total_found
// recall = total_correct / total_expected
let micro_precision = if total_found > 0 {
total_correct as f64 / total_found as f64
} else {
0.0 // Note: seqeval would optionally warn here
};
let micro_recall = if total_expected > 0 {
total_correct as f64 / total_expected as f64
} else {
0.0
};
let micro_f1 = if micro_precision + micro_recall > 0.0 {
2.0 * micro_precision * micro_recall / (micro_precision + micro_recall)
} else {
0.0
};
// MACRO-averaged metrics (for comparison, equal weight per test case)
let precisions: Vec<f64> = query_metrics.iter().map(|m| m.precision.get()).collect();
let recalls: Vec<f64> = query_metrics.iter().map(|m| m.recall.get()).collect();
let f1s: Vec<f64> = query_metrics.iter().map(|m| m.f1.get()).collect();
let tokens_per_second: Vec<f64> =
query_metrics.iter().map(|m| m.tokens_per_second).collect();
// Defensive checks for division by zero (shouldn't happen due to earlier check, but be safe)
let macro_precision = if precisions.is_empty() {
0.0
} else {
precisions.iter().sum::<f64>() / precisions.len() as f64
};
let macro_recall = if recalls.is_empty() {
0.0
} else {
recalls.iter().sum::<f64>() / recalls.len() as f64
};
let macro_f1 = if f1s.is_empty() {
0.0
} else {
f1s.iter().sum::<f64>() / f1s.len() as f64
};
let mean_tokens_per_second = if tokens_per_second.is_empty() {
0.0
} else {
tokens_per_second.iter().sum::<f64>() / tokens_per_second.len() as f64
};
// Validate metrics are finite
if !micro_precision.is_finite()
|| !micro_recall.is_finite()
|| !micro_f1.is_finite()
|| !mean_tokens_per_second.is_finite()
{
return Err(Error::InvalidInput(format!(
"Invalid aggregate metric values: precision={}, recall={}, f1={}, tps={}",
micro_precision, micro_recall, micro_f1, mean_tokens_per_second
)));
}
// Calculate standard deviations (of per-case metrics, for variability analysis)
let precision_std = calculate_std_dev(&precisions, macro_precision);
let recall_std = calculate_std_dev(&recalls, macro_recall);
let f1_std = calculate_std_dev(&f1s, macro_f1);
// Calculate 95% confidence intervals (based on per-case variability)
let precision_ci_95 = calculate_ci_95(&precisions, macro_precision, precision_std);
let recall_ci_95 = calculate_ci_95(&recalls, macro_recall, recall_std);
let f1_ci_95 = calculate_ci_95(&f1s, macro_f1, f1_std);
// Aggregate per-type metrics using MICRO-averaging
let mut per_type_totals: HashMap<String, (usize, usize, usize)> = HashMap::new();
for metric in query_metrics {
for (type_name, type_metric) in &metric.per_type {
let entry = per_type_totals
.entry(type_name.clone())
.or_insert((0, 0, 0));
entry.0 += type_metric.found;
entry.1 += type_metric.expected;
entry.2 += type_metric.correct;
}
}
let mut per_type = HashMap::new();
for (type_name, (type_found, type_expected, type_correct)) in per_type_totals {
// Micro-averaged per-type metrics
let type_precision = if type_found > 0 {
type_correct as f64 / type_found as f64
} else {
0.0
};
let type_recall = if type_expected > 0 {
type_correct as f64 / type_expected as f64
} else {
0.0
};
let type_f1 = if type_precision + type_recall > 0.0 {
2.0 * type_precision * type_recall / (type_precision + type_recall)
} else {
0.0
};
per_type.insert(
type_name,
TypeMetrics {
precision: type_precision,
recall: type_recall,
f1: type_f1,
found: type_found,
expected: type_expected,
correct: type_correct,
},
);
}
Ok(NERAggregateMetrics {
// Primary metrics: micro-averaged (standard for NER)
precision: MetricValue::new(micro_precision),
recall: MetricValue::new(micro_recall),
f1: MetricValue::new(micro_f1),
// Secondary metrics: macro-averaged (for comparison)
macro_precision: MetricValue::new(macro_precision),
macro_recall: MetricValue::new(macro_recall),
macro_f1: MetricValue::new(macro_f1),
precision_std,
recall_std,
f1_std,
precision_ci_95,
recall_ci_95,
f1_ci_95,
per_type,
tokens_per_second: mean_tokens_per_second,
num_test_cases: query_metrics.len(),
total_found,
total_expected,
total_correct,
})
}
fn check_goals(
&self,
metrics: &NERAggregateMetrics,
goals: &NERMetricGoals,
) -> Result<GoalCheckResult> {
let mut result = GoalCheckResult::new();
// Check overall goals
if let Some(min_precision) = goals.min_precision {
let actual = metrics.precision.get();
let goal = min_precision.get();
if actual < goal {
result.add_failure("precision".to_string(), actual, goal);
}
}
if let Some(min_recall) = goals.min_recall {
let actual = metrics.recall.get();
let goal = min_recall.get();
if actual < goal {
result.add_failure("recall".to_string(), actual, goal);
}
}
if let Some(min_f1) = goals.min_f1 {
let actual = metrics.f1.get();
let goal = min_f1.get();
if actual < goal {
result.add_failure("f1".to_string(), actual, goal);
}
}
// Check per-type goals
for (type_name, type_goals) in &goals.per_type_goals {
if let Some(type_metrics) = metrics.per_type.get(type_name) {
if let Some(min_precision) = type_goals.min_precision {
let actual = type_metrics.precision;
let goal = min_precision.get();
if actual < goal {
result.add_failure(format!("{}.precision", type_name), actual, goal);
}
}
if let Some(min_recall) = type_goals.min_recall {
let actual = type_metrics.recall;
let goal = min_recall.get();
if actual < goal {
result.add_failure(format!("{}.recall", type_name), actual, goal);
}
}
if let Some(min_f1) = type_goals.min_f1 {
let actual = type_metrics.f1;
let goal = min_f1.get();
if actual < goal {
result.add_failure(format!("{}.f1", type_name), actual, goal);
}
}
}
}
Ok(result)
}
}
/// Calculate standard deviation.
fn calculate_std_dev(values: &[f64], mean: f64) -> f64 {
if values.len() < 2 {
return 0.0;
}
let variance =
values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
variance.sqrt()
}
/// Calculate 95% confidence interval.
///
/// Uses t-distribution approximation (z-score for large samples).
///
/// # Note
/// Confidence intervals may extend beyond [0.0, 1.0] for small samples or high variance.
/// This is statistically valid and indicates uncertainty in the estimate.
/// For display purposes, you may want to clamp bounds to [0.0, 1.0], but the raw
/// intervals provide more accurate statistical information.
fn calculate_ci_95(values: &[f64], mean: f64, std_dev: f64) -> Option<(f64, f64)> {
if values.len() < 2 {
return None;
}
// Use z-score for 95% CI (1.96 for large samples)
// For small samples, should use t-distribution, but z-score is acceptable approximation
let z_score = 1.96;
let margin = z_score * std_dev / (values.len() as f64).sqrt();
// Clamp CI bounds to [0.0, 1.0] for metrics (precision, recall, F1)
// Note: For very small samples, CI may extend beyond [0, 1], but we clamp
// to maintain valid metric bounds. This is a reasonable approximation.
let lower = (mean - margin).clamp(0.0, 1.0);
let upper = (mean + margin).clamp(0.0, 1.0);
Some((lower, upper))
}
// Tests moved to tests/ directory