1use crate::error::{QueryError, Result};
4use async_trait::async_trait;
5use bytes::Bytes;
6use oxigdal_core::error::OxiGdalError;
7use std::sync::Arc;
8
9#[derive(Debug, Clone)]
11pub struct RecordBatch {
12 pub schema: Arc<Schema>,
14 pub columns: Vec<ColumnData>,
16 pub num_rows: usize,
18}
19
20impl RecordBatch {
21 pub fn new(schema: Arc<Schema>, columns: Vec<ColumnData>, num_rows: usize) -> Result<Self> {
23 if columns.len() != schema.fields.len() {
24 return Err(QueryError::execution(
25 OxiGdalError::invalid_state_builder("Column count does not match schema")
26 .with_operation("record_batch_creation")
27 .with_parameter("schema_fields", schema.fields.len().to_string())
28 .with_parameter("actual_columns", columns.len().to_string())
29 .with_suggestion("Ensure all schema fields have corresponding column data")
30 .build()
31 .to_string(),
32 ));
33 }
34
35 for (idx, column) in columns.iter().enumerate() {
36 if column.len() != num_rows {
37 return Err(QueryError::execution(
38 OxiGdalError::invalid_state_builder("Column length mismatch in batch")
39 .with_operation("record_batch_creation")
40 .with_parameter("expected_rows", num_rows.to_string())
41 .with_parameter("actual_rows", column.len().to_string())
42 .with_parameter("column_index", idx.to_string())
43 .with_suggestion("Ensure all columns have the same number of rows")
44 .build()
45 .to_string(),
46 ));
47 }
48 }
49
50 Ok(Self {
51 schema,
52 columns,
53 num_rows,
54 })
55 }
56
57 pub fn column(&self, index: usize) -> Option<&ColumnData> {
59 self.columns.get(index)
60 }
61
62 pub fn column_by_name(&self, name: &str) -> Option<&ColumnData> {
64 self.schema
65 .fields
66 .iter()
67 .position(|f| f.name == name)
68 .and_then(|idx| self.columns.get(idx))
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct Schema {
75 pub fields: Vec<Field>,
77}
78
79impl Schema {
80 pub fn new(fields: Vec<Field>) -> Self {
82 Self { fields }
83 }
84
85 pub fn field_with_name(&self, name: &str) -> Option<&Field> {
87 self.fields.iter().find(|f| f.name == name)
88 }
89
90 pub fn index_of(&self, name: &str) -> Option<usize> {
92 self.fields.iter().position(|f| f.name == name)
93 }
94}
95
96#[derive(Debug, Clone)]
98pub struct Field {
99 pub name: String,
101 pub data_type: DataType,
103 pub nullable: bool,
105}
106
107impl Field {
108 pub fn new(name: String, data_type: DataType, nullable: bool) -> Self {
110 Self {
111 name,
112 data_type,
113 nullable,
114 }
115 }
116}
117
118#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub enum DataType {
121 Boolean,
123 Int32,
125 Int64,
127 Float32,
129 Float64,
131 String,
133 Binary,
135 Geometry,
137}
138
139#[derive(Debug, Clone)]
141pub enum ColumnData {
142 Boolean(Vec<Option<bool>>),
144 Int32(Vec<Option<i32>>),
146 Int64(Vec<Option<i64>>),
148 Float32(Vec<Option<f32>>),
150 Float64(Vec<Option<f64>>),
152 String(Vec<Option<String>>),
154 Binary(Vec<Option<Bytes>>),
156}
157
158impl ColumnData {
159 pub fn len(&self) -> usize {
161 match self {
162 ColumnData::Boolean(v) => v.len(),
163 ColumnData::Int32(v) => v.len(),
164 ColumnData::Int64(v) => v.len(),
165 ColumnData::Float32(v) => v.len(),
166 ColumnData::Float64(v) => v.len(),
167 ColumnData::String(v) => v.len(),
168 ColumnData::Binary(v) => v.len(),
169 }
170 }
171
172 pub fn is_empty(&self) -> bool {
174 self.len() == 0
175 }
176}
177
178#[async_trait]
180pub trait DataSource: Send + Sync {
181 async fn schema(&self) -> Result<Arc<Schema>>;
183
184 async fn scan(&self) -> Result<Vec<RecordBatch>>;
186}
187
188pub struct TableScan {
190 pub table_name: String,
192 pub source: Arc<dyn DataSource>,
194 pub projection: Option<Vec<usize>>,
196}
197
198impl TableScan {
199 pub fn new(table_name: String, source: Arc<dyn DataSource>) -> Self {
201 Self {
202 table_name,
203 source,
204 projection: None,
205 }
206 }
207
208 pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
210 self.projection = Some(projection);
211 self
212 }
213
214 pub async fn execute(&self) -> Result<Vec<RecordBatch>> {
216 let batches = self.source.scan().await?;
217
218 if let Some(ref projection) = self.projection {
219 let mut projected_batches = Vec::new();
221 for batch in batches {
222 projected_batches.push(self.project_batch(batch, projection)?);
223 }
224 Ok(projected_batches)
225 } else {
226 Ok(batches)
227 }
228 }
229
230 fn project_batch(&self, batch: RecordBatch, projection: &[usize]) -> Result<RecordBatch> {
232 let mut projected_columns = Vec::new();
233 let mut projected_fields = Vec::new();
234
235 for &idx in projection {
236 if idx >= batch.columns.len() {
237 return Err(QueryError::execution(format!(
238 "Column index {} out of bounds",
239 idx
240 )));
241 }
242 projected_columns.push(batch.columns[idx].clone());
243 projected_fields.push(batch.schema.fields[idx].clone());
244 }
245
246 let projected_schema = Arc::new(Schema::new(projected_fields));
247 RecordBatch::new(projected_schema, projected_columns, batch.num_rows)
248 }
249}
250
251pub struct MemoryDataSource {
253 schema: Arc<Schema>,
255 batches: Vec<RecordBatch>,
257}
258
259impl MemoryDataSource {
260 pub fn new(schema: Arc<Schema>, batches: Vec<RecordBatch>) -> Self {
262 Self { schema, batches }
263 }
264}
265
266#[async_trait]
267impl DataSource for MemoryDataSource {
268 async fn schema(&self) -> Result<Arc<Schema>> {
269 Ok(self.schema.clone())
270 }
271
272 async fn scan(&self) -> Result<Vec<RecordBatch>> {
273 Ok(self.batches.clone())
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280
281 #[test]
282 fn test_schema_creation() {
283 let schema = Schema::new(vec![
284 Field::new("id".to_string(), DataType::Int64, false),
285 Field::new("name".to_string(), DataType::String, true),
286 ]);
287
288 assert_eq!(schema.fields.len(), 2);
289 assert_eq!(schema.index_of("id"), Some(0));
290 assert_eq!(schema.index_of("name"), Some(1));
291 }
292
293 #[test]
294 fn test_record_batch_creation() -> Result<()> {
295 let schema = Arc::new(Schema::new(vec![
296 Field::new("id".to_string(), DataType::Int64, false),
297 Field::new("value".to_string(), DataType::Float64, true),
298 ]));
299
300 let columns = vec![
301 ColumnData::Int64(vec![Some(1), Some(2), Some(3)]),
302 ColumnData::Float64(vec![Some(1.0), Some(2.0), Some(3.0)]),
303 ];
304
305 let batch = RecordBatch::new(schema, columns, 3)?;
306 assert_eq!(batch.num_rows, 3);
307 assert_eq!(batch.columns.len(), 2);
308
309 Ok(())
310 }
311
312 #[tokio::test]
313 async fn test_memory_data_source() -> Result<()> {
314 let schema = Arc::new(Schema::new(vec![Field::new(
315 "id".to_string(),
316 DataType::Int64,
317 false,
318 )]));
319
320 let columns = vec![ColumnData::Int64(vec![Some(1), Some(2), Some(3)])];
321 let batch = RecordBatch::new(schema.clone(), columns, 3)?;
322
323 let source = MemoryDataSource::new(schema, vec![batch]);
324 let result_schema = source.schema().await?;
325 assert_eq!(result_schema.fields.len(), 1);
326
327 let batches = source.scan().await?;
328 assert_eq!(batches.len(), 1);
329 assert_eq!(batches[0].num_rows, 3);
330
331 Ok(())
332 }
333}