use std::collections::HashSet;
use std::fmt::Write;
use crate::analysis::registry::TypeRegistry;
use crate::analysis::types::RustNames;
use crate::ast::stmt::{Stmt, TypeExpr};
use crate::intern::{Interner, Symbol};
use super::context::{RefinementContext, analyze_variable_capabilities, replace_word};
use super::detection::{is_result_type, collect_mutable_vars};
use super::types::codegen_type_expr;
use super::{
CAbiClass, classify_type_for_c_abi,
codegen_assertion, codegen_stmt,
try_emit_vec_fill_pattern, try_emit_for_range_pattern, try_emit_swap_pattern,
};
pub(super) fn is_text_type(ty: &TypeExpr, interner: &Interner) -> bool {
match ty {
TypeExpr::Primitive(sym) | TypeExpr::Named(sym) => {
matches!(interner.resolve(*sym), "Text" | "String")
}
TypeExpr::Refinement { base, .. } => is_text_type(base, interner),
_ => false,
}
}
pub(super) fn map_type_to_c_abi(ty: &TypeExpr, interner: &Interner, is_return: bool) -> String {
if is_text_type(ty, interner) {
if is_return {
"*mut std::os::raw::c_char".to_string()
} else {
"*const std::os::raw::c_char".to_string()
}
} else {
codegen_type_expr(ty, interner)
}
}
pub(super) fn codegen_c_export_with_marshaling(
name: Symbol,
params: &[(Symbol, &TypeExpr)],
body: &[Stmt],
return_type: Option<&TypeExpr>,
interner: &Interner,
lww_fields: &HashSet<(String, String)>,
mv_fields: &HashSet<(String, String)>,
async_functions: &HashSet<Symbol>,
boxed_fields: &HashSet<(String, String, String)>,
registry: &crate::analysis::registry::TypeRegistry,
type_env: &crate::analysis::types::TypeEnv,
) -> String {
let mut output = String::new();
let names = RustNames::new(interner);
let raw_name = names.raw(name);
let func_name = format!("logos_{}", raw_name);
let inner_name = names.ident(name);
let has_ref_return = return_type.map_or(false, |ty| {
classify_type_for_c_abi(ty, interner, registry) == CAbiClass::ReferenceType
});
let has_result_return = return_type.map_or(false, |ty| is_result_type(ty, interner));
let has_text_return = return_type.map_or(false, |t| is_text_type(t, interner));
let uses_status_code = has_ref_return || has_result_return || has_text_return
|| params.iter().any(|(_, ty)| matches!(ty, TypeExpr::Refinement { .. }));
let inner_params: Vec<String> = params.iter()
.map(|(pname, ptype)| {
format!("{}: {}", interner.resolve(*pname), codegen_type_expr(ptype, interner))
})
.collect();
let inner_ret = return_type.map(|t| codegen_type_expr(t, interner));
let inner_sig = if let Some(ref ret) = inner_ret {
if ret != "()" {
format!("fn {}({}) -> {}", inner_name, inner_params.join(", "), ret)
} else {
format!("fn {}({})", inner_name, inner_params.join(", "))
}
} else {
format!("fn {}({})", inner_name, inner_params.join(", "))
};
writeln!(output, "{} {{", inner_sig).unwrap();
let func_mutable_vars = collect_mutable_vars(body);
let mut func_ctx = RefinementContext::new();
let mut func_synced_vars = HashSet::new();
let func_var_caps = analyze_variable_capabilities(body, interner);
for (param_name, param_type) in params {
let type_name = codegen_type_expr(param_type, interner);
func_ctx.register_variable_type(*param_name, type_name);
}
let func_pipe_vars = HashSet::new();
{
let stmt_refs: Vec<&Stmt> = body.iter().collect();
let mut si = 0;
while si < stmt_refs.len() {
if let Some((code, skip)) = try_emit_vec_fill_pattern(&stmt_refs, si, interner, 1, &mut func_ctx) {
output.push_str(&code);
si += 1 + skip;
continue;
}
if let Some((code, skip)) = try_emit_for_range_pattern(&stmt_refs, si, interner, 1, &func_mutable_vars, &mut func_ctx, lww_fields, mv_fields, &mut func_synced_vars, &func_var_caps, async_functions, &func_pipe_vars, boxed_fields, registry, type_env) {
output.push_str(&code);
si += 1 + skip;
continue;
}
if let Some((code, skip)) = try_emit_swap_pattern(&stmt_refs, si, interner, 1, func_ctx.get_variable_types()) {
output.push_str(&code);
si += 1 + skip;
continue;
}
output.push_str(&codegen_stmt(stmt_refs[si], interner, 1, &func_mutable_vars, &mut func_ctx, lww_fields, mv_fields, &mut func_synced_vars, &func_var_caps, async_functions, &func_pipe_vars, boxed_fields, registry, type_env));
si += 1;
}
}
writeln!(output, "}}\n").unwrap();
let mut c_params: Vec<String> = Vec::new();
for (pname, ptype) in params.iter() {
let pn = names.ident(*pname);
if classify_type_for_c_abi(ptype, interner, registry) == CAbiClass::ReferenceType {
c_params.push(format!("{}: LogosHandle", pn));
} else if is_text_type(ptype, interner) {
c_params.push(format!("{}: *const std::os::raw::c_char", pn));
} else {
c_params.push(format!("{}: {}", pn, codegen_type_expr(ptype, interner)));
}
}
if uses_status_code {
if let Some(ret_ty) = return_type {
if has_result_return {
if let TypeExpr::Generic { params: ref rparams, .. } = ret_ty {
if !rparams.is_empty() {
let ok_ty = &rparams[0];
if classify_type_for_c_abi(ok_ty, interner, registry) == CAbiClass::ReferenceType {
c_params.push("out: *mut LogosHandle".to_string());
} else if is_text_type(ok_ty, interner) {
c_params.push("out: *mut *mut std::os::raw::c_char".to_string());
} else {
let ty_str = codegen_type_expr(ok_ty, interner);
c_params.push(format!("out: *mut {}", ty_str));
}
}
}
} else if has_ref_return {
c_params.push("out: *mut LogosHandle".to_string());
} else if has_text_return {
c_params.push("out: *mut *mut std::os::raw::c_char".to_string());
}
}
}
let c_sig = if uses_status_code {
format!("pub extern \"C\" fn {}({}) -> LogosStatus", func_name, c_params.join(", "))
} else if has_text_return {
format!("pub extern \"C\" fn {}({}) -> *mut std::os::raw::c_char", func_name, c_params.join(", "))
} else if let Some(ret_ty) = return_type {
let ret_str = codegen_type_expr(ret_ty, interner);
if ret_str != "()" {
format!("pub extern \"C\" fn {}({}) -> {}", func_name, c_params.join(", "), ret_str)
} else {
format!("pub extern \"C\" fn {}({})", func_name, c_params.join(", "))
}
} else {
format!("pub extern \"C\" fn {}({})", func_name, c_params.join(", "))
};
writeln!(output, "#[no_mangle]").unwrap();
writeln!(output, "{} {{", c_sig).unwrap();
let call_args: Vec<String> = params.iter()
.map(|(pname, ptype)| {
let pname_str = names.ident(*pname);
if classify_type_for_c_abi(ptype, interner, registry) == CAbiClass::ReferenceType {
let rust_ty = codegen_type_expr(ptype, interner);
writeln!(output, " let {pn} = {{", pn = pname_str).unwrap();
writeln!(output, " let __id = {pn} as u64;", pn = pname_str).unwrap();
writeln!(output, " let __reg = logos_handle_registry().lock().unwrap_or_else(|e| e.into_inner());").unwrap();
writeln!(output, " let __ptr = __reg.deref(__id).expect(\"InvalidHandle: handle not found in registry\");").unwrap();
writeln!(output, " drop(__reg);").unwrap();
writeln!(output, " unsafe {{ &*(__ptr as *const {ty}) }}.clone()", ty = rust_ty).unwrap();
writeln!(output, " }};").unwrap();
} else if is_text_type(ptype, interner) {
if uses_status_code {
writeln!(output, " if {pn}.is_null() {{ logos_set_last_error(\"NullPointer: text parameter '{pn}' is null\".to_string()); return LogosStatus::NullPointer; }}",
pn = pname_str).unwrap();
writeln!(output, " let {pn} = unsafe {{ std::ffi::CStr::from_ptr({pn}).to_string_lossy().into_owned() }};",
pn = pname_str).unwrap();
} else {
writeln!(output, " let {pn} = if {pn}.is_null() {{ String::new() }} else {{ unsafe {{ std::ffi::CStr::from_ptr({pn}).to_string_lossy().into_owned() }} }};",
pn = pname_str).unwrap();
}
}
pname_str.to_string()
})
.collect();
for (pname, ptype) in params.iter() {
if let TypeExpr::Refinement { base: _, var, predicate } = ptype {
let pname_str = interner.resolve(*pname);
let bound = interner.resolve(*var);
let assertion = codegen_assertion(predicate, interner);
let check = if bound == pname_str {
assertion
} else {
replace_word(&assertion, bound, pname_str)
};
writeln!(output, " if !({}) {{", check).unwrap();
writeln!(output, " logos_set_last_error(format!(\"Refinement violation: expected {check}, got {pn} = {{}}\", {pn}));",
check = check, pn = pname_str).unwrap();
writeln!(output, " return LogosStatus::RefinementViolation;").unwrap();
writeln!(output, " }}").unwrap();
}
}
if uses_status_code && (has_ref_return || has_text_return || has_result_return) {
writeln!(output, " if out.is_null() {{ logos_set_last_error(\"NullPointer: output parameter is null\".to_string()); return LogosStatus::NullPointer; }}").unwrap();
}
let panic_default = if uses_status_code {
"LogosStatus::ThreadPanic"
} else if has_text_return {
"std::ptr::null_mut()"
} else if return_type.map_or(false, |t| codegen_type_expr(t, interner) != "()") {
"Default::default()"
} else {
"" };
writeln!(output, " match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {{").unwrap();
if uses_status_code {
if has_result_return {
writeln!(output, " match {}({}) {{", inner_name, call_args.join(", ")).unwrap();
writeln!(output, " Ok(val) => {{").unwrap();
if let Some(TypeExpr::Generic { params: ref rparams, .. }) = return_type {
if !rparams.is_empty() {
let ok_ty = &rparams[0];
if classify_type_for_c_abi(ok_ty, interner, registry) == CAbiClass::ReferenceType {
writeln!(output, " let __ptr = Box::into_raw(Box::new(val)) as usize;").unwrap();
writeln!(output, " let mut __reg = logos_handle_registry().lock().unwrap_or_else(|e| e.into_inner());").unwrap();
writeln!(output, " let (__id, _) = __reg.register(__ptr);").unwrap();
writeln!(output, " unsafe {{ *out = __id as LogosHandle; }}").unwrap();
} else if is_text_type(ok_ty, interner) {
writeln!(output, " match std::ffi::CString::new(val) {{").unwrap();
writeln!(output, " Ok(cstr) => unsafe {{ *out = cstr.into_raw(); }},").unwrap();
writeln!(output, " Err(_) => {{").unwrap();
writeln!(output, " logos_set_last_error(\"Return value contains null byte\".to_string());").unwrap();
writeln!(output, " return LogosStatus::ContainsNullByte;").unwrap();
writeln!(output, " }}").unwrap();
writeln!(output, " }}").unwrap();
} else {
writeln!(output, " unsafe {{ *out = val; }}").unwrap();
}
}
}
writeln!(output, " LogosStatus::Ok").unwrap();
writeln!(output, " }}").unwrap();
writeln!(output, " Err(e) => {{").unwrap();
writeln!(output, " logos_set_last_error(format!(\"{{}}\", e));").unwrap();
writeln!(output, " LogosStatus::Error").unwrap();
writeln!(output, " }}").unwrap();
writeln!(output, " }}").unwrap();
} else if has_ref_return {
writeln!(output, " let result = {}({});", inner_name, call_args.join(", ")).unwrap();
writeln!(output, " let __ptr = Box::into_raw(Box::new(result)) as usize;").unwrap();
writeln!(output, " let mut __reg = logos_handle_registry().lock().unwrap_or_else(|e| e.into_inner());").unwrap();
writeln!(output, " let (__id, _) = __reg.register(__ptr);").unwrap();
writeln!(output, " unsafe {{ *out = __id as LogosHandle; }}").unwrap();
writeln!(output, " LogosStatus::Ok").unwrap();
} else if has_text_return {
writeln!(output, " let result = {}({});", inner_name, call_args.join(", ")).unwrap();
writeln!(output, " match std::ffi::CString::new(result) {{").unwrap();
writeln!(output, " Ok(cstr) => {{").unwrap();
writeln!(output, " unsafe {{ *out = cstr.into_raw(); }}").unwrap();
writeln!(output, " LogosStatus::Ok").unwrap();
writeln!(output, " }}").unwrap();
writeln!(output, " Err(_) => {{").unwrap();
writeln!(output, " logos_set_last_error(\"Return value contains null byte\".to_string());").unwrap();
writeln!(output, " LogosStatus::ContainsNullByte").unwrap();
writeln!(output, " }}").unwrap();
writeln!(output, " }}").unwrap();
} else {
writeln!(output, " {}({});", inner_name, call_args.join(", ")).unwrap();
writeln!(output, " LogosStatus::Ok").unwrap();
}
} else if has_text_return {
writeln!(output, " let result = {}({});", inner_name, call_args.join(", ")).unwrap();
writeln!(output, " match std::ffi::CString::new(result) {{").unwrap();
writeln!(output, " Ok(cstr) => cstr.into_raw(),").unwrap();
writeln!(output, " Err(_) => {{ logos_set_last_error(\"Return value contains null byte\".to_string()); std::ptr::null_mut() }}").unwrap();
writeln!(output, " }}").unwrap();
} else if return_type.is_some() {
writeln!(output, " {}({})", inner_name, call_args.join(", ")).unwrap();
} else {
writeln!(output, " {}({})", inner_name, call_args.join(", ")).unwrap();
}
writeln!(output, " }})) {{").unwrap();
writeln!(output, " Ok(__v) => __v,").unwrap();
writeln!(output, " Err(__panic) => {{").unwrap();
writeln!(output, " let __msg = if let Some(s) = __panic.downcast_ref::<String>() {{ s.clone() }} else if let Some(s) = __panic.downcast_ref::<&str>() {{ s.to_string() }} else {{ \"Unknown panic\".to_string() }};").unwrap();
writeln!(output, " logos_set_last_error(__msg);").unwrap();
if !panic_default.is_empty() {
writeln!(output, " {}", panic_default).unwrap();
}
writeln!(output, " }}").unwrap();
writeln!(output, " }}").unwrap();
writeln!(output, "}}\n").unwrap();
output
}