polars_ffi/
polars_ffi.rs

1//! ---------------------------------------------------------
2//! Minarrow ↔️ Polars (via polars_arrow/arrow2) FFI roundtrip
3//! 
4//! Run with:
5//!    cargo run --example polars_ffi --features cast_polars
6//! 
7//! This is for custom FFI - you can instead also directly go to polars
8//! via `to_polars()` from the `Array`, `FieldArray` or `Table`
9//! types when the *cast_polars* feature is activated.
10//! ---------------------------------------------------------
11
12#[cfg(feature = "cast_polars")]
13use crate::polars_roundtrip::run_example;
14
15#[cfg(feature = "cast_polars")]
16mod polars_roundtrip {
17    use std::sync::Arc;
18
19    use minarrow::ffi::arrow_c_ffi::{export_to_c, import_from_c};
20    use minarrow::ffi::arrow_dtype::CategoricalIndexType;
21    use minarrow::ffi::schema::Schema;
22    use minarrow::{Array, ArrowType, Field, FieldArray, NumericArray, Table, TextArray};
23    #[cfg(feature = "datetime")]
24    use minarrow::{TemporalArray, TimeUnit};
25    use polars::prelude::*;
26    use polars_arrow as pa;
27
28    // -------------------------------------------------------------------------
29    // Build test table with full type coverage
30    // -------------------------------------------------------------------------
31    fn build_minarrow_table() -> Table {
32        // Arrays
33        #[cfg(feature = "extended_numeric_types")]
34        let arr_int8 = Arc::new(minarrow::IntegerArray::<i8>::from_slice(&[1, 2, -1])) as Arc<_>;
35        #[cfg(feature = "extended_numeric_types")]
36        let arr_int16 =
37            Arc::new(minarrow::IntegerArray::<i16>::from_slice(&[10, 20, -10])) as Arc<_>;
38        let arr_int32 =
39            Arc::new(minarrow::IntegerArray::<i32>::from_slice(&[100, 200, -100])) as Arc<_>;
40        let arr_int64 =
41            Arc::new(minarrow::IntegerArray::<i64>::from_slice(&[1000, 2000, -1000])) as Arc<_>;
42
43        #[cfg(feature = "extended_numeric_types")]
44        let arr_uint8 = Arc::new(minarrow::IntegerArray::<u8>::from_slice(&[1, 2, 255]))
45            as Arc<minarrow::IntegerArray<u8>>;
46        #[cfg(feature = "extended_numeric_types")]
47        let arr_uint16 = Arc::new(minarrow::IntegerArray::<u16>::from_slice(&[1, 2, 65535]))
48            as Arc<minarrow::IntegerArray<u16>>;
49        let arr_uint32 = Arc::new(minarrow::IntegerArray::<u32>::from_slice(&[1, 2, 4294967295]))
50            as Arc<minarrow::IntegerArray<u32>>;
51        let arr_uint64 =
52            Arc::new(minarrow::IntegerArray::<u64>::from_slice(&[1, 2, 18446744073709551615]))
53                as Arc<minarrow::IntegerArray<u64>>;
54
55        let arr_float32 = Arc::new(minarrow::FloatArray::<f32>::from_slice(&[1.5, -0.5, 0.0]))
56            as Arc<minarrow::FloatArray<f32>>;
57        let arr_float64 = Arc::new(minarrow::FloatArray::<f64>::from_slice(&[1.0, -2.0, 0.0]))
58            as Arc<minarrow::FloatArray<f64>>;
59
60        let arr_bool = Arc::new(minarrow::BooleanArray::<()>::from_slice(&[true, false, true]))
61            as Arc<minarrow::BooleanArray<()>>;
62
63        let arr_string32 = Arc::new(minarrow::StringArray::<u32>::from_slice(&["abc", "def", ""]))
64            as Arc<minarrow::StringArray<u32>>;
65        let arr_categorical32 = Arc::new(minarrow::CategoricalArray::<u32>::from_slices(
66            &[0, 1, 2],
67            &["A".to_string(), "B".to_string(), "C".to_string()],
68        )) as Arc<minarrow::CategoricalArray<u32>>;
69
70        #[cfg(feature = "datetime")]
71        let arr_datetime32 = Arc::new(minarrow::DatetimeArray::<i32> {
72            data: minarrow::Buffer::<i32>::from_slice(&[
73                1_600_000_000 / 86_400,
74                1_600_000_001 / 86_400,
75                1_600_000_002 / 86_400,
76            ]),
77            null_mask: None,
78            time_unit: TimeUnit::Days,
79        });
80        #[cfg(feature = "datetime")]
81        let arr_datetime64 = Arc::new(minarrow::DatetimeArray::<i64> {
82            data: minarrow::Buffer::<i64>::from_slice(&[
83                1_600_000_000_000,
84                1_600_000_000_001,
85                1_600_000_000_002,
86            ]),
87            null_mask: None,
88            time_unit: TimeUnit::Milliseconds,
89        }) as Arc<_>;
90
91        // Wrap in Array enums
92        #[cfg(feature = "extended_numeric_types")]
93        let minarr_int8 = Array::NumericArray(NumericArray::Int8(arr_int8));
94        #[cfg(feature = "extended_numeric_types")]
95        let minarr_int16 = Array::NumericArray(NumericArray::Int16(arr_int16));
96        let minarr_int32 = Array::NumericArray(NumericArray::Int32(arr_int32));
97        let minarr_int64 = Array::NumericArray(NumericArray::Int64(arr_int64));
98        #[cfg(feature = "extended_numeric_types")]
99        let minarr_uint8 = Array::NumericArray(NumericArray::UInt8(arr_uint8));
100        #[cfg(feature = "extended_numeric_types")]
101        let minarr_uint16 = Array::NumericArray(NumericArray::UInt16(arr_uint16));
102        let minarr_uint32 = Array::NumericArray(NumericArray::UInt32(arr_uint32));
103        let minarr_uint64 = Array::NumericArray(NumericArray::UInt64(arr_uint64));
104        let minarr_float32 = Array::NumericArray(NumericArray::Float32(arr_float32));
105        let minarr_float64 = Array::NumericArray(NumericArray::Float64(arr_float64));
106        let minarr_bool = Array::BooleanArray(arr_bool);
107        let minarr_string32 = Array::TextArray(TextArray::String32(arr_string32));
108        let minarr_categorical32 = Array::TextArray(TextArray::Categorical32(arr_categorical32));
109        #[cfg(feature = "datetime")]
110        let minarr_datetime32 = Array::TemporalArray(TemporalArray::Datetime32(arr_datetime32));
111        #[cfg(feature = "datetime")]
112        let minarr_datetime64 = Array::TemporalArray(TemporalArray::Datetime64(arr_datetime64));
113
114        // Fields
115        #[cfg(feature = "extended_numeric_types")]
116        let field_int8 = Field::new("int8", ArrowType::Int8, false, None);
117        #[cfg(feature = "extended_numeric_types")]
118        let field_int16 = Field::new("int16", ArrowType::Int16, false, None);
119        let field_int32 = Field::new("int32", ArrowType::Int32, false, None);
120        let field_int64 = Field::new("int64", ArrowType::Int64, false, None);
121        #[cfg(feature = "extended_numeric_types")]
122        let field_uint8 = Field::new("uint8", ArrowType::UInt8, false, None);
123        #[cfg(feature = "extended_numeric_types")]
124        let field_uint16 = Field::new("uint16", ArrowType::UInt16, false, None);
125        let field_uint32 = Field::new("uint32", ArrowType::UInt32, false, None);
126        let field_uint64 = Field::new("uint64", ArrowType::UInt64, false, None);
127        let field_float32 = Field::new("float32", ArrowType::Float32, false, None);
128        let field_float64 = Field::new("float64", ArrowType::Float64, false, None);
129        let field_bool = Field::new("bool", ArrowType::Boolean, false, None);
130        let field_string32 = Field::new("string32", ArrowType::String, false, None);
131        let field_categorical32 = Field::new(
132            "categorical32",
133            ArrowType::Dictionary(CategoricalIndexType::UInt32),
134            false,
135            None,
136        );
137        #[cfg(feature = "datetime")]
138        let field_datetime32 = Field::new("dt32", ArrowType::Date32, false, None);
139        #[cfg(feature = "datetime")]
140        let field_datetime64 = Field::new("dt64", ArrowType::Date64, false, None);
141
142        // FieldArrays
143        #[cfg(feature = "extended_numeric_types")]
144        let fa_int8 = FieldArray::new(field_int8, minarr_int8);
145        #[cfg(feature = "extended_numeric_types")]
146        let fa_int16 = FieldArray::new(field_int16, minarr_int16);
147        let fa_int32 = FieldArray::new(field_int32, minarr_int32);
148        let fa_int64 = FieldArray::new(field_int64, minarr_int64);
149        #[cfg(feature = "extended_numeric_types")]
150        let fa_uint8 = FieldArray::new(field_uint8, minarr_uint8);
151        #[cfg(feature = "extended_numeric_types")]
152        let fa_uint16 = FieldArray::new(field_uint16, minarr_uint16);
153        let fa_uint32 = FieldArray::new(field_uint32, minarr_uint32);
154        let fa_uint64 = FieldArray::new(field_uint64, minarr_uint64);
155        let fa_float32 = FieldArray::new(field_float32, minarr_float32);
156        let fa_float64 = FieldArray::new(field_float64, minarr_float64);
157        let fa_bool = FieldArray::new(field_bool, minarr_bool);
158        let fa_string32 = FieldArray::new(field_string32, minarr_string32);
159        let fa_categorical32 = FieldArray::new(field_categorical32, minarr_categorical32);
160        #[cfg(feature = "datetime")]
161        let fa_datetime32 = FieldArray::new(field_datetime32, minarr_datetime32);
162        #[cfg(feature = "datetime")]
163        let fa_datetime64 = FieldArray::new(field_datetime64, minarr_datetime64);
164
165        // Build table
166        let mut cols = Vec::new();
167        #[cfg(feature = "extended_numeric_types")]
168        {
169            cols.push(fa_int8);
170            cols.push(fa_int16);
171        }
172        cols.push(fa_int32);
173        cols.push(fa_int64);
174        #[cfg(feature = "extended_numeric_types")]
175        {
176            cols.push(fa_uint8);
177            cols.push(fa_uint16);
178        }
179        cols.push(fa_uint32);
180        cols.push(fa_uint64);
181        cols.push(fa_float32);
182        cols.push(fa_float64);
183        cols.push(fa_bool);
184        cols.push(fa_string32);
185        cols.push(fa_categorical32);
186        #[cfg(feature = "datetime")]
187        {
188            cols.push(fa_datetime32);
189            cols.push(fa_datetime64);
190        }
191        Table::new("polars_ffi_test".to_string(), Some(cols))
192    }
193
194    // Minarrow -> C -> arrow2
195    fn minarrow_col_to_arrow2(
196        array: &Array,
197        field: &Field,
198    ) -> (Box<dyn pa::array::Array>, pa::datatypes::Field) {
199        let schema = Schema::from(vec![field.clone()]);
200        let (c_arr, c_schema) = export_to_c(Arc::new(array.clone()), schema);
201        let arr_ptr = c_arr as *mut pa::ffi::ArrowArray;
202        let sch_ptr = c_schema as *mut pa::ffi::ArrowSchema;
203        unsafe {
204            let arr_val = std::ptr::read(arr_ptr);
205            let sch_val = std::ptr::read(sch_ptr);
206            let fld = pa::ffi::import_field_from_c(&sch_val)
207                .expect("polars_arrow import_field_from_c failed");
208            let dtype = fld.dtype().clone();
209            let arr = pa::ffi::import_array_from_c(arr_val, dtype)
210                .expect("polars_arrow import_array_from_c failed");
211            (arr, fld)
212        }
213    }
214
215    // arrow2 -> Polars
216    fn series_from_arrow(name: &str, a: Box<dyn pa::array::Array>) -> Series {
217        Series::from_arrow(name.into(), a).expect("Polars Series::from_arrow failed")
218    }
219
220    // Polars -> C
221    fn export_series_to_c(
222        name: &str,
223        s: &Series,
224    ) -> (pa::ffi::ArrowArray, pa::ffi::ArrowSchema) {
225        let arr2 = s.to_arrow(0, CompatLevel::oldest());
226        let out_arr: pa::ffi::ArrowArray = pa::ffi::export_array_to_c(arr2.clone());
227        let fld = pa::datatypes::Field::new(name.into(), arr2.dtype().clone(), false);
228        let out_sch: pa::ffi::ArrowSchema = pa::ffi::export_field_to_c(&fld);
229        (out_arr, out_sch)
230    }
231
232    // C -> Minarrow
233    fn import_back_minarrow(
234        out_arr: pa::ffi::ArrowArray,
235        out_sch: pa::ffi::ArrowSchema,
236    ) -> Arc<Array> {
237        let back_arr_ptr =
238            Box::into_raw(Box::new(out_arr)) as *const minarrow::ffi::arrow_c_ffi::ArrowArray;
239        let back_sch_ptr =
240            Box::into_raw(Box::new(out_sch)) as *const minarrow::ffi::arrow_c_ffi::ArrowSchema;
241        unsafe { import_from_c(back_arr_ptr, back_sch_ptr) }
242    }
243
244    // Equality with String32 <-> String64 relaxed match
245    fn arrays_equal_allow_utf8_width(left: &Array, right: &Array) -> bool {
246        if left == right {
247            return true;
248        }
249        match (left, right) {
250            (Array::TextArray(TextArray::String32(a)), Array::TextArray(TextArray::String64(b)))
251            | (Array::TextArray(TextArray::String64(b)), Array::TextArray(TextArray::String32(a))) => {
252                let a = a.as_ref();
253                let b = b.as_ref();
254                a.len() == b.len()
255                    && a.null_mask == b.null_mask
256                    && (0..a.len()).all(|i| a.get(i) == b.get(i))
257            }
258            _ => false,
259        }
260    }
261
262    pub(crate) fn run_example() {
263        let minarrow_table = build_minarrow_table();
264        for col in &minarrow_table.cols {
265            let field_name = &col.field.name;
266
267            let (arrow2_array, _) = minarrow_col_to_arrow2(&col.array, &col.field);
268            let s = series_from_arrow(field_name, arrow2_array);
269
270            let c = Column::new("TestCol".into(), s.clone());
271            let df = DataFrame::new(vec![c]).expect("build DataFrame");
272            println!("{df}");
273
274            let (out_arr, out_sch) = export_series_to_c(field_name, &s);
275            let minarr_back = import_back_minarrow(out_arr, out_sch);
276
277            assert!(
278                arrays_equal_allow_utf8_width(&col.array, minarr_back.as_ref()),
279                "Roundtrip mismatch for field {field_name}"
280            );
281        }
282    }
283}
284
285fn main() {
286    if cfg!(feature = "cast_polars") {
287        #[cfg(feature = "cast_polars")]
288        run_example()
289    } else {
290        println!("The polars-FFI example requires enabling the `cast_polars` feature.")
291    }
292}