#[cfg(test)]
#[path = "mom_files_test.rs"]
mod mom_files_test;
use crate::cli::Version;
use crate::merge_map_values;
use crate::serde_common::CommonFields;
use crate::tasks::Task;
use crate::types::DynErrResult;
use crate::utils::{get_task_dependency_graph, to_os_task_name};
use lazy_static::lazy_static;
use petgraph::algo::toposort;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::{env, fs};
#[derive(Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
pub(crate) struct MomFile {
pub(crate) version: Version,
#[serde(skip_deserializing)]
pub(crate) filepath: PathBuf,
#[serde(skip_deserializing)]
pub(crate) directory: PathBuf,
#[serde(flatten)]
pub(crate) common: CommonFields,
#[serde(default)]
#[serde(deserialize_with = "deserialize_tasks")]
pub(crate) tasks: HashMap<String, Task>,
}
fn deserialize_tasks<'de, D>(deserializer: D) -> Result<HashMap<String, Task>, D::Error>
where
D: serde::Deserializer<'de>,
{
struct TaskVisitor;
impl<'de> serde::de::Visitor<'de> for TaskVisitor {
type Value = HashMap<String, Task>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a map of tasks")
}
fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
where
M: serde::de::MapAccess<'de>,
{
lazy_static! {
static ref KEY_REGEX: regex::Regex =
regex::Regex::new(r"^[_a-zA-Z][a-zA-Z0-9_-]*(\.(windows|linux|macos))?$")
.unwrap();
}
let mut tasks = HashMap::new();
while let Some((key, task)) = map.next_entry::<String, Task>()? {
if !KEY_REGEX.is_match(&key) {
return Err(serde::de::Error::custom(format!(
"Invalid task name `{}`. Task names must start with a letter or underscore and can only \
contain letters, numbers, underscores and dashes. They can also end with .windows, .linux \
or .macos to specify an OS specific task.",
key
)));
}
tasks.insert(key, task);
}
Ok(tasks)
}
}
deserializer.deserialize_map(TaskVisitor)
}
impl MomFile {
fn deserialize_from_path(path: &Path) -> DynErrResult<MomFile> {
let contents = match fs::read_to_string(path) {
Ok(file_contents) => file_contents,
Err(e) => return Err(format!("There was an error reading the file:\n{}", e).into()),
};
Ok(serde_yaml::from_str(&contents)?)
}
#[cfg(test)]
fn deserialize_from_str(contents: &str) -> DynErrResult<MomFile> {
Ok(serde_yaml::from_str(contents)?)
}
pub(crate) fn from_path(path: PathBuf) -> DynErrResult<MomFile> {
let mut mom_file = MomFile::deserialize_from_path(path.as_path())?;
mom_file.filepath = path;
mom_file.directory = PathBuf::from(mom_file.filepath.parent().unwrap());
mom_file.setup()?;
Ok(mom_file)
}
#[cfg(test)]
pub(crate) fn from_str(contents: &str) -> DynErrResult<MomFile> {
let mut mom_file = MomFile::deserialize_from_str(contents)?;
mom_file.setup()?;
Ok(mom_file)
}
pub(crate) fn setup(&mut self) -> DynErrResult<()> {
self.common.setup(&self.directory)?;
let mut tasks = self.get_flat_tasks()?;
let dep_graph = get_task_dependency_graph(&tasks)?;
let dependencies = toposort(&dep_graph, None);
let dependencies = match dependencies {
Ok(dependencies) => dependencies,
Err(e) => {
return Err(format!("Found a cyclic dependency for task: {}", e.node_id()).into());
}
};
let dependencies: Vec<String> = dependencies
.iter()
.rev()
.map(|v| String::from(*v))
.collect();
for dependency_name in dependencies {
let mut task = tasks.remove(&dependency_name).unwrap();
let bases = std::mem::take(&mut task.common.extend);
for base in bases.iter() {
let os_task_name = format!("{}.{}", &base, env::consts::OS);
let base_task = self
.tasks
.get(&os_task_name)
.unwrap_or_else(|| self.tasks.get(base).unwrap());
task.extend(base_task);
}
task.common.extend = bases;
self.tasks.insert(dependency_name, task);
}
Ok(())
}
pub(crate) fn extend(&mut self, other: &MomFile) {
self.common.extend(&other.common);
merge_map_values!(self.tasks, &other.tasks);
}
fn get_flat_tasks(&mut self) -> DynErrResult<HashMap<String, Task>> {
let mut flat_tasks = HashMap::new();
let tasks = std::mem::take(&mut self.tasks);
macro_rules! insert_os_task {
($os_task:expr, $parent_name:expr, $os_name:expr) => {
let os_task = std::mem::take(&mut $os_task);
let mut os_task = *os_task.unwrap();
let os_task_name = format!("{}.{}", $parent_name, $os_name);
if flat_tasks.contains_key(&os_task_name) {
return Err(format!("Duplicate task `{}`", os_task_name).into());
}
os_task.setup(&os_task_name, &self.directory)?;
flat_tasks.insert(os_task_name, os_task);
};
}
for (name, mut task) in tasks {
if task.linux.is_some() {
insert_os_task!(task.linux, name, "linux");
}
if task.windows.is_some() {
insert_os_task!(task.windows, name, "windows");
}
if task.macos.is_some() {
insert_os_task!(task.macos, name, "macos");
}
task.setup(&name, &self.directory)?;
flat_tasks.insert(name, task);
}
Ok(flat_tasks)
}
pub(crate) fn clone_task(&self, task_name: &str) -> Option<Task> {
self.get_task(task_name).cloned()
}
pub(crate) fn get_task(&self, task_name: &str) -> Option<&Task> {
let os_task_name = to_os_task_name(task_name);
if let Some(task) = self.tasks.get(&os_task_name) {
return Some(task);
} else if let Some(task) = self.tasks.get(task_name) {
return Some(task);
}
None
}
pub(crate) fn clone_public_task(&self, task_name: &str) -> Option<Task> {
let os_task_name = to_os_task_name(task_name);
let task = self
.tasks
.get(&os_task_name)
.or_else(|| self.tasks.get(task_name));
if let Some(task) = task {
if task.is_private() {
return None;
}
Some(task.clone())
} else {
None
}
}
#[cfg(test)]
pub(crate) fn has_task(&self, task_name: &str) -> bool {
let os_task_name = to_os_task_name(task_name);
self.tasks.contains_key(&os_task_name) || self.tasks.contains_key(task_name)
}
pub(crate) fn get_public_task_names(&self) -> Vec<&str> {
self.tasks
.values()
.filter(|t| !t.is_private())
.map(|t| t.get_name())
.collect()
}
}