use std::collections::HashSet;
use boltffi_ffi_rules::naming;
use crate::ir::abi::{AbiCall, AbiParam, AbiRecord, CallId, ParamRole};
use crate::ir::definitions::{FieldDef, FunctionDef, ParamDef, ParamPassing, RecordDef, ReturnDef};
use crate::ir::ids::{FieldName, RecordId};
use crate::ir::ops::{ReadOp, ReadSeq, WriteOp, WriteSeq};
use crate::ir::types::TypeExpr;
use crate::ir::{AbiContract, FfiContract};
use super::emit;
use super::mappings;
use super::plan::{
CSharpFunction, CSharpModule, CSharpParam, CSharpParamKind, CSharpRecord, CSharpRecordField,
CSharpReturnKind, CSharpType, CSharpWireWriter,
};
use super::{CSharpOptions, NamingConvention};
pub struct CSharpLowerer<'a> {
ffi: &'a FfiContract,
abi: &'a AbiContract,
options: &'a CSharpOptions,
supported_records: HashSet<String>,
}
impl<'a> CSharpLowerer<'a> {
pub fn new(ffi: &'a FfiContract, abi: &'a AbiContract, options: &'a CSharpOptions) -> Self {
let supported_records = Self::compute_supported_records(ffi);
Self {
ffi,
abi,
options,
supported_records,
}
}
fn compute_supported_records(ffi: &FfiContract) -> HashSet<String> {
let mut supported = HashSet::new();
let mut changed = true;
while changed {
changed = false;
for record in ffi.catalog.all_records() {
let id = record.id.as_str().to_string();
if supported.contains(&id) {
continue;
}
let all_fields_ok = record.fields.iter().all(|field| match &field.type_expr {
TypeExpr::Primitive(_) | TypeExpr::String | TypeExpr::Void => true,
TypeExpr::Record(ref_id) => supported.contains(ref_id.as_str()),
_ => false,
});
if all_fields_ok {
supported.insert(id);
changed = true;
}
}
}
supported
}
pub fn lower(&self) -> CSharpModule {
let lib_name = self
.options
.library_name
.clone()
.unwrap_or_else(|| naming::library_name(&self.ffi.package.name));
let class_name = NamingConvention::class_name(&self.ffi.package.name);
let namespace = NamingConvention::namespace(&self.ffi.package.name);
let prefix = naming::ffi_prefix().to_string();
let records: Vec<CSharpRecord> = self
.ffi
.catalog
.all_records()
.filter(|r| self.supported_records.contains(r.id.as_str()))
.map(|r| self.lower_record(r))
.collect();
let functions: Vec<CSharpFunction> = self
.ffi
.functions
.iter()
.filter_map(|f| self.lower_function(f))
.collect();
CSharpModule {
namespace,
class_name,
lib_name,
prefix,
records,
functions,
}
}
fn lower_function(&self, function: &FunctionDef) -> Option<CSharpFunction> {
if function.is_async() {
return None;
}
if !function.params.iter().all(|p| self.is_supported_param(p)) {
return None;
}
let return_type = self.lower_return(&function.returns)?;
let return_kind = self.return_kind(&function.returns, &return_type);
let wire_writers = self.wire_writers_for_params(function)?;
let params: Vec<CSharpParam> = function
.params
.iter()
.map(|p| self.lower_param(p, &wire_writers))
.collect::<Option<Vec<_>>>()?;
Some(CSharpFunction {
name: NamingConvention::method_name(function.id.as_str()),
ffi_name: naming::function_ffi_name(function.id.as_str()).into_string(),
params,
return_type,
return_kind,
wire_writers,
})
}
fn return_kind(&self, return_def: &ReturnDef, return_type: &CSharpType) -> CSharpReturnKind {
if return_type.is_void() {
return CSharpReturnKind::Void;
}
match return_def {
ReturnDef::Value(TypeExpr::String) => CSharpReturnKind::WireDecodeString,
ReturnDef::Value(TypeExpr::Record(id)) if !self.is_blittable_record(id) => {
CSharpReturnKind::WireDecodeRecord {
class_name: NamingConvention::class_name(id.as_str()),
}
}
_ => CSharpReturnKind::Direct,
}
}
fn is_blittable_record(&self, id: &RecordId) -> bool {
self.abi_record_for(id).is_some_and(|r| r.is_blittable)
}
fn is_supported_param(&self, param: &ParamDef) -> bool {
param.passing == ParamPassing::Value && self.is_supported_type(¶m.type_expr)
}
fn is_supported_type(&self, ty: &TypeExpr) -> bool {
match ty {
TypeExpr::Primitive(_) | TypeExpr::String | TypeExpr::Void => true,
TypeExpr::Record(id) => self.supported_records.contains(id.as_str()),
_ => false,
}
}
fn lower_param(
&self,
param: &ParamDef,
wire_writers: &[CSharpWireWriter],
) -> Option<CSharpParam> {
if param.passing != ParamPassing::Value {
return None;
}
let csharp_type = self.lower_type(¶m.type_expr)?;
let kind = match ¶m.type_expr {
TypeExpr::String => CSharpParamKind::Utf8Bytes,
TypeExpr::Record(id) if !self.is_blittable_record(id) => {
let writer = wire_writers
.iter()
.find(|w| w.param_name == param.name.as_str())?;
CSharpParamKind::WireEncoded {
binding_name: writer.bytes_binding_name.clone(),
}
}
_ => CSharpParamKind::Direct,
};
Some(CSharpParam {
name: NamingConvention::field_name(param.name.as_str()),
csharp_type,
kind,
})
}
fn lower_return(&self, return_def: &ReturnDef) -> Option<CSharpType> {
match return_def {
ReturnDef::Void => Some(CSharpType::Void),
ReturnDef::Value(type_expr) => self.lower_type(type_expr),
ReturnDef::Result { .. } => None,
}
}
fn lower_type(&self, type_expr: &TypeExpr) -> Option<CSharpType> {
match type_expr {
TypeExpr::Void => Some(CSharpType::Void),
TypeExpr::Primitive(primitive) => Some(mappings::csharp_type(*primitive)),
TypeExpr::String => Some(CSharpType::String),
TypeExpr::Record(id) if self.supported_records.contains(id.as_str()) => Some(
CSharpType::Record(NamingConvention::class_name(id.as_str())),
),
_ => None,
}
}
fn lower_record(&self, record: &RecordDef) -> CSharpRecord {
let class_name = NamingConvention::class_name(record.id.as_str());
let fields = record
.fields
.iter()
.map(|field| self.lower_record_field(&record.id, field))
.collect();
let is_blittable = self.is_blittable_record(&record.id);
CSharpRecord {
class_name,
fields,
is_blittable,
}
}
fn lower_record_field(&self, record_id: &RecordId, field: &FieldDef) -> CSharpRecordField {
let decode_seq = self
.record_field_read_seq(record_id, &field.name)
.expect("record field decode ops");
let encode_seq = self
.record_field_write_seq(record_id, &field.name)
.expect("record field encode ops");
let csharp_type = self
.lower_type(&field.type_expr)
.expect("record field type must be supported");
CSharpRecordField {
name: NamingConvention::property_name(field.name.as_str()),
csharp_type,
wire_decode_expr: emit::emit_reader_read(&decode_seq),
wire_size_expr: emit::emit_size_expr(&encode_seq.size),
wire_encode_expr: emit::emit_write_expr(&encode_seq, "wire"),
}
}
fn record_field_read_seq(
&self,
record_id: &RecordId,
field_name: &FieldName,
) -> Option<ReadSeq> {
self.abi_record_for(record_id)
.and_then(|record| match record.decode_ops.ops.first() {
Some(ReadOp::Record { fields, .. }) => fields
.iter()
.find(|field| field.name == *field_name)
.map(|field| field.seq.clone()),
_ => None,
})
}
fn record_field_write_seq(
&self,
record_id: &RecordId,
field_name: &FieldName,
) -> Option<WriteSeq> {
self.abi_record_for(record_id)
.and_then(|record| match record.encode_ops.ops.first() {
Some(WriteOp::Record { fields, .. }) => fields
.iter()
.find(|field| field.name == *field_name)
.map(|field| field.seq.clone()),
_ => None,
})
}
fn abi_record_for(&self, record_id: &RecordId) -> Option<&AbiRecord> {
self.abi
.records
.iter()
.find(|record| record.id == *record_id)
}
fn wire_writers_for_params(&self, function: &FunctionDef) -> Option<Vec<CSharpWireWriter>> {
let call = self.abi_call_for_function(function)?;
Some(
call.params
.iter()
.filter_map(|abi_param| self.wire_writer_for_param(abi_param))
.collect(),
)
}
fn wire_writer_for_param(&self, param: &AbiParam) -> Option<CSharpWireWriter> {
let encode_ops = match ¶m.role {
ParamRole::Input {
encode_ops: Some(encode_ops),
..
} => encode_ops.clone(),
_ => return None,
};
let record_id = match encode_ops.ops.first()? {
WriteOp::Record { id, .. } => id,
_ => return None,
};
if self.is_blittable_record(record_id) {
return None;
}
let param_name = param.name.as_str().to_string();
let binding_name = format!("_wire_{}", param_name);
let bytes_binding_name = format!("_{}Bytes", NamingConvention::field_name(¶m_name));
let encode_expr = emit::emit_write_expr(&encode_ops, &binding_name);
Some(CSharpWireWriter {
binding_name,
bytes_binding_name,
param_name,
size_expr: emit::emit_size_expr(&encode_ops.size),
encode_expr,
})
}
fn abi_call_for_function(&self, function: &FunctionDef) -> Option<&AbiCall> {
self.abi.calls.iter().find(|call| match &call.id {
CallId::Function(id) => id == &function.id,
_ => false,
})
}
}