use http::Method;
use crate::request::{
RequestProtocol,
WsSendKindTemplate,
parse_within_duration,
};
use super::{
Rule,
request_error,
};
pub(in crate::parser) fn validate_protocol_directives(
protocol: RequestProtocol,
has_graphql_directives: bool,
has_mcp_directives: bool,
has_ws_directives: bool,
has_sse_directives: bool,
request_span: pest::Span<'_>,
) -> Result<(), pest::error::Error<Rule>> {
match protocol {
RequestProtocol::Http => {
if has_graphql_directives {
return Err(request_error(
"GraphQL directives require 'protocol = graphql'",
request_span,
));
}
if has_mcp_directives {
return Err(request_error(
"MCP directives require 'protocol = mcp'",
request_span,
));
}
if has_ws_directives {
return Err(request_error(
"WebSocket directives require 'protocol = ws'",
request_span,
));
}
if has_sse_directives {
return Err(request_error(
"SSE directives require 'protocol = sse'",
request_span,
));
}
}
RequestProtocol::Graphql => {
if has_mcp_directives {
return Err(request_error(
"MCP directives require 'protocol = mcp'",
request_span,
));
}
if has_ws_directives {
return Err(request_error(
"WebSocket directives require 'protocol = ws'",
request_span,
));
}
if has_sse_directives {
return Err(request_error(
"SSE directives require 'protocol = sse'",
request_span,
));
}
}
RequestProtocol::Mcp => {
if has_graphql_directives {
return Err(request_error(
"GraphQL directives require 'protocol = graphql'",
request_span,
));
}
if has_sse_directives {
return Err(request_error(
"SSE directives require 'protocol = sse'",
request_span,
));
}
if has_ws_directives {
return Err(request_error(
"WebSocket directives require 'protocol = ws'",
request_span,
));
}
}
RequestProtocol::Sse => {
if has_graphql_directives {
return Err(request_error(
"GraphQL directives require 'protocol = graphql'",
request_span,
));
}
if has_mcp_directives {
return Err(request_error(
"MCP directives require 'protocol = mcp'",
request_span,
));
}
if has_ws_directives {
return Err(request_error(
"WebSocket directives require 'protocol = ws'",
request_span,
));
}
}
RequestProtocol::Ws => {
if has_graphql_directives {
return Err(request_error(
"GraphQL directives require 'protocol = graphql'",
request_span,
));
}
if has_mcp_directives {
return Err(request_error(
"MCP directives require 'protocol = mcp'",
request_span,
));
}
}
}
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(in crate::parser) enum ValidatedMcpCallKind {
Initialize,
ToolsList,
ResourcesList,
ToolsCall,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(in crate::parser) enum ValidatedSseActionKind {
Open,
Receive,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(in crate::parser) enum ValidatedWsActionKind {
Open,
Send,
Exchange,
Receive,
}
pub(in crate::parser) fn validate_graphql_request<'a, F>(
request_span: pest::Span<'a>,
method: Method,
has_form_fields: bool,
has_document: bool,
has_valid_document_content_type: bool,
graphql_document_error_span: pest::Span<'a>,
mut validate_variables_json: F,
) -> Result<(), pest::error::Error<Rule>>
where
F: FnMut() -> Result<(), pest::error::Error<Rule>>,
{
if method != Method::POST {
return Err(request_error(
"GraphQL requests currently require POST",
request_span.clone(),
));
}
if has_form_fields {
return Err(request_error(
"GraphQL requests do not support form fields",
request_span.clone(),
));
}
if !has_document || !has_valid_document_content_type {
return Err(request_error(
"GraphQL requests require a ~~~graphql document block",
graphql_document_error_span,
));
}
validate_variables_json()
}
pub(in crate::parser) fn validate_mcp_call<'a, 'b, F>(
request_span: pest::Span<'a>,
call: Option<&'b str>,
has_tool: bool,
arguments: Option<(&'b str, pest::Span<'a>)>,
has_protocol_version: bool,
has_client_name: bool,
has_client_version: bool,
capabilities: Option<(&'b str, pest::Span<'a>)>,
call_error_span: pest::Span<'a>,
mut validate_object_json: F,
) -> Result<ValidatedMcpCallKind, pest::error::Error<Rule>>
where
F: FnMut(&'b str, &str, pest::Span<'a>) -> Result<(), pest::error::Error<Rule>>,
{
let call = call.ok_or_else(|| request_error("MCP requests require 'call = ...'", request_span))?;
match call {
"initialize" => {
if has_tool || arguments.is_some() {
return Err(request_error(
"'tool' and 'arguments' are only valid with 'call = tools/call'",
call_error_span,
));
}
if let Some((raw, span)) = capabilities {
validate_object_json(raw, "capabilities", span)?;
}
Ok(ValidatedMcpCallKind::Initialize)
}
"tools/list" => {
if has_tool
|| arguments.is_some()
|| has_protocol_version
|| has_client_name
|| has_client_version
|| capabilities.is_some()
{
return Err(request_error(
"'call = tools/list' does not accept tool, arguments, or initialize override directives",
call_error_span,
));
}
Ok(ValidatedMcpCallKind::ToolsList)
}
"resources/list" => {
if has_tool
|| arguments.is_some()
|| has_protocol_version
|| has_client_name
|| has_client_version
|| capabilities.is_some()
{
return Err(request_error(
"'call = resources/list' does not accept tool, arguments, or initialize override directives",
call_error_span,
));
}
Ok(ValidatedMcpCallKind::ResourcesList)
}
"tools/call" => {
if has_protocol_version || has_client_name || has_client_version || capabilities.is_some() {
return Err(request_error(
"initialize override directives are only valid with 'call = initialize'",
call_error_span,
));
}
if !has_tool {
return Err(request_error(
"'call = tools/call' requires 'tool = ...'",
call_error_span,
));
}
if let Some((raw, span)) = arguments {
validate_object_json(raw, "arguments", span)?;
}
Ok(ValidatedMcpCallKind::ToolsCall)
}
other => Err(request_error(
format!(
"unsupported MCP call '{}' (supported: initialize, tools/list, resources/list, tools/call)",
other
)
.as_str(),
call_error_span,
)),
}
}
pub(in crate::parser) fn validate_sse_action<'a, 'b>(
request_span: pest::Span<'a>,
has_session: bool,
method: Method,
body_error: Option<(&str, pest::Span<'a>)>,
receive_seen: bool,
within: Option<(&'b str, pest::Span<'a>)>,
within_error_span: pest::Span<'a>,
) -> Result<ValidatedSseActionKind, pest::error::Error<Rule>> {
if !has_session {
return Err(request_error(
"SSE requests require 'session = ...'",
request_span.clone(),
));
}
if method != Method::GET {
return Err(request_error(
"SSE requests currently require GET",
request_span.clone(),
));
}
if let Some((message, span)) = body_error {
return Err(request_error(message, span));
}
if receive_seen {
let (within, _) = within.ok_or_else(|| {
request_error(
"SSE receive steps require 'within = ...'",
request_span.clone(),
)
})?;
parse_within_duration(within).map_err(|message| request_error(message.as_str(), within_error_span))?;
Ok(ValidatedSseActionKind::Receive)
} else {
if within.is_some() {
return Err(request_error(
"'within = ...' is only valid with 'receive'",
within_error_span,
));
}
Ok(ValidatedSseActionKind::Open)
}
}
pub(in crate::parser) fn validate_ws_action<'a, 'b>(
request_span: pest::Span<'a>,
has_session: bool,
method: Method,
form_error: Option<(&str, pest::Span<'a>)>,
receive_seen: bool,
has_send_directive: bool,
send_conflict_span: pest::Span<'a>,
send_kind: Option<WsSendKindTemplate>,
has_body: bool,
missing_body_span: pest::Span<'a>,
forbidden_body_error: Option<(&str, pest::Span<'a>)>,
within: Option<(&'b str, pest::Span<'a>)>,
) -> Result<ValidatedWsActionKind, pest::error::Error<Rule>> {
if !has_session {
return Err(request_error(
"WebSocket requests require 'session = ...'",
request_span.clone(),
));
}
if method != Method::GET {
return Err(request_error(
"WebSocket requests currently require GET",
request_span.clone(),
));
}
if let Some((message, span)) = form_error {
return Err(request_error(message, span));
}
if receive_seen {
if has_send_directive {
return Err(request_error(
"WebSocket requests cannot combine 'send = ...' with 'receive'",
send_conflict_span,
));
}
if let Some((message, span)) = forbidden_body_error {
return Err(request_error(message, span));
}
let (within, span) = within.ok_or_else(|| {
request_error(
"WebSocket receive steps require 'within = ...'",
request_span.clone(),
)
})?;
parse_within_duration(within).map_err(|message| request_error(message.as_str(), span))?;
return Ok(ValidatedWsActionKind::Receive);
}
if send_kind.is_some() {
if !has_body {
return Err(request_error(
"WebSocket send steps require a body block",
missing_body_span,
));
}
if let Some((within, span)) = within {
parse_within_duration(within)
.map_err(|message| request_error(message.as_str(), span))?;
return Ok(ValidatedWsActionKind::Exchange);
}
return Ok(ValidatedWsActionKind::Send);
}
if let Some((message, span)) = forbidden_body_error {
return Err(request_error(message, span));
}
if let Some((_, span)) = within {
return Err(request_error(
"'within = ...' is only valid with 'receive'",
span,
));
}
Ok(ValidatedWsActionKind::Open)
}