cratestack-macros 0.2.1

Rust-native schema-first framework for typed HTTP APIs, generated clients, and backend services.
Documentation
use cratestack_core::{Procedure, TypeArity, TypeDecl};
use quote::quote;

use super::ast::{generate_policy_ast_tokens, parse_policy_ast};
use super::auth::{find_auth_field, parse_builtin_policy_call, parse_string_literal};

#[derive(Clone)]
struct ProcedurePolicyField {
    ty: cratestack_core::TypeRef,
}

pub(crate) fn parse_procedure_allow_expression(raw: &str) -> Option<Result<&str, String>> {
    parse_procedure_policy_expression(raw, "@allow")
}

pub(crate) fn parse_procedure_deny_expression(raw: &str) -> Option<Result<&str, String>> {
    parse_procedure_policy_expression(raw, "@deny")
}

pub(crate) fn generate_procedure_policy(
    expression: &str,
    procedure: &Procedure,
    types: &[TypeDecl],
    auth: Option<&cratestack_core::AuthBlock>,
) -> Result<proc_macro2::TokenStream, String> {
    let ast = parse_policy_ast(expression)?;
    let expr = generate_policy_ast_tokens(
        &ast,
        &|term| {
            parse_procedure_policy_term(term, procedure, types, auth).map(
                |predicate| quote! { ::cratestack::ProcedurePolicyExpr::Predicate(#predicate) },
            )
        },
        quote! { ::cratestack::ProcedurePolicyExpr::And },
        quote! { ::cratestack::ProcedurePolicyExpr::Or },
    )?;

    Ok(quote! {
        ::cratestack::ProcedurePolicy {
            expr: #expr,
        }
    })
}

fn parse_procedure_policy_expression<'a>(
    raw: &'a str,
    directive: &str,
) -> Option<Result<&'a str, String>> {
    let inner = raw
        .trim()
        .strip_prefix(directive)?
        .strip_prefix('(')?
        .strip_suffix(')')?
        .trim();
    Some(Ok(inner))
}

fn parse_procedure_policy_term(
    term: &str,
    procedure: &Procedure,
    types: &[TypeDecl],
    auth: Option<&cratestack_core::AuthBlock>,
) -> Result<proc_macro2::TokenStream, String> {
    if term == "auth() != null" {
        return Ok(quote! { ::cratestack::ProcedurePredicate::AuthNotNull });
    }

    if term == "auth() == null" {
        return Ok(quote! { ::cratestack::ProcedurePredicate::AuthIsNull });
    }

    if let Some(function) = parse_builtin_policy_call(term) {
        return parse_builtin_procedure_policy_term(function?);
    }

    if let Some((lhs, rhs)) = term.split_once("==") {
        return parse_procedure_comparison(lhs.trim(), rhs.trim(), procedure, types, auth, false);
    }

    if let Some((lhs, rhs)) = term.split_once("!=") {
        return parse_procedure_comparison(lhs.trim(), rhs.trim(), procedure, types, auth, true);
    }

    let field_decl = resolve_procedure_field(procedure, types, term)?;
    if field_decl.ty.name != "Boolean" || field_decl.ty.arity != TypeArity::Required {
        return Err(format!(
            "boolean procedure policy check `{term}` is only supported for required Boolean input fields"
        ));
    }

    Ok(quote! {
        ::cratestack::ProcedurePredicate::InputFieldIsTrue {
            field: #term,
        }
    })
}

fn parse_builtin_procedure_policy_term(
    (name, value): (&str, &str),
) -> Result<proc_macro2::TokenStream, String> {
    match name {
        "hasRole" => Ok(quote! {
            ::cratestack::ProcedurePredicate::HasRole {
                role: #value,
            }
        }),
        "inTenant" => Ok(quote! {
            ::cratestack::ProcedurePredicate::InTenant {
                tenant_id: #value,
            }
        }),
        _ => Err(format!("unsupported policy function `{name}`")),
    }
}

fn parse_procedure_comparison(
    lhs: &str,
    rhs: &str,
    procedure: &Procedure,
    types: &[TypeDecl],
    auth: Option<&cratestack_core::AuthBlock>,
    negate: bool,
) -> Result<proc_macro2::TokenStream, String> {
    if let Some(auth_field) = lhs.strip_prefix("auth().") {
        let auth_field = auth_field.trim();
        ensure_auth_field(auth, types, auth_field)?;
        if resolve_procedure_field(procedure, types, rhs).is_ok() {
            return Ok(if negate {
                quote! {
                    ::cratestack::ProcedurePredicate::InputFieldNeAuth {
                        field: #rhs,
                        auth_field: #auth_field,
                    }
                }
            } else {
                quote! {
                    ::cratestack::ProcedurePredicate::InputFieldEqAuth {
                        field: #rhs,
                        auth_field: #auth_field,
                    }
                }
            });
        }

        let literal = parse_procedure_literal(rhs, None, auth_field)?;
        return Ok(if negate {
            quote! {
                ::cratestack::ProcedurePredicate::AuthFieldNeLiteral {
                    auth_field: #auth_field,
                    value: #literal,
                }
            }
        } else {
            quote! {
                ::cratestack::ProcedurePredicate::AuthFieldEqLiteral {
                    auth_field: #auth_field,
                    value: #literal,
                }
            }
        });
    }

    let field_decl = resolve_procedure_field(procedure, types, lhs)?;
    if let Some(auth_field) = rhs.strip_prefix("auth().") {
        let auth_field = auth_field.trim();
        ensure_auth_field(auth, types, auth_field)?;
        return Ok(if negate {
            quote! {
                ::cratestack::ProcedurePredicate::InputFieldNeAuth {
                    field: #lhs,
                    auth_field: #auth_field,
                }
            }
        } else {
            quote! {
                ::cratestack::ProcedurePredicate::InputFieldEqAuth {
                    field: #lhs,
                    auth_field: #auth_field,
                }
            }
        });
    }

    if let Ok(other_field_decl) = resolve_procedure_field(procedure, types, rhs) {
        validate_procedure_field_type_match(&field_decl, &other_field_decl, lhs, rhs)?;
        return Ok(if negate {
            quote! {
                ::cratestack::ProcedurePredicate::InputFieldNeInput {
                    field: #lhs,
                    other_field: #rhs,
                }
            }
        } else {
            quote! {
                ::cratestack::ProcedurePredicate::InputFieldEqInput {
                    field: #lhs,
                    other_field: #rhs,
                }
            }
        });
    }

    let literal = parse_procedure_literal(rhs, Some(&field_decl), lhs)?;
    Ok(if negate {
        quote! {
            ::cratestack::ProcedurePredicate::InputFieldNeLiteral {
                field: #lhs,
                value: #literal,
            }
        }
    } else {
        quote! {
            ::cratestack::ProcedurePredicate::InputFieldEqLiteral {
                field: #lhs,
                value: #literal,
            }
        }
    })
}

