use std::collections::HashMap;
use sigmd::model::{
BinaryExpression, BinaryOperator, Buffer, BufferDirection, BufferPhase, Expression, Parameter,
UnaryExpression, UnaryOperator,
};
use super::{annotation::Annotation, parse};
use crate::cli::compiler::BuildContext;
pub fn analyze(
parameters: &[Parameter],
annotations_per_parameter: &[Vec<Annotation<'_>>],
ctx: &BuildContext,
) -> Vec<Buffer> {
let resolver = Resolver::new(parameters, &ctx.sizeofs);
let mut produced = Vec::new();
for (index, annotations) in annotations_per_parameter.iter().enumerate() {
assert!(index <= u8::MAX as usize);
let parameter_index = index as u8;
for annotation in annotations {
produced.extend(buffers_for(parameter_index, annotation, &resolver));
}
}
let mut groups = Vec::<(BufferKey, Vec<Buffer>)>::new();
for buffer in produced {
let key = BufferKey::from(&buffer);
match groups.iter_mut().find(|(other, _)| *other == key) {
Some((_, group)) => group.push(buffer),
None => groups.push((key, vec![buffer])),
}
}
let mut output = Vec::new();
for (key, group) in groups {
let count = group.len();
let mut iter = group.into_iter();
let first = iter.next().expect("group is non-empty by construction");
if iter.all(|buffer| buffer.length == first.length) {
output.push(first);
}
else {
tracing::debug!(
parameter = key.parameter,
direction = ?key.direction,
phase = ?key.phase,
variants = count,
"dropping ambiguous buffer group with conflicting lengths"
);
}
}
output
}
#[derive(Clone, Copy, PartialEq, Eq)]
struct BufferKey {
parameter: u8,
direction: BufferDirection,
phase: BufferPhase,
}
impl From<&Buffer> for BufferKey {
fn from(buffer: &Buffer) -> Self {
Self {
parameter: buffer.parameter,
direction: buffer.direction,
phase: buffer.phase,
}
}
}
struct Resolver<'a> {
parameter_by_name: HashMap<&'a str, u8>,
sizeofs: &'a HashMap<String, u64>,
}
impl<'a> Resolver<'a> {
fn new(parameters: &'a [Parameter], sizeofs: &'a HashMap<String, u64>) -> Self {
let mut parameter_by_name = HashMap::with_capacity(parameters.len());
for (index, parameter) in parameters.iter().enumerate() {
assert!(index <= u8::MAX as usize);
if let Some(name) = parameter.name.as_deref() {
parameter_by_name.insert(name, index as u8);
}
}
Self {
parameter_by_name,
sizeofs,
}
}
fn parameter(&self, name: &str) -> Option<u8> {
let index = self.parameter_by_name.get(name).copied();
if index.is_none() {
tracing::debug!(parameter = name, "SAL referenced unknown parameter");
}
index
}
fn sizeof(&self, name: &str) -> Option<u64> {
let size = self.sizeofs.get(name).copied();
if size.is_none() {
tracing::debug!(ty = name, "sizeof() of unknown type");
}
size
}
}
struct Emit {
arg: usize,
direction: BufferDirection,
phase: BufferPhase,
}
struct Rule {
name: &'static str,
emits: &'static [Emit],
}
const RULES: &[Rule] = &[
Rule {
name: "_In_reads_bytes_",
emits: &[Emit {
arg: 0,
direction: BufferDirection::Input,
phase: BufferPhase::Pre,
}],
},
Rule {
name: "_In_reads_bytes_opt_",
emits: &[Emit {
arg: 0,
direction: BufferDirection::Input,
phase: BufferPhase::Pre,
}],
},
Rule {
name: "_Out_writes_bytes_",
emits: &[Emit {
arg: 0,
direction: BufferDirection::Output,
phase: BufferPhase::Pre,
}],
},
Rule {
name: "_Out_writes_bytes_opt_",
emits: &[Emit {
arg: 0,
direction: BufferDirection::Output,
phase: BufferPhase::Pre,
}],
},
Rule {
name: "_Inout_updates_bytes_",
emits: &[Emit {
arg: 0,
direction: BufferDirection::Output,
phase: BufferPhase::Pre,
}],
},
Rule {
name: "_Inout_updates_bytes_opt_",
emits: &[Emit {
arg: 0,
direction: BufferDirection::Output,
phase: BufferPhase::Pre,
}],
},
Rule {
name: "_Out_writes_bytes_to_",
emits: &[Emit {
arg: 1,
direction: BufferDirection::Output,
phase: BufferPhase::Post,
}],
},
Rule {
name: "_Out_writes_bytes_to_opt_",
emits: &[Emit {
arg: 1,
direction: BufferDirection::Output,
phase: BufferPhase::Post,
}],
},
Rule {
name: "_Inout_updates_bytes_to_",
emits: &[
Emit {
arg: 0,
direction: BufferDirection::Input,
phase: BufferPhase::Pre,
},
Emit {
arg: 1,
direction: BufferDirection::Output,
phase: BufferPhase::Post,
},
],
},
Rule {
name: "_Inout_updates_bytes_to_opt_",
emits: &[
Emit {
arg: 0,
direction: BufferDirection::Input,
phase: BufferPhase::Pre,
},
Emit {
arg: 1,
direction: BufferDirection::Output,
phase: BufferPhase::Post,
},
],
},
];
fn buffers_for(
parameter_index: u8,
annotation: &Annotation<'_>,
resolver: &Resolver<'_>,
) -> Vec<Buffer> {
let rule = match RULES.iter().find(|rule| rule.name == annotation.name) {
Some(rule) => rule,
None => return Vec::new(),
};
rule.emits
.iter()
.filter_map(|emit| build_buffer(parameter_index, annotation, emit, resolver))
.collect()
}
fn build_buffer(
parameter_index: u8,
annotation: &Annotation<'_>,
emit: &Emit,
resolver: &Resolver<'_>,
) -> Option<Buffer> {
let arg = match annotation.args.get(emit.arg) {
Some(arg) => arg,
None => {
tracing::debug!(
annotation = annotation.name,
arg = emit.arg,
"missing SAL argument"
);
return None;
}
};
let parsed = match parse::parse(arg) {
Ok(parsed) => parsed,
Err(err) => {
tracing::debug!(
%err,
annotation = annotation.name,
"SAL arg parse failed"
);
return None;
}
};
let length = lower(parsed, resolver)?;
Some(
Buffer::builder()
.parameter(parameter_index)
.length(length)
.direction(emit.direction)
.phase(emit.phase)
.build(),
)
}
fn lower(expr: parse::Expression<'_>, resolver: &Resolver<'_>) -> Option<Expression> {
match expr {
parse::Expression::Return => Some(Expression::Return),
parse::Expression::Constant(value) => Some(Expression::Constant(value)),
parse::Expression::Identifier(name) => {
Some(Expression::Parameter(resolver.parameter(name)?))
}
parse::Expression::UnaryExpression(unary) => lower_unary(unary, resolver),
parse::Expression::BinaryExpression(binary) => lower_binary(binary, resolver),
}
}
fn lower_unary(unary: parse::UnaryExpression<'_>, resolver: &Resolver<'_>) -> Option<Expression> {
match unary.operator {
parse::UnaryOperator::SizeOf => {
let name = match *unary.expression {
parse::Expression::Identifier(name) => name,
_ => {
tracing::debug!("sizeof() requires a bare identifier");
return None;
}
};
Some(Expression::Constant(resolver.sizeof(name)?))
}
parse::UnaryOperator::Dereference => {
let inner = lower(*unary.expression, resolver)?;
Some(Expression::UnaryExpression(Box::new(
UnaryExpression::builder()
.operator(UnaryOperator::Dereference)
.expression(inner)
.build(),
)))
}
}
}
fn lower_binary(
binary: parse::BinaryExpression<'_>,
resolver: &Resolver<'_>,
) -> Option<Expression> {
let lhs = lower(*binary.lhs, resolver)?;
let rhs = lower(*binary.rhs, resolver)?;
let operator = match binary.operator {
parse::BinaryOperator::Add => BinaryOperator::Add,
parse::BinaryOperator::Subtract => BinaryOperator::Subtract,
parse::BinaryOperator::Multiply => BinaryOperator::Multiply,
parse::BinaryOperator::Divide => BinaryOperator::Divide,
};
Some(Expression::BinaryExpression(Box::new(
BinaryExpression::builder()
.operator(operator)
.lhs(lhs)
.rhs(rhs)
.build(),
)))
}
#[cfg(test)]
mod tests {
use sigmd::model::{Type, TypeKind};
use super::*;
fn param(name: &str, kind: TypeKind) -> Parameter {
Parameter::builder()
.name(name)
.ty(Type::builder().name(name).kind(kind).build())
.build()
}
#[test]
fn out_writes_bytes_to_emits_post_output_buffer() {
let parameters = vec![
param("hFile", TypeKind::U64),
param("lpBuffer", TypeKind::U8),
param("nNumberOfBytesToRead", TypeKind::U32),
param("lpNumberOfBytesRead", TypeKind::U32),
];
let annotations = vec![
vec![],
vec![Annotation {
name: "_Out_writes_bytes_to_",
args: vec!["nNumberOfBytesToRead", "*lpNumberOfBytesRead"],
}],
vec![],
vec![],
];
let buffers = analyze(¶meters, &annotations, &BuildContext::default());
assert_eq!(buffers.len(), 1);
assert_eq!(buffers[0].parameter, 1);
assert_eq!(buffers[0].direction, BufferDirection::Output);
assert_eq!(buffers[0].phase, BufferPhase::Post);
match &buffers[0].length {
Expression::UnaryExpression(unary) => {
assert!(matches!(unary.operator, UnaryOperator::Dereference));
assert!(matches!(unary.expression, Expression::Parameter(3)));
}
other => panic!("unexpected length expression: {other:?}"),
}
}
#[test]
fn inout_updates_bytes_to_emits_two_buffers() {
let parameters = vec![
param("buf", TypeKind::U8),
param("nIn", TypeKind::U32),
param("pnOut", TypeKind::U32),
];
let annotations = vec![
vec![Annotation {
name: "_Inout_updates_bytes_to_",
args: vec!["nIn", "*pnOut"],
}],
vec![],
vec![],
];
let buffers = analyze(¶meters, &annotations, &BuildContext::default());
assert_eq!(buffers.len(), 2);
assert_eq!(buffers[0].direction, BufferDirection::Input);
assert_eq!(buffers[0].phase, BufferPhase::Pre);
assert_eq!(buffers[1].direction, BufferDirection::Output);
assert_eq!(buffers[1].phase, BufferPhase::Post);
}
#[test]
fn sizeof_resolves_to_constant() {
let parameters = vec![param("buf", TypeKind::U8)];
let annotations = vec![vec![Annotation {
name: "_In_reads_bytes_",
args: vec!["sizeof(WIN32_FIND_DATAA)"],
}]];
let buffers = analyze(
¶meters,
&annotations,
&BuildContext {
sizeofs: HashMap::from([(String::from("WIN32_FIND_DATAA"), 0x0140)]),
..Default::default()
},
);
assert_eq!(buffers.len(), 1);
assert!(matches!(buffers[0].length, Expression::Constant(0x0140)));
}
#[test]
fn unrecognized_annotation_emits_nothing() {
let parameters = vec![param("p", TypeKind::U32)];
let annotations = vec![vec![Annotation {
name: "_Reserved_",
args: vec![],
}]];
let buffers = analyze(¶meters, &annotations, &BuildContext::default());
assert!(buffers.is_empty());
}
#[test]
fn fully_identical_buffers_from_redeclaration_collapse_to_one() {
let parameters = vec![param("name", TypeKind::U8), param("namelen", TypeKind::U32)];
let annotations = vec![
vec![
Annotation {
name: "_Out_writes_bytes_",
args: vec!["namelen"],
},
Annotation {
name: "_Out_writes_bytes_",
args: vec!["namelen"],
},
],
vec![],
];
let buffers = analyze(¶meters, &annotations, &BuildContext::default());
assert_eq!(buffers.len(), 1);
assert_eq!(buffers[0].parameter, 0);
assert_eq!(buffers[0].direction, BufferDirection::Output);
assert_eq!(buffers[0].phase, BufferPhase::Pre);
}
#[test]
fn ambiguous_buffers_with_conflicting_lengths_are_dropped() {
let parameters = vec![param("buf", TypeKind::U8)];
let annotations = vec![vec![
Annotation {
name: "_In_reads_bytes_",
args: vec!["sizeof(SMALL)"],
},
Annotation {
name: "_In_reads_bytes_",
args: vec!["sizeof(BIG)"],
},
]];
let buffers = analyze(
¶meters,
&annotations,
&BuildContext {
sizeofs: HashMap::from([(String::from("SMALL"), 32), (String::from("BIG"), 128)]),
..Default::default()
},
);
assert!(buffers.is_empty(), "expected drop, got {buffers:?}");
}
#[test]
fn unknown_identifier_drops_the_buffer() {
let parameters = vec![param("p", TypeKind::U32)];
let annotations = vec![vec![Annotation {
name: "_In_reads_bytes_",
args: vec!["nGhost"],
}]];
let buffers = analyze(¶meters, &annotations, &BuildContext::default());
assert!(buffers.is_empty());
}
}