1use std::cmp::Ordering;
3
4use crate::custom_ops::{CustomOperation, CustomOperationBody};
5use crate::data_types::{Type, INT128, INT64};
6use crate::errors::Result;
7use crate::graphs::{Context, Graph, Node, SliceElement};
8
9use serde::{Deserialize, Serialize};
10
11use super::fixed_precision::fixed_precision_config::FixedPrecisionConfig;
12use super::goldschmidt_division::GoldschmidtDivision;
13use super::integer_key_sort::SortByIntegerKey;
14use super::utils::constant_scalar;
15
16const MAX_LOG_ARRAY_SIZE: u64 = 20;
17
18#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
49pub struct AucScore {
50 pub fp: FixedPrecisionConfig,
52}
53
54#[typetag::serde]
55impl CustomOperationBody for AucScore {
56 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
57 if arguments_types.len() != 2 {
58 return Err(runtime_error!("Invalid number of arguments for AucMetric"));
59 }
60 let t = arguments_types[0].clone();
61 if !t.is_array() {
62 return Err(runtime_error!(
63 "`y_true` in AucMetric must be an array, got {t:?}"
64 ));
65 }
66 if t.get_dimensions().len() != 1 {
67 return Err(runtime_error!(
68 "`y_true` in AucMetric must be 1-dimensional, got {t:?}"
69 ));
70 }
71 let n = t.get_dimensions()[0];
72 if n >= (1 << MAX_LOG_ARRAY_SIZE) {
73 return Err(runtime_error!(
74 "`y_true` in AucMetric must have less than 2^{} elements, got {n:?}",
75 MAX_LOG_ARRAY_SIZE
76 ));
77 }
78 let sc = t.get_scalar_type();
79 if sc != INT64 {
80 return Err(runtime_error!(
81 "`y_true` in AucMetric must consist of INT64's, got {sc:?}"
82 ));
83 }
84 if arguments_types[1] != t {
85 return Err(runtime_error!(
86 "`y_pred` in AucMetric must be of the same type as `y_true`, got {:?} vs {:?}",
87 t,
88 arguments_types[1]
89 ));
90 }
91
92 let g = context.create_graph()?;
93 let y_true = g.input(t.clone())?;
94 let y_pred = g.input(t)?;
95 let auc1 = compute_naive_auc(y_true.clone(), y_pred.clone(), &self.fp)?;
104 let y_true = y_true.get_slice(vec![SliceElement::SubArray(None, None, Some(-1))])?;
105 let y_pred = y_pred.get_slice(vec![SliceElement::SubArray(None, None, Some(-1))])?;
106 let auc2 = compute_naive_auc(y_true, y_pred, &self.fp)?;
107 let auc = auc1.add(auc2)?.truncate(2)?;
108 auc.set_as_output()?;
109 g.finalize()?;
110 Ok(g)
111 }
112
113 fn get_name(&self) -> String {
114 format!("AucScore(fp={:?})", self.fp)
115 }
116}
117
118fn compute_naive_auc(y_true: Node, y_pred: Node, fp: &FixedPrecisionConfig) -> Result<Node> {
119 let g = y_true.get_graph();
123 let joined =
124 g.create_named_tuple(vec![("y_pred".into(), y_pred), ("y_true".into(), y_true)])?;
125 let joined = g.custom_op(
126 CustomOperation::new(SortByIntegerKey {
127 key: "y_pred".into(),
128 }),
129 vec![joined],
130 )?;
131 let y_true = joined.named_tuple_get("y_true".into())?;
132
133 let num_ones = y_true.sum(vec![0])?.truncate(fp.denominator())?;
135 let n = y_true.get_type()?.get_dimensions()[0] as i64;
136 let n = constant_scalar(&g, n, INT64)?;
137 let num_zeros = n.subtract(num_ones.clone())?;
138 let denominator = num_ones.multiply(num_zeros)?;
139
140 let one = constant_scalar(&g, fp.denominator(), INT64)?;
142 let num_zeros_on_prefix = one.subtract(y_true.clone())?.cum_sum(0)?;
143 let num_zeros_before_one = num_zeros_on_prefix
144 .multiply(y_true)?
145 .truncate(fp.denominator())?;
146 let numerator = num_zeros_before_one
147 .sum(vec![0])?
148 .truncate(fp.denominator())?;
149
150 let numerator = i64_to_i128(numerator)?;
152 let denominator = i64_to_i128(denominator)?;
153 let denom_bits = MAX_LOG_ARRAY_SIZE * 2;
154 let result = g.custom_op(
158 CustomOperation::new(GoldschmidtDivision {
159 iterations: 7,
161 denominator_cap_2k: denom_bits,
162 }),
163 vec![numerator, denominator],
164 )?;
165
166 let result = match denom_bits.cmp(&fp.fractional_bits) {
167 Ordering::Less => result.multiply(constant_scalar(
168 &g,
169 1 << (fp.fractional_bits - denom_bits),
170 INT128,
171 )?)?,
172 Ordering::Equal => result,
173 Ordering::Greater => result.truncate(1 << (denom_bits - fp.fractional_bits))?,
174 };
175
176 i128_to_i64(result)
177}
178
179fn i64_to_i128(x: Node) -> Result<Node> {
180 let g = x.get_graph();
181 let bits = x.a2b()?;
182 let zeros = g.zeros(bits.get_type()?)?;
183 let bits = g.concatenate(vec![bits, zeros], 0)?;
184 bits.b2a(INT128)
185}
186
187fn i128_to_i64(x: Node) -> Result<Node> {
188 let bits = x.a2b()?;
189 let bits = bits.get_slice(vec![SliceElement::SubArray(None, Some(64), None)])?;
190 bits.b2a(INT64)
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 use crate::custom_ops::run_instantiation_pass;
198 use crate::custom_ops::CustomOperation;
199 use crate::data_types::array_type;
200 use crate::data_values::Value;
201 use crate::evaluators::random_evaluate;
202 use crate::graphs::util::simple_context;
203
204 fn test_helper(y_true: Vec<i64>, y_pred: Vec<i64>) -> Result<f64> {
205 let array_t = array_type(vec![y_true.len() as u64], INT64);
206 let c = simple_context(|g| {
207 let y_true = g.input(array_t.clone())?;
208 let y_pred = g.input(array_t)?;
209 g.custom_op(
210 CustomOperation::new(AucScore {
211 fp: FixedPrecisionConfig::default(),
212 }),
213 vec![y_true, y_pred],
214 )
215 })?;
216 let mapped_c = run_instantiation_pass(c)?;
217 let result = random_evaluate(
218 mapped_c.get_context().get_main_graph()?,
219 vec![
220 Value::from_flattened_array(&y_true, INT64)?,
221 Value::from_flattened_array(&y_pred, INT64)?,
222 ],
223 )?;
224 Ok(result.to_i64(INT64)? as f64 / FixedPrecisionConfig::default().denominator_f64())
225 }
226
227 #[test]
228 fn test_auc_simple_case() -> Result<()> {
229 let one = FixedPrecisionConfig::default().denominator() as i64;
230 let y_true = vec![0, one, 0, one];
231 let y_pred = vec![-10, 30, 20, 10];
232 let res = test_helper(y_true, y_pred)?;
233 assert!((res - 0.75).abs() < 1e-3);
234 Ok(())
235 }
236
237 #[test]
238 fn test_auc_equal_predictions() -> Result<()> {
239 let one = FixedPrecisionConfig::default().denominator() as i64;
240 let y_true = vec![0, one, 0, one];
241 let y_pred = vec![42, 42, 42, 42];
242 let res = test_helper(y_true, y_pred)?;
243 assert!((res - 0.5).abs() < 1e-3);
244 Ok(())
245 }
246
247 #[test]
248 fn test_auc_large_array() -> Result<()> {
249 let one = FixedPrecisionConfig::default().denominator() as i64;
250 let mut y_true = vec![];
251 let mut y_pred = vec![];
252 for i in 0..10000 {
253 y_true.push(if i < 5000 { 0 } else { one });
254 y_pred.push(i);
255 }
256 let res = test_helper(y_true, y_pred)?;
257 assert!((res - 1.0).abs() < 1e-3);
258 Ok(())
259 }
260}