use crate::error::{Result, SklearsError};
use crate::traits::{Fit, Predict};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Dimension};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
pub mod sklearn {
use super::*;
use crate::traits::{Estimator, Score};
pub trait SklearnCompatible {
fn set_param(&mut self, param: &str, value: impl Into<ParamValue>) -> Result<()>;
fn get_param(&self, param: &str) -> Result<ParamValue>;
fn get_params(&self, deep: bool) -> HashMap<String, ParamValue>;
fn set_params(&mut self, params: HashMap<String, ParamValue>) -> Result<()>;
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ParamValue {
Bool(bool),
Int(i64),
Float(f64),
String(String),
Array(Vec<f64>),
None,
}
impl From<bool> for ParamValue {
fn from(value: bool) -> Self {
ParamValue::Bool(value)
}
}
impl From<i64> for ParamValue {
fn from(value: i64) -> Self {
ParamValue::Int(value)
}
}
impl From<f64> for ParamValue {
fn from(value: f64) -> Self {
ParamValue::Float(value)
}
}
impl From<String> for ParamValue {
fn from(value: String) -> Self {
ParamValue::String(value)
}
}
impl From<&str> for ParamValue {
fn from(value: &str) -> Self {
ParamValue::String(value.to_string())
}
}
#[derive(Debug, Clone)]
pub struct ScikitLearnModel {
model_type: String,
parameters: HashMap<String, ParamValue>,
fitted: bool,
}
impl ScikitLearnModel {
pub fn linear_regression() -> Self {
let mut params = HashMap::new();
params.insert("fit_intercept".to_string(), ParamValue::Bool(true));
params.insert("normalize".to_string(), ParamValue::Bool(false));
params.insert("copy_X".to_string(), ParamValue::Bool(true));
params.insert("n_jobs".to_string(), ParamValue::None);
Self {
model_type: "LinearRegression".to_string(),
parameters: params,
fitted: false,
}
}
pub fn random_forest_classifier() -> Self {
let mut params = HashMap::new();
params.insert("n_estimators".to_string(), ParamValue::Int(100));
params.insert(
"criterion".to_string(),
ParamValue::String("gini".to_string()),
);
params.insert("max_depth".to_string(), ParamValue::None);
params.insert("min_samples_split".to_string(), ParamValue::Int(2));
params.insert("min_samples_leaf".to_string(), ParamValue::Int(1));
params.insert(
"max_features".to_string(),
ParamValue::String("auto".to_string()),
);
params.insert("bootstrap".to_string(), ParamValue::Bool(true));
params.insert("oob_score".to_string(), ParamValue::Bool(false));
params.insert("n_jobs".to_string(), ParamValue::None);
params.insert("random_state".to_string(), ParamValue::None);
Self {
model_type: "RandomForestClassifier".to_string(),
parameters: params,
fitted: false,
}
}
pub fn svm_classifier() -> Self {
let mut params = HashMap::new();
params.insert("C".to_string(), ParamValue::Float(1.0));
params.insert("kernel".to_string(), ParamValue::String("rbf".to_string()));
params.insert("degree".to_string(), ParamValue::Int(3));
params.insert("gamma".to_string(), ParamValue::String("scale".to_string()));
params.insert("coef0".to_string(), ParamValue::Float(0.0));
params.insert("shrinking".to_string(), ParamValue::Bool(true));
params.insert("probability".to_string(), ParamValue::Bool(false));
params.insert("tol".to_string(), ParamValue::Float(1e-3));
params.insert("cache_size".to_string(), ParamValue::Float(200.0));
params.insert("class_weight".to_string(), ParamValue::None);
params.insert("verbose".to_string(), ParamValue::Bool(false));
params.insert("max_iter".to_string(), ParamValue::Int(-1));
params.insert(
"decision_function_shape".to_string(),
ParamValue::String("ovr".to_string()),
);
params.insert("break_ties".to_string(), ParamValue::Bool(false));
params.insert("random_state".to_string(), ParamValue::None);
Self {
model_type: "SVC".to_string(),
parameters: params,
fitted: false,
}
}
}
impl SklearnCompatible for ScikitLearnModel {
fn set_param(&mut self, param: &str, value: impl Into<ParamValue>) -> Result<()> {
self.parameters.insert(param.to_string(), value.into());
Ok(())
}
fn get_param(&self, param: &str) -> Result<ParamValue> {
self.parameters
.get(param)
.cloned()
.ok_or_else(|| SklearsError::InvalidInput(format!("Parameter '{param}' not found")))
}
fn get_params(&self, deep: bool) -> HashMap<String, ParamValue> {
if deep {
self.parameters.clone()
} else {
self.parameters.clone()
}
}
fn set_params(&mut self, params: HashMap<String, ParamValue>) -> Result<()> {
for (key, value) in params {
self.parameters.insert(key, value);
}
Ok(())
}
}
impl Estimator for ScikitLearnModel {
type Config = HashMap<String, ParamValue>;
type Error = SklearsError;
type Float = f64;
fn config(&self) -> &Self::Config {
&self.parameters
}
}
impl<'a> Fit<ArrayView2<'a, f64>, ArrayView1<'a, f64>> for ScikitLearnModel {
type Fitted = FittedScikitLearnModel;
fn fit(mut self, x: &ArrayView2<'a, f64>, y: &ArrayView1<'a, f64>) -> Result<Self::Fitted> {
if x.nrows() != y.len() {
return Err(SklearsError::ShapeMismatch {
expected: format!("({}, n_features)", y.len()),
actual: format!("({}, {})", x.nrows(), x.ncols()),
});
}
self.fitted = true;
Ok(FittedScikitLearnModel {
model: self,
training_shape: (x.nrows(), x.ncols()),
feature_importances: vec![0.1; x.ncols()], classes: get_unique_classes(y),
})
}
}
#[derive(Debug, Clone)]
pub struct FittedScikitLearnModel {
model: ScikitLearnModel,
training_shape: (usize, usize),
feature_importances: Vec<f64>,
classes: Vec<f64>,
}
impl FittedScikitLearnModel {
pub fn feature_importances(&self) -> &[f64] {
&self.feature_importances
}
pub fn classes(&self) -> &[f64] {
&self.classes
}
pub fn n_features_in(&self) -> usize {
self.training_shape.1
}
}
impl<'a> Predict<ArrayView2<'a, f64>, Array1<f64>> for FittedScikitLearnModel {
fn predict(&self, x: &ArrayView2<'a, f64>) -> Result<Array1<f64>> {
if x.ncols() != self.training_shape.1 {
return Err(SklearsError::FeatureMismatch {
expected: self.training_shape.1,
actual: x.ncols(),
});
}
let predictions = match self.model.model_type.as_str() {
"LinearRegression" => {
Array1::from_iter(x.rows().into_iter().map(|row| row.sum() * 0.1))
}
"RandomForestClassifier" | "SVC" => {
let most_common_class = self.classes.first().copied().unwrap_or(0.0);
Array1::from_elem(x.nrows(), most_common_class)
}
_ => Array1::zeros(x.nrows()),
};
Ok(predictions)
}
}
impl<'a> Score<ArrayView2<'a, f64>, ArrayView1<'a, f64>> for FittedScikitLearnModel {
type Float = f64;
fn score(&self, x: &ArrayView2<'a, f64>, y: &ArrayView1<'a, f64>) -> Result<f64> {
let predictions = self.predict(x)?;
match self.model.model_type.as_str() {
"LinearRegression" => {
let y_mean = y.mean().unwrap_or(0.0);
let ss_res = predictions
.iter()
.zip(y.iter())
.map(|(pred, actual)| (actual - pred).powi(2))
.sum::<f64>();
let ss_tot = y
.iter()
.map(|actual| (actual - y_mean).powi(2))
.sum::<f64>();
if ss_tot == 0.0 {
Ok(1.0)
} else {
Ok(1.0 - (ss_res / ss_tot))
}
}
_ => {
let correct = predictions
.iter()
.zip(y.iter())
.map(|(pred, actual)| {
if (pred - actual).abs() < 0.5 {
1.0
} else {
0.0
}
})
.sum::<f64>();
Ok(correct / y.len() as f64)
}
}
}
}
fn get_unique_classes(y: &ArrayView1<f64>) -> Vec<f64> {
let mut classes: Vec<f64> = y.iter().copied().collect();
classes.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
classes.dedup_by(|a, b| (*a - *b).abs() < 1e-10);
classes
}
}
pub mod numpy {
use super::*;
use bytemuck::{Pod, Zeroable};
#[derive(Debug, Clone)]
pub struct NumpyArray<T: Pod + Zeroable> {
data: Vec<T>,
shape: Vec<usize>,
strides: Vec<usize>,
dtype: String,
fortran_order: bool,
}
impl<T: Pod + Zeroable + fmt::Debug> NumpyArray<T> {
pub fn from_ndarray<D: Dimension>(
array: &scirs2_core::ndarray::ArrayBase<scirs2_core::ndarray::OwnedRepr<T>, D>,
) -> Result<Self> {
let shape = array.shape().to_vec();
let strides = array.strides().iter().map(|&s| s as usize).collect();
let data = array.iter().cloned().collect();
let dtype = Self::get_dtype_string();
Ok(Self {
data,
shape,
strides,
dtype,
fortran_order: false,
})
}
pub fn from_raw(data: Vec<T>, shape: Vec<usize>) -> Result<Self> {
let expected_size = shape.iter().product::<usize>();
if data.len() != expected_size {
return Err(SklearsError::ShapeMismatch {
expected: format!("{expected_size} elements"),
actual: format!("{} elements", data.len()),
});
}
let strides = Self::calculate_strides(&shape, false);
let dtype = Self::get_dtype_string();
Ok(Self {
data,
shape,
strides,
dtype,
fortran_order: false,
})
}
pub fn to_bytes(&self) -> Result<Vec<u8>> {
let header = self.create_numpy_header()?;
let data_bytes = bytemuck::cast_slice(&self.data);
let mut result = Vec::new();
result.extend_from_slice(&header);
result.extend_from_slice(data_bytes);
Ok(result)
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn strides(&self) -> &[usize] {
&self.strides
}
pub fn dtype(&self) -> &str {
&self.dtype
}
pub fn data(&self) -> &[T] {
&self.data
}
pub fn to_ndarray(&self) -> Result<Array2<T>> {
if self.shape.len() != 2 {
return Err(SklearsError::InvalidInput(
"Only 2D arrays are currently supported for conversion back to ndarray"
.to_string(),
));
}
Array2::from_shape_vec((self.shape[0], self.shape[1]), self.data.clone())
.map_err(|e| SklearsError::InvalidInput(format!("Failed to create ndarray: {e}")))
}
fn get_dtype_string() -> String {
if std::mem::size_of::<T>() == 8 {
"<f8".to_string() } else if std::mem::size_of::<T>() == 4 {
"<f4".to_string() } else {
"<i8".to_string() }
}
fn calculate_strides(shape: &[usize], fortran_order: bool) -> Vec<usize> {
let mut strides = vec![0; shape.len()];
let item_size = std::mem::size_of::<T>();
if fortran_order {
let mut stride = item_size;
for i in 0..shape.len() {
strides[i] = stride;
stride *= shape[i];
}
} else {
let mut stride = item_size;
for i in (0..shape.len()).rev() {
strides[i] = stride;
stride *= shape[i];
}
}
strides
}
fn create_numpy_header(&self) -> Result<Vec<u8>> {
let header_dict = format!(
"{{'descr': '{}', 'fortran_order': {}, 'shape': ({},)}}",
self.dtype,
self.fortran_order,
self.shape
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(", ")
);
let mut header = header_dict.into_bytes();
while header.len() % 64 != 0 {
header.push(b' ');
}
header.push(b'\n');
Ok(header)
}
}
}
pub mod pandas {
use super::*;
use std::collections::BTreeMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DataFrame {
columns: Vec<String>,
data: BTreeMap<String, Vec<DataValue>>,
index: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DataValue {
Float(f64),
Int(i64),
String(String),
Bool(bool),
None,
}
impl DataFrame {
pub fn new() -> Self {
Self {
columns: Vec::new(),
data: BTreeMap::new(),
index: Vec::new(),
}
}
pub fn from_ndarray(array: &Array2<f64>, columns: Option<Vec<String>>) -> Result<Self> {
let n_cols = array.ncols();
let n_rows = array.nrows();
let columns =
columns.unwrap_or_else(|| (0..n_cols).map(|i| format!("col_{i}")).collect());
if columns.len() != n_cols {
return Err(SklearsError::ShapeMismatch {
expected: format!("{n_cols} columns"),
actual: format!("{} column names", columns.len()),
});
}
let mut data = BTreeMap::new();
for (col_idx, col_name) in columns.iter().enumerate() {
let column_data: Vec<DataValue> = (0..n_rows)
.map(|row_idx| DataValue::Float(array[[row_idx, col_idx]]))
.collect();
data.insert(col_name.clone(), column_data);
}
let index: Vec<String> = (0..n_rows).map(|i| i.to_string()).collect();
Ok(Self {
columns,
data,
index,
})
}
pub fn add_column(&mut self, name: String, values: Vec<DataValue>) -> Result<()> {
if !self.data.is_empty() && values.len() != self.index.len() {
return Err(SklearsError::ShapeMismatch {
expected: format!("{} rows", self.index.len()),
actual: format!("{} values", values.len()),
});
}
if self.data.is_empty() {
self.index = (0..values.len()).map(|i| i.to_string()).collect();
}
self.columns.push(name.clone());
self.data.insert(name, values);
Ok(())
}
pub fn columns(&self) -> &[String] {
&self.columns
}
pub fn get_column(&self, name: &str) -> Option<&Vec<DataValue>> {
self.data.get(name)
}
pub fn shape(&self) -> (usize, usize) {
(self.index.len(), self.columns.len())
}
pub fn to_ndarray(&self) -> Result<Array2<f64>> {
let (n_rows, n_cols) = self.shape();
let mut array = Array2::zeros((n_rows, n_cols));
for (col_idx, col_name) in self.columns.iter().enumerate() {
if let Some(column) = self.data.get(col_name) {
for (row_idx, value) in column.iter().enumerate() {
match value {
DataValue::Float(f) => array[[row_idx, col_idx]] = *f,
DataValue::Int(i) => array[[row_idx, col_idx]] = *i as f64,
DataValue::Bool(b) => {
array[[row_idx, col_idx]] = if *b { 1.0 } else { 0.0 }
}
_ => {
return Err(SklearsError::InvalidInput(format!(
"Non-numeric value in column '{col_name}' at row {row_idx}"
)))
}
}
}
}
}
Ok(array)
}
pub fn describe(&self) -> Result<DataFrame> {
let mut stats_df = DataFrame::new();
let stats = ["count", "mean", "std", "min", "25%", "50%", "75%", "max"];
for stat in &stats {
stats_df.add_column(stat.to_string(), Vec::new())?;
}
for col_name in &self.columns {
if let Some(column) = self.data.get(col_name) {
let numeric_values: Vec<f64> = column
.iter()
.filter_map(|v| match v {
DataValue::Float(f) => Some(*f),
DataValue::Int(i) => Some(*i as f64),
_ => None,
})
.collect();
if !numeric_values.is_empty() {
let count = numeric_values.len() as f64;
let mean = numeric_values.iter().sum::<f64>() / count;
let variance = numeric_values
.iter()
.map(|x| (x - mean).powi(2))
.sum::<f64>()
/ count;
let _std = variance.sqrt();
let mut sorted = numeric_values.clone();
sorted
.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let _min = sorted[0];
let _max = sorted[sorted.len() - 1];
let _q25 = sorted[sorted.len() / 4];
let _q50 = sorted[sorted.len() / 2];
let _q75 = sorted[3 * sorted.len() / 4];
}
}
}
Ok(stats_df)
}
}
impl Default for DataFrame {
fn default() -> Self {
Self::new()
}
}
}
pub mod pytorch {
use super::*;
use bytemuck::{Pod, Zeroable};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorMetadata {
pub shape: Vec<usize>,
pub dtype: String,
pub requires_grad: bool,
pub device: String,
}
pub fn ndarray_to_pytorch_tensor<T: Pod + Zeroable>(
array: &Array2<T>,
requires_grad: bool,
) -> Result<(Vec<u8>, TensorMetadata)> {
let shape = array.shape().to_vec();
let data_bytes = bytemuck::cast_slice(array.as_slice().unwrap_or(&[]));
let dtype = if std::mem::size_of::<T>() == 8 {
"float64"
} else {
"float32"
}
.to_string();
let metadata = TensorMetadata {
shape,
dtype,
requires_grad,
device: "cpu".to_string(),
};
Ok((data_bytes.to_vec(), metadata))
}
}
pub mod serialization {
use super::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ModelFormat {
SklearnPickle,
XGBoostJson,
LightGBMText,
TensorFlowSavedModel,
PyTorchStateDict,
OnnxProtobuf,
}
pub trait ModelSerialization {
fn serialize(&self, format: ModelFormat) -> Result<Vec<u8>>;
fn deserialize(data: &[u8], format: ModelFormat) -> Result<Self>
where
Self: Sized;
fn supported_formats() -> Vec<ModelFormat>;
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CrossPlatformModel {
pub model_type: String,
pub version: String,
pub parameters: HashMap<String, serde_json::Value>,
pub weights: Vec<f64>,
pub metadata: HashMap<String, String>,
}
impl CrossPlatformModel {
pub fn to_sklearn_metadata(&self) -> Result<HashMap<String, serde_json::Value>> {
let mut sklearn_meta = HashMap::new();
sklearn_meta.insert(
"__class__".to_string(),
serde_json::Value::String(self.model_type.clone()),
);
sklearn_meta.insert(
"__version__".to_string(),
serde_json::Value::String(self.version.clone()),
);
sklearn_meta.extend(self.parameters.clone());
Ok(sklearn_meta)
}
pub fn from_sklearn_metadata(metadata: HashMap<String, serde_json::Value>) -> Result<Self> {
let model_type = metadata
.get("__class__")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
let version = metadata
.get("__version__")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
let mut parameters = metadata;
parameters.remove("__class__");
parameters.remove("__version__");
Ok(Self {
model_type,
version,
parameters,
weights: Vec::new(),
metadata: HashMap::new(),
})
}
}
}
#[allow(non_snake_case)]
#[cfg(test)]
mod tests {
use super::numpy::*;
use super::pandas::*;
use super::sklearn::*;
use super::*;
use crate::traits::Fit;
#[test]
fn test_sklearn_linear_regression() {
let mut model = ScikitLearnModel::linear_regression();
assert!(model.set_param("fit_intercept", false).is_ok());
assert_eq!(
model
.get_param("fit_intercept")
.expect("get_param should succeed"),
ParamValue::Bool(false)
);
}
#[test]
fn test_sklearn_parameter_management() {
let mut model = ScikitLearnModel::random_forest_classifier();
assert!(model.set_param("n_estimators", 200).is_ok());
assert!(model.set_param("max_depth", 10).is_ok());
assert_eq!(
model
.get_param("n_estimators")
.expect("get_param should succeed"),
ParamValue::Int(200)
);
assert_eq!(
model
.get_param("max_depth")
.expect("get_param should succeed"),
ParamValue::Int(10)
);
let params = model.get_params(false);
assert!(params.contains_key("n_estimators"));
assert!(params.contains_key("max_depth"));
}
#[test]
fn test_numpy_array_conversion() {
let array = Array2::<f64>::zeros((10, 5));
let numpy_array = NumpyArray::from_ndarray(&array);
assert!(numpy_array.is_ok());
let numpy_array = numpy_array.expect("expected valid value");
assert_eq!(numpy_array.shape(), &[10, 5]);
assert_eq!(numpy_array.data().len(), 50);
}
#[test]
fn test_pandas_dataframe() {
let mut df = DataFrame::new();
let values = vec![
DataValue::Float(1.0),
DataValue::Float(2.0),
DataValue::Float(3.0),
];
assert!(df.add_column("test_col".to_string(), values).is_ok());
assert_eq!(df.shape(), (3, 1));
assert_eq!(df.columns(), &["test_col"]);
}
#[test]
fn test_dataframe_to_ndarray() {
let array =
Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("valid array shape");
let df = DataFrame::from_ndarray(&array, None).expect("expected valid value");
let converted = df.to_ndarray().expect("to_ndarray should succeed");
assert_eq!(converted.shape(), [2, 2]);
assert_eq!(converted[[0, 0]], 1.0);
assert_eq!(converted[[1, 1]], 4.0);
}
#[test]
fn test_sklearn_model_fitting() {
let model = ScikitLearnModel::linear_regression();
let features = Array2::zeros((10, 3));
let targets = Array1::zeros(10);
let fitted = model.fit(&features.view(), &targets.view());
assert!(fitted.is_ok());
let fitted = fitted.expect("expected valid value");
assert_eq!(fitted.n_features_in(), 3);
}
#[test]
fn test_cross_platform_model() {
use serialization::CrossPlatformModel;
let model = CrossPlatformModel {
model_type: "LinearRegression".to_string(),
version: "1.0".to_string(),
parameters: HashMap::new(),
weights: vec![1.0, 2.0, 3.0],
metadata: HashMap::new(),
};
let sklearn_meta = model.to_sklearn_metadata();
assert!(sklearn_meta.is_ok());
let meta = sklearn_meta.expect("expected valid value");
assert_eq!(
meta.get("__class__")
.expect("key should exist")
.as_str()
.expect("key should exist"),
"LinearRegression"
);
}
}