use crate::{
cilassembly::{ChangeRefRc, CilAssembly},
metadata::{
signatures::{encode_field_signature, SignatureField, TypeSignature},
tables::{
CodedIndex, CodedIndexType, FieldBuilder, InterfaceImplBuilder, TypeAttributes,
TypeDefBuilder,
},
token::Token,
},
Error, Result,
};
use super::method::MethodBuilder;
struct FieldDefinition {
name: String,
field_type: TypeSignature,
attributes: u32,
}
struct PropertyDefinition {
name: String,
property_type: TypeSignature,
has_getter: bool,
has_setter: bool,
backing_field_name: Option<String>,
}
pub struct ClassBuilder {
name: String,
namespace: Option<String>,
flags: u32,
extends: Option<CodedIndex>,
implements: Vec<CodedIndex>,
fields: Vec<FieldDefinition>,
methods: Vec<MethodBuilder>,
properties: Vec<PropertyDefinition>,
generate_default_ctor: bool,
nested_types: Vec<ClassBuilder>,
}
impl ClassBuilder {
#[must_use]
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
namespace: None,
flags: 0x0010_0001, extends: None,
implements: Vec::new(),
fields: Vec::new(),
methods: Vec::new(),
properties: Vec::new(),
generate_default_ctor: false,
nested_types: Vec::new(),
}
}
#[must_use]
pub fn namespace(mut self, namespace: &str) -> Self {
self.namespace = Some(namespace.to_string());
self
}
#[must_use]
pub fn public(mut self) -> Self {
self.flags = (self.flags & !0x0000_0007) | 0x0000_0001; self
}
#[must_use]
pub fn internal(mut self) -> Self {
self.flags &= !0x0000_0007; self
}
#[must_use]
pub fn sealed(mut self) -> Self {
self.flags |= TypeAttributes::SEALED;
self
}
#[must_use]
pub fn abstract_class(mut self) -> Self {
self.flags |= TypeAttributes::ABSTRACT;
self
}
#[must_use]
pub fn inherits(mut self, base_class: CodedIndex) -> Self {
self.extends = Some(base_class);
self
}
#[must_use]
pub fn implements(mut self, interface: CodedIndex) -> Self {
self.implements.push(interface);
self
}
#[must_use]
pub fn field(mut self, name: &str, field_type: TypeSignature) -> Self {
self.fields.push(FieldDefinition {
name: name.to_string(),
field_type,
attributes: 0x0001, });
self
}
#[must_use]
pub fn public_field(mut self, name: &str, field_type: TypeSignature) -> Self {
self.fields.push(FieldDefinition {
name: name.to_string(),
field_type,
attributes: 0x0006, });
self
}
#[must_use]
pub fn static_field(mut self, name: &str, field_type: TypeSignature) -> Self {
self.fields.push(FieldDefinition {
name: name.to_string(),
field_type,
attributes: 0x0001 | 0x0010, });
self
}
#[must_use]
pub fn method<F>(mut self, builder_fn: F) -> Self
where
F: FnOnce(MethodBuilder) -> MethodBuilder,
{
let method_builder = builder_fn(MethodBuilder::new("method"));
self.methods.push(method_builder);
self
}
#[must_use]
pub fn auto_property(mut self, name: &str, property_type: TypeSignature) -> Self {
let backing_field_name = format!("<{name}>k__BackingField");
self.properties.push(PropertyDefinition {
name: name.to_string(),
property_type: property_type.clone(),
has_getter: true,
has_setter: true,
backing_field_name: Some(backing_field_name.clone()),
});
self.fields.push(FieldDefinition {
name: backing_field_name,
field_type: property_type,
attributes: 0x0001, });
self
}
#[must_use]
pub fn readonly_property(mut self, name: &str, property_type: TypeSignature) -> Self {
let backing_field_name = format!("<{name}>k__BackingField");
self.properties.push(PropertyDefinition {
name: name.to_string(),
property_type: property_type.clone(),
has_getter: true,
has_setter: false,
backing_field_name: Some(backing_field_name.clone()),
});
self.fields.push(FieldDefinition {
name: backing_field_name,
field_type: property_type,
attributes: 0x0001 | 0x0020, });
self
}
#[must_use]
pub fn default_constructor(mut self) -> Self {
self.generate_default_ctor = true;
self
}
pub fn build(self, assembly: &mut CilAssembly) -> Result<ChangeRefRc> {
if (self.flags & TypeAttributes::SEALED) != 0
&& (self.flags & TypeAttributes::ABSTRACT) != 0
{
return Err(Error::ModificationInvalid(
"Class cannot be both sealed and abstract (mutually exclusive flags per ECMA-335)"
.to_string(),
));
}
let typedef_ref = TypeDefBuilder::new()
.name(&self.name)
.namespace(self.namespace.as_deref().unwrap_or(""))
.flags(self.flags)
.extends(
self.extends
.unwrap_or_else(|| CodedIndex::null(CodedIndexType::TypeDefOrRef)),
) .build(assembly)?;
let mut field_refs: Vec<(String, ChangeRefRc)> = Vec::new();
for field_def in &self.fields {
let field_sig = SignatureField {
modifiers: Vec::new(),
base: field_def.field_type.clone(),
};
let sig_bytes = encode_field_signature(&field_sig)?;
let field_ref = FieldBuilder::new()
.name(&field_def.name)
.flags(field_def.attributes)
.signature(&sig_bytes)
.build(assembly)?;
field_refs.push((field_def.name.clone(), field_ref));
}
if self.generate_default_ctor {
let base_ctor_token = Token::new(0x0A00_0001);
MethodBuilder::constructor()
.implementation(move |body| {
body.implementation(move |asm| {
asm.ldarg_0()? .call(base_ctor_token)? .ret()?;
Ok(())
})
})
.build(assembly)?;
}
for prop_def in &self.properties {
if let Some(backing_field_name) = &prop_def.backing_field_name {
let backing_field_ref = field_refs
.iter()
.find(|(name, _)| name == backing_field_name)
.map(|(_, change_ref)| change_ref.clone())
.ok_or_else(|| {
Error::ModificationInvalid(format!(
"Backing field {backing_field_name} not found"
))
})?;
let backing_field_token =
backing_field_ref.placeholder_token().ok_or_else(|| {
Error::ModificationInvalid(
"Failed to get placeholder token for backing field".to_string(),
)
})?;
if prop_def.has_getter {
let getter_field_token = backing_field_token; MethodBuilder::property_getter(&prop_def.name, prop_def.property_type.clone())
.implementation(move |body| {
body.implementation(move |asm| {
asm.ldarg_0()? .ldfld(getter_field_token)? .ret()?;
Ok(())
})
})
.build(assembly)?;
}
if prop_def.has_setter {
let setter_field_token = backing_field_token; MethodBuilder::property_setter(&prop_def.name, prop_def.property_type.clone())
.implementation(move |body| {
body.implementation(move |asm| {
asm.ldarg_0()? .ldarg_1()? .stfld(setter_field_token)? .ret()?;
Ok(())
})
})
.build(assembly)?;
}
}
}
for method_builder in self.methods {
method_builder.build(assembly)?;
}
for interface_index in self.implements {
InterfaceImplBuilder::new()
.class(typedef_ref.placeholder())
.interface(interface_index)
.build(assembly)?;
}
Ok(typedef_ref)
}
}
impl Default for ClassBuilder {
fn default() -> Self {
Self::new("DefaultClass")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
cilassembly::{changes::ChangeRefKind, CilAssembly},
metadata::{cilassemblyview::CilAssemblyView, signatures::TypeSignature, tables::TableId},
};
use std::path::PathBuf;
fn get_test_assembly() -> Result<CilAssembly> {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
let view = CilAssemblyView::from_path(&path)?;
Ok(CilAssembly::new(view))
}
#[test]
fn test_simple_class() -> Result<()> {
let mut assembly = get_test_assembly()?;
let class_ref = ClassBuilder::new("SimpleClass")
.public()
.field("value", TypeSignature::I4)
.default_constructor()
.build(&mut assembly)?;
assert_eq!(class_ref.kind(), ChangeRefKind::TableRow(TableId::TypeDef));
Ok(())
}
#[test]
fn test_class_with_namespace() -> Result<()> {
let mut assembly = get_test_assembly()?;
let class_ref = ClassBuilder::new("MyClass")
.namespace("MyCompany.MyProduct")
.public()
.build(&mut assembly)?;
assert_eq!(class_ref.kind(), ChangeRefKind::TableRow(TableId::TypeDef));
Ok(())
}
#[test]
fn test_class_with_auto_properties() -> Result<()> {
let mut assembly = get_test_assembly()?;
let class_ref = ClassBuilder::new("Person")
.public()
.auto_property("Name", TypeSignature::String)
.auto_property("Age", TypeSignature::I4)
.default_constructor()
.build(&mut assembly)?;
assert_eq!(class_ref.kind(), ChangeRefKind::TableRow(TableId::TypeDef));
Ok(())
}
#[test]
fn test_class_with_methods() -> Result<()> {
let mut assembly = get_test_assembly()?;
let class_ref = ClassBuilder::new("Calculator")
.public()
.field("lastResult", TypeSignature::I4)
.method(|_m| {
MethodBuilder::new("Add")
.public()
.static_method()
.parameter("a", TypeSignature::I4)
.parameter("b", TypeSignature::I4)
.returns(TypeSignature::I4)
.implementation(|body| {
body.implementation(|asm| {
asm.ldarg_0()?.ldarg_1()?.add()?.ret()?;
Ok(())
})
})
})
.build(&mut assembly)?;
assert_eq!(class_ref.kind(), ChangeRefKind::TableRow(TableId::TypeDef));
Ok(())
}
#[test]
fn test_sealed_class() -> Result<()> {
let mut assembly = get_test_assembly()?;
let class_ref = ClassBuilder::new("SealedClass")
.public()
.sealed()
.build(&mut assembly)?;
assert_eq!(class_ref.kind(), ChangeRefKind::TableRow(TableId::TypeDef));
Ok(())
}
#[test]
fn test_abstract_class() -> Result<()> {
let mut assembly = get_test_assembly()?;
let class_ref = ClassBuilder::new("AbstractBase")
.public()
.abstract_class()
.build(&mut assembly)?;
assert_eq!(class_ref.kind(), ChangeRefKind::TableRow(TableId::TypeDef));
Ok(())
}
#[test]
fn test_class_with_static_fields() -> Result<()> {
let mut assembly = get_test_assembly()?;
let class_ref = ClassBuilder::new("Configuration")
.public()
.static_field("instance", TypeSignature::Object)
.public_field("settings", TypeSignature::String)
.build(&mut assembly)?;
assert_eq!(class_ref.kind(), ChangeRefKind::TableRow(TableId::TypeDef));
Ok(())
}
#[test]
fn test_class_with_readonly_property() -> Result<()> {
let mut assembly = get_test_assembly()?;
let class_ref = ClassBuilder::new("Circle")
.public()
.field("radius", TypeSignature::R8)
.readonly_property("Diameter", TypeSignature::R8)
.default_constructor()
.build(&mut assembly)?;
assert_eq!(class_ref.kind(), ChangeRefKind::TableRow(TableId::TypeDef));
Ok(())
}
#[test]
fn test_class_with_inheritance() -> Result<()> {
let mut assembly = get_test_assembly()?;
let base_class_index = CodedIndex::new(TableId::TypeRef, 1, CodedIndexType::TypeDefOrRef);
let class_ref = ClassBuilder::new("DerivedClass")
.public()
.inherits(base_class_index)
.default_constructor()
.build(&mut assembly)?;
assert_eq!(class_ref.kind(), ChangeRefKind::TableRow(TableId::TypeDef));
Ok(())
}
#[test]
fn test_class_with_interfaces() -> Result<()> {
let mut assembly = get_test_assembly()?;
let interface1 = CodedIndex::new(TableId::TypeRef, 2, CodedIndexType::TypeDefOrRef); let interface2 = CodedIndex::new(TableId::TypeRef, 3, CodedIndexType::TypeDefOrRef);
let class_ref = ClassBuilder::new("Implementation")
.public()
.implements(interface1)
.implements(interface2)
.build(&mut assembly)?;
assert_eq!(class_ref.kind(), ChangeRefKind::TableRow(TableId::TypeDef));
Ok(())
}
#[test]
fn test_sealed_and_abstract_mutually_exclusive() {
let mut assembly = get_test_assembly().unwrap();
let result = ClassBuilder::new("InvalidClass")
.public()
.sealed()
.abstract_class()
.build(&mut assembly);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("sealed and abstract"),
"Error should mention sealed and abstract conflict: {}",
err
);
}
}