use crate::{Error, Result};
#[allow(non_snake_case)]
pub mod NATIVE_TYPE {
pub const END: u8 = 0x00;
pub const VOID: u8 = 0x01;
pub const BOOLEAN: u8 = 0x02;
pub const I1: u8 = 0x03;
pub const U1: u8 = 0x04;
pub const I2: u8 = 0x05;
pub const U2: u8 = 0x06;
pub const I4: u8 = 0x07;
pub const U4: u8 = 0x08;
pub const I8: u8 = 0x09;
pub const U8: u8 = 0x0a;
pub const R4: u8 = 0x0b;
pub const R8: u8 = 0x0c;
pub const SYSCHAR: u8 = 0x0d;
pub const VARIANT: u8 = 0x0e;
pub const CURRENCY: u8 = 0x0f;
pub const PTR: u8 = 0x10;
pub const DECIMAL: u8 = 0x11;
pub const DATE: u8 = 0x12;
pub const BSTR: u8 = 0x13;
pub const LPSTR: u8 = 0x14;
pub const LPWSTR: u8 = 0x15;
pub const LPTSTR: u8 = 0x16;
pub const FIXEDSYSSTRING: u8 = 0x17;
pub const OBJECTREF: u8 = 0x18;
pub const IUNKNOWN: u8 = 0x19;
pub const IDISPATCH: u8 = 0x1a;
pub const STRUCT: u8 = 0x1b;
pub const INTERFACE: u8 = 0x1c;
pub const SAFEARRAY: u8 = 0x1d;
pub const FIXEDARRAY: u8 = 0x1e;
pub const INT: u8 = 0x1f;
pub const UINT: u8 = 0x20;
pub const NESTEDSTRUCT: u8 = 0x21;
pub const BYVALSTR: u8 = 0x22;
pub const ANSIBSTR: u8 = 0x23;
pub const TBSTR: u8 = 0x24;
pub const VARIANTBOOL: u8 = 0x25;
pub const FUNC: u8 = 0x26;
pub const ASANY: u8 = 0x28;
pub const ARRAY: u8 = 0x2a;
pub const LPSTRUCT: u8 = 0x2b;
pub const CUSTOMMARSHALER: u8 = 0x2c;
pub const ERROR: u8 = 0x2d;
pub const IINSPECTABLE: u8 = 0x2e;
pub const HSTRING: u8 = 0x2f;
pub const LPUTF8STR: u8 = 0x30;
pub const MAX: u8 = 0x50;
}
#[allow(non_snake_case)]
pub mod VARIANT_TYPE {
pub const EMPTY: u16 = 0;
pub const NULL: u16 = 1;
pub const I2: u16 = 2;
pub const I4: u16 = 3;
pub const R4: u16 = 4;
pub const R8: u16 = 5;
pub const CY: u16 = 6;
pub const DATE: u16 = 7;
pub const BSTR: u16 = 8;
pub const DISPATCH: u16 = 9;
pub const ERROR: u16 = 10;
pub const BOOL: u16 = 11;
pub const VARIANT: u16 = 12;
pub const UNKNOWN: u16 = 13;
pub const DECIMAL: u16 = 14;
pub const I1: u16 = 16;
pub const UI1: u16 = 17;
pub const UI2: u16 = 18;
pub const UI4: u16 = 19;
pub const I8: u16 = 20;
pub const UI8: u16 = 21;
pub const INT: u16 = 22;
pub const UINT: u16 = 23;
pub const VOID: u16 = 24;
pub const HRESULT: u16 = 25;
pub const PTR: u16 = 26;
pub const SAFEARRAY: u16 = 27;
pub const CARRAY: u16 = 28;
pub const USERDEFINED: u16 = 29;
pub const LPSTR: u16 = 30;
pub const LPWSTR: u16 = 31;
pub const RECORD: u16 = 36;
pub const INT_PTR: u16 = 37;
pub const UINT_PTR: u16 = 38;
pub const FILETIME: u16 = 64;
pub const BLOB: u16 = 65;
pub const STREAM: u16 = 66;
pub const STORAGE: u16 = 67;
pub const STREAMED_OBJECT: u16 = 68;
pub const STORED_OBJECT: u16 = 69;
pub const BLOB_OBJECT: u16 = 70;
pub const CF: u16 = 71;
pub const CLSID: u16 = 72;
pub const VECTOR: u16 = 0x1000;
pub const ARRAY: u16 = 0x2000;
pub const BYREF: u16 = 0x4000;
pub const TYPEMASK: u16 = 0xfff;
}
#[derive(Debug, PartialEq, Clone)]
pub struct MarshallingInfo {
pub primary_type: NativeType,
pub additional_types: Vec<NativeType>,
}
impl std::fmt::Display for MarshallingInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.primary_type)?;
if !self.additional_types.is_empty() {
write!(f, " + [")?;
for (i, t) in self.additional_types.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{t}")?;
}
write!(f, "]")?;
}
Ok(())
}
}
impl MarshallingInfo {
pub fn validate(&self) -> Result<()> {
Self::validate_native_type(&self.primary_type)?;
for additional in &self.additional_types {
Self::validate_native_type(additional)?;
}
Ok(())
}
fn validate_native_type(native_type: &NativeType) -> Result<()> {
match native_type {
NativeType::Struct {
packing_size,
class_size,
} => {
if packing_size.is_none() && class_size.is_some() {
return Err(Error::MarshallingError(
"Struct: class_size cannot be set without packing_size \
(sequential encoding constraint)"
.to_string(),
));
}
Ok(())
}
NativeType::Array {
element_type,
num_param,
num_element,
} => {
if num_param.is_none() && num_element.is_some() {
return Err(Error::MarshallingError(
"Array: num_element cannot be set without num_param \
(sequential encoding constraint)"
.to_string(),
));
}
Self::validate_native_type(element_type)?;
Ok(())
}
NativeType::FixedArray { element_type, .. } => {
if let Some(et) = element_type {
Self::validate_native_type(et)?;
}
Ok(())
}
NativeType::Ptr { ref_type } => {
if let Some(rt) = ref_type {
Self::validate_native_type(rt)?;
}
Ok(())
}
_ => Ok(()),
}
}
}
#[derive(Debug, PartialEq, Clone)]
pub enum NativeType {
Void,
Boolean,
I1,
U1,
I2,
U2,
I4,
U4,
I8,
U8,
R4,
R8,
SysChar,
Variant,
Currency,
Decimal,
Date,
Int,
UInt,
Error,
BStr,
LPStr {
size_param_index: Option<u32>,
},
LPWStr {
size_param_index: Option<u32>,
},
LPTStr {
size_param_index: Option<u32>,
},
LPUtf8Str {
size_param_index: Option<u32>,
},
FixedSysString {
size: u32,
},
AnsiBStr,
TBStr,
ByValStr {
size: u32,
},
VariantBool,
FixedArray {
size: u32,
element_type: Option<Box<NativeType>>,
},
Array {
element_type: Box<NativeType>,
num_param: Option<u32>,
num_element: Option<u32>,
},
SafeArray {
variant_type: u16,
user_defined_name: Option<String>,
},
Ptr {
ref_type: Option<Box<NativeType>>,
},
IUnknown,
IDispatch,
IInspectable,
Interface {
iid_param_index: Option<u32>,
},
Struct {
packing_size: Option<u8>,
class_size: Option<u32>,
},
NestedStruct,
LPStruct,
CustomMarshaler {
guid: String,
native_type_name: String,
cookie: String,
type_reference: String,
},
ObjectRef,
Func,
AsAny,
HString,
End,
}
impl std::fmt::Display for NativeType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
NativeType::Void => write!(f, "void"),
NativeType::Boolean => write!(f, "bool"),
NativeType::I1 => write!(f, "i1"),
NativeType::U1 => write!(f, "u1"),
NativeType::I2 => write!(f, "i2"),
NativeType::U2 => write!(f, "u2"),
NativeType::I4 => write!(f, "i4"),
NativeType::U4 => write!(f, "u4"),
NativeType::I8 => write!(f, "i8"),
NativeType::U8 => write!(f, "u8"),
NativeType::R4 => write!(f, "r4"),
NativeType::R8 => write!(f, "r8"),
NativeType::SysChar => write!(f, "syschar"),
NativeType::Variant => write!(f, "variant"),
NativeType::Currency => write!(f, "currency"),
NativeType::Decimal => write!(f, "decimal"),
NativeType::Date => write!(f, "date"),
NativeType::Int => write!(f, "int"),
NativeType::UInt => write!(f, "uint"),
NativeType::Error => write!(f, "error"),
NativeType::BStr => write!(f, "bstr"),
NativeType::LPStr { size_param_index } => {
write!(f, "lpstr")?;
if let Some(idx) = size_param_index {
write!(f, "(size_param={idx})")?;
}
Ok(())
}
NativeType::LPWStr { size_param_index } => {
write!(f, "lpwstr")?;
if let Some(idx) = size_param_index {
write!(f, "(size_param={idx})")?;
}
Ok(())
}
NativeType::LPTStr { size_param_index } => {
write!(f, "lptstr")?;
if let Some(idx) = size_param_index {
write!(f, "(size_param={idx})")?;
}
Ok(())
}
NativeType::LPUtf8Str { size_param_index } => {
write!(f, "lputf8str")?;
if let Some(idx) = size_param_index {
write!(f, "(size_param={idx})")?;
}
Ok(())
}
NativeType::FixedSysString { size } => write!(f, "fixed sysstring[{size}]"),
NativeType::AnsiBStr => write!(f, "ansi bstr"),
NativeType::TBStr => write!(f, "tbstr"),
NativeType::ByValStr { size } => write!(f, "byvalstr[{size}]"),
NativeType::VariantBool => write!(f, "variant bool"),
NativeType::FixedArray { size, element_type } => {
write!(f, "fixed array[{size}]")?;
if let Some(et) = element_type {
write!(f, " of {et}")?;
}
Ok(())
}
NativeType::Array {
element_type,
num_param,
num_element,
} => {
write!(f, "array of {element_type}")?;
match (num_param, num_element) {
(Some(p), Some(e)) => write!(f, "(param={p}, count={e})"),
(Some(p), None) => write!(f, "(param={p})"),
(None, Some(e)) => write!(f, "(count={e})"),
(None, None) => Ok(()),
}
}
NativeType::SafeArray {
variant_type,
user_defined_name,
} => {
write!(f, "safearray(vt=0x{variant_type:04X})")?;
if let Some(name) = user_defined_name {
write!(f, " of {name}")?;
}
Ok(())
}
NativeType::Ptr { ref_type } => {
write!(f, "ptr")?;
if let Some(rt) = ref_type {
write!(f, " to {rt}")?;
}
Ok(())
}
NativeType::IUnknown => write!(f, "iunknown"),
NativeType::IDispatch => write!(f, "idispatch"),
NativeType::IInspectable => write!(f, "iinspectable"),
NativeType::Interface { iid_param_index } => {
write!(f, "interface")?;
if let Some(idx) = iid_param_index {
write!(f, "(iid_param={idx})")?;
}
Ok(())
}
NativeType::Struct {
packing_size,
class_size,
} => {
write!(f, "struct")?;
match (packing_size, class_size) {
(Some(p), Some(s)) => write!(f, "(pack={p}, size={s})"),
(Some(p), None) => write!(f, "(pack={p})"),
(None, Some(s)) => write!(f, "(size={s})"),
(None, None) => Ok(()),
}
}
NativeType::NestedStruct => write!(f, "nested struct"),
NativeType::LPStruct => write!(f, "lpstruct"),
NativeType::CustomMarshaler {
guid,
native_type_name,
cookie,
type_reference,
} => {
write!(f, "custom marshaler({type_reference}")?;
if !guid.is_empty() {
write!(f, ", guid={guid}")?;
}
if !native_type_name.is_empty() {
write!(f, ", native={native_type_name}")?;
}
if !cookie.is_empty() {
write!(f, ", cookie={cookie}")?;
}
write!(f, ")")
}
NativeType::ObjectRef => write!(f, "objectref"),
NativeType::Func => write!(f, "func"),
NativeType::AsAny => write!(f, "asany"),
NativeType::HString => write!(f, "hstring"),
NativeType::End => write!(f, "end"),
}
}
}
impl NativeType {
#[must_use]
pub fn has_parameters(&self) -> bool {
matches!(
self,
NativeType::LPStr { .. }
| NativeType::LPWStr { .. }
| NativeType::LPTStr { .. }
| NativeType::LPUtf8Str { .. }
| NativeType::FixedSysString { .. }
| NativeType::ByValStr { .. }
| NativeType::FixedArray { .. }
| NativeType::Array { .. }
| NativeType::SafeArray { .. }
| NativeType::Ptr { .. }
| NativeType::Interface { .. }
| NativeType::Struct { .. }
| NativeType::CustomMarshaler { .. }
)
}
}
pub const MAX_RECURSION_DEPTH: usize = 50;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_native_type_display_simple_types() {
assert_eq!(format!("{}", NativeType::Void), "void");
assert_eq!(format!("{}", NativeType::Boolean), "bool");
assert_eq!(format!("{}", NativeType::I4), "i4");
assert_eq!(format!("{}", NativeType::Int), "int");
assert_eq!(format!("{}", NativeType::BStr), "bstr");
}
#[test]
fn test_native_type_display_with_params() {
assert_eq!(
format!(
"{}",
NativeType::LPStr {
size_param_index: Some(5)
}
),
"lpstr(size_param=5)"
);
assert_eq!(
format!(
"{}",
NativeType::LPStr {
size_param_index: None
}
),
"lpstr"
);
}
#[test]
fn test_native_type_display_array() {
assert_eq!(
format!(
"{}",
NativeType::Array {
element_type: Box::new(NativeType::I4),
num_param: Some(3),
num_element: Some(10),
}
),
"array of i4(param=3, count=10)"
);
}
#[test]
fn test_native_type_display_struct() {
assert_eq!(
format!(
"{}",
NativeType::Struct {
packing_size: Some(4),
class_size: Some(128),
}
),
"struct(pack=4, size=128)"
);
}
#[test]
fn test_marshalling_info_display() {
let info = MarshallingInfo {
primary_type: NativeType::I4,
additional_types: vec![],
};
assert_eq!(format!("{info}"), "i4");
let info_with_additional = MarshallingInfo {
primary_type: NativeType::LPStr {
size_param_index: None,
},
additional_types: vec![NativeType::Boolean, NativeType::I4],
};
assert_eq!(format!("{info_with_additional}"), "lpstr + [bool, i4]");
}
#[test]
fn test_marshalling_info_validate_valid() {
let valid = MarshallingInfo {
primary_type: NativeType::Struct {
packing_size: Some(4),
class_size: Some(128),
},
additional_types: vec![],
};
assert!(valid.validate().is_ok());
let valid2 = MarshallingInfo {
primary_type: NativeType::Struct {
packing_size: Some(4),
class_size: None,
},
additional_types: vec![],
};
assert!(valid2.validate().is_ok());
let valid3 = MarshallingInfo {
primary_type: NativeType::Struct {
packing_size: None,
class_size: None,
},
additional_types: vec![],
};
assert!(valid3.validate().is_ok());
}
#[test]
fn test_marshalling_info_validate_invalid_struct() {
let invalid = MarshallingInfo {
primary_type: NativeType::Struct {
packing_size: None,
class_size: Some(128),
},
additional_types: vec![],
};
assert!(invalid.validate().is_err());
}
#[test]
fn test_marshalling_info_validate_invalid_array() {
let invalid = MarshallingInfo {
primary_type: NativeType::Array {
element_type: Box::new(NativeType::I4),
num_param: None,
num_element: Some(10),
},
additional_types: vec![],
};
assert!(invalid.validate().is_err());
}
}