polars_testing/asserts/utils.rs
1use std::ops::Not;
2
3use polars_core::datatypes::unpack_dtypes;
4use polars_core::prelude::*;
5use polars_ops::series::is_close;
6
7/// Configuration options for comparing Series equality.
8///
9/// Controls the behavior of Series equality comparisons by specifying
10/// which aspects to check and the tolerance for floating point comparisons.
11pub struct SeriesEqualOptions {
12 /// Whether to check that the data types match.
13 pub check_dtypes: bool,
14 /// Whether to check that the Series names match.
15 pub check_names: bool,
16 /// Whether to check that elements appear in the same order.
17 pub check_order: bool,
18 /// Whether to check for exact equality (true) or approximate equality (false) for floating point values.
19 pub check_exact: bool,
20 /// Relative tolerance for approximate equality of floating point values.
21 pub rel_tol: f64,
22 /// Absolute tolerance for approximate equality of floating point values.
23 pub abs_tol: f64,
24 /// Whether to compare categorical values as strings.
25 pub categorical_as_str: bool,
26}
27
28impl Default for SeriesEqualOptions {
29 /// Creates a new `SeriesEqualOptions` with default settings.
30 ///
31 /// Default configuration:
32 /// - Checks data types, names, and order
33 /// - Uses exact equality comparisons
34 /// - Sets relative tolerance to 1e-5 and absolute tolerance to 1e-8 for floating point comparisons
35 /// - Does not convert categorical values to strings for comparison
36 fn default() -> Self {
37 Self {
38 check_dtypes: true,
39 check_names: true,
40 check_order: true,
41 check_exact: true,
42 rel_tol: 1e-5,
43 abs_tol: 1e-8,
44 categorical_as_str: false,
45 }
46 }
47}
48
49impl SeriesEqualOptions {
50 /// Creates a new `SeriesEqualOptions` with default settings.
51 pub fn new() -> Self {
52 Self::default()
53 }
54
55 /// Sets whether to check that data types match.
56 pub fn with_check_dtypes(mut self, value: bool) -> Self {
57 self.check_dtypes = value;
58 self
59 }
60
61 /// Sets whether to check that Series names match.
62 pub fn with_check_names(mut self, value: bool) -> Self {
63 self.check_names = value;
64 self
65 }
66
67 /// Sets whether to check that elements appear in the same order.
68 pub fn with_check_order(mut self, value: bool) -> Self {
69 self.check_order = value;
70 self
71 }
72
73 /// Sets whether to check for exact equality (true) or approximate equality (false) for floating point values.
74 pub fn with_check_exact(mut self, value: bool) -> Self {
75 self.check_exact = value;
76 self
77 }
78
79 /// Sets the relative tolerance for approximate equality of floating point values.
80 pub fn with_rel_tol(mut self, value: f64) -> Self {
81 self.rel_tol = value;
82 self
83 }
84
85 /// Sets the absolute tolerance for approximate equality of floating point values.
86 pub fn with_abs_tol(mut self, value: f64) -> Self {
87 self.abs_tol = value;
88 self
89 }
90
91 /// Sets whether to compare categorical values as strings.
92 pub fn with_categorical_as_str(mut self, value: bool) -> Self {
93 self.categorical_as_str = value;
94 self
95 }
96}
97
98/// Change a (possibly nested) Categorical data type to a String data type.
99fn categorical_dtype_to_string_dtype(dtype: &DataType) -> DataType {
100 match dtype {
101 DataType::Categorical(..) => DataType::String,
102 DataType::List(inner) => {
103 let inner_cast = categorical_dtype_to_string_dtype(inner);
104 DataType::List(Box::new(inner_cast))
105 },
106 DataType::Array(inner, size) => {
107 let inner_cast = categorical_dtype_to_string_dtype(inner);
108 DataType::Array(Box::new(inner_cast), *size)
109 },
110 DataType::Struct(fields) => {
111 let transformed_fields = fields
112 .iter()
113 .map(|field| {
114 Field::new(
115 field.name().clone(),
116 categorical_dtype_to_string_dtype(field.dtype()),
117 )
118 })
119 .collect::<Vec<Field>>();
120
121 DataType::Struct(transformed_fields)
122 },
123 _ => dtype.clone(),
124 }
125}
126
127/// Cast a (possibly nested) Categorical Series to a String Series.
128fn categorical_series_to_string(s: &Series) -> PolarsResult<Series> {
129 let dtype = s.dtype();
130 let noncat_dtype = categorical_dtype_to_string_dtype(dtype);
131
132 if *dtype != noncat_dtype {
133 Ok(s.cast(&noncat_dtype)?)
134 } else {
135 Ok(s.clone())
136 }
137}
138
139/// Returns true if both DataTypes are floating point types.
140fn are_both_floats(left: &DataType, right: &DataType) -> bool {
141 left.is_float() && right.is_float()
142}
143
144/// Returns true if both DataTypes are list-like (either List or Array types).
145fn are_both_lists(left: &DataType, right: &DataType) -> bool {
146 matches!(left, DataType::List(_) | DataType::Array(_, _))
147 && matches!(right, DataType::List(_) | DataType::Array(_, _))
148}
149
150/// Returns true if both DataTypes are struct types.
151fn are_both_structs(left: &DataType, right: &DataType) -> bool {
152 left.is_struct() && right.is_struct()
153}
154
155/// Returns true if both DataTypes are nested types (lists or structs) that contain floating point types within them.
156/// First checks if both types are either lists or structs, then unpacks their nested DataTypes to determine if
157/// at least one floating point type exists in each of the nested structures.
158fn comparing_nested_floats(left: &DataType, right: &DataType) -> bool {
159 if !are_both_lists(left, right) && !are_both_structs(left, right) {
160 return false;
161 }
162
163 let left_dtypes = unpack_dtypes(left, false);
164 let right_dtypes = unpack_dtypes(right, false);
165
166 let left_has_floats = left_dtypes.iter().any(|dt| dt.is_float());
167 let right_has_floats = right_dtypes.iter().any(|dt| dt.is_float());
168
169 left_has_floats && right_has_floats
170}
171
172/// Ensures that null values in two Series match exactly and returns an error if any mismatches are found.
173fn assert_series_null_values_match(left: &Series, right: &Series) -> PolarsResult<()> {
174 let null_value_mismatch = left.is_null().not_equal(&right.is_null());
175
176 if null_value_mismatch.any() {
177 return Err(polars_err!(
178 assertion_error = "Series",
179 "null value mismatch",
180 left.null_count(),
181 right.null_count()
182 ));
183 }
184
185 Ok(())
186}
187
188/// Validates that NaN patterns are identical between two float Series, returning error if any mismatches are found.
189fn assert_series_nan_values_match(left: &Series, right: &Series) -> PolarsResult<()> {
190 if !are_both_floats(left.dtype(), right.dtype()) {
191 return Ok(());
192 }
193 let left_nan = left.is_nan()?;
194 let right_nan = right.is_nan()?;
195
196 let nan_value_mismatch = left_nan.not_equal(&right_nan);
197
198 let left_nan_count = left_nan.sum().unwrap_or(0);
199 let right_nan_count = right_nan.sum().unwrap_or(0);
200
201 if nan_value_mismatch.any() {
202 return Err(polars_err!(
203 assertion_error = "Series",
204 "nan value mismatch",
205 left_nan_count,
206 right_nan_count
207 ));
208 }
209
210 Ok(())
211}
212
213/// Verifies that two Series have values within a specified tolerance.
214///
215/// This function checks if the values in `left` and `right` Series that are marked as unequal
216/// in the `unequal` boolean array are within the specified relative and absolute tolerances.
217///
218/// # Arguments
219///
220/// * `left` - The first Series to compare
221/// * `right` - The second Series to compare
222/// * `unequal` - Boolean ChunkedArray indicating which elements to check (true = check this element)
223/// * `rel_tol` - Relative tolerance (relative to the maximum absolute value of the two Series)
224/// * `abs_tol` - Absolute tolerance added to the relative tolerance
225///
226/// # Returns
227///
228/// * `Ok(())` if all values are within tolerance
229/// * `Err` with details about problematic values if any values exceed the tolerance
230///
231/// # Formula
232///
233/// Values are considered within tolerance if:
234/// `|left - right| <= max(rel_tol * max(abs(left), abs(right)), abs_tol)` OR values are exactly equal
235///
236fn assert_series_values_within_tolerance(
237 left: &Series,
238 right: &Series,
239 unequal: &ChunkedArray<BooleanType>,
240 rel_tol: f64,
241 abs_tol: f64,
242) -> PolarsResult<()> {
243 let left_unequal = left.filter(unequal)?;
244 let right_unequal = right.filter(unequal)?;
245
246 let within_tolerance = is_close(&left_unequal, &right_unequal, abs_tol, rel_tol, false)?;
247 if within_tolerance.all() {
248 Ok(())
249 } else {
250 let exceeded_indices = within_tolerance.not();
251 let problematic_left = left_unequal.filter(&exceeded_indices)?;
252 let problematic_right = right_unequal.filter(&exceeded_indices)?;
253
254 Err(polars_err!(
255 assertion_error = "Series",
256 "values not within tolerance",
257 problematic_left,
258 problematic_right
259 ))
260 }
261}
262
263/// Compares two Series for equality with configurable options for ordering, exact matching, and tolerance.
264///
265/// This function verifies that the values in `left` and `right` Series are equal according to
266/// the specified comparison criteria. It handles different types including floats and nested types
267/// with appropriate equality checks.
268///
269/// # Arguments
270///
271/// * `left` - The first Series to compare
272/// * `right` - The second Series to compare
273/// * `check_order` - If true, elements must be in the same order; if false, Series will be sorted before comparison
274/// * `check_exact` - If true, requires exact equality; if false, allows approximate equality for floats within tolerance
275/// * `rel_tol` - Relative tolerance for float comparison (used when `check_exact` is false)
276/// * `abs_tol` - Absolute tolerance for float comparison (used when `check_exact` is false)
277/// * `categorical_as_str` - If true, converts categorical Series to strings before comparison
278///
279/// # Returns
280///
281/// * `Ok(())` if Series match according to specified criteria
282/// * `Err` with details about mismatches if Series differ
283///
284/// # Behavior
285///
286/// 1. Handles categorical Series based on `categorical_as_str` flag
287/// 2. Sorts Series if `check_order` is false
288/// 3. For nested float types, delegates to `assert_series_nested_values_equal`
289/// 4. For non-float types or when `check_exact` is true, requires exact match
290/// 5. For float types with approximate matching:
291/// - Verifies null values match using `assert_series_null_values_match`
292/// - Verifies NaN values match using `assert_series_nan_values_match`
293/// - Verifies float values are within tolerance using `assert_series_values_within_tolerance`
294///
295#[allow(clippy::too_many_arguments)]
296fn assert_series_values_equal(
297 left: &Series,
298 right: &Series,
299 check_order: bool,
300 check_exact: bool,
301 check_dtypes: bool,
302 rel_tol: f64,
303 abs_tol: f64,
304 categorical_as_str: bool,
305) -> PolarsResult<()> {
306 // When `check_dtypes` is `false` and both series are entirely null,
307 // consider them equal regardless of their underlying data types
308 if !check_dtypes && left.dtype() != right.dtype() {
309 if left.null_count() == left.len() && right.null_count() == right.len() {
310 return Ok(());
311 }
312 }
313
314 let (left, right) = if categorical_as_str {
315 (
316 categorical_series_to_string(left)?,
317 categorical_series_to_string(right)?,
318 )
319 } else {
320 (left.clone(), right.clone())
321 };
322
323 let (left, right) = if !check_order {
324 (
325 left.sort(SortOptions::default())?,
326 right.sort(SortOptions::default())?,
327 )
328 } else {
329 (left, right)
330 };
331
332 let unequal = match left.not_equal_missing(&right) {
333 Ok(result) => result,
334 Err(_) => {
335 return Err(polars_err!(
336 assertion_error = "Series",
337 "incompatible data types",
338 left.dtype(),
339 right.dtype()
340 ));
341 },
342 };
343
344 if comparing_nested_floats(left.dtype(), right.dtype()) {
345 let filtered_left = left.filter(&unequal)?;
346 let filtered_right = right.filter(&unequal)?;
347
348 match assert_series_nested_values_equal(
349 &filtered_left,
350 &filtered_right,
351 check_exact,
352 check_dtypes,
353 rel_tol,
354 abs_tol,
355 categorical_as_str,
356 ) {
357 Ok(_) => return Ok(()),
358 Err(_) => {
359 return Err(polars_err!(
360 assertion_error = "Series",
361 "nested value mismatch",
362 left,
363 right
364 ));
365 },
366 }
367 }
368
369 if !unequal.any() {
370 return Ok(());
371 }
372
373 if check_exact || !left.dtype().is_float() || !right.dtype().is_float() {
374 return Err(polars_err!(
375 assertion_error = "Series",
376 "exact value mismatch",
377 left,
378 right
379 ));
380 }
381
382 assert_series_null_values_match(&left, &right)?;
383 assert_series_nan_values_match(&left, &right)?;
384 assert_series_values_within_tolerance(&left, &right, &unequal, rel_tol, abs_tol)?;
385
386 Ok(())
387}
388
389/// Recursively compares nested Series structures (lists or structs) for equality.
390///
391/// This function handles the comparison of complex nested data structures by recursively
392/// applying appropriate equality checks based on the nested data type.
393///
394/// # Arguments
395///
396/// * `left` - The first nested Series to compare
397/// * `right` - The second nested Series to compare
398/// * `check_exact` - If true, requires exact equality; if false, allows approximate equality for floats
399/// * `rel_tol` - Relative tolerance for float comparison (used when `check_exact` is false)
400/// * `abs_tol` - Absolute tolerance for float comparison (used when `check_exact` is false)
401/// * `categorical_as_str` - If true, converts categorical Series to strings before comparison
402///
403/// # Returns
404///
405/// * `Ok(())` if nested Series match according to specified criteria
406/// * `Err` with details about mismatches if Series differ
407///
408/// # Behavior
409///
410/// For List types:
411/// 1. Iterates through corresponding elements in both Series
412/// 2. Returns error if null values are encountered
413/// 3. Creates single-element Series for each value and explodes them
414/// 4. Recursively calls `assert_series_values_equal` on the exploded Series
415///
416/// For Struct types:
417/// 1. Unnests both struct Series to access their columns
418/// 2. Iterates through corresponding columns
419/// 3. Recursively calls `assert_series_values_equal` on each column pair
420///
421fn assert_series_nested_values_equal(
422 left: &Series,
423 right: &Series,
424 check_exact: bool,
425 check_dtypes: bool,
426 rel_tol: f64,
427 abs_tol: f64,
428 categorical_as_str: bool,
429) -> PolarsResult<()> {
430 if are_both_lists(left.dtype(), right.dtype()) {
431 let zipped = left.iter().zip(right.iter());
432
433 for (s1, s2) in zipped {
434 if s1.is_null() || s2.is_null() {
435 return Err(polars_err!(
436 assertion_error = "Series",
437 "nested value mismatch",
438 s1,
439 s2
440 ));
441 } else {
442 let s1_series = Series::new("".into(), std::slice::from_ref(&s1));
443 let s2_series = Series::new("".into(), std::slice::from_ref(&s2));
444
445 assert_series_values_equal(
446 &s1_series.explode(ExplodeOptions {
447 empty_as_null: true,
448 keep_nulls: true,
449 })?,
450 &s2_series.explode(ExplodeOptions {
451 empty_as_null: true,
452 keep_nulls: true,
453 })?,
454 true,
455 check_exact,
456 check_dtypes,
457 rel_tol,
458 abs_tol,
459 categorical_as_str,
460 )?
461 }
462 }
463 } else {
464 let ls = left.struct_()?.clone().unnest();
465 let rs = right.struct_()?.clone().unnest();
466
467 for col_name in ls.get_column_names() {
468 let s1_column = ls.column(col_name)?;
469 let s2_column = rs.column(col_name)?;
470
471 let s1_series = s1_column.as_materialized_series();
472 let s2_series = s2_column.as_materialized_series();
473
474 assert_series_values_equal(
475 s1_series,
476 s2_series,
477 true,
478 check_exact,
479 check_dtypes,
480 rel_tol,
481 abs_tol,
482 categorical_as_str,
483 )?
484 }
485 }
486
487 Ok(())
488}
489
490/// Verifies that two Series are equal according to a set of configurable criteria.
491///
492/// This function serves as the main entry point for comparing Series, checking various
493/// metadata properties before comparing the actual values.
494///
495/// # Arguments
496///
497/// * `left` - The first Series to compare
498/// * `right` - The second Series to compare
499/// * `options` - A `SeriesEqualOptions` struct containing configuration parameters:
500/// * `check_names` - If true, verifies Series names match
501/// * `check_dtypes` - If true, verifies data types match
502/// * `check_order` - If true, elements must be in the same order
503/// * `check_exact` - If true, requires exact equality for float values
504/// * `rel_tol` - Relative tolerance for float comparison
505/// * `abs_tol` - Absolute tolerance for float comparison
506/// * `categorical_as_str` - If true, converts categorical Series to strings before comparison
507///
508/// # Returns
509///
510/// * `Ok(())` if Series match according to all specified criteria
511/// * `Err` with details about the first mismatch encountered:
512/// * Length mismatch
513/// * Name mismatch (if checking names)
514/// * Data type mismatch (if checking dtypes)
515/// * Value mismatches (via `assert_series_values_equal`)
516///
517/// # Order of Checks
518///
519/// 1. Series length
520/// 2. Series names (if `check_names` is true)
521/// 3. Data types (if `check_dtypes` is true)
522/// 4. Series values (delegated to `assert_series_values_equal`)
523///
524pub fn assert_series_equal(
525 left: &Series,
526 right: &Series,
527 options: SeriesEqualOptions,
528) -> PolarsResult<()> {
529 // Short-circuit if they're the same series object
530 if std::ptr::eq(left, right) {
531 return Ok(());
532 }
533
534 if left.len() != right.len() {
535 return Err(polars_err!(
536 assertion_error = "Series",
537 "length mismatch",
538 left.len(),
539 right.len()
540 ));
541 }
542
543 if options.check_names && left.name() != right.name() {
544 return Err(polars_err!(
545 assertion_error = "Series",
546 "name mismatch",
547 left.name(),
548 right.name()
549 ));
550 }
551
552 if options.check_dtypes && left.dtype() != right.dtype() {
553 return Err(polars_err!(
554 assertion_error = "Series",
555 "dtype mismatch",
556 left.dtype(),
557 right.dtype()
558 ));
559 }
560
561 assert_series_values_equal(
562 left,
563 right,
564 options.check_order,
565 options.check_exact,
566 options.check_dtypes,
567 options.rel_tol,
568 options.abs_tol,
569 options.categorical_as_str,
570 )
571}
572
573/// Configuration options for comparing DataFrame equality.
574///
575/// Controls the behavior of DataFrame equality comparisons by specifying
576/// which aspects to check and the tolerance for floating point comparisons.
577pub struct DataFrameEqualOptions {
578 /// Whether to check that rows appear in the same order.
579 pub check_row_order: bool,
580 /// Whether to check that columns appear in the same order.
581 pub check_column_order: bool,
582 /// Whether to check that the data types match for corresponding columns.
583 pub check_dtypes: bool,
584 /// Whether to check for exact equality (true) or approximate equality (false) for floating point values.
585 pub check_exact: bool,
586 /// Relative tolerance for approximate equality of floating point values.
587 pub rel_tol: f64,
588 /// Absolute tolerance for approximate equality of floating point values.
589 pub abs_tol: f64,
590 /// Whether to compare categorical values as strings.
591 pub categorical_as_str: bool,
592}
593
594impl Default for DataFrameEqualOptions {
595 /// Creates a new `DataFrameEqualOptions` with default settings.
596 ///
597 /// Default configuration:
598 /// - Checks row order, column order, and data types
599 /// - Uses approximate equality comparisons for floating point values
600 /// - Sets relative tolerance to 1e-5 and absolute tolerance to 1e-8 for floating point comparisons
601 /// - Does not convert categorical values to strings for comparison
602 fn default() -> Self {
603 Self {
604 check_row_order: true,
605 check_column_order: true,
606 check_dtypes: true,
607 check_exact: false,
608 rel_tol: 1e-5,
609 abs_tol: 1e-8,
610 categorical_as_str: false,
611 }
612 }
613}
614
615impl DataFrameEqualOptions {
616 /// Creates a new `DataFrameEqualOptions` with default settings.
617 pub fn new() -> Self {
618 Self::default()
619 }
620
621 /// Sets whether to check that rows appear in the same order.
622 pub fn with_check_row_order(mut self, value: bool) -> Self {
623 self.check_row_order = value;
624 self
625 }
626
627 /// Sets whether to check that columns appear in the same order.
628 pub fn with_check_column_order(mut self, value: bool) -> Self {
629 self.check_column_order = value;
630 self
631 }
632
633 /// Sets whether to check that data types match for corresponding columns.
634 pub fn with_check_dtypes(mut self, value: bool) -> Self {
635 self.check_dtypes = value;
636 self
637 }
638
639 /// Sets whether to check for exact equality (true) or approximate equality (false) for floating point values.
640 pub fn with_check_exact(mut self, value: bool) -> Self {
641 self.check_exact = value;
642 self
643 }
644
645 /// Sets the relative tolerance for approximate equality of floating point values.
646 pub fn with_rel_tol(mut self, value: f64) -> Self {
647 self.rel_tol = value;
648 self
649 }
650
651 /// Sets the absolute tolerance for approximate equality of floating point values.
652 pub fn with_abs_tol(mut self, value: f64) -> Self {
653 self.abs_tol = value;
654 self
655 }
656
657 /// Sets whether to compare categorical values as strings.
658 pub fn with_categorical_as_str(mut self, value: bool) -> Self {
659 self.categorical_as_str = value;
660 self
661 }
662}
663
664/// Compares DataFrame schemas for equality based on specified criteria.
665///
666/// This function validates that two DataFrames have compatible schemas by checking
667/// column names, their order, and optionally their data types according to the
668/// provided configuration parameters.
669///
670/// # Arguments
671///
672/// * `left` - The first DataFrame to compare
673/// * `right` - The second DataFrame to compare
674/// * `check_dtypes` - If true, requires data types to match for corresponding columns
675/// * `check_column_order` - If true, requires columns to appear in the same order
676///
677/// # Returns
678///
679/// * `Ok(())` if DataFrame schemas match according to specified criteria
680/// * `Err` with details about schema mismatches if DataFrames differ
681///
682/// # Behavior
683///
684/// The function performs schema validation in the following order:
685///
686/// 1. **Fast path**: Returns immediately if schemas are identical
687/// 2. **Column name validation**: Ensures both DataFrames have the same set of column names
688/// - Reports columns present in left but missing in right
689/// - Reports columns present in right but missing in left
690/// 3. **Column order validation**: If `check_column_order` is true, verifies columns appear in the same sequence
691/// 4. **Data type validation**: If `check_dtypes` is true, ensures corresponding columns have matching data types
692/// - When `check_column_order` is false, compares data type sets for equality
693/// - When `check_column_order` is true, performs more precise type checking
694///
695fn assert_dataframe_schema_equal(
696 left: &DataFrame,
697 right: &DataFrame,
698 check_dtypes: bool,
699 check_column_order: bool,
700) -> PolarsResult<()> {
701 let left_schema = left.schema();
702 let right_schema = right.schema();
703
704 let ordered_left_cols = left.get_column_names();
705 let ordered_right_cols = right.get_column_names();
706
707 let left_set: PlHashSet<&PlSmallStr> = ordered_left_cols.iter().copied().collect();
708 let right_set: PlHashSet<&PlSmallStr> = ordered_right_cols.iter().copied().collect();
709
710 // Fast path for equal DataFrames
711 if left_schema == right_schema {
712 return Ok(());
713 }
714
715 if left_set != right_set {
716 let left_not_right: Vec<_> = left_set
717 .iter()
718 .filter(|col| !right_set.contains(*col))
719 .collect();
720
721 if !left_not_right.is_empty() {
722 return Err(polars_err!(
723 assertion_error = "DataFrames",
724 format!(
725 "columns mismatch: {:?} in left, but not in right",
726 left_not_right
727 ),
728 format!("{:?}", left_set),
729 format!("{:?}", right_set)
730 ));
731 } else {
732 let right_not_left: Vec<_> = right_set
733 .iter()
734 .filter(|col| !left_set.contains(*col))
735 .collect();
736
737 return Err(polars_err!(
738 assertion_error = "DataFrames",
739 format!(
740 "columns mismatch: {:?} in right, but not in left",
741 right_not_left
742 ),
743 format!("{:?}", left_set),
744 format!("{:?}", right_set)
745 ));
746 }
747 }
748
749 if check_column_order && ordered_left_cols != ordered_right_cols {
750 return Err(polars_err!(
751 assertion_error = "DataFrames",
752 "columns are not in the same order",
753 format!("{:?}", ordered_left_cols),
754 format!("{:?}", ordered_right_cols)
755 ));
756 }
757
758 if check_dtypes {
759 if check_column_order {
760 let left_dtypes_ordered = left.dtypes();
761 let right_dtypes_ordered = right.dtypes();
762 if left_dtypes_ordered != right_dtypes_ordered {
763 return Err(polars_err!(
764 assertion_error = "DataFrames",
765 "dtypes do not match",
766 format!("{:?}", left_dtypes_ordered),
767 format!("{:?}", right_dtypes_ordered)
768 ));
769 }
770 } else {
771 let left_dtypes: PlHashSet<DataType> = left.dtypes().into_iter().collect();
772 let right_dtypes: PlHashSet<DataType> = right.dtypes().into_iter().collect();
773 if left_dtypes != right_dtypes {
774 return Err(polars_err!(
775 assertion_error = "DataFrames",
776 "dtypes do not match",
777 format!("{:?}", left_dtypes),
778 format!("{:?}", right_dtypes)
779 ));
780 }
781 }
782 }
783
784 Ok(())
785}
786
787/// Verifies that two DataFrames are equal according to a set of configurable criteria.
788///
789/// This function serves as the main entry point for comparing DataFrames, first validating
790/// schema compatibility and then comparing the actual data values column by column.
791///
792/// # Arguments
793///
794/// * `left` - The first DataFrame to compare
795/// * `right` - The second DataFrame to compare
796/// * `options` - A `DataFrameEqualOptions` struct containing configuration parameters:
797/// * `check_row_order` - If true, rows must be in the same order
798/// * `check_column_order` - If true, columns must be in the same order
799/// * `check_dtypes` - If true, verifies data types match for corresponding columns
800/// * `check_exact` - If true, requires exact equality for float values
801/// * `rel_tol` - Relative tolerance for float comparison
802/// * `abs_tol` - Absolute tolerance for float comparison
803/// * `categorical_as_str` - If true, converts categorical values to strings before comparison
804///
805/// # Returns
806///
807/// * `Ok(())` if DataFrames match according to all specified criteria
808/// * `Err` with details about the first mismatch encountered:
809/// * Schema mismatches (column names, order, or data types)
810/// * Height (row count) mismatch
811/// * Value mismatches in specific columns
812///
813/// # Order of Checks
814///
815/// 1. Schema validation (column names, order, and data types via `assert_dataframe_schema_equal`)
816/// 2. DataFrame height (row count)
817/// 3. Row ordering (sorts both DataFrames if `check_row_order` is false)
818/// 4. Column-by-column value comparison (delegated to `assert_series_values_equal`)
819///
820/// # Behavior
821///
822/// When `check_row_order` is false, both DataFrames are sorted using all columns to ensure
823/// consistent ordering before value comparison. This allows for row-order-independent equality
824/// checking while maintaining deterministic results.
825///
826pub fn assert_dataframe_equal(
827 left: &DataFrame,
828 right: &DataFrame,
829 options: DataFrameEqualOptions,
830) -> PolarsResult<()> {
831 // Short-circuit if they're the same DataFrame object
832 if std::ptr::eq(left, right) {
833 return Ok(());
834 }
835
836 assert_dataframe_schema_equal(
837 left,
838 right,
839 options.check_dtypes,
840 options.check_column_order,
841 )?;
842
843 if left.height() != right.height() {
844 return Err(polars_err!(
845 assertion_error = "DataFrames",
846 "height (row count) mismatch",
847 left.height(),
848 right.height()
849 ));
850 }
851
852 let left_cols = left.get_column_names_owned();
853
854 let (left, right) = if !options.check_row_order {
855 (
856 left.sort(left_cols.clone(), SortMultipleOptions::default())?,
857 right.sort(left_cols.clone(), SortMultipleOptions::default())?,
858 )
859 } else {
860 (left.clone(), right.clone())
861 };
862
863 for col in left_cols.iter() {
864 let s_left = left.column(col)?;
865 let s_right = right.column(col)?;
866
867 let s_left_series = s_left.as_materialized_series();
868 let s_right_series = s_right.as_materialized_series();
869
870 match assert_series_values_equal(
871 s_left_series,
872 s_right_series,
873 true,
874 options.check_exact,
875 options.check_dtypes,
876 options.rel_tol,
877 options.abs_tol,
878 options.categorical_as_str,
879 ) {
880 Ok(_) => {},
881 Err(_) => {
882 return Err(polars_err!(
883 assertion_error = "DataFrames",
884 format!("value mismatch for column {:?}", col),
885 format!("{:?}", s_left_series),
886 format!("{:?}", s_right_series)
887 ));
888 },
889 }
890 }
891
892 Ok(())
893}