use clap::Parser;
use colored::Colorize;
use std::collections::HashMap;
use std::fmt::Display;
use serde::de::{self, Deserializer};
use serde::{Deserialize, Serialize};
mod error;
use error::ShieldError;
mod node;
use node::{Node, ResolvedNode, StringChunk, UnresolvedNode, extract_string_chunks};
mod log;
const MAX_DISPLAY_LEN: usize = 50;
type EnvMap = HashMap<String, String>;
fn try_to_resolve(
unresolved: &mut Vec<UnresolvedNode>,
resolved: &mut Vec<ResolvedNode>,
) -> Result<(), ShieldError> {
let unresolved_node_opt = unresolved.first_mut();
if let Some(unresolved_node) = unresolved_node_opt {
let available: HashMap<&String, &String> =
resolved.iter().map(|r| (&r.key, &r.value)).collect();
let resolved_string_chunks: Vec<String> = unresolved_node
.chunks
.iter()
.map(|chunk| match chunk {
StringChunk::Original(s) => Ok(s.to_string().clone()),
StringChunk::Reference(r) => {
if let Some(resolved) = available.get(r) {
Ok(resolved.to_string().clone())
} else {
info!(" -- couldn't resolve {}", r);
unresolved_node.offending_node = Some(r.clone());
Err(ShieldError::UnresolvedReference)
}
}
})
.collect::<Result<Vec<_>, _>>()?;
let resolved_string: String = resolved_string_chunks.join("");
info!(" >> resolved {}", unresolved_node.key);
resolved.push(ResolvedNode {
key: unresolved_node.key.clone(),
value: resolved_string,
});
unresolved.remove(0);
return Ok(());
}
Err(ShieldError::UnresolvedReference)
}
#[derive(Debug, Clone, Serialize)]
enum ValidatedSchema {
Version1(HashMap<String, ValidatedAttribute>),
}
impl TryFrom<ParsedSchema> for ValidatedSchema {
type Error = ShieldError;
fn try_from(value: ParsedSchema) -> Result<ValidatedSchema, ShieldError> {
let validated = match value {
ParsedSchema::Version1(hash_map) => {
hash_map
.into_iter()
.map(|(key, attr)| {
if attr.value.is_none() && attr.default.is_none() && attr.optional.is_none()
{
return Ok((
key,
ValidatedAttribute {
description: attr.description,
options: ValidatedOptions::Secret,
},
));
}
if let Some(default) = attr.default.clone()
&& attr.optional.is_none()
&& attr.value.is_none()
{
return Ok((
key,
ValidatedAttribute {
description: attr.description,
options: ValidatedOptions::WithDefault(default),
},
));
};
if let Some(option) = attr.optional
&& attr.default.is_none()
&& attr.value.is_none()
{
if option {
return Ok((
key,
ValidatedAttribute {
description: attr.description,
options: ValidatedOptions::Optional,
},
));
} else {
return Err(ShieldError::InvalidSchema(
"'optional' can only be set to true".to_string(),
));
}
}
if let Some(value) = attr.value.clone()
&& attr.default.is_none()
&& attr.optional.is_none()
{
return Ok((
key,
ValidatedAttribute {
description: attr.description,
options: ValidatedOptions::WithValue(value),
},
));
}
Err(ShieldError::InvalidSchema(format!(
"illegal combination of options for variable [{}]",
key
)))
})
.collect::<Result<HashMap<_, _>, _>>()?
}
};
Ok(ValidatedSchema::Version1(validated))
}
}
#[derive(Debug, Serialize)]
enum FinalizedSchema {
Version1(HashMap<String, ValidatedAttribute>),
}
#[derive(Debug, Deserialize, Serialize)]
struct SchemaCheck {
existing_subset: EnvMap,
incorrect_values: EnvMap,
missing_values: EnvMap,
missing_default: EnvMap,
missing_optional: Vec<String>,
missing_secrets: Vec<String>,
not_in_schema: Vec<String>,
}
impl SchemaCheck {
fn new(final_schema: &FinalizedSchema, env_map: &EnvMap) -> Self {
let mut result = SchemaCheck {
existing_subset: HashMap::new(),
incorrect_values: HashMap::new(),
missing_default: HashMap::new(),
missing_values: HashMap::new(),
missing_optional: Vec::new(),
missing_secrets: Vec::new(),
not_in_schema: Vec::new(),
};
match final_schema {
FinalizedSchema::Version1(schema_map) => {
for (schema_key, validated_attribute) in schema_map.iter() {
if let Some((_, env_value)) =
env_map.iter().find(|(b_key, _)| b_key == &schema_key)
{
result
.existing_subset
.insert(schema_key.clone(), env_value.clone());
} else {
match &validated_attribute.options {
ValidatedOptions::Optional => {
result.missing_optional.push(schema_key.clone())
}
ValidatedOptions::Secret => {
result.missing_secrets.push(schema_key.clone())
}
ValidatedOptions::WithValue(value) => {
result
.missing_values
.insert(schema_key.clone(), value.clone());
}
ValidatedOptions::WithDefault(default) => {
result
.missing_default
.insert(schema_key.clone(), default.clone());
}
}
}
let matching_entry = env_map.iter().find(|(env_key, _)| env_key == &schema_key);
if let Some((_, matching_value)) = matching_entry
&& let ValidatedOptions::WithValue(expected_value) =
&validated_attribute.options
&& expected_value != matching_value
{
result
.incorrect_values
.insert(schema_key.clone(), expected_value.clone());
}
}
for env_key in env_map.keys() {
if !schema_map.contains_key(env_key) {
result.not_in_schema.push(env_key.clone());
}
}
}
}
result
}
}
impl TryFrom<ValidatedSchema> for FinalizedSchema {
type Error = ShieldError;
fn try_from(value: ValidatedSchema) -> Result<FinalizedSchema, ShieldError> {
match value {
ValidatedSchema::Version1(hash_map) => {
let mut result = hash_map.clone();
let nodes: Vec<Node> = result
.clone()
.into_iter()
.map(|(key, attr)| match &attr.options {
ValidatedOptions::WithValue(input_string)
| ValidatedOptions::WithDefault(input_string) => {
let string_chunks = extract_string_chunks(&key, input_string)?;
let contains_references = string_chunks
.iter()
.any(|chunk| matches!(chunk, StringChunk::Reference(_)));
if contains_references {
Ok(Some(Node::Unresolved(UnresolvedNode {
key,
chunks: string_chunks,
offending_node: None,
})))
} else {
Ok(Some(Node::Resolved(ResolvedNode {
key,
value: input_string.clone(),
})))
}
}
_ => Ok(None),
})
.collect::<Result<Vec<Option<Node>>, ShieldError>>()?
.into_iter()
.flatten()
.collect();
let mut resolved: Vec<ResolvedNode> = nodes
.clone()
.into_iter()
.filter_map(|node| {
if let Node::Resolved(root) = node {
Some(root)
} else {
None
}
})
.collect();
let mut unresolved: Vec<UnresolvedNode> = nodes
.into_iter()
.filter_map(|node| {
if let Node::Unresolved(child) = node {
Some(child)
} else {
None
}
})
.collect();
let dead_ends: HashMap<&String, &ValidatedAttribute> = hash_map
.iter()
.filter(|(key, _)| {
let not_resolved = resolved
.iter()
.all(|resolved_node| &resolved_node.key != *key);
let not_unresolved = unresolved
.iter()
.all(|unresolved_node| &unresolved_node.key != *key);
not_resolved && not_unresolved
})
.collect();
for unresolved_node in unresolved.iter() {
for chunk in unresolved_node.chunks.iter() {
match chunk {
StringChunk::Original(_) => (),
StringChunk::Reference(r) => {
if dead_ends.contains_key(r) {
return Err(ShieldError::DeadEndReference(
unresolved_node.key.clone(),
r.clone(),
));
}
if !result.contains_key(r) {
return Err(ShieldError::MissingReferenceExtended(
unresolved_node.key.clone(),
r.clone(),
));
}
if &unresolved_node.key == r {
return Err(ShieldError::CyclicReference(r.clone()));
}
}
}
}
}
let mut stagnation_counter = 0;
let mut iteration = 0;
loop {
info!(
"{} ({}/{})",
iteration,
stagnation_counter,
unresolved.len()
);
iteration += 1;
if unresolved.is_empty() {
break;
}
match try_to_resolve(&mut unresolved, &mut resolved) {
Ok(_) => {
stagnation_counter = 0;
}
Err(_) => {
stagnation_counter += 1;
}
}
if let Some(last) = unresolved.pop() {
unresolved.insert(0, last);
} else {
break;
}
if stagnation_counter > (2 * unresolved.len()) && !unresolved.is_empty() {
if let Some(missing) = unresolved.first() {
error!("total unresolved: {}", unresolved.len());
if let Some(offender) = &missing.offending_node {
return Err(ShieldError::MissingReferenceExtended(
missing.key.clone(),
offender.clone(),
));
} else {
return Err(ShieldError::MissingReference(missing.key.clone()));
}
} else {
error!("expected to see at least one unresolved node, but didn't");
}
}
}
info!("all references resolved");
for (key, validated_attr) in result.iter_mut() {
if let Some(resolved_node) = resolved.iter().find(|node| &node.key == key) {
match validated_attr.options {
ValidatedOptions::WithValue(ref mut value) => {
*value = resolved_node.value.clone();
}
ValidatedOptions::WithDefault(ref mut description) => {
*description = resolved_node.value.clone();
}
_ => (),
}
}
}
Ok(FinalizedSchema::Version1(result))
}
}
}
}
#[derive(Debug, Clone, Serialize)]
struct ValidatedAttribute {
description: String,
options: ValidatedOptions,
}
#[derive(Debug, Clone, Serialize)]
enum ValidatedOptions {
Optional,
Secret,
WithValue(String),
WithDefault(String),
}
#[derive(Debug, Deserialize)]
#[serde(tag = "version")]
#[serde(deny_unknown_fields)]
enum ParsedSchema {
#[serde(rename = "1")]
Version1(HashMap<String, Attributes>),
}
#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields)]
struct Attributes {
#[serde(deserialize_with = "validate_description")]
description: String,
value: Option<String>,
optional: Option<bool>,
default: Option<String>,
}
fn validate_description<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
if s.is_empty() {
Err(de::Error::custom("host cannot be empty"))
} else {
Ok(s)
}
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
enum ShieldStatus {
Hopeless,
Recoverable,
Operational,
}
#[derive(Debug, Deserialize, Serialize)]
struct ShieldResponse {
schema_file: String,
status: ShieldStatus,
kind: ShieldResponseKind,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)]
enum ShieldResponseKind {
Failed {
error: String, },
Success {
checks_from_env: Box<SchemaCheck>,
},
}
fn truncated(s: &str) -> String {
if s.len() <= MAX_DISPLAY_LEN {
s.to_string()
} else {
format!("{}...", &s[..MAX_DISPLAY_LEN.saturating_sub(3)])
}
}
impl Display for ShieldResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.kind {
ShieldResponseKind::Failed { error } => {
let _ = writeln!(f, "{} {}", "Error:".red().bold(), error);
write!(f, "")
}
ShieldResponseKind::Success { checks_from_env } => {
let _ = writeln!(
f,
"{} schema at: ./{}",
"Parsed:".green().bold(),
self.schema_file
);
let total_missing = checks_from_env.missing_values.len()
+ checks_from_env.missing_default.len()
+ checks_from_env.missing_secrets.len();
let num_correct = checks_from_env.existing_subset.len();
if num_correct > 0 {
let _ = writeln!(
f,
"{} {} variables",
"Correct: ".green().bold(),
num_correct
);
if total_missing == 0 && checks_from_env.missing_optional.is_empty() {
let _ = writeln!(f, "{}", "Success!".green().bold(),);
}
}
let _ = writeln!(f);
let max_key_len = checks_from_env
.missing_values
.keys()
.chain(checks_from_env.missing_default.keys())
.chain(checks_from_env.incorrect_values.keys())
.chain(checks_from_env.missing_secrets.iter())
.chain(checks_from_env.missing_optional.iter())
.map(|k| k.len())
.max()
.unwrap_or(0);
if !checks_from_env.missing_optional.is_empty() {
let _ = writeln!(
f,
"{} {} optional variables missing from env:",
"Warning:".yellow().bold(),
checks_from_env.missing_optional.len(),
);
for key in checks_from_env.missing_optional.iter() {
let _ = writeln!(f, " {:width$} ", key, width = max_key_len);
}
let _ = writeln!(f);
}
if !checks_from_env.incorrect_values.is_empty() {
let _ = writeln!(
f,
"{} Variables with incorrect values:",
"Error:".red().bold()
);
}
for (key, incorrect_value) in checks_from_env.incorrect_values.iter() {
let _ = writeln!(
f,
" {:width$}: '{}'",
key,
truncated(incorrect_value),
width = max_key_len
);
}
if !checks_from_env.incorrect_values.is_empty() {
let _ = writeln!(f);
}
if total_missing > 0 {
let _ = writeln!(
f,
"{} {} {}",
"Error:".red().bold(),
total_missing.to_string().bold(),
"required variables missing from env:".bold(),
);
}
for (key, missing_value) in checks_from_env.missing_values.iter() {
let _ = writeln!(
f,
"(value) {:width$}: '{}'",
key,
truncated(missing_value),
width = max_key_len
);
}
for (key, missing_value) in checks_from_env.missing_default.iter() {
let _ = writeln!(
f,
"(default) {:width$}: '{}'",
key,
truncated(missing_value),
width = max_key_len
);
}
for key in checks_from_env.missing_secrets.iter() {
let _ = writeln!(f, "(secret) {:width$}", key, width = max_key_len);
}
write!(f, "")
}
}
}
}
impl FinalizedSchema {
fn new(filename: &str) -> Result<FinalizedSchema, ShieldError> {
info!("reading: {}", filename);
let schema_contents = std::fs::read_to_string(filename)?;
info!("parsing {} into ShieldSchema", filename);
let parsed: ParsedSchema = toml::from_str(&schema_contents)?;
info!("validating schema");
let validated_schema = ValidatedSchema::try_from(parsed)?;
info!("resolving references");
let finalized_schema = FinalizedSchema::try_from(validated_schema)?;
Ok(finalized_schema)
}
}
impl ShieldResponse {
fn new(filename: &str) -> ShieldResponse {
let schema = match FinalizedSchema::new(filename) {
Ok(s) => s,
Err(err) => match err {
ShieldError::Unrecoverable(err) => {
return Self {
status: ShieldStatus::Hopeless,
schema_file: filename.to_string(),
kind: ShieldResponseKind::Failed {
error: err.to_string(),
},
};
}
_ => {
return Self {
status: ShieldStatus::Recoverable,
schema_file: filename.to_string(),
kind: ShieldResponseKind::Failed {
error: err.to_string(),
},
};
}
},
};
let env_vars: EnvMap = std::env::vars().collect();
let checked_from_env = SchemaCheck::new(&schema, &env_vars);
Self {
status: ShieldStatus::Operational,
schema_file: filename.to_string(),
kind: ShieldResponseKind::Success {
checks_from_env: Box::new(checked_from_env),
},
}
}
}
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct InputArgs {
#[arg(short, long, default_value_t = format!("env.toml"))]
file: String,
#[arg(short, long, default_value_t = false)]
json: bool,
}
fn main() {
let args = InputArgs::parse();
let response = ShieldResponse::new(&args.file);
if args.json {
match serde_json::to_string_pretty(&response) {
Ok(response) => {
println!("{}", response);
}
Err(_) => {
println!("{{ \"status\": \"JsonParsingError\" }}")
}
}
} else {
print!("{}", response);
}
match response.kind {
ShieldResponseKind::Failed { error: _ } => std::process::exit(1),
ShieldResponseKind::Success { checks_from_env } => {
let total_missing = checks_from_env.missing_values.len()
+ checks_from_env.missing_default.len()
+ checks_from_env.missing_secrets.len();
if total_missing > 0 {
std::process::exit(1)
}
std::process::exit(0);
}
}
}
use test_generator::test_resources;
#[allow(unused)]
#[test_resources("test-files/invalid/*.toml")]
fn invalid_test(filename: &str) {
assert!(std::path::Path::new(filename).exists());
let response = ShieldResponse::new(filename);
assert!(matches!(
response.kind,
ShieldResponseKind::Failed { error: _ }
));
}
#[allow(unused)]
#[test_resources("test-files/valid/*.toml")]
fn valid_test(filename: &str) {
assert!(std::path::Path::new(filename).exists());
let response = ShieldResponse::new(filename);
assert!(matches!(
response.kind,
ShieldResponseKind::Success { checks_from_env: _ }
));
}