use std::fmt::Display;
use std::sync::mpsc::Sender;
use fxhash::FxHashMap;
use kind_derive::getters::derive_getters;
use kind_derive::matching::derive_match;
use kind_derive::mutters::derive_mutters;
use kind_derive::open::derive_match_rec;
use kind_derive::setters::derive_setters;
use kind_report::data::Diagnostic;
use kind_span::Locatable;
use kind_span::Range;
use kind_tree::concrete::Entry;
use kind_tree::concrete::EntryMeta;
use kind_tree::concrete::Module;
use kind_tree::concrete::RecordDecl;
use kind_tree::concrete::SumTypeDecl;
use kind_tree::concrete::{Attribute, TopLevel};
use crate::diagnostic::PassDiagnostic;
pub mod uses;
type Derivations = FxHashMap<Derive, Range>;
type Channel = Sender<Box<dyn Diagnostic>>;
#[derive(Debug, Hash, PartialEq, Eq)]
pub enum Derive {
Match,
Getters,
Setters,
Mutters
}
impl Display for Derive {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Derive::Match => write!(f, "match"),
Derive::Getters => write!(f, "getters"),
Derive::Setters => write!(f, "setters"),
Derive::Mutters => write!(f, "mutters"),
}
}
}
pub fn insert_or_report(channel: Channel, hashmap: &mut Derivations, key: Derive, range: Range) {
if let Some(last_range) = hashmap.get(&key) {
let err = Box::new(PassDiagnostic::DuplicatedAttributeArgument(
*last_range,
range,
));
channel.send(err).unwrap();
} else {
hashmap.insert(key, range);
}
}
fn string_to_derive(name: &str) -> Option<Derive> {
match name {
"match" => Some(Derive::Match),
"getters" => Some(Derive::Getters),
"setters" => Some(Derive::Setters),
"mutters" => Some(Derive::Mutters),
_ => None,
}
}
pub fn expand_derive(error_channel: Channel, attrs: &[Attribute]) -> Option<Derivations> {
use kind_tree::concrete::AttributeStyle::*;
let mut failed = false;
let mut defs = FxHashMap::default();
for attr in attrs {
if attr.name.to_str() != "derive" {
continue;
}
if let Some(attr) = &attr.value {
let err = Box::new(PassDiagnostic::AttributeDoesNotExpectEqual(attr.locate()));
error_channel.send(err).unwrap();
failed = true;
}
for arg in &attr.args {
match arg {
Ident(range, ident) if string_to_derive(ident.to_str()).is_some() => {
let key = string_to_derive(ident.to_str()).unwrap();
insert_or_report(error_channel.clone(), &mut defs, key, *range)
}
other => {
let err = Box::new(PassDiagnostic::InvalidAttributeArgument(other.locate()));
error_channel.send(err).unwrap();
failed = true;
}
}
}
}
if failed {
None
} else {
Some(defs)
}
}
pub fn expand_sum_type(
error_channel: Channel,
entries: &mut FxHashMap<String, (Entry, EntryMeta)>,
sum: &SumTypeDecl,
derivations: Derivations,
) -> bool {
let mut failed = false;
for (key, val) in derivations {
match key {
Derive::Match => {
let (res, errs) = derive_match(sum.name.range, sum);
let info = res.extract_book_info();
entries.insert(res.name.to_string(), (res, info));
for err in errs {
error_channel.send(err).unwrap();
failed = true;
}
}
other => {
error_channel
.send(Box::new(PassDiagnostic::CannotDerive(
other.to_string(),
val,
)))
.unwrap();
failed = true;
}
}
}
failed
}
pub fn expand_record_type(
_error_channel: Channel,
entries: &mut FxHashMap<String, (Entry, EntryMeta)>,
rec: &RecordDecl,
derivations: Derivations,
) {
for (key, _) in derivations {
match key {
Derive::Match => {
let res = derive_match_rec(rec.name.range, rec);
let info = res.extract_book_info();
entries.insert(res.name.to_string(), (res, info));
}
Derive::Getters => {
for res in derive_getters(rec.name.range, rec) {
let info = res.extract_book_info();
entries.insert(res.name.to_string(), (res, info));
}
}
Derive::Setters => {
for res in derive_setters(rec.name.range, rec) {
let info = res.extract_book_info();
entries.insert(res.name.to_string(), (res, info));
}
}
Derive::Mutters => {
for res in derive_mutters(rec.name.range, rec) {
let info = res.extract_book_info();
entries.insert(res.name.to_string(), (res, info));
}
}
}
}
}
pub fn expand_module(error_channel: Channel, module: &mut Module) -> bool {
let mut failed = false;
let mut entries = FxHashMap::default();
for entry in &module.entries {
match entry {
TopLevel::SumType(sum) => {
if let Some(derive) = expand_derive(error_channel.clone(), &sum.attrs) {
failed |= expand_sum_type(error_channel.clone(), &mut entries, sum, derive)
} else {
failed = true;
}
}
TopLevel::RecordType(rec) => {
if let Some(derive) = expand_derive(error_channel.clone(), &rec.attrs) {
expand_record_type(error_channel.clone(), &mut entries, rec, derive)
} else {
failed = true;
}
}
TopLevel::Entry(_) => (),
}
}
for (_, (tl, _)) in entries {
module.entries.push(TopLevel::Entry(tl));
}
failed
}