use crate::Context;
use crate::error::{Error, ErrorCode};
use crate::types::request::RequestParamsMeta;
use crate::types::{Meta, ProgressToken};
use serde::de::DeserializeOwned;
use std::collections::hash_map::Iter;
#[cfg(feature = "tasks")]
use crate::types::RelatedTaskMetadata;
pub(crate) enum Payload<'a> {
Args(serde_json::Value),
Meta(&'a Option<RequestParamsMeta>),
}
pub(crate) enum Source {
Args,
Meta,
}
pub(crate) trait RequestArgument: Sized {
type Error;
fn extract(payload: Payload<'_>) -> Result<Self, Self::Error>;
#[inline]
fn source() -> Source {
Source::Args
}
}
impl<'a> Payload<'a> {
#[inline]
pub(crate) fn expect_args(self) -> serde_json::Value {
match self {
Payload::Args(val) => val,
_ => unreachable!("Expected Args variant"),
}
}
#[inline]
pub(crate) fn expect_meta(self) -> &'a Option<RequestParamsMeta> {
match self {
Payload::Meta(meta) => meta,
_ => unreachable!("Expected Meta variant"),
}
}
}
impl<T: DeserializeOwned> RequestArgument for T {
type Error = Error;
#[inline]
fn extract(payload: Payload<'_>) -> Result<Self, Self::Error> {
let arg = payload.expect_args();
T::deserialize(arg).map_err(Error::from)
}
}
impl RequestArgument for Meta<RequestParamsMeta> {
type Error = Error;
#[inline]
fn extract(payload: Payload<'_>) -> Result<Self, Self::Error> {
let meta = payload.expect_meta();
meta.clone()
.ok_or(Error::new(ErrorCode::InvalidParams, "Missing metadata"))
.map(Meta)
}
#[inline]
fn source() -> Source {
Source::Meta
}
}
impl RequestArgument for Meta<ProgressToken> {
type Error = Error;
#[inline]
fn extract(payload: Payload<'_>) -> Result<Self, Self::Error> {
let meta = payload.expect_meta();
meta.as_ref()
.and_then(|meta| meta.progress_token.clone())
.ok_or(Error::new(
ErrorCode::InvalidParams,
"Missing progress token",
))
.map(Meta)
}
#[inline]
fn source() -> Source {
Source::Meta
}
}
#[cfg(feature = "tasks")]
impl RequestArgument for Meta<RelatedTaskMetadata> {
type Error = Error;
#[inline]
fn extract(payload: Payload<'_>) -> Result<Self, Self::Error> {
let meta = payload.expect_meta();
meta.as_ref()
.and_then(|meta| meta.task.clone())
.ok_or(Error::new(
ErrorCode::InvalidParams,
"Missing progress token",
))
.map(Meta)
}
#[inline]
fn source() -> Source {
Source::Meta
}
}
impl RequestArgument for Context {
type Error = Error;
#[inline]
fn extract(payload: Payload<'_>) -> Result<Self, Self::Error> {
let meta = payload.expect_meta();
meta.as_ref()
.and_then(|meta| meta.context.clone())
.ok_or(Error::new(ErrorCode::InvalidParams, "Missing MCP context"))
}
#[inline]
fn source() -> Source {
Source::Meta
}
}
#[inline]
pub(crate) fn extract_arg<T: RequestArgument<Error = Error>>(
meta: &Option<RequestParamsMeta>,
iter: &mut Iter<'_, String, serde_json::Value>,
) -> Result<T, Error> {
match T::source() {
Source::Meta => T::extract(Payload::Meta(meta)),
Source::Args => T::extract(Payload::Args(
iter.next()
.ok_or(Error::new(
ErrorCode::InvalidParams,
"Invalid param provided",
))?
.1
.clone(),
)),
}
}