use std::ops::Not;
use polars_core::datatypes::unpack_dtypes;
use polars_core::prelude::*;
use polars_ops::series::is_close;
pub struct SeriesEqualOptions {
pub check_dtypes: bool,
pub check_names: bool,
pub check_order: bool,
pub check_exact: bool,
pub rel_tol: f64,
pub abs_tol: f64,
pub categorical_as_str: bool,
}
impl Default for SeriesEqualOptions {
fn default() -> Self {
Self {
check_dtypes: true,
check_names: true,
check_order: true,
check_exact: true,
rel_tol: 1e-5,
abs_tol: 1e-8,
categorical_as_str: false,
}
}
}
impl SeriesEqualOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_check_dtypes(mut self, value: bool) -> Self {
self.check_dtypes = value;
self
}
pub fn with_check_names(mut self, value: bool) -> Self {
self.check_names = value;
self
}
pub fn with_check_order(mut self, value: bool) -> Self {
self.check_order = value;
self
}
pub fn with_check_exact(mut self, value: bool) -> Self {
self.check_exact = value;
self
}
pub fn with_rel_tol(mut self, value: f64) -> Self {
self.rel_tol = value;
self
}
pub fn with_abs_tol(mut self, value: f64) -> Self {
self.abs_tol = value;
self
}
pub fn with_categorical_as_str(mut self, value: bool) -> Self {
self.categorical_as_str = value;
self
}
}
fn categorical_dtype_to_string_dtype(dtype: &DataType) -> DataType {
match dtype {
DataType::Categorical(..) => DataType::String,
DataType::List(inner) => {
let inner_cast = categorical_dtype_to_string_dtype(inner);
DataType::List(Box::new(inner_cast))
},
DataType::Array(inner, size) => {
let inner_cast = categorical_dtype_to_string_dtype(inner);
DataType::Array(Box::new(inner_cast), *size)
},
DataType::Struct(fields) => {
let transformed_fields = fields
.iter()
.map(|field| {
Field::new(
field.name().clone(),
categorical_dtype_to_string_dtype(field.dtype()),
)
})
.collect::<Vec<Field>>();
DataType::Struct(transformed_fields)
},
_ => dtype.clone(),
}
}
fn categorical_series_to_string(s: &Series) -> PolarsResult<Series> {
let dtype = s.dtype();
let noncat_dtype = categorical_dtype_to_string_dtype(dtype);
if *dtype != noncat_dtype {
Ok(s.cast(&noncat_dtype)?)
} else {
Ok(s.clone())
}
}
fn are_both_floats(left: &DataType, right: &DataType) -> bool {
left.is_float() && right.is_float()
}
fn are_both_lists(left: &DataType, right: &DataType) -> bool {
matches!(left, DataType::List(_) | DataType::Array(_, _))
&& matches!(right, DataType::List(_) | DataType::Array(_, _))
}
fn are_both_structs(left: &DataType, right: &DataType) -> bool {
left.is_struct() && right.is_struct()
}
fn comparing_nested_floats(left: &DataType, right: &DataType) -> bool {
if !are_both_lists(left, right) && !are_both_structs(left, right) {
return false;
}
let left_dtypes = unpack_dtypes(left, false);
let right_dtypes = unpack_dtypes(right, false);
let left_has_floats = left_dtypes.iter().any(|dt| dt.is_float());
let right_has_floats = right_dtypes.iter().any(|dt| dt.is_float());
left_has_floats && right_has_floats
}
fn assert_series_null_values_match(left: &Series, right: &Series) -> PolarsResult<()> {
let null_value_mismatch = left.is_null().not_equal(&right.is_null());
if null_value_mismatch.any() {
return Err(polars_err!(
assertion_error = "Series",
"null value mismatch",
left.null_count(),
right.null_count()
));
}
Ok(())
}
fn assert_series_nan_values_match(left: &Series, right: &Series) -> PolarsResult<()> {
if !are_both_floats(left.dtype(), right.dtype()) {
return Ok(());
}
let left_nan = left.is_nan()?;
let right_nan = right.is_nan()?;
let nan_value_mismatch = left_nan.not_equal(&right_nan);
let left_nan_count = left_nan.sum().unwrap_or(0);
let right_nan_count = right_nan.sum().unwrap_or(0);
if nan_value_mismatch.any() {
return Err(polars_err!(
assertion_error = "Series",
"nan value mismatch",
left_nan_count,
right_nan_count
));
}
Ok(())
}
fn assert_series_values_within_tolerance(
left: &Series,
right: &Series,
unequal: &ChunkedArray<BooleanType>,
rel_tol: f64,
abs_tol: f64,
) -> PolarsResult<()> {
let left_unequal = left.filter(unequal)?;
let right_unequal = right.filter(unequal)?;
let within_tolerance = is_close(&left_unequal, &right_unequal, abs_tol, rel_tol, false)?;
if within_tolerance.all() {
Ok(())
} else {
let exceeded_indices = within_tolerance.not();
let problematic_left = left_unequal.filter(&exceeded_indices)?;
let problematic_right = right_unequal.filter(&exceeded_indices)?;
Err(polars_err!(
assertion_error = "Series",
"values not within tolerance",
problematic_left,
problematic_right
))
}
}
#[allow(clippy::too_many_arguments)]
fn assert_series_values_equal(
left: &Series,
right: &Series,
check_order: bool,
check_exact: bool,
check_dtypes: bool,
rel_tol: f64,
abs_tol: f64,
categorical_as_str: bool,
) -> PolarsResult<()> {
if !check_dtypes && left.dtype() != right.dtype() {
if left.null_count() == left.len() && right.null_count() == right.len() {
return Ok(());
}
}
let (left, right) = if categorical_as_str {
(
categorical_series_to_string(left)?,
categorical_series_to_string(right)?,
)
} else {
(left.clone(), right.clone())
};
let (left, right) = if !check_order {
(
left.sort(SortOptions::default())?,
right.sort(SortOptions::default())?,
)
} else {
(left, right)
};
let unequal = match left.not_equal_missing(&right) {
Ok(result) => result,
Err(_) => {
return Err(polars_err!(
assertion_error = "Series",
"incompatible data types",
left.dtype(),
right.dtype()
));
},
};
if comparing_nested_floats(left.dtype(), right.dtype()) {
let filtered_left = left.filter(&unequal)?;
let filtered_right = right.filter(&unequal)?;
match assert_series_nested_values_equal(
&filtered_left,
&filtered_right,
check_exact,
check_dtypes,
rel_tol,
abs_tol,
categorical_as_str,
) {
Ok(_) => return Ok(()),
Err(_) => {
return Err(polars_err!(
assertion_error = "Series",
"nested value mismatch",
left,
right
));
},
}
}
if !unequal.any() {
return Ok(());
}
if check_exact || !left.dtype().is_float() || !right.dtype().is_float() {
return Err(polars_err!(
assertion_error = "Series",
"exact value mismatch",
left,
right
));
}
assert_series_null_values_match(&left, &right)?;
assert_series_nan_values_match(&left, &right)?;
assert_series_values_within_tolerance(&left, &right, &unequal, rel_tol, abs_tol)?;
Ok(())
}
fn assert_series_nested_values_equal(
left: &Series,
right: &Series,
check_exact: bool,
check_dtypes: bool,
rel_tol: f64,
abs_tol: f64,
categorical_as_str: bool,
) -> PolarsResult<()> {
if are_both_lists(left.dtype(), right.dtype()) {
let zipped = left.iter().zip(right.iter());
for (s1, s2) in zipped {
if s1.is_null() || s2.is_null() {
return Err(polars_err!(
assertion_error = "Series",
"nested value mismatch",
s1,
s2
));
} else {
let s1_series = Series::new("".into(), std::slice::from_ref(&s1));
let s2_series = Series::new("".into(), std::slice::from_ref(&s2));
assert_series_values_equal(
&s1_series.explode(ExplodeOptions {
empty_as_null: true,
keep_nulls: true,
})?,
&s2_series.explode(ExplodeOptions {
empty_as_null: true,
keep_nulls: true,
})?,
true,
check_exact,
check_dtypes,
rel_tol,
abs_tol,
categorical_as_str,
)?
}
}
} else {
let ls = left.struct_()?.clone().unnest();
let rs = right.struct_()?.clone().unnest();
for col_name in ls.get_column_names() {
let s1_column = ls.column(col_name)?;
let s2_column = rs.column(col_name)?;
let s1_series = s1_column.as_materialized_series();
let s2_series = s2_column.as_materialized_series();
assert_series_values_equal(
s1_series,
s2_series,
true,
check_exact,
check_dtypes,
rel_tol,
abs_tol,
categorical_as_str,
)?
}
}
Ok(())
}
pub fn assert_series_equal(
left: &Series,
right: &Series,
options: SeriesEqualOptions,
) -> PolarsResult<()> {
if std::ptr::eq(left, right) {
return Ok(());
}
if left.len() != right.len() {
return Err(polars_err!(
assertion_error = "Series",
"length mismatch",
left.len(),
right.len()
));
}
if options.check_names && left.name() != right.name() {
return Err(polars_err!(
assertion_error = "Series",
"name mismatch",
left.name(),
right.name()
));
}
if options.check_dtypes && left.dtype() != right.dtype() {
return Err(polars_err!(
assertion_error = "Series",
"dtype mismatch",
left.dtype(),
right.dtype()
));
}
assert_series_values_equal(
left,
right,
options.check_order,
options.check_exact,
options.check_dtypes,
options.rel_tol,
options.abs_tol,
options.categorical_as_str,
)
}
pub struct DataFrameEqualOptions {
pub check_row_order: bool,
pub check_column_order: bool,
pub check_dtypes: bool,
pub check_exact: bool,
pub rel_tol: f64,
pub abs_tol: f64,
pub categorical_as_str: bool,
}
impl Default for DataFrameEqualOptions {
fn default() -> Self {
Self {
check_row_order: true,
check_column_order: true,
check_dtypes: true,
check_exact: false,
rel_tol: 1e-5,
abs_tol: 1e-8,
categorical_as_str: false,
}
}
}
impl DataFrameEqualOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_check_row_order(mut self, value: bool) -> Self {
self.check_row_order = value;
self
}
pub fn with_check_column_order(mut self, value: bool) -> Self {
self.check_column_order = value;
self
}
pub fn with_check_dtypes(mut self, value: bool) -> Self {
self.check_dtypes = value;
self
}
pub fn with_check_exact(mut self, value: bool) -> Self {
self.check_exact = value;
self
}
pub fn with_rel_tol(mut self, value: f64) -> Self {
self.rel_tol = value;
self
}
pub fn with_abs_tol(mut self, value: f64) -> Self {
self.abs_tol = value;
self
}
pub fn with_categorical_as_str(mut self, value: bool) -> Self {
self.categorical_as_str = value;
self
}
}
pub fn assert_schema_equal(
left_schema: &Schema,
right_schema: &Schema,
check_dtypes: bool,
check_column_order: bool,
) -> PolarsResult<()> {
assert_schema_equal_impl(
left_schema,
right_schema,
check_dtypes,
check_column_order,
"Schemas",
)
}
fn assert_schema_equal_impl(
left_schema: &Schema,
right_schema: &Schema,
check_dtypes: bool,
check_column_order: bool,
context: &'static str,
) -> PolarsResult<()> {
let mut one_sided_names: Vec<&PlSmallStr> = vec![];
let mut column_name_order_mismatch = false;
let mut dtype_mismatch = false;
for (l_idx, (l_name, l_dtype)) in left_schema.iter().enumerate() {
let Some((r_idx, _, r_dtype)) = right_schema.get_full(l_name) else {
one_sided_names.reserve_exact(left_schema.len() - l_idx);
one_sided_names.push(l_name);
continue;
};
if check_column_order && l_idx != r_idx {
column_name_order_mismatch = true;
}
if check_dtypes && l_dtype != r_dtype {
dtype_mismatch = true;
}
}
if !one_sided_names.is_empty() {
polars_bail!(
assertion_error = context,
format!(
"columns mismatch: {:?} in left, but not in right",
one_sided_names
),
left_schema.names_display(),
right_schema.names_display()
)
}
debug_assert!(right_schema.len() >= left_schema.len());
if right_schema.len() > left_schema.len() {
one_sided_names.reserve_exact(right_schema.len() - left_schema.len());
one_sided_names.extend(
right_schema
.iter_names()
.filter(|name| !left_schema.contains(name)),
);
polars_bail!(
assertion_error = context,
format!(
"columns mismatch: {:?} in right, but not in left",
one_sided_names
),
left_schema.names_display(),
right_schema.names_display()
)
}
debug_assert_eq!(left_schema.len(), right_schema.len());
if check_column_order && column_name_order_mismatch {
polars_bail!(
assertion_error = context,
"columns are not in the same order",
left_schema.names_display(),
right_schema.names_display()
)
}
if check_dtypes && dtype_mismatch {
polars_bail!(
assertion_error = context,
"dtypes do not match",
left_schema.values_display(),
right_schema.values_display()
)
}
Ok(())
}
pub fn assert_dataframe_equal(
left: &DataFrame,
right: &DataFrame,
options: DataFrameEqualOptions,
) -> PolarsResult<()> {
if std::ptr::eq(left, right) {
return Ok(());
}
let left_schema = left.schema();
let right_schema = right.schema();
assert_schema_equal_impl(
left_schema,
right_schema,
options.check_dtypes,
options.check_column_order,
"DataFrames",
)?;
if left.height() != right.height() {
return Err(polars_err!(
assertion_error = "DataFrames",
"height (row count) mismatch",
left.height(),
right.height()
));
}
let left_cols = left.get_column_names_owned();
let (left, right) = if !options.check_row_order {
(
left.sort(left_cols.clone(), SortMultipleOptions::default())?,
right.sort(left_cols.clone(), SortMultipleOptions::default())?,
)
} else {
(left.clone(), right.clone())
};
for col in left_cols.iter() {
let s_left = left.column(col)?;
let s_right = right.column(col)?;
let s_left_series = s_left.as_materialized_series();
let s_right_series = s_right.as_materialized_series();
match assert_series_values_equal(
s_left_series,
s_right_series,
true,
options.check_exact,
options.check_dtypes,
options.rel_tol,
options.abs_tol,
options.categorical_as_str,
) {
Ok(_) => {},
Err(_) => {
return Err(polars_err!(
assertion_error = "DataFrames",
format!("value mismatch for column {:?}", col),
format!("{:?}", s_left_series),
format!("{:?}", s_right_series)
));
},
}
}
Ok(())
}