1use pyo3::exceptions::PyRuntimeError;
7use pyo3::prelude::*;
8use pyo3::types::{PyAny, PyDict};
9
10use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods};
12
13use scirs2_core::ndarray::{Array1, Array2};
15
16use scirs2_metrics::classification::advanced::{
18 balanced_accuracy_score, cohen_kappa_score, matthews_corrcoef,
19};
20use scirs2_metrics::classification::curves::roc_curve;
21use scirs2_metrics::classification::{
22 accuracy_score, binary_log_loss, confusion_matrix, f1_score, fbeta_score, precision_score,
23 recall_score, roc_auc_score,
24};
25use scirs2_metrics::clustering::{
26 adjusted_rand_index, calinski_harabasz_score, davies_bouldin_score,
27 normalized_mutual_info_score, silhouette_score,
28};
29use scirs2_metrics::ranking::{mean_reciprocal_rank, ndcg_score};
30use scirs2_metrics::regression::{
31 explained_variance_score, mean_absolute_error, mean_absolute_percentage_error,
32 mean_squared_error, r2_score,
33};
34
35#[pyfunction]
41fn accuracy_score_py(
42 y_true: &Bound<'_, PyArray1<i64>>,
43 y_pred: &Bound<'_, PyArray1<i64>>,
44) -> PyResult<f64> {
45 let y_true_binding = y_true.readonly();
46 let y_pred_binding = y_pred.readonly();
47 let y_true_data = y_true_binding.as_array();
48 let y_pred_data = y_pred_binding.as_array();
49
50 accuracy_score(&y_true_data, &y_pred_data)
51 .map_err(|e| PyRuntimeError::new_err(format!("Accuracy score failed: {}", e)))
52}
53
54#[pyfunction]
56#[pyo3(signature = (y_true, y_pred, pos_label=1))]
57fn precision_score_py(
58 y_true: &Bound<'_, PyArray1<i64>>,
59 y_pred: &Bound<'_, PyArray1<i64>>,
60 pos_label: i64,
61) -> PyResult<f64> {
62 let y_true_binding = y_true.readonly();
63 let y_pred_binding = y_pred.readonly();
64 let y_true_data = y_true_binding.as_array();
65 let y_pred_data = y_pred_binding.as_array();
66
67 precision_score(&y_true_data, &y_pred_data, pos_label)
68 .map_err(|e| PyRuntimeError::new_err(format!("Precision score failed: {}", e)))
69}
70
71#[pyfunction]
73#[pyo3(signature = (y_true, y_pred, pos_label=1))]
74fn recall_score_py(
75 y_true: &Bound<'_, PyArray1<i64>>,
76 y_pred: &Bound<'_, PyArray1<i64>>,
77 pos_label: i64,
78) -> PyResult<f64> {
79 let y_true_binding = y_true.readonly();
80 let y_pred_binding = y_pred.readonly();
81 let y_true_data = y_true_binding.as_array();
82 let y_pred_data = y_pred_binding.as_array();
83
84 recall_score(&y_true_data, &y_pred_data, pos_label)
85 .map_err(|e| PyRuntimeError::new_err(format!("Recall score failed: {}", e)))
86}
87
88#[pyfunction]
90#[pyo3(signature = (y_true, y_pred, pos_label=1))]
91fn f1_score_py(
92 y_true: &Bound<'_, PyArray1<i64>>,
93 y_pred: &Bound<'_, PyArray1<i64>>,
94 pos_label: i64,
95) -> PyResult<f64> {
96 let y_true_binding = y_true.readonly();
97 let y_pred_binding = y_pred.readonly();
98 let y_true_data = y_true_binding.as_array();
99 let y_pred_data = y_pred_binding.as_array();
100
101 f1_score(&y_true_data, &y_pred_data, pos_label)
102 .map_err(|e| PyRuntimeError::new_err(format!("F1 score failed: {}", e)))
103}
104
105#[pyfunction]
107#[pyo3(signature = (y_true, y_pred, beta, pos_label=1))]
108fn fbeta_score_py(
109 y_true: &Bound<'_, PyArray1<i64>>,
110 y_pred: &Bound<'_, PyArray1<i64>>,
111 beta: f64,
112 pos_label: i64,
113) -> PyResult<f64> {
114 let y_true_binding = y_true.readonly();
115 let y_pred_binding = y_pred.readonly();
116 let y_true_data = y_true_binding.as_array();
117 let y_pred_data = y_pred_binding.as_array();
118
119 fbeta_score(&y_true_data, &y_pred_data, pos_label, beta)
120 .map_err(|e| PyRuntimeError::new_err(format!("F-beta score failed: {}", e)))
121}
122
123#[pyfunction]
125fn confusion_matrix_py(
126 py: Python,
127 y_true: &Bound<'_, PyArray1<i64>>,
128 y_pred: &Bound<'_, PyArray1<i64>>,
129) -> PyResult<Py<PyAny>> {
130 let y_true_binding = y_true.readonly();
131 let y_pred_binding = y_pred.readonly();
132 let y_true_data = y_true_binding.as_array();
133 let y_pred_data = y_pred_binding.as_array();
134
135 let (matrix, classes): (Array2<u64>, Array1<i64>) =
136 confusion_matrix(&y_true_data, &y_pred_data, None)
137 .map_err(|e| PyRuntimeError::new_err(format!("Confusion matrix failed: {}", e)))?;
138
139 let dict = PyDict::new(py);
140 dict.set_item("matrix", matrix.into_pyarray(py).unbind())?;
141 dict.set_item("classes", classes.into_pyarray(py).unbind())?;
142
143 Ok(dict.into())
144}
145
146#[pyfunction]
148fn roc_curve_py(
149 py: Python,
150 y_true: &Bound<'_, PyArray1<i32>>,
151 y_score: &Bound<'_, PyArray1<f64>>,
152) -> PyResult<Py<PyAny>> {
153 let y_true_binding = y_true.readonly();
154 let y_score_binding = y_score.readonly();
155 let y_true_data = y_true_binding.as_array();
156 let y_score_data = y_score_binding.as_array();
157
158 let result = roc_curve(&y_true_data, &y_score_data)
161 .map_err(|e| PyRuntimeError::new_err(format!("ROC curve failed: {}", e)))?;
162
163 let dict = PyDict::new(py);
164 dict.set_item("fpr", result.0.into_pyarray(py).unbind())?;
165 dict.set_item("tpr", result.1.into_pyarray(py).unbind())?;
166 dict.set_item("thresholds", result.2.into_pyarray(py).unbind())?;
167
168 Ok(dict.into())
169}
170
171#[pyfunction]
173fn roc_auc_score_py(
174 y_true: &Bound<'_, PyArray1<u32>>,
175 y_score: &Bound<'_, PyArray1<f64>>,
176) -> PyResult<f64> {
177 let y_true_binding = y_true.readonly();
178 let y_score_binding = y_score.readonly();
179 let y_true_data = y_true_binding.as_array();
180 let y_score_data = y_score_binding.as_array();
181
182 roc_auc_score(&y_true_data, &y_score_data)
184 .map_err(|e| PyRuntimeError::new_err(format!("ROC AUC score failed: {}", e)))
185}
186
187#[pyfunction]
189#[pyo3(signature = (y_true, y_prob, eps=1e-15))]
190fn log_loss_py(
191 y_true: &Bound<'_, PyArray1<u32>>,
192 y_prob: &Bound<'_, PyArray1<f64>>,
193 eps: f64,
194) -> PyResult<f64> {
195 let y_true_binding = y_true.readonly();
196 let y_prob_binding = y_prob.readonly();
197 let y_true_data = y_true_binding.as_array();
198 let y_prob_data = y_prob_binding.as_array();
199
200 binary_log_loss(&y_true_data, &y_prob_data, eps)
202 .map_err(|e| PyRuntimeError::new_err(format!("Log loss failed: {}", e)))
203}
204
205#[pyfunction]
207fn matthews_corrcoef_py(
208 y_true: &Bound<'_, PyArray1<i64>>,
209 y_pred: &Bound<'_, PyArray1<i64>>,
210) -> PyResult<f64> {
211 let y_true_binding = y_true.readonly();
212 let y_pred_binding = y_pred.readonly();
213 let y_true_data = y_true_binding.as_array();
214 let y_pred_data = y_pred_binding.as_array();
215
216 matthews_corrcoef(&y_true_data, &y_pred_data)
217 .map_err(|e| PyRuntimeError::new_err(format!("Matthews correlation failed: {}", e)))
218}
219
220#[pyfunction]
222fn balanced_accuracy_score_py(
223 y_true: &Bound<'_, PyArray1<i64>>,
224 y_pred: &Bound<'_, PyArray1<i64>>,
225) -> PyResult<f64> {
226 let y_true_binding = y_true.readonly();
227 let y_pred_binding = y_pred.readonly();
228 let y_true_data = y_true_binding.as_array();
229 let y_pred_data = y_pred_binding.as_array();
230
231 balanced_accuracy_score(&y_true_data, &y_pred_data)
232 .map_err(|e| PyRuntimeError::new_err(format!("Balanced accuracy failed: {}", e)))
233}
234
235#[pyfunction]
237fn cohen_kappa_score_py(
238 y_true: &Bound<'_, PyArray1<i64>>,
239 y_pred: &Bound<'_, PyArray1<i64>>,
240) -> PyResult<f64> {
241 let y_true_binding = y_true.readonly();
242 let y_pred_binding = y_pred.readonly();
243 let y_true_data = y_true_binding.as_array();
244 let y_pred_data = y_pred_binding.as_array();
245
246 cohen_kappa_score(&y_true_data, &y_pred_data)
247 .map_err(|e| PyRuntimeError::new_err(format!("Cohen's kappa failed: {}", e)))
248}
249
250#[pyfunction]
256fn mean_squared_error_py(
257 y_true: &Bound<'_, PyArray1<f64>>,
258 y_pred: &Bound<'_, PyArray1<f64>>,
259) -> PyResult<f64> {
260 let y_true_binding = y_true.readonly();
261 let y_pred_binding = y_pred.readonly();
262 let y_true_data = y_true_binding.as_array();
263 let y_pred_data = y_pred_binding.as_array();
264
265 mean_squared_error(&y_true_data, &y_pred_data)
266 .map_err(|e| PyRuntimeError::new_err(format!("MSE failed: {}", e)))
267}
268
269#[pyfunction]
271fn mean_absolute_error_py(
272 y_true: &Bound<'_, PyArray1<f64>>,
273 y_pred: &Bound<'_, PyArray1<f64>>,
274) -> PyResult<f64> {
275 let y_true_binding = y_true.readonly();
276 let y_pred_binding = y_pred.readonly();
277 let y_true_data = y_true_binding.as_array();
278 let y_pred_data = y_pred_binding.as_array();
279
280 mean_absolute_error(&y_true_data, &y_pred_data)
281 .map_err(|e| PyRuntimeError::new_err(format!("MAE failed: {}", e)))
282}
283
284#[pyfunction]
286fn r2_score_py(
287 y_true: &Bound<'_, PyArray1<f64>>,
288 y_pred: &Bound<'_, PyArray1<f64>>,
289) -> PyResult<f64> {
290 let y_true_binding = y_true.readonly();
291 let y_pred_binding = y_pred.readonly();
292 let y_true_data = y_true_binding.as_array();
293 let y_pred_data = y_pred_binding.as_array();
294
295 r2_score(&y_true_data, &y_pred_data)
296 .map_err(|e| PyRuntimeError::new_err(format!("R² score failed: {}", e)))
297}
298
299#[pyfunction]
301fn mape_py(y_true: &Bound<'_, PyArray1<f64>>, y_pred: &Bound<'_, PyArray1<f64>>) -> PyResult<f64> {
302 let y_true_binding = y_true.readonly();
303 let y_pred_binding = y_pred.readonly();
304 let y_true_data = y_true_binding.as_array();
305 let y_pred_data = y_pred_binding.as_array();
306
307 mean_absolute_percentage_error(&y_true_data, &y_pred_data)
308 .map_err(|e| PyRuntimeError::new_err(format!("MAPE failed: {}", e)))
309}
310
311#[pyfunction]
313fn explained_variance_score_py(
314 y_true: &Bound<'_, PyArray1<f64>>,
315 y_pred: &Bound<'_, PyArray1<f64>>,
316) -> PyResult<f64> {
317 let y_true_binding = y_true.readonly();
318 let y_pred_binding = y_pred.readonly();
319 let y_true_data = y_true_binding.as_array();
320 let y_pred_data = y_pred_binding.as_array();
321
322 explained_variance_score(&y_true_data, &y_pred_data)
323 .map_err(|e| PyRuntimeError::new_err(format!("Explained variance failed: {}", e)))
324}
325
326#[pyfunction]
332#[pyo3(signature = (x, labels, metric="euclidean"))]
333fn silhouette_score_py(
334 x: &Bound<'_, PyArray2<f64>>,
335 labels: &Bound<'_, PyArray1<usize>>,
336 metric: &str,
337) -> PyResult<f64> {
338 let x_binding = x.readonly();
339 let labels_binding = labels.readonly();
340 let x_data = x_binding.as_array();
341 let labels_data = labels_binding.as_array();
342
343 silhouette_score(&x_data, &labels_data, metric)
344 .map_err(|e| PyRuntimeError::new_err(format!("Silhouette score failed: {}", e)))
345}
346
347#[pyfunction]
349fn davies_bouldin_score_py(
350 x: &Bound<'_, PyArray2<f64>>,
351 labels: &Bound<'_, PyArray1<usize>>,
352) -> PyResult<f64> {
353 let x_binding = x.readonly();
354 let labels_binding = labels.readonly();
355 let x_data = x_binding.as_array();
356 let labels_data = labels_binding.as_array();
357
358 davies_bouldin_score(&x_data, &labels_data)
359 .map_err(|e| PyRuntimeError::new_err(format!("Davies-Bouldin score failed: {}", e)))
360}
361
362#[pyfunction]
364fn calinski_harabasz_score_py(
365 x: &Bound<'_, PyArray2<f64>>,
366 labels: &Bound<'_, PyArray1<usize>>,
367) -> PyResult<f64> {
368 let x_binding = x.readonly();
369 let labels_binding = labels.readonly();
370 let x_data = x_binding.as_array();
371 let labels_data = labels_binding.as_array();
372
373 calinski_harabasz_score(&x_data, &labels_data)
374 .map_err(|e| PyRuntimeError::new_err(format!("Calinski-Harabasz score failed: {}", e)))
375}
376
377#[pyfunction]
379fn adjusted_rand_index_py(
380 labels_true: &Bound<'_, PyArray1<i64>>,
381 labels_pred: &Bound<'_, PyArray1<i64>>,
382) -> PyResult<f64> {
383 let labels_true_binding = labels_true.readonly();
384 let labels_pred_binding = labels_pred.readonly();
385 let labels_true_data = labels_true_binding.as_array();
386 let labels_pred_data = labels_pred_binding.as_array();
387
388 adjusted_rand_index(&labels_true_data, &labels_pred_data)
389 .map_err(|e| PyRuntimeError::new_err(format!("Adjusted Rand index failed: {}", e)))
390}
391
392#[pyfunction]
394#[pyo3(signature = (labels_true, labels_pred, average_method="arithmetic"))]
395fn nmi_score_py(
396 labels_true: &Bound<'_, PyArray1<i64>>,
397 labels_pred: &Bound<'_, PyArray1<i64>>,
398 average_method: &str,
399) -> PyResult<f64> {
400 let labels_true_binding = labels_true.readonly();
401 let labels_pred_binding = labels_pred.readonly();
402 let labels_true_data = labels_true_binding.as_array();
403 let labels_pred_data = labels_pred_binding.as_array();
404
405 normalized_mutual_info_score(&labels_true_data, &labels_pred_data, average_method)
406 .map_err(|e| PyRuntimeError::new_err(format!("NMI score failed: {}", e)))
407}
408
409#[pyfunction]
416#[pyo3(signature = (y_true, y_score, k=None))]
417fn ndcg_score_py(
418 y_true: &Bound<'_, PyArray2<f64>>,
419 y_score: &Bound<'_, PyArray2<f64>>,
420 k: Option<usize>,
421) -> PyResult<f64> {
422 let y_true_binding = y_true.readonly();
423 let y_score_binding = y_score.readonly();
424 let y_true_data = y_true_binding.as_array();
425 let y_score_data = y_score_binding.as_array();
426
427 let y_true_vec: Vec<Array1<f64>> = y_true_data
429 .rows()
430 .into_iter()
431 .map(|row| row.to_owned())
432 .collect();
433 let y_score_vec: Vec<Array1<f64>> = y_score_data
434 .rows()
435 .into_iter()
436 .map(|row| row.to_owned())
437 .collect();
438
439 ndcg_score(&y_true_vec, &y_score_vec, k)
440 .map_err(|e| PyRuntimeError::new_err(format!("NDCG score failed: {}", e)))
441}
442
443#[pyfunction]
446fn mrr_py(y_true: &Bound<'_, PyArray2<f64>>, y_score: &Bound<'_, PyArray2<f64>>) -> PyResult<f64> {
447 let y_true_binding = y_true.readonly();
448 let y_score_binding = y_score.readonly();
449 let y_true_data = y_true_binding.as_array();
450 let y_score_data = y_score_binding.as_array();
451
452 let y_true_vec: Vec<Array1<f64>> = y_true_data
454 .rows()
455 .into_iter()
456 .map(|row| row.to_owned())
457 .collect();
458 let y_score_vec: Vec<Array1<f64>> = y_score_data
459 .rows()
460 .into_iter()
461 .map(|row| row.to_owned())
462 .collect();
463
464 mean_reciprocal_rank(&y_true_vec, &y_score_vec)
465 .map_err(|e| PyRuntimeError::new_err(format!("MRR failed: {}", e)))
466}
467
468pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
470 m.add_function(wrap_pyfunction!(accuracy_score_py, m)?)?;
472 m.add_function(wrap_pyfunction!(precision_score_py, m)?)?;
473 m.add_function(wrap_pyfunction!(recall_score_py, m)?)?;
474 m.add_function(wrap_pyfunction!(f1_score_py, m)?)?;
475 m.add_function(wrap_pyfunction!(fbeta_score_py, m)?)?;
476 m.add_function(wrap_pyfunction!(confusion_matrix_py, m)?)?;
477 m.add_function(wrap_pyfunction!(roc_curve_py, m)?)?;
478 m.add_function(wrap_pyfunction!(roc_auc_score_py, m)?)?;
479 m.add_function(wrap_pyfunction!(log_loss_py, m)?)?;
480 m.add_function(wrap_pyfunction!(matthews_corrcoef_py, m)?)?;
481 m.add_function(wrap_pyfunction!(balanced_accuracy_score_py, m)?)?;
482 m.add_function(wrap_pyfunction!(cohen_kappa_score_py, m)?)?;
483
484 m.add_function(wrap_pyfunction!(mean_squared_error_py, m)?)?;
486 m.add_function(wrap_pyfunction!(mean_absolute_error_py, m)?)?;
487 m.add_function(wrap_pyfunction!(r2_score_py, m)?)?;
488 m.add_function(wrap_pyfunction!(mape_py, m)?)?;
489 m.add_function(wrap_pyfunction!(explained_variance_score_py, m)?)?;
490
491 m.add_function(wrap_pyfunction!(silhouette_score_py, m)?)?;
493 m.add_function(wrap_pyfunction!(davies_bouldin_score_py, m)?)?;
494 m.add_function(wrap_pyfunction!(calinski_harabasz_score_py, m)?)?;
495 m.add_function(wrap_pyfunction!(adjusted_rand_index_py, m)?)?;
496 m.add_function(wrap_pyfunction!(nmi_score_py, m)?)?;
497
498 m.add_function(wrap_pyfunction!(ndcg_score_py, m)?)?;
500 m.add_function(wrap_pyfunction!(mrr_py, m)?)?;
501
502 Ok(())
503}