trustfall_stubgen 0.4.1

Generate a Trustfall adapter stub for a given schema.
Documentation
use std::{collections::BTreeMap, sync::Arc};

use maplit::btreemap;
use quote::quote;
use trustfall::{Schema, SchemaAdapter, TryIntoStruct};

use crate::util::upper_case_variant_name;

use super::util::escaped_rust_name;

use super::{
    root::RustFile,
    util::{
        field_value_to_rust_type, parse_import, to_lower_snake_case, trustfall_type_to_rust_type,
        type_edge_resolver_fn_name,
    },
};

pub(super) fn make_edges_file(
    querying_schema: &Schema,
    adapter: Arc<SchemaAdapter<'_>>,
    edges_file: &mut RustFile,
) {
    let query = r#"
{
    VertexType {
        name @output

        edge @fold @transform(op: "count") @filter(op: ">", value: ["$zero"]) {
            edge_name: name @output

            parameter_: parameter @fold {
                name @output
                type @output
            }
        }
    }
}"#;
    let variables: BTreeMap<Arc<str>, i64> = btreemap! {
        "zero".into() => 0,
    };

    #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, serde::Deserialize)]
    struct ResultRow {
        name: String,
        edge_name: Vec<String>,
        parameter_name: Vec<Vec<String>>,
        parameter_type: Vec<Vec<String>>,
    }

    let mut rows: Vec<_> = trustfall::execute_query(querying_schema, adapter, query, variables)
        .expect("invalid query")
        .map(|x| x.try_into_struct::<ResultRow>().expect("invalid conversion"))
        .collect();
    rows.sort_unstable();
    for row in rows {
        let mut edges: Vec<(String, Vec<(String, String)>)> = row
            .edge_name
            .into_iter()
            .zip(row.parameter_name.into_iter().zip(row.parameter_type.into_iter()))
            .map(|(edge, (param, ty))| (edge, param.into_iter().zip(ty.into_iter()).collect()))
            .collect();
        edges.sort_unstable();

        let (type_edge_resolver_fn, type_edge_mod) = make_type_edge_resolver(&row.name, edges);
        edges_file.top_level_items.push(type_edge_resolver_fn);
        edges_file.top_level_items.push(type_edge_mod);
    }

    edges_file.external_imports.insert(parse_import("trustfall::provider::AsVertex"));
    edges_file.external_imports.insert(parse_import("trustfall::provider::ContextIterator"));
    edges_file.external_imports.insert(parse_import("trustfall::provider::ContextOutcomeIterator"));
    edges_file.external_imports.insert(parse_import("trustfall::provider::EdgeParameters"));
    edges_file.external_imports.insert(parse_import("trustfall::provider::ResolveEdgeInfo"));
    edges_file.external_imports.insert(parse_import("trustfall::provider::VertexIterator"));

    edges_file.internal_imports.insert(parse_import("super::vertex::Vertex"));
}

fn make_type_edge_resolver(
    type_name: &str,
    edges: Vec<(String, Vec<(String, String)>)>,
) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
    let lower_type_name = to_lower_snake_case(type_name);
    let mod_name = syn::Ident::new(
        &escaped_rust_name(lower_type_name.clone()),
        proc_macro2::Span::call_site(),
    );

    let mut arms = proc_macro2::TokenStream::new();
    let mut edge_resolvers = proc_macro2::TokenStream::new();
    for (edge_name, params) in edges {
        let (arm, resolver) =
            make_edge_resolver_and_call(type_name, &edge_name, &params, &mod_name);
        arms.extend(arm);
        edge_resolvers.extend(resolver);
    }

    let type_edge_resolver_fn = type_edge_resolver_fn_name(&lower_type_name);
    let ident = syn::Ident::new(&type_edge_resolver_fn, proc_macro2::Span::call_site());
    let unreachable_msg =
        format!("attempted to resolve unexpected edge '{{edge_name}}' on type '{type_name}'");
    let type_edge_resolver = quote! {
        pub(super) fn #ident<'a, V: AsVertex<Vertex> + 'a>(
            contexts: ContextIterator<'a, V>,
            edge_name: &str,
            parameters: &EdgeParameters,
            resolve_info: &ResolveEdgeInfo,
        ) -> ContextOutcomeIterator<'a, V, VertexIterator<'a, Vertex>> {
            match edge_name {
                #arms
                _ => unreachable!(#unreachable_msg),
            }
        }
    };

    let type_edge_mod = quote! {
        mod #mod_name {
            use trustfall::provider::{
                resolve_neighbors_with, AsVertex, ContextIterator, ContextOutcomeIterator, ResolveEdgeInfo,
                VertexIterator,
            };

            use super::super::vertex::Vertex;

            #edge_resolvers
        }
    };

    (type_edge_resolver, type_edge_mod)
}

