hushspec 0.1.0

Portable specification types for AI agent security rules
Documentation
use crate::extensions::{
    DetectionExtension, Extensions, JailbreakDetection, OriginsExtension, PostureExtension,
    PromptInjectionDetection, ThreatIntelDetection,
};
use crate::rules::Rules;
use crate::schema::{HushSpec, MergeStrategy};

#[must_use = "merged spec is returned, not applied in place"]
pub fn merge(base: &HushSpec, child: &HushSpec) -> HushSpec {
    let strategy = child.merge_strategy.unwrap_or_default();
    match strategy {
        MergeStrategy::Replace => {
            let mut result = child.clone();
            result.extends = None;
            result
        }
        MergeStrategy::Merge => merge_with_strategy(base, child, false),
        MergeStrategy::DeepMerge => merge_with_strategy(base, child, true),
    }
}

fn merge_with_strategy(base: &HushSpec, child: &HushSpec, deep: bool) -> HushSpec {
    HushSpec {
        hushspec: child.hushspec.clone(),
        name: child.name.clone().or_else(|| base.name.clone()),
        description: child
            .description
            .clone()
            .or_else(|| base.description.clone()),
        extends: None,
        merge_strategy: child.merge_strategy,
        rules: merge_rules(&base.rules, &child.rules),
        extensions: if deep {
            merge_extensions_deep(&base.extensions, &child.extensions)
        } else {
            merge_extensions_merge(&base.extensions, &child.extensions)
        },
        metadata: child.metadata.clone().or_else(|| base.metadata.clone()),
    }
}

fn merge_rules(base: &Option<Rules>, child: &Option<Rules>) -> Option<Rules> {
    match (base, child) {
        (_, Some(child_rules)) => {
            let base_rules = base.as_ref().cloned().unwrap_or_default();
            Some(Rules {
                forbidden_paths: child_rules
                    .forbidden_paths
                    .clone()
                    .or(base_rules.forbidden_paths),
                path_allowlist: child_rules
                    .path_allowlist
                    .clone()
                    .or(base_rules.path_allowlist),
                egress: child_rules.egress.clone().or(base_rules.egress),
                secret_patterns: child_rules
                    .secret_patterns
                    .clone()
                    .or(base_rules.secret_patterns),
                patch_integrity: child_rules
                    .patch_integrity
                    .clone()
                    .or(base_rules.patch_integrity),
                shell_commands: child_rules
                    .shell_commands
                    .clone()
                    .or(base_rules.shell_commands),
                tool_access: child_rules.tool_access.clone().or(base_rules.tool_access),
                computer_use: child_rules.computer_use.clone().or(base_rules.computer_use),
                remote_desktop_channels: child_rules
                    .remote_desktop_channels
                    .clone()
                    .or(base_rules.remote_desktop_channels),
                input_injection: child_rules
                    .input_injection
                    .clone()
                    .or(base_rules.input_injection),
            })
        }
        (Some(base_rules), None) => Some(base_rules.clone()),
        (None, None) => None,
    }
}

fn merge_extensions_merge(
    base: &Option<Extensions>,
    child: &Option<Extensions>,
) -> Option<Extensions> {
    match (base, child) {
        (_, Some(child_ext)) => {
            let base_ext = base.as_ref().cloned().unwrap_or_default();
            Some(Extensions {
                posture: child_ext.posture.clone().or(base_ext.posture),
                origins: child_ext.origins.clone().or(base_ext.origins),
                detection: child_ext.detection.clone().or(base_ext.detection),
            })
        }
        (Some(base_ext), None) => Some(base_ext.clone()),
        (None, None) => None,
    }
}

fn merge_extensions_deep(
    base: &Option<Extensions>,
    child: &Option<Extensions>,
) -> Option<Extensions> {
    match (base, child) {
        (_, Some(child_ext)) => {
            let base_ext = base.as_ref().cloned().unwrap_or_default();
            Some(Extensions {
                posture: merge_posture(&base_ext.posture, &child_ext.posture),
                origins: merge_origins(&base_ext.origins, &child_ext.origins),
                detection: merge_detection(&base_ext.detection, &child_ext.detection),
            })
        }
        (Some(base_ext), None) => Some(base_ext.clone()),
        (None, None) => None,
    }
}

fn merge_posture(
    base: &Option<PostureExtension>,
    child: &Option<PostureExtension>,
) -> Option<PostureExtension> {
    match (base, child) {
        (_, Some(child_posture)) => {
            if let Some(base_posture) = base {
                let mut states = base_posture.states.clone();
                for (name, state) in &child_posture.states {
                    states.insert(name.clone(), state.clone());
                }

                Some(PostureExtension {
                    initial: child_posture.initial.clone(),
                    states,
                    transitions: child_posture.transitions.clone(),
                })
            } else {
                Some(child_posture.clone())
            }
        }
        (Some(base_posture), None) => Some(base_posture.clone()),
        (None, None) => None,
    }
}

