systemprompt-ai 0.1.21

Core AI module for systemprompt.io
Documentation
use super::analyzer::DiscriminatedUnion;
use super::capabilities::ProviderCapabilities;
use super::sanitizer::SchemaSanitizer;
use crate::models::tools::McpTool;
use anyhow::{Result, anyhow};
use serde_json::{Map, Value, json};

fn merge_properties_into(
    target: &mut Map<String, Value>,
    source: &Value,
    exclude_field: Option<&str>,
) {
    let Some(source_props) = source.get("properties").and_then(|p| p.as_object()) else {
        return;
    };

    for (key, value) in source_props {
        if exclude_field.is_some_and(|f| f == key) {
            continue;
        }
        target.insert(key.clone(), value.clone());
    }
}

fn collect_required_fields(base: &Value, variant: &Value, discriminator_field: &str) -> Vec<Value> {
    let mut all_required = Vec::new();

    if let Some(base_arr) = base.get("required").and_then(|r| r.as_array()) {
        for item in base_arr {
            if let Some(field) = item.as_str() {
                if field != discriminator_field {
                    all_required.push(json!(field));
                }
            }
        }
    }

    if let Some(variant_arr) = variant.get("required").and_then(|r| r.as_array()) {
        for item in variant_arr {
            if !all_required.contains(item) {
                all_required.push(item.clone());
            }
        }
    }

    all_required
}

#[derive(Debug, Clone)]
pub struct TransformedTool {
    pub name: String,
    pub description: String,
    pub input_schema: Value,
    pub original_name: String,
    pub discriminator_value: Option<String>,
}

#[derive(Debug, Copy, Clone)]
pub struct SchemaTransformer {
    capabilities: ProviderCapabilities,
    sanitizer: SchemaSanitizer,
}

impl SchemaTransformer {
    pub const fn new(capabilities: ProviderCapabilities) -> Self {
        let sanitizer = SchemaSanitizer::new(capabilities);
        Self {
            capabilities,
            sanitizer,
        }
    }

    fn sanitize_function_name(name: &str) -> String {
        let mut result = String::new();

        for (i, ch) in name.chars().enumerate() {
            if i == 0 {
                if ch.is_alphabetic() || ch == '_' {
                    result.push(ch);
                } else if ch.is_numeric() {
                    result.push('_');
                    result.push(ch);
                } else {
                    result.push('_');
                }
            } else {
                match ch {
                    'a'..='z' | 'A'..='Z' | '0'..='9' | '_' | '.' | ':' | '-' => result.push(ch),
                    _ => result.push('_'),
                }
            }
        }

        if result.len() > 64 {
            result.truncate(64);
        }

        result
    }

    pub fn transform(&self, tool: &McpTool) -> Result<Vec<TransformedTool>> {
        let schema = tool
            .input_schema
            .as_ref()
            .ok_or_else(|| anyhow!("Tool '{}' missing required input_schema", tool.name))?;

        if !self.capabilities.requires_transformation(schema) {
            return Ok(vec![self.pass_through(tool)?]);
        }

        if let Some(union) = DiscriminatedUnion::detect(schema) {
            Ok(self.auto_split(tool, &union)?)
        } else {
            Ok(vec![self.pass_through(tool)?])
        }
    }

    fn pass_through(&self, tool: &McpTool) -> Result<TransformedTool> {
        let schema = tool
            .input_schema
            .as_ref()
            .ok_or_else(|| anyhow!("Tool '{}' missing required input_schema", tool.name))?
            .clone();

        let description = tool
            .description
            .as_ref()
            .filter(|d| !d.is_empty())
            .ok_or_else(|| anyhow!("Tool '{}' has empty or missing description", tool.name))?
            .clone();

        let sanitized_schema = self.sanitizer.sanitize(schema);

        Ok(TransformedTool {
            name: tool.name.clone(),
            description,
            input_schema: sanitized_schema,
            original_name: tool.name.clone(),
            discriminator_value: None,
        })
    }

    fn auto_split(
        &self,
        tool: &McpTool,
        union: &DiscriminatedUnion,
    ) -> Result<Vec<TransformedTool>> {
        let base_description = tool
            .description
            .as_ref()
            .filter(|d| !d.is_empty())
            .ok_or_else(|| anyhow!("Tool '{}' has empty or missing description", tool.name))?;

        let transformed_tools = union
            .variants
            .iter()
            .map(|(variant_value, variant_schema)| {
                let merged_schema = self.build_variant_schema(union, variant_schema, variant_value);

                TransformedTool {
                    name: Self::sanitize_function_name(&format!("{}_{}", tool.name, variant_value)),
                    description: format!(
                        "{} - {}",
                        base_description,
                        Self::humanize_variant_name(variant_value)
                    ),
                    input_schema: merged_schema,
                    original_name: tool.name.clone(),
                    discriminator_value: Some(variant_value.clone()),
                }
            })
            .collect();

        Ok(transformed_tools)
    }

    fn build_variant_schema(
        &self,
        union: &DiscriminatedUnion,
        variant_schema: &Value,
        _variant_value: &str,
    ) -> Value {
        let mut properties = Map::new();

        merge_properties_into(
            &mut properties,
            &union.base_properties,
            Some(&union.discriminator_field),
        );
        merge_properties_into(&mut properties, variant_schema, None);

        let all_required = collect_required_fields(
            &union.base_properties,
            variant_schema,
            &union.discriminator_field,
        );

        let mut merged_schema = json!({
            "type": "object",
            "properties": properties
        });

        if !all_required.is_empty() {
            merged_schema["required"] = json!(all_required);
        }

        self.sanitizer.sanitize(merged_schema)
    }

    fn humanize_variant_name(variant: &str) -> String {
        variant
            .chars()
            .enumerate()
            .map(|(i, c)| {
                if i == 0 {
                    c.to_uppercase().collect::<String>()
                } else {
                    c.to_string()
                }
            })
            .collect()
    }
}