systemprompt-models 0.2.1

Foundation data models for systemprompt.io AI governance infrastructure. Shared DTOs, config, and domain types consumed by every layer of the MCP governance pipeline.
Documentation
use serde::{Deserialize, Serialize};
use serde_json::Value;

use super::execution_plan::{PlannedToolCall, TemplateRef};
use super::tools::McpTool;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlanValidationError {
    pub tool_index: usize,
    pub argument: String,
    pub template: String,
    pub error: ValidationErrorKind,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ValidationErrorKind {
    InvalidTemplateSyntax,
    IndexOutOfBounds {
        referenced_index: usize,
        max_valid_index: usize,
    },
    SelfReference,
    ForwardReference {
        referenced_index: usize,
    },
    FieldNotFound {
        tool_name: String,
        field: String,
        available_fields: Vec<String>,
    },
    NoOutputSchema {
        tool_name: String,
    },
}

impl std::fmt::Display for PlanValidationError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match &self.error {
            ValidationErrorKind::InvalidTemplateSyntax => {
                write!(
                    f,
                    "Tool {}: Invalid template syntax '{}' for argument '{}'",
                    self.tool_index, self.template, self.argument
                )
            },
            ValidationErrorKind::IndexOutOfBounds {
                referenced_index,
                max_valid_index,
            } => {
                write!(
                    f,
                    "Tool {}: Template '{}' references tool {} but only tools 0-{} are available",
                    self.tool_index, self.template, referenced_index, max_valid_index
                )
            },
            ValidationErrorKind::SelfReference => {
                write!(
                    f,
                    "Tool {}: Template '{}' cannot reference itself",
                    self.tool_index, self.template
                )
            },
            ValidationErrorKind::ForwardReference { referenced_index } => {
                write!(
                    f,
                    "Tool {}: Template '{}' references tool {} which hasn't executed yet",
                    self.tool_index, self.template, referenced_index
                )
            },
            ValidationErrorKind::FieldNotFound {
                tool_name,
                field,
                available_fields,
            } => {
                write!(
                    f,
                    "Tool {}: Template '{}' references field '{}' but tool '{}' outputs: [{}]",
                    self.tool_index,
                    self.template,
                    field,
                    tool_name,
                    available_fields.join(", ")
                )
            },
            ValidationErrorKind::NoOutputSchema { tool_name } => {
                write!(
                    f,
                    "Tool {}: Template '{}' references '{}' which has no output schema",
                    self.tool_index, self.template, tool_name
                )
            },
        }
    }
}

impl std::error::Error for PlanValidationError {}

#[derive(Debug, Clone, Copy)]
pub struct TemplateValidator;

impl TemplateValidator {
    pub fn get_tool_output_schemas(
        calls: &[PlannedToolCall],
        tools: &[McpTool],
    ) -> Vec<(String, Option<Value>)> {
        calls
            .iter()
            .map(|call| {
                let output_schema = tools
                    .iter()
                    .find(|t| t.name == call.tool_name)
                    .and_then(|t| t.output_schema.clone());
                (call.tool_name.clone(), output_schema)
            })
            .collect()
    }

    pub fn find_templates_in_value(value: &Value) -> Vec<String> {
        let mut templates = Vec::new();
        Self::collect_templates(value, &mut templates);
        templates
    }

    fn collect_templates(value: &Value, templates: &mut Vec<String>) {
        match value {
            Value::String(s) if s.starts_with('$') && s.contains(".output.") => {
                templates.push(s.clone());
            },
            Value::Array(arr) => {
                for v in arr {
                    Self::collect_templates(v, templates);
                }
            },
            Value::Object(obj) => {
                for v in obj.values() {
                    Self::collect_templates(v, templates);
                }
            },
            Value::Null | Value::Bool(_) | Value::Number(_) | Value::String(_) => {},
        }
    }

    pub fn validate_plan(
        calls: &[PlannedToolCall],
        tool_output_schemas: &[(String, Option<Value>)],
    ) -> Result<(), Vec<PlanValidationError>> {
        let mut errors = Vec::new();

        for (tool_index, call) in calls.iter().enumerate() {
            for template in Self::find_templates_in_value(&call.arguments) {
                if let Some(err) =
                    Self::validate_template(tool_index, call, &template, tool_output_schemas)
                {
                    errors.push(err);
                }
            }
        }

        if errors.is_empty() {
            Ok(())
        } else {
            Err(errors)
        }
    }

    fn validate_template(
        tool_index: usize,
        call: &PlannedToolCall,
        template: &str,
        tool_output_schemas: &[(String, Option<Value>)],
    ) -> Option<PlanValidationError> {
        let make_error = |error: ValidationErrorKind| PlanValidationError {
            tool_index,
            argument: Self::find_argument_for_template(&call.arguments, template),
            template: template.to_string(),
            error,
        };

        let Some(template_ref) = TemplateRef::parse(template) else {
            return Some(make_error(ValidationErrorKind::InvalidTemplateSyntax));
        };

        if template_ref.tool_index == tool_index {
            return Some(make_error(ValidationErrorKind::SelfReference));
        }
        if template_ref.tool_index > tool_index {
            return Some(make_error(ValidationErrorKind::ForwardReference {
                referenced_index: template_ref.tool_index,
            }));
        }
        if template_ref.tool_index >= tool_output_schemas.len() {
            return Some(make_error(ValidationErrorKind::IndexOutOfBounds {
                referenced_index: template_ref.tool_index,
                max_valid_index: tool_output_schemas.len().saturating_sub(1),
            }));
        }

        let (ref_tool_name, ref_output_schema) = &tool_output_schemas[template_ref.tool_index];

        ref_output_schema.as_ref().map_or_else(
            || {
                Some(make_error(ValidationErrorKind::NoOutputSchema {
                    tool_name: ref_tool_name.clone(),
                }))
            },
            |schema| {
                Self::validate_field_access(&template_ref, schema, ref_tool_name).map(make_error)
            },
        )
    }

    fn validate_field_access(
        template_ref: &TemplateRef,
        schema: &Value,
        tool_name: &str,
    ) -> Option<ValidationErrorKind> {
        let first_field = template_ref.field_path.first()?;
        let available_fields = Self::get_schema_fields(schema);

        if available_fields.contains(first_field) {
            None
        } else {
            Some(ValidationErrorKind::FieldNotFound {
                tool_name: tool_name.to_string(),
                field: first_field.clone(),
                available_fields,
            })
        }
    }

    fn find_argument_for_template(value: &Value, template: &str) -> String {
        if let Value::Object(obj) = value {
            for (key, val) in obj {
                if let Value::String(s) = val {
                    if s == template {
                        return key.clone();
                    }
                }
                let nested = Self::find_argument_for_template(val, template);
                if !nested.is_empty() {
                    return format!("{key}.{nested}");
                }
            }
        }
        String::new()
    }

    fn get_schema_fields(schema: &Value) -> Vec<String> {
        schema
            .get("properties")
            .and_then(|p| p.as_object())
            .map_or_else(Vec::new, |obj| obj.keys().cloned().collect())
    }
}