polars_ffi/
polars_ffi.rs

1//! ---------------------------------------------------------
2//! Minarrow <->️ Polars (via polars_arrow/arrow2) FFI roundtrip
3//!
4//! Run with:
5//!    cargo run --example polars_ffi --features cast_polars
6//!
7//! This is for custom FFI - you can instead also directly go to polars
8//! via `to_polars()` from the `Array`, `FieldArray` or `Table`
9//! types when the *cast_polars* feature is activated.
10//! ---------------------------------------------------------
11
12#[cfg(feature = "cast_polars")]
13use crate::polars_roundtrip::run_example;
14
15#[cfg(feature = "cast_polars")]
16mod polars_roundtrip {
17    use std::sync::Arc;
18
19    use minarrow::ffi::arrow_c_ffi::{export_to_c, import_from_c};
20    use minarrow::ffi::arrow_dtype::CategoricalIndexType;
21    use minarrow::ffi::schema::Schema;
22    use minarrow::{Array, ArrowType, Field, FieldArray, NumericArray, Table, TextArray};
23    #[cfg(feature = "datetime")]
24    use minarrow::{TemporalArray, TimeUnit};
25    use polars::prelude::*;
26    use polars_arrow as pa;
27
28    // -------------------------------------------------------------------------
29    // Build test table with full type coverage
30    // -------------------------------------------------------------------------
31    fn build_minarrow_table() -> Table {
32        // Arrays
33        #[cfg(feature = "extended_numeric_types")]
34        let arr_int8 = Arc::new(minarrow::IntegerArray::<i8>::from_slice(&[1, 2, -1])) as Arc<_>;
35        #[cfg(feature = "extended_numeric_types")]
36        let arr_int16 =
37            Arc::new(minarrow::IntegerArray::<i16>::from_slice(&[10, 20, -10])) as Arc<_>;
38        let arr_int32 =
39            Arc::new(minarrow::IntegerArray::<i32>::from_slice(&[100, 200, -100])) as Arc<_>;
40        let arr_int64 = Arc::new(minarrow::IntegerArray::<i64>::from_slice(&[
41            1000, 2000, -1000,
42        ])) as Arc<_>;
43
44        #[cfg(feature = "extended_numeric_types")]
45        let arr_uint8 = Arc::new(minarrow::IntegerArray::<u8>::from_slice(&[1, 2, 255]))
46            as Arc<minarrow::IntegerArray<u8>>;
47        #[cfg(feature = "extended_numeric_types")]
48        let arr_uint16 = Arc::new(minarrow::IntegerArray::<u16>::from_slice(&[1, 2, 65535]))
49            as Arc<minarrow::IntegerArray<u16>>;
50        let arr_uint32 = Arc::new(minarrow::IntegerArray::<u32>::from_slice(&[
51            1, 2, 4294967295,
52        ])) as Arc<minarrow::IntegerArray<u32>>;
53        let arr_uint64 = Arc::new(minarrow::IntegerArray::<u64>::from_slice(&[
54            1,
55            2,
56            18446744073709551615,
57        ])) as Arc<minarrow::IntegerArray<u64>>;
58
59        let arr_float32 = Arc::new(minarrow::FloatArray::<f32>::from_slice(&[1.5, -0.5, 0.0]))
60            as Arc<minarrow::FloatArray<f32>>;
61        let arr_float64 = Arc::new(minarrow::FloatArray::<f64>::from_slice(&[1.0, -2.0, 0.0]))
62            as Arc<minarrow::FloatArray<f64>>;
63
64        let arr_bool = Arc::new(minarrow::BooleanArray::<()>::from_slice(&[
65            true, false, true,
66        ])) as Arc<minarrow::BooleanArray<()>>;
67
68        let arr_string32 = Arc::new(minarrow::StringArray::<u32>::from_slice(&[
69            "abc", "def", "",
70        ])) as Arc<minarrow::StringArray<u32>>;
71        let arr_categorical32 = Arc::new(minarrow::CategoricalArray::<u32>::from_slices(
72            &[0, 1, 2],
73            &["A".to_string(), "B".to_string(), "C".to_string()],
74        )) as Arc<minarrow::CategoricalArray<u32>>;
75
76        #[cfg(feature = "datetime")]
77        let arr_datetime32 = Arc::new(minarrow::DatetimeArray::<i32> {
78            data: minarrow::Buffer::<i32>::from_slice(&[
79                1_600_000_000 / 86_400,
80                1_600_000_001 / 86_400,
81                1_600_000_002 / 86_400,
82            ]),
83            null_mask: None,
84            time_unit: TimeUnit::Days,
85        });
86        #[cfg(feature = "datetime")]
87        let arr_datetime64 = Arc::new(minarrow::DatetimeArray::<i64> {
88            data: minarrow::Buffer::<i64>::from_slice(&[
89                1_600_000_000_000,
90                1_600_000_000_001,
91                1_600_000_000_002,
92            ]),
93            null_mask: None,
94            time_unit: TimeUnit::Milliseconds,
95        }) as Arc<_>;
96
97        // Wrap in Array enums
98        #[cfg(feature = "extended_numeric_types")]
99        let minarr_int8 = Array::NumericArray(NumericArray::Int8(arr_int8));
100        #[cfg(feature = "extended_numeric_types")]
101        let minarr_int16 = Array::NumericArray(NumericArray::Int16(arr_int16));
102        let minarr_int32 = Array::NumericArray(NumericArray::Int32(arr_int32));
103        let minarr_int64 = Array::NumericArray(NumericArray::Int64(arr_int64));
104        #[cfg(feature = "extended_numeric_types")]
105        let minarr_uint8 = Array::NumericArray(NumericArray::UInt8(arr_uint8));
106        #[cfg(feature = "extended_numeric_types")]
107        let minarr_uint16 = Array::NumericArray(NumericArray::UInt16(arr_uint16));
108        let minarr_uint32 = Array::NumericArray(NumericArray::UInt32(arr_uint32));
109        let minarr_uint64 = Array::NumericArray(NumericArray::UInt64(arr_uint64));
110        let minarr_float32 = Array::NumericArray(NumericArray::Float32(arr_float32));
111        let minarr_float64 = Array::NumericArray(NumericArray::Float64(arr_float64));
112        let minarr_bool = Array::BooleanArray(arr_bool);
113        let minarr_string32 = Array::TextArray(TextArray::String32(arr_string32));
114        let minarr_categorical32 = Array::TextArray(TextArray::Categorical32(arr_categorical32));
115        #[cfg(feature = "datetime")]
116        let minarr_datetime32 = Array::TemporalArray(TemporalArray::Datetime32(arr_datetime32));
117        #[cfg(feature = "datetime")]
118        let minarr_datetime64 = Array::TemporalArray(TemporalArray::Datetime64(arr_datetime64));
119
120        // Fields
121        #[cfg(feature = "extended_numeric_types")]
122        let field_int8 = Field::new("int8", ArrowType::Int8, false, None);
123        #[cfg(feature = "extended_numeric_types")]
124        let field_int16 = Field::new("int16", ArrowType::Int16, false, None);
125        let field_int32 = Field::new("int32", ArrowType::Int32, false, None);
126        let field_int64 = Field::new("int64", ArrowType::Int64, false, None);
127        #[cfg(feature = "extended_numeric_types")]
128        let field_uint8 = Field::new("uint8", ArrowType::UInt8, false, None);
129        #[cfg(feature = "extended_numeric_types")]
130        let field_uint16 = Field::new("uint16", ArrowType::UInt16, false, None);
131        let field_uint32 = Field::new("uint32", ArrowType::UInt32, false, None);
132        let field_uint64 = Field::new("uint64", ArrowType::UInt64, false, None);
133        let field_float32 = Field::new("float32", ArrowType::Float32, false, None);
134        let field_float64 = Field::new("float64", ArrowType::Float64, false, None);
135        let field_bool = Field::new("bool", ArrowType::Boolean, false, None);
136        let field_string32 = Field::new("string32", ArrowType::String, false, None);
137        let field_categorical32 = Field::new(
138            "categorical32",
139            ArrowType::Dictionary(CategoricalIndexType::UInt32),
140            false,
141            None,
142        );
143        #[cfg(feature = "datetime")]
144        let field_datetime32 = Field::new("dt32", ArrowType::Date32, false, None);
145        #[cfg(feature = "datetime")]
146        let field_datetime64 = Field::new("dt64", ArrowType::Date64, false, None);
147
148        // FieldArrays
149        #[cfg(feature = "extended_numeric_types")]
150        let fa_int8 = FieldArray::new(field_int8, minarr_int8);
151        #[cfg(feature = "extended_numeric_types")]
152        let fa_int16 = FieldArray::new(field_int16, minarr_int16);
153        let fa_int32 = FieldArray::new(field_int32, minarr_int32);
154        let fa_int64 = FieldArray::new(field_int64, minarr_int64);
155        #[cfg(feature = "extended_numeric_types")]
156        let fa_uint8 = FieldArray::new(field_uint8, minarr_uint8);
157        #[cfg(feature = "extended_numeric_types")]
158        let fa_uint16 = FieldArray::new(field_uint16, minarr_uint16);
159        let fa_uint32 = FieldArray::new(field_uint32, minarr_uint32);
160        let fa_uint64 = FieldArray::new(field_uint64, minarr_uint64);
161        let fa_float32 = FieldArray::new(field_float32, minarr_float32);
162        let fa_float64 = FieldArray::new(field_float64, minarr_float64);
163        let fa_bool = FieldArray::new(field_bool, minarr_bool);
164        let fa_string32 = FieldArray::new(field_string32, minarr_string32);
165        let fa_categorical32 = FieldArray::new(field_categorical32, minarr_categorical32);
166        #[cfg(feature = "datetime")]
167        let fa_datetime32 = FieldArray::new(field_datetime32, minarr_datetime32);
168        #[cfg(feature = "datetime")]
169        let fa_datetime64 = FieldArray::new(field_datetime64, minarr_datetime64);
170
171        // Build table
172        let mut cols = Vec::new();
173        #[cfg(feature = "extended_numeric_types")]
174        {
175            cols.push(fa_int8);
176            cols.push(fa_int16);
177        }
178        cols.push(fa_int32);
179        cols.push(fa_int64);
180        #[cfg(feature = "extended_numeric_types")]
181        {
182            cols.push(fa_uint8);
183            cols.push(fa_uint16);
184        }
185        cols.push(fa_uint32);
186        cols.push(fa_uint64);
187        cols.push(fa_float32);
188        cols.push(fa_float64);
189        cols.push(fa_bool);
190        cols.push(fa_string32);
191        cols.push(fa_categorical32);
192        #[cfg(feature = "datetime")]
193        {
194            cols.push(fa_datetime32);
195            cols.push(fa_datetime64);
196        }
197        Table::new("polars_ffi_test".to_string(), Some(cols))
198    }
199
200    // Minarrow -> C -> arrow2
201    fn minarrow_col_to_arrow2(
202        array: &Array,
203        field: &Field,
204    ) -> (Box<dyn pa::array::Array>, pa::datatypes::Field) {
205        let schema = Schema::from(vec![field.clone()]);
206        let (c_arr, c_schema) = export_to_c(Arc::new(array.clone()), schema);
207        let arr_ptr = c_arr as *mut pa::ffi::ArrowArray;
208        let sch_ptr = c_schema as *mut pa::ffi::ArrowSchema;
209        unsafe {
210            let arr_val = std::ptr::read(arr_ptr);
211            let sch_val = std::ptr::read(sch_ptr);
212            let fld = pa::ffi::import_field_from_c(&sch_val)
213                .expect("polars_arrow import_field_from_c failed");
214            let dtype = fld.dtype().clone();
215            let arr = pa::ffi::import_array_from_c(arr_val, dtype)
216                .expect("polars_arrow import_array_from_c failed");
217            (arr, fld)
218        }
219    }
220
221    // arrow2 -> Polars
222    fn series_from_arrow(name: &str, a: Box<dyn pa::array::Array>) -> Series {
223        Series::from_arrow(name.into(), a).expect("Polars Series::from_arrow failed")
224    }
225
226    // Polars -> C
227    fn export_series_to_c(name: &str, s: &Series) -> (pa::ffi::ArrowArray, pa::ffi::ArrowSchema) {
228        let arr2 = s.to_arrow(0, CompatLevel::oldest());
229        let out_arr: pa::ffi::ArrowArray = pa::ffi::export_array_to_c(arr2.clone());
230        let fld = pa::datatypes::Field::new(name.into(), arr2.dtype().clone(), false);
231        let out_sch: pa::ffi::ArrowSchema = pa::ffi::export_field_to_c(&fld);
232        (out_arr, out_sch)
233    }
234
235    // C -> Minarrow
236    fn import_back_minarrow(
237        out_arr: pa::ffi::ArrowArray,
238        out_sch: pa::ffi::ArrowSchema,
239    ) -> Arc<Array> {
240        let back_arr_ptr =
241            Box::into_raw(Box::new(out_arr)) as *const minarrow::ffi::arrow_c_ffi::ArrowArray;
242        let back_sch_ptr =
243            Box::into_raw(Box::new(out_sch)) as *const minarrow::ffi::arrow_c_ffi::ArrowSchema;
244        unsafe { import_from_c(back_arr_ptr, back_sch_ptr) }
245    }
246
247    // Equality with String32 <-> String64 relaxed match
248    fn arrays_equal_allow_utf8_width(left: &Array, right: &Array) -> bool {
249        if left == right {
250            return true;
251        }
252        match (left, right) {
253            (
254                Array::TextArray(TextArray::String32(a)),
255                Array::TextArray(TextArray::String64(b)),
256            )
257            | (
258                Array::TextArray(TextArray::String64(b)),
259                Array::TextArray(TextArray::String32(a)),
260            ) => {
261                let a = a.as_ref();
262                let b = b.as_ref();
263                a.len() == b.len()
264                    && a.null_mask == b.null_mask
265                    && (0..a.len()).all(|i| a.get(i) == b.get(i))
266            }
267            _ => false,
268        }
269    }
270
271    pub(crate) fn run_example() {
272        let minarrow_table = build_minarrow_table();
273        for col in &minarrow_table.cols {
274            let field_name = &col.field.name;
275
276            let (arrow2_array, _) = minarrow_col_to_arrow2(&col.array, &col.field);
277            let s = series_from_arrow(field_name, arrow2_array);
278
279            let c = Column::new("TestCol".into(), s.clone());
280            let df = DataFrame::new(vec![c]).expect("build DataFrame");
281            println!("{df}");
282
283            let (out_arr, out_sch) = export_series_to_c(field_name, &s);
284            let minarr_back = import_back_minarrow(out_arr, out_sch);
285
286            assert!(
287                arrays_equal_allow_utf8_width(&col.array, minarr_back.as_ref()),
288                "Roundtrip mismatch for field {field_name}"
289            );
290        }
291    }
292}
293
294fn main() {
295    if cfg!(feature = "cast_polars") {
296        #[cfg(feature = "cast_polars")]
297        run_example()
298    } else {
299        println!("The polars-FFI example requires enabling the `cast_polars` feature.")
300    }
301}