use std::collections::BTreeMap;
use thiserror::Error;
use crate::{
ProgramFunctionSignature, ProgramMetadata, ProgramParameter, ProgramType, ProgramTypeComponent,
ProgramTypeDetails,
};
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum ProgramTypeIndexError {
#[error("program type id `{id}` appears more than once")]
DuplicateId { id: String },
#[error("program type id `{id}` is not present in the type table")]
MissingType { id: String },
#[error("program type id `{id}` has kind `{actual}`, expected {expected}")]
UnexpectedKind {
id: String,
expected: &'static str,
actual: &'static str,
},
}
#[derive(Debug, Clone)]
pub struct ProgramTypeIndex<'a> {
by_id: BTreeMap<&'a str, &'a ProgramType>,
}
impl<'a> ProgramTypeIndex<'a> {
pub fn new(types: &'a [ProgramType]) -> Result<Self, ProgramTypeIndexError> {
let mut by_id = BTreeMap::new();
for type_ref in types {
if by_id.insert(type_ref.id.as_str(), type_ref).is_some() {
return Err(ProgramTypeIndexError::DuplicateId {
id: type_ref.id.clone(),
});
}
}
Ok(Self { by_id })
}
pub fn from_metadata(metadata: &'a ProgramMetadata) -> Result<Self, ProgramTypeIndexError> {
Self::new(&metadata.types)
}
pub fn len(&self) -> usize {
self.by_id.len()
}
pub fn is_empty(&self) -> bool {
self.by_id.is_empty()
}
pub fn iter(&self) -> impl DoubleEndedIterator<Item = &'a ProgramType> + '_ {
self.by_id.values().copied()
}
pub fn get(&self, id: &str) -> Option<&'a ProgramType> {
self.by_id.get(id).copied()
}
pub fn require(&self, id: &str) -> Result<&'a ProgramType, ProgramTypeIndexError> {
self.get(id)
.ok_or_else(|| ProgramTypeIndexError::MissingType { id: id.to_string() })
}
pub fn pointee(&self, id: &str) -> Result<Option<&'a ProgramType>, ProgramTypeIndexError> {
let type_ref = self.require(id)?;
match &type_ref.details {
ProgramTypeDetails::Pointer { pointee_type_id } => self.optional_type(pointee_type_id),
details => Err(unexpected_kind(type_ref, "pointer", details)),
}
}
pub fn array_element(
&self,
id: &str,
) -> Result<Option<&'a ProgramType>, ProgramTypeIndexError> {
let type_ref = self.require(id)?;
match &type_ref.details {
ProgramTypeDetails::Array {
element_type_id, ..
} => self.optional_type(element_type_id),
details => Err(unexpected_kind(type_ref, "array", details)),
}
}
pub fn typedef_base(&self, id: &str) -> Result<Option<&'a ProgramType>, ProgramTypeIndexError> {
let type_ref = self.require(id)?;
match &type_ref.details {
ProgramTypeDetails::Typedef { base_type_id } => self.optional_type(base_type_id),
details => Err(unexpected_kind(type_ref, "typedef", details)),
}
}
pub fn bitfield_base(
&self,
id: &str,
) -> Result<Option<&'a ProgramType>, ProgramTypeIndexError> {
let type_ref = self.require(id)?;
match &type_ref.details {
ProgramTypeDetails::Bitfield { base_type_id, .. } => self.optional_type(base_type_id),
details => Err(unexpected_kind(type_ref, "bitfield", details)),
}
}
pub fn component_type(
&self,
component: &ProgramTypeComponent,
) -> Result<Option<&'a ProgramType>, ProgramTypeIndexError> {
self.optional_type(&component.type_id)
}
pub fn signature_return_type(
&self,
signature: &ProgramFunctionSignature,
) -> Result<Option<&'a ProgramType>, ProgramTypeIndexError> {
self.optional_type(&signature.return_type_id)
}
pub fn parameter_type(
&self,
parameter: &ProgramParameter,
) -> Result<Option<&'a ProgramType>, ProgramTypeIndexError> {
self.optional_type(¶meter.type_id)
}
pub fn validate_references(&self) -> Result<(), ProgramTypeIndexError> {
for type_ref in self.iter() {
match &type_ref.details {
ProgramTypeDetails::Builtin | ProgramTypeDetails::Unknown => {}
ProgramTypeDetails::Pointer { pointee_type_id } => {
self.require_optional(pointee_type_id)?;
}
ProgramTypeDetails::Array {
element_type_id, ..
} => {
self.require_optional(element_type_id)?;
}
ProgramTypeDetails::Structure { components }
| ProgramTypeDetails::Union { components } => {
for component in components {
self.component_type(component)?;
}
}
ProgramTypeDetails::Enum { .. } => {}
ProgramTypeDetails::Typedef { base_type_id } => {
self.require_optional(base_type_id)?;
}
ProgramTypeDetails::FunctionDefinition { signature } => {
self.signature_return_type(signature)?;
for parameter in &signature.parameters {
self.parameter_type(parameter)?;
}
}
ProgramTypeDetails::Bitfield { base_type_id, .. } => {
self.require_optional(base_type_id)?;
}
}
}
Ok(())
}
fn optional_type(
&self,
id: &Option<String>,
) -> Result<Option<&'a ProgramType>, ProgramTypeIndexError> {
id.as_deref().map(|id| self.require(id)).transpose()
}
fn require_optional(&self, id: &Option<String>) -> Result<(), ProgramTypeIndexError> {
self.optional_type(id).map(|_| ())
}
}
impl ProgramMetadata {
pub fn type_index(&self) -> Result<ProgramTypeIndex<'_>, ProgramTypeIndexError> {
ProgramTypeIndex::from_metadata(self)
}
}
fn unexpected_kind(
type_ref: &ProgramType,
expected: &'static str,
actual: &ProgramTypeDetails,
) -> ProgramTypeIndexError {
ProgramTypeIndexError::UnexpectedKind {
id: type_ref.id.clone(),
expected,
actual: kind_name(actual),
}
}
fn kind_name(details: &ProgramTypeDetails) -> &'static str {
match details {
ProgramTypeDetails::Builtin => "builtin",
ProgramTypeDetails::Unknown => "unknown",
ProgramTypeDetails::Pointer { .. } => "pointer",
ProgramTypeDetails::Array { .. } => "array",
ProgramTypeDetails::Structure { .. } => "structure",
ProgramTypeDetails::Union { .. } => "union",
ProgramTypeDetails::Enum { .. } => "enum",
ProgramTypeDetails::Typedef { .. } => "typedef",
ProgramTypeDetails::FunctionDefinition { .. } => "function_definition",
ProgramTypeDetails::Bitfield { .. } => "bitfield",
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{ProgramFunctionSignature, ProgramParameter};
#[test]
fn rejects_duplicate_type_ids() {
let types = vec![builtin("type:int"), builtin("type:int")];
assert_eq!(
ProgramTypeIndex::new(&types).expect_err("duplicate id is rejected"),
ProgramTypeIndexError::DuplicateId {
id: "type:int".to_string()
}
);
}
#[test]
fn resolves_common_nested_type_references() {
let types = vec![
builtin("type:int"),
ProgramType {
id: "type:int_ptr".to_string(),
name: "int *".to_string(),
display_name: "int *".to_string(),
size: 8,
alignment: 8,
category_path: None,
details: ProgramTypeDetails::Pointer {
pointee_type_id: Some("type:int".to_string()),
},
},
ProgramType {
id: "type:int_array".to_string(),
name: "int[4]".to_string(),
display_name: "int[4]".to_string(),
size: 16,
alignment: 4,
category_path: None,
details: ProgramTypeDetails::Array {
element_type_id: Some("type:int".to_string()),
element_count: 4,
element_size: 4,
},
},
ProgramType {
id: "type:int_alias".to_string(),
name: "int_alias".to_string(),
display_name: "int_alias".to_string(),
size: 4,
alignment: 4,
category_path: None,
details: ProgramTypeDetails::Typedef {
base_type_id: Some("type:int".to_string()),
},
},
ProgramType {
id: "type:int_bit".to_string(),
name: "int:3".to_string(),
display_name: "int:3".to_string(),
size: 1,
alignment: 1,
category_path: None,
details: ProgramTypeDetails::Bitfield {
base_type_id: Some("type:int".to_string()),
bit_size: 3,
bit_offset: 0,
storage_size: 1,
},
},
ProgramType {
id: "type:record".to_string(),
name: "Record".to_string(),
display_name: "Record".to_string(),
size: 4,
alignment: 4,
category_path: None,
details: ProgramTypeDetails::Structure {
components: vec![component("value", "type:int")],
},
},
ProgramType {
id: "type:callback".to_string(),
name: "callback".to_string(),
display_name: "callback".to_string(),
size: 1,
alignment: 1,
category_path: None,
details: ProgramTypeDetails::FunctionDefinition {
signature: signature(),
},
},
];
let index = ProgramTypeIndex::new(&types).expect("index builds");
assert_eq!(index.len(), types.len());
assert_eq!(
index.pointee("type:int_ptr").unwrap().unwrap().id,
"type:int"
);
assert_eq!(
index.array_element("type:int_array").unwrap().unwrap().id,
"type:int"
);
assert_eq!(
index.typedef_base("type:int_alias").unwrap().unwrap().id,
"type:int"
);
assert_eq!(
index.bitfield_base("type:int_bit").unwrap().unwrap().id,
"type:int"
);
let components = match &index.require("type:record").unwrap().details {
ProgramTypeDetails::Structure { components } => components,
_ => panic!("expected structure"),
};
assert_eq!(
index.component_type(&components[0]).unwrap().unwrap().id,
"type:int"
);
let signature = match &index.require("type:callback").unwrap().details {
ProgramTypeDetails::FunctionDefinition { signature } => signature,
_ => panic!("expected function definition"),
};
assert_eq!(
index.signature_return_type(signature).unwrap().unwrap().id,
"type:int"
);
assert_eq!(
index
.parameter_type(&signature.parameters[0])
.unwrap()
.unwrap()
.id,
"type:int"
);
index.validate_references().unwrap();
}
#[test]
fn reports_missing_nested_references() {
let types = vec![ProgramType {
id: "type:int_ptr".to_string(),
name: "int *".to_string(),
display_name: "int *".to_string(),
size: 8,
alignment: 8,
category_path: None,
details: ProgramTypeDetails::Pointer {
pointee_type_id: Some("type:missing".to_string()),
},
}];
let index = ProgramTypeIndex::new(&types).expect("index builds");
assert_eq!(
index.validate_references(),
Err(ProgramTypeIndexError::MissingType {
id: "type:missing".to_string()
})
);
}
#[test]
fn reports_unexpected_kind_for_specific_traversal() {
let types = vec![builtin("type:int")];
let index = ProgramTypeIndex::new(&types).expect("index builds");
assert_eq!(
index.pointee("type:int"),
Err(ProgramTypeIndexError::UnexpectedKind {
id: "type:int".to_string(),
expected: "pointer",
actual: "builtin"
})
);
}
#[test]
fn builds_from_metadata() {
let metadata = ProgramMetadata {
symbols: Vec::new(),
functions: Vec::new(),
types: vec![builtin("type:int")],
};
assert_eq!(metadata.type_index().unwrap().len(), 1);
}
fn builtin(id: &str) -> ProgramType {
ProgramType {
id: id.to_string(),
name: id.to_string(),
display_name: id.to_string(),
size: 4,
alignment: 4,
category_path: None,
details: ProgramTypeDetails::Builtin,
}
}
fn component(name: &str, type_id: &str) -> ProgramTypeComponent {
ProgramTypeComponent {
ordinal: 0,
name: name.to_string(),
offset: 0,
length: 4,
type_id: Some(type_id.to_string()),
bit_size: None,
bit_offset: None,
comment: None,
}
}
fn signature() -> ProgramFunctionSignature {
ProgramFunctionSignature {
display: "int callback(int value)".to_string(),
calling_convention: "__stdcall".to_string(),
return_type_id: Some("type:int".to_string()),
parameters: vec![ProgramParameter {
ordinal: 0,
name: "value".to_string(),
type_id: Some("type:int".to_string()),
storage: "unknown".to_string(),
}],
varargs: false,
no_return: false,
}
}
}