Skip to main content

molrs/block/
column.rs

1//! Internal column representation for heterogeneous data.
2
3use ndarray::ArrayD;
4
5use super::dtype::DType;
6use crate::types::{F, I, U};
7
8/// Internal enum representing a column of data in a Block.
9///
10/// This type is exposed in the public API but users typically don't need to
11/// interact with it directly. Instead, use the type-specific getters like
12/// `get_float()`, `get_int()`, etc.
13#[derive(Clone)]
14pub enum Column {
15    /// Floating point column using the compile-time scalar type [`F`].
16    Float(ArrayD<F>),
17    /// Signed integer column using the compile-time scalar type [`I`].
18    Int(ArrayD<I>),
19    /// Boolean column
20    Bool(ArrayD<bool>),
21    /// Unsigned integer column using the compile-time scalar type [`U`].
22    UInt(ArrayD<U>),
23    /// 8-bit unsigned integer column
24    U8(ArrayD<u8>),
25    /// String column
26    String(ArrayD<String>),
27}
28
29impl Column {
30    /// Returns the number of rows (axis-0 length) of this column.
31    ///
32    /// Returns `None` if the array has rank 0 (which should never happen
33    /// in a valid Block, as rank-0 arrays are rejected during insertion).
34    pub fn nrows(&self) -> Option<usize> {
35        match self {
36            Column::Float(a) => a.shape().first().copied(),
37            Column::Int(a) => a.shape().first().copied(),
38            Column::Bool(a) => a.shape().first().copied(),
39            Column::UInt(a) => a.shape().first().copied(),
40            Column::U8(a) => a.shape().first().copied(),
41            Column::String(a) => a.shape().first().copied(),
42        }
43    }
44
45    /// Returns the data type of this column.
46    pub fn dtype(&self) -> DType {
47        match self {
48            Column::Float(_) => DType::Float,
49            Column::Int(_) => DType::Int,
50            Column::Bool(_) => DType::Bool,
51            Column::UInt(_) => DType::UInt,
52            Column::U8(_) => DType::U8,
53            Column::String(_) => DType::String,
54        }
55    }
56
57    /// Returns the shape of the underlying array.
58    pub fn shape(&self) -> &[usize] {
59        match self {
60            Column::Float(a) => a.shape(),
61            Column::Int(a) => a.shape(),
62            Column::Bool(a) => a.shape(),
63            Column::UInt(a) => a.shape(),
64            Column::U8(a) => a.shape(),
65            Column::String(a) => a.shape(),
66        }
67    }
68
69    /// Returns a reference to the float data, or `None` if this column is not `Float`.
70    pub fn as_float(&self) -> Option<&ArrayD<F>> {
71        match self {
72            Column::Float(a) => Some(a),
73            _ => None,
74        }
75    }
76
77    /// Returns a mutable reference to the float data, or `None` if not `Float`.
78    pub fn as_float_mut(&mut self) -> Option<&mut ArrayD<F>> {
79        match self {
80            Column::Float(a) => Some(a),
81            _ => None,
82        }
83    }
84
85    /// Returns a reference to the integer data, or `None` if not `Int`.
86    pub fn as_int(&self) -> Option<&ArrayD<I>> {
87        match self {
88            Column::Int(a) => Some(a),
89            _ => None,
90        }
91    }
92
93    /// Returns a mutable reference to the integer data, or `None` if not `Int`.
94    pub fn as_int_mut(&mut self) -> Option<&mut ArrayD<I>> {
95        match self {
96            Column::Int(a) => Some(a),
97            _ => None,
98        }
99    }
100
101    /// Returns a reference to the boolean data, or `None` if not `Bool`.
102    pub fn as_bool(&self) -> Option<&ArrayD<bool>> {
103        match self {
104            Column::Bool(a) => Some(a),
105            _ => None,
106        }
107    }
108
109    /// Returns a mutable reference to the boolean data, or `None` if not `Bool`.
110    pub fn as_bool_mut(&mut self) -> Option<&mut ArrayD<bool>> {
111        match self {
112            Column::Bool(a) => Some(a),
113            _ => None,
114        }
115    }
116
117    /// Returns a reference to the unsigned integer data, or `None` if not `UInt`.
118    pub fn as_uint(&self) -> Option<&ArrayD<U>> {
119        match self {
120            Column::UInt(a) => Some(a),
121            _ => None,
122        }
123    }
124
125    /// Returns a mutable reference to the unsigned integer data, or `None` if not `UInt`.
126    pub fn as_uint_mut(&mut self) -> Option<&mut ArrayD<U>> {
127        match self {
128            Column::UInt(a) => Some(a),
129            _ => None,
130        }
131    }
132
133    /// Returns a reference to the u8 data, or `None` if not `U8`.
134    pub fn as_u8(&self) -> Option<&ArrayD<u8>> {
135        match self {
136            Column::U8(a) => Some(a),
137            _ => None,
138        }
139    }
140
141    /// Returns a mutable reference to the u8 data, or `None` if not `U8`.
142    pub fn as_u8_mut(&mut self) -> Option<&mut ArrayD<u8>> {
143        match self {
144            Column::U8(a) => Some(a),
145            _ => None,
146        }
147    }
148
149    /// Returns a reference to the string data, or `None` if not `String`.
150    pub fn as_string(&self) -> Option<&ArrayD<String>> {
151        match self {
152            Column::String(a) => Some(a),
153            _ => None,
154        }
155    }
156
157    /// Returns a mutable reference to the string data, or `None` if not `String`.
158    pub fn as_string_mut(&mut self) -> Option<&mut ArrayD<String>> {
159        match self {
160            Column::String(a) => Some(a),
161            _ => None,
162        }
163    }
164
165    /// Resize this column along axis 0 to `new_nrows`.
166    ///
167    /// - If `new_nrows` < current nrows, slices to keep only the first `new_nrows` rows.
168    /// - If `new_nrows` > current nrows, extends with default values
169    ///   (0.0 for Float, 0 for Int/UInt/U8, false for Bool, empty string for String).
170    /// - If `new_nrows` == current nrows, this is a no-op.
171    ///
172    /// Only axis 0 is modified; trailing dimensions are preserved.
173    pub fn resize(&mut self, new_nrows: usize) {
174        use ndarray::{Axis, IxDyn, concatenate};
175
176        let current = self.shape()[0];
177        if new_nrows == current {
178            return;
179        }
180
181        match self {
182            Column::Float(a) => {
183                if new_nrows < current {
184                    *a = a.slice_axis(Axis(0), (..new_nrows).into()).to_owned();
185                } else {
186                    let mut pad_shape = a.shape().to_vec();
187                    pad_shape[0] = new_nrows - current;
188                    let pad = ArrayD::<F>::zeros(IxDyn(&pad_shape));
189                    *a = concatenate(Axis(0), &[a.view(), pad.view()]).unwrap();
190                }
191            }
192            Column::Int(a) => {
193                if new_nrows < current {
194                    *a = a.slice_axis(Axis(0), (..new_nrows).into()).to_owned();
195                } else {
196                    let mut pad_shape = a.shape().to_vec();
197                    pad_shape[0] = new_nrows - current;
198                    let pad = ArrayD::<I>::zeros(IxDyn(&pad_shape));
199                    *a = concatenate(Axis(0), &[a.view(), pad.view()]).unwrap();
200                }
201            }
202            Column::UInt(a) => {
203                if new_nrows < current {
204                    *a = a.slice_axis(Axis(0), (..new_nrows).into()).to_owned();
205                } else {
206                    let mut pad_shape = a.shape().to_vec();
207                    pad_shape[0] = new_nrows - current;
208                    let pad = ArrayD::<U>::zeros(IxDyn(&pad_shape));
209                    *a = concatenate(Axis(0), &[a.view(), pad.view()]).unwrap();
210                }
211            }
212            Column::U8(a) => {
213                if new_nrows < current {
214                    *a = a.slice_axis(Axis(0), (..new_nrows).into()).to_owned();
215                } else {
216                    let mut pad_shape = a.shape().to_vec();
217                    pad_shape[0] = new_nrows - current;
218                    let pad = ArrayD::<u8>::zeros(IxDyn(&pad_shape));
219                    *a = concatenate(Axis(0), &[a.view(), pad.view()]).unwrap();
220                }
221            }
222            Column::Bool(a) => {
223                if new_nrows < current {
224                    *a = a.slice_axis(Axis(0), (..new_nrows).into()).to_owned();
225                } else {
226                    let mut pad_shape = a.shape().to_vec();
227                    pad_shape[0] = new_nrows - current;
228                    let pad = ArrayD::<bool>::default(IxDyn(&pad_shape));
229                    *a = concatenate(Axis(0), &[a.view(), pad.view()]).unwrap();
230                }
231            }
232            Column::String(a) => {
233                if new_nrows < current {
234                    *a = a.slice_axis(Axis(0), (..new_nrows).into()).to_owned();
235                } else {
236                    let mut pad_shape = a.shape().to_vec();
237                    pad_shape[0] = new_nrows - current;
238                    let pad = ArrayD::<String>::default(IxDyn(&pad_shape));
239                    *a = concatenate(Axis(0), &[a.view(), pad.view()]).unwrap();
240                }
241            }
242        }
243    }
244}
245
246impl std::fmt::Debug for Column {
247    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248        match self {
249            Column::Float(a) => write!(f, "Column::Float(shape={:?})", a.shape()),
250            Column::Int(a) => write!(f, "Column::Int(shape={:?})", a.shape()),
251            Column::Bool(a) => write!(f, "Column::Bool(shape={:?})", a.shape()),
252            Column::UInt(a) => write!(f, "Column::UInt(shape={:?})", a.shape()),
253            Column::U8(a) => write!(f, "Column::U8(shape={:?})", a.shape()),
254            Column::String(a) => write!(f, "Column::String(shape={:?})", a.shape()),
255        }
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use crate::types::{F, I, U};
263    use ndarray::{Array1, ArrayD};
264
265    // ---- helpers ----
266
267    fn float_col(n: usize) -> Column {
268        Column::Float(Array1::from_vec(vec![0.0 as F; n]).into_dyn())
269    }
270
271    fn int_col(n: usize) -> Column {
272        Column::Int(Array1::from_vec(vec![0 as I; n]).into_dyn())
273    }
274
275    fn bool_col(n: usize) -> Column {
276        Column::Bool(Array1::from_vec(vec![false; n]).into_dyn())
277    }
278
279    fn uint_col(n: usize) -> Column {
280        Column::UInt(Array1::from_vec(vec![0 as U; n]).into_dyn())
281    }
282
283    fn u8_col(n: usize) -> Column {
284        Column::U8(Array1::from_vec(vec![0u8; n]).into_dyn())
285    }
286
287    fn string_col(n: usize) -> Column {
288        Column::String(Array1::from_vec(vec![String::new(); n]).into_dyn())
289    }
290
291    // ---- 1. nrows ----
292
293    #[test]
294    fn test_nrows() {
295        assert_eq!(float_col(5).nrows(), Some(5));
296        assert_eq!(int_col(3).nrows(), Some(3));
297        assert_eq!(bool_col(7).nrows(), Some(7));
298        assert_eq!(uint_col(2).nrows(), Some(2));
299        assert_eq!(u8_col(4).nrows(), Some(4));
300        assert_eq!(string_col(1).nrows(), Some(1));
301
302        // rank-0 array has no axis-0 dimension
303        let rank0 = Column::Float(ArrayD::<F>::from_elem(vec![], 1.0));
304        assert_eq!(rank0.nrows(), None);
305    }
306
307    // ---- 2. dtype ----
308
309    #[test]
310    fn test_dtype() {
311        assert_eq!(float_col(1).dtype(), DType::Float);
312        assert_eq!(int_col(1).dtype(), DType::Int);
313        assert_eq!(bool_col(1).dtype(), DType::Bool);
314        assert_eq!(uint_col(1).dtype(), DType::UInt);
315        assert_eq!(u8_col(1).dtype(), DType::U8);
316        assert_eq!(string_col(1).dtype(), DType::String);
317    }
318
319    // ---- 3. shape ----
320
321    #[test]
322    fn test_shape() {
323        // 1-D
324        assert_eq!(float_col(4).shape(), &[4]);
325        // 2-D
326        let col2d = Column::Int(ArrayD::<I>::from_elem(vec![3, 2], 0));
327        assert_eq!(col2d.shape(), &[3, 2]);
328    }
329
330    // ---- 4. as_float on Float ----
331
332    #[test]
333    fn test_as_float_on_float() {
334        let col = float_col(3);
335        assert!(col.as_float().is_some());
336        assert_eq!(col.as_float().unwrap().len(), 3);
337    }
338
339    // ---- 5. as_float on wrong type ----
340
341    #[test]
342    fn test_as_float_on_wrong_type() {
343        assert!(int_col(2).as_float().is_none());
344        assert!(bool_col(2).as_float().is_none());
345        assert!(uint_col(2).as_float().is_none());
346        assert!(u8_col(2).as_float().is_none());
347        assert!(string_col(2).as_float().is_none());
348    }
349
350    // ---- 6. as_int ----
351
352    #[test]
353    fn test_as_int() {
354        let col = int_col(4);
355        assert!(col.as_int().is_some());
356        assert_eq!(col.as_int().unwrap().len(), 4);
357        // wrong types return None
358        assert!(float_col(1).as_int().is_none());
359        assert!(bool_col(1).as_int().is_none());
360    }
361
362    // ---- 7. as_bool ----
363
364    #[test]
365    fn test_as_bool() {
366        let col = bool_col(2);
367        assert!(col.as_bool().is_some());
368        assert_eq!(col.as_bool().unwrap().len(), 2);
369        // wrong types return None
370        assert!(float_col(1).as_bool().is_none());
371        assert!(int_col(1).as_bool().is_none());
372    }
373
374    // ---- 8. as_uint ----
375
376    #[test]
377    fn test_as_uint() {
378        let col = uint_col(6);
379        assert!(col.as_uint().is_some());
380        assert_eq!(col.as_uint().unwrap().len(), 6);
381        // wrong types return None
382        assert!(float_col(1).as_uint().is_none());
383        assert!(int_col(1).as_uint().is_none());
384    }
385
386    // ---- 9. as_u8 ----
387
388    #[test]
389    fn test_as_u8() {
390        let col = u8_col(3);
391        assert!(col.as_u8().is_some());
392        assert_eq!(col.as_u8().unwrap().len(), 3);
393        // wrong types return None
394        assert!(float_col(1).as_u8().is_none());
395        assert!(uint_col(1).as_u8().is_none());
396    }
397
398    // ---- 10. as_string ----
399
400    #[test]
401    fn test_as_string() {
402        let col = string_col(2);
403        assert!(col.as_string().is_some());
404        assert_eq!(col.as_string().unwrap().len(), 2);
405        // wrong types return None
406        assert!(float_col(1).as_string().is_none());
407        assert!(int_col(1).as_string().is_none());
408    }
409
410    // ---- 11. as_float_mut ----
411
412    #[test]
413    fn test_as_float_mut() {
414        let mut col =
415            Column::Float(Array1::from_vec(vec![1.0 as F, 2.0 as F, 3.0 as F]).into_dyn());
416        {
417            let arr = col.as_float_mut().unwrap();
418            arr[0] = 99.0;
419        }
420        let arr = col.as_float().unwrap();
421        assert!((arr[0] - 99.0).abs() < F::EPSILON);
422        // wrong variant returns None
423        let mut int = int_col(1);
424        assert!(int.as_float_mut().is_none());
425    }
426
427    // ---- 12. Debug format ----
428
429    #[test]
430    fn test_debug_format() {
431        let dbg = format!("{:?}", float_col(3));
432        assert!(dbg.contains("Column::Float"));
433        assert!(dbg.contains("shape="));
434
435        let dbg = format!("{:?}", int_col(2));
436        assert!(dbg.contains("Column::Int"));
437
438        let dbg = format!("{:?}", bool_col(1));
439        assert!(dbg.contains("Column::Bool"));
440
441        let dbg = format!("{:?}", uint_col(4));
442        assert!(dbg.contains("Column::UInt"));
443
444        let dbg = format!("{:?}", u8_col(5));
445        assert!(dbg.contains("Column::U8"));
446
447        let dbg = format!("{:?}", string_col(1));
448        assert!(dbg.contains("Column::String"));
449    }
450}