magi-tool 0.0.3

provide tools for Magi AI agents
Documentation
use schemars::{
    JsonSchema,
    generate::SchemaSettings,
    transform::{self, Transform},
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::borrow::Cow;
use thiserror::Error;

#[cfg(feature = "pegboard")]
mod pegboard;
#[cfg(feature = "pegboard")]
pub use pegboard::{PegBoard, PegBoardError};

// Re-export commonly used rmcp types for convenience
#[cfg(feature = "pegboard")]
pub use rmcp::model::CallToolResult;

/// Error type for tool creation and schema parsing
#[derive(Debug, Error)]
pub enum ToolError {
    #[error("Failed to serialize JSON schema: {0}")]
    SchemaSerialization(#[from] serde_json::Error),
    #[error("Schema type field is required")]
    MissingSchemaType,
    #[error("Schema must be of type 'object', got: {0}")]
    InvalidSchemaType(String),
}

/// Tool schema definition for function calling.
/// Compatible with Claude's tool format.
#[derive(Debug, Deserialize, Clone, Default, PartialEq, Serialize)]
pub struct Tool {
    pub name: Cow<'static, str>,
    pub description: Option<Cow<'static, str>>,
    pub input_schema: serde_json::Value,
}

#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct AddNullable {
    /// When set to `true` (the default), `"null"` will also be removed from the schemas `type`.
    pub remove_null_type: bool,
}

impl Default for AddNullable {
    fn default() -> Self {
        Self {
            remove_null_type: true,
        }
    }
}

impl AddNullable {
    fn has_type(schema: &schemars::Schema, ty: &str) -> bool {
        match schema.get("type") {
            Some(Value::Array(values)) => values.iter().any(|v| v.as_str() == Some(ty)),
            Some(Value::String(s)) => s == ty,
            _ => false,
        }
    }
}

impl Transform for AddNullable {
    fn transform(&mut self, schema: &mut schemars::Schema) {
        if Self::has_type(schema, "null") {
            // Don't add nullable property, just handle the null type removal
            if let Some(ty) = schema.get_mut("type")
                && self.remove_null_type
            {
                // Remove null from type array and clean up enum if present
                if let Value::Array(array) = ty {
                    array.retain(|t| t.as_str() != Some("null"));
                    if array.len() == 1 {
                        *ty = array[0].clone();
                    }
                }

                // Also clean up enum arrays that contain null
                if let Some(Value::Array(enum_array)) = schema.get_mut("enum") {
                    enum_array.retain(|v| !v.is_null());
                }
            }
        }

        transform::transform_subschemas(self, schema);
    }
}

pub fn get_tool<T: JsonSchema, S1, S2>(name: S1, desc: Option<S2>) -> Result<Tool, ToolError>
where
    S1: Into<Cow<'static, str>>,
    S2: Into<Cow<'static, str>>,
{
    let json_value = parse_input_schema::<T>()?;
    Ok(Tool {
        name: name.into(),
        description: desc.map(Into::into),
        input_schema: json_value,
    })
}

#[macro_export]
macro_rules! define_tool {
    ($tool_name:ident, $function_name:expr, $description:expr, $param_type:ty) => {
        paste::paste! {
            static [<$tool_name _ONCE_LOCK>]: std::sync::OnceLock<Result<$crate::Tool, $crate::ToolError>> = ::std::sync::OnceLock::new();

            pub fn [<get_ $tool_name:lower>]() -> Result<&'static $crate::Tool, &'static $crate::ToolError> {
                [<$tool_name _ONCE_LOCK>].get_or_init(|| {
                    $crate::get_tool::<$param_type, _, _>(
                        $function_name,
                        Some($description),
                    )
                }).as_ref()
            }
        }
    };
}

pub fn parse_input_schema<T: JsonSchema>() -> Result<serde_json::Value, ToolError> {
    let settings = SchemaSettings::draft2019_09()
        .with(|s| {
            // s.option_nullable = false;
            // s.option_add_null_type = false;
            s.inline_subschemas = true;
        })
        .with_transform(AddNullable::default());
    let schema = settings.into_generator().into_root_schema_for::<T>();
    let mut json_value = serde_json::to_value(schema)?;
    let schema_type = json_value.get("type").ok_or(ToolError::MissingSchemaType)?;

    match schema_type {
        Value::String(s) if s == "object" => {
            // Valid object type
        }
        Value::String(s) => {
            return Err(ToolError::InvalidSchemaType(s.clone()));
        }
        other => {
            return Err(ToolError::InvalidSchemaType(format!("{:?}", other)));
        }
    }

    if let Some(obj) = json_value.as_object_mut() {
        obj.remove("$schema");
        obj.remove("title");
        obj.remove("definitions");
    }
    Ok(json_value)
}

#[cfg(test)]
mod tests {
    use super::*;
    use schemars::JsonSchema;

    #[derive(JsonSchema, serde::Deserialize)]
    #[allow(dead_code)]
    pub struct WeatherParams {
        /// The city and state, e.g. San Francisco, CA
        pub location: String,
        pub unit: Option<UnitEnum>,
        pub arr: Option<Vec<String>>,
    }

    #[derive(JsonSchema, serde::Deserialize, PartialEq, Debug)]
    #[serde(rename_all = "lowercase")]
    pub enum UnitEnum {
        Celsius,
        Fahrenheit,
    }

    #[test]
    fn test_parse_input_schema() {
        let schema = parse_input_schema::<WeatherParams>().unwrap();

        let got = serde_json::to_string(&schema).unwrap();
        let want = serde_json::json!({
            "type": "object",
            "properties": {
                "location": {
                    "type": "string",
                    "description": "The city and state, e.g. San Francisco, CA"
                },
                "unit": {
                    "type": "string",
                    "enum": ["celsius", "fahrenheit"]
                },
                "arr": {
                    "type": "array",
                    "items": {
                        "type": "string"
                    }
                }
            },
            "required": ["location"]
        })
        .to_string();
        assert_eq!(want, got, "Expected: {} Got: {}", want, got);
    }

    #[test]
    fn test_get_tool() {
        let tool = get_tool::<WeatherParams, _, _>(
            "get_weather",
            Some("Get the current weather in a given location"),
        )
        .unwrap();

        assert_eq!(tool.name, "get_weather");
        assert_eq!(
            tool.description,
            Some(Cow::Borrowed("Get the current weather in a given location"))
        );

        // Verify input_schema structure
        let schema = &tool.input_schema;
        assert_eq!(schema.get("type").and_then(|v| v.as_str()), Some("object"));
        assert!(schema.get("properties").is_some());
        assert!(schema.get("required").is_some());
    }

    #[test]
    fn test_deserialize_params() {
        let s = r#"{"location":"Boston, MA","unit":"celsius"}"#;
        let got: WeatherParams = serde_json::from_str(s).unwrap();
        assert_eq!(
            "Boston, MA", got.location,
            "Expected: Boston, MA Got: {}",
            got.location
        );
        assert_eq!(
            Some(UnitEnum::Celsius),
            got.unit,
            "Expected: celsius got: {:?}",
            got.unit
        );
    }

    #[test]
    fn test_define_tool_macro() {
        define_tool!(
            WEATHER_TOOL,
            "get_weather",
            "Get the current weather",
            WeatherParams
        );

        let tool = get_weather_tool().unwrap();
        assert_eq!(tool.name, "get_weather");
        assert_eq!(
            tool.description,
            Some(Cow::Borrowed("Get the current weather"))
        );
        assert!(tool.input_schema.get("properties").is_some());
        assert!(tool.input_schema.get("required").is_some());
    }

    #[test]
    fn test_nullable_handling() {
        #[derive(JsonSchema)]
        #[allow(dead_code)]
        struct TestNullable {
            required_field: String,
            optional_field: Option<String>,
        }

        let schema = parse_input_schema::<TestNullable>().unwrap();

        // Verify that optional fields don't have "null" in their type
        let props = schema.get("properties").unwrap().as_object().unwrap();
        let optional = props.get("optional_field").unwrap();

        // Should not have null type due to AddNullable transform
        if let Some(type_val) = optional.get("type") {
            assert_ne!(type_val.as_str(), Some("null"));
        }

        // Only required_field should be in required array
        let required = schema.get("required").unwrap().as_array().unwrap();
        assert_eq!(required.len(), 1);
        assert_eq!(required[0].as_str(), Some("required_field"));
    }
}