use std::collections::{HashMap, HashSet};
use std::ffi::OsStr;
use std::fs;
use std::path::{Path, PathBuf};
use std::time::Instant;
use quick_xml::reader::Reader;
use crate::protocol::{
Dialect, DialectId, DialectVersion, Enum, Message, MessageField, MessageId, Protocol,
};
use super::definition::DialectXmlDefinition;
use super::errors::{XmlInspectionError, XmlParseError};
use super::xml::XmlParser;
use crate::errors::Result;
use crate::parser::metadata::MavInspectMetadata;
use crate::protocol::DialectMetadata;
use crate::utils::{dialect_canonical_name, Buildable, Builder};
#[derive(Debug)]
pub struct Inspector {
definitions: Vec<DialectXmlDefinition>,
}
#[derive(Clone, Debug, Default)]
pub struct InspectorBuilder {
sources: Vec<PathBuf>,
include: Option<Vec<String>>,
exclude: Option<Vec<String>>,
}
impl InspectorBuilder {
pub fn build(&self) -> Result<Inspector> {
let sources: Vec<&Path> = self.sources.iter().map(|p| p.as_path()).collect();
let mut definitions = Inspector::discover(&sources)?;
if let Some(names) = &self.include {
log::info!("Only the following dialects will be included: {names:?}");
let names: Vec<String> = names
.iter()
.map(|name| dialect_canonical_name(name))
.collect();
definitions.retain(|def| names.contains(&dialect_canonical_name(def.name())));
}
if let Some(names) = &self.exclude {
log::info!("The following dialects will be excluded: {names:?}");
let names: Vec<String> = names
.iter()
.map(|name| dialect_canonical_name(name))
.collect();
definitions.retain(|def| !names.contains(&dialect_canonical_name(def.name())));
}
Ok(Inspector { definitions })
}
pub fn set_sources<T>(&mut self, sources: &[T]) -> &mut Self
where
T: Into<PathBuf> + Clone,
{
self.sources = sources.iter().cloned().map(|src| src.into()).collect();
self
}
pub fn add_source<T>(&mut self, source: T) -> &mut Self
where
T: Into<PathBuf> + Clone,
{
self.sources.push(source.into());
self
}
pub fn set_include<T: ?Sized + ToString>(&mut self, dialect_names: &[&T]) -> &mut Self {
self.include = Some(dialect_names.iter().map(|s| s.to_string()).collect());
self
}
pub fn set_exclude<T: ?Sized + ToString>(&mut self, dialect_names: &[&T]) -> &mut Self {
self.exclude = Some(dialect_names.iter().map(|s| s.to_string()).collect());
self
}
}
impl Inspector {
pub fn builder() -> InspectorBuilder {
InspectorBuilder::default()
}
pub fn definitions(&self) -> &[DialectXmlDefinition] {
&self.definitions
}
pub fn parse(&self) -> Result<Protocol> {
let mut dialects: HashMap<String, Dialect> = HashMap::new();
log::info!("Parsing dialects.");
let started_at = Instant::now();
for def in &self.definitions {
if dialects.contains_key(&def.canonical_name()) {
continue;
}
Self::parse_definition(def, &mut dialects)?;
}
let ended_at = Instant::now();
let duration = ended_at - started_at;
log::info!("All dialects parsed.");
log::info!("Parsed dialects: {:?}", dialects.keys());
log::info!(
"Parse duration: {}s",
(duration.as_micros() as f64) / 1000000.0
);
Ok(Protocol::new(dialects.values().cloned().collect()))
}
pub fn discover_dialect_names<T: ?Sized + AsRef<OsStr>>(paths: &[&T]) -> Result<Vec<String>> {
let definitions = Self::discover(paths)?;
let canonical_names = definitions
.iter()
.map(|def| def.name().to_string())
.collect();
Ok(canonical_names)
}
pub fn discover_dialect_canonical_names<T: ?Sized + AsRef<OsStr>>(
paths: &[&T],
) -> Result<Vec<String>> {
let definitions = Self::discover(paths)?;
let canonical_names = definitions.iter().map(|def| def.canonical_name()).collect();
Ok(canonical_names)
}
fn discover<T: ?Sized + AsRef<OsStr>>(paths: &[&T]) -> Result<Vec<DialectXmlDefinition>> {
let mut dialects: Vec<DialectXmlDefinition> = Vec::new();
let mut dialect_ids: HashMap<String, String> = HashMap::new();
for path in paths {
let directory_path = Path::new(&path).canonicalize()?;
log::debug!(
"Entering XML definitions directory: {}",
directory_path
.to_str()
.ok_or(XmlInspectionError::InvalidPath)?
);
for entry in fs::read_dir(directory_path)? {
let entry_path = entry?.path();
if entry_path.is_file()
&& entry_path
.extension()
.unwrap_or_default()
.to_str()
.unwrap_or_default()
.to_lowercase()
.eq("xml")
{
let path = entry_path
.to_str()
.ok_or(XmlInspectionError::InvalidPath)?
.to_string();
let definition = DialectXmlDefinition::load_from_path(&path);
#[allow(clippy::map_entry)]
if dialect_ids.contains_key(&definition.canonical_name()) {
return Err(XmlInspectionError::NamingCollision {
first: definition.name().to_string(),
second: dialect_ids
.get(&definition.canonical_name())
.unwrap()
.clone(),
canonical: definition.canonical_name(),
}
.into());
} else {
dialect_ids
.insert(definition.canonical_name(), definition.name().to_string());
}
dialects.push(definition);
}
}
}
Ok(dialects)
}
fn parse_definition(
definition: &DialectXmlDefinition,
dialects: &mut HashMap<String, Dialect>,
) -> Result<()> {
if dialects.contains_key(&definition.canonical_name()) {
return Ok(());
}
let metadata = Self::load_metadata(definition);
let mut enums: HashMap<String, Enum> = HashMap::new();
let mut messages: HashMap<MessageId, Message> = HashMap::new();
Self::load_dependencies(definition, dialects, &mut enums, &mut messages)?;
let started_at = Instant::now();
let mut parser: XmlParser = XmlParser::new(&mut enums, &mut messages);
let mut file_reader = Reader::from_file(definition.path()).unwrap();
parser.parse(definition.name(), &mut file_reader)?;
Self::update_messages_defined_in(
&mut messages,
&mut enums,
definition.canonical_name().as_str(),
);
Self::validate_field_enum_types(&enums, &messages)?;
let version = Self::derive_dialect_version(definition);
let dialect_id = Self::derive_dialect_id(definition);
let includes = Self::derive_includes(definition);
dialects.insert(
definition.canonical_name(),
Dialect::new(
definition.name(),
version,
dialect_id,
messages.values().cloned().collect(),
enums.values().cloned().collect(),
includes,
metadata,
),
);
Self::report_definition_parsing(definition, started_at, Instant::now());
Ok(())
}
fn load_dependencies(
definition: &DialectXmlDefinition,
dialects: &mut HashMap<String, Dialect>,
enums: &mut HashMap<String, Enum>,
messages: &mut HashMap<MessageId, Message>,
) -> Result<()> {
for dependency in definition.includes() {
Self::parse_definition(dependency, dialects)?;
}
Self::merge_enums(definition, dialects, enums);
Self::merge_messages(definition, dialects, messages);
Ok(())
}
fn report_definition_parsing(
definition: &DialectXmlDefinition,
started_at: Instant,
finished_at: Instant,
) {
let duration = finished_at - started_at;
if log::log_enabled!(log::Level::Debug) {
log::debug!("Parsed definition '{}'.", definition.name());
log::debug!("Definition path: {}", definition.path());
log::debug!("Definition version: {:?}", definition.version());
log::debug!("Definition dialect #: {:?}", definition.dialect());
log::debug!(
"Parse duration: {}s",
(duration.as_micros() as f64) / 1000000.0
);
}
}
fn derive_dialect_version(definition: &DialectXmlDefinition) -> Option<DialectVersion> {
let mut version = definition.version();
if version.is_none() {
for dependency in definition.includes() {
if dependency.version().is_some() {
version = dependency.version();
}
}
}
version
}
fn derive_dialect_id(definition: &DialectXmlDefinition) -> Option<DialectId> {
let mut dialect_id = definition.dialect();
if dialect_id.is_none() {
for dependency in definition.includes() {
if dependency.version().is_some() {
dialect_id = dependency.dialect();
}
}
}
dialect_id
}
fn derive_includes(definition: &DialectXmlDefinition) -> Vec<String> {
let mut includes = Vec::new();
for def in definition.includes() {
includes.push(def.canonical_name().to_string());
}
includes
}
fn validate_field_enum_types(
enums: &HashMap<String, Enum>,
messages: &HashMap<MessageId, Message>,
) -> Result<()> {
for mav_enum in enums.values() {
for msg in messages.values() {
for field in msg.fields() {
Self::validate_field_enum_type(mav_enum, field)?;
}
}
}
Ok(())
}
fn validate_field_enum_type(mav_enum: &Enum, field: &MessageField) -> Result<()> {
if let Some(field_enum_name) = field.r#enum() {
if field_enum_name == mav_enum.name() && *field.r#type() < mav_enum.inferred_type() {
return Err(XmlParseError::MessageFieldISTooSmallForEnum {
enum_name: mav_enum.name().into(),
enum_type: mav_enum.inferred_type(),
field_name: field.name().into(),
field_type: field.r#type().clone(),
}
.into());
}
}
Ok(())
}
fn merge_enums(
definition: &DialectXmlDefinition,
dialects: &mut HashMap<String, Dialect>,
enums: &mut HashMap<String, Enum>,
) {
let mut enums_variants_map: HashMap<&str, Vec<&Enum>> = Default::default();
for dependency in definition.includes() {
let dialect = dialects.get(&dependency.canonical_name()).unwrap();
for enm in dialect.enums() {
if let Some(variants) = enums_variants_map.get_mut(enm.name()) {
for &variant in variants.iter() {
if variant.defined_in() == enm.defined_in() {
continue;
}
}
variants.push(enm);
} else {
enums_variants_map.insert(enm.name(), vec![enm]);
}
}
}
for (enum_name, variants) in enums_variants_map {
let enm = Self::merged_enum(variants.as_slice(), definition.canonical_name().as_str());
enums.insert(enum_name.to_string(), enm);
}
}
fn merge_messages(
definition: &DialectXmlDefinition,
dialects: &mut HashMap<String, Dialect>,
messages: &mut HashMap<MessageId, Message>,
) {
let mut messages_variants_map: HashMap<MessageId, Vec<&Message>> = Default::default();
for dependency in definition.includes() {
let dialect = dialects.get(&dependency.canonical_name()).unwrap();
for msg in dialect.messages() {
if let Some(msg_variants) = messages_variants_map.get_mut(&msg.id()) {
for &variant in msg_variants.iter() {
if variant.defined_in() == msg.defined_in() {
continue;
}
}
msg_variants.push(msg);
} else {
messages_variants_map.insert(msg.id(), vec![msg]);
}
}
}
for (msg_id, variants) in messages_variants_map {
let msg = Self::merged_message(variants.as_slice());
messages.insert(msg_id, msg);
}
}
fn merged_enum(variants: &[&Enum], dialect_canonical_name: &str) -> Enum {
let mut enum_builder = Enum::builder();
let mut entries = HashMap::new();
for &enm in variants {
enum_builder.set_name(enm.name());
enum_builder.set_description(enm.description());
enum_builder.set_bitmask(enm.bitmask());
enum_builder.set_deprecated(enm.deprecated().cloned());
enum_builder.set_defined_in(enm.defined_in());
for entry in enm.entries() {
entries.insert(entry.value(), entry.clone());
}
}
let entry_values: HashSet<u32> = HashSet::from_iter(entries.keys().copied());
let entries = entries.values().cloned().collect::<Vec<_>>();
enum_builder.set_entries(entries.as_slice());
if variants.len() > 1 {
let mut first_with_all_entries = None;
for &enm in variants {
let enum_entry_values =
HashSet::from_iter(enm.entries().iter().map(|entry| entry.value()));
if enum_entry_values.is_superset(&entry_values) {
first_with_all_entries = Some(enm.defined_in());
break;
}
}
enum_builder.set_defined_in(first_with_all_entries.unwrap_or(dialect_canonical_name));
}
enum_builder.build()
}
fn merged_message(messages: &[&Message]) -> Message {
let mut message = messages.last().cloned().unwrap().clone();
let mut appears_in: HashSet<String> =
HashSet::from_iter(message.appears_in().iter().map(|s| s.as_ref().to_string()));
for &msg in messages {
let msg_appears_in: HashSet<String> =
HashSet::from_iter(msg.appears_in().iter().map(|s| s.as_ref().to_string()));
if msg_appears_in.is_superset(&appears_in) {
message = msg.clone();
appears_in = msg_appears_in;
}
}
message
}
fn update_messages_defined_in(
messages: &mut HashMap<MessageId, Message>,
enums: &mut HashMap<String, Enum>,
dialect_canonical_name: &str,
) {
let msg_ids = messages.keys().cloned().collect::<Vec<_>>();
'messages: for msg_id in msg_ids {
if let Some(message) = messages.get(&msg_id) {
if message.defined_in() == dialect_canonical_name {
continue 'messages;
}
for field in message.fields() {
if let Some(field_enum_name) = field.r#enum() {
if let Some(field_enum) = enums.get(&field_enum_name.to_string()) {
if field_enum.defined_in() == dialect_canonical_name {
let updates_message = message
.to_builder()
.set_defined_in(dialect_canonical_name)
.build();
messages.insert(msg_id, updates_message);
continue 'messages;
}
}
}
}
}
}
}
fn load_metadata(definition: &DialectXmlDefinition) -> DialectMetadata {
if let Some(directory) = PathBuf::from(definition.path()).parent() {
let metadata_path = directory.join(".dialects-metadata.yml");
if let Ok(contents) = fs::read_to_string(metadata_path) {
if let Ok(metadata) = serde_yml::from_str(&contents) {
let metadata: MavInspectMetadata = metadata;
return metadata.metadata_for_dialect(definition.name());
}
}
}
DialectMetadata::default()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_dialect_paths() -> Vec<&'static str> {
vec![
"./message_definitions/standard",
"./message_definitions/extra",
]
}
#[test]
fn dialects_are_available() {
let parser = Inspector::builder()
.set_sources(&default_dialect_paths())
.build()
.unwrap();
assert!(!parser.definitions().is_empty());
}
#[test]
fn builder_can_add_sources() {
let parser = Inspector::builder()
.add_source("./message_definitions/standard")
.build()
.unwrap();
assert!(!parser.definitions().is_empty());
}
#[test]
fn inclusion_rules() {
let parser = Inspector::builder()
.set_sources(&default_dialect_paths())
.set_include(&["CrazyFlight"])
.build()
.unwrap();
assert!(!parser.definitions().is_empty());
assert_eq!(parser.definitions().len(), 1);
assert_eq!(parser.definitions()[0].name(), "CrazyFlight");
}
#[test]
fn exclusion_rules() {
let parser = Inspector::builder()
.set_sources(&default_dialect_paths())
.set_exclude(&["CrazyFlight"])
.build()
.unwrap();
assert!(!parser.definitions().is_empty());
for def in parser.definitions() {
assert_ne!(def.name(), "CrazyFlight");
}
}
#[test]
fn inclusion_by_canonical_names() {
let parser = Inspector::builder()
.set_sources(&default_dialect_paths())
.set_include(&["crazy_flight"])
.build()
.unwrap();
assert!(!parser.definitions().is_empty());
assert_eq!(parser.definitions().len(), 1);
assert_eq!(parser.definitions()[0].name(), "CrazyFlight");
}
#[test]
fn exclusion_by_canonical_name() {
let parser = Inspector::builder()
.set_sources(&default_dialect_paths())
.set_exclude(&["crazy_flight"])
.build()
.unwrap();
assert!(!parser.definitions().is_empty());
for def in parser.definitions() {
assert_ne!(def.name(), "CrazyFlight");
}
}
}