use proc_macro2::TokenStream;
use quote::quote;
use super::{
context::Context,
helpers::{generate_query_bindings, generate_where_conditions}
};
impl Context<'_> {
pub fn query_method(&self) -> TokenStream {
if !self.entity.has_filters() {
return TokenStream::new();
}
let Self {
entity_name,
row_name,
table,
columns_str,
id_name,
soft_delete,
..
} = self;
let query_type = self.entity.ident_with("", "Query");
let filter_fields = self.entity.filter_fields();
let where_conditions = generate_where_conditions(&filter_fields, *soft_delete);
let bindings = generate_query_bindings(&filter_fields);
quote! {
async fn query(&self, query: #query_type) -> Result<Vec<#entity_name>, Self::Error> {
let mut conditions: Vec<String> = Vec::new();
let mut param_idx: usize = 1;
#where_conditions
let where_clause = if conditions.is_empty() {
String::new()
} else {
format!("WHERE {}", conditions.join(" AND "))
};
let limit_idx = param_idx;
param_idx += 1;
let offset_idx = param_idx;
let sql = format!(
"SELECT {} FROM {} {} ORDER BY {} DESC LIMIT ${} OFFSET ${}",
#columns_str, #table, where_clause, stringify!(#id_name), limit_idx, offset_idx
);
let mut q = sqlx::query_as::<_, #row_name>(&sql);
#bindings
q = q.bind(query.limit.unwrap_or(100)).bind(query.offset.unwrap_or(0));
let rows = q.fetch_all(self).await?;
Ok(rows.into_iter().map(#entity_name::from).collect())
}
}
}
pub fn stream_filtered_method(&self) -> TokenStream {
if !self.streams || !self.entity.has_filters() {
return TokenStream::new();
}
let Self {
entity_name,
row_name,
table,
columns_str,
id_name,
soft_delete,
..
} = self;
let filter_type = self.entity.ident_with("", "Filter");
let filter_fields = self.entity.filter_fields();
let where_conditions = generate_where_conditions(&filter_fields, *soft_delete);
let bindings = generate_query_bindings(&filter_fields);
quote! {
async fn stream_filtered(
&self,
filter: #filter_type,
) -> Result<std::pin::Pin<Box<dyn futures::Stream<Item = Result<#entity_name, Self::Error>> + Send + '_>>, Self::Error> {
use futures::StreamExt;
let mut conditions: Vec<String> = Vec::new();
let mut param_idx: usize = 1;
let query = filter;
#where_conditions
let where_clause = if conditions.is_empty() {
String::new()
} else {
format!("WHERE {}", conditions.join(" AND "))
};
let limit_idx = param_idx;
param_idx += 1;
let offset_idx = param_idx;
let sql = format!(
"SELECT {} FROM {} {} ORDER BY {} DESC LIMIT ${} OFFSET ${}",
#columns_str, #table, where_clause, stringify!(#id_name), limit_idx, offset_idx
);
let mut q = sqlx::query_as::<_, #row_name>(&sql);
#bindings
q = q.bind(query.limit.unwrap_or(10000)).bind(query.offset.unwrap_or(0));
let rows = q.fetch_all(self).await?;
let entities: Vec<#entity_name> = rows.into_iter().map(#entity_name::from).collect();
let stream = futures::stream::iter(entities.into_iter().map(Ok));
Ok(Box::pin(stream))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::entity::parse::EntityDef;
#[test]
fn query_method_no_filters_returns_empty() {
let input: syn::DeriveInput = syn::parse_quote! {
#[entity(table = "users")]
pub struct User {
#[id]
pub id: uuid::Uuid,
#[field(create, response)]
pub name: String,
}
};
let entity = EntityDef::from_derive_input(&input).unwrap();
let ctx = Context::new(&entity);
let method = ctx.query_method();
assert!(method.is_empty());
}
#[test]
fn query_method_with_filter() {
let input: syn::DeriveInput = syn::parse_quote! {
#[entity(table = "users")]
pub struct User {
#[id]
pub id: uuid::Uuid,
#[field(create, response)]
#[filter]
pub name: String,
}
};
let entity = EntityDef::from_derive_input(&input).unwrap();
let ctx = Context::new(&entity);
let method = ctx.query_method();
let method_str = method.to_string();
assert!(method_str.contains("async fn query"));
assert!(method_str.contains("UserQuery"));
assert!(method_str.contains("conditions"));
assert!(method_str.contains("where_clause"));
}
#[test]
fn query_method_with_soft_delete() {
let input: syn::DeriveInput = syn::parse_quote! {
#[entity(table = "users", soft_delete)]
pub struct User {
#[id]
pub id: uuid::Uuid,
#[field(create, response)]
#[filter]
pub name: String,
#[field(response)]
#[auto]
pub deleted_at: Option<chrono::DateTime<chrono::Utc>>,
}
};
let entity = EntityDef::from_derive_input(&input).unwrap();
let ctx = Context::new(&entity);
let method = ctx.query_method();
let method_str = method.to_string();
assert!(method_str.contains("deleted_at"));
}
#[test]
fn stream_filtered_no_streams_returns_empty() {
let input: syn::DeriveInput = syn::parse_quote! {
#[entity(table = "users")]
pub struct User {
#[id]
pub id: uuid::Uuid,
#[field(create, response)]
#[filter]
pub name: String,
}
};
let entity = EntityDef::from_derive_input(&input).unwrap();
let ctx = Context::new(&entity);
let method = ctx.stream_filtered_method();
assert!(method.is_empty());
}
#[test]
fn stream_filtered_no_filters_returns_empty() {
let input: syn::DeriveInput = syn::parse_quote! {
#[entity(table = "users", streams)]
pub struct User {
#[id]
pub id: uuid::Uuid,
#[field(create, response)]
pub name: String,
}
};
let entity = EntityDef::from_derive_input(&input).unwrap();
let ctx = Context::new(&entity);
let method = ctx.stream_filtered_method();
assert!(method.is_empty());
}
#[test]
fn stream_filtered_with_streams_and_filters() {
let input: syn::DeriveInput = syn::parse_quote! {
#[entity(table = "users", streams)]
pub struct User {
#[id]
pub id: uuid::Uuid,
#[field(create, response)]
#[filter]
pub name: String,
}
};
let entity = EntityDef::from_derive_input(&input).unwrap();
let ctx = Context::new(&entity);
let method = ctx.stream_filtered_method();
let method_str = method.to_string();
assert!(method_str.contains("stream_filtered"));
assert!(method_str.contains("UserFilter"));
assert!(method_str.contains("futures"));
}
}