use std::collections::HashMap;
use bon::Builder;
use rkyv::{Archive, Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Default, Builder, Archive, Serialize, Deserialize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[rkyv(derive(Debug))]
pub struct Database {
pub x86: Metadata,
pub x64: Metadata,
}
#[derive(Debug, Default, Builder, Archive, Serialize, Deserialize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[rkyv(derive(Debug))]
#[builder(finish_fn(name = build_impl, vis = ""))]
pub struct Metadata {
#[builder(default, with = FromIterator::from_iter)]
pub functions: Vec<Function>,
#[builder(default, with = FromIterator::from_iter)]
pub interfaces: Vec<Interface>,
#[cfg_attr(feature = "serde", serde(skip))]
#[builder(skip)]
pub functions_by_name: HashMap<String, usize>,
#[cfg_attr(feature = "serde", serde(skip))]
#[builder(skip)]
pub interfaces_by_name: HashMap<String, usize>,
#[cfg_attr(feature = "serde", serde(skip))]
#[builder(skip)]
pub interfaces_by_uuid: HashMap<Uuid, usize>,
}
impl<S: metadata_builder::State> MetadataBuilder<S>
where
S: metadata_builder::IsComplete,
{
pub fn build(self) -> Metadata {
let mut metadata = self.build_impl();
metadata.functions.sort_by(|a, b| a.name.cmp(&b.name));
metadata.interfaces.sort_by(|a, b| a.name.cmp(&b.name));
metadata.functions_by_name = metadata
.functions
.iter()
.enumerate()
.map(|(index, function)| (function.name.clone(), index))
.collect();
metadata.interfaces_by_name = metadata
.interfaces
.iter()
.enumerate()
.map(|(index, interface)| (interface.name.clone(), index))
.collect();
metadata.interfaces_by_uuid = metadata
.interfaces
.iter()
.enumerate()
.map(|(index, interface)| (interface.uuid, index))
.collect();
metadata
}
}
#[derive(Debug, Builder, Archive, Serialize, Deserialize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[rkyv(derive(Debug))]
#[builder(finish_fn(name = build_impl, vis = ""))]
pub struct Function {
#[builder(into)]
pub name: String,
#[builder(default)]
pub parameters: Vec<Parameter>,
#[builder(default)]
pub buffers: Vec<Buffer>,
pub return_ty: Type,
#[cfg_attr(feature = "serde", serde(skip))]
#[builder(skip)]
pub output_parameter_indices: Vec<u8>,
#[cfg_attr(feature = "serde", serde(skip))]
#[builder(skip)]
pub input_buffer_indices: Vec<u8>,
#[cfg_attr(feature = "serde", serde(skip))]
#[builder(skip)]
pub output_buffer_indices: Vec<u8>,
}
impl<S: function_builder::State> FunctionBuilder<S>
where
S: function_builder::IsComplete,
{
pub fn build(self) -> Function {
let mut function = self.build_impl();
function.output_parameter_indices = function
.parameters
.iter()
.enumerate()
.filter_map(|(index, parameter)| {
parameter
.flags
.contains(ParameterFlags::HAS_OUT_ATTRIBUTE)
.then_some(index as u8)
})
.collect();
for buffer in &mut function.buffers {
buffer.position = match buffer.direction {
BufferDirection::Input => buffer.parameter,
BufferDirection::Output => function
.output_parameter_indices
.iter()
.copied()
.position(|index| index == buffer.parameter)
.expect("output buffer references parameter without HAS_OUT_ATTRIBUTE")
as u8,
};
}
function.input_buffer_indices = function
.buffers
.iter()
.enumerate()
.filter_map(|(index, buffer)| {
matches!(buffer.direction, BufferDirection::Input).then_some(index as u8)
})
.collect();
function.output_buffer_indices = function
.buffers
.iter()
.enumerate()
.filter_map(|(index, buffer)| {
matches!(buffer.direction, BufferDirection::Output).then_some(index as u8)
})
.collect();
function
}
}
#[derive(Debug, Builder, Archive, Serialize, Deserialize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[rkyv(derive(Debug))]
pub struct Interface {
#[builder(into)]
pub name: String,
pub uuid: Uuid,
#[builder(into)]
pub base: Option<String>,
#[builder(default)]
pub methods: Vec<Function>,
}
#[derive(Debug, Builder, Archive, Serialize, Deserialize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[rkyv(derive(Debug))]
pub struct Parameter {
#[builder(into)]
pub name: Option<String>,
#[rkyv(with = crate::asbits::AsBits)]
#[builder(default = ParameterFlags::empty())]
pub flags: ParameterFlags,
pub ty: Type,
}
bitflags::bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct ParameterFlags: u8 {
const HAS_IN_ATTRIBUTE = 0x01;
const HAS_OUT_ATTRIBUTE = 0x02;
const HAS_COM_ATTRIBUTE = 0x04;
}
}
#[derive(Debug, Builder, Archive, Serialize, Deserialize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[rkyv(derive(Debug))]
pub struct Type {
#[builder(default)]
pub indirections: u8,
#[builder(into)]
pub name: String,
pub kind: TypeKind,
}
impl Type {
pub fn void_pointer() -> Self {
Self {
indirections: 1,
name: String::from("void"),
kind: TypeKind::Void,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Archive, Serialize, Deserialize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[rkyv(derive(Debug))]
#[expect(missing_docs, reason = "self-explanatory")]
pub enum TypeKind {
Unknown,
Void,
Bool,
Char8,
Char16,
I8,
I16,
I32,
I64,
U8,
U16,
U32,
U64,
F32,
F64,
Custom(u8),
}
#[derive(Debug, Builder, Archive, Serialize, Deserialize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[rkyv(derive(Debug))]
pub struct Buffer {
pub parameter: u8,
#[cfg_attr(feature = "serde", serde(skip))]
#[builder(skip)]
pub position: u8,
pub length: Expression,
pub direction: BufferDirection,
pub phase: BufferPhase,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Archive, Serialize, Deserialize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[rkyv(derive(Debug, PartialEq, Eq))]
pub enum BufferDirection {
Input,
Output,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Archive, Serialize, Deserialize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[rkyv(derive(Debug, PartialEq, Eq))]
pub enum BufferPhase {
Pre,
Post,
}
#[derive(Debug, PartialEq, Eq, Archive, Serialize, Deserialize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[rkyv(
derive(Debug, PartialEq, Eq),
serialize_bounds(
__S: rkyv::ser::Writer + rkyv::ser::Allocator,
__S::Error: rkyv::rancor::Source,
),
deserialize_bounds(
__D::Error: rkyv::rancor::Source
),
bytecheck(bounds(
__C: rkyv::validation::ArchiveContext,
__C::Error: rkyv::rancor::Source,
)
))]
pub enum Expression {
Return,
Constant(u64),
Parameter(u8),
UnaryExpression(#[rkyv(omit_bounds)] Box<UnaryExpression>),
BinaryExpression(#[rkyv(omit_bounds)] Box<BinaryExpression>),
}
#[derive(Debug, PartialEq, Eq, Builder, Archive, Serialize, Deserialize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[rkyv(
derive(Debug, PartialEq, Eq),
serialize_bounds(
__S: rkyv::ser::Writer + rkyv::ser::Allocator,
__S::Error: rkyv::rancor::Source,
),
deserialize_bounds(
__D::Error: rkyv::rancor::Source
),
bytecheck(bounds(
__C: rkyv::validation::ArchiveContext,
__C::Error: rkyv::rancor::Source,
))
)]
pub struct UnaryExpression {
pub operator: UnaryOperator,
#[rkyv(omit_bounds)]
pub expression: Expression,
}
#[derive(Debug, PartialEq, Eq, Builder, Archive, Serialize, Deserialize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[rkyv(
derive(Debug, PartialEq, Eq),
serialize_bounds(
__S: rkyv::ser::Writer + rkyv::ser::Allocator,
__S::Error: rkyv::rancor::Source,
),
deserialize_bounds(
__D::Error: rkyv::rancor::Source
),
bytecheck(bounds(
__C: rkyv::validation::ArchiveContext,
__C::Error: rkyv::rancor::Source,
))
)]
pub struct BinaryExpression {
pub operator: BinaryOperator,
#[rkyv(omit_bounds)]
pub lhs: Expression,
#[rkyv(omit_bounds)]
pub rhs: Expression,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Archive, Serialize, Deserialize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[rkyv(derive(Debug, PartialEq, Eq))]
pub enum UnaryOperator {
Dereference,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Archive, Serialize, Deserialize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[rkyv(derive(Debug, PartialEq, Eq))]
pub enum BinaryOperator {
Add,
Subtract,
Multiply,
Divide,
}
#[cfg(test)]
mod tests {
use rkyv::rancor::Error as RkyvError;
use super::*;
fn sample_database() -> Database {
let func = Function::builder()
.name("ReadFile")
.parameters(vec![
Parameter::builder()
.name("hFile")
.flags(ParameterFlags::HAS_IN_ATTRIBUTE)
.ty(Type::builder().name("HANDLE").kind(TypeKind::U64).build())
.build(),
Parameter::builder()
.name("lpBuffer")
.flags(ParameterFlags::HAS_OUT_ATTRIBUTE)
.ty(Type::builder()
.indirections(1)
.name("BYTE")
.kind(TypeKind::U8)
.build())
.build(),
Parameter::builder()
.name("nBytes")
.flags(ParameterFlags::HAS_IN_ATTRIBUTE)
.ty(Type::builder().name("DWORD").kind(TypeKind::U32).build())
.build(),
Parameter::builder()
.name("lpRead")
.flags(ParameterFlags::HAS_OUT_ATTRIBUTE)
.ty(Type::builder()
.indirections(1)
.name("DWORD")
.kind(TypeKind::U32)
.build())
.build(),
])
.buffers(vec![
Buffer::builder()
.parameter(1)
.length(Expression::Parameter(2))
.direction(BufferDirection::Output)
.phase(BufferPhase::Post)
.build(),
Buffer::builder()
.parameter(2)
.length(Expression::Constant(4))
.direction(BufferDirection::Input)
.phase(BufferPhase::Pre)
.build(),
])
.return_ty(Type::builder().name("BOOL").kind(TypeKind::I32).build())
.build();
Database::builder()
.x86(Metadata::builder().functions(vec![func]).build())
.x64(Metadata::default())
.build()
}
#[test]
fn rkyv_round_trip() {
let db = sample_database();
let bytes = rkyv::to_bytes::<RkyvError>(&db).expect("serialize");
let restored = rkyv::from_bytes::<Database, RkyvError>(&bytes).expect("deserialize");
assert_eq!(restored.x86.functions.len(), 1);
assert_eq!(restored.x86.functions[0].name, "ReadFile");
assert_eq!(
restored.x86.functions[0].parameters[0].name.as_deref(),
Some("hFile")
);
assert!(
restored.x86.functions[0].parameters[0]
.flags
.contains(ParameterFlags::HAS_IN_ATTRIBUTE)
);
}
#[test]
fn function_builder_bakes_indices() {
let db = sample_database();
let func = &db.x86.functions[0];
assert_eq!(func.output_parameter_indices, vec![1, 3]);
assert_eq!(func.input_buffer_indices, vec![1]);
assert_eq!(func.output_buffer_indices, vec![0]);
}
#[test]
fn baked_indices_survive_rkyv_round_trip() {
let db = sample_database();
let bytes = rkyv::to_bytes::<RkyvError>(&db).expect("serialize");
let restored = rkyv::from_bytes::<Database, RkyvError>(&bytes).expect("deserialize");
let func = &restored.x86.functions[0];
assert_eq!(func.output_parameter_indices, vec![1, 3]);
assert_eq!(func.input_buffer_indices, vec![1]);
assert_eq!(func.output_buffer_indices, vec![0]);
}
}