Skip to main content

scirs2/
metrics.rs

1//! Python bindings for scirs2-metrics
2//!
3//! This module provides Python bindings for machine learning evaluation metrics,
4//! including classification, regression, and clustering metrics.
5
6use pyo3::exceptions::PyRuntimeError;
7use pyo3::prelude::*;
8use pyo3::types::{PyAny, PyDict};
9
10// NumPy types for Python array interface
11use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods};
12
13// ndarray types from scirs2-core
14use scirs2_core::ndarray::{Array1, Array2};
15
16// Direct imports from scirs2-metrics submodules
17use scirs2_metrics::classification::advanced::{
18    balanced_accuracy_score, cohen_kappa_score, matthews_corrcoef,
19};
20use scirs2_metrics::classification::curves::roc_curve;
21use scirs2_metrics::classification::{
22    accuracy_score, binary_log_loss, confusion_matrix, f1_score, fbeta_score, precision_score,
23    recall_score, roc_auc_score,
24};
25use scirs2_metrics::clustering::{
26    adjusted_rand_index, calinski_harabasz_score, davies_bouldin_score,
27    normalized_mutual_info_score, silhouette_score,
28};
29use scirs2_metrics::ranking::{mean_reciprocal_rank, ndcg_score};
30use scirs2_metrics::regression::{
31    explained_variance_score, mean_absolute_error, mean_absolute_percentage_error,
32    mean_squared_error, r2_score,
33};
34
35// ========================================
36// CLASSIFICATION METRICS
37// ========================================
38
39/// Calculate accuracy score
40#[pyfunction]
41fn accuracy_score_py(
42    y_true: &Bound<'_, PyArray1<i64>>,
43    y_pred: &Bound<'_, PyArray1<i64>>,
44) -> PyResult<f64> {
45    let y_true_binding = y_true.readonly();
46    let y_pred_binding = y_pred.readonly();
47    let y_true_data = y_true_binding.as_array();
48    let y_pred_data = y_pred_binding.as_array();
49
50    accuracy_score(&y_true_data, &y_pred_data)
51        .map_err(|e| PyRuntimeError::new_err(format!("Accuracy score failed: {}", e)))
52}
53
54/// Calculate precision score (binary classification)
55#[pyfunction]
56#[pyo3(signature = (y_true, y_pred, pos_label=1))]
57fn precision_score_py(
58    y_true: &Bound<'_, PyArray1<i64>>,
59    y_pred: &Bound<'_, PyArray1<i64>>,
60    pos_label: i64,
61) -> PyResult<f64> {
62    let y_true_binding = y_true.readonly();
63    let y_pred_binding = y_pred.readonly();
64    let y_true_data = y_true_binding.as_array();
65    let y_pred_data = y_pred_binding.as_array();
66
67    precision_score(&y_true_data, &y_pred_data, pos_label)
68        .map_err(|e| PyRuntimeError::new_err(format!("Precision score failed: {}", e)))
69}
70
71/// Calculate recall score (binary classification)
72#[pyfunction]
73#[pyo3(signature = (y_true, y_pred, pos_label=1))]
74fn recall_score_py(
75    y_true: &Bound<'_, PyArray1<i64>>,
76    y_pred: &Bound<'_, PyArray1<i64>>,
77    pos_label: i64,
78) -> PyResult<f64> {
79    let y_true_binding = y_true.readonly();
80    let y_pred_binding = y_pred.readonly();
81    let y_true_data = y_true_binding.as_array();
82    let y_pred_data = y_pred_binding.as_array();
83
84    recall_score(&y_true_data, &y_pred_data, pos_label)
85        .map_err(|e| PyRuntimeError::new_err(format!("Recall score failed: {}", e)))
86}
87
88/// Calculate F1 score (binary classification)
89#[pyfunction]
90#[pyo3(signature = (y_true, y_pred, pos_label=1))]
91fn f1_score_py(
92    y_true: &Bound<'_, PyArray1<i64>>,
93    y_pred: &Bound<'_, PyArray1<i64>>,
94    pos_label: i64,
95) -> PyResult<f64> {
96    let y_true_binding = y_true.readonly();
97    let y_pred_binding = y_pred.readonly();
98    let y_true_data = y_true_binding.as_array();
99    let y_pred_data = y_pred_binding.as_array();
100
101    f1_score(&y_true_data, &y_pred_data, pos_label)
102        .map_err(|e| PyRuntimeError::new_err(format!("F1 score failed: {}", e)))
103}
104
105/// Calculate F-beta score
106#[pyfunction]
107#[pyo3(signature = (y_true, y_pred, beta, pos_label=1))]
108fn fbeta_score_py(
109    y_true: &Bound<'_, PyArray1<i64>>,
110    y_pred: &Bound<'_, PyArray1<i64>>,
111    beta: f64,
112    pos_label: i64,
113) -> PyResult<f64> {
114    let y_true_binding = y_true.readonly();
115    let y_pred_binding = y_pred.readonly();
116    let y_true_data = y_true_binding.as_array();
117    let y_pred_data = y_pred_binding.as_array();
118
119    fbeta_score(&y_true_data, &y_pred_data, pos_label, beta)
120        .map_err(|e| PyRuntimeError::new_err(format!("F-beta score failed: {}", e)))
121}
122
123/// Calculate confusion matrix
124#[pyfunction]
125fn confusion_matrix_py(
126    py: Python,
127    y_true: &Bound<'_, PyArray1<i64>>,
128    y_pred: &Bound<'_, PyArray1<i64>>,
129) -> PyResult<Py<PyAny>> {
130    let y_true_binding = y_true.readonly();
131    let y_pred_binding = y_pred.readonly();
132    let y_true_data = y_true_binding.as_array();
133    let y_pred_data = y_pred_binding.as_array();
134
135    let (matrix, classes): (Array2<u64>, Array1<i64>) =
136        confusion_matrix(&y_true_data, &y_pred_data, None)
137            .map_err(|e| PyRuntimeError::new_err(format!("Confusion matrix failed: {}", e)))?;
138
139    let dict = PyDict::new(py);
140    dict.set_item("matrix", matrix.into_pyarray(py).unbind())?;
141    dict.set_item("classes", classes.into_pyarray(py).unbind())?;
142
143    Ok(dict.into())
144}
145
146/// Calculate ROC curve
147#[pyfunction]
148fn roc_curve_py(
149    py: Python,
150    y_true: &Bound<'_, PyArray1<i32>>,
151    y_score: &Bound<'_, PyArray1<f64>>,
152) -> PyResult<Py<PyAny>> {
153    let y_true_binding = y_true.readonly();
154    let y_score_binding = y_score.readonly();
155    let y_true_data = y_true_binding.as_array();
156    let y_score_data = y_score_binding.as_array();
157
158    // roc_curve requires S1::Elem: Into<f64>, i32 implements this
159    // Returns ROCCurveResult which is a struct with fpr, tpr, thresholds fields
160    let result = roc_curve(&y_true_data, &y_score_data)
161        .map_err(|e| PyRuntimeError::new_err(format!("ROC curve failed: {}", e)))?;
162
163    let dict = PyDict::new(py);
164    dict.set_item("fpr", result.0.into_pyarray(py).unbind())?;
165    dict.set_item("tpr", result.1.into_pyarray(py).unbind())?;
166    dict.set_item("thresholds", result.2.into_pyarray(py).unbind())?;
167
168    Ok(dict.into())
169}
170
171/// Calculate ROC AUC score
172#[pyfunction]
173fn roc_auc_score_py(
174    y_true: &Bound<'_, PyArray1<u32>>,
175    y_score: &Bound<'_, PyArray1<f64>>,
176) -> PyResult<f64> {
177    let y_true_binding = y_true.readonly();
178    let y_score_binding = y_score.readonly();
179    let y_true_data = y_true_binding.as_array();
180    let y_score_data = y_score_binding.as_array();
181
182    // roc_auc_score requires S1: Data<Elem = u32>
183    roc_auc_score(&y_true_data, &y_score_data)
184        .map_err(|e| PyRuntimeError::new_err(format!("ROC AUC score failed: {}", e)))
185}
186
187/// Calculate log loss
188#[pyfunction]
189#[pyo3(signature = (y_true, y_prob, eps=1e-15))]
190fn log_loss_py(
191    y_true: &Bound<'_, PyArray1<u32>>,
192    y_prob: &Bound<'_, PyArray1<f64>>,
193    eps: f64,
194) -> PyResult<f64> {
195    let y_true_binding = y_true.readonly();
196    let y_prob_binding = y_prob.readonly();
197    let y_true_data = y_true_binding.as_array();
198    let y_prob_data = y_prob_binding.as_array();
199
200    // binary_log_loss requires S1: Data<Elem = u32> and takes eps parameter
201    binary_log_loss(&y_true_data, &y_prob_data, eps)
202        .map_err(|e| PyRuntimeError::new_err(format!("Log loss failed: {}", e)))
203}
204
205/// Calculate Matthews correlation coefficient
206#[pyfunction]
207fn matthews_corrcoef_py(
208    y_true: &Bound<'_, PyArray1<i64>>,
209    y_pred: &Bound<'_, PyArray1<i64>>,
210) -> PyResult<f64> {
211    let y_true_binding = y_true.readonly();
212    let y_pred_binding = y_pred.readonly();
213    let y_true_data = y_true_binding.as_array();
214    let y_pred_data = y_pred_binding.as_array();
215
216    matthews_corrcoef(&y_true_data, &y_pred_data)
217        .map_err(|e| PyRuntimeError::new_err(format!("Matthews correlation failed: {}", e)))
218}
219
220/// Calculate balanced accuracy score
221#[pyfunction]
222fn balanced_accuracy_score_py(
223    y_true: &Bound<'_, PyArray1<i64>>,
224    y_pred: &Bound<'_, PyArray1<i64>>,
225) -> PyResult<f64> {
226    let y_true_binding = y_true.readonly();
227    let y_pred_binding = y_pred.readonly();
228    let y_true_data = y_true_binding.as_array();
229    let y_pred_data = y_pred_binding.as_array();
230
231    balanced_accuracy_score(&y_true_data, &y_pred_data)
232        .map_err(|e| PyRuntimeError::new_err(format!("Balanced accuracy failed: {}", e)))
233}
234
235/// Calculate Cohen's kappa score
236#[pyfunction]
237fn cohen_kappa_score_py(
238    y_true: &Bound<'_, PyArray1<i64>>,
239    y_pred: &Bound<'_, PyArray1<i64>>,
240) -> PyResult<f64> {
241    let y_true_binding = y_true.readonly();
242    let y_pred_binding = y_pred.readonly();
243    let y_true_data = y_true_binding.as_array();
244    let y_pred_data = y_pred_binding.as_array();
245
246    cohen_kappa_score(&y_true_data, &y_pred_data)
247        .map_err(|e| PyRuntimeError::new_err(format!("Cohen's kappa failed: {}", e)))
248}
249
250// ========================================
251// REGRESSION METRICS
252// ========================================
253
254/// Calculate mean squared error
255#[pyfunction]
256fn mean_squared_error_py(
257    y_true: &Bound<'_, PyArray1<f64>>,
258    y_pred: &Bound<'_, PyArray1<f64>>,
259) -> PyResult<f64> {
260    let y_true_binding = y_true.readonly();
261    let y_pred_binding = y_pred.readonly();
262    let y_true_data = y_true_binding.as_array();
263    let y_pred_data = y_pred_binding.as_array();
264
265    mean_squared_error(&y_true_data, &y_pred_data)
266        .map_err(|e| PyRuntimeError::new_err(format!("MSE failed: {}", e)))
267}
268
269/// Calculate mean absolute error
270#[pyfunction]
271fn mean_absolute_error_py(
272    y_true: &Bound<'_, PyArray1<f64>>,
273    y_pred: &Bound<'_, PyArray1<f64>>,
274) -> PyResult<f64> {
275    let y_true_binding = y_true.readonly();
276    let y_pred_binding = y_pred.readonly();
277    let y_true_data = y_true_binding.as_array();
278    let y_pred_data = y_pred_binding.as_array();
279
280    mean_absolute_error(&y_true_data, &y_pred_data)
281        .map_err(|e| PyRuntimeError::new_err(format!("MAE failed: {}", e)))
282}
283
284/// Calculate R² score
285#[pyfunction]
286fn r2_score_py(
287    y_true: &Bound<'_, PyArray1<f64>>,
288    y_pred: &Bound<'_, PyArray1<f64>>,
289) -> PyResult<f64> {
290    let y_true_binding = y_true.readonly();
291    let y_pred_binding = y_pred.readonly();
292    let y_true_data = y_true_binding.as_array();
293    let y_pred_data = y_pred_binding.as_array();
294
295    r2_score(&y_true_data, &y_pred_data)
296        .map_err(|e| PyRuntimeError::new_err(format!("R² score failed: {}", e)))
297}
298
299/// Calculate mean absolute percentage error
300#[pyfunction]
301fn mape_py(y_true: &Bound<'_, PyArray1<f64>>, y_pred: &Bound<'_, PyArray1<f64>>) -> PyResult<f64> {
302    let y_true_binding = y_true.readonly();
303    let y_pred_binding = y_pred.readonly();
304    let y_true_data = y_true_binding.as_array();
305    let y_pred_data = y_pred_binding.as_array();
306
307    mean_absolute_percentage_error(&y_true_data, &y_pred_data)
308        .map_err(|e| PyRuntimeError::new_err(format!("MAPE failed: {}", e)))
309}
310
311/// Calculate explained variance score
312#[pyfunction]
313fn explained_variance_score_py(
314    y_true: &Bound<'_, PyArray1<f64>>,
315    y_pred: &Bound<'_, PyArray1<f64>>,
316) -> PyResult<f64> {
317    let y_true_binding = y_true.readonly();
318    let y_pred_binding = y_pred.readonly();
319    let y_true_data = y_true_binding.as_array();
320    let y_pred_data = y_pred_binding.as_array();
321
322    explained_variance_score(&y_true_data, &y_pred_data)
323        .map_err(|e| PyRuntimeError::new_err(format!("Explained variance failed: {}", e)))
324}
325
326// ========================================
327// CLUSTERING METRICS
328// ========================================
329
330/// Calculate silhouette score
331#[pyfunction]
332#[pyo3(signature = (x, labels, metric="euclidean"))]
333fn silhouette_score_py(
334    x: &Bound<'_, PyArray2<f64>>,
335    labels: &Bound<'_, PyArray1<usize>>,
336    metric: &str,
337) -> PyResult<f64> {
338    let x_binding = x.readonly();
339    let labels_binding = labels.readonly();
340    let x_data = x_binding.as_array();
341    let labels_data = labels_binding.as_array();
342
343    silhouette_score(&x_data, &labels_data, metric)
344        .map_err(|e| PyRuntimeError::new_err(format!("Silhouette score failed: {}", e)))
345}
346
347/// Calculate Davies-Bouldin score
348#[pyfunction]
349fn davies_bouldin_score_py(
350    x: &Bound<'_, PyArray2<f64>>,
351    labels: &Bound<'_, PyArray1<usize>>,
352) -> PyResult<f64> {
353    let x_binding = x.readonly();
354    let labels_binding = labels.readonly();
355    let x_data = x_binding.as_array();
356    let labels_data = labels_binding.as_array();
357
358    davies_bouldin_score(&x_data, &labels_data)
359        .map_err(|e| PyRuntimeError::new_err(format!("Davies-Bouldin score failed: {}", e)))
360}
361
362/// Calculate Calinski-Harabasz score
363#[pyfunction]
364fn calinski_harabasz_score_py(
365    x: &Bound<'_, PyArray2<f64>>,
366    labels: &Bound<'_, PyArray1<usize>>,
367) -> PyResult<f64> {
368    let x_binding = x.readonly();
369    let labels_binding = labels.readonly();
370    let x_data = x_binding.as_array();
371    let labels_data = labels_binding.as_array();
372
373    calinski_harabasz_score(&x_data, &labels_data)
374        .map_err(|e| PyRuntimeError::new_err(format!("Calinski-Harabasz score failed: {}", e)))
375}
376
377/// Calculate adjusted Rand index
378#[pyfunction]
379fn adjusted_rand_index_py(
380    labels_true: &Bound<'_, PyArray1<i64>>,
381    labels_pred: &Bound<'_, PyArray1<i64>>,
382) -> PyResult<f64> {
383    let labels_true_binding = labels_true.readonly();
384    let labels_pred_binding = labels_pred.readonly();
385    let labels_true_data = labels_true_binding.as_array();
386    let labels_pred_data = labels_pred_binding.as_array();
387
388    adjusted_rand_index(&labels_true_data, &labels_pred_data)
389        .map_err(|e| PyRuntimeError::new_err(format!("Adjusted Rand index failed: {}", e)))
390}
391
392/// Calculate normalized mutual information score
393#[pyfunction]
394#[pyo3(signature = (labels_true, labels_pred, average_method="arithmetic"))]
395fn nmi_score_py(
396    labels_true: &Bound<'_, PyArray1<i64>>,
397    labels_pred: &Bound<'_, PyArray1<i64>>,
398    average_method: &str,
399) -> PyResult<f64> {
400    let labels_true_binding = labels_true.readonly();
401    let labels_pred_binding = labels_pred.readonly();
402    let labels_true_data = labels_true_binding.as_array();
403    let labels_pred_data = labels_pred_binding.as_array();
404
405    normalized_mutual_info_score(&labels_true_data, &labels_pred_data, average_method)
406        .map_err(|e| PyRuntimeError::new_err(format!("NMI score failed: {}", e)))
407}
408
409// ========================================
410// RANKING METRICS
411// ========================================
412
413/// Calculate NDCG score
414/// y_true and y_score are 2D arrays where each row represents a query
415#[pyfunction]
416#[pyo3(signature = (y_true, y_score, k=None))]
417fn ndcg_score_py(
418    y_true: &Bound<'_, PyArray2<f64>>,
419    y_score: &Bound<'_, PyArray2<f64>>,
420    k: Option<usize>,
421) -> PyResult<f64> {
422    let y_true_binding = y_true.readonly();
423    let y_score_binding = y_score.readonly();
424    let y_true_data = y_true_binding.as_array();
425    let y_score_data = y_score_binding.as_array();
426
427    // Convert 2D arrays to Vec of 1D arrays (each row is a query)
428    let y_true_vec: Vec<Array1<f64>> = y_true_data
429        .rows()
430        .into_iter()
431        .map(|row| row.to_owned())
432        .collect();
433    let y_score_vec: Vec<Array1<f64>> = y_score_data
434        .rows()
435        .into_iter()
436        .map(|row| row.to_owned())
437        .collect();
438
439    ndcg_score(&y_true_vec, &y_score_vec, k)
440        .map_err(|e| PyRuntimeError::new_err(format!("NDCG score failed: {}", e)))
441}
442
443/// Calculate mean reciprocal rank
444/// y_true and y_score are 2D arrays where each row represents a query
445#[pyfunction]
446fn mrr_py(y_true: &Bound<'_, PyArray2<f64>>, y_score: &Bound<'_, PyArray2<f64>>) -> PyResult<f64> {
447    let y_true_binding = y_true.readonly();
448    let y_score_binding = y_score.readonly();
449    let y_true_data = y_true_binding.as_array();
450    let y_score_data = y_score_binding.as_array();
451
452    // Convert 2D arrays to Vec of 1D arrays (each row is a query)
453    let y_true_vec: Vec<Array1<f64>> = y_true_data
454        .rows()
455        .into_iter()
456        .map(|row| row.to_owned())
457        .collect();
458    let y_score_vec: Vec<Array1<f64>> = y_score_data
459        .rows()
460        .into_iter()
461        .map(|row| row.to_owned())
462        .collect();
463
464    mean_reciprocal_rank(&y_true_vec, &y_score_vec)
465        .map_err(|e| PyRuntimeError::new_err(format!("MRR failed: {}", e)))
466}
467
468/// Python module registration
469pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
470    // Classification metrics
471    m.add_function(wrap_pyfunction!(accuracy_score_py, m)?)?;
472    m.add_function(wrap_pyfunction!(precision_score_py, m)?)?;
473    m.add_function(wrap_pyfunction!(recall_score_py, m)?)?;
474    m.add_function(wrap_pyfunction!(f1_score_py, m)?)?;
475    m.add_function(wrap_pyfunction!(fbeta_score_py, m)?)?;
476    m.add_function(wrap_pyfunction!(confusion_matrix_py, m)?)?;
477    m.add_function(wrap_pyfunction!(roc_curve_py, m)?)?;
478    m.add_function(wrap_pyfunction!(roc_auc_score_py, m)?)?;
479    m.add_function(wrap_pyfunction!(log_loss_py, m)?)?;
480    m.add_function(wrap_pyfunction!(matthews_corrcoef_py, m)?)?;
481    m.add_function(wrap_pyfunction!(balanced_accuracy_score_py, m)?)?;
482    m.add_function(wrap_pyfunction!(cohen_kappa_score_py, m)?)?;
483
484    // Regression metrics
485    m.add_function(wrap_pyfunction!(mean_squared_error_py, m)?)?;
486    m.add_function(wrap_pyfunction!(mean_absolute_error_py, m)?)?;
487    m.add_function(wrap_pyfunction!(r2_score_py, m)?)?;
488    m.add_function(wrap_pyfunction!(mape_py, m)?)?;
489    m.add_function(wrap_pyfunction!(explained_variance_score_py, m)?)?;
490
491    // Clustering metrics
492    m.add_function(wrap_pyfunction!(silhouette_score_py, m)?)?;
493    m.add_function(wrap_pyfunction!(davies_bouldin_score_py, m)?)?;
494    m.add_function(wrap_pyfunction!(calinski_harabasz_score_py, m)?)?;
495    m.add_function(wrap_pyfunction!(adjusted_rand_index_py, m)?)?;
496    m.add_function(wrap_pyfunction!(nmi_score_py, m)?)?;
497
498    // Ranking metrics
499    m.add_function(wrap_pyfunction!(ndcg_score_py, m)?)?;
500    m.add_function(wrap_pyfunction!(mrr_py, m)?)?;
501
502    Ok(())
503}