cairo-lang-plugins 2.1.2

Cairo core plugin implementations.
Documentation
use std::sync::Arc;

use cairo_lang_defs::plugin::{
    DynGeneratedFileAuxData, MacroPlugin, PluginDiagnostic, PluginGeneratedFile, PluginResult,
};
use cairo_lang_semantic::plugin::{AsDynMacroPlugin, SemanticPlugin, TrivialPluginAuxData};
use cairo_lang_syntax::attribute::structured::{
    Attribute, AttributeArg, AttributeArgVariant, AttributeStructurize,
};
use cairo_lang_syntax::node::db::SyntaxGroup;
use cairo_lang_syntax::node::helpers::QueryAttrs;
use cairo_lang_syntax::node::{ast, Terminal, TypedSyntaxNode};
use cairo_lang_utils::try_extract_matches;
use itertools::Itertools;
use smol_str::SmolStr;

#[derive(Debug, Default)]
#[non_exhaustive]
pub struct PanicablePlugin;

impl MacroPlugin for PanicablePlugin {
    fn generate_code(&self, db: &dyn SyntaxGroup, item_ast: ast::Item) -> PluginResult {
        let (declaration, attributes) = match item_ast {
            ast::Item::ExternFunction(extern_func_ast) => {
                (extern_func_ast.declaration(db), extern_func_ast.attributes(db))
            }
            ast::Item::FreeFunction(free_func_ast) => {
                (free_func_ast.declaration(db), free_func_ast.attributes(db))
            }
            _ => return PluginResult::default(),
        };

        generate_panicable_code(db, declaration, attributes)
    }
}
impl AsDynMacroPlugin for PanicablePlugin {
    fn as_dyn_macro_plugin<'a>(self: Arc<Self>) -> Arc<dyn MacroPlugin + 'a>
    where
        Self: 'a,
    {
        self
    }
}
impl SemanticPlugin for PanicablePlugin {}

/// Generate code defining a panicable variant of a function marked with `#[panic_with]` attribute.
fn generate_panicable_code(
    db: &dyn SyntaxGroup,
    declaration: ast::FunctionDeclaration,
    attributes: ast::AttributeList,
) -> PluginResult {
    let mut attrs = attributes.query_attr(db, "panic_with");
    if attrs.is_empty() {
        return PluginResult::default();
    }
    if attrs.len() > 1 {
        let extra_attr = attrs.swap_remove(1);
        return PluginResult {
            code: None,
            diagnostics: vec![PluginDiagnostic {
                stable_ptr: extra_attr.stable_ptr().untyped(),
                message: "`#[panic_with]` cannot be applied multiple times to the same item."
                    .into(),
            }],
            remove_original_item: false,
        };
    }
    let attr = attrs.swap_remove(0);

    let signature = declaration.signature(db);
    let Some((inner_ty_text, success_variant, failure_variant)) =
        extract_success_ty_and_variants(db, &signature)
    else {
        return PluginResult {
            code: None,
            diagnostics: vec![PluginDiagnostic {
                stable_ptr: signature.ret_ty(db).stable_ptr().untyped(),
                message: "Currently only wrapping functions returning an Option<T> or Result<T, E>"
                    .into(),
            }],
            remove_original_item: false,
        };
    };

    let attr = attr.structurize(db);

    let Some((err_value, panicable_name)) = parse_arguments(db, &attr) else {
        return PluginResult {
            code: None,
            diagnostics: vec![PluginDiagnostic {
                stable_ptr: attr.stable_ptr.untyped(),
                message: "Failed to extract panic data attribute".into(),
            }],
            remove_original_item: false,
        };
    };
    let generics_params = declaration.generic_params(db).as_syntax_node().get_text(db);

    let function_name = declaration.name(db).text(db);
    let params = signature.parameters(db).as_syntax_node().get_text(db);
    let args = signature
        .parameters(db)
        .elements(db)
        .into_iter()
        .map(|param| {
            let ref_kw = match &param.modifiers(db).elements(db)[..] {
                [ast::Modifier::Ref(_)] => "ref ",
                _ => "",
            };
            format!("{}{}", ref_kw, param.name(db).as_syntax_node().get_text(db))
        })
        .join(", ");

    PluginResult {
        code: Some(PluginGeneratedFile {
            name: "panicable".into(),
            content: indoc::formatdoc!(
                r#"
                    fn {panicable_name}{generics_params}({params}) -> {inner_ty_text} {{
                        match {function_name}({args}) {{
                            {success_variant} (v) => {{
                                v
                            }},
                            {failure_variant} (v) => {{
                                let mut data = array::array_new::<felt252>();
                                array::array_append::<felt252>(ref data, {err_value});
                                panic(data)
                            }},
                        }}
                    }}
                "#
            ),
            aux_data: DynGeneratedFileAuxData(Arc::new(TrivialPluginAuxData {})),
        }),
        diagnostics: vec![],
        remove_original_item: false,
    }
}

/// Given a function signature, if it returns `Option::<T>` or `Result::<T, E>`, returns T and the
/// variant match strings. Otherwise, returns None.
fn extract_success_ty_and_variants(
    db: &dyn SyntaxGroup,
    signature: &ast::FunctionSignature,
) -> Option<(String, String, String)> {
    let ret_ty_expr =
        try_extract_matches!(signature.ret_ty(db), ast::OptionReturnTypeClause::ReturnTypeClause)?
            .ty(db);
    let ret_ty_path = try_extract_matches!(ret_ty_expr, ast::Expr::Path)?;

    // Currently only wrapping functions returning an Option<T>.
    let [ast::PathSegment::WithGenericArgs(segment)] = &ret_ty_path.elements(db)[..] else {
        return None;
    };
    let ty = segment.ident(db).text(db);
    if ty == "Option" {
        let [inner] = &segment.generic_args(db).generic_args(db).elements(db)[..] else {
            return None;
        };
        Some((
            inner.as_syntax_node().get_text(db),
            "Option::Some".to_owned(),
            "Option::None".to_owned(),
        ))
    } else if ty == "Result" {
        let [inner, _err] = &segment.generic_args(db).generic_args(db).elements(db)[..] else {
            return None;
        };
        Some((
            inner.as_syntax_node().get_text(db),
            "Result::Ok".to_owned(),
            "Result::Err".to_owned(),
        ))
    } else {
        None
    }
}

/// Parse `#[panic_with(...)]` attribute arguments and return a tuple with error value and
/// panicable function name.
fn parse_arguments(db: &dyn SyntaxGroup, attr: &Attribute) -> Option<(SmolStr, SmolStr)> {
    let [
        AttributeArg {
            variant: AttributeArgVariant::Unnamed { value: ast::Expr::ShortString(err_value), .. },
            ..
        },
        AttributeArg {
            variant: AttributeArgVariant::Unnamed { value: ast::Expr::Path(name), .. },
            ..
        },
    ] = &attr.args[..]
    else {
        return None;
    };

    let [ast::PathSegment::Simple(segment)] = &name.elements(db)[..] else {
        return None;
    };

    Some((err_value.text(db), segment.ident(db).text(db)))
}