use anyhow::Result;
use crate::abi::{
ATSFunction, ATSModule, ATSParam, MemorySafetyProof, OwnershipPattern, Viewtype,
};
use crate::codegen::parser::CFunctionSignature;
use crate::manifest::{CSource, OwnershipRule};
pub fn generate_module(
project_name: &str,
signatures: &[CFunctionSignature],
rules: &[OwnershipRule],
c_sources: &[CSource],
) -> Result<ATSModule> {
let mut module = ATSModule::new(project_name);
for source in c_sources {
module.includes.push(source.path.clone());
}
let mut seen_viewtypes: std::collections::HashSet<String> = std::collections::HashSet::new();
for rule in rules {
let pattern = parse_pattern(&rule.pattern)?;
if let Some(ref res_type) = rule.resource_type {
let vt_name = viewtype_name(res_type);
if seen_viewtypes.insert(vt_name.clone()) {
let vt = match pattern {
OwnershipPattern::Alloc => {
Viewtype::nullable(&vt_name, res_type)
}
_ => Viewtype::new(&vt_name, res_type),
};
module.viewtypes.push(vt);
}
}
let maybe_sig = signatures.iter().find(|s| s.name == rule.function);
let func = generate_wrapper_function(rule, &pattern, maybe_sig)?;
module.functions.push(func);
let proof = generate_proof(&rule.function, &pattern, rule.resource_type.as_deref());
module.proofs.push(proof);
}
Ok(module)
}
pub fn render_module(module: &ATSModule) -> Result<(String, String)> {
let sats = render_sats(module);
let dats = render_dats(module);
Ok((sats, dats))
}
fn render_sats(module: &ATSModule) -> String {
let mut out = String::new();
out.push_str(&format!(
"(*\n** SPDX-License-Identifier: PMPL-1.0-or-later\n\
** Generated by atsiser — do not edit manually.\n\
** Module: {} (static signatures)\n*)\n\n",
module.name
));
for inc in &module.includes {
out.push_str(&format!("%{{#include \"{}\" %}}\n", inc));
}
if !module.includes.is_empty() {
out.push('\n');
}
if !module.viewtypes.is_empty() {
out.push_str("(* === Viewtype definitions === *)\n\n");
for vt in &module.viewtypes {
out.push_str(&vt.to_ats2_definition());
out.push_str("\n\n");
}
}
out.push_str("(* === External C function declarations === *)\n\n");
for func in &module.functions {
out.push_str(&format!(
"extern fun {}_c: ({}) -> {}\n",
func.c_function,
func.params
.iter()
.map(|p| p.ats_type.as_str())
.collect::<Vec<_>>()
.join(", "),
func.return_type.as_deref().unwrap_or("void")
));
}
out.push('\n');
out.push_str("(* === Safe wrapper function signatures === *)\n\n");
for func in &module.functions {
let params_str: Vec<String> = func
.params
.iter()
.map(|p| {
if p.consumed {
format!("{}: {}", p.name, p.ats_type)
} else {
format!("!{}: {}", p.name, p.ats_type)
}
})
.collect();
let ret = func.return_type.as_deref().unwrap_or("void");
out.push_str(&format!(
"fun {}({}): {}\n",
func.name,
params_str.join(", "),
ret
));
}
out
}
fn render_dats(module: &ATSModule) -> String {
module.to_ats2_source()
}
fn parse_pattern(pattern_str: &str) -> Result<OwnershipPattern> {
match pattern_str {
"alloc" => Ok(OwnershipPattern::Alloc),
"free" => Ok(OwnershipPattern::Free),
"borrow" => Ok(OwnershipPattern::Borrow),
"transfer" => Ok(OwnershipPattern::Transfer),
other => anyhow::bail!("Unknown ownership pattern: '{}'", other),
}
}
fn viewtype_name(c_type: &str) -> String {
let base = c_type.trim_end_matches('*').trim();
format!("{}_vt", base)
}
fn generate_wrapper_function(
rule: &OwnershipRule,
pattern: &OwnershipPattern,
maybe_sig: Option<&CFunctionSignature>,
) -> Result<ATSFunction> {
let safe_name = format!("safe_{}", rule.function);
let resource_vt = rule
.resource_type
.as_deref()
.map(|t| viewtype_name(t))
.unwrap_or_else(|| "ptr".to_string());
let (params, return_type) = match pattern {
OwnershipPattern::Alloc => {
let mut params = Vec::new();
if let Some(sig) = maybe_sig {
for (i, cp) in sig.params.iter().enumerate() {
params.push(ATSParam {
name: if cp.name.is_empty() {
format!("arg{}", i)
} else {
cp.name.clone()
},
ats_type: c_type_to_ats(&cp.c_type),
consumed: false,
});
}
}
(params, Some(resource_vt.clone()))
}
OwnershipPattern::Free => {
let params = vec![ATSParam {
name: "resource".to_string(),
ats_type: resource_vt.clone(),
consumed: true,
}];
(params, None)
}
OwnershipPattern::Borrow => {
let mut params = vec![ATSParam {
name: "resource".to_string(),
ats_type: resource_vt.clone(),
consumed: false,
}];
if let Some(sig) = maybe_sig {
for (i, cp) in sig.params.iter().enumerate() {
if i == rule.param_index.unwrap_or(0) {
continue; }
params.push(ATSParam {
name: if cp.name.is_empty() {
format!("arg{}", i)
} else {
cp.name.clone()
},
ats_type: c_type_to_ats(&cp.c_type),
consumed: false,
});
}
}
let ret = maybe_sig
.map(|s| c_type_to_ats(&s.return_type))
.unwrap_or_else(|| "int".to_string());
(params, Some(ret))
}
OwnershipPattern::Transfer => {
let params = vec![ATSParam {
name: "resource".to_string(),
ats_type: resource_vt.clone(),
consumed: true,
}];
(params, None)
}
};
Ok(ATSFunction {
name: safe_name,
c_function: rule.function.clone(),
params,
return_type,
pattern: pattern.clone(),
})
}
fn c_type_to_ats(c_type: &str) -> String {
let trimmed = c_type.trim();
match trimmed {
"void" => "void".to_string(),
"int" => "int".to_string(),
"unsigned int" | "uint" => "uint".to_string(),
"size_t" => "size_t".to_string(),
"char" => "char".to_string(),
"double" => "double".to_string(),
"float" => "float".to_string(),
"long" => "lint".to_string(),
"unsigned long" => "ulint".to_string(),
_ if trimmed.contains("const") && trimmed.contains('*') => {
let base = trimmed
.replace("const", "")
.replace('*', "")
.trim()
.to_string();
format!("!ptr({})", base)
}
_ if trimmed.contains('*') => {
let base = trimmed.replace('*', "").trim().to_string();
format!("ptr({})", base)
}
_ => trimmed.to_string(),
}
}
fn generate_proof(
func_name: &str,
pattern: &OwnershipPattern,
resource_type: Option<&str>,
) -> MemorySafetyProof {
let resource = resource_type.unwrap_or("ptr");
match pattern {
OwnershipPattern::Alloc => MemorySafetyProof::AllocProof {
viewtype: viewtype_name(resource),
},
OwnershipPattern::Free => MemorySafetyProof::FreeProof {
ptr_id: func_name.to_string(),
},
OwnershipPattern::Borrow => MemorySafetyProof::BorrowProof {
ptr_id: func_name.to_string(),
mutable: false,
},
OwnershipPattern::Transfer => MemorySafetyProof::FreeProof {
ptr_id: func_name.to_string(),
},
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_viewtype_name_generation() {
assert_eq!(viewtype_name("FILE*"), "FILE_vt");
assert_eq!(viewtype_name("mylib_t"), "mylib_t_vt");
assert_eq!(viewtype_name("void*"), "void_vt");
}
#[test]
fn test_c_type_to_ats_basic() {
assert_eq!(c_type_to_ats("int"), "int");
assert_eq!(c_type_to_ats("size_t"), "size_t");
assert_eq!(c_type_to_ats("void"), "void");
}
#[test]
fn test_c_type_to_ats_pointers() {
let result = c_type_to_ats("char*");
assert!(result.contains("ptr"));
}
#[test]
fn test_c_type_to_ats_const_pointers() {
let result = c_type_to_ats("const char*");
assert!(result.contains("!ptr"));
}
#[test]
fn test_parse_pattern() {
assert_eq!(parse_pattern("alloc").unwrap(), OwnershipPattern::Alloc);
assert_eq!(parse_pattern("free").unwrap(), OwnershipPattern::Free);
assert_eq!(parse_pattern("borrow").unwrap(), OwnershipPattern::Borrow);
assert_eq!(
parse_pattern("transfer").unwrap(),
OwnershipPattern::Transfer
);
assert!(parse_pattern("invalid").is_err());
}
#[test]
fn test_generate_module_empty() {
let module = generate_module("test", &[], &[], &[]).unwrap();
assert_eq!(module.name, "test");
assert!(module.viewtypes.is_empty());
assert!(module.functions.is_empty());
}
#[test]
fn test_generate_module_with_rules() {
let rules = vec![
OwnershipRule {
function: "my_alloc".to_string(),
pattern: "alloc".to_string(),
param_index: None,
resource_type: Some("my_resource".to_string()),
description: String::new(),
},
OwnershipRule {
function: "my_free".to_string(),
pattern: "free".to_string(),
param_index: Some(0),
resource_type: Some("my_resource".to_string()),
description: String::new(),
},
];
let module = generate_module("test", &[], &rules, &[]).unwrap();
assert_eq!(module.functions.len(), 2);
assert_eq!(module.functions[0].name, "safe_my_alloc");
assert_eq!(module.functions[1].name, "safe_my_free");
assert_eq!(module.viewtypes.len(), 1);
assert_eq!(module.viewtypes[0].name, "my_resource_vt");
}
#[test]
fn test_render_module_produces_sats_and_dats() {
let module = ATSModule::new("test");
let (sats, dats) = render_module(&module).unwrap();
assert!(sats.contains("PMPL-1.0-or-later"));
assert!(dats.contains("PMPL-1.0-or-later"));
assert!(sats.contains("static signatures"));
}
}