azure_core_test_macros 0.1.0

Procedural macros for testing client libraries built on azure_core.
Documentation
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

use proc_macro2::TokenStream;
use quote::quote;
use std::sync::LazyLock;
use syn::{
    parse::Parse, spanned::Spanned, FnArg, ItemFn, Meta, PatType, Result, ReturnType, Token,
};

const INVALID_RECORDED_ATTRIBUTE_MESSAGE: &str =
    "expected `#[recorded::test]`, `#[recorded::test(live)]`, or `#[recorded::test(playback)]`";
const INVALID_RECORDED_FUNCTION_MESSAGE: &str =
    "expected `async fn(TestContext)` function signature with `Result<T, E>` return";
const INVALID_TEST_MODE_MESSAGE: &str = "expected 'playback', 'record', or 'live'";

#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
enum TestMode {
    #[default]
    Playback,
    Record,
    Live,
}

impl TestMode {
    fn current() -> std::result::Result<Self, &'static str> {
        std::env::var("AZURE_TEST_MODE")
            .map_or_else(|_| Ok(Self::default()), |value| Self::parse(&value))
    }

    fn parse(value: &str) -> std::result::Result<Self, &'static str> {
        match value.to_ascii_lowercase().as_str() {
            "playback" => Ok(Self::Playback),
            "record" => Ok(Self::Record),
            "live" => Ok(Self::Live),
            _ => Err(INVALID_TEST_MODE_MESSAGE),
        }
    }
}

// cspell:ignore asyncness
pub fn parse_test(attr: TokenStream, item: TokenStream) -> Result<TokenStream> {
    let recorded_attrs: Attributes = syn::parse2(attr)?;
    let ItemFn {
        attrs,
        vis,
        sig: original_sig,
        block,
    } = syn::parse2(item)?;

    let mut test_attr: TokenStream = match original_sig.asyncness {
        Some(_) => quote! { #[::tokio::test(flavor = "multi_thread")] },
        None => {
            return Err(syn::Error::new(
                original_sig.span(),
                INVALID_RECORDED_FUNCTION_MESSAGE,
            ))
        }
    };

    // Assumes the return type is a `Result<T, E>` since that's all `#[test]`s support currently.
    if let ReturnType::Default = original_sig.output {
        return Err(syn::Error::new(
            original_sig.output.span(),
            INVALID_RECORDED_FUNCTION_MESSAGE,
        ));
    }

    // Ignore live-only tests if not running live tests.
    let test_mode = *TEST_MODE;
    if recorded_attrs.live && test_mode < TestMode::Live {
        test_attr.extend(quote! {
            #[ignore = "skipping live tests"]
        });
    }

    // Ignore playback-only tests if not running playback tests.
    if recorded_attrs.playback && test_mode != TestMode::Playback {
        test_attr.extend(quote! {
            #[ignore = "skipping playback-only tests"]
        });
    }

    let fn_name = &original_sig.ident;
    let mut inputs = original_sig.inputs.iter();
    let setup = match inputs.next() {
        None if recorded_attrs.live => quote! {
            #fn_name().await
        },
        Some(FnArg::Typed(PatType { ty, .. })) if is_test_context(ty.as_ref()) => {
            let test_mode = test_mode_to_tokens(test_mode);
            quote! {
                #[allow(dead_code)]
                let mut ctx = ::azure_core_test::recorded::start(
                    #test_mode,
                    env!("CARGO_MANIFEST_DIR"),
                    file!(),
                    stringify!(#fn_name),
                    ::std::option::Option::None,
                ).await?;
                #fn_name(ctx).await
            }
        }
        _ => {
            return Err(syn::Error::new(
                original_sig.ident.span(),
                INVALID_RECORDED_FUNCTION_MESSAGE,
            ))
        }
    };

    if let Some(arg) = inputs.next() {
        return Err(syn::Error::new(
            arg.span(),
            format!("too many parameters; {INVALID_RECORDED_FUNCTION_MESSAGE}"),
        ));
    }

    // Clear the actual test method parameters.
    let mut outer_sig = original_sig.clone();
    outer_sig.inputs.clear();

    Ok(quote! {
        #test_attr
        #(#attrs)*
        #vis #outer_sig {
            #original_sig {
                #block
            }
            #setup
        }
    })
}

static TEST_MODE: LazyLock<TestMode> = LazyLock::new(|| {
    // Okay to panic if AZURE_TEST_MODE is unsupported.
    TestMode::current().unwrap()
});

#[derive(Debug, Default)]
struct Attributes {
    live: bool,
    playback: bool,
}

