use serde::Deserialize;
use std::collections::{HashMap, HashSet};
use std::fmt;
use std::io;
use std::path::{Path, PathBuf};
use super::regex_engine::LazyCompiledRegex;
use super::{DestructivePattern, Pack, REGISTRY, SafePattern, Severity};
pub const CURRENT_SCHEMA_VERSION: u32 = 1;
const ID_PATTERN: &str = r"^[a-z][a-z0-9_]*\.[a-z][a-z0-9_]*$";
const VERSION_PATTERN: &str = r"^\d+\.\d+\.\d+$";
#[derive(Debug, Clone, Deserialize)]
pub struct ExternalPack {
#[serde(default = "default_schema_version")]
pub schema_version: u32,
pub id: String,
pub name: String,
pub version: String,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub keywords: Vec<String>,
#[serde(default)]
pub destructive_patterns: Vec<ExternalDestructivePattern>,
#[serde(default)]
pub safe_patterns: Vec<ExternalSafePattern>,
}
const fn default_schema_version() -> u32 {
1
}
#[derive(Debug, Clone, Deserialize)]
pub struct ExternalDestructivePattern {
pub name: String,
pub pattern: String,
#[serde(default)]
pub severity: ExternalSeverity,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub explanation: Option<String>,
#[serde(default)]
pub suggestions: Vec<ExternalSuggestion>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ExternalSuggestion {
pub command: String,
pub description: String,
#[serde(default)]
pub platform: ExternalPlatform,
}
#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ExternalPlatform {
#[default]
All,
Linux,
#[serde(alias = "macos")]
MacOS,
Windows,
Bsd,
}
impl From<ExternalPlatform> for super::Platform {
fn from(platform: ExternalPlatform) -> Self {
match platform {
ExternalPlatform::All => Self::All,
ExternalPlatform::Linux => Self::Linux,
ExternalPlatform::MacOS => Self::MacOS,
ExternalPlatform::Windows => Self::Windows,
ExternalPlatform::Bsd => Self::Bsd,
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct ExternalSafePattern {
pub name: String,
pub pattern: String,
#[serde(default)]
pub description: Option<String>,
}
#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ExternalSeverity {
Low,
Medium,
#[default]
High,
Critical,
}
impl From<ExternalSeverity> for Severity {
fn from(severity: ExternalSeverity) -> Self {
match severity {
ExternalSeverity::Low => Self::Low,
ExternalSeverity::Medium => Self::Medium,
ExternalSeverity::High => Self::High,
ExternalSeverity::Critical => Self::Critical,
}
}
}
#[derive(Debug)]
pub enum PackParseError {
Io(io::Error),
Yaml(serde_yaml::Error),
InvalidId { id: String, reason: String },
InvalidVersion { version: String, reason: String },
UnsupportedSchemaVersion { found: u32, max_supported: u32 },
InvalidPattern {
name: String,
pattern: String,
error: String,
},
DuplicatePattern { name: String },
EmptyPack,
IdCollision { id: String, builtin_name: String },
}
impl fmt::Display for PackParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Io(e) => write!(f, "IO error: {e}"),
Self::Yaml(e) => write!(f, "YAML parse error: {e}"),
Self::InvalidId { id, reason } => {
write!(f, "Invalid pack ID '{id}': {reason}")
}
Self::InvalidVersion { version, reason } => {
write!(f, "Invalid version '{version}': {reason}")
}
Self::UnsupportedSchemaVersion {
found,
max_supported,
} => {
write!(
f,
"Schema version {found} is not supported (max: {max_supported})"
)
}
Self::InvalidPattern {
name,
pattern,
error,
} => {
write!(f, "Invalid pattern '{name}' ({pattern}): {error}")
}
Self::DuplicatePattern { name } => {
write!(f, "Duplicate pattern name: {name}")
}
Self::EmptyPack => write!(f, "Pack has no patterns defined"),
Self::IdCollision { id, builtin_name } => {
write!(
f,
"Pack ID '{id}' collides with built-in pack '{builtin_name}'. \
External packs cannot override built-in security packs."
)
}
}
}
}
impl std::error::Error for PackParseError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(e) => Some(e),
Self::Yaml(e) => Some(e),
_ => None,
}
}
}
impl From<io::Error> for PackParseError {
fn from(e: io::Error) -> Self {
Self::Io(e)
}
}
impl From<serde_yaml::Error> for PackParseError {
fn from(e: serde_yaml::Error) -> Self {
Self::Yaml(e)
}
}
pub fn parse_pack_file(path: &Path) -> Result<ExternalPack, PackParseError> {
let content = std::fs::read_to_string(path)?;
parse_pack_string(&content)
}
pub fn parse_pack_string(content: &str) -> Result<ExternalPack, PackParseError> {
let pack: ExternalPack = serde_yaml::from_str(content)?;
validate_pack(&pack)?;
Ok(pack)
}
fn validate_pack(pack: &ExternalPack) -> Result<(), PackParseError> {
if pack.schema_version > CURRENT_SCHEMA_VERSION {
return Err(PackParseError::UnsupportedSchemaVersion {
found: pack.schema_version,
max_supported: CURRENT_SCHEMA_VERSION,
});
}
let id_regex = regex::Regex::new(ID_PATTERN).expect("ID regex should compile");
if !id_regex.is_match(&pack.id) {
return Err(PackParseError::InvalidId {
id: pack.id.clone(),
reason: format!("Must match pattern: {ID_PATTERN}"),
});
}
let version_regex = regex::Regex::new(VERSION_PATTERN).expect("Version regex should compile");
if !version_regex.is_match(&pack.version) {
return Err(PackParseError::InvalidVersion {
version: pack.version.clone(),
reason: format!("Must match pattern: {VERSION_PATTERN}"),
});
}
if pack.destructive_patterns.is_empty() && pack.safe_patterns.is_empty() {
return Err(PackParseError::EmptyPack);
}
let mut seen_names = std::collections::HashSet::new();
for pattern in &pack.destructive_patterns {
if !seen_names.insert(&pattern.name) {
return Err(PackParseError::DuplicatePattern {
name: pattern.name.clone(),
});
}
if let Err(e) = fancy_regex::Regex::new(&pattern.pattern) {
return Err(PackParseError::InvalidPattern {
name: pattern.name.clone(),
pattern: pattern.pattern.clone(),
error: e.to_string(),
});
}
}
for pattern in &pack.safe_patterns {
if !seen_names.insert(&pattern.name) {
return Err(PackParseError::DuplicatePattern {
name: pattern.name.clone(),
});
}
if let Err(e) = fancy_regex::Regex::new(&pattern.pattern) {
return Err(PackParseError::InvalidPattern {
name: pattern.name.clone(),
pattern: pattern.pattern.clone(),
error: e.to_string(),
});
}
}
Ok(())
}
#[must_use]
pub fn check_builtin_collision(pack_id: &str) -> Option<&'static str> {
REGISTRY.get(pack_id).map(|pack| pack.name)
}
pub fn validate_pack_with_collision_check(pack: &ExternalPack) -> Result<(), PackParseError> {
validate_pack(pack)?;
if let Some(builtin_name) = check_builtin_collision(&pack.id) {
return Err(PackParseError::IdCollision {
id: pack.id.clone(),
builtin_name: builtin_name.to_string(),
});
}
Ok(())
}
pub fn parse_pack_file_checked(path: &Path) -> Result<ExternalPack, PackParseError> {
let content = std::fs::read_to_string(path)?;
parse_pack_string_checked(&content)
}
pub fn parse_pack_string_checked(content: &str) -> Result<ExternalPack, PackParseError> {
let pack: ExternalPack = serde_yaml::from_str(content)?;
validate_pack_with_collision_check(&pack)?;
Ok(pack)
}
impl ExternalPack {
#[must_use]
pub fn into_pack(self) -> Pack {
let name: &'static str = Box::leak(self.name.into_boxed_str());
let description: &'static str = self
.description
.map_or("", |s| Box::leak(s.into_boxed_str()) as &'static str);
let keywords: &'static [&'static str] = if self.keywords.is_empty() {
&[]
} else {
let kw_vec: Vec<&'static str> = self
.keywords
.into_iter()
.map(|s| Box::leak(s.into_boxed_str()) as &'static str)
.collect();
Box::leak(kw_vec.into_boxed_slice())
};
let safe_patterns: Vec<SafePattern> = self
.safe_patterns
.into_iter()
.map(|p| {
let name: &'static str = Box::leak(p.name.into_boxed_str());
SafePattern {
regex: LazyCompiledRegex::new(Box::leak(p.pattern.into_boxed_str())),
name,
}
})
.collect();
let destructive_patterns: Vec<DestructivePattern> = self
.destructive_patterns
.into_iter()
.map(|p| {
let name: &'static str = Box::leak(p.name.into_boxed_str());
let reason: &'static str = p
.description
.map_or("Blocked by external pack pattern", |s| {
Box::leak(s.into_boxed_str()) as &'static str
});
let explanation: Option<&'static str> = p
.explanation
.map(|s| Box::leak(s.into_boxed_str()) as &'static str);
let suggestions: &'static [super::PatternSuggestion] = if p.suggestions.is_empty() {
&[]
} else {
let suggestion_vec: Vec<super::PatternSuggestion> = p
.suggestions
.into_iter()
.map(|s| super::PatternSuggestion {
command: Box::leak(s.command.into_boxed_str()),
description: Box::leak(s.description.into_boxed_str()),
platform: s.platform.into(),
})
.collect();
Box::leak(suggestion_vec.into_boxed_slice())
};
DestructivePattern {
regex: LazyCompiledRegex::new(Box::leak(p.pattern.into_boxed_str())),
reason,
name: Some(name),
severity: p.severity.into(),
explanation,
suggestions,
}
})
.collect();
Pack::new(
self.id,
name,
description,
keywords,
safe_patterns,
destructive_patterns,
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RegexEngineType {
Linear,
Backtracking,
}
impl fmt::Display for RegexEngineType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Linear => write!(f, "linear"),
Self::Backtracking => write!(f, "backtracking"),
}
}
}
#[derive(Debug)]
pub struct EngineSummary {
pub linear_count: usize,
pub backtracking_count: usize,
}
impl EngineSummary {
#[must_use]
pub const fn total(&self) -> usize {
self.linear_count + self.backtracking_count
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn linear_percentage(&self) -> f64 {
let total = self.total();
if total == 0 {
100.0
} else {
(self.linear_count as f64 / total as f64) * 100.0
}
}
}
#[derive(Debug)]
pub struct PatternEngineInfo {
pub name: String,
pub pattern: String,
pub is_destructive: bool,
pub engine: RegexEngineType,
}
#[must_use]
pub fn analyze_pack_engines(pack: &ExternalPack) -> Vec<PatternEngineInfo> {
use crate::packs::regex_engine::needs_backtracking_engine;
let mut results = Vec::new();
for pattern in &pack.destructive_patterns {
let engine = if needs_backtracking_engine(&pattern.pattern) {
RegexEngineType::Backtracking
} else {
RegexEngineType::Linear
};
results.push(PatternEngineInfo {
name: pattern.name.clone(),
pattern: pattern.pattern.clone(),
is_destructive: true,
engine,
});
}
for pattern in &pack.safe_patterns {
let engine = if needs_backtracking_engine(&pattern.pattern) {
RegexEngineType::Backtracking
} else {
RegexEngineType::Linear
};
results.push(PatternEngineInfo {
name: pattern.name.clone(),
pattern: pattern.pattern.clone(),
is_destructive: false,
engine,
});
}
results
}
#[must_use]
pub fn summarize_pack_engines(pack: &ExternalPack) -> EngineSummary {
let infos = analyze_pack_engines(pack);
let backtracking_count = infos
.iter()
.filter(|i| i.engine == RegexEngineType::Backtracking)
.count();
let linear_count = infos.len() - backtracking_count;
EngineSummary {
linear_count,
backtracking_count,
}
}
#[derive(Debug)]
pub struct LoadedExternalPack {
pub id: String,
pub pack: ExternalPack,
pub path: PathBuf,
}
#[derive(Debug)]
pub struct PackLoadWarning {
pub path: PathBuf,
pub error: PackParseError,
}
#[derive(Debug)]
pub struct ExternalPackLoadResult {
pub packs: Vec<LoadedExternalPack>,
pub warnings: Vec<PackLoadWarning>,
}
#[derive(Debug, Default)]
pub struct ExternalPackLoader {
paths: Vec<PathBuf>,
}
impl ExternalPackLoader {
#[must_use]
pub fn from_paths(paths: &[String]) -> Self {
let paths = paths.iter().map(PathBuf::from).collect();
Self { paths }
}
#[must_use]
pub fn paths(&self) -> &[PathBuf] {
&self.paths
}
#[must_use]
pub fn load_all(&self) -> ExternalPackLoadResult {
let mut packs = Vec::new();
let mut warnings = Vec::new();
for path in &self.paths {
match parse_pack_file_checked(path) {
Ok(pack) => {
let id = pack.id.clone();
packs.push(LoadedExternalPack {
id,
pack,
path: path.clone(),
});
}
Err(error) => {
warnings.push(PackLoadWarning {
path: path.clone(),
error,
});
}
}
}
ExternalPackLoadResult { packs, warnings }
}
#[must_use]
pub fn load_all_deduped(&self) -> ExternalPackLoadResult {
let mut warnings = Vec::new();
let mut order: Vec<String> = Vec::new();
let mut by_id: HashMap<String, LoadedExternalPack> = HashMap::new();
for path in &self.paths {
match parse_pack_file_checked(path) {
Ok(pack) => {
let id = pack.id.clone();
order.push(id.clone());
by_id.insert(
id.clone(),
LoadedExternalPack {
id,
pack,
path: path.clone(),
},
);
}
Err(error) => {
warnings.push(PackLoadWarning {
path: path.clone(),
error,
});
}
}
}
let mut seen = HashSet::new();
let mut packs_rev = Vec::new();
for id in order.iter().rev() {
if seen.insert(id.clone()) {
if let Some(pack) = by_id.remove(id) {
packs_rev.push(pack);
}
}
}
packs_rev.reverse();
ExternalPackLoadResult {
packs: packs_rev,
warnings,
}
}
}
#[cfg(test)]
#[allow(clippy::needless_raw_string_hashes)]
mod tests {
use super::*;
#[test]
fn test_parse_valid_pack() {
let yaml = r#"
schema_version: 1
id: test.example
name: Test Pack
version: 1.0.0
description: A test pack for unit testing
keywords:
- test
- example
destructive_patterns:
- name: test-pattern
pattern: test.*dangerous
severity: high
description: Blocks test dangerous commands
safe_patterns:
- name: test-safe
pattern: test.*safe
description: Allows test safe commands
"#;
let pack = parse_pack_string(yaml).unwrap();
assert_eq!(pack.id, "test.example");
assert_eq!(pack.name, "Test Pack");
assert_eq!(pack.version, "1.0.0");
assert_eq!(pack.keywords.len(), 2);
assert_eq!(pack.destructive_patterns.len(), 1);
assert_eq!(pack.safe_patterns.len(), 1);
}
#[test]
fn test_parse_minimal_pack() {
let yaml = r#"
id: minimal.pack
name: Minimal
version: 0.1.0
destructive_patterns:
- name: block-all
pattern: danger
"#;
let pack = parse_pack_string(yaml).unwrap();
assert_eq!(pack.id, "minimal.pack");
assert_eq!(pack.schema_version, 1); assert!(pack.keywords.is_empty());
assert_eq!(pack.destructive_patterns.len(), 1);
assert!(pack.safe_patterns.is_empty());
}
#[test]
fn test_invalid_id_format() {
let yaml = r#"
id: InvalidID
name: Test
version: 1.0.0
destructive_patterns:
- name: test
pattern: test
"#;
let result = parse_pack_string(yaml);
assert!(matches!(result, Err(PackParseError::InvalidId { .. })));
}
#[test]
fn test_invalid_id_missing_dot() {
let yaml = r#"
id: nodotinid
name: Test
version: 1.0.0
destructive_patterns:
- name: test
pattern: test
"#;
let result = parse_pack_string(yaml);
assert!(matches!(result, Err(PackParseError::InvalidId { .. })));
}
#[test]
fn test_invalid_version_format() {
let yaml = r#"
id: test.pack
name: Test
version: 1.0
destructive_patterns:
- name: test
pattern: test
"#;
let result = parse_pack_string(yaml);
assert!(matches!(result, Err(PackParseError::InvalidVersion { .. })));
}
#[test]
fn test_unsupported_schema_version() {
let yaml = r#"
schema_version: 999
id: test.pack
name: Test
version: 1.0.0
destructive_patterns:
- name: test
pattern: test
"#;
let result = parse_pack_string(yaml);
assert!(matches!(
result,
Err(PackParseError::UnsupportedSchemaVersion { .. })
));
}
#[test]
fn test_invalid_regex_pattern() {
let yaml = r#"
id: test.pack
name: Test
version: 1.0.0
destructive_patterns:
- name: bad-regex
pattern: "[invalid(regex"
"#;
let result = parse_pack_string(yaml);
assert!(matches!(result, Err(PackParseError::InvalidPattern { .. })));
}
#[test]
fn test_duplicate_pattern_name() {
let yaml = r#"
id: test.pack
name: Test
version: 1.0.0
destructive_patterns:
- name: duplicate
pattern: pattern1
- name: duplicate
pattern: pattern2
"#;
let result = parse_pack_string(yaml);
assert!(matches!(
result,
Err(PackParseError::DuplicatePattern { .. })
));
}
#[test]
fn test_duplicate_across_safe_and_destructive() {
let yaml = r#"
id: test.pack
name: Test
version: 1.0.0
destructive_patterns:
- name: duplicate
pattern: pattern1
safe_patterns:
- name: duplicate
pattern: pattern2
"#;
let result = parse_pack_string(yaml);
assert!(matches!(
result,
Err(PackParseError::DuplicatePattern { .. })
));
}
#[test]
fn test_empty_pack() {
let yaml = r#"
id: test.pack
name: Test
version: 1.0.0
"#;
let result = parse_pack_string(yaml);
assert!(matches!(result, Err(PackParseError::EmptyPack)));
}
#[test]
fn test_severity_levels() {
let yaml = r#"
id: test.pack
name: Test
version: 1.0.0
destructive_patterns:
- name: low
pattern: low
severity: low
- name: medium
pattern: medium
severity: medium
- name: high
pattern: high
severity: high
- name: critical
pattern: critical
severity: critical
"#;
let pack = parse_pack_string(yaml).unwrap();
assert_eq!(pack.destructive_patterns[0].severity, ExternalSeverity::Low);
assert_eq!(
pack.destructive_patterns[1].severity,
ExternalSeverity::Medium
);
assert_eq!(
pack.destructive_patterns[2].severity,
ExternalSeverity::High
);
assert_eq!(
pack.destructive_patterns[3].severity,
ExternalSeverity::Critical
);
}
#[test]
fn test_default_severity() {
let yaml = r#"
id: test.pack
name: Test
version: 1.0.0
destructive_patterns:
- name: no-severity
pattern: test
"#;
let pack = parse_pack_string(yaml).unwrap();
assert_eq!(
pack.destructive_patterns[0].severity,
ExternalSeverity::High
); }
#[test]
fn test_convert_to_pack() {
let yaml = r#"
id: test.example
name: Test Example Pack
version: 1.0.0
description: Testing conversion
keywords:
- test
destructive_patterns:
- name: block-test
pattern: dangerous
severity: critical
description: Blocks dangerous commands
explanation: This is a detailed explanation
safe_patterns:
- name: allow-safe
pattern: safe
description: Allows safe commands
"#;
let external = parse_pack_string(yaml).unwrap();
let pack = external.into_pack();
assert_eq!(pack.id, "test.example");
assert_eq!(pack.name, "Test Example Pack");
assert_eq!(pack.description, "Testing conversion");
assert_eq!(pack.keywords.len(), 1);
assert_eq!(pack.keywords[0], "test");
assert_eq!(pack.safe_patterns.len(), 1);
assert_eq!(pack.destructive_patterns.len(), 1);
assert_eq!(pack.destructive_patterns[0].severity, Severity::Critical);
}
#[test]
fn test_yaml_parse_error() {
let yaml = "invalid: yaml: content: [";
let result = parse_pack_string(yaml);
assert!(matches!(result, Err(PackParseError::Yaml(_))));
}
#[test]
fn test_error_display() {
let err = PackParseError::InvalidId {
id: "bad".to_string(),
reason: "test".to_string(),
};
assert!(err.to_string().contains("bad"));
let err = PackParseError::DuplicatePattern {
name: "dup".to_string(),
};
assert!(err.to_string().contains("dup"));
let err = PackParseError::IdCollision {
id: "core.git".to_string(),
builtin_name: "Git".to_string(),
};
assert!(err.to_string().contains("core.git"));
assert!(err.to_string().contains("Git"));
assert!(err.to_string().contains("collides"));
}
#[test]
fn test_collision_with_builtin_pack() {
let yaml = r#"
id: core.git
name: Malicious Override
version: 1.0.0
destructive_patterns:
- name: allow-everything
pattern: never-match-anything-12345
"#;
let result = parse_pack_string_checked(yaml);
assert!(matches!(result, Err(PackParseError::IdCollision { .. })));
if let Err(PackParseError::IdCollision { id, builtin_name }) = result {
assert_eq!(id, "core.git");
assert!(!builtin_name.is_empty());
}
}
#[test]
fn test_no_collision_with_custom_namespace() {
let yaml = r#"
id: mycompany.deploy
name: MyCompany Deploy
version: 1.0.0
destructive_patterns:
- name: block-prod
pattern: deploy.*prod
"#;
let result = parse_pack_string_checked(yaml);
assert!(result.is_ok());
}
#[test]
fn test_check_builtin_collision_function() {
let result = check_builtin_collision("core.git");
assert!(result.is_some());
let result = check_builtin_collision("database.postgresql");
assert!(result.is_some());
let result = check_builtin_collision("mycompany.custom");
assert!(result.is_none());
let result = check_builtin_collision("database.oracle");
assert!(result.is_none());
}
#[test]
fn test_parse_without_collision_check_allows_override() {
let yaml = r#"
id: core.git
name: Override Git
version: 1.0.0
destructive_patterns:
- name: test
pattern: test
"#;
let result = parse_pack_string(yaml);
assert!(result.is_ok());
}
#[test]
fn test_validate_pack_with_collision_check() {
let yaml = r#"
id: core.filesystem
name: Override Filesystem
version: 1.0.0
destructive_patterns:
- name: test
pattern: test
"#;
let pack: ExternalPack = serde_yaml::from_str(yaml).unwrap();
assert!(validate_pack(&pack).is_ok());
let result = validate_pack_with_collision_check(&pack);
assert!(matches!(result, Err(PackParseError::IdCollision { .. })));
}
}