tmcp 0.4.0

Complete, ergonomic implementation of the Model Context Protocol (MCP)
Documentation
use std::collections::HashMap;

use serde::{
    Deserialize, Serialize,
    de::{DeserializeOwned, Error as DeError},
};
use serde_json::{Map, Value};

use crate::error::{Error, ToolError};

/// Generic argument map used for passing parameters to tools and prompts.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Arguments(pub(crate) Map<String, Value>);

impl Arguments {
    /// Create an empty argument set.
    pub fn new() -> Self {
        Self(Map::new())
    }

    /// Build arguments from any serializable struct.
    pub fn from_struct<T: Serialize>(value: T) -> Result<Self, serde_json::Error> {
        match serde_json::to_value(value)? {
            Value::Object(map) => Ok(Self(map)),
            Value::Null => Ok(Self::new()),
            _ => Err(DeError::custom("arguments must be a struct")),
        }
    }

    /// Insert a single key/value pair, returning the updated `Arguments`.
    ///
    /// This enables fluent construction without an intermediate `HashMap`.
    /// For infallible chaining with common types (strings, numbers, bools),
    /// use [`insert`](Self::insert) instead.
    pub fn set(
        mut self,
        key: impl Into<String>,
        value: impl Serialize,
    ) -> Result<Self, serde_json::Error> {
        let v = serde_json::to_value(value)?;
        self.0.insert(key.into(), v);
        Ok(self)
    }

    /// Insert a JSON value directly, enabling infallible fluent chaining.
    ///
    /// Use this for common types that can be converted to `Value`:
    /// - Strings: `"hello"`, `String::from("hello")`
    /// - Numbers: `42`, `3.14`
    /// - Booleans: `true`, `false`
    /// - Null: `()`
    ///
    /// # Example
    ///
    /// ```ignore
    /// let args = Arguments::new()
    ///     .insert("name", "Alice")
    ///     .insert("count", 42)
    ///     .insert("enabled", true);
    /// ```
    pub fn insert(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
        self.0.insert(key.into(), value.into());
        self
    }

    /// Deserialize the arguments into the desired type.
    pub fn deserialize<T: DeserializeOwned>(self) -> Result<T, serde_json::Error> {
        serde_json::from_value(Value::Object(self.0))
    }

    /// Get a typed value by key.
    pub fn get<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
        self.0
            .get(key)
            .and_then(|v| serde_json::from_value(v.clone()).ok())
    }

    /// Get the raw JSON value for a key.
    pub fn get_value(&self, key: &str) -> Option<&Value> {
        self.0.get(key)
    }

    /// Require a typed parameter, returning an error if missing or wrong type.
    pub fn require<T: DeserializeOwned>(&self, key: &str) -> crate::Result<T> {
        match self.0.get(key) {
            Some(value) => serde_json::from_value(value.clone()).map_err(|e| {
                Error::InvalidParams(format!("parameter '{}' is invalid: {}", key, e))
            }),
            None => Err(Error::InvalidParams(format!(
                "missing required parameter: {}",
                key
            ))),
        }
    }

    /// Deserialize into a typed struct with proper error handling.
    ///
    /// This wraps `deserialize()` to return the crate's `Result` type
    /// with an `InvalidParams` error on failure.
    pub fn into_params<T: DeserializeOwned>(self) -> crate::Result<T> {
        self.deserialize()
            .map_err(|e| Error::InvalidParams(format!("invalid parameters: {}", e)))
    }

    /// Deserialize tool arguments with consistent tool error handling.
    ///
    /// This is used by the tool routing generated by the macros to turn missing
    /// arguments or deserialization failures into tool errors.
    pub fn into_tool_params<T: DeserializeOwned>(
        arguments: Option<Self>,
        defaults: bool,
    ) -> Result<T, ToolError> {
        let args = match arguments {
            Some(args) => args,
            None => {
                if defaults {
                    Self::new()
                } else {
                    return Err(ToolError::invalid_input("Missing arguments"));
                }
            }
        };

        args.deserialize()
            .map_err(|err| ToolError::invalid_input(err.to_string()))
    }
}

impl From<HashMap<String, Value>> for Arguments {
    fn from(map: HashMap<String, Value>) -> Self {
        Self(map.into_iter().collect())
    }
}

impl From<HashMap<String, String>> for Arguments {
    fn from(map: HashMap<String, String>) -> Self {
        Self(
            map.into_iter()
                .map(|(k, v)| (k, Value::String(v)))
                .collect(),
        )
    }
}

#[cfg(test)]
mod tests {
    use std::collections::HashMap;

    use super::*;

    #[test]
    fn test_require() {
        let args = Arguments::new()
            .insert("name", "Alice")
            .insert("count", 42)
            .insert("enabled", true);

        // Present keys with various types
        assert_eq!(args.require::<String>("name").unwrap(), "Alice");
        assert_eq!(args.require::<i64>("count").unwrap(), 42);
        assert!(args.require::<bool>("enabled").unwrap());

        // Missing key
        let err = args.require::<String>("missing").unwrap_err();
        assert!(matches!(err, Error::InvalidParams(_)));
        assert!(err.to_string().contains("missing"));
    }

    #[test]
    fn test_into_params() {
        #[derive(Debug, PartialEq, serde::Deserialize)]
        struct Params {
            name: String,
            count: i64,
        }

        let args = Arguments::new().insert("name", "test").insert("count", 42);

        let params: Params = args.into_params().unwrap();
        assert_eq!(params.name, "test");
        assert_eq!(params.count, 42);
    }

    #[test]
    fn test_into_tool_params_defaults() {
        #[derive(Debug, PartialEq, serde::Deserialize)]
        struct Params {
            name: Option<String>,
        }

        let params: Params = Arguments::into_tool_params(None, true).unwrap();
        assert_eq!(params.name, None);
    }

    #[test]
    fn test_into_tool_params_missing() {
        let err = Arguments::into_tool_params::<HashMap<String, String>>(None, false).unwrap_err();
        assert_eq!(err.code, crate::TOOL_ERROR_INVALID_INPUT);
        assert!(err.message.contains("Missing arguments"));
    }

    #[test]
    fn test_into_params_missing_field() {
        #[allow(dead_code)] // Fields are read via serde deserialization
        #[derive(Debug, serde::Deserialize)]
        struct Params {
            name: String,
            count: i64,
        }

        let args = Arguments::new().insert("name", "test");
        // Missing 'count' field

        let err = args.into_params::<Params>().unwrap_err();
        assert!(matches!(err, Error::InvalidParams(_)));
        assert!(err.to_string().contains("count"));
    }
}