fn merge_origins(
    base: &Option<OriginsExtension>,
    child: &Option<OriginsExtension>,
) -> Option<OriginsExtension> {
    match (base, child) {
        (_, Some(child_origins)) => {
            if let Some(base_origins) = base {
                let mut merged_profiles = base_origins.profiles.clone();
                for child_profile in &child_origins.profiles {
                    if let Some(pos) = merged_profiles
                        .iter()
                        .position(|profile| profile.id == child_profile.id)
                    {
                        merged_profiles[pos] = child_profile.clone();
                    } else {
                        merged_profiles.push(child_profile.clone());
                    }
                }

                Some(OriginsExtension {
                    default_behavior: child_origins
                        .default_behavior
                        .or(base_origins.default_behavior),
                    profiles: merged_profiles,
                })
            } else {
                Some(child_origins.clone())
            }
        }
        (Some(base_origins), None) => Some(base_origins.clone()),
        (None, None) => None,
    }
}

fn merge_detection(
    base: &Option<DetectionExtension>,
    child: &Option<DetectionExtension>,
) -> Option<DetectionExtension> {
    match (base, child) {
        (_, Some(child_detection)) => {
            if let Some(base_detection) = base {
                Some(DetectionExtension {
                    prompt_injection: merge_prompt_injection(
                        &base_detection.prompt_injection,
                        &child_detection.prompt_injection,
                    ),
                    jailbreak: merge_jailbreak(
                        &base_detection.jailbreak,
                        &child_detection.jailbreak,
                    ),
                    threat_intel: merge_threat_intel(
                        &base_detection.threat_intel,
                        &child_detection.threat_intel,
                    ),
                })
            } else {
                Some(child_detection.clone())
            }
        }
        (Some(base_detection), None) => Some(base_detection.clone()),
        (None, None) => None,
    }
}

fn merge_prompt_injection(
    base: &Option<PromptInjectionDetection>,
    child: &Option<PromptInjectionDetection>,
) -> Option<PromptInjectionDetection> {
    match (base, child) {
        (_, Some(child_prompt)) => {
            if let Some(base_prompt) = base {
                Some(PromptInjectionDetection {
                    enabled: child_prompt.enabled.or(base_prompt.enabled),
                    warn_at_or_above: child_prompt
                        .warn_at_or_above
                        .or(base_prompt.warn_at_or_above),
                    block_at_or_above: child_prompt
                        .block_at_or_above
                        .or(base_prompt.block_at_or_above),
                    max_scan_bytes: child_prompt.max_scan_bytes.or(base_prompt.max_scan_bytes),
                })
            } else {
                Some(child_prompt.clone())
            }
        }
        (Some(base_prompt), None) => Some(base_prompt.clone()),
        (None, None) => None,
    }
}

fn merge_jailbreak(
    base: &Option<JailbreakDetection>,
    child: &Option<JailbreakDetection>,
) -> Option<JailbreakDetection> {
    match (base, child) {
        (_, Some(child_jailbreak)) => {
            if let Some(base_jailbreak) = base {
                Some(JailbreakDetection {
                    enabled: child_jailbreak.enabled.or(base_jailbreak.enabled),
                    block_threshold: child_jailbreak
                        .block_threshold
                        .or(base_jailbreak.block_threshold),
                    warn_threshold: child_jailbreak
                        .warn_threshold
                        .or(base_jailbreak.warn_threshold),
                    max_input_bytes: child_jailbreak
                        .max_input_bytes
                        .or(base_jailbreak.max_input_bytes),
                })
            } else {
                Some(child_jailbreak.clone())
            }
        }
        (Some(base_jailbreak), None) => Some(base_jailbreak.clone()),
        (None, None) => None,
    }
}

fn merge_threat_intel(
    base: &Option<ThreatIntelDetection>,
    child: &Option<ThreatIntelDetection>,
) -> Option<ThreatIntelDetection> {
    match (base, child) {
        (_, Some(child_threat_intel)) => {
            if let Some(base_threat_intel) = base {
                Some(ThreatIntelDetection {
                    enabled: child_threat_intel.enabled.or(base_threat_intel.enabled),
                    pattern_db: child_threat_intel
                        .pattern_db
                        .clone()
                        .or_else(|| base_threat_intel.pattern_db.clone()),
                    similarity_threshold: child_threat_intel
                        .similarity_threshold
                        .or(base_threat_intel.similarity_threshold),
                    top_k: child_threat_intel.top_k.or(base_threat_intel.top_k),
                })
            } else {
                Some(child_threat_intel.clone())
            }
        }
        (Some(base_threat_intel), None) => Some(base_threat_intel.clone()),
        (None, None) => None,
    }
}