use std::path::{Path, PathBuf};
use toml::value::*;
use crate::utils::find_pyproject_toml;
const DEFAULT_PROFILE: &str = include_str!("../profiles/default.toml");
const KNOWN_TOOLS: [&str; 3] = ["mypy", "pytest", "ruff"];
pub fn is_known_tool<S: Into<String>>(tool: S) -> bool {
let s: String = tool.into();
let s = s.as_str();
KNOWN_TOOLS.iter().any(|x| s.eq(*x))
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
Ser(#[from] toml::ser::Error),
#[error(transparent)]
De(#[from] toml::de::Error),
#[error(transparent)]
Io(#[from] std::io::Error),
#[error(transparent)]
Reqwest(#[from] reqwest::Error),
#[error("invalid profile {0:?}")]
InvalidProfile(String),
}
#[derive(Debug)]
pub struct Profile {
pub name: String,
pub root: Table,
}
impl Profile {
pub fn load(name: Option<String>) -> Result<Self, Error> {
match name {
None => Self::load_string("default".to_owned(), DEFAULT_PROFILE.to_owned()),
Some(name) => {
if name == *"default" {
Self::load_string("default".to_owned(), DEFAULT_PROFILE.to_owned())
} else if name.starts_with("http://") || name.starts_with("https://") {
Self::load_url(name)
} else {
Err(Error::InvalidProfile(name))
}
}
}
}
pub fn load_url(url: String) -> Result<Self, Error> {
let content = reqwest::blocking::get(url.clone())?.text()?;
Self::load_string(url, content)
}
pub fn load_file(toml_file: &Path) -> Result<Self, Error> {
Self::load_string(
format!("file://{}", toml_file.to_string_lossy()),
std::fs::read_to_string(toml_file)?,
)
}
pub fn load_string(name: String, toml_text: String) -> Result<Self, Error> {
Ok(Profile {
name,
root: toml_text.parse::<Table>()?,
})
}
pub fn validate(&mut self) {
let unexpected_keys: Vec<String> = self
.root
.keys()
.filter(|k| (*k).ne("tool"))
.map(String::clone)
.collect();
if !unexpected_keys.is_empty() {
log::warn!(
"Unexpected top-level `[*]` keys found in profile `{}`: {}",
self.name,
unexpected_keys.join(", ")
);
unexpected_keys.iter().for_each(|k| {
self.root.remove(k);
});
}
match self.root.get(&"tool".to_owned()) {
Some(Value::Table(table)) => {
let unexpected_keys: Vec<String> = table
.keys()
.filter(|k| !is_known_tool(*k))
.map(String::clone)
.collect();
if !unexpected_keys.is_empty() {
log::warn!(
"Unexpected `[tool.*]` keys found in profile `{}`: {}",
self.name,
unexpected_keys.join(", ")
);
unexpected_keys.iter().for_each(|k| {
self.root.remove(k);
});
}
}
_ => {
log::warn!("The `tool` key in profile `{}` is not a table.", self.name);
self.root.remove("tool");
}
}
}
pub fn merge(&self, pyproject_toml: &Table) -> Table {
fn merge_tables(base: &Table, override_table: &Table) -> Table {
let mut result = Table::new();
let all_keys: Vec<_> = base.keys().chain(override_table.keys()).collect();
for key in all_keys {
let base_value = base.get(key);
let override_value = override_table.get(key);
match (base_value, override_value) {
(Some(Value::Table(base_table)), Some(Value::Table(override_table))) => {
result.insert(
key.clone(),
Value::Table(merge_tables(base_table, override_table)),
);
}
(Some(Value::Array(_)), Some(Value::Array(override_array))) => {
result.insert(key.clone(), Value::Array(override_array.clone()));
}
(_, Some(value)) => {
result.insert(key.clone(), value.clone());
}
(Some(value), _) => {
result.insert(key.clone(), value.clone());
}
(_, _) => {}
}
}
result
}
merge_tables(&self.root, pyproject_toml)
}
pub fn materialize(&self, cwd: Option<PathBuf>) -> Result<PathBuf, Error> {
let cwd = cwd.ok_or("").or_else(|_| std::env::current_dir())?;
let pyproject_toml_file = find_pyproject_toml(Some(cwd));
let project_root = if let Some(file) = pyproject_toml_file.clone() {
file.parent().unwrap().to_path_buf()
} else {
std::env::current_dir()?
};
let out_file = project_root.join(".tire").join("pyproject.toml");
std::fs::create_dir_all(out_file.parent().unwrap())?;
let pyproject_toml = if let Some(file) = pyproject_toml_file {
let content = std::fs::read_to_string(file).unwrap();
content.parse::<toml::Table>()?
} else {
Table::new()
};
std::fs::write(
out_file.clone(),
toml::to_string(&self.merge(&pyproject_toml))?,
)?;
Ok(out_file)
}
}