use crate::hyperql::{CompilerError, CompilerResult, QueryPattern};
use audb::model::Query;
use proc_macro2::TokenStream;
use quote::quote;
pub struct CodeGenerator;
impl CodeGenerator {
pub fn new() -> Self {
Self
}
pub fn generate(&self, pattern: &QueryPattern, query: &Query) -> CompilerResult<TokenStream> {
match pattern {
QueryPattern::PointQuery { table, id_param } => {
self.generate_point_query(table, id_param, query)
}
QueryPattern::FilterQuery { table, filters } => {
self.generate_filter_query(table, filters, query)
}
QueryPattern::ProjectionQuery {
table,
fields,
filters,
} => self.generate_projection_query(table, fields, filters, query),
QueryPattern::OrderedQuery {
table,
filters,
order_by,
limit,
offset,
} => self.generate_ordered_query(table, filters, order_by, *limit, *offset, query),
QueryPattern::RelationshipQuery {
table,
traverse,
filters,
} => self.generate_relationship_query(table, traverse, filters, query),
QueryPattern::AggregationQuery {
table,
aggregates,
filters,
group_by,
} => self.generate_aggregation_query(table, aggregates, filters, group_by, query),
QueryPattern::ComplexQuery { ast: _ } => Err(CompilerError::UnsupportedPattern(
"Complex queries not yet supported".to_string(),
)),
}
}
fn generate_point_query(
&self,
table: &str,
id_param: &str,
query: &Query,
) -> CompilerResult<TokenStream> {
use crate::hyperql::TypeMapper;
use quote::format_ident;
let type_mapper = TypeMapper::new();
let schema_type_str = type_mapper.table_to_schema_type(table);
let schema_type_str = if schema_type_str.ends_with('s') {
&schema_type_str[..schema_type_str.len() - 1]
} else {
&schema_type_str
};
let schema_type = format_ident!("{}", schema_type_str);
let param_ident = format_ident!("{}", id_param);
let is_optional = matches!(query.return_type, audb::schema::Type::Option(_));
if is_optional {
Ok(quote! {
#schema_type::get(db, #param_ident)
})
} else {
Ok(quote! {
#schema_type::get(db, #param_ident)?
.ok_or_else(|| QueryError::RowNotFound)
})
}
}
fn generate_filter_query(
&self,
table: &str,
filters: &[crate::hyperql::FilterCondition],
query: &Query,
) -> CompilerResult<TokenStream> {
use crate::hyperql::{FilterCompiler, TypeMapper};
use quote::format_ident;
let type_mapper = TypeMapper::new();
let schema_type_str = type_mapper.table_to_schema_type(table);
let schema_type_str = if schema_type_str.ends_with('s') {
&schema_type_str[..schema_type_str.len() - 1]
} else {
&schema_type_str
};
let schema_type = format_ident!("{}", schema_type_str);
let filter_compiler = FilterCompiler::new();
let filter_expr = filter_compiler.compile(filters)?;
Ok(quote! {
let all_entities = #schema_type::list_all(db)?;
let mut results = Vec::new();
for entity in all_entities {
if #filter_expr {
results.push(entity);
}
}
Ok(results)
})
}
fn generate_projection_query(
&self,
table: &str,
fields: &[String],
filters: &[crate::hyperql::FilterCondition],
query: &Query,
) -> CompilerResult<TokenStream> {
use crate::hyperql::{FilterCompiler, TypeMapper};
use quote::format_ident;
let type_mapper = TypeMapper::new();
let schema_type_str = type_mapper.table_to_schema_type(table);
let schema_type_str = if schema_type_str.ends_with('s') {
&schema_type_str[..schema_type_str.len() - 1]
} else {
&schema_type_str
};
let schema_type = format_ident!("{}", schema_type_str);
let projection_type = format_ident!(
"{}Result",
query
.name
.chars()
.next()
.unwrap()
.to_uppercase()
.to_string()
+ &query.name[1..]
);
let field_idents: Vec<_> = fields.iter().map(|f| format_ident!("{}", f)).collect();
let has_filters = !filters.is_empty();
let filter_expr = if has_filters {
let filter_compiler = FilterCompiler::new();
filter_compiler.compile(filters)?
} else {
quote! { true }
};
let projection_struct = quote! {
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct #projection_type {
#(pub #field_idents: String),*
}
};
Ok(quote! {
#projection_struct
let all_entities = #schema_type::list_all(db)?;
let mut results = Vec::new();
for entity in all_entities {
if #filter_expr {
results.push(#projection_type {
#(#field_idents: entity.#field_idents.to_string()),*
});
}
}
Ok(results)
})
}
fn generate_ordered_query(
&self,
table: &str,
filters: &[crate::hyperql::FilterCondition],
order_by: &[crate::hyperql::OrderByClause],
limit: Option<u64>,
offset: Option<u64>,
_query: &Query,
) -> CompilerResult<TokenStream> {
use crate::hyperql::{FilterCompiler, TypeMapper};
use quote::format_ident;
let type_mapper = TypeMapper::new();
let schema_type_str = type_mapper.table_to_schema_type(table);
let schema_type_str = if schema_type_str.ends_with('s') {
&schema_type_str[..schema_type_str.len() - 1]
} else {
&schema_type_str
};
let schema_type = format_ident!("{}", schema_type_str);
let has_filters = !filters.is_empty();
let filter_expr = if has_filters {
let filter_compiler = FilterCompiler::new();
filter_compiler.compile(filters)?
} else {
quote! { true }
};
let filter_code = if has_filters {
quote! {
for entity in all_entities {
if #filter_expr {
results.push(entity);
}
}
}
} else {
quote! {
results = all_entities;
}
};
let sort_code = if !order_by.is_empty() {
let order = &order_by[0];
let field_ident = format_ident!("{}", order.field);
if order.descending {
quote! {
results.sort_by(|a, b| b.#field_ident.cmp(&a.#field_ident));
}
} else {
quote! {
results.sort_by(|a, b| a.#field_ident.cmp(&b.#field_ident));
}
}
} else {
quote! {}
};
let limit_code = if let Some(lim) = limit {
if let Some(off) = offset {
quote! {
let start = #off as usize;
let end = (start + #lim as usize).min(results.len());
results = results[start..end].to_vec();
}
} else {
quote! {
results.truncate(#lim as usize);
}
}
} else if let Some(off) = offset {
quote! {
let start = #off as usize;
if start < results.len() {
results = results[start..].to_vec();
} else {
results.clear();
}
}
} else {
quote! {}
};
Ok(quote! {
let all_entities = #schema_type::list_all(db)?;
let mut results = Vec::new();
#filter_code
#sort_code
#limit_code
Ok(results)
})
}
fn generate_relationship_query(
&self,
table: &str,
traverse: &crate::hyperql::TraverseInfo,
filters: &[crate::hyperql::FilterCondition],
query: &Query,
) -> CompilerResult<TokenStream> {
use crate::hyperql::{FilterCompiler, TypeMapper};
use quote::format_ident;
let type_mapper = TypeMapper::new();
let schema_type_str = type_mapper.table_to_schema_type(table);
let schema_type_str = if schema_type_str.ends_with('s') {
&schema_type_str[..schema_type_str.len() - 1]
} else {
&schema_type_str
};
let schema_type = format_ident!("{}", schema_type_str);
let target_type_str = type_mapper.table_to_schema_type(&traverse.target_table);
let target_type_str = if target_type_str.ends_with('s') {
&target_type_str[..target_type_str.len() - 1]
} else {
&target_type_str
};
let target_type = format_ident!("{}", target_type_str);
let has_filters = !filters.is_empty();
let filter_expr = if has_filters {
let filter_compiler = FilterCompiler::new();
filter_compiler.compile(filters)?
} else {
quote! { true }
};
let relationship_method = format_ident!("{}", traverse.relationship);
if has_filters {
Ok(quote! {
let all_source = #schema_type::list_all(db)?;
let mut results = Vec::new();
for source in all_source {
let related = source.#relationship_method(db)?;
for entity in related {
if #filter_expr {
results.push(entity);
}
}
}
Ok(results)
})
} else {
Ok(quote! {
let all_source = #schema_type::list_all(db)?;
let mut results = Vec::new();
for source in all_source {
let mut related = source.#relationship_method(db)?;
results.append(&mut related);
}
Ok(results)
})
}
}
fn generate_aggregation_query(
&self,
table: &str,
aggregates: &[crate::hyperql::AggregateFunction],
filters: &[crate::hyperql::FilterCondition],
_group_by: &[String],
_query: &Query,
) -> CompilerResult<TokenStream> {
use crate::hyperql::{AggregateFunction, FilterCompiler, TypeMapper};
use quote::format_ident;
let type_mapper = TypeMapper::new();
let schema_type_str = type_mapper.table_to_schema_type(table);
let schema_type_str = if schema_type_str.ends_with('s') {
&schema_type_str[..schema_type_str.len() - 1]
} else {
&schema_type_str
};
let schema_type = format_ident!("{}", schema_type_str);
let has_filters = !filters.is_empty();
let filter_expr = if has_filters {
let filter_compiler = FilterCompiler::new();
filter_compiler.compile(filters)?
} else {
quote! { true }
};
if aggregates.len() != 1 {
return Err(CompilerError::UnsupportedPattern(
"Multiple aggregate functions not yet supported".to_string(),
));
}
let agg = &aggregates[0];
match agg {
AggregateFunction::Count => {
if has_filters {
Ok(quote! {
let all_entities = #schema_type::list_all(db)?;
let mut count = 0i64;
for entity in all_entities {
if #filter_expr {
count += 1;
}
}
Ok(count)
})
} else {
Ok(quote! {
let all_entities = #schema_type::list_all(db)?;
Ok(all_entities.len() as i64)
})
}
}
AggregateFunction::Sum(field) => {
let field_ident = format_ident!("{}", field);
Ok(quote! {
let all_entities = #schema_type::list_all(db)?;
let mut sum = 0i64;
for entity in all_entities {
if #filter_expr {
sum += entity.#field_ident as i64;
}
}
Ok(sum)
})
}
AggregateFunction::Avg(field) => {
let field_ident = format_ident!("{}", field);
Ok(quote! {
let all_entities = #schema_type::list_all(db)?;
let mut sum = 0f64;
let mut count = 0i64;
for entity in all_entities {
if #filter_expr {
sum += entity.#field_ident as f64;
count += 1;
}
}
let avg = if count > 0 { sum / count as f64 } else { 0.0 };
Ok(avg)
})
}
AggregateFunction::Min(field) => {
let field_ident = format_ident!("{}", field);
Ok(quote! {
let all_entities = #schema_type::list_all(db)?;
let mut min_val: Option<i64> = None;
for entity in all_entities {
if #filter_expr {
let val = entity.#field_ident as i64;
min_val = Some(match min_val {
None => val,
Some(current) => current.min(val),
});
}
}
Ok(min_val.unwrap_or(0))
})
}
AggregateFunction::Max(field) => {
let field_ident = format_ident!("{}", field);
Ok(quote! {
let all_entities = #schema_type::list_all(db)?;
let mut max_val: Option<i64> = None;
for entity in all_entities {
if #filter_expr {
let val = entity.#field_ident as i64;
max_val = Some(match max_val {
None => val,
Some(current) => current.max(val),
});
}
}
Ok(max_val.unwrap_or(0))
})
}
}
}
}
impl Default for CodeGenerator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_code_generator_creation() {
let generator = CodeGenerator::new();
assert!(true); }
}