use convert_case::{Case, Casing};
use darling::ToTokens;
use proc_macro2::{Span, TokenStream};
use quote::{TokenStreamExt, quote};
use super::options::*;
pub struct CursorStruct<'a> {
pub id: &'a syn::Ident,
pub entity: &'a syn::Ident,
pub column: &'a Column,
pub cursor_mod: &'a syn::Ident,
}
impl CursorStruct<'_> {
fn name(&self) -> String {
let entity_name = format!("{}", self.entity);
format!("{}_by_{}_cursor", entity_name, self.column.name()).to_case(Case::UpperCamel)
}
pub fn ident(&self) -> syn::Ident {
syn::Ident::new(&self.name(), Span::call_site())
}
pub fn cursor_mod(&self) -> &syn::Ident {
self.cursor_mod
}
pub fn select_columns(&self, for_column: Option<&syn::Ident>) -> String {
let mut for_column_str = String::new();
if let Some(for_column) = for_column
&& self.column.name() != for_column
{
for_column_str = format!("{for_column}, ");
}
if self.column.is_id() {
format!("{for_column_str}id")
} else {
format!("{}{}, id", for_column_str, self.column.name())
}
}
pub fn order_by(&self, ascending: bool) -> String {
let dir = if ascending { "ASC" } else { "DESC" };
let nulls = if ascending { "FIRST" } else { "LAST" };
if self.column.is_id() {
format!("id {dir}")
} else if self.column.is_optional() {
format!("{0} {dir} NULLS {nulls}, id {dir}", self.column.name())
} else {
format!("{} {dir}, id {dir}", self.column.name())
}
}
pub fn condition(&self, offset: u32, ascending: bool) -> String {
let comp = if ascending { ">" } else { "<" };
let id_offset = offset + 2;
let column_offset = offset + 3;
if self.column.is_id() {
format!("COALESCE(id {comp} ${id_offset}, true)")
} else if self.column.is_optional() {
format!(
"({0} IS NOT DISTINCT FROM ${column_offset}) AND COALESCE(id {comp} ${id_offset}, true) OR COALESCE({0} {comp} ${column_offset}, {0} IS NOT NULL)",
self.column.name(),
)
} else {
format!(
"COALESCE(({0}, id) {comp} (${column_offset}, ${id_offset}), ${id_offset} IS NULL)",
self.column.name(),
)
}
}
pub fn query_arg_tokens(&self) -> TokenStream {
let id = self.id;
if self.column.is_id() {
quote! {
(first + 1) as i64,
id as Option<#id>,
}
} else if self.column.is_optional() {
let column_name = self.column.name();
let column_type = self.column.ty();
quote! {
(first + 1) as i64,
id as Option<#id>,
#column_name as #column_type,
}
} else {
let column_name = self.column.name();
let column_type = self.column.ty();
quote! {
(first + 1) as i64,
id as Option<#id>,
#column_name as Option<#column_type>,
}
}
}
pub fn destructure_tokens(&self) -> TokenStream {
let column_name = self.column.name();
let mut after_args = quote! {
(id, #column_name)
};
let mut after_destruction = quote! {
(Some(after.id), Some(after.#column_name))
};
let mut after_default = quote! {
(None, None)
};
if self.column.is_id() {
after_args = quote! {
id
};
after_destruction = quote! {
Some(after.id)
};
after_default = quote! {
None
};
} else if self.column.is_optional() {
after_destruction = quote! {
(Some(after.id), after.#column_name)
};
}
quote! {
let es_entity::PaginatedQueryArgs { first, after } = cursor;
let #after_args = if let Some(after) = after {
#after_destruction
} else {
#after_default
};
}
}
#[cfg(feature = "graphql")]
pub fn gql_cursor(&self) -> TokenStream {
let ident = self.ident();
quote! {
impl es_entity::graphql::async_graphql::connection::CursorType for #ident {
type Error = String;
fn encode_cursor(&self) -> String {
use es_entity::graphql::base64::{engine::general_purpose, Engine as _};
let json = es_entity::prelude::serde_json::to_string(&self).expect("could not serialize token");
general_purpose::STANDARD_NO_PAD.encode(json.as_bytes())
}
fn decode_cursor(s: &str) -> Result<Self, Self::Error> {
use es_entity::graphql::base64::{engine::general_purpose, Engine as _};
let bytes = general_purpose::STANDARD_NO_PAD
.decode(s.as_bytes())
.map_err(|e| e.to_string())?;
let json = String::from_utf8(bytes).map_err(|e| e.to_string())?;
es_entity::prelude::serde_json::from_str(&json).map_err(|e| e.to_string())
}
}
}
}
}
impl ToTokens for CursorStruct<'_> {
fn to_tokens(&self, tokens: &mut TokenStream) {
let entity = self.entity;
let accessor = &self.column.accessor();
let ident = self.ident();
let id = &self.id;
let (field, from_impl) = if self.column.is_id() {
(quote! {}, quote! {})
} else {
let column_name = self.column.name();
let column_type = self.column.ty();
(
quote! {
pub #column_name: #column_type,
},
quote! {
#column_name: entity.#accessor.clone(),
},
)
};
tokens.append_all(quote! {
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct #ident {
pub id: #id,
#field
}
impl From<&#entity> for #ident {
fn from(entity: &#entity) -> Self {
Self {
id: entity.id.clone(),
#from_impl
}
}
}
});
}
}
pub struct ListByFn<'a> {
ignore_prefix: Option<&'a syn::LitStr>,
id: &'a syn::Ident,
entity: &'a syn::Ident,
column: &'a Column,
table_name: &'a str,
query_error: syn::Ident,
delete: DeleteOption,
cursor_mod: syn::Ident,
any_nested: bool,
post_hydrate_error: Option<&'a syn::Type>,
#[cfg(feature = "instrument")]
repo_name_snake: String,
}
impl<'a> ListByFn<'a> {
pub fn new(column: &'a Column, opts: &'a RepositoryOptions) -> Self {
Self {
ignore_prefix: opts.table_prefix(),
column,
id: opts.id(),
entity: opts.entity(),
table_name: opts.table_name(),
query_error: opts.query_error(),
delete: opts.delete,
cursor_mod: opts.cursor_mod(),
any_nested: opts.any_nested(),
post_hydrate_error: opts.post_hydrate_hook.as_ref().map(|h| &h.error),
#[cfg(feature = "instrument")]
repo_name_snake: opts.repo_name_snake_case(),
}
}
pub fn cursor(&'a self) -> CursorStruct<'a> {
CursorStruct {
column: self.column,
id: self.id,
entity: self.entity,
cursor_mod: &self.cursor_mod,
}
}
}
impl ToTokens for ListByFn<'_> {
fn to_tokens(&self, tokens: &mut TokenStream) {
let entity = self.entity;
let column_name = self.column.name();
let cursor = self.cursor();
let cursor_ident = cursor.ident();
let cursor_mod = cursor.cursor_mod();
let query_error = &self.query_error;
let query_fn_generics = RepositoryOptions::query_fn_generics(self.any_nested);
let query_fn_op_arg = RepositoryOptions::query_fn_op_arg(self.any_nested);
let query_fn_op_traits = RepositoryOptions::query_fn_op_traits(self.any_nested);
let query_fn_get_op = RepositoryOptions::query_fn_get_op(self.any_nested);
let destructure_tokens = self.cursor().destructure_tokens();
let select_columns = cursor.select_columns(None);
let arg_tokens = cursor.query_arg_tokens();
for delete in [DeleteOption::No, DeleteOption::Soft] {
let fn_name = syn::Ident::new(
&format!(
"list_by_{}{}",
column_name,
delete.include_deletion_fn_postfix()
),
Span::call_site(),
);
let fn_in_op = syn::Ident::new(
&format!(
"list_by_{}{}_in_op",
column_name,
delete.include_deletion_fn_postfix()
),
Span::call_site(),
);
let asc_query = format!(
r#"SELECT {} FROM {} WHERE ({}){} ORDER BY {} LIMIT $1"#,
select_columns,
self.table_name,
cursor.condition(0, true),
if delete == DeleteOption::No {
self.delete.not_deleted_condition()
} else {
""
},
cursor.order_by(true),
);
let desc_query = format!(
r#"SELECT {} FROM {} WHERE ({}){} ORDER BY {} LIMIT $1"#,
select_columns,
self.table_name,
cursor.condition(0, false),
if delete == DeleteOption::No {
self.delete.not_deleted_condition()
} else {
""
},
cursor.order_by(false),
);
let es_query_asc_call = if let Some(prefix) = self.ignore_prefix {
quote! {
es_entity::es_query!(
tbl_prefix = #prefix,
#asc_query,
#arg_tokens
)
}
} else {
quote! {
es_entity::es_query!(
entity = #entity,
#asc_query,
#arg_tokens
)
}
};
let es_query_desc_call = if let Some(prefix) = self.ignore_prefix {
quote! {
es_entity::es_query!(
tbl_prefix = #prefix,
#desc_query,
#arg_tokens
)
}
} else {
quote! {
es_entity::es_query!(
entity = #entity,
#desc_query,
#arg_tokens
)
}
};
#[cfg(feature = "instrument")]
let (
instrument_attr,
extract_has_cursor,
record_fields,
record_results,
error_recording,
) = {
let entity_name = entity.to_string();
let repo_name = &self.repo_name_snake;
let span_name = format!("{}.list_by_{}", repo_name, column_name);
(
quote! {
#[tracing::instrument(name = #span_name, skip_all, fields(entity = #entity_name, first, has_cursor, direction = tracing::field::debug(&direction), count = tracing::field::Empty, has_next_page = tracing::field::Empty, ids = tracing::field::Empty, error = tracing::field::Empty, exception.message = tracing::field::Empty, exception.type = tracing::field::Empty))]
},
quote! {
let has_cursor = cursor.after.is_some();
},
quote! {
tracing::Span::current().record("first", first);
tracing::Span::current().record("has_cursor", has_cursor);
},
quote! {
let result_ids: Vec<_> = entities.iter().map(|e| &e.id).collect();
tracing::Span::current().record("count", result_ids.len());
tracing::Span::current().record("has_next_page", has_next_page);
tracing::Span::current().record("ids", tracing::field::debug(&result_ids));
},
quote! {
if let Err(ref e) = __result {
tracing::Span::current().record("error", true);
tracing::Span::current().record("exception.message", tracing::field::display(e));
tracing::Span::current().record("exception.type", std::any::type_name_of_val(e));
}
},
)
};
#[cfg(not(feature = "instrument"))]
let (
instrument_attr,
extract_has_cursor,
record_fields,
record_results,
error_recording,
) = (quote! {}, quote! {}, quote! {}, quote! {}, quote! {});
let post_hydrate_check = if self.post_hydrate_error.is_some() {
quote! {
for __entity in &entities {
self.execute_post_hydrate_hook(__entity).map_err(#query_error::PostHydrateError)?;
}
}
} else {
quote! {}
};
tokens.append_all(quote! {
pub async fn #fn_name(
&self,
cursor: es_entity::PaginatedQueryArgs<#cursor_mod::#cursor_ident>,
direction: es_entity::ListDirection,
) -> Result<es_entity::PaginatedQueryRet<#entity, #cursor_mod::#cursor_ident>, #query_error> {
self.#fn_in_op(#query_fn_get_op, cursor, direction).await
}
#instrument_attr
pub async fn #fn_in_op #query_fn_generics(
&self,
#query_fn_op_arg,
cursor: es_entity::PaginatedQueryArgs<#cursor_mod::#cursor_ident>,
direction: es_entity::ListDirection,
) -> Result<es_entity::PaginatedQueryRet<#entity, #cursor_mod::#cursor_ident>, #query_error>
where
OP: #query_fn_op_traits
{
let __result: Result<es_entity::PaginatedQueryRet<#entity, #cursor_mod::#cursor_ident>, #query_error> = async {
#extract_has_cursor
#destructure_tokens
#record_fields
let (entities, has_next_page) = match direction {
es_entity::ListDirection::Ascending => {
#es_query_asc_call.fetch_n(op, first).await?
},
es_entity::ListDirection::Descending => {
#es_query_desc_call.fetch_n(op, first).await?
},
};
#post_hydrate_check
#record_results
let end_cursor = entities.last().map(#cursor_mod::#cursor_ident::from);
Ok(es_entity::PaginatedQueryRet {
entities,
has_next_page,
end_cursor,
})
}.await;
#error_recording
__result
}
});
if delete == self.delete || self.delete == DeleteOption::SoftWithoutQueries {
break;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use proc_macro2::Span;
use syn::Ident;
#[test]
fn cursor_struct_by_id() {
let id_type = Ident::new("EntityId", Span::call_site());
let entity = Ident::new("Entity", Span::call_site());
let by_column = Column::for_id(syn::parse_str("EntityId").unwrap());
let cursor_mod = Ident::new("cursor_mod", Span::call_site());
let cursor = CursorStruct {
column: &by_column,
id: &id_type,
entity: &entity,
cursor_mod: &cursor_mod,
};
let mut tokens = TokenStream::new();
cursor.to_tokens(&mut tokens);
let expected = quote! {
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct EntityByIdCursor {
pub id: EntityId,
}
impl From<&Entity> for EntityByIdCursor {
fn from(entity: &Entity) -> Self {
Self {
id: entity.id.clone(),
}
}
}
};
assert_eq!(tokens.to_string(), expected.to_string());
}
#[test]
fn cursor_struct_by_created_at() {
let id_type = Ident::new("EntityId", Span::call_site());
let entity = Ident::new("Entity", Span::call_site());
let by_column = Column::for_created_at();
let cursor_mod = Ident::new("cursor_mod", Span::call_site());
let cursor = CursorStruct {
column: &by_column,
id: &id_type,
entity: &entity,
cursor_mod: &cursor_mod,
};
let mut tokens = TokenStream::new();
cursor.to_tokens(&mut tokens);
let expected = quote! {
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct EntityByCreatedAtCursor {
pub id: EntityId,
pub created_at: es_entity::prelude::chrono::DateTime<es_entity::prelude::chrono::Utc>,
}
impl From<&Entity> for EntityByCreatedAtCursor {
fn from(entity: &Entity) -> Self {
Self {
id: entity.id.clone(),
created_at: entity.events()
.entity_first_persisted_at()
.expect("entity not persisted")
.clone(),
}
}
}
};
assert_eq!(tokens.to_string(), expected.to_string());
}
#[test]
fn list_by_fn() {
let id_type = Ident::new("EntityId", Span::call_site());
let entity = Ident::new("Entity", Span::call_site());
let query_error = syn::Ident::new("EntityQueryError", Span::call_site());
let column = Column::for_id(syn::parse_str("EntityId").unwrap());
let cursor_mod = Ident::new("cursor_mod", Span::call_site());
let persist_fn = ListByFn {
ignore_prefix: None,
column: &column,
id: &id_type,
entity: &entity,
table_name: "entities",
query_error,
delete: DeleteOption::SoftWithoutQueries,
cursor_mod,
any_nested: false,
post_hydrate_error: None,
#[cfg(feature = "instrument")]
repo_name_snake: "test_repo".to_string(),
};
let mut tokens = TokenStream::new();
persist_fn.to_tokens(&mut tokens);
let expected = quote! {
pub async fn list_by_id(
&self,
cursor: es_entity::PaginatedQueryArgs<cursor_mod::EntityByIdCursor>,
direction: es_entity::ListDirection,
) -> Result<es_entity::PaginatedQueryRet<Entity, cursor_mod::EntityByIdCursor>, EntityQueryError> {
self.list_by_id_in_op(self.pool(), cursor, direction).await
}
pub async fn list_by_id_in_op<'a, OP>(
&self,
op: OP,
cursor: es_entity::PaginatedQueryArgs<cursor_mod::EntityByIdCursor>,
direction: es_entity::ListDirection,
) -> Result<es_entity::PaginatedQueryRet<Entity, cursor_mod::EntityByIdCursor>, EntityQueryError>
where
OP: es_entity::IntoOneTimeExecutor<'a>
{
let __result: Result<es_entity::PaginatedQueryRet<Entity, cursor_mod::EntityByIdCursor>, EntityQueryError> = async {
let es_entity::PaginatedQueryArgs { first, after } = cursor;
let id = if let Some(after) = after {
Some(after.id)
} else {
None
};
let (entities, has_next_page) = match direction {
es_entity::ListDirection::Ascending => {
es_entity::es_query!(
entity = Entity,
"SELECT id FROM entities WHERE (COALESCE(id > $2, true)) AND deleted = FALSE ORDER BY id ASC LIMIT $1",
(first + 1) as i64,
id as Option<EntityId>,
)
.fetch_n(op, first)
.await?
},
es_entity::ListDirection::Descending => {
es_entity::es_query!(
entity = Entity,
"SELECT id FROM entities WHERE (COALESCE(id < $2, true)) AND deleted = FALSE ORDER BY id DESC LIMIT $1",
(first + 1) as i64,
id as Option<EntityId>,
)
.fetch_n(op, first)
.await?
},
};
let end_cursor = entities.last().map(cursor_mod::EntityByIdCursor::from);
Ok(es_entity::PaginatedQueryRet {
entities,
has_next_page,
end_cursor,
})
}.await;
__result
}
};
assert_eq!(tokens.to_string(), expected.to_string());
}
#[test]
fn list_by_fn_with_soft_delete_include_deleted() {
let id_type = Ident::new("EntityId", Span::call_site());
let entity = Ident::new("Entity", Span::call_site());
let query_error = syn::Ident::new("EntityQueryError", Span::call_site());
let column = Column::for_id(syn::parse_str("EntityId").unwrap());
let cursor_mod = Ident::new("cursor_mod", Span::call_site());
let persist_fn = ListByFn {
ignore_prefix: None,
column: &column,
id: &id_type,
entity: &entity,
table_name: "entities",
query_error,
delete: DeleteOption::Soft,
cursor_mod,
any_nested: false,
post_hydrate_error: None,
#[cfg(feature = "instrument")]
repo_name_snake: "test_repo".to_string(),
};
let mut tokens = TokenStream::new();
persist_fn.to_tokens(&mut tokens);
let token_str = tokens.to_string();
assert!(token_str.contains("list_by_id_include_deleted"));
}
#[test]
fn list_by_fn_name() {
let id_type = Ident::new("EntityId", Span::call_site());
let entity = Ident::new("Entity", Span::call_site());
let query_error = syn::Ident::new("EntityQueryError", Span::call_site());
let column = Column::new(
syn::Ident::new("name", proc_macro2::Span::call_site()),
syn::parse_str("String").unwrap(),
);
let cursor_mod = Ident::new("cursor_mod", Span::call_site());
let persist_fn = ListByFn {
ignore_prefix: None,
column: &column,
id: &id_type,
entity: &entity,
table_name: "entities",
query_error,
delete: DeleteOption::No,
cursor_mod,
any_nested: false,
post_hydrate_error: None,
#[cfg(feature = "instrument")]
repo_name_snake: "test_repo".to_string(),
};
let mut tokens = TokenStream::new();
persist_fn.to_tokens(&mut tokens);
let expected = quote! {
pub async fn list_by_name(
&self,
cursor: es_entity::PaginatedQueryArgs<cursor_mod::EntityByNameCursor>,
direction: es_entity::ListDirection,
) -> Result<es_entity::PaginatedQueryRet<Entity, cursor_mod::EntityByNameCursor>, EntityQueryError> {
self.list_by_name_in_op(self.pool(), cursor, direction).await
}
pub async fn list_by_name_in_op<'a, OP>(
&self,
op: OP,
cursor: es_entity::PaginatedQueryArgs<cursor_mod::EntityByNameCursor>,
direction: es_entity::ListDirection,
) -> Result<es_entity::PaginatedQueryRet<Entity, cursor_mod::EntityByNameCursor>, EntityQueryError>
where
OP: es_entity::IntoOneTimeExecutor<'a>
{
let __result: Result<es_entity::PaginatedQueryRet<Entity, cursor_mod::EntityByNameCursor>, EntityQueryError> = async {
let es_entity::PaginatedQueryArgs { first, after } = cursor;
let (id, name) = if let Some(after) = after {
(Some(after.id), Some(after.name))
} else {
(None, None)
};
let (entities, has_next_page) = match direction {
es_entity::ListDirection::Ascending => {
es_entity::es_query!(
entity = Entity,
"SELECT name, id FROM entities WHERE (COALESCE((name, id) > ($3, $2), $2 IS NULL)) ORDER BY name ASC, id ASC LIMIT $1",
(first + 1) as i64,
id as Option<EntityId>,
name as Option<String>,
)
.fetch_n(op, first)
.await?
},
es_entity::ListDirection::Descending => {
es_entity::es_query!(
entity = Entity,
"SELECT name, id FROM entities WHERE (COALESCE((name, id) < ($3, $2), $2 IS NULL)) ORDER BY name DESC, id DESC LIMIT $1",
(first + 1) as i64,
id as Option<EntityId>,
name as Option<String>,
)
.fetch_n(op, first)
.await?
},
};
let end_cursor = entities.last().map(cursor_mod::EntityByNameCursor::from);
Ok(es_entity::PaginatedQueryRet {
entities,
has_next_page,
end_cursor,
})
}.await;
__result
}
};
assert_eq!(tokens.to_string(), expected.to_string());
}
#[test]
fn list_by_fn_optional_column() {
let id_type = Ident::new("EntityId", Span::call_site());
let entity = Ident::new("Entity", Span::call_site());
let query_error = syn::Ident::new("EntityQueryError", Span::call_site());
let column = Column::new(
syn::Ident::new("value", proc_macro2::Span::call_site()),
syn::parse_str("Option<rust_decimal::Decimal>").unwrap(),
);
let cursor_mod = Ident::new("cursor_mod", Span::call_site());
let persist_fn = ListByFn {
ignore_prefix: None,
column: &column,
id: &id_type,
entity: &entity,
table_name: "entities",
query_error,
delete: DeleteOption::No,
cursor_mod,
any_nested: false,
post_hydrate_error: None,
#[cfg(feature = "instrument")]
repo_name_snake: "test_repo".to_string(),
};
let mut tokens = TokenStream::new();
persist_fn.to_tokens(&mut tokens);
let expected = quote! {
pub async fn list_by_value(
&self,
cursor: es_entity::PaginatedQueryArgs<cursor_mod::EntityByValueCursor>,
direction: es_entity::ListDirection,
) -> Result<es_entity::PaginatedQueryRet<Entity, cursor_mod::EntityByValueCursor>, EntityQueryError> {
self.list_by_value_in_op(self.pool(), cursor, direction).await
}
pub async fn list_by_value_in_op<'a, OP>(
&self,
op: OP,
cursor: es_entity::PaginatedQueryArgs<cursor_mod::EntityByValueCursor>,
direction: es_entity::ListDirection,
) -> Result<es_entity::PaginatedQueryRet<Entity, cursor_mod::EntityByValueCursor>, EntityQueryError>
where
OP: es_entity::IntoOneTimeExecutor<'a>
{
let __result: Result<es_entity::PaginatedQueryRet<Entity, cursor_mod::EntityByValueCursor>, EntityQueryError> = async {
let es_entity::PaginatedQueryArgs { first, after } = cursor;
let (id, value) = if let Some(after) = after {
(Some(after.id), after.value)
} else {
(None, None)
};
let (entities, has_next_page) = match direction {
es_entity::ListDirection::Ascending => {
es_entity::es_query!(
entity = Entity,
"SELECT value, id FROM entities WHERE ((value IS NOT DISTINCT FROM $3) AND COALESCE(id > $2, true) OR COALESCE(value > $3, value IS NOT NULL)) ORDER BY value ASC NULLS FIRST, id ASC LIMIT $1",
(first + 1) as i64,
id as Option<EntityId>,
value as Option<rust_decimal::Decimal>,
)
.fetch_n(op, first)
.await?
},
es_entity::ListDirection::Descending => {
es_entity::es_query!(
entity = Entity,
"SELECT value, id FROM entities WHERE ((value IS NOT DISTINCT FROM $3) AND COALESCE(id < $2, true) OR COALESCE(value < $3, value IS NOT NULL)) ORDER BY value DESC NULLS LAST, id DESC LIMIT $1",
(first + 1) as i64,
id as Option<EntityId>,
value as Option<rust_decimal::Decimal>,
)
.fetch_n(op, first)
.await?
},
};
let end_cursor = entities.last().map(cursor_mod::EntityByValueCursor::from);
Ok(es_entity::PaginatedQueryRet {
entities,
has_next_page,
end_cursor,
})
}.await;
__result
}
};
assert_eq!(tokens.to_string(), expected.to_string());
}
}