1use crate::alloc::Deallocation;
19use crate::buffer::Buffer;
20use crate::native::ArrowNativeType;
21use crate::{BufferBuilder, MutableBuffer, OffsetBuffer};
22use std::fmt::Formatter;
23use std::marker::PhantomData;
24use std::ops::Deref;
25
26#[derive(Clone, Default)]
45pub struct ScalarBuffer<T: ArrowNativeType> {
46    buffer: Buffer,
48    phantom: PhantomData<T>,
49}
50
51impl<T: ArrowNativeType> std::fmt::Debug for ScalarBuffer<T> {
52    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
53        f.debug_tuple("ScalarBuffer").field(&self.as_ref()).finish()
54    }
55}
56
57impl<T: ArrowNativeType> ScalarBuffer<T> {
58    pub fn new(buffer: Buffer, offset: usize, len: usize) -> Self {
69        let size = std::mem::size_of::<T>();
70        let byte_offset = offset.checked_mul(size).expect("offset overflow");
71        let byte_len = len.checked_mul(size).expect("length overflow");
72        buffer.slice_with_length(byte_offset, byte_len).into()
73    }
74
75    pub unsafe fn new_unchecked(buffer: Buffer) -> Self {
82        Self {
83            buffer,
84            phantom: Default::default(),
85        }
86    }
87
88    pub fn shrink_to_fit(&mut self) {
90        self.buffer.shrink_to_fit();
91    }
92
93    pub fn slice(&self, offset: usize, len: usize) -> Self {
95        Self::new(self.buffer.clone(), offset, len)
96    }
97
98    pub fn inner(&self) -> &Buffer {
100        &self.buffer
101    }
102
103    pub fn into_inner(self) -> Buffer {
105        self.buffer
106    }
107
108    #[inline]
112    pub fn ptr_eq(&self, other: &Self) -> bool {
113        self.buffer.ptr_eq(&other.buffer)
114    }
115
116    pub fn len(&self) -> usize {
118        self.buffer.len() / std::mem::size_of::<T>()
119    }
120
121    pub fn is_empty(&self) -> bool {
123        self.len() == 0
124    }
125}
126
127impl<T: ArrowNativeType> Deref for ScalarBuffer<T> {
128    type Target = [T];
129
130    #[inline]
131    fn deref(&self) -> &Self::Target {
132        unsafe {
134            std::slice::from_raw_parts(
135                self.buffer.as_ptr() as *const T,
136                self.buffer.len() / std::mem::size_of::<T>(),
137            )
138        }
139    }
140}
141
142impl<T: ArrowNativeType> AsRef<[T]> for ScalarBuffer<T> {
143    #[inline]
144    fn as_ref(&self) -> &[T] {
145        self
146    }
147}
148
149impl<T: ArrowNativeType> From<MutableBuffer> for ScalarBuffer<T> {
150    fn from(value: MutableBuffer) -> Self {
151        Buffer::from(value).into()
152    }
153}
154
155impl<T: ArrowNativeType> From<Buffer> for ScalarBuffer<T> {
156    fn from(buffer: Buffer) -> Self {
157        let align = std::mem::align_of::<T>();
158        let is_aligned = buffer.as_ptr().align_offset(align) == 0;
159
160        match buffer.deallocation() {
161            Deallocation::Standard(_) => assert!(
162                is_aligned,
163                "Memory pointer is not aligned with the specified scalar type"
164            ),
165            Deallocation::Custom(_, _) => assert!(
166                is_aligned,
167                "Memory pointer from external source (e.g, FFI) is not aligned with the specified scalar type. Before importing buffer through FFI, please make sure the allocation is aligned."
168            ),
169        }
170
171        Self {
172            buffer,
173            phantom: Default::default(),
174        }
175    }
176}
177
178impl<T: ArrowNativeType> From<OffsetBuffer<T>> for ScalarBuffer<T> {
179    fn from(value: OffsetBuffer<T>) -> Self {
180        value.into_inner()
181    }
182}
183
184impl<T: ArrowNativeType> From<Vec<T>> for ScalarBuffer<T> {
185    fn from(value: Vec<T>) -> Self {
186        Self {
187            buffer: Buffer::from_vec(value),
188            phantom: Default::default(),
189        }
190    }
191}
192
193impl<T: ArrowNativeType> From<ScalarBuffer<T>> for Vec<T> {
194    fn from(value: ScalarBuffer<T>) -> Self {
195        value
196            .buffer
197            .into_vec()
198            .unwrap_or_else(|buffer| buffer.typed_data::<T>().into())
199    }
200}
201
202impl<T: ArrowNativeType> From<BufferBuilder<T>> for ScalarBuffer<T> {
203    fn from(mut value: BufferBuilder<T>) -> Self {
204        let len = value.len();
205        Self::new(value.finish(), 0, len)
206    }
207}
208
209impl<T: ArrowNativeType> FromIterator<T> for ScalarBuffer<T> {
210    #[inline]
211    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
212        iter.into_iter().collect::<Vec<_>>().into()
213    }
214}
215
216impl<'a, T: ArrowNativeType> IntoIterator for &'a ScalarBuffer<T> {
217    type Item = &'a T;
218    type IntoIter = std::slice::Iter<'a, T>;
219
220    fn into_iter(self) -> Self::IntoIter {
221        self.as_ref().iter()
222    }
223}
224
225impl<T: ArrowNativeType, S: AsRef<[T]> + ?Sized> PartialEq<S> for ScalarBuffer<T> {
226    fn eq(&self, other: &S) -> bool {
227        self.as_ref().eq(other.as_ref())
228    }
229}
230
231impl<T: ArrowNativeType, const N: usize> PartialEq<ScalarBuffer<T>> for [T; N] {
232    fn eq(&self, other: &ScalarBuffer<T>) -> bool {
233        self.as_ref().eq(other.as_ref())
234    }
235}
236
237impl<T: ArrowNativeType> PartialEq<ScalarBuffer<T>> for [T] {
238    fn eq(&self, other: &ScalarBuffer<T>) -> bool {
239        self.as_ref().eq(other.as_ref())
240    }
241}
242
243impl<T: ArrowNativeType> PartialEq<ScalarBuffer<T>> for Vec<T> {
244    fn eq(&self, other: &ScalarBuffer<T>) -> bool {
245        self.as_slice().eq(other.as_ref())
246    }
247}
248
249impl<T: ArrowNativeType + Eq> Eq for ScalarBuffer<T> {}
251
252#[cfg(test)]
253mod tests {
254    use std::{ptr::NonNull, sync::Arc};
255
256    use super::*;
257
258    #[test]
259    fn test_basic() {
260        let expected = [0_i32, 1, 2];
261        let buffer = Buffer::from_iter(expected.iter().cloned());
262        let typed = ScalarBuffer::<i32>::new(buffer.clone(), 0, 3);
263        assert_eq!(*typed, expected);
264
265        let typed = ScalarBuffer::<i32>::new(buffer.clone(), 1, 2);
266        assert_eq!(*typed, expected[1..]);
267
268        let typed = ScalarBuffer::<i32>::new(buffer.clone(), 1, 0);
269        assert!(typed.is_empty());
270
271        let typed = ScalarBuffer::<i32>::new(buffer, 3, 0);
272        assert!(typed.is_empty());
273    }
274
275    #[test]
276    fn test_debug() {
277        let buffer = ScalarBuffer::from(vec![1, 2, 3]);
278        assert_eq!(format!("{buffer:?}"), "ScalarBuffer([1, 2, 3])");
279    }
280
281    #[test]
282    #[should_panic(expected = "Memory pointer is not aligned with the specified scalar type")]
283    fn test_unaligned() {
284        let expected = [0_i32, 1, 2];
285        let buffer = Buffer::from_iter(expected.iter().cloned());
286        let buffer = buffer.slice(1);
287        ScalarBuffer::<i32>::new(buffer, 0, 2);
288    }
289
290    #[test]
291    #[should_panic(expected = "the offset of the new Buffer cannot exceed the existing length")]
292    fn test_length_out_of_bounds() {
293        let buffer = Buffer::from_iter([0_i32, 1, 2]);
294        ScalarBuffer::<i32>::new(buffer, 1, 3);
295    }
296
297    #[test]
298    #[should_panic(expected = "the offset of the new Buffer cannot exceed the existing length")]
299    fn test_offset_out_of_bounds() {
300        let buffer = Buffer::from_iter([0_i32, 1, 2]);
301        ScalarBuffer::<i32>::new(buffer, 4, 0);
302    }
303
304    #[test]
305    #[should_panic(expected = "offset overflow")]
306    fn test_length_overflow() {
307        let buffer = Buffer::from_iter([0_i32, 1, 2]);
308        ScalarBuffer::<i32>::new(buffer, usize::MAX, 1);
309    }
310
311    #[test]
312    #[should_panic(expected = "offset overflow")]
313    fn test_start_overflow() {
314        let buffer = Buffer::from_iter([0_i32, 1, 2]);
315        ScalarBuffer::<i32>::new(buffer, usize::MAX / 4 + 1, 0);
316    }
317
318    #[test]
319    #[should_panic(expected = "length overflow")]
320    fn test_end_overflow() {
321        let buffer = Buffer::from_iter([0_i32, 1, 2]);
322        ScalarBuffer::<i32>::new(buffer, 0, usize::MAX / 4 + 1);
323    }
324
325    #[test]
326    fn convert_from_buffer_builder() {
327        let input = vec![1, 2, 3, 4];
328        let buffer_builder = BufferBuilder::from(input.clone());
329        let scalar_buffer = ScalarBuffer::from(buffer_builder);
330        assert_eq!(scalar_buffer.as_ref(), input);
331    }
332
333    #[test]
334    fn into_vec() {
335        let input = vec![1u8, 2, 3, 4];
336
337        let input_buffer = Buffer::from_vec(input.clone());
339        let input_ptr = input_buffer.as_ptr();
340        let input_len = input_buffer.len();
341        let scalar_buffer = ScalarBuffer::<u8>::new(input_buffer, 0, input_len);
342        let vec = Vec::from(scalar_buffer);
343        assert_eq!(vec.as_slice(), input.as_slice());
344        assert_eq!(vec.as_ptr(), input_ptr);
345
346        let mut input_clone = input.clone();
348        let input_ptr = NonNull::new(input_clone.as_mut_ptr()).unwrap();
349        let dealloc = Arc::new(());
350        let buffer =
351            unsafe { Buffer::from_custom_allocation(input_ptr, input_clone.len(), dealloc as _) };
352        let scalar_buffer = ScalarBuffer::<u8>::new(buffer, 0, input.len());
353        let vec = Vec::from(scalar_buffer);
354        assert_eq!(vec, input.as_slice());
355        assert_ne!(vec.as_ptr(), input_ptr.as_ptr());
356
357        let input_buffer = Buffer::from_vec(input.clone());
359        let input_ptr = input_buffer.as_ptr();
360        let input_len = input_buffer.len();
361        let scalar_buffer = ScalarBuffer::<u8>::new(input_buffer, 1, input_len - 1);
362        let vec = Vec::from(scalar_buffer);
363        assert_eq!(vec.as_slice(), &input[1..]);
364        assert_ne!(vec.as_ptr(), input_ptr);
365
366        let buffer = Buffer::from_slice_ref(input.as_slice());
368        let scalar_buffer = ScalarBuffer::<u8>::new(buffer, 0, input.len());
369        let vec = Vec::from(scalar_buffer);
370        assert_eq!(vec, input.as_slice());
371        assert_ne!(vec.as_ptr(), input.as_ptr());
372    }
373
374    #[test]
375    fn scalar_buffer_impl_eq() {
376        fn are_equal<T: Eq>(a: &T, b: &T) -> bool {
377            a.eq(b)
378        }
379
380        assert!(
381            are_equal(
382                &ScalarBuffer::<i16>::from(vec![23]),
383                &ScalarBuffer::<i16>::from(vec![23])
384            ),
385            "ScalarBuffer should implement Eq if the inner type does"
386        );
387    }
388}