pub use aprender::prelude::{
mae,
mse,
r_squared,
rmse,
Adam,
DecisionTreeClassifier,
DecisionTreeRegressor,
ElasticNet,
Estimator,
KMeans,
Lasso,
LinearRegression,
LogisticRegression,
Matrix,
MinMaxScaler,
RandomForestRegressor,
Ridge,
StandardScaler,
Transformer,
UnsupervisedEstimator,
Vector,
DBSCAN,
SGD,
};
pub use aprender::preprocessing::PCA;
pub use aprender::serialization::SafeTensorsMetadata;
pub use aprender::format::quantize::{
dequantize, quantization_mse, quantize, Q4_0Quantizer, Q8_0Quantizer, QuantType,
QuantizationInfo, QuantizedBlock, Quantizer, BLOCK_SIZE,
};
pub use aprender::format::gguf::{
export_tensors_to_gguf, GgmlType, GgufHeader, GgufTensor, GgufTensorInfo, GgufValue,
GgufValueType, GGUF_DEFAULT_ALIGNMENT, GGUF_MAGIC, GGUF_VERSION,
};
pub use aprender::hf_hub::{HfHubClient, HfHubError, PushOptions};
pub use aprender::format::{
load_verified, save_signed, SigningKey, VerifyingKey, PUBLIC_KEY_SIZE, SIGNATURE_SIZE,
};
pub fn compute_r2(y_true: &[f64], y_pred: &[f64]) -> f64 {
if y_true.len() != y_pred.len() || y_true.is_empty() {
return 0.0;
}
let mean_true: f64 = y_true.iter().sum::<f64>() / y_true.len() as f64;
let ss_tot: f64 = y_true.iter().map(|&y| (y - mean_true).powi(2)).sum();
let ss_res: f64 = y_true
.iter()
.zip(y_pred.iter())
.map(|(&yt, &yp)| (yt - yp).powi(2))
.sum();
if ss_tot == 0.0 {
return 1.0;
}
1.0 - (ss_res / ss_tot)
}
pub fn compute_mse(y_true: &[f64], y_pred: &[f64]) -> f64 {
if y_true.len() != y_pred.len() || y_true.is_empty() {
return 0.0;
}
y_true
.iter()
.zip(y_pred.iter())
.map(|(&yt, &yp)| (yt - yp).powi(2))
.sum::<f64>()
/ y_true.len() as f64
}
pub fn compute_mae(y_true: &[f64], y_pred: &[f64]) -> f64 {
if y_true.len() != y_pred.len() || y_true.is_empty() {
return 0.0;
}
y_true
.iter()
.zip(y_pred.iter())
.map(|(&yt, &yp)| (yt - yp).abs())
.sum::<f64>()
/ y_true.len() as f64
}
pub fn compute_rmse(y_true: &[f64], y_pred: &[f64]) -> f64 {
compute_mse(y_true, y_pred).sqrt()
}
pub fn compute_accuracy(y_true: &[usize], y_pred: &[usize]) -> f64 {
if y_true.len() != y_pred.len() || y_true.is_empty() {
return 0.0;
}
let correct = y_true
.iter()
.zip(y_pred.iter())
.filter(|(a, b)| a == b)
.count();
correct as f64 / y_true.len() as f64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mse_perfect() {
let y_true = vec![1.0, 2.0, 3.0];
let y_pred = vec![1.0, 2.0, 3.0];
let mse_val = compute_mse(&y_true, &y_pred);
assert!((mse_val - 0.0).abs() < 1e-10);
}
#[test]
fn test_mae_basic() {
let y_true = vec![1.0, 2.0, 3.0];
let y_pred = vec![1.5, 2.5, 3.5];
let mae_val = compute_mae(&y_true, &y_pred);
assert!((mae_val - 0.5).abs() < 1e-10);
}
#[test]
fn test_r2_perfect() {
let y_true = vec![1.0, 2.0, 3.0];
let y_pred = vec![1.0, 2.0, 3.0];
let r2 = compute_r2(&y_true, &y_pred);
assert!((r2 - 1.0).abs() < 1e-10);
}
#[test]
fn test_accuracy_perfect() {
let y_true = vec![0, 1, 0, 1];
let y_pred = vec![0, 1, 0, 1];
let acc = compute_accuracy(&y_true, &y_pred);
assert!((acc - 1.0).abs() < 1e-10);
}
#[test]
fn test_accuracy_half() {
let y_true = vec![0, 1, 0, 1];
let y_pred = vec![0, 0, 1, 1];
let acc = compute_accuracy(&y_true, &y_pred);
assert!((acc - 0.5).abs() < 1e-10);
}
#[test]
fn test_quantize_reexport_q8() {
let weights: Vec<f32> = (0..32).map(|i| (i as f32 - 16.0) / 10.0).collect();
let shape = vec![32];
let quantized = quantize(&weights, &shape, QuantType::Q8_0).expect("Q8_0 quantize");
let restored = dequantize(&quantized).expect("Q8_0 dequantize");
assert_eq!(restored.len(), 32);
let mse = quantization_mse(&weights, &restored);
assert!(mse < 0.01, "Q8_0 MSE too high: {mse}");
}
#[test]
fn test_quantize_reexport_q4() {
let weights: Vec<f32> = (0..32).map(|i| (i as f32 - 16.0) / 10.0).collect();
let shape = vec![32];
let quantized = quantize(&weights, &shape, QuantType::Q4_0).expect("Q4_0 quantize");
let restored = dequantize(&quantized).expect("Q4_0 dequantize");
assert_eq!(restored.len(), 32);
let mse = quantization_mse(&weights, &restored);
assert!(mse < 0.5, "Q4_0 MSE too high: {mse}");
}
#[test]
fn test_quant_type_reexport() {
assert_eq!(QuantType::Q8_0 as u8, 0x01);
assert_eq!(QuantType::Q4_0 as u8, 0x02);
}
#[test]
fn test_block_size_constant() {
assert_eq!(BLOCK_SIZE, 32);
}
#[test]
fn test_gguf_reexport() {
assert_eq!(GGUF_MAGIC, 0x4655_4747); assert_eq!(GGUF_VERSION, 3);
assert_eq!(GGUF_DEFAULT_ALIGNMENT, 32);
}
#[test]
fn test_hf_hub_client_reexport() {
let client = HfHubClient::with_token("test_token");
assert!(client.is_authenticated());
}
#[test]
fn test_push_options_reexport() {
let opts = PushOptions::default()
.with_commit_message("Test commit")
.with_filename("model.apr");
assert_eq!(opts.filename, "model.apr");
}
#[test]
fn test_hf_hub_error_reexport() {
let err = HfHubError::MissingToken;
assert!(err.to_string().contains("HF_TOKEN"));
}
#[test]
fn test_signing_constants_reexport() {
assert_eq!(SIGNATURE_SIZE, 64); assert_eq!(PUBLIC_KEY_SIZE, 32); }
#[test]
fn test_signing_key_reexport() {
let signing_key = SigningKey::generate(&mut rand::rngs::OsRng);
let verifying_key = signing_key.verifying_key();
assert_eq!(verifying_key.as_bytes().len(), PUBLIC_KEY_SIZE);
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(1000))]
#[test]
fn prop_r2_bounded(
y_true in prop::collection::vec(0.0..100.0f64, 5..20),
y_pred in prop::collection::vec(0.0..100.0f64, 5..20),
) {
if y_true.len() == y_pred.len() {
let r2 = compute_r2(&y_true, &y_pred);
prop_assert!(r2 <= 1.0 + 1e-10);
}
}
#[test]
fn prop_mse_non_negative(
y_true in prop::collection::vec(-100.0..100.0f64, 1..50),
y_pred in prop::collection::vec(-100.0..100.0f64, 1..50),
) {
if y_true.len() == y_pred.len() {
let mse_val = compute_mse(&y_true, &y_pred);
prop_assert!(mse_val >= 0.0);
}
}
#[test]
fn prop_mae_non_negative(
y_true in prop::collection::vec(-100.0..100.0f64, 1..50),
y_pred in prop::collection::vec(-100.0..100.0f64, 1..50),
) {
if y_true.len() == y_pred.len() {
let mae_val = compute_mae(&y_true, &y_pred);
prop_assert!(mae_val >= 0.0);
}
}
#[test]
fn prop_accuracy_bounded(
y_true in prop::collection::vec(0usize..2, 1..50),
y_pred in prop::collection::vec(0usize..2, 1..50),
) {
if y_true.len() == y_pred.len() {
let acc = compute_accuracy(&y_true, &y_pred);
prop_assert!((0.0..=1.0).contains(&acc));
}
}
}
}