use std::collections::HashMap;
use sigmd::model::ParameterFlags;
const IN_PREFIXES: &[&str] = &[
"_In",
"_Inout",
"__in",
"__inout",
"__RPC__in",
"__RPC__inout",
"__RPC__deref_in",
"__RPC__deref_inout",
];
const OUT_PREFIXES: &[&str] = &[
"_Out",
"_Inout",
"_COM_Out",
"__out",
"__inout",
"__RPC__out",
"__RPC__inout",
"__RPC__deref_out",
"__RPC__deref_inout",
];
const COM_PREFIXES: &[&str] = &["_COM_Out", "__RPC__deref_out"];
#[derive(Debug)]
pub struct Annotation<'a> {
pub name: &'a str,
pub args: Vec<&'a str>,
}
pub fn decode(raw: &[impl AsRef<str>]) -> Vec<Annotation<'_>> {
let mut annotations = Vec::new();
let mut by_counter = HashMap::new();
for attribute in raw.iter().map(AsRef::as_ref) {
let rest = match attribute.strip_prefix("__SAL:") {
Some(rest) => rest,
None => continue,
};
let (counter_str, rest) = match rest.split_once(':') {
Some(parts) => parts,
None => continue,
};
let counter = match counter_str.parse::<usize>() {
Ok(counter) => counter,
Err(_) => continue,
};
match rest.split_once(':') {
Some((index_str, value)) => {
let index = match index_str.parse::<usize>() {
Ok(index) => index,
Err(_) => continue,
};
let target = match by_counter.get(&counter).copied() {
Some(target) => target,
None => continue,
};
let annotation: &mut Annotation = &mut annotations[target];
assert_eq!(
index,
annotation.args.len(),
"argument indices must arrive in order within one annotation"
);
annotation.args.push(value);
}
None => {
by_counter.insert(counter, annotations.len());
annotations.push(Annotation {
name: rest,
args: Vec::new(),
});
}
}
}
annotations
}
pub fn flags(annotations: &[Annotation]) -> ParameterFlags {
let mut flags = ParameterFlags::empty();
if any_match(annotations, IN_PREFIXES) {
flags |= ParameterFlags::HAS_IN_ATTRIBUTE;
}
if any_match(annotations, OUT_PREFIXES) {
flags |= ParameterFlags::HAS_OUT_ATTRIBUTE;
}
if any_match(annotations, COM_PREFIXES) {
flags |= ParameterFlags::HAS_COM_ATTRIBUTE;
}
flags
}
fn any_match(annotations: &[Annotation], prefixes: &[&str]) -> bool {
annotations.iter().any(|annotation| {
prefixes
.iter()
.any(|prefix| annotation.name.starts_with(prefix))
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn decodes_one_annotation_with_one_arg() {
let raw = [
String::from("__SAL:0:_In_reads_bytes_"),
String::from("__SAL:0:0:nSize"),
];
let out = decode(&raw);
assert_eq!(out.len(), 1);
assert_eq!(out[0].name, "_In_reads_bytes_");
assert_eq!(out[0].args, vec!["nSize"]);
}
#[test]
fn decodes_two_annotations_back_to_back() {
let raw = [
"__SAL:0:_In_",
"__SAL:1:_Out_writes_bytes_to_",
"__SAL:1:0:cb",
"__SAL:1:1:*pcbWritten",
];
let out = decode(&raw);
assert_eq!(out.len(), 2);
assert_eq!(out[0].name, "_In_");
assert!(out[0].args.is_empty());
assert_eq!(out[1].name, "_Out_writes_bytes_to_");
assert_eq!(out[1].args, vec!["cb", "*pcbWritten"]);
}
#[test]
fn skips_unrelated_annotations() {
let raw = ["__OVERRIDE", "noise", "__SAL:0:_In_", "more noise"];
let out = decode(&raw);
assert_eq!(out.len(), 1);
assert_eq!(out[0].name, "_In_");
}
#[test]
#[should_panic(expected = "argument indices must arrive in order")]
fn out_of_order_arg_index_panics() {
let raw = [
"__SAL:0:_Out_writes_bytes_to_",
"__SAL:0:1:*pcbWritten",
"__SAL:0:0:cb",
];
let _ = decode(&raw);
}
#[test]
fn decodes_after_clang_dedup_pathology() {
let raw = [
"__SAL:1:_Out_writes_bytes_to_",
"__SAL:1:0:namelen",
"__SAL:1:1:return",
"__SAL:2:_Out_writes_bytes_",
"__SAL:2:0:namelen",
];
let out = decode(&raw);
assert_eq!(out.len(), 2);
assert_eq!(out[0].name, "_Out_writes_bytes_to_");
assert_eq!(out[0].args, vec!["namelen", "return"]);
assert_eq!(out[1].name, "_Out_writes_bytes_");
assert_eq!(out[1].args, vec!["namelen"]);
}
#[test]
fn arg_value_with_colons_survives() {
let raw = ["__SAL:0:_In_reads_", "__SAL:0:0:cond ? a : b"];
let out = decode(&raw);
assert_eq!(out.len(), 1);
assert_eq!(out[0].name, "_In_reads_");
assert_eq!(out[0].args, vec!["cond ? a : b"]);
}
#[test]
fn flags_reflects_in_out_com() {
let raw = [
String::from("__SAL:0:_In_reads_bytes_"),
String::from("__SAL:0:0:n"),
String::from("__SAL:1:_COM_Outptr_"),
];
let out = decode(&raw);
let flags = flags(&out);
assert!(flags.contains(ParameterFlags::HAS_IN_ATTRIBUTE));
assert!(flags.contains(ParameterFlags::HAS_OUT_ATTRIBUTE));
assert!(flags.contains(ParameterFlags::HAS_COM_ATTRIBUTE));
}
}