systemprompt-ai 0.1.19

Core AI module for systemprompt.io
Documentation
use super::transformer::TransformedTool;
use serde_json::Value;
use std::collections::HashMap;

#[derive(Debug)]
pub struct ToolNameMapper {
    forward_map: HashMap<String, (String, Option<String>, String)>,
    reverse_map: HashMap<String, Vec<String>>,
}

impl ToolNameMapper {
    pub fn new() -> Self {
        Self {
            forward_map: HashMap::new(),
            reverse_map: HashMap::new(),
        }
    }

    pub fn register_transformation(
        &mut self,
        transformed: &TransformedTool,
        discriminator_field: Option<String>,
    ) {
        let disc_field = discriminator_field.unwrap_or_else(|| "action".to_string());

        self.forward_map.insert(
            transformed.name.clone(),
            (
                transformed.original_name.clone(),
                transformed.discriminator_value.clone(),
                disc_field,
            ),
        );

        self.reverse_map
            .entry(transformed.original_name.clone())
            .or_default()
            .push(transformed.name.clone());
    }

    pub fn resolve_tool_call(&self, variant_name: &str, mut params: Value) -> (String, Value) {
        match self.forward_map.get(variant_name) {
            Some((original_name, Some(discriminator_value), discriminator_field)) => {
                if let Some(params_obj) = params.as_object_mut() {
                    params_obj.insert(
                        discriminator_field.clone(),
                        serde_json::json!(discriminator_value),
                    );
                }
                (original_name.clone(), params)
            },
            Some((original_name, None, _)) => (original_name.clone(), params),
            None => (variant_name.to_string(), params),
        }
    }

    pub fn get_variants(&self, original_name: &str) -> Option<&Vec<String>> {
        self.reverse_map.get(original_name)
    }

    pub fn is_variant(&self, tool_name: &str) -> bool {
        self.forward_map.contains_key(tool_name)
    }
}

impl Default for ToolNameMapper {
    fn default() -> Self {
        Self::new()
    }
}