override_macro 0.1.1

An attribute-like macro for Rust programs to override trait methods with default methods of other traits
Documentation
// Copyright (C) 2024-2025 Takayuki Sato. All Rights Reserved.
// This program is free software under MIT License.
// See the file LICENSE in this distribution for more details.

use crate::syn_data_acc;

use std::collections::HashMap;
use std::sync::{LazyLock, RwLock};

static TRAIT_VEC: RwLock<Vec<OverridableTrait>> = RwLock::new(Vec::new());
static TRAIT_MAP: LazyLock<RwLock<HashMap<String, FoundTrait>>> =
    LazyLock::new(|| RwLock::new(HashMap::new()));

enum FoundTrait {
    Found(usize),
    Conflict,
}

struct OverridableTrait {
    #[allow(dead_code)]
    mod_path: Option<String>,

    #[allow(dead_code)]
    name: String,

    method_map: HashMap<String, OverridableMethod>,
}

struct OverridableMethod {
    name: String,
    sig: String,
    call: String,
    has_impl: bool,
    is_async: bool,
}

pub fn collect_trait_info(data: &syn_data_acc::OverridableDataAcc) {
    let trait_name = data.get_trait_name();
    let mod_path = data.get_module_path();
    let trait_index = register_trait(data, &trait_name, &mod_path);

    let mut trait_map = TRAIT_MAP.write().unwrap();

    for key in data.list_trait_search_keys() {
        match trait_map.get_mut(&key) {
            Some(_) => trait_map.insert(key, FoundTrait::Conflict),
            None => trait_map.insert(key, FoundTrait::Found(trait_index)),
        };
    }
}

fn register_trait(
    data: &syn_data_acc::OverridableDataAcc,
    trait_name: &str,
    mod_path: &Option<String>,
) -> usize {
    let mut trait_vec = TRAIT_VEC.write().unwrap();
    let trait_index = trait_vec.len();

    let mut method_map = HashMap::<String, OverridableMethod>::new();

    data.for_each_method_registration(|name, sig, call, has_impl, key, is_async| {
        let m = OverridableMethod {
            name,
            sig,
            call,
            has_impl,
            is_async,
        };
        method_map.insert(key, m);
    });

    let t = OverridableTrait {
        mod_path: mod_path.clone(),
        name: trait_name.to_string(),
        method_map,
    };
    trait_vec.push(t);

    trait_index
}

struct ArgumentTrait {
    path: String,
    index: usize,
}

pub fn override_trait_methods(data: &mut syn_data_acc::OverrideWithDataAcc) {
    let trait_map = TRAIT_MAP.read().unwrap();

    let impl_trait_index = data
        .get_impl_trait(|keys| Some(find_trait(&trait_map, keys)))
        .unwrap();

    let trait_vec = TRAIT_VEC.read().unwrap();
    let impl_trait = &trait_vec[impl_trait_index];

    let arg_traits = data.list_argument_traits(|keys, path| {
        let index = find_trait(&trait_map, keys);
        if index == impl_trait_index {
            panic!(
                "The same trait as the impl trait was found in the arguments: {}",
                path
            );
        }
        Some(ArgumentTrait { path, index })
    });

    let impl_method_keys = data.list_impl_method_keys();

    let mut overriding_methods = Vec::<String>::new();

    for (key, method) in &impl_trait.method_map {
        if method.has_impl {
            continue;
        }

        if impl_method_keys.contains(&key) {
            continue;
        }

        let mut conflicting_traits = Vec::<String>::new();

        for arg_trait in &arg_traits {
            let t = &trait_vec[arg_trait.index];

            if let Some(m) = t.method_map.get(key) {
                if !m.has_impl {
                    continue;
                }

                let t_path = &arg_trait.path;

                conflicting_traits.push(t_path.clone());
                if conflicting_traits.len() > 1 {
                    continue;
                }

                if m.is_async {
                    overriding_methods
                        .push(format!("{} {{ {}::{}.await }}", m.sig, t_path, m.call));
                } else {
                    overriding_methods.push(format!("{} {{ {}::{} }}", m.sig, t_path, m.call));
                }
            }
        }

        if conflicting_traits.len() > 1 {
            panic!(
                "The method `{}` is implemented in multiple traits: {}",
                method.name,
                &conflicting_traits.join(", "),
            );
        }
    }

    data.set_overriding_method_impls(overriding_methods);
}

fn find_trait(trait_map: &HashMap<String, FoundTrait>, search_keys: &[String]) -> usize {
    for key in search_keys {
        match trait_map.get(key) {
            Some(FoundTrait::Found(i)) => return *i,
            Some(FoundTrait::Conflict) => {
                panic!("There are multiple traits matching with: {}", key);
            }
            None => continue,
        }
    }

    panic!(
        "There is no traits with the same path or sub path: {}",
        search_keys[0]
    );
}