1use arrow_array::{
7 Array, ArrayRef, BooleanArray, Float64Array, Int64Array, RecordBatch, StringArray,
8 TimestampMicrosecondArray,
9};
10use arrow_schema::{DataType, Field, Schema};
11use std::sync::Arc;
12
13
14#[derive(Debug, Clone)]
19pub struct ColumnPtrs {
20 pub values_ptr: *const u8,
22 pub offsets_ptr: *const u8,
24 pub validity_ptr: *const u8,
26 pub stride: usize,
28 pub data_type: DataType,
30}
31
32unsafe impl Send for ColumnPtrs {}
35unsafe impl Sync for ColumnPtrs {}
36
37impl ColumnPtrs {
38 fn from_array(array: &ArrayRef) -> Self {
40 let data = array.to_data();
41 let data_type = data.data_type().clone();
42
43 let (values_ptr, stride) = match &data_type {
45 DataType::Float64 => {
46 let ptr = if !data.buffers().is_empty() {
47 data.buffers()[0].as_ptr().wrapping_add(data.offset() * 8)
48 } else {
49 std::ptr::null()
50 };
51 (ptr, 8)
52 }
53 DataType::Int64 | DataType::Timestamp(_, _) => {
54 let ptr = if !data.buffers().is_empty() {
55 data.buffers()[0].as_ptr().wrapping_add(data.offset() * 8)
56 } else {
57 std::ptr::null()
58 };
59 (ptr, 8)
60 }
61 DataType::Int32 | DataType::Float32 => {
62 let ptr = if !data.buffers().is_empty() {
63 data.buffers()[0].as_ptr().wrapping_add(data.offset() * 4)
64 } else {
65 std::ptr::null()
66 };
67 (ptr, 4)
68 }
69 DataType::Boolean => {
70 let ptr = if !data.buffers().is_empty() {
72 data.buffers()[0].as_ptr()
73 } else {
74 std::ptr::null()
75 };
76 (ptr, 0)
77 }
78 DataType::Utf8 => {
79 let ptr = if data.buffers().len() > 1 {
81 data.buffers()[1].as_ptr()
82 } else {
83 std::ptr::null()
84 };
85 (ptr, 0) }
87 _ => (std::ptr::null(), 0),
88 };
89
90 let offsets_ptr = match &data_type {
92 DataType::Utf8 => {
93 if !data.buffers().is_empty() {
94 data.buffers()[0].as_ptr().wrapping_add(data.offset() * 4)
95 } else {
96 std::ptr::null()
97 }
98 }
99 _ => std::ptr::null(),
100 };
101
102 let validity_ptr = data
104 .nulls()
105 .map(|nulls| nulls.buffer().as_ptr())
106 .unwrap_or(std::ptr::null());
107
108 ColumnPtrs {
109 values_ptr,
110 offsets_ptr,
111 validity_ptr,
112 stride,
113 data_type,
114 }
115 }
116}
117
118#[derive(Debug, Clone)]
123pub struct DataTable {
124 batch: RecordBatch,
125 type_name: Option<String>,
127 schema_id: Option<u32>,
129 column_ptrs: Vec<ColumnPtrs>,
131 index_col: Option<String>,
133}
134
135impl DataTable {
136 fn build_column_ptrs(batch: &RecordBatch) -> Vec<ColumnPtrs> {
138 (0..batch.num_columns())
139 .map(|i| ColumnPtrs::from_array(batch.column(i)))
140 .collect()
141 }
142
143 pub fn new(batch: RecordBatch) -> Self {
145 let column_ptrs = Self::build_column_ptrs(&batch);
146 Self {
147 batch,
148 type_name: None,
149 schema_id: None,
150 column_ptrs,
151 index_col: None,
152 }
153 }
154
155 pub fn with_type_name(batch: RecordBatch, type_name: String) -> Self {
157 let column_ptrs = Self::build_column_ptrs(&batch);
158 Self {
159 batch,
160 type_name: Some(type_name),
161 schema_id: None,
162 column_ptrs,
163 index_col: None,
164 }
165 }
166
167 pub fn with_schema_id(mut self, schema_id: u32) -> Self {
169 self.schema_id = Some(schema_id);
170 self
171 }
172
173 pub fn with_index_col(mut self, name: String) -> Self {
175 self.index_col = Some(name);
176 self
177 }
178
179 pub fn schema_id(&self) -> Option<u32> {
181 self.schema_id
182 }
183
184 pub fn index_col(&self) -> Option<&str> {
186 self.index_col.as_deref()
187 }
188
189 pub fn column_ptr(&self, index: usize) -> Option<&ColumnPtrs> {
191 self.column_ptrs.get(index)
192 }
193
194 pub fn column_ptrs(&self) -> &[ColumnPtrs] {
196 &self.column_ptrs
197 }
198
199 pub fn row_count(&self) -> usize {
201 self.batch.num_rows()
202 }
203
204 pub fn column_count(&self) -> usize {
206 self.batch.num_columns()
207 }
208
209 pub fn column_names(&self) -> Vec<String> {
211 self.batch
212 .schema()
213 .fields()
214 .iter()
215 .map(|f| f.name().clone())
216 .collect()
217 }
218
219 pub fn schema(&self) -> Arc<Schema> {
221 self.batch.schema()
222 }
223
224 pub fn type_name(&self) -> Option<&str> {
226 self.type_name.as_deref()
227 }
228
229 pub fn column_by_name(&self, name: &str) -> Option<&ArrayRef> {
231 let idx = self.batch.schema().index_of(name).ok()?;
232 Some(self.batch.column(idx))
233 }
234
235 pub fn get_f64_column(&self, name: &str) -> Option<&Float64Array> {
237 self.column_by_name(name)?
238 .as_any()
239 .downcast_ref::<Float64Array>()
240 }
241
242 pub fn get_i64_column(&self, name: &str) -> Option<&Int64Array> {
244 self.column_by_name(name)?
245 .as_any()
246 .downcast_ref::<Int64Array>()
247 }
248
249 pub fn get_string_column(&self, name: &str) -> Option<&StringArray> {
251 self.column_by_name(name)?
252 .as_any()
253 .downcast_ref::<StringArray>()
254 }
255
256 pub fn get_bool_column(&self, name: &str) -> Option<&BooleanArray> {
258 self.column_by_name(name)?
259 .as_any()
260 .downcast_ref::<BooleanArray>()
261 }
262
263 pub fn get_timestamp_column(&self, name: &str) -> Option<&TimestampMicrosecondArray> {
265 self.column_by_name(name)?
266 .as_any()
267 .downcast_ref::<TimestampMicrosecondArray>()
268 }
269
270 pub fn slice(&self, offset: usize, length: usize) -> Self {
272 let sliced = self.batch.slice(offset, length);
273 let column_ptrs = Self::build_column_ptrs(&sliced);
274 Self {
275 batch: sliced,
276 type_name: self.type_name.clone(),
277 schema_id: self.schema_id,
278 column_ptrs,
279 index_col: self.index_col.clone(),
280 }
281 }
282
283 pub fn inner(&self) -> &RecordBatch {
285 &self.batch
286 }
287
288 pub fn into_inner(self) -> RecordBatch {
290 self.batch
291 }
292
293 pub fn is_empty(&self) -> bool {
295 self.batch.num_rows() == 0
296 }
297}
298
299impl std::fmt::Display for DataTable {
300 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301 let name = self.type_name.as_deref().unwrap_or("DataTable");
302 write!(
303 f,
304 "{}({} rows x {} cols: [{}])",
305 name,
306 self.row_count(),
307 self.column_count(),
308 self.column_names().join(", "),
309 )
310 }
311}
312
313impl PartialEq for DataTable {
314 fn eq(&self, other: &Self) -> bool {
315 self.batch == other.batch
316 }
317}
318
319pub struct DataTableBuilder {
323 schema: Schema,
324 columns: Vec<ArrayRef>,
325}
326
327impl DataTableBuilder {
328 pub fn new(schema: Schema) -> Self {
330 Self {
331 schema,
332 columns: Vec::new(),
333 }
334 }
335
336 pub fn with_fields(fields: Vec<Field>) -> Self {
338 Self {
339 schema: Schema::new(fields),
340 columns: Vec::new(),
341 }
342 }
343
344 pub fn add_f64_column(&mut self, values: Vec<f64>) -> &mut Self {
346 self.columns
347 .push(Arc::new(Float64Array::from(values)) as ArrayRef);
348 self
349 }
350
351 pub fn add_i64_column(&mut self, values: Vec<i64>) -> &mut Self {
353 self.columns
354 .push(Arc::new(Int64Array::from(values)) as ArrayRef);
355 self
356 }
357
358 pub fn add_string_column(&mut self, values: Vec<&str>) -> &mut Self {
360 self.columns
361 .push(Arc::new(StringArray::from(values)) as ArrayRef);
362 self
363 }
364
365 pub fn add_bool_column(&mut self, values: Vec<bool>) -> &mut Self {
367 self.columns
368 .push(Arc::new(BooleanArray::from(values)) as ArrayRef);
369 self
370 }
371
372 pub fn add_timestamp_column(&mut self, values: Vec<i64>) -> &mut Self {
374 self.columns
375 .push(Arc::new(TimestampMicrosecondArray::from(values)) as ArrayRef);
376 self
377 }
378
379 pub fn add_column(&mut self, array: ArrayRef) -> &mut Self {
381 self.columns.push(array);
382 self
383 }
384
385 pub fn finish(self) -> Result<DataTable, arrow_schema::ArrowError> {
387 let batch = RecordBatch::try_new(Arc::new(self.schema), self.columns)?;
388 Ok(DataTable::new(batch))
389 }
390
391 pub fn finish_with_type_name(
393 self,
394 type_name: String,
395 ) -> Result<DataTable, arrow_schema::ArrowError> {
396 let batch = RecordBatch::try_new(Arc::new(self.schema), self.columns)?;
397 Ok(DataTable::with_type_name(batch, type_name))
398 }
399
400 pub fn finish_with_schema_id(
402 self,
403 schema_id: u32,
404 ) -> Result<DataTable, arrow_schema::ArrowError> {
405 let batch = RecordBatch::try_new(Arc::new(self.schema), self.columns)?;
406 Ok(DataTable::new(batch).with_schema_id(schema_id))
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use arrow_schema::{DataType, TimeUnit};
414
415 fn sample_schema() -> Schema {
416 Schema::new(vec![
417 Field::new("price", DataType::Float64, false),
418 Field::new("volume", DataType::Int64, false),
419 Field::new("symbol", DataType::Utf8, false),
420 ])
421 }
422
423 fn sample_datatable() -> DataTable {
424 let mut builder = DataTableBuilder::new(sample_schema());
425 builder
426 .add_f64_column(vec![100.0, 101.5, 99.8])
427 .add_i64_column(vec![1000, 2000, 1500])
428 .add_string_column(vec!["AAPL", "AAPL", "AAPL"]);
429 builder.finish().unwrap()
430 }
431
432 #[test]
433 fn test_creation_and_basic_accessors() {
434 let dt = sample_datatable();
435 assert_eq!(dt.row_count(), 3);
436 assert_eq!(dt.column_count(), 3);
437 assert_eq!(dt.column_names(), vec!["price", "volume", "symbol"]);
438 assert!(!dt.is_empty());
439 }
440
441 #[test]
442 fn test_typed_column_access() {
443 let dt = sample_datatable();
444
445 let prices = dt.get_f64_column("price").unwrap();
446 assert_eq!(prices.value(0), 100.0);
447 assert_eq!(prices.value(2), 99.8);
448
449 let volumes = dt.get_i64_column("volume").unwrap();
450 assert_eq!(volumes.value(1), 2000);
451
452 let symbols = dt.get_string_column("symbol").unwrap();
453 assert_eq!(symbols.value(0), "AAPL");
454
455 assert!(dt.get_f64_column("symbol").is_none());
457 assert!(dt.get_f64_column("nonexistent").is_none());
459 }
460
461 #[test]
462 fn test_bool_column() {
463 let schema = Schema::new(vec![Field::new("flag", DataType::Boolean, false)]);
464 let mut builder = DataTableBuilder::new(schema);
465 builder.add_bool_column(vec![true, false, true]);
466 let dt = builder.finish().unwrap();
467
468 let flags = dt.get_bool_column("flag").unwrap();
469 assert!(flags.value(0));
470 assert!(!flags.value(1));
471 }
472
473 #[test]
474 fn test_timestamp_column() {
475 let schema = Schema::new(vec![Field::new(
476 "ts",
477 DataType::Timestamp(TimeUnit::Microsecond, None),
478 false,
479 )]);
480 let mut builder = DataTableBuilder::new(schema);
481 builder.add_timestamp_column(vec![1_000_000, 2_000_000, 3_000_000]);
482 let dt = builder.finish().unwrap();
483
484 let ts = dt.get_timestamp_column("ts").unwrap();
485 assert_eq!(ts.value(0), 1_000_000);
486 assert_eq!(ts.value(2), 3_000_000);
487 }
488
489 #[test]
490 fn test_zero_copy_slice() {
491 let dt = sample_datatable();
492 let sliced = dt.slice(1, 2);
493
494 assert_eq!(sliced.row_count(), 2);
495 assert_eq!(sliced.column_count(), 3);
496
497 let prices = sliced.get_f64_column("price").unwrap();
498 assert_eq!(prices.value(0), 101.5);
499 assert_eq!(prices.value(1), 99.8);
500 }
501
502 #[test]
503 fn test_empty_datatable() {
504 let schema = Schema::new(vec![Field::new("x", DataType::Float64, false)]);
505 let mut builder = DataTableBuilder::new(schema);
506 builder.add_f64_column(vec![]);
507 let dt = builder.finish().unwrap();
508
509 assert!(dt.is_empty());
510 assert_eq!(dt.row_count(), 0);
511 }
512
513 #[test]
514 fn test_display() {
515 let dt = sample_datatable();
516 let s = format!("{}", dt);
517 assert!(s.contains("DataTable"));
518 assert!(s.contains("3 rows"));
519 assert!(s.contains("price"));
520 }
521
522 #[test]
523 fn test_type_name() {
524 let dt = sample_datatable();
525 assert!(dt.type_name().is_none());
526
527 let schema = sample_schema();
528 let mut builder = DataTableBuilder::new(schema);
529 builder
530 .add_f64_column(vec![1.0])
531 .add_i64_column(vec![10])
532 .add_string_column(vec!["X"]);
533 let dt = builder.finish_with_type_name("Candle".to_string()).unwrap();
534 assert_eq!(dt.type_name(), Some("Candle"));
535 let s = format!("{}", dt);
536 assert!(s.starts_with("Candle("));
537 }
538
539 #[test]
540 fn test_builder_schema_mismatch_errors() {
541 let schema = Schema::new(vec![
542 Field::new("a", DataType::Float64, false),
543 Field::new("b", DataType::Int64, false),
544 ]);
545 let mut builder = DataTableBuilder::new(schema);
546 builder.add_f64_column(vec![1.0]);
548 assert!(builder.finish().is_err());
549 }
550
551 #[test]
552 fn test_inner_and_into_inner() {
553 let dt = sample_datatable();
554 let batch_ref = dt.inner();
555 assert_eq!(batch_ref.num_rows(), 3);
556
557 let dt2 = sample_datatable();
558 let batch = dt2.into_inner();
559 assert_eq!(batch.num_rows(), 3);
560 }
561
562 #[test]
563 fn test_partial_eq() {
564 let dt1 = sample_datatable();
565 let dt2 = sample_datatable();
566 assert_eq!(dt1, dt2);
567
568 let sliced = dt1.slice(0, 2);
569 assert_ne!(sliced, dt2);
570 }
571
572 #[test]
573 fn test_column_by_name() {
574 let dt = sample_datatable();
575 assert!(dt.column_by_name("price").is_some());
576 assert!(dt.column_by_name("missing").is_none());
577 }
578
579 #[test]
580 fn test_column_ptrs_constructed() {
581 let dt = sample_datatable();
582 assert_eq!(dt.column_ptrs().len(), 3);
584
585 let price_ptrs = dt.column_ptr(0).unwrap();
587 assert_eq!(price_ptrs.stride, 8);
588 assert!(matches!(price_ptrs.data_type, DataType::Float64));
589 assert!(!price_ptrs.values_ptr.is_null());
590
591 let vol_ptrs = dt.column_ptr(1).unwrap();
593 assert_eq!(vol_ptrs.stride, 8);
594 assert!(matches!(vol_ptrs.data_type, DataType::Int64));
595
596 let sym_ptrs = dt.column_ptr(2).unwrap();
598 assert_eq!(sym_ptrs.stride, 0);
599 assert!(matches!(sym_ptrs.data_type, DataType::Utf8));
600 assert!(!sym_ptrs.offsets_ptr.is_null());
601 }
602
603 #[test]
604 fn test_column_ptrs_f64_read() {
605 let dt = sample_datatable();
606 let ptrs = dt.column_ptr(0).unwrap();
607
608 unsafe {
610 let f64_ptr = ptrs.values_ptr as *const f64;
611 assert_eq!(*f64_ptr, 100.0);
612 assert_eq!(*f64_ptr.add(1), 101.5);
613 assert_eq!(*f64_ptr.add(2), 99.8);
614 }
615 }
616
617 #[test]
618 fn test_column_ptrs_i64_read() {
619 let dt = sample_datatable();
620 let ptrs = dt.column_ptr(1).unwrap();
621
622 unsafe {
624 let i64_ptr = ptrs.values_ptr as *const i64;
625 assert_eq!(*i64_ptr, 1000);
626 assert_eq!(*i64_ptr.add(1), 2000);
627 assert_eq!(*i64_ptr.add(2), 1500);
628 }
629 }
630
631 #[test]
632 fn test_schema_id() {
633 let dt = sample_datatable();
634 assert!(dt.schema_id().is_none());
635
636 let dt_typed = sample_datatable().with_schema_id(42);
637 assert_eq!(dt_typed.schema_id(), Some(42));
638 }
639
640 #[test]
641 fn test_column_ptr_out_of_bounds() {
642 let dt = sample_datatable();
643 assert!(dt.column_ptr(99).is_none());
644 }
645}