use std::collections::HashMap;
use rsigma_eval::pipeline::sources::SourceType;
use rsigma_eval::pipeline::transformations::Transformation;
use rsigma_eval::{Pipeline, TransformationItem};
const MAX_INCLUDE_DEPTH: usize = 1;
pub fn expand_includes(
pipeline: &mut Pipeline,
resolved: &HashMap<String, serde_json::Value>,
allow_remote_include: bool,
) -> Result<(), String> {
expand_includes_with_depth(pipeline, resolved, allow_remote_include, 0)
}
fn expand_includes_with_depth(
pipeline: &mut Pipeline,
resolved: &HashMap<String, serde_json::Value>,
allow_remote_include: bool,
depth: usize,
) -> Result<(), String> {
if depth > MAX_INCLUDE_DEPTH {
return Err(
"recursive includes are not allowed (max depth 1); included content cannot itself contain include directives".to_string()
);
}
let mut expanded_transformations = Vec::new();
let mut had_include = false;
for item in &pipeline.transformations {
if let Transformation::Include { template } = &item.transformation {
had_include = true;
let source_id = extract_source_id(template);
if !allow_remote_include
&& let Some(source) = pipeline.sources.iter().find(|s| s.id == source_id)
{
match &source.source_type {
SourceType::Http { .. } | SourceType::Nats { .. } => {
return Err(format!(
"include references remote source '{}'; use --allow-remote-include to permit",
source_id
));
}
_ => {}
}
}
if let Some(data) = resolved.get(&source_id) {
let items = parse_transformation_array(data)?;
for parsed_item in &items {
if matches!(parsed_item.transformation, Transformation::Include { .. }) {
return Err(format!(
"included content from source '{}' contains nested include directives; recursive includes are not allowed (max depth 1)",
source_id
));
}
}
expanded_transformations.extend(items);
} else {
return Err(format!(
"include references unresolved source '{source_id}'"
));
}
} else {
expanded_transformations.push(item.clone());
}
}
if had_include {
pipeline.transformations = expanded_transformations;
}
Ok(())
}
fn extract_source_id(template: &str) -> String {
let trimmed = template.trim();
if let Some(inner) = trimmed.strip_prefix("${source.")
&& let Some(id) = inner.strip_suffix('}')
{
return id.split('.').next().unwrap_or(id).to_string();
}
trimmed.to_string()
}
fn parse_transformation_array(data: &serde_json::Value) -> Result<Vec<TransformationItem>, String> {
if !data.is_array() {
return Err("include source data must be an array of transformation objects".to_string());
}
let yaml_str =
serde_json::to_string(data).map_err(|e| format!("include serialization: {e}"))?;
let yaml_val: yaml_serde::Value = yaml_serde::from_str(&yaml_str)
.map_err(|e| format!("include data is not valid YAML: {e}"))?;
rsigma_eval::parse_transformation_items(&yaml_val)
.map_err(|e| format!("include parse error: {e}"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_source_id_simple() {
assert_eq!(
extract_source_id("${source.my_transforms}"),
"my_transforms"
);
}
#[test]
fn extract_source_id_with_path() {
assert_eq!(extract_source_id("${source.config.transforms}"), "config");
}
#[test]
fn extract_source_id_plain_string() {
assert_eq!(extract_source_id("my_source"), "my_source");
}
#[test]
fn nested_include_rejected() {
let mut pipeline = Pipeline {
name: "test".to_string(),
priority: 0,
vars: HashMap::new(),
transformations: vec![TransformationItem {
id: None,
transformation: Transformation::Include {
template: "${source.transforms}".to_string(),
},
rule_conditions: vec![],
rule_cond_expr: None,
detection_item_conditions: vec![],
field_name_conditions: vec![],
field_name_cond_not: false,
}],
finalizers: vec![],
sources: vec![],
source_refs: vec![],
};
let nested_yaml = serde_json::json!([
{"type": "include", "include": "${source.other}"}
]);
let mut resolved = HashMap::new();
resolved.insert("transforms".to_string(), nested_yaml);
let result = expand_includes(&mut pipeline, &resolved, true);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.contains("nested include") || err.contains("recursive"),
"error should mention nesting: {err}"
);
}
}