use crate::data::datatable::{DataTable, DataValue};
use crate::sql::recursive_parser::SqlExpression;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub enum ViewColumn {
Original {
source_index: usize, name: String, },
Derived {
name: String,
expression: SqlExpression,
cached_values: Vec<DataValue>, },
}
#[derive(Debug, Clone)]
pub struct ComputedDataView {
source_table: Arc<DataTable>,
columns: Vec<ViewColumn>,
visible_rows: Vec<usize>,
}
impl ComputedDataView {
#[must_use]
pub fn new(
source_table: Arc<DataTable>,
columns: Vec<ViewColumn>,
visible_rows: Vec<usize>,
) -> Self {
Self {
source_table,
columns,
visible_rows,
}
}
#[must_use]
pub fn row_count(&self) -> usize {
self.visible_rows.len()
}
#[must_use]
pub fn column_count(&self) -> usize {
self.columns.len()
}
#[must_use]
pub fn column_names(&self) -> Vec<String> {
self.columns
.iter()
.map(|col| match col {
ViewColumn::Original { name, .. } => name.clone(),
ViewColumn::Derived { name, .. } => name.clone(),
})
.collect()
}
#[must_use]
pub fn get_value(&self, row_idx: usize, col_idx: usize) -> Option<DataValue> {
if row_idx >= self.visible_rows.len() || col_idx >= self.columns.len() {
return None;
}
match &self.columns[col_idx] {
ViewColumn::Original { source_index, .. } => {
let source_row_idx = self.visible_rows[row_idx];
self.source_table
.get_row(source_row_idx)
.and_then(|row| row.get(*source_index))
.cloned()
}
ViewColumn::Derived { cached_values, .. } => {
cached_values.get(row_idx).cloned()
}
}
}
#[must_use]
pub fn get_row_values(&self, row_idx: usize) -> Option<Vec<DataValue>> {
if row_idx >= self.visible_rows.len() {
return None;
}
let mut values = Vec::new();
for col_idx in 0..self.columns.len() {
values.push(self.get_value(row_idx, col_idx)?);
}
Some(values)
}
#[must_use]
pub fn source_table(&self) -> &Arc<DataTable> {
&self.source_table
}
#[must_use]
pub fn visible_rows(&self) -> &[usize] {
&self.visible_rows
}
#[must_use]
pub fn is_derived_column(&self, col_idx: usize) -> bool {
matches!(self.columns.get(col_idx), Some(ViewColumn::Derived { .. }))
}
#[must_use]
pub fn from_source_all_columns(source: Arc<DataTable>) -> Self {
let columns: Vec<ViewColumn> = source
.column_names()
.into_iter()
.enumerate()
.map(|(idx, name)| ViewColumn::Original {
source_index: idx,
name,
})
.collect();
let visible_rows: Vec<usize> = (0..source.row_count()).collect();
Self::new(source, columns, visible_rows)
}
#[must_use]
pub fn with_filtered_rows(mut self, row_indices: Vec<usize>) -> Self {
self.visible_rows = row_indices;
for col in &mut self.columns {
if let ViewColumn::Derived { cached_values, .. } = col {
let mut new_cache = Vec::new();
for &row_idx in &self.visible_rows {
if row_idx < cached_values.len() {
new_cache.push(cached_values[row_idx].clone());
}
}
*cached_values = new_cache;
}
}
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data::datatable::{DataColumn, DataRow};
use crate::sql::parser::ast::ColumnRef;
fn create_test_table() -> Arc<DataTable> {
let mut table = DataTable::new("test");
table.add_column(DataColumn::new("a"));
table.add_column(DataColumn::new("b"));
table
.add_row(DataRow::new(vec![
DataValue::Integer(10),
DataValue::Float(2.5),
]))
.unwrap();
table
.add_row(DataRow::new(vec![
DataValue::Integer(20),
DataValue::Float(3.5),
]))
.unwrap();
Arc::new(table)
}
#[test]
fn test_original_columns_view() {
let table = create_test_table();
let view = ComputedDataView::from_source_all_columns(table);
assert_eq!(view.row_count(), 2);
assert_eq!(view.column_count(), 2);
assert_eq!(view.column_names(), vec!["a", "b"]);
assert_eq!(view.get_value(0, 0), Some(DataValue::Integer(10)));
assert_eq!(view.get_value(0, 1), Some(DataValue::Float(2.5)));
assert_eq!(view.get_value(1, 0), Some(DataValue::Integer(20)));
}
#[test]
fn test_mixed_columns() {
let table = create_test_table();
let columns = vec![
ViewColumn::Original {
source_index: 0,
name: "a".to_string(),
},
ViewColumn::Derived {
name: "doubled".to_string(),
expression: SqlExpression::Column(ColumnRef::unquoted("a".to_string())), cached_values: vec![
DataValue::Integer(20), DataValue::Integer(40), ],
},
];
let view = ComputedDataView::new(table, columns, vec![0, 1]);
assert_eq!(view.column_count(), 2);
assert_eq!(view.column_names(), vec!["a", "doubled"]);
assert_eq!(view.get_value(0, 0), Some(DataValue::Integer(10)));
assert_eq!(view.get_value(0, 1), Some(DataValue::Integer(20)));
assert_eq!(view.get_value(1, 1), Some(DataValue::Integer(40)));
}
#[test]
fn test_filtered_rows() {
let table = create_test_table();
let view = ComputedDataView::from_source_all_columns(table).with_filtered_rows(vec![1]);
assert_eq!(view.row_count(), 1);
assert_eq!(view.get_value(0, 0), Some(DataValue::Integer(20)));
}
}