use http::Method;
use crate::request::{
RequestProtocol,
WsSendKindTemplate,
parse_within_duration,
};
use super::{
Rule,
request_error,
};
pub(in crate::parser) fn validate_protocol_directives(
protocol: RequestProtocol,
has_graphql_directives: bool,
has_mcp_directives: bool,
has_ws_directives: bool,
has_sse_directives: bool,
request_span: pest::Span<'_>,
) -> Result<(), pest::error::Error<Rule>> {
match protocol {
RequestProtocol::Http => {
if has_graphql_directives {
return Err(request_error(
"GraphQL directives require 'protocol = graphql'",
request_span,
));
}
if has_mcp_directives {
return Err(request_error(
"MCP directives require 'protocol = mcp'",
request_span,
));
}
if has_ws_directives {
return Err(request_error(
"WebSocket directives require 'protocol = ws'",
request_span,
));
}
if has_sse_directives {
return Err(request_error(
"SSE directives require 'protocol = sse'",
request_span,
));
}
}
RequestProtocol::Graphql => {
if has_mcp_directives {
return Err(request_error(
"MCP directives require 'protocol = mcp'",
request_span,
));
}
if has_ws_directives {
return Err(request_error(
"WebSocket directives require 'protocol = ws'",
request_span,
));
}
if has_sse_directives {
return Err(request_error(
"SSE directives require 'protocol = sse'",
request_span,
));
}
}
RequestProtocol::Mcp => {
if has_graphql_directives {
return Err(request_error(
"GraphQL directives require 'protocol = graphql'",
request_span,
));
}
if has_sse_directives {
return Err(request_error(
"SSE directives require 'protocol = sse'",
request_span,
));
}
if has_ws_directives {
return Err(request_error(
"WebSocket directives require 'protocol = ws'",
request_span,
));
}
}
RequestProtocol::Sse => {
if has_graphql_directives {
return Err(request_error(
"GraphQL directives require 'protocol = graphql'",
request_span,
));
}
if has_mcp_directives {
return Err(request_error(
"MCP directives require 'protocol = mcp'",
request_span,
));
}
if has_ws_directives {
return Err(request_error(
"WebSocket directives require 'protocol = ws'",
request_span,
));
}
}
RequestProtocol::Ws => {
if has_graphql_directives {
return Err(request_error(
"GraphQL directives require 'protocol = graphql'",
request_span,
));
}
if has_mcp_directives {
return Err(request_error(
"MCP directives require 'protocol = mcp'",
request_span,
));
}
}
}
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(in crate::parser) enum ValidatedMcpCallKind {
Initialize,
ToolsList,
ResourcesList,
ToolsCall,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(in crate::parser) enum ValidatedSseActionKind {
Open,
Receive,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(in crate::parser) enum ValidatedWsActionKind {
Open,
Send,
Exchange,
Receive,
}
pub(in crate::parser) fn validate_graphql_request<'a, F>(
request_span: pest::Span<'a>,
method: Method,
has_form_fields: bool,
has_document: bool,
has_valid_document_content_type: bool,
graphql_document_error_span: pest::Span<'a>,
mut validate_variables_json: F,
) -> Result<(), pest::error::Error<Rule>>
where
F: FnMut() -> Result<(), pest::error::Error<Rule>>,
{
if method != Method::POST {
return Err(request_error(
"GraphQL requests currently require POST",
request_span.clone(),
));
}
if has_form_fields {
return Err(request_error(
"GraphQL requests do not support form fields",
request_span.clone(),
));
}
if !has_document || !has_valid_document_content_type {
return Err(request_error(
"GraphQL requests require a ~~~graphql document block",
graphql_document_error_span,
));
}
validate_variables_json()
}
pub(in crate::parser) fn validate_mcp_call<'a, 'b, F>(
request_span: pest::Span<'a>,
call: Option<&'b str>,
has_tool: bool,
arguments: Option<(&'b str, pest::Span<'a>)>,
has_protocol_version: bool,
has_client_name: bool,
has_client_version: bool,
capabilities: Option<(&'b str, pest::Span<'a>)>,
call_error_span: pest::Span<'a>,
mut validate_object_json: F,
) -> Result<ValidatedMcpCallKind, pest::error::Error<Rule>>
where
F: FnMut(&'b str, &str, pest::Span<'a>) -> Result<(), pest::error::Error<Rule>>,
{
let call = call.ok_or_else(|| request_error("MCP requests require 'call = ...'", request_span))?;
match call {
"initialize" => {
if has_tool || arguments.is_some() {
return Err(request_error(
"'tool' and 'arguments' are only valid with 'call = tools/call'",
call_error_span,
));
}
if let Some((raw, span)) = capabilities {
validate_object_json(raw, "capabilities", span)?;
}
Ok(ValidatedMcpCallKind::Initialize)
}
"tools/list" => {
let has_tool_arguments = has_tool || arguments.is_some();
let has_initialize_overrides = has_protocol_version
|| has_client_name
|| has_client_version
|| capabilities.is_some();
if has_tool_arguments || has_initialize_overrides {
let message = mcp_list_call_conflict_message(
"tools/list",
has_tool_arguments,
has_initialize_overrides,
);
return Err(request_error(
message.as_str(),
call_error_span,
));
}
Ok(ValidatedMcpCallKind::ToolsList)
}
"resources/list" => {
let has_tool_arguments = has_tool || arguments.is_some();
let has_initialize_overrides = has_protocol_version
|| has_client_name
|| has_client_version
|| capabilities.is_some();
if has_tool_arguments || has_initialize_overrides {
let message = mcp_list_call_conflict_message(
"resources/list",
has_tool_arguments,
has_initialize_overrides,
);
return Err(request_error(
message.as_str(),
call_error_span,
));
}
Ok(ValidatedMcpCallKind::ResourcesList)
}
"tools/call" => {
if has_protocol_version || has_client_name || has_client_version || capabilities.is_some() {
return Err(request_error(
"initialize override directives are only valid with 'call = initialize'",
call_error_span,
));
}
if !has_tool {
return Err(request_error(
"'call = tools/call' requires 'tool = ...'",
call_error_span,
));
}
if let Some((raw, span)) = arguments {
validate_object_json(raw, "arguments", span)?;
}
Ok(ValidatedMcpCallKind::ToolsCall)
}
other => Err(request_error(
format!(
"unsupported MCP call '{}' (supported: initialize, tools/list, resources/list, tools/call)",
other
)
.as_str(),
call_error_span,
)),
}
}
fn mcp_list_call_conflict_message(
current_call: &str,
has_tool_arguments: bool,
has_initialize_overrides: bool,
) -> String {
match (has_tool_arguments, has_initialize_overrides) {
(true, true) => format!(
"'call = {}' does not accept tool/arguments or initialize override directives; use 'call = tools/call' or 'call = initialize'",
current_call
),
(true, false) => format!(
"'call = {}' does not accept tool/arguments directives; use 'call = tools/call'",
current_call
),
(false, true) => format!(
"'call = {}' does not accept initialize override directives; use 'call = initialize'",
current_call
),
(false, false) => unreachable!("conflict message requires at least one incompatible directive family"),
}
}
pub(in crate::parser) fn validate_sse_action<'a, 'b>(
request_span: pest::Span<'a>,
has_session: bool,
method: Method,
body_error: Option<(&str, pest::Span<'a>)>,
receive_seen: bool,
within: Option<(&'b str, pest::Span<'a>)>,
within_error_span: pest::Span<'a>,
) -> Result<ValidatedSseActionKind, pest::error::Error<Rule>> {
if !has_session {
return Err(request_error(
"SSE requests require 'session = ...'",
request_span.clone(),
));
}
if method != Method::GET {
return Err(request_error(
"SSE requests currently require GET",
request_span.clone(),
));
}
if let Some((message, span)) = body_error {
return Err(request_error(message, span));
}
if receive_seen {
let (within, _) = within.ok_or_else(|| {
request_error(
"SSE receive steps require 'within = ...'",
request_span.clone(),
)
})?;
parse_within_duration(within).map_err(|message| request_error(message.as_str(), within_error_span))?;
Ok(ValidatedSseActionKind::Receive)
} else {
if within.is_some() {
return Err(request_error(
"'within = ...' is only valid with 'receive'",
within_error_span,
));
}
Ok(ValidatedSseActionKind::Open)
}
}
pub(in crate::parser) fn validate_ws_action<'a, 'b>(
request_span: pest::Span<'a>,
has_session: bool,
method: Method,
form_error: Option<(&str, pest::Span<'a>)>,
receive_seen: bool,
has_send_directive: bool,
send_conflict_span: pest::Span<'a>,
send_kind: Option<WsSendKindTemplate>,
has_body: bool,
missing_body_span: pest::Span<'a>,
forbidden_body_error: Option<(&str, pest::Span<'a>)>,
within: Option<(&'b str, pest::Span<'a>)>,
) -> Result<ValidatedWsActionKind, pest::error::Error<Rule>> {
if !has_session {
return Err(request_error(
"WebSocket requests require 'session = ...'",
request_span.clone(),
));
}
if method != Method::GET {
return Err(request_error(
"WebSocket requests currently require GET",
request_span.clone(),
));
}
if let Some((message, span)) = form_error {
return Err(request_error(message, span));
}
if receive_seen {
if has_send_directive {
return Err(request_error(
"WebSocket requests cannot combine 'send = ...' with 'receive'",
send_conflict_span,
));
}
if let Some((message, span)) = forbidden_body_error {
return Err(request_error(message, span));
}
let (within, span) = within.ok_or_else(|| {
request_error(
"WebSocket receive steps require 'within = ...'",
request_span.clone(),
)
})?;
parse_within_duration(within).map_err(|message| request_error(message.as_str(), span))?;
return Ok(ValidatedWsActionKind::Receive);
}
if let Some(send_kind) = send_kind {
if !has_body {
return Err(request_error(
&format!(
"WebSocket send kind '{}' requires a body block",
ws_send_kind_name(send_kind)
),
missing_body_span,
));
}
if let Some((within, span)) = within {
parse_within_duration(within)
.map_err(|message| request_error(message.as_str(), span))?;
return Ok(ValidatedWsActionKind::Exchange);
}
return Ok(ValidatedWsActionKind::Send);
}
if let Some((message, span)) = forbidden_body_error {
return Err(request_error(message, span));
}
if let Some((_, span)) = within {
return Err(request_error(
"'within = ...' is only valid with 'receive'",
span,
));
}
Ok(ValidatedWsActionKind::Open)
}
fn ws_send_kind_name(send_kind: WsSendKindTemplate) -> &'static str {
match send_kind {
WsSendKindTemplate::Text => "text",
WsSendKindTemplate::Json => "json",
}
}