reductionml_core/metrics/
ips.rs1use crate::{metrics::Metric, utils::AsInner, ActionProbsPrediction, CBLabel, Features};
2
3use super::MetricValue;
4
5pub struct IpsMetric {
6 pub examples_count: u64,
7 pub weighted_reward: f32,
8}
9
10impl IpsMetric {
11 pub fn new() -> IpsMetric {
12 IpsMetric {
13 examples_count: 0,
14 weighted_reward: 0.0,
15 }
16 }
17}
18
19impl Default for IpsMetric {
20 fn default() -> Self {
21 Self::new()
22 }
23}
24
25impl Metric for IpsMetric {
26 fn add_point(
27 &mut self,
28 _features: &Features,
29 label: &crate::types::Label,
30 prediction: &crate::types::Prediction,
31 ) {
32 let label: &CBLabel = label.as_inner().unwrap();
33 let pred: &ActionProbsPrediction = prediction.as_inner().unwrap();
34
35 let p_log = label.probability;
36 let p_pred = pred
37 .0
38 .iter()
39 .find(|(action, _)| action == &label.action)
40 .unwrap()
41 .1;
42
43 let w = p_pred / p_log;
44
45 self.weighted_reward += (-1.0 * label.cost) * w;
46 self.examples_count += 1;
47 }
48
49 fn get_value(&self) -> MetricValue {
50 MetricValue::Float(self.weighted_reward / (self.examples_count as f32))
51 }
52
53 fn get_name(&self) -> String {
54 "Estimated reward (IPS)".to_owned()
55 }
56}