use crate::hyperql::HyperQLCompiler;
use audb::model::query::{Query, QueryLanguage};
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
pub struct QueryGenerator {
pub database_type: String,
hyperql_compiler: HyperQLCompiler,
}
impl QueryGenerator {
pub fn new() -> Self {
Self {
database_type: "Database".to_string(),
hyperql_compiler: HyperQLCompiler::new(),
}
}
pub fn generate_prelude(&self) -> TokenStream {
quote! {
#![allow(dead_code)]
#![allow(unused_assignments)]
use audb_runtime::{Database, Result, QueryError};
use crate::generated::schemas::*;
}
}
pub fn generate(&self, query: &Query) -> TokenStream {
let fn_name = format_ident!("{}", query.name);
let doc_comment = if let Some(ref doc) = query.doc_comment {
let doc_lines: Vec<_> = doc.lines().map(|line| quote! { #[doc = #line] }).collect();
quote! { #(#doc_lines)* }
} else {
quote! {}
};
let params: Vec<_> = query
.params
.iter()
.map(|p| {
let param_name = format_ident!("{}", p.name);
let param_type = self.convert_type(&p.param_type);
quote! { #param_name: #param_type }
})
.collect();
let db_type = format_ident!("{}", self.database_type);
let all_params = if params.is_empty() {
quote! { db: &#db_type }
} else {
quote! { db: &#db_type, #(#params),* }
};
let return_type = self.convert_type(&query.return_type);
let query_body = self.generate_query_body(query);
quote! {
#doc_comment
pub async fn #fn_name(#all_params) -> Result<#return_type> {
#query_body
}
}
}
fn generate_query_body(&self, query: &Query) -> TokenStream {
let source = &query.source;
match query.language {
QueryLanguage::Native => {
let parts: Vec<&str> = source.split("::").collect();
if parts.len() != 2 {
panic!(
"Invalid native query source: {}. Expected format: 'Schema::method'",
source
);
}
let schema_name = format_ident!("{}", parts[0]);
let method_name = format_ident!("{}", parts[1]);
let param_names: Vec<_> = query
.params
.iter()
.map(|p| format_ident!("{}", p.name))
.collect();
if param_names.is_empty() {
quote! {
Ok(#schema_name::#method_name(db)?)
}
} else {
quote! {
Ok(#schema_name::#method_name(db, #(#param_names),*)?)
}
}
}
QueryLanguage::HyperQL => {
match self.hyperql_compiler.compile(query) {
Ok(compiled_code) => compiled_code,
Err(e) => {
eprintln!(
"Warning: HyperQL compilation failed for '{}': {:?}",
query.name, e
);
eprintln!("Falling back to runtime execution");
let param_bindings = self.generate_hyperql_bindings(query);
quote! {
let query = #source;
#param_bindings
db.execute_hyperql(query).await
}
}
}
}
QueryLanguage::SQL => {
let param_bindings = self.generate_sql_bindings(query);
quote! {
let query = #source;
#param_bindings
db.execute_sql(query).await
}
}
QueryLanguage::Cypher => {
let param_bindings = self.generate_cypher_bindings(query);
quote! {
let query = #source;
#param_bindings
db.execute_cypher(query).await
}
}
QueryLanguage::Custom(ref lang) => {
let lang_str = lang.as_str();
quote! {
let query = #source;
db.execute_custom(#lang_str, query).await
}
}
}
}
fn generate_hyperql_bindings(&self, query: &Query) -> TokenStream {
if query.params.is_empty() {
return quote! {};
}
let bindings: Vec<_> = query
.params
.iter()
.map(|p| {
let param_name = format_ident!("{}", p.name);
let key = &p.name;
quote! {
.bind(#key, #param_name)
}
})
.collect();
quote! {
let query = query #(#bindings)*;
}
}
fn generate_sql_bindings(&self, query: &Query) -> TokenStream {
if query.params.is_empty() {
return quote! {};
}
let bindings: Vec<_> = query
.params
.iter()
.enumerate()
.map(|(idx, p)| {
let param_name = format_ident!("{}", p.name);
let position = idx + 1;
quote! {
.bind($#position, #param_name)
}
})
.collect();
quote! {
let query = query #(#bindings)*;
}
}
fn generate_cypher_bindings(&self, query: &Query) -> TokenStream {
if query.params.is_empty() {
return quote! {};
}
let bindings: Vec<_> = query
.params
.iter()
.map(|p| {
let param_name = format_ident!("{}", p.name);
let key = &p.name;
quote! {
.param(#key, #param_name)
}
})
.collect();
quote! {
let query = query #(#bindings)*;
}
}
fn convert_type(&self, typ: &audb::schema::Type) -> TokenStream {
use audb::schema::Type;
match typ {
Type::String => quote! { String },
Type::Integer => quote! { i64 },
Type::Float => quote! { f64 },
Type::Boolean => quote! { bool },
Type::EntityId => quote! { uuid::Uuid },
Type::Timestamp => quote! { chrono::DateTime<chrono::Utc> },
Type::Vector => quote! { Vec<f32> },
Type::JsonValue => quote! { serde_json::Value },
Type::Unit => quote! { () },
Type::Named(name) | Type::Custom(name) => {
let ident = format_ident!("{}", name);
quote! { #ident }
}
Type::Option(inner) => {
let inner_type = self.convert_type(inner);
quote! { Option<#inner_type> }
}
Type::Vec(inner) => {
let inner_type = self.convert_type(inner);
quote! { Vec<#inner_type> }
}
Type::Enum(variants) => {
let _ = variants;
quote! { String }
}
}
}
pub fn generate_all(&self, queries: &[&Query]) -> TokenStream {
let prelude = self.generate_prelude();
let query_code: Vec<_> = queries.iter().map(|q| self.generate(q)).collect();
quote! {
#prelude
#(#query_code)*
}
}
}
impl Default for QueryGenerator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use audb::model::query::Parameter;
#[test]
fn test_generate_simple_query() {
let mut query = Query::new(
"get_user".to_string(),
QueryLanguage::HyperQL,
"SELECT * FROM users WHERE id = :id".to_string(),
audb::schema::Type::Named("User".to_string()),
);
query.add_param(Parameter::new(
"id".to_string(),
audb::schema::Type::EntityId,
));
let generator = QueryGenerator::new();
let code = generator.generate(&query);
let code_str = code.to_string();
assert!(code_str.contains("pub async fn get_user"));
assert!(code_str.contains("db : & Database"));
assert!(code_str.contains("id :"));
}
#[test]
fn test_generate_query_with_no_params() {
let query = Query::new(
"list_all".to_string(),
QueryLanguage::HyperQL,
"SELECT * FROM users".to_string(),
audb::schema::Type::Vec(Box::new(audb::schema::Type::Named("User".to_string()))),
);
let generator = QueryGenerator::new();
let code = generator.generate(&query);
let code_str = code.to_string();
assert!(code_str.contains("pub async fn list_all"));
assert!(code_str.contains("db : & Database"));
assert!(!code_str.contains(", id"));
}
#[test]
fn test_generate_sql_query() {
let mut query = Query::new(
"get_user_sql".to_string(),
QueryLanguage::SQL,
"SELECT * FROM users WHERE id = $1".to_string(),
"User".to_string(),
);
query.add_param(Parameter::new("id".to_string(), "EntityId".to_string()));
let generator = QueryGenerator::new();
let code = generator.generate(&query);
let code_str = code.to_string();
assert!(code_str.contains("execute_sql"));
}
#[test]
fn test_generate_cypher_query() {
let mut query = Query::new(
"get_user_cypher".to_string(),
QueryLanguage::Cypher,
"MATCH (u:User {id: $id}) RETURN u".to_string(),
"User".to_string(),
);
query.add_param(Parameter::new("id".to_string(), "EntityId".to_string()));
let generator = QueryGenerator::new();
let code = generator.generate(&query);
let code_str = code.to_string();
assert!(code_str.contains("execute_cypher"));
}
#[test]
fn test_return_type_conversion() {
let generator = QueryGenerator::new();
let simple = generator.convert_return_type("User");
assert_eq!(simple.to_string(), "User");
let vec_type = generator.convert_return_type("Vec<User>");
assert!(vec_type.to_string().contains("Vec"));
assert!(vec_type.to_string().contains("User"));
let option_type = generator.convert_return_type("Option<User>");
assert!(option_type.to_string().contains("Option"));
assert!(option_type.to_string().contains("User"));
}
#[test]
fn test_param_type_conversion() {
let generator = QueryGenerator::new();
let string_type = generator.convert_param_type("String");
assert!(string_type.to_string().contains("Into"));
let int_type = generator.convert_param_type("Integer");
assert_eq!(int_type.to_string(), "i64");
let entity_type = generator.convert_param_type("EntityId");
assert!(entity_type.to_string().contains("uuid :: Uuid"));
}
#[test]
fn test_generate_with_doc_comment() {
let mut query = Query::new(
"get_user".to_string(),
QueryLanguage::HyperQL,
"SELECT * FROM users".to_string(),
"User".to_string(),
);
query.doc_comment = Some("Get a user by ID".to_string());
let generator = QueryGenerator::new();
let code = generator.generate(&query);
let code_str = code.to_string();
assert!(code_str.contains("Get a user by ID"));
}
#[test]
fn test_generate_multiple_params() {
let mut query = Query::new(
"find_users".to_string(),
QueryLanguage::HyperQL,
"SELECT * FROM users".to_string(),
"Vec<User>".to_string(),
);
query.add_param(Parameter::new("name".to_string(), "String".to_string()));
query.add_param(Parameter::new("age".to_string(), "Integer".to_string()));
let generator = QueryGenerator::new();
let code = generator.generate(&query);
let code_str = code.to_string();
assert!(code_str.contains("name :"));
assert!(code_str.contains("age :"));
}
#[test]
fn test_generate_all() {
let query1 = Query::new(
"get_user".to_string(),
QueryLanguage::HyperQL,
"SELECT *".to_string(),
"User".to_string(),
);
let query2 = Query::new(
"list_users".to_string(),
QueryLanguage::HyperQL,
"SELECT *".to_string(),
"Vec<User>".to_string(),
);
let generator = QueryGenerator::new();
let code = generator.generate_all(&[&query1, &query2]);
let code_str = code.to_string();
assert!(code_str.contains("pub async fn get_user"));
assert!(code_str.contains("pub async fn list_users"));
}
#[test]
fn test_custom_database_type() {
let query = Query::new(
"test".to_string(),
QueryLanguage::HyperQL,
"SELECT *".to_string(),
"User".to_string(),
);
let mut generator = QueryGenerator::new();
generator.database_type = "MyDatabase".to_string();
let code = generator.generate(&query);
let code_str = code.to_string();
assert!(code_str.contains("MyDatabase"));
}
}