spatialbench_arrow/
trip.rs1use crate::conversions::{decimal128_array_from_iter, to_arrow_timestamp_millis};
19use crate::{DEFAULT_BATCH_SIZE, RecordBatchIterator};
20use arrow::array::{BinaryArray, Int64Array, RecordBatch, TimestampMillisecondArray};
21use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
22use geo::Geometry;
23use geozero::{CoordDimensions, ToWkb};
24use spatialbench::generators::{Trip, TripGenerator, TripGeneratorIterator};
25use std::sync::{Arc, LazyLock, Mutex};
26
27struct ThreadSafeTripGenerator {
29 generator: Mutex<TripGeneratorIterator>,
30}
31
32impl ThreadSafeTripGenerator {
33 fn new(generator: TripGenerator) -> Self {
34 Self {
35 generator: Mutex::new(generator.iter()),
36 }
37 }
38
39 fn next_batch(&self, batch_size: usize) -> Vec<Trip> {
40 let mut generator = self.generator.lock().unwrap();
41 generator.by_ref().take(batch_size).collect()
42 }
43}
44
45unsafe impl Send for ThreadSafeTripGenerator {}
47unsafe impl Sync for ThreadSafeTripGenerator {}
48
49pub struct TripArrow {
50 generator: ThreadSafeTripGenerator,
51 batch_size: usize,
52 schema: SchemaRef,
53}
54
55impl TripArrow {
56 pub fn new(generator: TripGenerator) -> Self {
57 Self {
58 generator: ThreadSafeTripGenerator::new(generator),
59 batch_size: DEFAULT_BATCH_SIZE,
60 schema: TRIP_SCHEMA.clone(),
61 }
62 }
63
64 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
65 self.batch_size = batch_size;
66 self
67 }
68}
69
70impl RecordBatchIterator for TripArrow {
71 fn schema(&self) -> &SchemaRef {
72 &self.schema
73 }
74}
75
76impl Iterator for TripArrow {
77 type Item = RecordBatch;
78
79 fn next(&mut self) -> Option<Self::Item> {
80 let rows = self.generator.next_batch(self.batch_size);
82 if rows.is_empty() {
83 return None;
84 }
85
86 let t_tripkey = Int64Array::from_iter_values(rows.iter().map(|row| row.t_tripkey));
88 let t_custkey = Int64Array::from_iter_values(rows.iter().map(|row| row.t_custkey));
89 let t_driverkey = Int64Array::from_iter_values(rows.iter().map(|row| row.t_driverkey));
90 let t_vehiclekey = Int64Array::from_iter_values(rows.iter().map(|row| row.t_vehiclekey));
91 let t_pickuptime = TimestampMillisecondArray::from_iter_values(
92 rows.iter()
93 .map(|row| to_arrow_timestamp_millis(row.t_pickuptime)),
94 );
95 let t_dropofftime = TimestampMillisecondArray::from_iter_values(
96 rows.iter()
97 .map(|row| to_arrow_timestamp_millis(row.t_dropofftime)),
98 );
99 let t_fare = decimal128_array_from_iter(rows.iter().map(|row| row.t_fare));
100 let t_tip = decimal128_array_from_iter(rows.iter().map(|row| row.t_tip));
101 let t_totalamount = decimal128_array_from_iter(rows.iter().map(|row| row.t_totalamount));
102 let t_distance = decimal128_array_from_iter(rows.iter().map(|row| row.t_distance));
103 let t_pickuploc = BinaryArray::from_iter_values(rows.iter().map(|row| {
104 Geometry::Point(row.t_pickuploc)
105 .to_wkb(CoordDimensions::xy())
106 .expect("Failed to convert pickup location to WKB")
107 }));
108 let t_dropoffloc = BinaryArray::from_iter_values(rows.iter().map(|row| {
109 Geometry::Point(row.t_dropoffloc)
110 .to_wkb(CoordDimensions::xy())
111 .expect("Failed to convert dropoff location to WKB")
112 }));
113
114 let batch = RecordBatch::try_new(
115 Arc::clone(&self.schema),
116 vec![
117 Arc::new(t_tripkey),
118 Arc::new(t_custkey),
119 Arc::new(t_driverkey),
120 Arc::new(t_vehiclekey),
121 Arc::new(t_pickuptime),
122 Arc::new(t_dropofftime),
123 Arc::new(t_fare),
124 Arc::new(t_tip),
125 Arc::new(t_totalamount),
126 Arc::new(t_distance),
127 Arc::new(t_pickuploc),
128 Arc::new(t_dropoffloc),
129 ],
130 )
131 .unwrap();
132
133 Some(batch)
134 }
135}
136
137static TRIP_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(make_trip_schema);
139
140fn make_trip_schema() -> SchemaRef {
141 Arc::new(Schema::new(vec![
142 Field::new("t_tripkey", DataType::Int64, false),
143 Field::new("t_custkey", DataType::Int64, false),
144 Field::new("t_driverkey", DataType::Int64, false),
145 Field::new("t_vehiclekey", DataType::Int64, false),
146 Field::new(
147 "t_pickuptime",
148 DataType::Timestamp(TimeUnit::Millisecond, None),
149 false,
150 ),
151 Field::new(
152 "t_dropofftime",
153 DataType::Timestamp(TimeUnit::Millisecond, None),
154 false,
155 ),
156 Field::new("t_fare", DataType::Decimal128(15, 5), false),
157 Field::new("t_tip", DataType::Decimal128(15, 5), false),
158 Field::new("t_totalamount", DataType::Decimal128(15, 5), false),
159 Field::new("t_distance", DataType::Decimal128(15, 5), false),
160 Field::new("t_pickuploc", DataType::Binary, false),
161 Field::new("t_dropoffloc", DataType::Binary, false),
162 ]))
163}