use std::collections::{HashMap, HashSet};
use std::fmt::Write;
use crate::analysis::registry::{FieldDef, FieldType, TypeDef, TypeRegistry, VariantDef};
use crate::ast::stmt::{Expr, Stmt, TypeExpr};
use crate::intern::{Interner, Symbol};
pub(super) fn codegen_type_expr(ty: &TypeExpr, interner: &Interner) -> String {
match ty {
TypeExpr::Primitive(sym) => {
map_type_to_rust(interner.resolve(*sym))
}
TypeExpr::Named(sym) => {
let name = interner.resolve(*sym);
map_type_to_rust(name)
}
TypeExpr::Generic { base, params } => {
let base_name = interner.resolve(*base);
let params_str: Vec<String> = params.iter()
.map(|p| codegen_type_expr(p, interner))
.collect();
match base_name {
"Result" => {
if params_str.len() == 2 {
format!("Result<{}, {}>", params_str[0], params_str[1])
} else if params_str.len() == 1 {
format!("Result<{}, String>", params_str[0])
} else {
"Result<(), String>".to_string()
}
}
"Option" | "Maybe" => {
if !params_str.is_empty() {
format!("Option<{}>", params_str[0])
} else {
"Option<()>".to_string()
}
}
"Seq" | "List" | "Vec" => {
if !params_str.is_empty() {
format!("Vec<{}>", params_str[0])
} else {
"Vec<()>".to_string()
}
}
"Map" | "HashMap" => {
if params_str.len() >= 2 {
format!("FxHashMap<{}, {}>", params_str[0], params_str[1])
} else {
"FxHashMap<String, String>".to_string()
}
}
"Set" | "HashSet" => {
if !params_str.is_empty() {
format!("FxHashSet<{}>", params_str[0])
} else {
"FxHashSet<()>".to_string()
}
}
other => {
if params_str.is_empty() {
other.to_string()
} else {
format!("{}<{}>", other, params_str.join(", "))
}
}
}
}
TypeExpr::Function { inputs, output } => {
let inputs_str: Vec<String> = inputs.iter()
.map(|i| codegen_type_expr(i, interner))
.collect();
let output_str = codegen_type_expr(output, interner);
format!("impl Fn({}) -> {}", inputs_str.join(", "), output_str)
}
TypeExpr::Refinement { base, .. } => {
codegen_type_expr(base, interner)
}
TypeExpr::Persistent { inner } => {
let inner_type = codegen_type_expr(inner, interner);
format!("logicaffeine_system::storage::Persistent<{}>", inner_type)
}
}
}
pub(super) fn infer_return_type_from_body(body: &[Stmt], _interner: &Interner) -> Option<String> {
for stmt in body {
if let Stmt::Return { value: Some(_) } = stmt {
return Some("i64".to_string());
}
}
None
}
pub(super) fn map_type_to_rust(ty: &str) -> String {
match ty {
"Int" => "i64".to_string(),
"Nat" => "u64".to_string(),
"Text" => "String".to_string(),
"Bool" | "Boolean" => "bool".to_string(),
"Real" | "Float" => "f64".to_string(),
"Char" => "char".to_string(),
"Byte" => "u8".to_string(),
"Unit" | "()" => "()".to_string(),
"Duration" => "std::time::Duration".to_string(),
other => other.to_string(),
}
}
pub(super) fn codegen_struct_def(name: Symbol, fields: &[FieldDef], generics: &[Symbol], is_portable: bool, is_shared: bool, interner: &Interner, indent: usize, c_abi_value_structs: &HashSet<Symbol>, c_abi_ref_structs: &HashSet<Symbol>) -> String {
let ind = " ".repeat(indent);
let mut output = String::new();
let generic_str = if generics.is_empty() {
String::new()
} else {
let params: Vec<&str> = generics.iter()
.map(|g| interner.resolve(*g))
.collect();
format!("<{}>", params.join(", "))
};
if c_abi_value_structs.contains(&name) {
writeln!(output, "{}#[repr(C)]", ind).unwrap();
}
if is_portable || is_shared || c_abi_ref_structs.contains(&name) {
writeln!(output, "{}#[derive(Default, Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]", ind).unwrap();
} else {
writeln!(output, "{}#[derive(Default, Debug, Clone, PartialEq)]", ind).unwrap();
}
writeln!(output, "{}pub struct {}{} {{", ind, interner.resolve(name), generic_str).unwrap();
for field in fields {
let vis = if field.is_public { "pub " } else { "" };
let rust_type = codegen_field_type(&field.ty, interner);
writeln!(output, "{} {}{}: {},", ind, vis, interner.resolve(field.name), rust_type).unwrap();
}
writeln!(output, "{}}}\n", ind).unwrap();
if is_shared {
output.push_str(&codegen_merge_impl(name, fields, generics, interner, indent));
}
output
}
pub(super) fn codegen_merge_impl(name: Symbol, fields: &[FieldDef], generics: &[Symbol], interner: &Interner, indent: usize) -> String {
let ind = " ".repeat(indent);
let name_str = interner.resolve(name);
let mut output = String::new();
let generic_str = if generics.is_empty() {
String::new()
} else {
let params: Vec<&str> = generics.iter()
.map(|g| interner.resolve(*g))
.collect();
format!("<{}>", params.join(", "))
};
writeln!(output, "{}impl{} logicaffeine_data::crdt::Merge for {}{} {{", ind, generic_str, name_str, generic_str).unwrap();
writeln!(output, "{} fn merge(&mut self, other: &Self) {{", ind).unwrap();
for field in fields {
let field_name = interner.resolve(field.name);
if is_crdt_field_type(&field.ty, interner) {
writeln!(output, "{} self.{}.merge(&other.{});", ind, field_name, field_name).unwrap();
}
}
writeln!(output, "{} }}", ind).unwrap();
writeln!(output, "{}}}\n", ind).unwrap();
output
}
pub(super) fn is_crdt_field_type(ty: &FieldType, interner: &Interner) -> bool {
match ty {
FieldType::Named(sym) => {
let name = interner.resolve(*sym);
matches!(name,
"ConvergentCount" | "GCounter" |
"Tally" | "PNCounter"
)
}
FieldType::Generic { base, .. } => {
let name = interner.resolve(*base);
matches!(name,
"LastWriteWins" | "LWWRegister" |
"SharedSet" | "ORSet" | "SharedSet_AddWins" | "SharedSet_RemoveWins" |
"SharedSequence" | "RGA" | "SharedSequence_YATA" | "CollaborativeSequence" |
"SharedMap" | "ORMap" |
"Divergent" | "MVRegister"
)
}
_ => false,
}
}
pub(super) fn codegen_enum_def(name: Symbol, variants: &[VariantDef], generics: &[Symbol], is_portable: bool, _is_shared: bool, interner: &Interner, indent: usize) -> String {
let ind = " ".repeat(indent);
let mut output = String::new();
let generic_str = if generics.is_empty() {
String::new()
} else {
let params: Vec<&str> = generics.iter()
.map(|g| interner.resolve(*g))
.collect();
format!("<{}>", params.join(", "))
};
if is_portable {
writeln!(output, "{}#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]", ind).unwrap();
} else {
writeln!(output, "{}#[derive(Debug, Clone, PartialEq)]", ind).unwrap();
}
writeln!(output, "{}pub enum {}{} {{", ind, interner.resolve(name), generic_str).unwrap();
for variant in variants {
let variant_name = interner.resolve(variant.name);
if variant.fields.is_empty() {
writeln!(output, "{} {},", ind, variant_name).unwrap();
} else {
let enum_name_str = interner.resolve(name);
let fields_str: Vec<String> = variant.fields.iter()
.map(|f| {
let rust_type = codegen_field_type(&f.ty, interner);
let field_name = interner.resolve(f.name);
if is_recursive_field(&f.ty, enum_name_str, interner) {
format!("{}: Box<{}>", field_name, rust_type)
} else {
format!("{}: {}", field_name, rust_type)
}
})
.collect();
writeln!(output, "{} {} {{ {} }},", ind, variant_name, fields_str.join(", ")).unwrap();
}
}
writeln!(output, "{}}}\n", ind).unwrap();
if generics.is_empty() {
if let Some(first_variant) = variants.first() {
let enum_name_str = interner.resolve(name);
let first_variant_name = interner.resolve(first_variant.name);
writeln!(output, "{}impl{} Default for {}{} {{", ind, generic_str, enum_name_str, generic_str).unwrap();
writeln!(output, "{} fn default() -> Self {{", ind).unwrap();
if first_variant.fields.is_empty() {
writeln!(output, "{} {}::{}", ind, enum_name_str, first_variant_name).unwrap();
} else {
let default_fields: Vec<String> = first_variant.fields.iter()
.map(|f| {
let field_name = interner.resolve(f.name);
let enum_name_check = interner.resolve(name);
if is_recursive_field(&f.ty, enum_name_check, interner) {
format!("{}: Box::new(Default::default())", field_name)
} else {
format!("{}: Default::default()", field_name)
}
})
.collect();
writeln!(output, "{} {}::{} {{ {} }}", ind, enum_name_str, first_variant_name, default_fields.join(", ")).unwrap();
}
writeln!(output, "{} }}", ind).unwrap();
writeln!(output, "{}}}\n", ind).unwrap();
}
}
output
}
pub(super) fn codegen_field_type(ty: &FieldType, interner: &Interner) -> String {
match ty {
FieldType::Primitive(sym) => {
match interner.resolve(*sym) {
"Int" => "i64".to_string(),
"Nat" => "u64".to_string(),
"Text" => "String".to_string(),
"Bool" | "Boolean" => "bool".to_string(),
"Real" | "Float" => "f64".to_string(),
"Char" => "char".to_string(),
"Byte" => "u8".to_string(),
"Unit" => "()".to_string(),
"Duration" => "std::time::Duration".to_string(),
other => other.to_string(),
}
}
FieldType::Named(sym) => {
let name = interner.resolve(*sym);
match name {
"ConvergentCount" => "logicaffeine_data::crdt::GCounter".to_string(),
"Tally" => "logicaffeine_data::crdt::PNCounter".to_string(),
_ => name.to_string(),
}
}
FieldType::Generic { base, params } => {
let base_name = interner.resolve(*base);
let param_strs: Vec<String> = params.iter()
.map(|p| codegen_field_type(p, interner))
.collect();
match base_name {
"SharedSet_RemoveWins" => {
return format!("logicaffeine_data::crdt::ORSet<{}, logicaffeine_data::crdt::RemoveWins>", param_strs.join(", "));
}
"SharedSet_AddWins" => {
return format!("logicaffeine_data::crdt::ORSet<{}, logicaffeine_data::crdt::AddWins>", param_strs.join(", "));
}
"SharedSequence_YATA" | "CollaborativeSequence" => {
return format!("logicaffeine_data::crdt::YATA<{}>", param_strs.join(", "));
}
_ => {}
}
let base_str = match base_name {
"List" | "Seq" => "Vec",
"Set" => "FxHashSet",
"Map" => "FxHashMap",
"Option" | "Maybe" => "Option",
"Result" => "Result",
"LastWriteWins" => "logicaffeine_data::crdt::LWWRegister",
"SharedSet" | "ORSet" => "logicaffeine_data::crdt::ORSet",
"SharedSequence" | "RGA" => "logicaffeine_data::crdt::RGA",
"SharedMap" | "ORMap" => "logicaffeine_data::crdt::ORMap",
"Divergent" | "MVRegister" => "logicaffeine_data::crdt::MVRegister",
other => other,
};
format!("{}<{}>", base_str, param_strs.join(", "))
}
FieldType::TypeParam(sym) => interner.resolve(*sym).to_string(),
}
}
pub(crate) fn is_recursive_field(ty: &FieldType, enum_name: &str, interner: &Interner) -> bool {
match ty {
FieldType::Primitive(sym) => interner.resolve(*sym) == enum_name,
FieldType::Named(sym) => interner.resolve(*sym) == enum_name,
FieldType::TypeParam(_) => false,
FieldType::Generic { base, params } => {
interner.resolve(*base) == enum_name ||
params.iter().any(|p| is_recursive_field(p, enum_name, interner))
}
}
}
pub(super) fn infer_variant_type_annotation(
expr: &Expr,
registry: &TypeRegistry,
interner: &Interner,
) -> Option<String> {
let (enum_name, variant_name, field_values) = match expr {
Expr::NewVariant { enum_name, variant, fields } => (*enum_name, *variant, fields),
_ => return None,
};
let enum_def = registry.get(enum_name)?;
let (generics, variants) = match enum_def {
TypeDef::Enum { generics, variants, .. } => (generics, variants),
_ => return None,
};
if generics.len() < 2 {
return None;
}
let variant_def = variants.iter().find(|v| v.name == variant_name)?;
let mut type_param_types: HashMap<Symbol, String> = HashMap::new();
for (field_name, field_value) in field_values {
if let Some(field_def) = variant_def.fields.iter().find(|f| f.name == *field_name) {
if let FieldType::TypeParam(type_param) = &field_def.ty {
let inferred = infer_rust_type_from_expr(field_value, interner);
type_param_types.insert(*type_param, inferred);
}
}
}
let enum_str = interner.resolve(enum_name);
let param_strs: Vec<String> = generics.iter()
.map(|g| {
type_param_types.get(g)
.cloned()
.unwrap_or_else(|| "()".to_string())
})
.collect();
Some(format!("{}<{}>", enum_str, param_strs.join(", ")))
}
pub(super) fn infer_rust_type_from_expr(expr: &Expr, _interner: &Interner) -> String {
match expr {
Expr::Literal(lit) => {
let ty = crate::analysis::types::LogosType::from_literal(lit);
ty.to_rust_type()
}
_ => "_".to_string(),
}
}
pub(super) fn infer_numeric_type(
expr: &Expr,
interner: &Interner,
variable_types: &HashMap<Symbol, String>,
) -> &'static str {
let mut env = crate::analysis::types::TypeEnv::new();
for (sym, ty_str) in variable_types {
let ty = crate::analysis::types::LogosType::from_rust_type_str(ty_str);
env.register(*sym, ty);
}
let inferred = env.infer_expr(expr, interner);
match inferred {
crate::analysis::types::LogosType::Int => "i64",
crate::analysis::types::LogosType::Float => "f64",
_ => "unknown",
}
}