impl Parse for Attributes {
    fn parse(input: syn::parse::ParseStream) -> Result<Self> {
        let mut attrs = Self::default();
        for arg in input.parse_terminated(Meta::parse, Token![,])? {
            match &arg {
                Meta::Path(path) => {
                    let ident = path.get_ident().ok_or_else(|| {
                        syn::Error::new(arg.span(), INVALID_RECORDED_ATTRIBUTE_MESSAGE)
                    })?;
                    match ident.to_string().as_str() {
                        "live" => attrs.live = true,
                        "playback" => attrs.playback = true,
                        _ => {
                            return Err(syn::Error::new(
                                arg.span(),
                                INVALID_RECORDED_ATTRIBUTE_MESSAGE,
                            ))
                        }
                    }
                }
                _ => {
                    return Err(syn::Error::new(
                        arg.span(),
                        INVALID_RECORDED_ATTRIBUTE_MESSAGE,
                    ))
                }
            }
        }
        Ok(attrs)
    }
}

fn is_test_context(arg: &syn::Type) -> bool {
    let path = match arg {
        syn::Type::Path(syn::TypePath { path, .. }) => path,
        _ => return false,
    };

    if path.leading_colon.is_none()
        && path.segments.len() == 1
        && path.segments[0].ident == "TestContext"
    {
        return true;
    }

    path.segments.len() == 2
        && path.segments[0].ident == "azure_core_test"
        && path.segments[1].ident == "TestContext"
}

fn test_mode_to_tokens(test_mode: TestMode) -> TokenStream {
    match test_mode {
        TestMode::Playback => quote! { ::azure_core_test::TestMode::Playback },
        TestMode::Record => quote! { ::azure_core_test::TestMode::Record },
        TestMode::Live => quote! { ::azure_core_test::TestMode::Live },
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use syn::Attribute;

    #[test]
    fn test_mode_parse() {
        assert_eq!(TestMode::parse("playback").unwrap(), TestMode::Playback);
        assert_eq!(TestMode::parse("Record").unwrap(), TestMode::Record);
        assert_eq!(TestMode::parse("LIVE").unwrap(), TestMode::Live);
        assert_eq!(
            TestMode::parse("invalid").unwrap_err(),
            INVALID_TEST_MODE_MESSAGE
        );
    }

    #[test]
    fn attributes_parse_live() {
        let attr: Attribute = syn::parse_quote! {
            #[recorded(live)]
        };
        let attrs: Attributes = attr.parse_args().unwrap();
        assert!(attrs.live);
    }

    #[test]
    fn attributes_parse_other() {
        let attr: Attribute = syn::parse_quote! {
            #[recorded(other)]
        };
        attr.parse_args::<Attributes>().unwrap_err();
    }

    #[test]
    fn attributes_parse_multiple() {
        let attr: Attribute = syn::parse_quote! {
            #[recorded(live, other)]
        };
        attr.parse_args::<Attributes>().unwrap_err();
    }

    #[test]
    fn attributes_parse_live_value() {
        let attr: Attribute = syn::parse_quote! {
            #[recorded(live = true)]
        };
        attr.parse_args::<Attributes>().unwrap_err();
    }

    #[test]
    fn is_test_context() {
        let types: Vec<syn::Type> = vec![
            syn::parse_quote! { ::azure_core_test::TestContext },
            syn::parse_quote! { azure_core_test::TestContext },
            syn::parse_quote! { TestContext },
        ];
        for ty in types {
            assert!(super::is_test_context(&ty));
        }
    }

    #[test]
    fn parse_recorded_playback() {
        let attr = TokenStream::new();
        let item = quote! {
            async fn recorded() -> azure_core::Result<()> {
                todo!()
            }
        };
        parse_test(attr, item).unwrap_err();
    }

    #[test]
    fn parse_recorded_playback_with_context() {
        let attr = TokenStream::new();
        let item = quote! {
            async fn recorded(ctx: TestContext) -> azure_core::Result<()> {
                todo!()
            }
        };
        parse_test(attr, item).unwrap();
    }

    #[test]
    fn parse_recorded_playback_with_multiple() {
        let attr = TokenStream::new();
        let item = quote! {
            async fn recorded(ctx: TestContext, name: &'static str)- > azure_core::Result<()> {
                todo!()
            }
        };
        parse_test(attr, item).unwrap_err();
    }

    #[test]
    fn parse_recorded_live() {
        let attr = quote! { live };
        let item = quote! {
            async fn live_only() -> azure_core::Result<()> {
                todo!()
            }
        };
        parse_test(attr, item).unwrap();
    }

    #[test]
    fn parse_recorded_live_with_context() {
        let attr = quote! { live };
        let item = quote! {
            async fn live_only(ctx: TestContext) -> azure_core::Result<()> {
                todo!()
            }
        };
        parse_test(attr, item).unwrap();
    }

    #[test]
    fn attributes_parse_playback() {
        let attr: Attribute = syn::parse_quote! {
            #[recorded(playback)]
        };
        let attrs: Attributes = attr.parse_args().unwrap();
        assert!(attrs.playback);
        assert!(!attrs.live);
    }

    #[test]
    fn parse_recorded_playback_only() {
        let attr = quote! { playback };
        let item = quote! {
            async fn playback_only(ctx: TestContext) -> azure_core::Result<()> {
                todo!()
            }
        };
        parse_test(attr, item).unwrap();
    }
}