use std::collections::HashSet;
use std::ffi::OsStr;
use std::fs::remove_dir_all;
use std::path::{Path, PathBuf};
use std::sync::Arc;
extern crate cargo_manifest;
use crate::error::RustGenResult;
use cargo_manifest::{Manifest, Value};
use mavinspect::parser::InspectorBuilder;
use mavinspect::protocol::{Filter, Microservices, Protocol};
use mavinspect::Inspector;
use crate::generator::{Generator, GeneratorParams};
#[derive(Clone, Debug, Default)]
pub struct BuildHelper {
out_path: PathBuf,
sources: Option<Vec<PathBuf>>,
manifest_path: Option<PathBuf>,
include_dialects: Option<HashSet<String>>,
exclude_dialects: Option<HashSet<String>>,
messages: Option<Vec<String>>,
enums: Option<Vec<String>>,
commands: Option<Vec<String>>,
protocol: Option<Arc<Protocol>>,
microservices: Option<Microservices>,
serde: bool,
specta: bool,
generate_tests: Option<bool>,
internal: bool,
}
#[derive(Clone, Debug, Default)]
pub struct BuildHelperBuilder(BuildHelper);
impl BuildHelper {
pub fn builder<T: Into<PathBuf>>(out_path: T) -> BuildHelperBuilder {
BuildHelperBuilder(Self {
out_path: out_path.into(),
..Default::default()
})
}
pub fn generate(&self) -> RustGenResult<()> {
if let Err(err) = remove_dir_all(&self.out_path) {
log::debug!("Error while cleaning output directory: {err:?}");
}
let protocol = self.load_filtered_protocol()?;
Generator::make(
protocol,
&self.out_path,
GeneratorParams {
serde: self.serde,
specta: self.specta,
generate_tests: self.generate_tests.unwrap_or(false),
internal: self.internal,
..Default::default()
},
)
.generate()?;
Ok(())
}
pub fn out_path(&self) -> &Path {
self.out_path.as_path()
}
pub fn sources(&self) -> Option<Vec<&Path>> {
self.sources
.as_ref()
.map(|sources| sources.iter().map(|src| src.as_path()).collect())
}
pub fn manifest_path(&self) -> Option<&Path> {
self.manifest_path.as_deref()
}
pub fn include_dialects(&self) -> Option<HashSet<&str>> {
self.include_dialects.as_ref().map(|include_dialects| {
include_dialects
.iter()
.map(|dialect| dialect.as_str())
.collect()
})
}
pub fn exclude_dialects(&self) -> Option<HashSet<&str>> {
self.exclude_dialects.as_ref().map(|include_dialects| {
include_dialects
.iter()
.map(|dialect| dialect.as_str())
.collect()
})
}
pub fn messages(&self) -> Option<Vec<&str>> {
self.messages
.as_ref()
.map(|messages| messages.iter().map(|msg| msg.as_str()).collect())
}
pub fn enums(&self) -> Option<Vec<&str>> {
self.enums
.as_ref()
.map(|enums| enums.iter().map(|msg| msg.as_str()).collect())
}
pub fn commands(&self) -> Option<Vec<&str>> {
self.commands
.as_ref()
.map(|commands| commands.iter().map(|msg| msg.as_str()).collect())
}
pub fn microservices(&self) -> Option<&Microservices> {
self.microservices.as_ref()
}
pub fn protocol(&self) -> Option<&Protocol> {
match &self.protocol {
None => None,
Some(protocol) => Some(protocol.as_ref()),
}
}
pub fn serde(&self) -> bool {
self.serde
}
pub fn specta(&self) -> bool {
self.specta
}
pub fn generate_tests(&self) -> bool {
self.generate_tests.unwrap_or(false)
}
pub fn internal(&self) -> bool {
self.internal
}
fn load_filtered_protocol(&self) -> RustGenResult<Arc<Protocol>> {
Ok(if let Some(protocol) = &self.protocol {
protocol.clone()
} else {
let inspector_builder = self.make_mavlink_inspector_builder();
let mut protocol = inspector_builder.build()?.parse()?;
self.retain_protocol_entities(&mut protocol);
Arc::new(protocol)
})
}
fn make_mavlink_inspector_builder(&self) -> InspectorBuilder {
let mut inspector_builder = Inspector::builder();
let sources: Vec<&Path> = self
.sources
.as_ref()
.unwrap()
.iter()
.map(|s| s.as_path())
.collect();
inspector_builder.set_sources(&sources);
if let Some(include_dialects) = &self.include_dialects {
inspector_builder
.set_include(&Vec::from_iter(include_dialects.iter().map(|d| d.as_str())));
}
if let Some(exclude_dialects) = &self.exclude_dialects {
inspector_builder
.set_exclude(&Vec::from_iter(exclude_dialects.iter().map(|d| d.as_str())));
}
inspector_builder
}
fn retain_protocol_entities(&self, protocol: &mut Protocol) {
let mut filters = Filter::new();
if let Some(microservices) = &self.microservices {
filters = filters.with_microservices(*microservices);
}
if let Some(messages) = &self.messages {
filters = filters.with_messages(messages);
}
if let Some(enums) = &self.enums {
filters = filters.with_enums(enums);
}
if let Some(commands) = &self.commands {
filters = filters.with_commands(commands);
}
protocol.retain(&filters);
}
fn apply_manifest_config(&mut self) -> RustGenResult<()> {
if let Some(manifest_path) = &self.manifest_path {
let manifest = Manifest::from_path(manifest_path)?;
if let Some(package) = manifest.package {
if let Some(metadata) = package.metadata {
if let Some(spec) = metadata.get("mavspec") {
self.apply_manifest_config_spec(spec);
}
}
}
}
Ok(())
}
fn apply_manifest_config_spec(&mut self, spec: &Value) {
self.apply_manifest_config_messages(spec);
self.apply_manifest_config_enums(spec);
self.apply_manifest_config_commands(spec);
self.apply_manifest_config_microservices(spec);
if let Some(Value::Boolean(generate_tests)) = spec.get("generate_tests") {
self.generate_tests = Some(*generate_tests);
}
}
fn apply_manifest_config_messages(&mut self, spec: &Value) {
if let Some(Value::Array(msgs)) = spec.get("messages") {
if self.messages.is_none() {
self.messages = Some(Vec::from_iter(
msgs.iter().map(|v| v.to_string().replace('"', "")),
));
}
}
}
fn apply_manifest_config_enums(&mut self, spec: &Value) {
if let Some(Value::Array(msgs)) = spec.get("enums") {
if self.enums.is_none() {
self.enums = Some(Vec::from_iter(
msgs.iter().map(|v| v.to_string().replace('"', "")),
));
}
}
}
fn apply_manifest_config_commands(&mut self, spec: &Value) {
if let Some(Value::Array(msgs)) = spec.get("commands") {
if self.commands.is_none() {
self.commands = Some(Vec::from_iter(
msgs.iter().map(|v| v.to_string().replace('"', "")),
));
}
}
}
fn apply_manifest_config_microservices(&mut self, spec: &Value) {
if let Some(Value::Array(msgs)) = spec.get("microservices") {
if self.microservices.is_none() {
let mut microservices = Microservices::default();
let flags_map = Microservices::flags_map();
for flag_name in msgs.iter().map(|v| v.to_string().replace('"', "")) {
if let Some(microservice_flag) = flags_map.get(flag_name.as_str()) {
microservices |= *microservice_flag;
}
}
self.microservices = Some(microservices);
}
}
}
}
impl BuildHelperBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn build(&self) -> RustGenResult<BuildHelper> {
let mut helper = self.0.clone();
if helper.manifest_path.is_some() {
helper.apply_manifest_config()?;
}
Ok(helper)
}
pub fn generate(&self) -> RustGenResult<()> {
self.build()?.generate()
}
pub fn set_sources<T>(&mut self, sources: &[T]) -> &mut Self
where
T: ?Sized + Into<PathBuf> + Clone,
{
self.0.sources = Some(sources.iter().cloned().map(|src| src.into()).collect());
self.0.manifest_path = None;
self
}
pub fn set_manifest_path<T: ?Sized + AsRef<OsStr>>(&mut self, manifest_path: &T) -> &mut Self {
self.0.manifest_path = Some(PathBuf::from(manifest_path));
self
}
pub fn set_include_dialects<T: ToString>(&mut self, include_dialects: &[T]) -> &mut Self {
self.0.include_dialects = Some(HashSet::from_iter(
include_dialects.iter().map(|s| s.to_string()),
));
self
}
pub fn set_exclude_dialects<T: ToString>(&mut self, include_dialects: &[T]) -> &mut Self {
self.0.include_dialects = Some(HashSet::from_iter(
include_dialects.iter().map(|s| s.to_string()),
));
self
}
pub fn set_messages<T: ToString>(&mut self, messages: &[T]) -> &mut Self {
self.0.messages = Some(Vec::from_iter(messages.iter().map(|s| s.to_string())));
self
}
pub fn set_enums<T: ToString>(&mut self, enums: &[T]) -> &mut Self {
self.0.enums = Some(Vec::from_iter(enums.iter().map(|s| s.to_string())));
self
}
pub fn set_commands<T: ToString>(&mut self, commands: &[T]) -> &mut Self {
self.0.commands = Some(Vec::from_iter(commands.iter().map(|s| s.to_string())));
self
}
pub fn set_microservices<T: ToString>(&mut self, microservices: &[T]) -> &mut Self {
let mut microservices_ = Microservices::default();
for microservice in microservices {
microservices_ |= Microservices::from(microservice.to_string())
}
self.0.microservices = Some(microservices_);
self
}
pub fn set_serde(&mut self, serde: bool) -> &mut Self {
self.0.serde = serde;
self
}
pub fn set_specta(&mut self, specta: bool) -> &mut Self {
self.0.specta = specta;
self
}
pub fn set_protocol(&mut self, protocol: Protocol) -> &mut Self {
self.0.protocol = Some(Arc::new(protocol));
self.0.sources = None;
self
}
pub fn set_generate_tests(&mut self, generate_tests: bool) -> &mut Self {
self.0.generate_tests = Some(generate_tests);
self
}
pub fn set_internal(&mut self, internal: bool) -> &mut Self {
self.0.internal = internal;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs::remove_dir_all;
use std::path::Path;
#[test]
fn build_helper_basic() {
let out_path = "../tmp/mavlink/helper_basics";
BuildHelper::builder(Path::new(out_path))
.set_sources(&[
PathBuf::from("message_definitions").join("standard"),
PathBuf::from("message_definitions").join("test"),
])
.set_include_dialects(&["minimal"])
.generate()
.unwrap();
remove_dir_all(out_path).unwrap();
}
#[test]
fn build_helper_new_generic() {
BuildHelper::builder("../tmp/mavlink");
BuildHelper::builder("../tmp/mavlink".to_string());
BuildHelper::builder(Path::new("../tmp/mavlink"));
BuildHelper::builder(Path::new("../tmp").join("mavlink"));
}
#[test]
fn build_helper_set_sources_generic() {
BuildHelper::builder("../tmp/mavlink").set_sources(&["./message_definitions/standard"]);
BuildHelper::builder("../tmp/mavlink")
.set_sources(&["./message_definitions/standard".to_string()]);
BuildHelper::builder("../tmp/mavlink")
.set_sources(&[Path::new("./message_definitions/standard")]);
BuildHelper::builder("../tmp/mavlink")
.set_sources(&[Path::new("./message_definitions").join("extra")]);
}
#[test]
fn build_helper_protocol_filtering() {
let out_path = "../tmp/mavlink/protocol_filtering";
let protocol = BuildHelper::builder(Path::new(out_path))
.set_sources(&[
PathBuf::from("message_definitions").join("standard"),
PathBuf::from("message_definitions").join("test"),
])
.set_microservices(&["HEARTBEAT", "FTP", "GIMBAL_V1"])
.set_messages(&["PROTOCOL_VERSION", "MAV_INSPECT_V1"])
.set_commands(&["MAV_CMD_DO_CHANGE_SPEED", "MAV_CMD_DO_SET_ROI*"])
.set_enums(&["STORAGE_STATUS", "GIMBAL_*"])
.set_include_dialects(&["minimal", "standard", "common", "mav_inspect_test"])
.build()
.unwrap()
.load_filtered_protocol()
.unwrap();
let dialect = protocol.get_dialect_by_canonical_name("common").unwrap();
assert!(dialect.contains_enum_with_name("MAV_CMD"));
let mav_cmd = dialect.get_enum_by_name("MAV_CMD").unwrap();
assert!(dialect.contains_message_with_name("COMMAND_LONG"));
assert!(dialect.contains_message_with_name("COMMAND_INT"));
assert!(dialect.contains_message_with_name("COMMAND_ACK"));
assert!(dialect.contains_message_with_name("COMMAND_CANCEL"));
assert!(dialect.contains_enum_with_name("MAV_FRAME"));
assert!(dialect.contains_enum_with_name("SPEED_TYPE"));
assert!(dialect.contains_enum_with_name("MAV_ROI"));
assert!(mav_cmd.has_entry_with_name("MAV_CMD_DO_CHANGE_SPEED"));
assert!(mav_cmd.has_entry_with_name("MAV_CMD_DO_SET_ROI"));
assert!(mav_cmd.has_entry_with_name("MAV_CMD_DO_SET_ROI_LOCATION"));
assert!(mav_cmd.has_entry_with_name("MAV_CMD_DO_SET_ROI_NONE"));
assert!(!mav_cmd.has_entry_with_name("MAV_CMD_DO_INVERTED_FLIGHT"));
assert!(!mav_cmd.has_entry_with_name("MAV_CMD_DO_GRIPPER"));
assert!(!mav_cmd.has_entry_with_name("MAV_CMD_PREFLIGHT_CALIBRATION"));
}
}