use smartcore::{algorithm::neighbour::KNNAlgorithmName, neighbors::KNNWeightFunction};
use std::fmt::{Debug, Display, Formatter};
pub fn print_option<T: Display>(x: Option<T>) -> String {
x.map_or_else(|| "None".to_string(), |y| format!("{y}"))
}
pub fn debug_option<T: Debug>(x: Option<T>) -> String {
x.map_or_else(|| "None".to_string(), |y| format!("{y:#?}"))
}
pub fn print_knn_weight_function(f: &KNNWeightFunction) -> String {
match f {
KNNWeightFunction::Uniform => "Uniform".to_string(),
KNNWeightFunction::Distance => "Distance".to_string(),
}
}
pub fn print_knn_search_algorithm(a: &KNNAlgorithmName) -> String {
match a {
KNNAlgorithmName::LinearSearch => "Linear Search".to_string(),
KNNAlgorithmName::CoverTree => "Cover Tree".to_string(),
}
}
#[derive(serde::Serialize, serde::Deserialize)]
pub enum Kernel {
Linear,
Polynomial(f32, f32, f32),
RBF(f32),
Sigmoid(f32, f32),
}
impl Display for Kernel {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::Linear => write!(f, "Linear"),
Self::Polynomial(degree, gamma, coef) => write!(
f,
"Polynomial\n degree = {degree}\n gamma = {gamma}\n coef = {coef}"
),
Self::RBF(gamma) => write!(f, "RBF\n gamma = {gamma}"),
Self::Sigmoid(gamma, coef) => {
write!(f, "Sigmoid\n gamma = {gamma}\n coef = {coef}")
}
}
}
}
#[derive(serde::Serialize, serde::Deserialize)]
pub enum Distance {
Euclidean,
Manhattan,
Minkowski(u16),
Mahalanobis,
Hamming,
}
impl Display for Distance {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::Euclidean => write!(f, "Euclidean"),
Self::Manhattan => write!(f, "Manhattan"),
Self::Minkowski(n) => write!(f, "Minkowski\n p = {n}"),
Self::Mahalanobis => write!(f, "Mahalanobis"),
Self::Hamming => write!(f, "Hamming"),
}
}
}
pub fn elementwise_multiply(v1: &[f32], v2: &[f32]) -> Vec<f32> {
v1.iter().zip(v2).map(|(&i1, &i2)| i1 * i2).collect()
}
#[cfg(any(feature = "csv"))]
use polars::prelude::{CsvReader, DataFrame, PolarsError, SerReader};
#[cfg(any(feature = "csv"))]
pub fn validate_and_read<P>(file_path: P) -> DataFrame
where
P: AsRef<std::path::Path>,
{
let file_path_as_str = file_path.as_ref().to_str().unwrap();
CsvReader::from_path(file_path_as_str).map_or_else(
|_| {
if url::Url::parse(file_path_as_str).is_ok() {
let file_contents = minreq::get(file_path_as_str)
.send()
.expect("Could not open URL");
let temp = temp_file::with_contents(file_contents.as_bytes());
validate_and_read(temp.path().to_str().unwrap())
} else {
panic!("The string {file_path_as_str} is not a valid URL or file path.")
}
},
|csv| {
csv.infer_schema(Some(10))
.has_header(
csv_sniffer::Sniffer::new()
.sniff_path(file_path_as_str)
.expect("Cannot sniff file")
.dialect
.header
.has_header_row,
)
.finish()
.expect("Cannot read file as CSV")
.drop_nulls(None)
.expect("Cannot remove null values")
.convert_to_float()
.expect("Cannot convert types")
},
)
}
#[cfg(any(feature = "csv"))]
trait Cleanup {
fn convert_to_float(self) -> Result<DataFrame, PolarsError>;
}
#[cfg(any(feature = "csv"))]
impl Cleanup for DataFrame {
#[allow(unused_mut)]
fn convert_to_float(mut self) -> Result<DataFrame, PolarsError> {
Ok(self)
}
}