fn resolve_procedure_field(
    procedure: &Procedure,
    types: &[TypeDecl],
    field: &str,
) -> Result<ProcedurePolicyField, String> {
    if let Some((root, rest)) = field.split_once('.') {
        let arg = procedure
            .args
            .iter()
            .find(|candidate| candidate.name == root)
            .ok_or_else(|| {
                format!(
                    "unknown procedure input field `{field}` on `{}`",
                    procedure.name
                )
            })?;
        return resolve_type_field_path(types, &arg.ty.name, rest, &procedure.name, field);
    }

    if let Some(arg) = procedure
        .args
        .iter()
        .find(|candidate| candidate.name == field)
    {
        return Ok(ProcedurePolicyField { ty: arg.ty.clone() });
    }

    if let Some(arg) = procedure
        .args
        .iter()
        .find(|candidate| candidate.name == "args")
        && let Ok(field_decl) =
            resolve_type_field_path(types, &arg.ty.name, field, &procedure.name, field)
    {
        return Ok(field_decl);
    }

    Err(format!(
        "unknown procedure input field `{field}` on `{}`",
        procedure.name
    ))
}

fn resolve_type_field_path(
    types: &[TypeDecl],
    type_name: &str,
    path: &str,
    procedure_name: &str,
    original_field: &str,
) -> Result<ProcedurePolicyField, String> {
    let ty = types.iter().find(|candidate| candidate.name == type_name).ok_or_else(|| {
        format!(
            "procedure `{procedure_name}` references unsupported input type `{type_name}` for policy checks"
        )
    })?;
    let Some((head, tail)) = path.split_once('.') else {
        return ty
            .fields
            .iter()
            .find(|candidate| candidate.name == path)
            .map(|candidate| ProcedurePolicyField {
                ty: candidate.ty.clone(),
            })
            .ok_or_else(|| {
                format!("unknown procedure input field `{original_field}` on `{procedure_name}`")
            });
    };
    let field = ty
        .fields
        .iter()
        .find(|candidate| candidate.name == head)
        .ok_or_else(|| {
            format!("unknown procedure input field `{original_field}` on `{procedure_name}`")
        })?;
    resolve_type_field_path(types, &field.ty.name, tail, procedure_name, original_field)
}

fn validate_procedure_field_type_match(
    left: &ProcedurePolicyField,
    right: &ProcedurePolicyField,
    left_name: &str,
    right_name: &str,
) -> Result<(), String> {
    if left.ty.name != right.ty.name || left.ty.arity != right.ty.arity {
        return Err(format!(
            "procedure fields `{left_name}` and `{right_name}` must share the same type for policy comparisons"
        ));
    }
    Ok(())
}

fn parse_procedure_literal(
    rhs: &str,
    field: Option<&ProcedurePolicyField>,
    field_name: &str,
) -> Result<proc_macro2::TokenStream, String> {
    let (field_type, arity) = match field {
        Some(field) => (field.ty.name.as_str(), field.ty.arity),
        None => ("auth", TypeArity::Required),
    };

    match field_type {
        "Boolean" | "auth" if arity == TypeArity::Required && matches!(rhs, "true" | "false") => {
            let value = rhs == "true";
            Ok(quote! { ::cratestack::ProcedurePolicyLiteral::Bool(#value) })
        }
        "Int" if arity == TypeArity::Required => rhs
            .parse::<i64>()
            .map(|value| quote! { ::cratestack::ProcedurePolicyLiteral::Int(#value) })
            .map_err(|_| format!("expected integer literal for procedure field `{field_name}`")),
        "String" | "auth" if arity == TypeArity::Required => {
            let value = parse_string_literal(rhs).ok_or_else(|| {
                format!("expected string literal for procedure field `{field_name}`")
            })?;
            Ok(quote! { ::cratestack::ProcedurePolicyLiteral::String(#value) })
        }
        _ => Err(format!(
            "procedure policy literal support is currently limited to required Boolean, Int, and String fields; `{field_name}` is unsupported"
        )),
    }
}

fn ensure_auth_field(
    auth: Option<&cratestack_core::AuthBlock>,
    types: &[TypeDecl],
    field: &str,
) -> Result<(), String> {
    find_auth_field(auth, types, field).map(|_| ())
}