fn make_edge_resolver_and_call(
    type_name: &str,
    edge_name: &str,
    parameters: &[(String, String)],
    mod_name: &proc_macro2::Ident,
) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
    let FnCall { fn_params, fn_args, fn_arg_prep } = prepare_call_parameters(
        parameters,
        |parameter_name| {
            format!("failed to find parameter '{parameter_name}' for edge '{edge_name}' on type '{type_name}'")
        },
    );

    let variant_name = escaped_rust_name(upper_case_variant_name(type_name));
    let resolver_fn_name = escaped_rust_name(to_lower_snake_case(edge_name));
    let resolver_fn_ident = syn::Ident::new(&resolver_fn_name, proc_macro2::Span::call_site());
    let conversion_fn_name = format!("as_{}", to_lower_snake_case(&variant_name));
    let conversion_fn_ident = syn::Ident::new(&conversion_fn_name, proc_macro2::Span::call_site());
    let expect_msg = format!("conversion failed, vertex was not a {type_name}");
    let todo_msg = format!("get neighbors along edge '{edge_name}' for type '{type_name}'");
    let resolver = quote! {
        pub(super) fn #resolver_fn_ident<'a, V: AsVertex<Vertex> + 'a>(
            contexts: ContextIterator<'a, V>,
            #fn_params
            _resolve_info: &ResolveEdgeInfo,
        ) -> ContextOutcomeIterator<'a, V, VertexIterator<'a, Vertex>> {
            resolve_neighbors_with(contexts, move |vertex| {
                let vertex = vertex.#conversion_fn_ident().expect(#expect_msg);
                todo!(#todo_msg)
            })
        }
    };

    let match_arm = if parameters.is_empty() {
        quote! {
            #edge_name => #mod_name::#resolver_fn_ident(contexts, resolve_info),
        }
    } else {
        quote! {
            #edge_name => {
                #fn_arg_prep
                #mod_name::#resolver_fn_ident(contexts, #fn_args resolve_info)
            }
        }
    };

    (match_arm, resolver)
}

pub(super) struct FnCall {
    pub(super) fn_params: proc_macro2::TokenStream,
    pub(super) fn_args: proc_macro2::TokenStream,
    pub(super) fn_arg_prep: proc_macro2::TokenStream,
}

pub(super) fn prepare_call_parameters(
    parameters: &[(String, String)],
    expect_msg_fn: impl Fn(&str) -> String,
) -> FnCall {
    let mut fn_params: proc_macro2::TokenStream = proc_macro2::TokenStream::new();
    let mut fn_args: proc_macro2::TokenStream = proc_macro2::TokenStream::new();
    let mut fn_arg_prep: proc_macro2::TokenStream = proc_macro2::TokenStream::new();

    for (parameter_name, parameter_type) in parameters {
        let ident = syn::Ident::new(parameter_name, proc_macro2::Span::call_site());
        let ty = trustfall_type_to_rust_type(parameter_type);
        fn_params.extend(quote! {
            #ident: #ty,
        });
        fn_args.extend(quote! {
            #ident,
        });

        let expect_msg = expect_msg_fn(parameter_name);
        let parameter_get = quote! {
            parameters.get(#parameter_name).expect(#expect_msg)
        };
        let parameter_expr = field_value_to_rust_type(parameter_type, parameter_get);

        fn_arg_prep.extend(quote! {
            let #ident: #ty = #parameter_expr;
        });
    }

    FnCall { fn_params, fn_args, fn_arg_prep }
}