use std::{collections::BTreeMap, sync::Arc};
use maplit::btreemap;
use quote::quote;
use trustfall::{Schema, SchemaAdapter, TryIntoStruct};
use crate::util::upper_case_variant_name;
use super::util::escaped_rust_name;
use super::{
root::RustFile,
util::{
field_value_to_rust_type, parse_import, to_lower_snake_case, trustfall_type_to_rust_type,
type_edge_resolver_fn_name,
},
};
pub(super) fn make_edges_file(
querying_schema: &Schema,
adapter: Arc<SchemaAdapter<'_>>,
edges_file: &mut RustFile,
) {
let query = r#"
{
VertexType {
name @output
edge @fold @transform(op: "count") @filter(op: ">", value: ["$zero"]) {
edge_name: name @output
parameter_: parameter @fold {
name @output
type @output
}
}
}
}"#;
let variables: BTreeMap<Arc<str>, i64> = btreemap! {
"zero".into() => 0,
};
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, serde::Deserialize)]
struct ResultRow {
name: String,
edge_name: Vec<String>,
parameter_name: Vec<Vec<String>>,
parameter_type: Vec<Vec<String>>,
}
let mut rows: Vec<_> = trustfall::execute_query(querying_schema, adapter, query, variables)
.expect("invalid query")
.map(|x| x.try_into_struct::<ResultRow>().expect("invalid conversion"))
.collect();
rows.sort_unstable();
for row in rows {
let mut edges: Vec<(String, Vec<(String, String)>)> = row
.edge_name
.into_iter()
.zip(row.parameter_name.into_iter().zip(row.parameter_type.into_iter()))
.map(|(edge, (param, ty))| (edge, param.into_iter().zip(ty.into_iter()).collect()))
.collect();
edges.sort_unstable();
let (type_edge_resolver_fn, type_edge_mod) = make_type_edge_resolver(&row.name, edges);
edges_file.top_level_items.push(type_edge_resolver_fn);
edges_file.top_level_items.push(type_edge_mod);
}
edges_file.external_imports.insert(parse_import("trustfall::provider::AsVertex"));
edges_file.external_imports.insert(parse_import("trustfall::provider::ContextIterator"));
edges_file.external_imports.insert(parse_import("trustfall::provider::ContextOutcomeIterator"));
edges_file.external_imports.insert(parse_import("trustfall::provider::EdgeParameters"));
edges_file.external_imports.insert(parse_import("trustfall::provider::ResolveEdgeInfo"));
edges_file.external_imports.insert(parse_import("trustfall::provider::VertexIterator"));
edges_file.internal_imports.insert(parse_import("super::vertex::Vertex"));
}
fn make_type_edge_resolver(
type_name: &str,
edges: Vec<(String, Vec<(String, String)>)>,
) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
let lower_type_name = to_lower_snake_case(type_name);
let mod_name = syn::Ident::new(
&escaped_rust_name(lower_type_name.clone()),
proc_macro2::Span::call_site(),
);
let mut arms = proc_macro2::TokenStream::new();
let mut edge_resolvers = proc_macro2::TokenStream::new();
for (edge_name, params) in edges {
let (arm, resolver) =
make_edge_resolver_and_call(type_name, &edge_name, ¶ms, &mod_name);
arms.extend(arm);
edge_resolvers.extend(resolver);
}
let type_edge_resolver_fn = type_edge_resolver_fn_name(&lower_type_name);
let ident = syn::Ident::new(&type_edge_resolver_fn, proc_macro2::Span::call_site());
let unreachable_msg =
format!("attempted to resolve unexpected edge '{{edge_name}}' on type '{type_name}'");
let type_edge_resolver = quote! {
pub(super) fn #ident<'a, V: AsVertex<Vertex> + 'a>(
contexts: ContextIterator<'a, V>,
edge_name: &str,
parameters: &EdgeParameters,
resolve_info: &ResolveEdgeInfo,
) -> ContextOutcomeIterator<'a, V, VertexIterator<'a, Vertex>> {
match edge_name {
#arms
_ => unreachable!(#unreachable_msg),
}
}
};
let type_edge_mod = quote! {
mod #mod_name {
use trustfall::provider::{
resolve_neighbors_with, AsVertex, ContextIterator, ContextOutcomeIterator, ResolveEdgeInfo,
VertexIterator,
};
use super::super::vertex::Vertex;
#edge_resolvers
}
};
(type_edge_resolver, type_edge_mod)
}
fn make_edge_resolver_and_call(
type_name: &str,
edge_name: &str,
parameters: &[(String, String)],
mod_name: &proc_macro2::Ident,
) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
let FnCall { fn_params, fn_args, fn_arg_prep } = prepare_call_parameters(
parameters,
|parameter_name| {
format!("failed to find parameter '{parameter_name}' for edge '{edge_name}' on type '{type_name}'")
},
);
let variant_name = escaped_rust_name(upper_case_variant_name(type_name));
let resolver_fn_name = escaped_rust_name(to_lower_snake_case(edge_name));
let resolver_fn_ident = syn::Ident::new(&resolver_fn_name, proc_macro2::Span::call_site());
let conversion_fn_name = format!("as_{}", to_lower_snake_case(&variant_name));
let conversion_fn_ident = syn::Ident::new(&conversion_fn_name, proc_macro2::Span::call_site());
let expect_msg = format!("conversion failed, vertex was not a {type_name}");
let todo_msg = format!("get neighbors along edge '{edge_name}' for type '{type_name}'");
let resolver = quote! {
pub(super) fn #resolver_fn_ident<'a, V: AsVertex<Vertex> + 'a>(
contexts: ContextIterator<'a, V>,
#fn_params
_resolve_info: &ResolveEdgeInfo,
) -> ContextOutcomeIterator<'a, V, VertexIterator<'a, Vertex>> {
resolve_neighbors_with(contexts, move |vertex| {
let vertex = vertex.#conversion_fn_ident().expect(#expect_msg);
todo!(#todo_msg)
})
}
};
let match_arm = if parameters.is_empty() {
quote! {
#edge_name => #mod_name::#resolver_fn_ident(contexts, resolve_info),
}
} else {
quote! {
#edge_name => {
#fn_arg_prep
#mod_name::#resolver_fn_ident(contexts, #fn_args resolve_info)
}
}
};
(match_arm, resolver)
}
pub(super) struct FnCall {
pub(super) fn_params: proc_macro2::TokenStream,
pub(super) fn_args: proc_macro2::TokenStream,
pub(super) fn_arg_prep: proc_macro2::TokenStream,
}
pub(super) fn prepare_call_parameters(
parameters: &[(String, String)],
expect_msg_fn: impl Fn(&str) -> String,
) -> FnCall {
let mut fn_params: proc_macro2::TokenStream = proc_macro2::TokenStream::new();
let mut fn_args: proc_macro2::TokenStream = proc_macro2::TokenStream::new();
let mut fn_arg_prep: proc_macro2::TokenStream = proc_macro2::TokenStream::new();
for (parameter_name, parameter_type) in parameters {
let ident = syn::Ident::new(parameter_name, proc_macro2::Span::call_site());
let ty = trustfall_type_to_rust_type(parameter_type);
fn_params.extend(quote! {
#ident: #ty,
});
fn_args.extend(quote! {
#ident,
});
let expect_msg = expect_msg_fn(parameter_name);
let parameter_get = quote! {
parameters.get(#parameter_name).expect(#expect_msg)
};
let parameter_expr = field_value_to_rust_type(parameter_type, parameter_get);
fn_arg_prep.extend(quote! {
let #ident: #ty = #parameter_expr;
});
}
FnCall { fn_params, fn_args, fn_arg_prep }
}