use std::collections::HashMap;
use serde::{
Deserialize, Serialize,
de::{DeserializeOwned, Error as DeError},
};
use serde_json::{Map, Value};
use crate::error::{Error, ToolError};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Arguments(pub(crate) Map<String, Value>);
impl Arguments {
pub fn new() -> Self {
Self(Map::new())
}
pub fn from_struct<T: Serialize>(value: T) -> Result<Self, serde_json::Error> {
match serde_json::to_value(value)? {
Value::Object(map) => Ok(Self(map)),
Value::Null => Ok(Self::new()),
_ => Err(DeError::custom("arguments must be a struct")),
}
}
pub fn set(
mut self,
key: impl Into<String>,
value: impl Serialize,
) -> Result<Self, serde_json::Error> {
let v = serde_json::to_value(value)?;
self.0.insert(key.into(), v);
Ok(self)
}
pub fn insert(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
self.0.insert(key.into(), value.into());
self
}
pub fn deserialize<T: DeserializeOwned>(self) -> Result<T, serde_json::Error> {
serde_json::from_value(Value::Object(self.0))
}
pub fn get<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
self.0
.get(key)
.and_then(|v| serde_json::from_value(v.clone()).ok())
}
pub fn get_value(&self, key: &str) -> Option<&Value> {
self.0.get(key)
}
pub fn require<T: DeserializeOwned>(&self, key: &str) -> crate::Result<T> {
match self.0.get(key) {
Some(value) => serde_json::from_value(value.clone()).map_err(|e| {
Error::InvalidParams(format!("parameter '{}' is invalid: {}", key, e))
}),
None => Err(Error::InvalidParams(format!(
"missing required parameter: {}",
key
))),
}
}
pub fn into_params<T: DeserializeOwned>(self) -> crate::Result<T> {
self.deserialize()
.map_err(|e| Error::InvalidParams(format!("invalid parameters: {}", e)))
}
pub fn into_tool_params<T: DeserializeOwned>(
arguments: Option<Self>,
defaults: bool,
) -> Result<T, ToolError> {
let args = match arguments {
Some(args) => args,
None => {
if defaults {
Self::new()
} else {
return Err(ToolError::invalid_input("Missing arguments"));
}
}
};
args.deserialize()
.map_err(|err| ToolError::invalid_input(err.to_string()))
}
}
impl From<HashMap<String, Value>> for Arguments {
fn from(map: HashMap<String, Value>) -> Self {
Self(map.into_iter().collect())
}
}
impl From<HashMap<String, String>> for Arguments {
fn from(map: HashMap<String, String>) -> Self {
Self(
map.into_iter()
.map(|(k, v)| (k, Value::String(v)))
.collect(),
)
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
#[test]
fn test_require() {
let args = Arguments::new()
.insert("name", "Alice")
.insert("count", 42)
.insert("enabled", true);
assert_eq!(args.require::<String>("name").unwrap(), "Alice");
assert_eq!(args.require::<i64>("count").unwrap(), 42);
assert!(args.require::<bool>("enabled").unwrap());
let err = args.require::<String>("missing").unwrap_err();
assert!(matches!(err, Error::InvalidParams(_)));
assert!(err.to_string().contains("missing"));
}
#[test]
fn test_into_params() {
#[derive(Debug, PartialEq, serde::Deserialize)]
struct Params {
name: String,
count: i64,
}
let args = Arguments::new().insert("name", "test").insert("count", 42);
let params: Params = args.into_params().unwrap();
assert_eq!(params.name, "test");
assert_eq!(params.count, 42);
}
#[test]
fn test_into_tool_params_defaults() {
#[derive(Debug, PartialEq, serde::Deserialize)]
struct Params {
name: Option<String>,
}
let params: Params = Arguments::into_tool_params(None, true).unwrap();
assert_eq!(params.name, None);
}
#[test]
fn test_into_tool_params_missing() {
let err = Arguments::into_tool_params::<HashMap<String, String>>(None, false).unwrap_err();
assert_eq!(err.code, crate::TOOL_ERROR_INVALID_INPUT);
assert!(err.message.contains("Missing arguments"));
}
#[test]
fn test_into_params_missing_field() {
#[allow(dead_code)] #[derive(Debug, serde::Deserialize)]
struct Params {
name: String,
count: i64,
}
let args = Arguments::new().insert("name", "test");
let err = args.into_params::<Params>().unwrap_err();
assert!(matches!(err, Error::InvalidParams(_)));
assert!(err.to_string().contains("count"));
}
}