1use std::fmt::{Display, Formatter};
23
24use std::iter::FromIterator;
25use std::sync::Arc;
26
27use crate::enums::{error::MinarrowError, shape_dim::ShapeDim};
28use crate::structs::field::Field;
29use crate::structs::field_array::FieldArray;
30use crate::structs::table::Table;
31use crate::traits::{concatenate::Concatenate, shape::Shape};
32#[cfg(feature = "views")]
33use crate::{SuperTableV, TableV};
34
35#[derive(Clone, Debug, PartialEq)]
59pub struct SuperTable {
60 pub batches: Vec<Arc<Table>>,
61 pub schema: Vec<Arc<Field>>,
62 pub n_rows: usize,
63 pub name: String,
64}
65
66impl SuperTable {
67 pub fn new(name: String) -> Self {
69 Self {
70 batches: Vec::new(),
71 schema: Vec::new(),
72 n_rows: 0,
73 name,
74 }
75 }
76
77 pub fn from_batches(batches: Vec<Arc<Table>>, name_override: Option<String>) -> Self {
81 if batches.is_empty() {
82 return Self::new("Unnamed".into());
83 }
84
85 let name = name_override.unwrap_or_else(|| batches[0].name.clone());
86 let schema: Vec<Arc<Field>> = batches[0].cols.iter().map(|fa| fa.field.clone()).collect();
87 let n_cols = schema.len();
88 let mut total_rows = 0usize;
89
90 for (b_idx, batch) in batches.iter().enumerate() {
92 assert_eq!(
93 batch.n_cols(),
94 n_cols,
95 "Batch {b_idx} column-count mismatch"
96 );
97 for col_idx in 0..n_cols {
98 let field = &schema[col_idx];
99 let fa = &batch.cols[col_idx];
100 assert_eq!(
101 &fa.field, field,
102 "Batch {b_idx} col {col_idx} schema mismatch"
103 );
104 }
105 total_rows += batch.n_rows;
106 }
107
108 Self {
109 batches,
110 schema,
111 n_rows: total_rows,
112 name,
113 }
114 }
115
116 pub fn push(&mut self, batch: Arc<Table>) {
120 if self.batches.is_empty() {
121 self.schema = batch.cols.iter().map(|fa| fa.field.clone()).collect();
122 }
123 let n_cols = self.schema.len();
124 assert_eq!(batch.n_cols(), n_cols, "Pushed batch column-count mismatch");
125 for col_idx in 0..n_cols {
126 let field = &self.schema[col_idx];
127 let fa = &batch.cols[col_idx];
128 assert_eq!(
129 &fa.field, field,
130 "Pushed batch col {col_idx} schema mismatch"
131 );
132 }
133 self.n_rows += batch.n_rows;
134 self.batches.push(batch);
135 }
136
137 pub fn to_table(&self, name: Option<&str>) -> Table {
141 assert!(!self.batches.is_empty(), "to_table() on empty BatchedTable");
142 let n_cols = self.schema.len();
143 let mut unified_cols = Vec::with_capacity(n_cols);
144
145 for col_idx in 0..n_cols {
146 let field = self.schema[col_idx].clone();
147 let mut arr = self.batches[0].cols[col_idx].array.clone();
149 for batch in self.batches.iter().skip(1) {
150 arr.concat_array(&batch.cols[col_idx].array);
151 }
152 let null_count = arr.null_count();
153 unified_cols.push(FieldArray {
154 field,
155 array: arr.clone(),
156 null_count,
157 });
158 }
159
160 Table {
161 cols: unified_cols,
162 n_rows: self.n_rows,
163 name: name
164 .map(str::to_owned)
165 .unwrap_or_else(|| "unified_table".to_string()),
166 }
167 }
168
169 #[inline]
172 pub fn n_cols(&self) -> usize {
173 self.schema.len()
174 }
175
176 #[inline]
182 pub fn cols(&self) -> Vec<Arc<Field>> {
183 self.batches[0]
184 .cols()
185 .iter()
186 .map(|x| x.field.clone())
187 .collect()
188 }
189
190 #[inline]
191 pub fn n_rows(&self) -> usize {
192 self.n_rows
193 }
194
195 #[inline]
196 pub fn n_batches(&self) -> usize {
197 self.batches.len()
198 }
199 #[inline]
200 pub fn len(&self) -> usize {
201 self.n_rows
202 }
203 #[inline]
204 pub fn is_empty(&self) -> bool {
205 self.n_rows == 0
206 }
207 #[inline]
208 pub fn schema(&self) -> &[Arc<Field>] {
209 &self.schema
210 }
211 #[inline]
212 pub fn batches(&self) -> &[Arc<Table>] {
213 &self.batches
214 }
215 #[inline]
216 pub fn batch(&self, idx: usize) -> Option<&Arc<Table>> {
217 self.batches.get(idx)
218 }
219
220 #[cfg(feature = "views")]
222 pub fn view(&self, offset: usize, len: usize) -> SuperTableV {
223 assert!(offset + len <= self.n_rows, "slice out of bounds");
224 let mut slices = Vec::<TableV>::new();
225 let mut remaining = len;
226 let mut global_row = offset;
227
228 for batch in &self.batches {
229 if global_row >= batch.n_rows {
230 global_row -= batch.n_rows;
231 continue;
232 }
233 let take = (batch.n_rows - global_row).min(remaining);
234 slices.push(TableV::from_arc_table(batch.clone(), global_row, take));
235 global_row = 0;
236 remaining -= take;
237 if remaining == 0 {
238 break;
239 }
240 }
241 SuperTableV { slices, len }
242 }
243
244 #[cfg(feature = "views")]
245 pub fn from_views(slices: &[TableV], name: String) -> Self {
246 assert!(!slices.is_empty(), "from_slices: no slices provided");
247 let n_cols = slices[0].n_cols();
248 let mut batches = Vec::with_capacity(slices.len());
249 let mut total_rows = 0usize;
250 for slice in slices {
251 let table = slice.to_table();
252 assert_eq!(table.n_cols(), n_cols, "Batch column-count mismatch");
253 total_rows += table.n_rows;
254 batches.push(table.into());
255 }
256 let schema = slices[0].fields.iter().cloned().collect();
257 Self {
258 batches,
259 schema,
260 n_rows: total_rows,
261 name,
262 }
263 }
264}
265
266impl Default for SuperTable {
267 fn default() -> Self {
268 Self::new("Unnamed".into())
269 }
270}
271
272impl FromIterator<Table> for SuperTable {
273 fn from_iter<T: IntoIterator<Item = Table>>(iter: T) -> Self {
274 let batches: Vec<Arc<Table>> = iter.into_iter().map(|x| x.into()).collect();
275 SuperTable::from_batches(batches, None)
276 }
277}
278
279impl Shape for SuperTable {
280 fn shape(&self) -> ShapeDim {
281 ShapeDim::Rank2 {
282 rows: self.n_rows(),
283 cols: self.n_cols(),
284 }
285 }
286}
287
288impl Concatenate for SuperTable {
289 fn concat(self, other: Self) -> Result<Self, MinarrowError> {
300 if self.batches.is_empty() && other.batches.is_empty() {
302 return Ok(SuperTable::new(format!("{}+{}", self.name, other.name)));
303 }
304
305 if self.batches.is_empty() {
307 let mut result = other;
308 result.name = format!("{}+{}", self.name, result.name);
309 return Ok(result);
310 }
311 if other.batches.is_empty() {
312 let mut result = self;
313 result.name = format!("{}+{}", result.name, other.name);
314 return Ok(result);
315 }
316
317 if self.schema.len() != other.schema.len() {
319 return Err(MinarrowError::IncompatibleTypeError {
320 from: "SuperTable",
321 to: "SuperTable",
322 message: Some(format!(
323 "Cannot concatenate SuperTables with different column counts: {} vs {}",
324 self.schema.len(),
325 other.schema.len()
326 )),
327 });
328 }
329
330 for (col_idx, (self_field, other_field)) in
332 self.schema.iter().zip(other.schema.iter()).enumerate()
333 {
334 if self_field.name != other_field.name {
335 return Err(MinarrowError::IncompatibleTypeError {
336 from: "SuperTable",
337 to: "SuperTable",
338 message: Some(format!(
339 "Column {} name mismatch: '{}' vs '{}'",
340 col_idx, self_field.name, other_field.name
341 )),
342 });
343 }
344
345 if self_field.dtype != other_field.dtype {
346 return Err(MinarrowError::IncompatibleTypeError {
347 from: "SuperTable",
348 to: "SuperTable",
349 message: Some(format!(
350 "Column '{}' type mismatch: {:?} vs {:?}",
351 self_field.name, self_field.dtype, other_field.dtype
352 )),
353 });
354 }
355
356 if self_field.nullable != other_field.nullable {
357 return Err(MinarrowError::IncompatibleTypeError {
358 from: "SuperTable",
359 to: "SuperTable",
360 message: Some(format!(
361 "Column '{}' nullable mismatch: {} vs {}",
362 self_field.name, self_field.nullable, other_field.nullable
363 )),
364 });
365 }
366 }
367
368 let mut result_batches = self.batches;
370 result_batches.extend(other.batches);
371 let total_rows = self.n_rows + other.n_rows;
372
373 Ok(SuperTable {
374 batches: result_batches,
375 schema: self.schema,
376 n_rows: total_rows,
377 name: format!("{}+{}", self.name, other.name),
378 })
379 }
380}
381
382impl Display for SuperTable {
383 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
384 writeln!(
385 f,
386 "SuperTable \"{}\" [{} rows, {} columns, {} batches]",
387 self.name,
388 self.n_rows,
389 self.schema.len(),
390 self.batches.len()
391 )?;
392
393 for (batch_idx, batch) in self.batches.iter().enumerate() {
394 writeln!(
395 f,
396 " ├─ Batch {batch_idx}: {} rows, {} columns",
397 batch.n_rows,
398 batch.n_cols()
399 )?;
400 for (col_idx, col) in batch.cols.iter().enumerate() {
401 let indent = " │ ";
402 writeln!(
403 f,
404 "{indent}Col {col_idx}: \"{}\" (dtype: {}, nulls: {})",
405 col.field.name, col.field.dtype, col.null_count
406 )?;
407 for line in format!("{}", col.array).lines() {
408 writeln!(f, "{indent} {line}")?;
409 }
410 }
411 }
412
413 Ok(())
414 }
415}
416
417#[cfg(feature = "views")]
418impl From<SuperTableV> for SuperTable {
419 fn from(super_table_v: SuperTableV) -> Self {
420 if super_table_v.is_empty() {
421 return SuperTable::new("".to_string());
422 }
423 SuperTable::from_views(&super_table_v.slices, "SuperTable".to_string())
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430 use crate::ffi::arrow_dtype::ArrowType;
431 use crate::{Array, Field, FieldArray, NumericArray, Table};
432
433 fn fa(name: &str, vals: &[i32]) -> FieldArray {
434 let arr = Array::from_int32(crate::IntegerArray::<i32>::from_slice(vals));
435 let field = Field::new(name.to_string(), ArrowType::Int32, false, None);
436 FieldArray::new(field, arr)
437 }
438
439 fn table(cols: Vec<FieldArray>) -> Table {
440 let n_rows = cols[0].len();
441 for c in &cols {
442 assert_eq!(c.len(), n_rows, "all columns must have same len for Table");
443 }
444 Table {
445 cols,
446 n_rows,
447 name: "batch".to_string(),
448 }
449 }
450
451 #[test]
452 fn test_empty_and_default() {
453 let t = SuperTable::default();
454 assert!(t.is_empty());
455 assert_eq!(t.n_cols(), 0);
456 assert_eq!(t.n_batches(), 0);
457 assert_eq!(t.len(), 0);
458 }
459
460 #[test]
461 fn test_from_batches_basic() {
462 let col1 = fa("a", &[1, 2, 3]);
463 let col2 = fa("b", &[10, 11, 12]);
464 let col3 = fa("a", &[4, 5]);
465 let col4 = fa("b", &[13, 14]);
466 let batch1 = Arc::new(table(vec![col1.clone(), col2.clone()]));
467 let batch2 = Arc::new(table(vec![col3.clone(), col4.clone()]));
468 let batches = vec![batch1, batch2].into();
469
470 let t = SuperTable::from_batches(batches, None);
471 assert_eq!(t.n_cols(), 2);
472 assert_eq!(t.n_batches(), 2);
473 assert_eq!(t.len(), 5);
474 assert_eq!(t.schema()[0].name, "a");
475 assert_eq!(t.schema()[1].name, "b");
476 assert_eq!(t.batches()[0].cols[0], col1);
477 assert_eq!(t.batches()[1].cols[1], col4);
478 }
479
480 #[test]
481 #[should_panic(expected = "column-count mismatch")]
482 fn test_from_batches_col_count_mismatch() {
483 let batch1 = Arc::new(table(vec![fa("a", &[1, 2])]));
484 let batch2 = Arc::new(table(vec![fa("a", &[3, 4]), fa("b", &[5, 6])]));
485 SuperTable::from_batches(vec![batch1, batch2].into(), None);
486 }
487
488 #[test]
489 #[should_panic(expected = "schema mismatch")]
490 fn test_from_batches_schema_mismatch() {
491 let batch1 = Arc::new(table(vec![fa("a", &[1, 2])]));
492 let mut wrong = fa("z", &[3, 4]);
493 let mut mismatched_field = (*wrong.field).clone();
494 mismatched_field.dtype = ArrowType::Int32;
495 wrong.field = Arc::new(mismatched_field);
496 let batch2 = Arc::new(table(vec![wrong]));
497 SuperTable::from_batches(vec![batch1, batch2].into(), None);
498 }
499
500 #[test]
501 fn test_push_and_to_table() {
502 let mut t = SuperTable::default();
503 t.push(Arc::new(table(vec![fa("x", &[1, 2]), fa("y", &[3, 4])])));
504 t.push(Arc::new(table(vec![fa("x", &[5]), fa("y", &[6])])));
505 assert_eq!(t.n_cols(), 2);
506 assert_eq!(t.n_batches(), 2);
507 assert_eq!(t.len(), 3);
508 let tab = t.to_table(Some("joined"));
509 assert_eq!(tab.name, "joined");
510 assert_eq!(tab.n_rows, 3);
511 assert_eq!(tab.cols[0].field.name, "x");
512 assert_eq!(tab.cols[1].field.name, "y");
513 }
514
515 #[test]
516 #[should_panic(expected = "column-count mismatch")]
517 fn test_push_col_count_mismatch() {
518 let mut t = SuperTable::default();
519 t.push(Arc::new(table(vec![fa("a", &[1, 2])])));
520 t.push(Arc::new(table(vec![fa("a", &[3, 4]), fa("b", &[5, 6])])));
521 }
522
523 #[cfg(feature = "views")]
524 #[test]
525 fn test_slice_and_owned_table() {
526 let batch1 = Arc::new(table(vec![fa("q", &[1, 2, 3]), fa("w", &[4, 5, 6])]));
527 let batch2 = Arc::new(table(vec![fa("q", &[7, 8]), fa("w", &[9, 10])]));
528 let t = SuperTable::from_batches(vec![batch1, batch2].into(), None);
529
530 let slice = t.view(2, 3);
532 assert_eq!(slice.len, 3);
533 assert_eq!(slice.slices.len(), 2);
534
535 let owned = slice.to_table(Some("part"));
536 assert_eq!(owned.name, "part");
537 assert_eq!(owned.n_rows, 3);
538 assert_eq!(owned.cols[0].field.name, "q");
539 assert_eq!(owned.cols[1].field.name, "w");
540
541 let arr = &owned.cols[0].array;
542 if let Array::NumericArray(NumericArray::Int32(ints)) = arr {
543 assert_eq!(ints.data.as_slice(), &[3, 7, 8]);
544 } else {
545 panic!("expected Int32 array");
546 }
547
548 let arr = &owned.cols[1].array;
549 if let Array::NumericArray(NumericArray::Int32(ints)) = arr {
550 assert_eq!(ints.data.as_slice(), &[6, 9, 10]);
551 } else {
552 panic!("expected Int32 array");
553 }
554 }
555
556 #[test]
557 fn test_schema_and_batch_access() {
558 let t = SuperTable::from_batches(vec![Arc::new(table(vec![fa("alpha", &[1, 2])]))], None);
559 assert_eq!(t.n_cols(), 1);
560 assert_eq!(t.schema()[0].name, "alpha");
561 assert!(t.batch(0).is_some());
562 assert!(t.batch(5).is_none());
563 assert_eq!(t.batches().len(), 1);
564 }
565
566 #[cfg(feature = "views")]
567 #[test]
568 fn test_from_slices() {
569 let batch1 = Arc::new(table(vec![fa("x", &[1, 2]), fa("y", &[3, 4])]));
570 let batch2 = Arc::new(table(vec![fa("x", &[5, 6]), fa("y", &[7, 8])]));
571 let t = SuperTable::from_batches(vec![batch1.clone(), batch2.clone()], None);
572
573 let mut table_slices = Vec::new();
575 for i in 0..t.len() {
576 let bts = t.view(i, 1);
577 for ts in bts.slices.clone() {
578 table_slices.push(ts);
579 }
580 }
581
582 let rebuilt = SuperTable::from_views(&table_slices, "rebuilt".to_string());
584
585 assert_eq!(rebuilt.n_cols(), t.n_cols());
586 assert_eq!(rebuilt.len(), t.len());
587
588 for (left, right) in rebuilt.schema().iter().zip(t.schema()) {
590 assert_eq!(left.name, right.name);
591 assert_eq!(left.dtype, right.dtype);
592 }
593
594 let expected_x = [1, 2, 5, 6];
596 let expected_y = [3, 4, 7, 8];
597 for (col_idx, expected) in [expected_x.as_slice(), expected_y.as_slice()]
598 .iter()
599 .enumerate()
600 {
601 let arr = rebuilt.to_table(None).cols[col_idx].array.clone();
602 if let Array::NumericArray(NumericArray::Int32(ints)) = arr {
603 assert_eq!(ints.data.as_slice(), *expected);
604 } else {
605 panic!("unexpected array type at col {col_idx}");
606 }
607 }
608 }
609}