#[macro_export]
macro_rules! quick_dataset {
(data: $data:expr, target: $target:expr) => {
$crate::dataset::Dataset::builder()
.data($data)
.target($target)
.build()
};
(data: $data:expr) => {
$crate::dataset::Dataset::builder().data($data).build()
};
}
#[macro_export]
macro_rules! define_ml_float_bounds {
($name:ident) => {
trait $name:
Float + NumCast + Copy + Clone + Send + Sync + std::fmt::Debug
{
}
impl<T> $name for T where
T: Float
+ NumCast
+ Copy
+ Clone
+ Send
+ Sync
+ std::fmt::Debug
{
}
};
}
#[macro_export]
macro_rules! parameter_map {
($($param:ident: $value:expr),* $(,)?) => {
{
let mut params = std::collections::HashMap::new();
$(
params.insert(stringify!($param).to_string(), $value);
)*
params
}
};
}
#[macro_export]
macro_rules! impl_default_config {
($struct_name:ident { $($field:ident: $default:expr),* $(,)? }) => {
impl Default for $struct_name {
fn default() -> Self {
Self {
$($field: $default,)*
}
}
}
};
}
#[macro_export]
macro_rules! impl_ml_traits {
($estimator:ident) => {
impl $crate::traits::Estimator for $estimator {
type Config = ();
fn name(&self) -> &str {
stringify!($estimator)
}
}
};
($estimator:ident, config: $config:ty) => {
impl $crate::traits::Estimator for $estimator {
type Config = $config;
fn name(&self) -> &str {
stringify!($estimator)
}
}
};
}
#[macro_export]
macro_rules! estimator_test_suite {
($estimator:ident) => {
#[allow(non_snake_case)]
#[cfg(test)]
mod tests {
use super::*;
use $crate::test_utilities::*;
#[test]
fn test_estimator_creation() {
let estimator = $estimator::new();
assert_eq!(estimator.name(), stringify!($estimator));
}
#[test]
fn test_estimator_clone() {
let estimator = $estimator::new();
let cloned = estimator.clone();
assert_eq!(estimator.name(), cloned.name());
}
}
};
}
#[macro_export]
macro_rules! define_estimator {
(
name: $name:ident,
config: $config:ident {
$(
$field:ident: $type:ty = $default:expr
),* $(,)?
},
features: [$($trait:ident),* $(,)?],
validation: {
$(
$validation:expr
),* $(,)?
}
) => {
#[derive(Debug, Clone, PartialEq)]
pub struct $config {
$(
pub $field: $type,
)*
}
impl Default for $config {
fn default() -> Self {
Self {
$(
$field: $default,
)*
}
}
}
#[derive(Debug, Clone)]
pub struct $name<State = $crate::types::const_generic::FixedFeatures<f64, 1>> {
config: $config,
_state: std::marker::PhantomData<State>,
}
impl Default for $name {
fn default() -> Self {
Self::new()
}
}
impl $name {
pub fn new() -> Self {
Self {
config: $config::default(),
_state: std::marker::PhantomData,
}
}
pub fn builder() -> $name<()> {
$name {
config: $config::default(),
_state: std::marker::PhantomData,
}
}
$(
pub fn $field(mut self, value: $type) -> Self {
self.config.$field = value;
self
}
)*
pub fn validate(&self) -> $crate::error::Result<()> {
$(
if !($validation) {
return Err($crate::error::SklearsError::InvalidInput(
format!("Validation failed: {}", stringify!($validation))
));
}
)*
Ok(())
}
pub fn name(&self) -> &'static str {
stringify!($name)
}
}
impl $crate::traits::Estimator for $name {
type Config = $config;
fn name(&self) -> &'static str {
stringify!($name)
}
fn config(&self) -> &Self::Config {
&self.config
}
}
#[allow(non_snake_case)]
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_creation() {
let estimator = $name::default();
assert_eq!(estimator.name(), stringify!($name));
estimator.validate().expect("validate should succeed");
}
#[test]
fn test_builder_pattern() {
let estimator = $name::builder();
estimator.validate().expect("validate should succeed");
}
}
};
}
#[macro_export]
macro_rules! validation_rules {
($(
$rule_name:ident: |$param:ident: $type:ty| $condition:expr
),* $(,)?) => {
$(
pub fn $rule_name($param: $type) -> bool {
$condition
}
)*
};
}
#[macro_export]
macro_rules! define_ml_algorithm {
(
name: $name:ident,
config: {
$(
$field:ident: $field_type:ty = $default:expr
$(=> validate($validator:expr))?
),* $(,)?
},
fit_fn: $fit_fn:ident,
predict_fn: $predict_fn:ident,
algorithm_type: $algorithm_type:ident
) => {
#[derive(Debug, Clone)]
pub struct [<$name Config>] {
$(
pub $field: $field_type,
)*
}
impl Default for [<$name Config>] {
fn default() -> Self {
Self {
$(
$field: $default,
)*
}
}
}
impl [<$name Config>] {
pub fn validate(&self) -> $crate::error::Result<()> {
$(
$(
if !($validator)(self.$field) {
return Err($crate::error::SklearsError::InvalidInput(
format!("Validation failed for {}: {:?}", stringify!($field), self.$field)
));
}
)?
)*
Ok(())
}
}
};
}
#[macro_export]
macro_rules! benchmark_suite {
(
algorithm: $algo:ident,
datasets: [$($dataset:ident),* $(,)?],
metrics: [$($metric:ident),* $(,)?],
iterations: $iters:expr
) => {
#[allow(non_snake_case)]
#[cfg(test)]
mod benchmarks {
use super::*;
use std::time::Instant;
$(
#[test]
fn [<bench_ $algo:snake _ $dataset>]() {
let dataset = $dataset();
let mut total_fit_time = std::time::Duration::new(0, 0);
let mut total_predict_time = std::time::Duration::new(0, 0);
for _ in 0..$iters {
let algo = $algo::default();
let start = Instant::now();
let _ = start.elapsed(); }
println!("{} benchmark completed", stringify!($algo));
}
)*
}
};
}
#[macro_export]
macro_rules! simd_operations {
($(
$op_name:ident: ($($param:ident: $param_type:ty),*) -> $return_type:ty {
simd: |$($simd_param:ident),*| $simd_impl:expr,
fallback: |$($fallback_param:ident),*| $fallback_impl:expr
}
),* $(,)?) => {
$(
pub fn $op_name($($param: $param_type),*) -> $return_type {
#[cfg(target_feature = "avx2")]
{
$fallback_impl }
#[cfg(not(target_feature = "avx2"))]
{
$fallback_impl
}
}
)*
};
}