1use ndarray::ArrayD;
4
5use super::dtype::DType;
6use crate::types::{F, I, U};
7
8#[derive(Clone)]
14pub enum Column {
15 Float(ArrayD<F>),
17 Int(ArrayD<I>),
19 Bool(ArrayD<bool>),
21 UInt(ArrayD<U>),
23 U8(ArrayD<u8>),
25 String(ArrayD<String>),
27}
28
29impl Column {
30 pub fn nrows(&self) -> Option<usize> {
35 match self {
36 Column::Float(a) => a.shape().first().copied(),
37 Column::Int(a) => a.shape().first().copied(),
38 Column::Bool(a) => a.shape().first().copied(),
39 Column::UInt(a) => a.shape().first().copied(),
40 Column::U8(a) => a.shape().first().copied(),
41 Column::String(a) => a.shape().first().copied(),
42 }
43 }
44
45 pub fn dtype(&self) -> DType {
47 match self {
48 Column::Float(_) => DType::Float,
49 Column::Int(_) => DType::Int,
50 Column::Bool(_) => DType::Bool,
51 Column::UInt(_) => DType::UInt,
52 Column::U8(_) => DType::U8,
53 Column::String(_) => DType::String,
54 }
55 }
56
57 pub fn shape(&self) -> &[usize] {
59 match self {
60 Column::Float(a) => a.shape(),
61 Column::Int(a) => a.shape(),
62 Column::Bool(a) => a.shape(),
63 Column::UInt(a) => a.shape(),
64 Column::U8(a) => a.shape(),
65 Column::String(a) => a.shape(),
66 }
67 }
68
69 pub fn as_float(&self) -> Option<&ArrayD<F>> {
71 match self {
72 Column::Float(a) => Some(a),
73 _ => None,
74 }
75 }
76
77 pub fn as_float_mut(&mut self) -> Option<&mut ArrayD<F>> {
79 match self {
80 Column::Float(a) => Some(a),
81 _ => None,
82 }
83 }
84
85 pub fn as_int(&self) -> Option<&ArrayD<I>> {
87 match self {
88 Column::Int(a) => Some(a),
89 _ => None,
90 }
91 }
92
93 pub fn as_int_mut(&mut self) -> Option<&mut ArrayD<I>> {
95 match self {
96 Column::Int(a) => Some(a),
97 _ => None,
98 }
99 }
100
101 pub fn as_bool(&self) -> Option<&ArrayD<bool>> {
103 match self {
104 Column::Bool(a) => Some(a),
105 _ => None,
106 }
107 }
108
109 pub fn as_bool_mut(&mut self) -> Option<&mut ArrayD<bool>> {
111 match self {
112 Column::Bool(a) => Some(a),
113 _ => None,
114 }
115 }
116
117 pub fn as_uint(&self) -> Option<&ArrayD<U>> {
119 match self {
120 Column::UInt(a) => Some(a),
121 _ => None,
122 }
123 }
124
125 pub fn as_uint_mut(&mut self) -> Option<&mut ArrayD<U>> {
127 match self {
128 Column::UInt(a) => Some(a),
129 _ => None,
130 }
131 }
132
133 pub fn as_u8(&self) -> Option<&ArrayD<u8>> {
135 match self {
136 Column::U8(a) => Some(a),
137 _ => None,
138 }
139 }
140
141 pub fn as_u8_mut(&mut self) -> Option<&mut ArrayD<u8>> {
143 match self {
144 Column::U8(a) => Some(a),
145 _ => None,
146 }
147 }
148
149 pub fn as_string(&self) -> Option<&ArrayD<String>> {
151 match self {
152 Column::String(a) => Some(a),
153 _ => None,
154 }
155 }
156
157 pub fn as_string_mut(&mut self) -> Option<&mut ArrayD<String>> {
159 match self {
160 Column::String(a) => Some(a),
161 _ => None,
162 }
163 }
164
165 pub fn resize(&mut self, new_nrows: usize) {
174 use ndarray::{Axis, IxDyn, concatenate};
175
176 let current = self.shape()[0];
177 if new_nrows == current {
178 return;
179 }
180
181 match self {
182 Column::Float(a) => {
183 if new_nrows < current {
184 *a = a.slice_axis(Axis(0), (..new_nrows).into()).to_owned();
185 } else {
186 let mut pad_shape = a.shape().to_vec();
187 pad_shape[0] = new_nrows - current;
188 let pad = ArrayD::<F>::zeros(IxDyn(&pad_shape));
189 *a = concatenate(Axis(0), &[a.view(), pad.view()]).unwrap();
190 }
191 }
192 Column::Int(a) => {
193 if new_nrows < current {
194 *a = a.slice_axis(Axis(0), (..new_nrows).into()).to_owned();
195 } else {
196 let mut pad_shape = a.shape().to_vec();
197 pad_shape[0] = new_nrows - current;
198 let pad = ArrayD::<I>::zeros(IxDyn(&pad_shape));
199 *a = concatenate(Axis(0), &[a.view(), pad.view()]).unwrap();
200 }
201 }
202 Column::UInt(a) => {
203 if new_nrows < current {
204 *a = a.slice_axis(Axis(0), (..new_nrows).into()).to_owned();
205 } else {
206 let mut pad_shape = a.shape().to_vec();
207 pad_shape[0] = new_nrows - current;
208 let pad = ArrayD::<U>::zeros(IxDyn(&pad_shape));
209 *a = concatenate(Axis(0), &[a.view(), pad.view()]).unwrap();
210 }
211 }
212 Column::U8(a) => {
213 if new_nrows < current {
214 *a = a.slice_axis(Axis(0), (..new_nrows).into()).to_owned();
215 } else {
216 let mut pad_shape = a.shape().to_vec();
217 pad_shape[0] = new_nrows - current;
218 let pad = ArrayD::<u8>::zeros(IxDyn(&pad_shape));
219 *a = concatenate(Axis(0), &[a.view(), pad.view()]).unwrap();
220 }
221 }
222 Column::Bool(a) => {
223 if new_nrows < current {
224 *a = a.slice_axis(Axis(0), (..new_nrows).into()).to_owned();
225 } else {
226 let mut pad_shape = a.shape().to_vec();
227 pad_shape[0] = new_nrows - current;
228 let pad = ArrayD::<bool>::default(IxDyn(&pad_shape));
229 *a = concatenate(Axis(0), &[a.view(), pad.view()]).unwrap();
230 }
231 }
232 Column::String(a) => {
233 if new_nrows < current {
234 *a = a.slice_axis(Axis(0), (..new_nrows).into()).to_owned();
235 } else {
236 let mut pad_shape = a.shape().to_vec();
237 pad_shape[0] = new_nrows - current;
238 let pad = ArrayD::<String>::default(IxDyn(&pad_shape));
239 *a = concatenate(Axis(0), &[a.view(), pad.view()]).unwrap();
240 }
241 }
242 }
243 }
244}
245
246impl std::fmt::Debug for Column {
247 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248 match self {
249 Column::Float(a) => write!(f, "Column::Float(shape={:?})", a.shape()),
250 Column::Int(a) => write!(f, "Column::Int(shape={:?})", a.shape()),
251 Column::Bool(a) => write!(f, "Column::Bool(shape={:?})", a.shape()),
252 Column::UInt(a) => write!(f, "Column::UInt(shape={:?})", a.shape()),
253 Column::U8(a) => write!(f, "Column::U8(shape={:?})", a.shape()),
254 Column::String(a) => write!(f, "Column::String(shape={:?})", a.shape()),
255 }
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use crate::types::{F, I, U};
263 use ndarray::{Array1, ArrayD};
264
265 fn float_col(n: usize) -> Column {
268 Column::Float(Array1::from_vec(vec![0.0 as F; n]).into_dyn())
269 }
270
271 fn int_col(n: usize) -> Column {
272 Column::Int(Array1::from_vec(vec![0 as I; n]).into_dyn())
273 }
274
275 fn bool_col(n: usize) -> Column {
276 Column::Bool(Array1::from_vec(vec![false; n]).into_dyn())
277 }
278
279 fn uint_col(n: usize) -> Column {
280 Column::UInt(Array1::from_vec(vec![0 as U; n]).into_dyn())
281 }
282
283 fn u8_col(n: usize) -> Column {
284 Column::U8(Array1::from_vec(vec![0u8; n]).into_dyn())
285 }
286
287 fn string_col(n: usize) -> Column {
288 Column::String(Array1::from_vec(vec![String::new(); n]).into_dyn())
289 }
290
291 #[test]
294 fn test_nrows() {
295 assert_eq!(float_col(5).nrows(), Some(5));
296 assert_eq!(int_col(3).nrows(), Some(3));
297 assert_eq!(bool_col(7).nrows(), Some(7));
298 assert_eq!(uint_col(2).nrows(), Some(2));
299 assert_eq!(u8_col(4).nrows(), Some(4));
300 assert_eq!(string_col(1).nrows(), Some(1));
301
302 let rank0 = Column::Float(ArrayD::<F>::from_elem(vec![], 1.0));
304 assert_eq!(rank0.nrows(), None);
305 }
306
307 #[test]
310 fn test_dtype() {
311 assert_eq!(float_col(1).dtype(), DType::Float);
312 assert_eq!(int_col(1).dtype(), DType::Int);
313 assert_eq!(bool_col(1).dtype(), DType::Bool);
314 assert_eq!(uint_col(1).dtype(), DType::UInt);
315 assert_eq!(u8_col(1).dtype(), DType::U8);
316 assert_eq!(string_col(1).dtype(), DType::String);
317 }
318
319 #[test]
322 fn test_shape() {
323 assert_eq!(float_col(4).shape(), &[4]);
325 let col2d = Column::Int(ArrayD::<I>::from_elem(vec![3, 2], 0));
327 assert_eq!(col2d.shape(), &[3, 2]);
328 }
329
330 #[test]
333 fn test_as_float_on_float() {
334 let col = float_col(3);
335 assert!(col.as_float().is_some());
336 assert_eq!(col.as_float().unwrap().len(), 3);
337 }
338
339 #[test]
342 fn test_as_float_on_wrong_type() {
343 assert!(int_col(2).as_float().is_none());
344 assert!(bool_col(2).as_float().is_none());
345 assert!(uint_col(2).as_float().is_none());
346 assert!(u8_col(2).as_float().is_none());
347 assert!(string_col(2).as_float().is_none());
348 }
349
350 #[test]
353 fn test_as_int() {
354 let col = int_col(4);
355 assert!(col.as_int().is_some());
356 assert_eq!(col.as_int().unwrap().len(), 4);
357 assert!(float_col(1).as_int().is_none());
359 assert!(bool_col(1).as_int().is_none());
360 }
361
362 #[test]
365 fn test_as_bool() {
366 let col = bool_col(2);
367 assert!(col.as_bool().is_some());
368 assert_eq!(col.as_bool().unwrap().len(), 2);
369 assert!(float_col(1).as_bool().is_none());
371 assert!(int_col(1).as_bool().is_none());
372 }
373
374 #[test]
377 fn test_as_uint() {
378 let col = uint_col(6);
379 assert!(col.as_uint().is_some());
380 assert_eq!(col.as_uint().unwrap().len(), 6);
381 assert!(float_col(1).as_uint().is_none());
383 assert!(int_col(1).as_uint().is_none());
384 }
385
386 #[test]
389 fn test_as_u8() {
390 let col = u8_col(3);
391 assert!(col.as_u8().is_some());
392 assert_eq!(col.as_u8().unwrap().len(), 3);
393 assert!(float_col(1).as_u8().is_none());
395 assert!(uint_col(1).as_u8().is_none());
396 }
397
398 #[test]
401 fn test_as_string() {
402 let col = string_col(2);
403 assert!(col.as_string().is_some());
404 assert_eq!(col.as_string().unwrap().len(), 2);
405 assert!(float_col(1).as_string().is_none());
407 assert!(int_col(1).as_string().is_none());
408 }
409
410 #[test]
413 fn test_as_float_mut() {
414 let mut col =
415 Column::Float(Array1::from_vec(vec![1.0 as F, 2.0 as F, 3.0 as F]).into_dyn());
416 {
417 let arr = col.as_float_mut().unwrap();
418 arr[0] = 99.0;
419 }
420 let arr = col.as_float().unwrap();
421 assert!((arr[0] - 99.0).abs() < F::EPSILON);
422 let mut int = int_col(1);
424 assert!(int.as_float_mut().is_none());
425 }
426
427 #[test]
430 fn test_debug_format() {
431 let dbg = format!("{:?}", float_col(3));
432 assert!(dbg.contains("Column::Float"));
433 assert!(dbg.contains("shape="));
434
435 let dbg = format!("{:?}", int_col(2));
436 assert!(dbg.contains("Column::Int"));
437
438 let dbg = format!("{:?}", bool_col(1));
439 assert!(dbg.contains("Column::Bool"));
440
441 let dbg = format!("{:?}", uint_col(4));
442 assert!(dbg.contains("Column::UInt"));
443
444 let dbg = format!("{:?}", u8_col(5));
445 assert!(dbg.contains("Column::U8"));
446
447 let dbg = format!("{:?}", string_col(1));
448 assert!(dbg.contains("Column::String"));
449 }
450}