Skip to main content

lance_arrow_scalar/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! A scalar type backed by a single-element Arrow array with [`Ord`], [`Hash`],
5//! and [`Eq`] support.
6//!
7//! Comparisons and hashing are delegated to [`arrow_row::OwnedRow`], which
8//! provides a correct total ordering for all Arrow types (including proper NaN
9//! handling for floats and null ordering).
10
11mod convert;
12pub mod serde;
13
14use std::cmp::Ordering;
15use std::fmt;
16use std::hash::{Hash, Hasher};
17use std::sync::Arc;
18
19use arrow_array::{ArrayRef, make_array, new_null_array};
20use arrow_cast::display::ArrayFormatter;
21use arrow_data::transform::MutableArrayData;
22use arrow_row::{OwnedRow, RowConverter, SortField};
23use arrow_schema::{ArrowError, DataType};
24
25type Result<T> = std::result::Result<T, ArrowError>;
26
27/// A scalar value backed by a length-1 Arrow array.
28///
29/// `ArrowScalar` provides [`Eq`], [`Ord`], and [`Hash`] by caching an
30/// [`OwnedRow`] at construction time. This means comparisons and hashing are
31/// O(1) row-byte operations rather than per-type dispatch.
32///
33/// # Cross-type comparison
34///
35/// Comparing scalars of different data types produces an arbitrary but
36/// consistent ordering based on the underlying row bytes. This is intentional
37/// — it allows scalars to be used as keys in sorted collections regardless of
38/// type, but the ordering across types is not semantically meaningful.
39///
40/// # Examples
41///
42/// ```
43/// use lance_arrow_scalar::ArrowScalar;
44///
45/// let a = ArrowScalar::from(1i32);
46/// let b = ArrowScalar::from(2i32);
47/// assert!(a < b);
48///
49/// let c = ArrowScalar::from("hello");
50/// assert_eq!(c, ArrowScalar::from("hello"));
51/// ```
52pub struct ArrowScalar {
53    array: ArrayRef,
54    row: OwnedRow,
55}
56
57impl ArrowScalar {
58    /// Create a scalar by extracting the element at `offset` from `array`.
59    pub fn try_new(array: &ArrayRef, offset: usize) -> Result<Self> {
60        if offset >= array.len() {
61            return Err(ArrowError::InvalidArgumentError(
62                "Scalar index out of bounds".to_string(),
63            ));
64        }
65
66        let data = array.to_data();
67        let mut mutable = MutableArrayData::new(vec![&data], true, 1);
68        mutable.extend(0, offset, offset + 1);
69        let single = make_array(mutable.freeze());
70        Self::try_from_array(single)
71    }
72
73    /// Create a scalar from a length-1 array.
74    pub fn try_from_array(array: ArrayRef) -> Result<Self> {
75        if array.len() != 1 {
76            return Err(ArrowError::InvalidArgumentError(format!(
77                "ArrowScalar requires a length-1 array, got length {}",
78                array.len()
79            )));
80        }
81
82        let row = Self::compute_row(&array)?;
83        Ok(Self { array, row })
84    }
85
86    /// Create a null scalar of the given data type.
87    pub fn new_null(data_type: &DataType) -> Result<Self> {
88        Self::try_from_array(new_null_array(data_type, 1))
89    }
90
91    fn compute_row(array: &ArrayRef) -> Result<OwnedRow> {
92        let sort_field = SortField::new(array.data_type().clone());
93        let converter = RowConverter::new(vec![sort_field])?;
94        let rows = converter.convert_columns(&[Arc::clone(array)])?;
95        Ok(rows.row(0).owned())
96    }
97
98    /// Returns a reference to the underlying length-1 array.
99    pub fn as_array(&self) -> &ArrayRef {
100        &self.array
101    }
102
103    /// Returns the data type of this scalar.
104    pub fn data_type(&self) -> &DataType {
105        self.array.data_type()
106    }
107
108    /// Returns `true` if this scalar is null.
109    pub fn is_null(&self) -> bool {
110        self.array.null_count() == 1
111    }
112}
113
114impl PartialEq for ArrowScalar {
115    fn eq(&self, other: &Self) -> bool {
116        self.row == other.row
117    }
118}
119
120impl Eq for ArrowScalar {}
121
122impl PartialOrd for ArrowScalar {
123    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
124        Some(self.cmp(other))
125    }
126}
127
128impl Ord for ArrowScalar {
129    fn cmp(&self, other: &Self) -> Ordering {
130        self.row.cmp(&other.row)
131    }
132}
133
134impl Hash for ArrowScalar {
135    fn hash<H: Hasher>(&self, state: &mut H) {
136        self.row.hash(state);
137    }
138}
139
140impl fmt::Display for ArrowScalar {
141    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
142        if self.is_null() {
143            return write!(f, "null");
144        }
145        let formatter =
146            ArrayFormatter::try_new(&self.array, &Default::default()).map_err(|_| fmt::Error)?;
147        write!(f, "{}", formatter.value(0))
148    }
149}
150
151impl fmt::Debug for ArrowScalar {
152    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
153        write!(f, "ArrowScalar({}: {})", self.data_type(), self)
154    }
155}
156
157impl Clone for ArrowScalar {
158    fn clone(&self) -> Self {
159        Self {
160            array: Arc::clone(&self.array),
161            row: self.row.clone(),
162        }
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use std::collections::{BTreeSet, HashSet};
169    use std::sync::Arc;
170
171    use arrow_array::*;
172    use rstest::rstest;
173
174    use super::*;
175
176    #[test]
177    fn test_try_new_extracts_element() {
178        let array: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 30]));
179        let s = ArrowScalar::try_new(&array, 1).unwrap();
180        assert_eq!(format!("{s}"), "20");
181    }
182
183    #[test]
184    fn test_try_new_out_of_bounds() {
185        let array: ArrayRef = Arc::new(Int32Array::from(vec![1]));
186        assert!(ArrowScalar::try_new(&array, 5).is_err());
187    }
188
189    #[test]
190    fn test_try_from_array_wrong_length() {
191        let array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
192        assert!(ArrowScalar::try_from_array(array).is_err());
193    }
194
195    #[test]
196    fn test_equality() {
197        let a = ArrowScalar::from(42i32);
198        let b = ArrowScalar::from(42i32);
199        let c = ArrowScalar::from(99i32);
200        assert_eq!(a, b);
201        assert_ne!(a, c);
202    }
203
204    #[test]
205    fn test_ordering() {
206        let a = ArrowScalar::from(1i32);
207        let b = ArrowScalar::from(2i32);
208        let c = ArrowScalar::from(3i32);
209        assert!(a < b);
210        assert!(b < c);
211        assert_eq!(a.cmp(&a), Ordering::Equal);
212    }
213
214    #[test]
215    fn test_hash_consistent_with_eq() {
216        use std::hash::DefaultHasher;
217
218        let a = ArrowScalar::from(42i32);
219        let b = ArrowScalar::from(42i32);
220        let hash_a = {
221            let mut h = DefaultHasher::new();
222            a.hash(&mut h);
223            h.finish()
224        };
225        let hash_b = {
226            let mut h = DefaultHasher::new();
227            b.hash(&mut h);
228            h.finish()
229        };
230        assert_eq!(hash_a, hash_b);
231    }
232
233    #[test]
234    fn test_in_hashset() {
235        let mut set = HashSet::new();
236        set.insert(ArrowScalar::from(1i32));
237        set.insert(ArrowScalar::from(2i32));
238        set.insert(ArrowScalar::from(1i32));
239        assert_eq!(set.len(), 2);
240    }
241
242    #[test]
243    fn test_in_btreeset() {
244        let mut set = BTreeSet::new();
245        set.insert(ArrowScalar::from(3i32));
246        set.insert(ArrowScalar::from(1i32));
247        set.insert(ArrowScalar::from(2i32));
248        let values: Vec<_> = set.iter().map(|s| format!("{s}")).collect();
249        assert_eq!(values, vec!["1", "2", "3"]);
250    }
251
252    #[test]
253    fn test_null_scalar() {
254        let array: ArrayRef = Arc::new(Int32Array::from(vec![None]));
255        let s = ArrowScalar::try_from_array(array).unwrap();
256        assert!(s.is_null());
257        assert_eq!(format!("{s}"), "null");
258    }
259
260    #[test]
261    fn test_null_sorts_first() {
262        let null_scalar = {
263            let array: ArrayRef = Arc::new(Int32Array::from(vec![None]));
264            ArrowScalar::try_from_array(array).unwrap()
265        };
266        let value_scalar = ArrowScalar::from(0i32);
267        assert!(null_scalar < value_scalar);
268    }
269
270    #[rstest]
271    #[case::float_nan(
272        ArrowScalar::from(f64::NAN),
273        ArrowScalar::from(f64::INFINITY),
274        Ordering::Greater
275    )]
276    #[case::float_normal(ArrowScalar::from(1.0f64), ArrowScalar::from(2.0f64), Ordering::Less)]
277    fn test_float_ordering(
278        #[case] a: ArrowScalar,
279        #[case] b: ArrowScalar,
280        #[case] expected: Ordering,
281    ) {
282        assert_eq!(a.cmp(&b), expected);
283    }
284
285    #[test]
286    fn test_display_string() {
287        let s = ArrowScalar::from("hello world");
288        assert_eq!(format!("{s}"), "hello world");
289    }
290
291    #[test]
292    fn test_debug() {
293        let s = ArrowScalar::from(42i32);
294        let debug = format!("{s:?}");
295        assert!(debug.contains("ArrowScalar"));
296        assert!(debug.contains("42"));
297    }
298
299    #[test]
300    fn test_clone() {
301        let a = ArrowScalar::from(42i32);
302        let b = a.clone();
303        assert_eq!(a, b);
304    }
305
306    #[test]
307    fn test_data_type() {
308        let s = ArrowScalar::from(42i32);
309        assert_eq!(s.data_type(), &DataType::Int32);
310    }
311
312    #[test]
313    fn test_boolean_roundtrip() {
314        let t = ArrowScalar::from(true);
315        let f = ArrowScalar::from(false);
316        assert_eq!(t.data_type(), &DataType::Boolean);
317        assert!(!t.is_null());
318        assert_eq!(format!("{t}"), "true");
319        assert_eq!(format!("{f}"), "false");
320
321        // Extract from multi-element array
322        let array: ArrayRef = Arc::new(BooleanArray::from(vec![true, false, true]));
323        let s = ArrowScalar::try_new(&array, 1).unwrap();
324        assert_eq!(format!("{s}"), "false");
325        assert_eq!(s.data_type(), &DataType::Boolean);
326    }
327
328    #[test]
329    fn test_boolean_equality_and_ordering() {
330        let t1 = ArrowScalar::from(true);
331        let t2 = ArrowScalar::from(true);
332        let f1 = ArrowScalar::from(false);
333        assert_eq!(t1, t2);
334        assert_ne!(t1, f1);
335        // false < true in arrow row encoding
336        assert!(f1 < t1);
337    }
338
339    #[test]
340    fn test_boolean_null() {
341        let array: ArrayRef = Arc::new(BooleanArray::from(vec![None]));
342        let scalar = ArrowScalar::try_from_array(array).unwrap();
343        assert!(scalar.is_null());
344        assert_eq!(scalar.data_type(), &DataType::Boolean);
345        assert_eq!(format!("{scalar}"), "null");
346
347        // null sorts before false
348        let f = ArrowScalar::from(false);
349        assert!(scalar < f);
350    }
351
352    #[test]
353    fn test_string_view_roundtrip() {
354        let array: ArrayRef = Arc::new(StringViewArray::from(vec![
355            "hello world, this is a long string view",
356        ]));
357        let scalar = ArrowScalar::try_from_array(array).unwrap();
358        assert_eq!(scalar.data_type(), &DataType::Utf8View);
359        assert!(!scalar.is_null());
360        assert_eq!(
361            format!("{scalar}"),
362            "hello world, this is a long string view"
363        );
364
365        // Extract from multi-element array
366        let array: ArrayRef = Arc::new(StringViewArray::from(vec!["alpha", "beta", "gamma"]));
367        let s = ArrowScalar::try_new(&array, 1).unwrap();
368        assert_eq!(format!("{s}"), "beta");
369        assert_eq!(s.data_type(), &DataType::Utf8View);
370    }
371
372    #[test]
373    fn test_binary_view_roundtrip() {
374        let values: Vec<&[u8]> = vec![b"\xDE\xAD\xBE\xEF"];
375        let array: ArrayRef = Arc::new(BinaryViewArray::from(values));
376        let scalar = ArrowScalar::try_from_array(array).unwrap();
377        assert_eq!(scalar.data_type(), &DataType::BinaryView);
378        assert!(!scalar.is_null());
379
380        // Extract from multi-element array
381        let values: Vec<&[u8]> = vec![b"aaa", b"bbb", b"ccc"];
382        let array: ArrayRef = Arc::new(BinaryViewArray::from(values));
383        let s = ArrowScalar::try_new(&array, 2).unwrap();
384        assert_eq!(s.data_type(), &DataType::BinaryView);
385    }
386
387    #[test]
388    fn test_string_view_equality_and_ordering() {
389        let mk = |s: &str| {
390            let array: ArrayRef = Arc::new(StringViewArray::from(vec![s]));
391            ArrowScalar::try_from_array(array).unwrap()
392        };
393        let a = mk("apple");
394        let b = mk("apple");
395        let c = mk("banana");
396        assert_eq!(a, b);
397        assert_ne!(a, c);
398        assert!(a < c);
399    }
400
401    #[test]
402    fn test_binary_view_equality_and_ordering() {
403        let mk = |b: &[u8]| {
404            let values: Vec<&[u8]> = vec![b];
405            let array: ArrayRef = Arc::new(BinaryViewArray::from(values));
406            ArrowScalar::try_from_array(array).unwrap()
407        };
408        let a = mk(b"\x01\x02");
409        let b = mk(b"\x01\x02");
410        let c = mk(b"\x01\x03");
411        assert_eq!(a, b);
412        assert_ne!(a, c);
413        assert!(a < c);
414    }
415
416    #[test]
417    fn test_string_view_in_collections() {
418        let mk = |s: &str| {
419            let array: ArrayRef = Arc::new(StringViewArray::from(vec![s]));
420            ArrowScalar::try_from_array(array).unwrap()
421        };
422
423        let mut hset = HashSet::new();
424        hset.insert(mk("foo"));
425        hset.insert(mk("bar"));
426        hset.insert(mk("foo"));
427        assert_eq!(hset.len(), 2);
428
429        let mut bset = BTreeSet::new();
430        bset.insert(mk("cherry"));
431        bset.insert(mk("apple"));
432        bset.insert(mk("banana"));
433        let sorted: Vec<_> = bset.iter().map(|s| format!("{s}")).collect();
434        assert_eq!(sorted, vec!["apple", "banana", "cherry"]);
435    }
436
437    #[test]
438    fn test_string_view_null() {
439        let array: ArrayRef = Arc::new(StringViewArray::from(vec![Option::<&str>::None]));
440        let scalar = ArrowScalar::try_from_array(array).unwrap();
441        assert!(scalar.is_null());
442        assert_eq!(scalar.data_type(), &DataType::Utf8View);
443        assert_eq!(format!("{scalar}"), "null");
444    }
445
446    #[test]
447    fn test_binary_view_null() {
448        let array: ArrayRef = Arc::new(BinaryViewArray::from(vec![Option::<&[u8]>::None]));
449        let scalar = ArrowScalar::try_from_array(array).unwrap();
450        assert!(scalar.is_null());
451        assert_eq!(scalar.data_type(), &DataType::BinaryView);
452    }
453
454    #[test]
455    fn test_cross_type_comparison_is_consistent() {
456        let int_scalar = ArrowScalar::from(42i32);
457        let str_scalar = ArrowScalar::from("hello");
458        // The ordering is arbitrary but must be consistent
459        let ord1 = int_scalar.cmp(&str_scalar);
460        let ord2 = int_scalar.cmp(&str_scalar);
461        assert_eq!(ord1, ord2);
462        // And the reverse should be opposite
463        assert_eq!(str_scalar.cmp(&int_scalar), ord1.reverse());
464    }
465}
466
467#[cfg(test)]
468mod prop_tests {
469    use std::sync::Arc;
470
471    use arrow_array::*;
472    use arrow_ord::sort::sort;
473    use arrow_schema::SortOptions;
474    use proptest::prelude::*;
475
476    use super::ArrowScalar;
477
478    /// Generate an arbitrary Arrow array of a randomly chosen type, including
479    /// nulls. Covers primitives, booleans, string/binary types and their view
480    /// variants.
481    fn arbitrary_array() -> BoxedStrategy<ArrayRef> {
482        let len = 0..=100usize;
483
484        prop_oneof![
485            // --- integer types ---
486            proptest::collection::vec(proptest::option::of(any::<i8>()), len.clone())
487                .prop_map(|v| Arc::new(Int8Array::from(v)) as ArrayRef),
488            proptest::collection::vec(proptest::option::of(any::<i16>()), len.clone())
489                .prop_map(|v| Arc::new(Int16Array::from(v)) as ArrayRef),
490            proptest::collection::vec(proptest::option::of(any::<i32>()), len.clone())
491                .prop_map(|v| Arc::new(Int32Array::from(v)) as ArrayRef),
492            proptest::collection::vec(proptest::option::of(any::<i64>()), len.clone())
493                .prop_map(|v| Arc::new(Int64Array::from(v)) as ArrayRef),
494            proptest::collection::vec(proptest::option::of(any::<u8>()), len.clone())
495                .prop_map(|v| Arc::new(UInt8Array::from(v)) as ArrayRef),
496            proptest::collection::vec(proptest::option::of(any::<u16>()), len.clone())
497                .prop_map(|v| Arc::new(UInt16Array::from(v)) as ArrayRef),
498            proptest::collection::vec(proptest::option::of(any::<u32>()), len.clone())
499                .prop_map(|v| Arc::new(UInt32Array::from(v)) as ArrayRef),
500            proptest::collection::vec(proptest::option::of(any::<u64>()), len.clone())
501                .prop_map(|v| Arc::new(UInt64Array::from(v)) as ArrayRef),
502            // --- float types ---
503            proptest::collection::vec(proptest::option::of(any::<f32>()), len.clone())
504                .prop_map(|v| Arc::new(Float32Array::from(v)) as ArrayRef),
505            proptest::collection::vec(proptest::option::of(any::<f64>()), len.clone())
506                .prop_map(|v| Arc::new(Float64Array::from(v)) as ArrayRef),
507            // --- boolean ---
508            proptest::collection::vec(proptest::option::of(any::<bool>()), len.clone())
509                .prop_map(|v| Arc::new(BooleanArray::from(v)) as ArrayRef),
510            // --- string types ---
511            proptest::collection::vec(proptest::option::of(any::<String>()), len.clone()).prop_map(
512                |v| {
513                    let refs: Vec<Option<&str>> = v.iter().map(|o| o.as_deref()).collect();
514                    Arc::new(StringArray::from(refs)) as ArrayRef
515                }
516            ),
517            proptest::collection::vec(proptest::option::of(any::<String>()), len.clone()).prop_map(
518                |v| {
519                    let refs: Vec<Option<&str>> = v.iter().map(|o| o.as_deref()).collect();
520                    Arc::new(LargeStringArray::from(refs)) as ArrayRef
521                }
522            ),
523            proptest::collection::vec(proptest::option::of(any::<String>()), len.clone()).prop_map(
524                |v| {
525                    let refs: Vec<Option<&str>> = v.iter().map(|o| o.as_deref()).collect();
526                    Arc::new(StringViewArray::from(refs)) as ArrayRef
527                }
528            ),
529            // --- binary types ---
530            proptest::collection::vec(
531                proptest::option::of(proptest::collection::vec(any::<u8>(), 0..50)),
532                len.clone(),
533            )
534            .prop_map(|v| {
535                let refs: Vec<Option<&[u8]>> = v.iter().map(|o| o.as_deref()).collect();
536                Arc::new(BinaryArray::from(refs)) as ArrayRef
537            }),
538            proptest::collection::vec(
539                proptest::option::of(proptest::collection::vec(any::<u8>(), 0..50)),
540                len.clone(),
541            )
542            .prop_map(|v| {
543                let refs: Vec<Option<&[u8]>> = v.iter().map(|o| o.as_deref()).collect();
544                Arc::new(LargeBinaryArray::from(refs)) as ArrayRef
545            }),
546            proptest::collection::vec(
547                proptest::option::of(proptest::collection::vec(any::<u8>(), 0..50)),
548                len,
549            )
550            .prop_map(|v| {
551                let refs: Vec<Option<&[u8]>> = v.iter().map(|o| o.as_deref()).collect();
552                Arc::new(BinaryViewArray::from(refs)) as ArrayRef
553            }),
554        ]
555        .boxed()
556    }
557
558    proptest::proptest! {
559        #[test]
560        fn sorted_array_produces_sorted_scalars(array in arbitrary_array()) {
561            let sorted = sort(
562                &array,
563                Some(SortOptions { descending: false, nulls_first: true }),
564            )
565            .unwrap();
566
567            let scalars: Vec<ArrowScalar> = (0..sorted.len())
568                .map(|i| ArrowScalar::try_new(&sorted, i).unwrap())
569                .collect();
570
571            for i in 1..scalars.len() {
572                prop_assert!(
573                    scalars[i - 1] <= scalars[i],
574                    "scalar[{}] ({:?}) should be <= scalar[{}] ({:?})",
575                    i - 1, scalars[i - 1], i, scalars[i],
576                );
577            }
578        }
579    }
580}