use crate::expression::{FieldExpression, SelectedValue};
use crate::predicate::FilterExpression;
use crate::record_ops::FieldResolver;
use crate::{CsvSource, PivotSpec};
use csv::{Reader, StringRecord};
use proc_macro2::TokenStream;
use std::collections::HashMap;
use std::error::Error;
use std::fmt::{Debug, Formatter};
use std::ops::{Bound, Deref};
use syn::{Ident, Lit, LitStr};
pub(crate) fn query_csv(header: &CsvSource) -> syn::Result<DataSource> {
let path = header
.from
.span()
.unwrap()
.local_file()
.unwrap()
.parent()
.unwrap()
.join(header.from.value());
let reader = csv::Reader::from_path(&path).map_err(|err| {
let message = format!("{err}\nhint: tried to read from {}", path.display());
syn::Error::new(header.from.span(), message)
})?;
DataSource::new(header.from.clone(), header.pivot.as_ref(), reader)
.map_err(|err| syn::Error::new(header.from.span(), err))
}
#[derive(Default, Debug)]
pub struct Query {
selection: Vec<FieldExpression>,
any_row: Vec<FilterExpression>,
}
impl Query {
pub(crate) fn any(filter: FilterExpression) -> Query {
Self {
any_row: vec![filter],
..Query::default()
}
}
}
impl Query {
pub const EMPTY: &'static Query = &Query {
selection: Vec::new(),
any_row: vec![],
};
fn fields(&self, data_set: &impl DataSet) -> Result<Vec<FieldIndex>, QueryError> {
let mut fields_key = Vec::with_capacity(self.selection.len());
for expr in &self.selection {
let field_index = data_set.get_field_index(expr.field_ident())?;
fields_key.push(field_index);
}
Ok(fields_key)
}
}
impl FromIterator<Query> for Query {
fn from_iter<T: IntoIterator<Item = Query>>(iter: T) -> Self {
let iter = iter.into_iter();
let mut selection = vec![];
let mut any_row = vec![];
for q in iter {
selection.extend(q.selection);
any_row.extend(q.any_row);
}
Query { selection, any_row }
}
}
pub trait DataSet: Sized {
fn source(&self) -> &DataSource;
fn as_group(&self) -> DataGroup<'_>;
fn record_iter(&self) -> impl Iterator<Item = &StringRecord>;
fn get_field_index(&self, field_ident: &Ident) -> Result<FieldIndex, QueryError>;
fn get_pivot_fields(
&self,
fields: &[FieldIndex],
) -> Result<Option<Vec<(&str, usize)>>, QueryError> {
if let Some(pivot) = &self.source().pivot_setup {
if fields
.iter()
.any(|field| matches!(field, FieldIndex::PivotKey | FieldIndex::PivotValue))
{
let pivot_fields = pivot
.pivot_fields
.iter()
.map(|&idx| {
let pivot_heading = self
.source()
.heading
.get(idx)
.expect("pivot field in heading");
(pivot_heading, idx)
})
.collect();
Ok(Some(pivot_fields))
} else {
Ok(None)
}
} else {
Ok(None)
}
}
fn process_records_directly<'a>(
&'a self,
fields: &[FieldIndex],
query: &Query,
where_predicate: Option<&'a FilterExpression>,
pivot_fields: Option<Vec<(&'a str, usize)>>,
) -> Result<std::collections::BTreeMap<Vec<SelectedValue>, Vec<&'a StringRecord>>, QueryError>
where
Self: FieldResolver,
{
use crate::pivot_iterator::UnifiedRecordIterator;
use crate::record_ops::RecordOps;
use std::collections::BTreeMap;
let pivot_key_filter = where_predicate.and_then(|pred| self.extract_pivot_key_filter(pred));
let filtered_records = self.record_iter().filter(|record| {
if let Some(where_predicate) = where_predicate {
let record_ops = RecordOps::new(record);
let filtered_predicate = self.exclude_pivot_predicates(where_predicate);
if let Some(filtered_pred) = filtered_predicate {
record_ops.matches_predicate(&filtered_pred, self)
} else {
true
}
} else {
true
}
});
let unified_iter = if let Some(pivot_fields) = pivot_fields {
UnifiedRecordIterator::new_pivoted(filtered_records, pivot_fields, pivot_key_filter)
} else {
UnifiedRecordIterator::new_pass_through(filtered_records)
};
let mut subgroups = BTreeMap::new();
for (record, pivot_info) in unified_iter {
let key = self.build_selection_key(record, pivot_info.as_ref(), fields)?;
let entry = subgroups.entry(key);
let group: &mut Vec<&StringRecord> = entry.or_default();
group.push(record);
}
if !query.any_row.is_empty() {
subgroups.retain(|_key, records| {
query.any_row.iter().all(|predicate| {
records.iter().any(|record| {
let record_ops = RecordOps::new(record);
record_ops.matches_predicate(predicate, self)
})
})
});
}
Ok(subgroups)
}
fn build_selection_key(
&self,
record: &StringRecord,
pivot_info: Option<&crate::pivot_iterator::PivotField>,
fields: &[FieldIndex],
) -> Result<Vec<SelectedValue>, QueryError> {
use crate::record_ops::RecordOps;
let record_ops = RecordOps::new(record);
let mut key = Vec::with_capacity(fields.len());
for field_index in fields {
let value = match field_index {
FieldIndex::Index(_) => record_ops.get_selected_value(field_index)?,
FieldIndex::PivotKey => {
let pivot = pivot_info.expect("PivotKey field requires pivot context");
SelectedValue(pivot.key.clone())
}
FieldIndex::PivotValue => {
let pivot = pivot_info.expect("PivotValue field requires pivot context");
SelectedValue(pivot.value.clone())
}
};
key.push(value);
}
Ok(key)
}
fn extract_pivot_key_filter<'a>(
&self,
predicate: &'a FilterExpression,
) -> Option<&'a crate::predicate::PredicateComparison> {
self.extract_pivot_key_filter_inner(&predicate.expr)
}
fn extract_pivot_key_filter_inner<'a>(
&self,
predicate: &'a crate::predicate::FilterExpressionInner,
) -> Option<&'a crate::predicate::PredicateComparison> {
use crate::predicate::FilterExpressionInner;
match predicate {
FilterExpressionInner::Comparison(comparison) => {
match self.get_field_index(&comparison.field) {
Ok(FieldIndex::PivotKey) => Some(comparison),
_ => None,
}
}
FilterExpressionInner::All(all) => all
.iter()
.find_map(|pred| self.extract_pivot_key_filter_inner(pred)),
FilterExpressionInner::Any(any) => any
.iter()
.find_map(|pred| self.extract_pivot_key_filter_inner(pred)),
}
}
fn exclude_pivot_predicates(&self, predicate: &FilterExpression) -> Option<FilterExpression> {
self.exclude_pivot_predicates_inner(&predicate.expr)
.map(|expr| FilterExpression {
_parens: predicate._parens,
expr,
})
}
fn exclude_pivot_predicates_inner(
&self,
predicate: &crate::predicate::FilterExpressionInner,
) -> Option<crate::predicate::FilterExpressionInner> {
use crate::predicate::FilterExpressionInner;
use syn::Token;
use syn::punctuated::Punctuated;
match predicate {
FilterExpressionInner::Comparison(comparison) => {
match self.get_field_index(&comparison.field) {
Ok(FieldIndex::PivotKey) => None, _ => Some(FilterExpressionInner::Comparison(comparison.clone())),
}
}
FilterExpressionInner::All(all) => {
let mut filtered = Punctuated::<FilterExpressionInner, Token![&&]>::new();
for pred in all.iter() {
if let Some(filtered_pred) = self.exclude_pivot_predicates_inner(pred) {
filtered.push(filtered_pred);
}
}
if filtered.is_empty() {
None
} else if filtered.len() == 1 {
Some(filtered.into_iter().next().unwrap())
} else {
Some(FilterExpressionInner::All(filtered))
}
}
FilterExpressionInner::Any(any) => {
let mut filtered = Punctuated::<FilterExpressionInner, Token![||]>::new();
for pred in any.iter() {
if let Some(filtered_pred) = self.exclude_pivot_predicates_inner(pred) {
filtered.push(filtered_pred);
}
}
if filtered.is_empty() {
None
} else if filtered.len() == 1 {
Some(filtered.into_iter().next().unwrap())
} else {
Some(FilterExpressionInner::Any(filtered))
}
}
}
}
fn select<'a>(
&'a self,
where_predicate: Option<&'a FilterExpression>,
query: &'a Query,
) -> QueryResult<'a>
where
Self: FieldResolver,
{
let fields = query.fields(self)?;
let pivot_fields = self.get_pivot_fields(&fields)?;
let subgroups =
self.process_records_directly(&fields, query, where_predicate, pivot_fields)?;
let data = subgroups
.into_iter()
.map(|(k, v)| {
let fields_tokens = k
.iter()
.zip(&query.selection)
.map(|(value, expr)| {
expr.write_output(value).map_err(|error| {
Box::new(QueryErrorInner::CouldntParseField {
error,
value: value.clone(),
expr: expr.clone(),
})
})
})
.collect::<Result<_, _>>()?;
Ok(DataGroup {
source: self.source(),
query,
fields: fields_tokens,
records: v,
})
})
.collect::<Result<_, QueryError>>()?;
Ok(QueryData { data })
}
}
pub struct DataSource {
file_name: LitStr,
field_map: HashMap<String, usize>,
heading: StringRecord,
pivot_setup: Option<PivotSetup>,
records: Vec<StringRecord>,
}
struct PivotSetup {
key_field: syn::Ident,
value_field: syn::Ident,
pivot_fields: Vec<usize>,
}
impl PivotSetup {
fn new(
PivotSpec {
column_from,
column_to,
key_field_name,
value_field_name,
..
}: &PivotSpec,
headers: &StringRecord,
) -> syn::Result<Self> {
let resolve_field = |field_spec: &Lit| {
Ok(match field_spec {
Lit::Str(lit) => headers
.iter()
.enumerate()
.find_map(|(idx, header)| (header == lit.value()).then_some(idx))
.ok_or_else(|| syn::Error::new(lit.span(), "field not found"))?,
Lit::Int(lit) => {
let int = lit.base10_parse()?;
if int < headers.len() {
int
} else {
return Err(syn::Error::new(lit.span(), "field not found"));
}
}
_ => {
return Err(syn::Error::new(
field_spec.span(),
"Field spec can either be the index as an int or name as a string",
));
}
})
};
let resolve_bound = |field_bound: &Bound<Lit>| -> Result<Bound<usize>, syn::Error> {
Ok(match field_bound {
Bound::Included(field) => Bound::Included(resolve_field(field)?),
Bound::Excluded(field) => Bound::Excluded(resolve_field(field)?),
Bound::Unbounded => Bound::Unbounded,
})
};
let start_index = column_from
.as_ref()
.map(resolve_field)
.transpose()?
.unwrap_or(0);
let end_index = match resolve_bound(column_to)? {
Bound::Excluded(column_to) => column_to - 1,
Bound::Included(column_to) => column_to,
Bound::Unbounded => headers.len() - 1,
};
Ok(Self {
key_field: key_field_name.clone(),
value_field: value_field_name.clone(),
pivot_fields: (start_index..=end_index).collect(),
})
}
}
impl DataSource {
pub(crate) fn new<R: std::io::Read>(
file_name: LitStr,
pivot_spec: Option<&PivotSpec>,
mut reader: Reader<R>,
) -> Result<DataSource, Box<dyn Error>> {
let headers = reader.headers()?;
let field_map: HashMap<String, usize> = headers
.into_iter()
.enumerate()
.map(|(i, header)| (header.to_string(), i))
.collect();
Ok(Self {
file_name,
heading: headers.clone(),
field_map,
pivot_setup: pivot_spec
.map(|pivot| PivotSetup::new(pivot, headers))
.transpose()?,
records: reader.into_records().collect::<Result<_, _>>()?,
})
}
}
pub enum FieldIndex {
Index(usize),
PivotKey,
PivotValue,
}
impl FieldResolver for DataSource {
fn get_field_index(&self, field_ident: &Ident) -> Result<FieldIndex, QueryError> {
if let Some(field_index) = self.field_map.get(&field_ident.to_string()) {
return Ok(FieldIndex::Index(*field_index));
} else if let Some(pivot_setup) = self.pivot_setup.as_ref() {
if pivot_setup.key_field == *field_ident {
return Ok(FieldIndex::PivotKey);
} else if pivot_setup.value_field == *field_ident {
return Ok(FieldIndex::PivotValue);
}
}
Err(QueryErrorInner::NoField {
field: field_ident.clone(),
file: self.file_name.clone(),
}
.into())
}
}
impl DataSet for DataSource {
fn source(&self) -> &DataSource {
self
}
fn get_field_index(&self, field_ident: &Ident) -> Result<FieldIndex, QueryError> {
FieldResolver::get_field_index(self, field_ident)
}
fn record_iter(&self) -> impl Iterator<Item = &StringRecord> {
self.records.iter()
}
fn as_group(&self) -> DataGroup<'_> {
DataGroup {
source: self,
query: Query::EMPTY,
fields: vec![],
records: vec![],
}
}
}
#[derive(Clone)]
pub struct DataGroup<'a> {
source: &'a DataSource,
query: &'a Query,
fields: Vec<TokenStream>,
records: Vec<&'a StringRecord>,
}
impl<'a> DataGroup<'a> {
pub(crate) fn get_field(&self, expr: &FieldExpression) -> TokenStream {
let index = self
.query
.selection
.iter()
.position(|x| x == expr)
.unwrap_or_else(|| {
panic!("expression {expr:?} should be in query : something went wrong")
});
self.fields[index].clone()
}
}
impl<'a> FieldResolver for DataGroup<'a> {
fn get_field_index(&self, field_ident: &Ident) -> Result<FieldIndex, QueryError> {
FieldResolver::get_field_index(self.source, field_ident)
}
}
impl<'a> DataSet for DataGroup<'a> {
fn source(&self) -> &DataSource {
self.source
}
fn get_field_index(&self, field_ident: &Ident) -> Result<FieldIndex, QueryError> {
FieldResolver::get_field_index(self, field_ident)
}
fn record_iter(&self) -> impl Iterator<Item = &StringRecord> {
self.records.iter().copied()
}
fn as_group(&self) -> DataGroup<'_> {
self.clone()
}
}
pub struct QueryData<'a> {
pub data: Vec<DataGroup<'a>>,
}
type QueryResult<'a> = Result<QueryData<'a>, QueryError>;
pub type QueryError = Box<QueryErrorInner>;
pub enum QueryErrorInner {
NoField {
field: syn::Ident,
file: syn::LitStr,
},
InvalidSource {
file: LitStr,
message: String,
},
CouldntParseField {
error: syn::Error,
value: SelectedValue,
expr: FieldExpression,
},
}
impl QueryErrorInner {
pub(crate) fn to_compile_error(&self) -> TokenStream {
match self {
QueryErrorInner::NoField { field, file } => syn::Error::new(
field.span(),
format_args!(
"Couldn't find the field named {} in file {}",
field,
file.value()
),
)
.into_compile_error(),
QueryErrorInner::InvalidSource { file, message } => syn::Error::new(
file.span(),
format_args!("Source file {} invalid: {}", file.value(), message),
)
.into_compile_error(),
QueryErrorInner::CouldntParseField { error, value, expr } => syn::Error::new(
expr.field.span(),
format_args!(
"Unable to parse value '{}' into type {:?}: {error}",
value.0, expr.syntax_type
),
)
.into_compile_error(),
}
}
}
impl Debug for QueryError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self.deref() {
QueryErrorInner::NoField { field, file } => {
write!(f, "NoField {} in file {}", field, file.value())
}
QueryErrorInner::InvalidSource { file, message } => {
write!(f, "InvalidSource {}: {}", file.value(), message)
}
QueryErrorInner::CouldntParseField { error, value, expr } => {
write!(
f,
"Unable to parse value '{}' into type {:?} : {error}",
value.0, expr.syntax_type
)
}
}
}
}
impl FieldExpression {
pub(crate) fn to_query(&self) -> Query {
Query {
selection: vec![self.clone()],
any_row: vec![],
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use syn::{RangeLimits, parse_quote};
#[test]
fn test_single_column() {
let data = "compound
co2
n2o
ch4
";
let source = DataSource::new(
parse_quote!("test.csv"),
None,
Reader::from_reader(data.as_bytes()),
)
.unwrap();
let expression = FieldExpression::new(parse_quote!(compound));
let query = expression.to_query();
let result = source.select(None, &query).unwrap().data;
assert_eq!(result.len(), 3);
assert_eq!(result[0].fields.len(), 1);
assert_eq!(result[0].fields[0].to_string().as_str(), "ch4");
assert_eq!(result[1].fields[0].to_string().as_str(), "co2");
assert_eq!(result[2].fields[0].to_string().as_str(), "n2o");
}
const STATUS_CODE_CLIENT_MESSAGE_CSV: &str = r#"status_code,client_type,message
200,web,Page loaded successfully
200,api,Request completed successfully
200,mobile,Data retrieved
400,web,Please check your input and try again
400,api,Bad Request: {field} is invalid
401,web,Please log in to continue
401,api,Authentication required: Bearer token expired
401,mobile,Session expired - please sign in
404,web,Page not found - <a href="/">Go home</a>
404,api,Resource not found: {endpoint}
404,mobile,Content not available
422,api,Validation failed: {errors}
429,api,Rate limit exceeded: {limit} requests per {window}
500,web,Something went wrong - we're working on it
500,api,Internal server error (ID: {error_id})
500,mobile,Server error - please try again later"#;
#[test]
fn test_multi_group() {
let source = DataSource::new(
parse_quote!("test.csv"),
None,
Reader::from_reader(STATUS_CODE_CLIENT_MESSAGE_CSV.as_bytes()),
)
.unwrap();
let expression_status_code = FieldExpression::new(parse_quote!(status_code));
let expression_client = FieldExpression::new(parse_quote!(client_type));
let query = expression_status_code.to_query();
let status_groups = source.select(None, &query).unwrap().data;
assert_eq!(status_groups.len(), 7);
assert_eq!(status_groups[0].fields.len(), 1);
assert_eq!(status_groups[0].fields[0].to_string().as_str(), "200");
assert_eq!(status_groups[1].fields[0].to_string().as_str(), "400");
assert_eq!(status_groups[2].fields[0].to_string().as_str(), "401");
let query1 = expression_client.to_query();
let group_200_client_groups = source.select(None, &query1).unwrap().data;
assert_eq!(group_200_client_groups.len(), 3);
assert_eq!(group_200_client_groups[0].fields.len(), 1);
assert_eq!(
group_200_client_groups[0].fields[0].to_string().as_str(),
"api"
);
assert_eq!(
group_200_client_groups[1].fields[0].to_string().as_str(),
"mobile"
);
assert_eq!(
group_200_client_groups[2].fields[0].to_string().as_str(),
"web"
);
}
#[test]
fn test_multi_select() {
let source = DataSource::new(
parse_quote!("test.csv"),
None,
Reader::from_reader(STATUS_CODE_CLIENT_MESSAGE_CSV.as_bytes()),
)
.unwrap();
let expression_status_code = FieldExpression::new(parse_quote!(status_code));
let expression_client = FieldExpression::new(parse_quote!(client_type));
let query = Query {
selection: vec![expression_status_code, expression_client],
any_row: vec![],
};
let groups = source.select(None, &query).unwrap().data;
assert_eq!(groups.len(), 16);
assert_eq!(groups[0].fields.len(), 2);
assert_eq!(groups[0].fields[0].to_string().as_str(), "200");
assert_eq!(groups[0].fields[1].to_string().as_str(), "api");
assert_eq!(groups[1].fields[0].to_string().as_str(), "200");
assert_eq!(groups[1].fields[1].to_string().as_str(), "mobile");
}
#[test]
fn test_pivot() {
let source = DataSource::new(
parse_quote!("test.csv"),
Some(&PivotSpec {
_kw: Default::default(),
_parens: Default::default(),
column_from: None,
_range_limits: RangeLimits::HalfOpen(Default::default()),
column_to: Bound::Unbounded,
key_field_name: parse_quote!(key),
value_field_name: parse_quote!(value),
}),
Reader::from_reader("a,b,c\n1,2,3\nx,y,z".as_bytes()),
)
.unwrap();
let expression_key = FieldExpression::new(parse_quote!(key));
let expression_value = FieldExpression::new(parse_quote!(value));
let query = Query {
selection: vec![expression_key, expression_value],
any_row: vec![],
};
let groups = source.select(None, &query).unwrap().data;
assert_eq!(groups.len(), 6);
assert_eq!(groups[0].fields.len(), 2);
assert_eq!(groups[0].fields[0].to_string().as_str(), "a");
assert_eq!(groups[0].fields[1].to_string().as_str(), "1");
assert_eq!(groups[1].fields[0].to_string().as_str(), "a");
assert_eq!(groups[1].fields[1].to_string().as_str(), "x");
assert_eq!(groups[2].fields[0].to_string().as_str(), "b");
assert_eq!(groups[2].fields[1].to_string().as_str(), "2");
assert_eq!(groups[3].fields[0].to_string().as_str(), "b");
assert_eq!(groups[3].fields[1].to_string().as_str(), "y");
assert_eq!(groups[4].fields[0].to_string().as_str(), "c");
assert_eq!(groups[4].fields[1].to_string().as_str(), "3");
assert_eq!(groups[5].fields[0].to_string().as_str(), "c");
assert_eq!(groups[5].fields[1].to_string().as_str(), "z");
}
#[test]
fn test_any_row_filtering() {
let data = "department,employee_id,name,has_security_clearance,salary
security,emp001,Alice,true,75000
security,emp002,Bob,false,65000
marketing,emp003,Carol,false,60000
marketing,emp004,Dave,false,58000
engineering,emp005,Eve,true,80000
engineering,emp006,Frank,true,82000
hr,emp007,Grace,false,55000";
let source = DataSource::new(
parse_quote!("departments.csv"),
None,
Reader::from_reader(data.as_bytes()),
)
.unwrap();
let department_expr = FieldExpression::new(parse_quote!(department));
let security_filter: crate::predicate::FilterExpression =
parse_quote!((has_security_clearance == true));
let query = Query {
selection: vec![department_expr],
any_row: vec![security_filter],
};
let groups = source.select(None, &query).unwrap().data;
assert_eq!(
groups.len(),
2,
"Should only have 2 departments with cleared employees"
);
assert_eq!(groups[0].fields[0].to_string().as_str(), "engineering");
assert_eq!(groups[1].fields[0].to_string().as_str(), "security");
assert_eq!(groups[0].records.len(), 2); assert_eq!(groups[1].records.len(), 2); }
}