use crate::metadata::{
method::{Method, MethodMap, MethodRc, MethodRefList},
token::Token,
typesystem::{CilType, CilTypeRc, TypeRegistry},
};
type TypeFilter<'a> = Box<dyn Fn(&CilType) -> bool + 'a>;
type MethodFilter<'a> = Box<dyn Fn(&Method) -> bool + 'a>;
pub struct TypeQuery<'a> {
registry: &'a TypeRegistry,
filters: Vec<TypeFilter<'a>>,
}
impl<'a> TypeQuery<'a> {
pub fn new(registry: &'a TypeRegistry) -> Self {
Self {
registry,
filters: Vec::new(),
}
}
#[must_use]
pub fn defined(mut self) -> Self {
self.filters.push(Box::new(|t| !t.is_typeref()));
self
}
#[must_use]
pub fn classes(mut self) -> Self {
self.filters.push(Box::new(CilType::is_class));
self
}
#[must_use]
pub fn interfaces(mut self) -> Self {
self.filters.push(Box::new(CilType::is_interface));
self
}
#[must_use]
pub fn value_types(mut self) -> Self {
self.filters.push(Box::new(CilType::is_value_type));
self
}
#[must_use]
pub fn enums(mut self) -> Self {
self.filters.push(Box::new(CilType::is_enum));
self
}
#[must_use]
pub fn delegates(mut self) -> Self {
self.filters.push(Box::new(CilType::is_delegate));
self
}
#[must_use]
pub fn public(mut self) -> Self {
self.filters.push(Box::new(CilType::is_public));
self
}
#[must_use]
pub fn internal(mut self) -> Self {
self.filters.push(Box::new(CilType::is_internal));
self
}
#[must_use]
pub fn sealed(mut self) -> Self {
self.filters.push(Box::new(CilType::is_sealed));
self
}
#[must_use]
pub fn abstract_types(mut self) -> Self {
self.filters.push(Box::new(CilType::is_abstract));
self
}
#[must_use]
pub fn namespace(mut self, ns: &'a str) -> Self {
self.filters.push(Box::new(move |t| t.namespace == ns));
self
}
#[must_use]
pub fn namespace_prefix(mut self, prefix: &'a str) -> Self {
self.filters
.push(Box::new(move |t| t.namespace.starts_with(prefix)));
self
}
#[must_use]
pub fn name(mut self, name: &'a str) -> Self {
self.filters.push(Box::new(move |t| t.name == name));
self
}
#[must_use]
pub fn name_contains(mut self, substr: &'a str) -> Self {
self.filters
.push(Box::new(move |t| t.name.contains(substr)));
self
}
#[must_use]
pub fn fullname(mut self, fqn: &'a str) -> Self {
self.filters.push(Box::new(move |t| t.fullname() == fqn));
self
}
#[must_use]
pub fn has_methods(mut self) -> Self {
self.filters.push(Box::new(|t| !t.methods.is_empty()));
self
}
#[must_use]
pub fn has_fields(mut self) -> Self {
self.filters.push(Box::new(|t| !t.fields.is_empty()));
self
}
#[must_use]
pub fn nested(mut self) -> Self {
self.filters
.push(Box::new(|t| t.enclosing_type.get().is_some()));
self
}
#[must_use]
pub fn top_level(mut self) -> Self {
self.filters
.push(Box::new(|t| t.enclosing_type.get().is_none()));
self
}
#[must_use]
pub fn has_base_type(mut self) -> Self {
self.filters.push(Box::new(|t| t.base().is_some()));
self
}
#[must_use]
pub fn generic(mut self) -> Self {
self.filters
.push(Box::new(|t| !t.generic_params.is_empty()));
self
}
#[must_use]
pub fn filter(mut self, f: impl Fn(&CilType) -> bool + 'a) -> Self {
self.filters.push(Box::new(f));
self
}
#[must_use]
pub fn find_all(&self) -> Vec<CilTypeRc> {
self.iter().collect()
}
#[must_use]
pub fn find_first(&self) -> Option<CilTypeRc> {
self.iter().next()
}
#[must_use]
pub fn count(&self) -> usize {
self.iter().count()
}
#[must_use]
pub fn exists(&self) -> bool {
self.iter().next().is_some()
}
#[must_use]
pub fn tokens(&self) -> Vec<Token> {
self.iter().map(|t| t.token).collect()
}
pub fn iter(&self) -> impl Iterator<Item = CilTypeRc> + '_ {
self.registry.iter().filter_map(move |entry| {
let t = entry.value().clone();
if self.filters.iter().all(|f| f(&t)) {
Some(t)
} else {
None
}
})
}
#[must_use]
pub fn methods(self) -> MethodQuery<'a> {
let methods: Vec<MethodRc> = self
.iter()
.flat_map(|t| {
t.methods
.iter()
.filter_map(|(_, method_ref)| method_ref.upgrade())
.collect::<Vec<_>>()
})
.collect();
MethodQuery::from_collected(methods)
}
}
enum MethodQuerySource<'a> {
Assembly(&'a MethodMap),
Collected(Vec<MethodRc>),
}
pub struct MethodQuery<'a> {
source: MethodQuerySource<'a>,
filters: Vec<MethodFilter<'a>>,
}
impl<'a> MethodQuery<'a> {
pub fn from_assembly(methods: &'a MethodMap) -> Self {
Self {
source: MethodQuerySource::Assembly(methods),
filters: Vec::new(),
}
}
pub fn from_type(methods: &MethodRefList) -> Self {
let collected: Vec<MethodRc> = methods
.iter()
.filter_map(|(_, method_ref)| method_ref.upgrade())
.collect();
Self {
source: MethodQuerySource::Collected(collected),
filters: Vec::new(),
}
}
fn from_collected(methods: Vec<MethodRc>) -> Self {
Self {
source: MethodQuerySource::Collected(methods),
filters: Vec::new(),
}
}
#[must_use]
pub fn public(mut self) -> Self {
self.filters.push(Box::new(Method::is_public));
self
}
#[must_use]
pub fn static_methods(mut self) -> Self {
self.filters.push(Box::new(Method::is_static));
self
}
#[must_use]
pub fn instance(mut self) -> Self {
self.filters.push(Box::new(|m| !m.is_static()));
self
}
#[must_use]
pub fn virtual_methods(mut self) -> Self {
self.filters.push(Box::new(Method::is_virtual));
self
}
#[must_use]
pub fn abstract_methods(mut self) -> Self {
self.filters.push(Box::new(Method::is_abstract));
self
}
#[must_use]
pub fn constructors(mut self) -> Self {
self.filters.push(Box::new(Method::is_ctor));
self
}
#[must_use]
pub fn static_constructors(mut self) -> Self {
self.filters.push(Box::new(Method::is_cctor));
self
}
#[must_use]
pub fn name(mut self, name: &'a str) -> Self {
self.filters.push(Box::new(move |m| m.name == name));
self
}
#[must_use]
pub fn name_contains(mut self, substr: &'a str) -> Self {
self.filters
.push(Box::new(move |m| m.name.contains(substr)));
self
}
#[must_use]
pub fn has_body(mut self) -> Self {
self.filters.push(Box::new(Method::has_body));
self
}
#[must_use]
pub fn without_body(mut self) -> Self {
self.filters.push(Box::new(|m| !m.has_body()));
self
}
#[must_use]
pub fn native(mut self) -> Self {
self.filters.push(Box::new(Method::is_code_native));
self
}
#[must_use]
pub fn il(mut self) -> Self {
self.filters.push(Box::new(Method::is_code_il));
self
}
#[must_use]
pub fn pinvoke(mut self) -> Self {
self.filters.push(Box::new(Method::is_pinvoke));
self
}
#[must_use]
pub fn min_params(mut self, n: usize) -> Self {
self.filters
.push(Box::new(move |m| m.signature.params.len() >= n));
self
}
#[must_use]
pub fn max_params(mut self, n: usize) -> Self {
self.filters
.push(Box::new(move |m| m.signature.params.len() <= n));
self
}
#[must_use]
pub fn declaring_type(mut self, type_name: &'a str) -> Self {
self.filters.push(Box::new(move |m| {
m.declaring_type_fullname().is_some_and(|n| n == type_name)
}));
self
}
#[must_use]
pub fn event_handlers(mut self) -> Self {
self.filters.push(Box::new(Method::is_event_handler));
self
}
#[must_use]
pub fn filter(mut self, f: impl Fn(&Method) -> bool + 'a) -> Self {
self.filters.push(Box::new(f));
self
}
#[must_use]
pub fn find_all(&self) -> Vec<MethodRc> {
self.iter().collect()
}
#[must_use]
pub fn find_first(&self) -> Option<MethodRc> {
self.iter().next()
}
#[must_use]
pub fn count(&self) -> usize {
self.iter().count()
}
#[must_use]
pub fn exists(&self) -> bool {
self.iter().next().is_some()
}
#[must_use]
pub fn tokens(&self) -> Vec<Token> {
self.iter().map(|m| m.token).collect()
}
#[must_use]
pub fn iter(&self) -> Box<dyn Iterator<Item = MethodRc> + '_> {
let base: Box<dyn Iterator<Item = MethodRc> + '_> = match &self.source {
MethodQuerySource::Assembly(map) => {
Box::new(map.iter().map(|entry| entry.value().clone()))
}
MethodQuerySource::Collected(methods) => Box::new(methods.iter().cloned()),
};
Box::new(base.filter(move |m| self.filters.iter().all(|f| f(m))))
}
}
impl<'b> IntoIterator for &'b MethodQuery<'_> {
type Item = MethodRc;
type IntoIter = Box<dyn Iterator<Item = MethodRc> + 'b>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
#[cfg(test)]
mod tests {
use crate::CilObject;
#[test]
fn test_type_query_defined_filters_typerefs() {
let assembly =
CilObject::from_path("tests/samples/WindowsBase.dll").expect("Failed to load assembly");
let all_types = assembly.types().all_types();
let defined_types = assembly.query_types().defined().find_all();
assert!(defined_types.len() < all_types.len());
assert!(defined_types.iter().all(|t| !t.is_typeref()));
}
#[test]
fn test_type_query_chained_filters() {
let assembly =
CilObject::from_path("tests/samples/WindowsBase.dll").expect("Failed to load assembly");
let public_defined = assembly.query_types().defined().public().find_all();
for t in &public_defined {
assert!(!t.is_typeref());
assert!(t.is_public());
}
}
#[test]
fn test_type_query_exists_and_count() {
let assembly =
CilObject::from_path("tests/samples/WindowsBase.dll").expect("Failed to load assembly");
assert!(assembly.query_types().defined().exists());
assert!(assembly.query_types().defined().count() > 0);
}
#[test]
fn test_method_query_from_assembly() {
let assembly =
CilObject::from_path("tests/samples/WindowsBase.dll").expect("Failed to load assembly");
let static_methods = assembly.query_methods().static_methods().find_all();
assert!(!static_methods.is_empty());
assert!(static_methods.iter().all(|m| m.is_static()));
}
#[test]
fn test_method_query_static_constructors() {
let assembly =
CilObject::from_path("tests/samples/WindowsBase.dll").expect("Failed to load assembly");
let cctors = assembly.query_methods().static_constructors().find_all();
assert!(cctors.iter().all(|m| m.is_cctor()));
}
#[test]
fn test_type_query_methods_pivot() {
let assembly =
CilObject::from_path("tests/samples/WindowsBase.dll").expect("Failed to load assembly");
let methods = assembly
.query_types()
.defined()
.public()
.methods()
.public()
.find_all();
assert!(!methods.is_empty());
assert!(methods.iter().all(|m| m.is_public()));
}
#[test]
fn test_method_query_by_name() {
let assembly =
CilObject::from_path("tests/samples/WindowsBase.dll").expect("Failed to load assembly");
let ctors = assembly.query_methods().name(".ctor").find_all();
assert!(!ctors.is_empty());
assert!(ctors.iter().all(|m| m.name == ".ctor"));
}
#[test]
fn test_type_query_tokens() {
let assembly =
CilObject::from_path("tests/samples/WindowsBase.dll").expect("Failed to load assembly");
let tokens = assembly.query_types().defined().public().tokens();
let types = assembly.query_types().defined().public().find_all();
assert_eq!(tokens.len(), types.len());
}
#[test]
fn test_type_query_from_type_methods() {
let assembly =
CilObject::from_path("tests/samples/WindowsBase.dll").expect("Failed to load assembly");
let type_opt = assembly.query_types().defined().has_methods().find_first();
if let Some(t) = type_opt {
let methods = t.query_methods().find_all();
assert!(!methods.is_empty());
}
}
}