Skip to main content

oxigdal_query/executor/
scan.rs

1//! Table scan executor.
2
3use crate::error::{QueryError, Result};
4use async_trait::async_trait;
5use bytes::Bytes;
6use oxigdal_core::error::OxiGdalError;
7use std::sync::Arc;
8
9/// A record batch of data.
10#[derive(Debug, Clone)]
11pub struct RecordBatch {
12    /// Schema of the batch.
13    pub schema: Arc<Schema>,
14    /// Column data.
15    pub columns: Vec<ColumnData>,
16    /// Number of rows.
17    pub num_rows: usize,
18}
19
20impl RecordBatch {
21    /// Create a new record batch.
22    pub fn new(schema: Arc<Schema>, columns: Vec<ColumnData>, num_rows: usize) -> Result<Self> {
23        if columns.len() != schema.fields.len() {
24            return Err(QueryError::execution(
25                OxiGdalError::invalid_state_builder("Column count does not match schema")
26                    .with_operation("record_batch_creation")
27                    .with_parameter("schema_fields", schema.fields.len().to_string())
28                    .with_parameter("actual_columns", columns.len().to_string())
29                    .with_suggestion("Ensure all schema fields have corresponding column data")
30                    .build()
31                    .to_string(),
32            ));
33        }
34
35        for (idx, column) in columns.iter().enumerate() {
36            if column.len() != num_rows {
37                return Err(QueryError::execution(
38                    OxiGdalError::invalid_state_builder("Column length mismatch in batch")
39                        .with_operation("record_batch_creation")
40                        .with_parameter("expected_rows", num_rows.to_string())
41                        .with_parameter("actual_rows", column.len().to_string())
42                        .with_parameter("column_index", idx.to_string())
43                        .with_suggestion("Ensure all columns have the same number of rows")
44                        .build()
45                        .to_string(),
46                ));
47            }
48        }
49
50        Ok(Self {
51            schema,
52            columns,
53            num_rows,
54        })
55    }
56
57    /// Get a column by index.
58    pub fn column(&self, index: usize) -> Option<&ColumnData> {
59        self.columns.get(index)
60    }
61
62    /// Get a column by name.
63    pub fn column_by_name(&self, name: &str) -> Option<&ColumnData> {
64        self.schema
65            .fields
66            .iter()
67            .position(|f| f.name == name)
68            .and_then(|idx| self.columns.get(idx))
69    }
70}
71
72/// Schema definition.
73#[derive(Debug, Clone)]
74pub struct Schema {
75    /// Fields in the schema.
76    pub fields: Vec<Field>,
77}
78
79impl Schema {
80    /// Create a new schema.
81    pub fn new(fields: Vec<Field>) -> Self {
82        Self { fields }
83    }
84
85    /// Find field by name.
86    pub fn field_with_name(&self, name: &str) -> Option<&Field> {
87        self.fields.iter().find(|f| f.name == name)
88    }
89
90    /// Get field index by name.
91    pub fn index_of(&self, name: &str) -> Option<usize> {
92        self.fields.iter().position(|f| f.name == name)
93    }
94}
95
96/// Field definition.
97#[derive(Debug, Clone)]
98pub struct Field {
99    /// Field name.
100    pub name: String,
101    /// Data type.
102    pub data_type: DataType,
103    /// Nullable.
104    pub nullable: bool,
105}
106
107impl Field {
108    /// Create a new field.
109    pub fn new(name: String, data_type: DataType, nullable: bool) -> Self {
110        Self {
111            name,
112            data_type,
113            nullable,
114        }
115    }
116}
117
118/// Data type.
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub enum DataType {
121    /// Boolean.
122    Boolean,
123    /// 32-bit integer.
124    Int32,
125    /// 64-bit integer.
126    Int64,
127    /// 32-bit float.
128    Float32,
129    /// 64-bit float.
130    Float64,
131    /// UTF-8 string.
132    String,
133    /// Binary data.
134    Binary,
135    /// Geometry.
136    Geometry,
137}
138
139/// Column data.
140#[derive(Debug, Clone)]
141pub enum ColumnData {
142    /// Boolean column.
143    Boolean(Vec<Option<bool>>),
144    /// 32-bit integer column.
145    Int32(Vec<Option<i32>>),
146    /// 64-bit integer column.
147    Int64(Vec<Option<i64>>),
148    /// 32-bit float column.
149    Float32(Vec<Option<f32>>),
150    /// 64-bit float column.
151    Float64(Vec<Option<f64>>),
152    /// String column.
153    String(Vec<Option<String>>),
154    /// Binary column.
155    Binary(Vec<Option<Bytes>>),
156}
157
158impl ColumnData {
159    /// Get the length of the column.
160    pub fn len(&self) -> usize {
161        match self {
162            ColumnData::Boolean(v) => v.len(),
163            ColumnData::Int32(v) => v.len(),
164            ColumnData::Int64(v) => v.len(),
165            ColumnData::Float32(v) => v.len(),
166            ColumnData::Float64(v) => v.len(),
167            ColumnData::String(v) => v.len(),
168            ColumnData::Binary(v) => v.len(),
169        }
170    }
171
172    /// Check if the column is empty.
173    pub fn is_empty(&self) -> bool {
174        self.len() == 0
175    }
176}
177
178/// Data source trait.
179#[async_trait]
180pub trait DataSource: Send + Sync {
181    /// Get the schema of the data source.
182    async fn schema(&self) -> Result<Arc<Schema>>;
183
184    /// Scan the data source.
185    async fn scan(&self) -> Result<Vec<RecordBatch>>;
186}
187
188/// Table scan operator.
189pub struct TableScan {
190    /// Table name.
191    pub table_name: String,
192    /// Data source.
193    pub source: Arc<dyn DataSource>,
194    /// Projected columns (None means all columns).
195    pub projection: Option<Vec<usize>>,
196}
197
198impl TableScan {
199    /// Create a new table scan.
200    pub fn new(table_name: String, source: Arc<dyn DataSource>) -> Self {
201        Self {
202            table_name,
203            source,
204            projection: None,
205        }
206    }
207
208    /// Set projection.
209    pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
210        self.projection = Some(projection);
211        self
212    }
213
214    /// Execute the scan.
215    pub async fn execute(&self) -> Result<Vec<RecordBatch>> {
216        let batches = self.source.scan().await?;
217
218        if let Some(ref projection) = self.projection {
219            // Apply projection
220            let mut projected_batches = Vec::new();
221            for batch in batches {
222                projected_batches.push(self.project_batch(batch, projection)?);
223            }
224            Ok(projected_batches)
225        } else {
226            Ok(batches)
227        }
228    }
229
230    /// Project a record batch.
231    fn project_batch(&self, batch: RecordBatch, projection: &[usize]) -> Result<RecordBatch> {
232        let mut projected_columns = Vec::new();
233        let mut projected_fields = Vec::new();
234
235        for &idx in projection {
236            if idx >= batch.columns.len() {
237                return Err(QueryError::execution(format!(
238                    "Column index {} out of bounds",
239                    idx
240                )));
241            }
242            projected_columns.push(batch.columns[idx].clone());
243            projected_fields.push(batch.schema.fields[idx].clone());
244        }
245
246        let projected_schema = Arc::new(Schema::new(projected_fields));
247        RecordBatch::new(projected_schema, projected_columns, batch.num_rows)
248    }
249}
250
251/// In-memory data source for testing.
252pub struct MemoryDataSource {
253    /// Schema.
254    schema: Arc<Schema>,
255    /// Batches.
256    batches: Vec<RecordBatch>,
257}
258
259impl MemoryDataSource {
260    /// Create a new memory data source.
261    pub fn new(schema: Arc<Schema>, batches: Vec<RecordBatch>) -> Self {
262        Self { schema, batches }
263    }
264}
265
266#[async_trait]
267impl DataSource for MemoryDataSource {
268    async fn schema(&self) -> Result<Arc<Schema>> {
269        Ok(self.schema.clone())
270    }
271
272    async fn scan(&self) -> Result<Vec<RecordBatch>> {
273        Ok(self.batches.clone())
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    #[test]
282    fn test_schema_creation() {
283        let schema = Schema::new(vec![
284            Field::new("id".to_string(), DataType::Int64, false),
285            Field::new("name".to_string(), DataType::String, true),
286        ]);
287
288        assert_eq!(schema.fields.len(), 2);
289        assert_eq!(schema.index_of("id"), Some(0));
290        assert_eq!(schema.index_of("name"), Some(1));
291    }
292
293    #[test]
294    fn test_record_batch_creation() -> Result<()> {
295        let schema = Arc::new(Schema::new(vec![
296            Field::new("id".to_string(), DataType::Int64, false),
297            Field::new("value".to_string(), DataType::Float64, true),
298        ]));
299
300        let columns = vec![
301            ColumnData::Int64(vec![Some(1), Some(2), Some(3)]),
302            ColumnData::Float64(vec![Some(1.0), Some(2.0), Some(3.0)]),
303        ];
304
305        let batch = RecordBatch::new(schema, columns, 3)?;
306        assert_eq!(batch.num_rows, 3);
307        assert_eq!(batch.columns.len(), 2);
308
309        Ok(())
310    }
311
312    #[tokio::test]
313    async fn test_memory_data_source() -> Result<()> {
314        let schema = Arc::new(Schema::new(vec![Field::new(
315            "id".to_string(),
316            DataType::Int64,
317            false,
318        )]));
319
320        let columns = vec![ColumnData::Int64(vec![Some(1), Some(2), Some(3)])];
321        let batch = RecordBatch::new(schema.clone(), columns, 3)?;
322
323        let source = MemoryDataSource::new(schema, vec![batch]);
324        let result_schema = source.schema().await?;
325        assert_eq!(result_schema.fields.len(), 1);
326
327        let batches = source.scan().await?;
328        assert_eq!(batches.len(), 1);
329        assert_eq!(batches[0].num_rows, 3);
330
331        Ok(())
332    }
333}