use std::fmt;
use boltffi_ffi_rules::naming::{LibraryName, Name};
#[derive(Debug, Clone)]
pub struct CSharpModule {
pub namespace: String,
pub class_name: String,
pub lib_name: Name<LibraryName>,
pub prefix: String,
pub records: Vec<CSharpRecord>,
pub functions: Vec<CSharpFunction>,
}
impl CSharpModule {
pub fn has_functions(&self) -> bool {
!self.functions.is_empty()
}
pub fn needs_system_text(&self) -> bool {
self.functions
.iter()
.any(|f| f.params.iter().any(|p| p.csharp_type.is_string()))
|| self.records.iter().any(CSharpRecord::has_string_fields)
}
pub fn has_wire_params(&self) -> bool {
self.functions.iter().any(|f| !f.wire_writers.is_empty())
}
pub fn has_ffi_buf_returns(&self) -> bool {
self.functions
.iter()
.any(|f| f.return_kind.native_returns_ffi_buf())
}
pub fn needs_ffi_buf(&self) -> bool {
self.has_ffi_buf_returns() || !self.records.is_empty()
}
pub fn needs_wire_reader(&self) -> bool {
self.has_ffi_buf_returns() || !self.records.is_empty()
}
pub fn needs_wire_writer(&self) -> bool {
self.has_wire_params() || !self.records.is_empty()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CSharpType {
Void,
Bool,
SByte,
Byte,
Short,
UShort,
Int,
UInt,
Long,
ULong,
NInt,
NUInt,
Float,
Double,
String,
Record(String),
}
impl CSharpType {
pub fn display_name(&self) -> &str {
match self {
Self::Void => "void",
Self::Bool => "bool",
Self::SByte => "sbyte",
Self::Byte => "byte",
Self::Short => "short",
Self::UShort => "ushort",
Self::Int => "int",
Self::UInt => "uint",
Self::Long => "long",
Self::ULong => "ulong",
Self::NInt => "nint",
Self::NUInt => "nuint",
Self::Float => "float",
Self::Double => "double",
Self::String => "string",
Self::Record(name) => name.as_str(),
}
}
pub fn is_void(&self) -> bool {
matches!(self, Self::Void)
}
pub fn is_bool(&self) -> bool {
matches!(self, Self::Bool)
}
pub fn is_string(&self) -> bool {
matches!(self, Self::String)
}
pub fn is_record(&self) -> bool {
matches!(self, Self::Record(_))
}
}
impl fmt::Display for CSharpType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.display_name())
}
}
#[derive(Debug, Clone)]
pub struct CSharpRecord {
pub class_name: String,
pub fields: Vec<CSharpRecordField>,
pub is_blittable: bool,
}
impl CSharpRecord {
pub fn is_empty(&self) -> bool {
self.fields.is_empty()
}
pub fn needs_wire_helpers(&self) -> bool {
!self.is_blittable
}
pub fn has_string_fields(&self) -> bool {
self.fields.iter().any(|f| f.csharp_type.is_string())
}
}
#[derive(Debug, Clone)]
pub struct CSharpRecordField {
pub name: String,
pub csharp_type: CSharpType,
pub wire_decode_expr: String,
pub wire_size_expr: String,
pub wire_encode_expr: String,
}
#[derive(Debug, Clone)]
pub struct CSharpFunction {
pub name: String,
pub params: Vec<CSharpParam>,
pub return_type: CSharpType,
pub return_kind: CSharpReturnKind,
pub ffi_name: String,
pub wire_writers: Vec<CSharpWireWriter>,
}
impl CSharpFunction {
pub fn is_void(&self) -> bool {
matches!(self.return_kind, CSharpReturnKind::Void)
}
pub fn wrapper_param_list(&self) -> String {
self.params
.iter()
.map(CSharpParam::wrapper_declaration)
.collect::<Vec<_>>()
.join(", ")
}
pub fn native_param_list(&self) -> String {
self.params
.iter()
.map(CSharpParam::native_declaration)
.collect::<Vec<_>>()
.join(", ")
}
pub fn native_call_args(&self) -> String {
self.params
.iter()
.map(CSharpParam::native_call_arg)
.collect::<Vec<_>>()
.join(", ")
}
pub fn native_return_type(&self) -> String {
if self.return_kind.native_returns_ffi_buf() {
"FfiBuf".to_string()
} else {
self.return_type.to_string()
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CSharpReturnKind {
Void,
Direct,
WireDecodeString,
WireDecodeRecord { class_name: String },
}
impl CSharpReturnKind {
pub fn is_void(&self) -> bool {
matches!(self, Self::Void)
}
pub fn is_direct(&self) -> bool {
matches!(self, Self::Direct)
}
pub fn is_wire_decode_string(&self) -> bool {
matches!(self, Self::WireDecodeString)
}
pub fn is_wire_decode_record(&self) -> bool {
matches!(self, Self::WireDecodeRecord { .. })
}
pub fn native_returns_ffi_buf(&self) -> bool {
matches!(self, Self::WireDecodeString | Self::WireDecodeRecord { .. })
}
pub fn decode_class_name(&self) -> Option<&str> {
match self {
Self::WireDecodeRecord { class_name } => Some(class_name),
_ => None,
}
}
pub fn wire_decode_return(&self, buf_var: &str) -> Option<String> {
match self {
Self::WireDecodeString => {
Some(format!("return new WireReader({}).ReadString();", buf_var))
}
Self::WireDecodeRecord { class_name } => Some(format!(
"return {}.Decode(new WireReader({}));",
class_name, buf_var
)),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct CSharpParam {
pub name: String,
pub csharp_type: CSharpType,
pub kind: CSharpParamKind,
}
impl CSharpParam {
pub fn wrapper_declaration(&self) -> String {
format!("{} {}", self.csharp_type, self.name)
}
pub fn native_declaration(&self) -> String {
match &self.kind {
CSharpParamKind::Utf8Bytes | CSharpParamKind::WireEncoded { .. } => {
format!("byte[] {name}, UIntPtr {name}Len", name = self.name)
}
CSharpParamKind::Direct if self.csharp_type.is_bool() => {
format!("[MarshalAs(UnmanagedType.I1)] bool {}", self.name)
}
CSharpParamKind::Direct => {
format!("{} {}", self.csharp_type, self.name)
}
}
}
pub fn native_call_arg(&self) -> String {
match &self.kind {
CSharpParamKind::Direct => self.name.clone(),
CSharpParamKind::Utf8Bytes => {
let buf = format!("_{}Bytes", self.name);
format!("{buf}, (UIntPtr){buf}.Length")
}
CSharpParamKind::WireEncoded { binding_name } => {
format!("{binding_name}, (UIntPtr){binding_name}.Length")
}
}
}
pub fn setup_statement(&self) -> Option<String> {
match &self.kind {
CSharpParamKind::Utf8Bytes => Some(format!(
"byte[] _{name}Bytes = Encoding.UTF8.GetBytes({name});",
name = self.name
)),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CSharpParamKind {
Direct,
Utf8Bytes,
WireEncoded { binding_name: String },
}
#[derive(Debug, Clone)]
pub struct CSharpWireWriter {
pub binding_name: String,
pub bytes_binding_name: String,
pub param_name: String,
pub size_expr: String,
pub encode_expr: String,
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
fn function_with_return(
return_type: CSharpType,
return_kind: CSharpReturnKind,
) -> CSharpFunction {
CSharpFunction {
name: "Test".to_string(),
params: vec![],
return_type,
return_kind,
ffi_name: "boltffi_test".to_string(),
wire_writers: vec![],
}
}
fn param(name: &str, csharp_type: CSharpType, kind: CSharpParamKind) -> CSharpParam {
CSharpParam {
name: name.to_string(),
csharp_type,
kind,
}
}
#[rstest]
#[case::void(CSharpType::Void, CSharpReturnKind::Void, true)]
#[case::int(CSharpType::Int, CSharpReturnKind::Direct, false)]
#[case::bool(CSharpType::Bool, CSharpReturnKind::Direct, false)]
#[case::double(CSharpType::Double, CSharpReturnKind::Direct, false)]
fn is_void(
#[case] return_type: CSharpType,
#[case] return_kind: CSharpReturnKind,
#[case] expected: bool,
) {
assert_eq!(
function_with_return(return_type, return_kind).is_void(),
expected
);
}
#[test]
fn record_type_display_uses_class_name() {
let ty = CSharpType::Record("Point".to_string());
assert_eq!(ty.to_string(), "Point");
assert!(ty.is_record());
}
#[test]
fn wrapper_declaration_puts_type_before_name() {
let p = param("value", CSharpType::Int, CSharpParamKind::Direct);
assert_eq!(p.wrapper_declaration(), "int value");
}
#[test]
fn wrapper_declaration_uses_record_class_name() {
let p = param(
"point",
CSharpType::Record("Point".to_string()),
CSharpParamKind::Direct,
);
assert_eq!(p.wrapper_declaration(), "Point point");
}
#[test]
fn native_declaration_direct_primitive_matches_wrapper() {
let p = param("value", CSharpType::Int, CSharpParamKind::Direct);
assert_eq!(p.native_declaration(), "int value");
}
#[test]
fn native_declaration_bool_gets_marshal_attribute() {
let p = param("flag", CSharpType::Bool, CSharpParamKind::Direct);
assert_eq!(
p.native_declaration(),
"[MarshalAs(UnmanagedType.I1)] bool flag"
);
}
#[test]
fn native_declaration_blittable_record_passes_by_value() {
let p = param(
"point",
CSharpType::Record("Point".to_string()),
CSharpParamKind::Direct,
);
assert_eq!(p.native_declaration(), "Point point");
}
#[test]
fn native_declaration_string_splits_into_bytes_and_length() {
let p = param("v", CSharpType::String, CSharpParamKind::Utf8Bytes);
assert_eq!(p.native_declaration(), "byte[] v, UIntPtr vLen");
}
#[test]
fn native_declaration_wire_encoded_record_splits_into_bytes_and_length() {
let p = param(
"person",
CSharpType::Record("Person".to_string()),
CSharpParamKind::WireEncoded {
binding_name: "_personBytes".to_string(),
},
);
assert_eq!(p.native_declaration(), "byte[] person, UIntPtr personLen");
}
#[test]
fn native_call_arg_direct_passes_name() {
let p = param("value", CSharpType::Int, CSharpParamKind::Direct);
assert_eq!(p.native_call_arg(), "value");
}
#[test]
fn native_call_arg_utf8_bytes_passes_buffer_and_length() {
let p = param("v", CSharpType::String, CSharpParamKind::Utf8Bytes);
assert_eq!(p.native_call_arg(), "_vBytes, (UIntPtr)_vBytes.Length");
}
#[test]
fn native_call_arg_wire_encoded_uses_binding_name() {
let p = param(
"person",
CSharpType::Record("Person".to_string()),
CSharpParamKind::WireEncoded {
binding_name: "_personBytes".to_string(),
},
);
assert_eq!(
p.native_call_arg(),
"_personBytes, (UIntPtr)_personBytes.Length"
);
}
#[rstest]
#[case::direct(CSharpParamKind::Direct, None)]
#[case::wire_encoded(
CSharpParamKind::WireEncoded { binding_name: "_personBytes".to_string() },
None,
)]
fn setup_statement_non_string_has_none(
#[case] kind: CSharpParamKind,
#[case] expected: Option<&str>,
) {
let p = param("x", CSharpType::Int, kind);
assert_eq!(p.setup_statement().as_deref(), expected);
}
#[test]
fn setup_statement_utf8_bytes_encodes_string() {
let p = param("v", CSharpType::String, CSharpParamKind::Utf8Bytes);
assert_eq!(
p.setup_statement().as_deref(),
Some("byte[] _vBytes = Encoding.UTF8.GetBytes(v);"),
);
}
fn function_with_params(
params: Vec<CSharpParam>,
return_type: CSharpType,
return_kind: CSharpReturnKind,
) -> CSharpFunction {
CSharpFunction {
name: "Test".to_string(),
params,
return_type,
return_kind,
ffi_name: "boltffi_test".to_string(),
wire_writers: vec![],
}
}
#[test]
fn wrapper_param_list_joins_with_comma_space() {
let f = function_with_params(
vec![
param("a", CSharpType::Int, CSharpParamKind::Direct),
param("b", CSharpType::String, CSharpParamKind::Utf8Bytes),
],
CSharpType::Void,
CSharpReturnKind::Void,
);
assert_eq!(f.wrapper_param_list(), "int a, string b");
}
#[test]
fn wrapper_param_list_empty_for_no_params() {
let f = function_with_params(vec![], CSharpType::Void, CSharpReturnKind::Void);
assert_eq!(f.wrapper_param_list(), "");
}
#[test]
fn native_param_list_expands_each_slot_by_kind() {
let f = function_with_params(
vec![
param("flag", CSharpType::Bool, CSharpParamKind::Direct),
param("v", CSharpType::String, CSharpParamKind::Utf8Bytes),
param("count", CSharpType::UInt, CSharpParamKind::Direct),
param(
"person",
CSharpType::Record("Person".to_string()),
CSharpParamKind::WireEncoded {
binding_name: "_personBytes".to_string(),
},
),
],
CSharpType::Void,
CSharpReturnKind::Void,
);
assert_eq!(
f.native_param_list(),
"[MarshalAs(UnmanagedType.I1)] bool flag, byte[] v, UIntPtr vLen, uint count, byte[] person, UIntPtr personLen",
);
}
#[test]
fn native_call_args_mirror_param_shapes() {
let f = function_with_params(
vec![
param("v", CSharpType::String, CSharpParamKind::Utf8Bytes),
param("count", CSharpType::UInt, CSharpParamKind::Direct),
],
CSharpType::Void,
CSharpReturnKind::Void,
);
assert_eq!(
f.native_call_args(),
"_vBytes, (UIntPtr)_vBytes.Length, count",
);
}
#[rstest]
#[case::void(CSharpType::Void, CSharpReturnKind::Void, "void")]
#[case::primitive(CSharpType::Int, CSharpReturnKind::Direct, "int")]
#[case::blittable_record(
CSharpType::Record("Point".to_string()),
CSharpReturnKind::Direct,
"Point",
)]
#[case::string(CSharpType::String, CSharpReturnKind::WireDecodeString, "FfiBuf")]
#[case::wire_record(
CSharpType::Record("Person".to_string()),
CSharpReturnKind::WireDecodeRecord { class_name: "Person".to_string() },
"FfiBuf",
)]
fn native_return_type_reflects_ffi_buf_paths(
#[case] return_type: CSharpType,
#[case] return_kind: CSharpReturnKind,
#[case] expected: &str,
) {
assert_eq!(
function_with_return(return_type, return_kind).native_return_type(),
expected
);
}
#[test]
fn wire_decode_return_for_string_uses_read_string() {
let kind = CSharpReturnKind::WireDecodeString;
assert_eq!(
kind.wire_decode_return("_buf").as_deref(),
Some("return new WireReader(_buf).ReadString();"),
);
}
#[test]
fn wire_decode_return_for_record_calls_decode() {
let kind = CSharpReturnKind::WireDecodeRecord {
class_name: "Person".to_string(),
};
assert_eq!(
kind.wire_decode_return("_buf").as_deref(),
Some("return Person.Decode(new WireReader(_buf));"),
);
}
#[rstest]
#[case::void(CSharpReturnKind::Void)]
#[case::direct(CSharpReturnKind::Direct)]
fn wire_decode_return_none_for_non_wire_kinds(#[case] kind: CSharpReturnKind) {
assert_eq!(kind.wire_decode_return("_buf"), None);
}
#[test]
fn decode_class_name_some_only_for_wire_decode_record() {
assert_eq!(
CSharpReturnKind::WireDecodeRecord {
class_name: "Point".to_string()
}
.decode_class_name(),
Some("Point"),
);
assert_eq!(CSharpReturnKind::WireDecodeString.decode_class_name(), None);
assert_eq!(CSharpReturnKind::Void.decode_class_name(), None);
assert_eq!(CSharpReturnKind::Direct.decode_class_name(), None);
}
}