use std::collections::HashMap;
use http::Method;
use crate::request::{RequestProtocol, WsSendKindTemplate};
use super::{Rule, context};
#[derive(Debug, Clone)]
pub(super) struct SessionRequestTarget {
pub(super) protocol: RequestProtocol,
pub(super) method: Method,
pub(super) url: String,
}
pub(super) fn resolve_request_target(
method: Option<Method>,
url: Option<String>,
session_name: Option<&str>,
protocol: RequestProtocol,
session_targets: &HashMap<String, SessionRequestTarget>,
request_span: pest::Span<'_>,
) -> Result<(Method, String), pest::error::Error<Rule>> {
match (method, url) {
(Some(method), Some(url)) => Ok((method, url)),
(Some(_), None) => Err(pest::error::Error::<Rule>::new_from_span(
pest::error::ErrorVariant::CustomError {
message: "request is missing a URL".to_string(),
},
request_span,
)),
(None, Some(_)) => Err(pest::error::Error::<Rule>::new_from_span(
pest::error::ErrorVariant::CustomError {
message: "request is missing an HTTP method".to_string(),
},
request_span,
)),
(None, None) => {
let Some(session_name) = session_name else {
return Err(pest::error::Error::<Rule>::new_from_span(
pest::error::ErrorVariant::CustomError {
message: "request is missing an HTTP method".to_string(),
},
request_span,
));
};
let inherited = session_targets.get(session_name).ok_or_else(|| {
pest::error::Error::<Rule>::new_from_span(
pest::error::ErrorVariant::CustomError {
message: format!(
"session-backed request omitted its method/URL, but session '{}' has no earlier request to inherit from",
session_name
),
},
request_span.clone(),
)
})?;
if inherited.protocol != protocol {
return Err(pest::error::Error::<Rule>::new_from_span(
pest::error::ErrorVariant::CustomError {
message: format!(
"session-backed request for session '{}' uses protocol '{}', but the inherited target belongs to protocol '{}'",
session_name,
protocol.as_str(),
inherited.protocol.as_str()
),
},
request_span,
));
}
Ok((inherited.method.clone(), inherited.url.clone()))
}
}
}
pub(super) fn infer_request_protocol<'a>(
protocol_raw: Option<(&'a str, pest::Span<'a>)>,
session_name: Option<&str>,
has_mcp_directives: bool,
has_ws_directives: bool,
has_sse_directives: bool,
session_targets: &HashMap<String, SessionRequestTarget>,
) -> Result<RequestProtocol, pest::error::Error<Rule>> {
if let Some((protocol, span)) = protocol_raw {
return parse_request_protocol(protocol, span);
}
if let Some(session_name) = session_name {
if let Some(existing) = session_targets.get(session_name) {
return Ok(existing.protocol);
}
}
if has_mcp_directives {
return Ok(RequestProtocol::Mcp);
}
if has_ws_directives {
return Ok(RequestProtocol::Ws);
}
if has_sse_directives {
return Ok(RequestProtocol::Sse);
}
Ok(RequestProtocol::Http)
}
pub(super) fn ensure_session_protocol_compatible(
session_name: Option<&str>,
protocol: RequestProtocol,
session_targets: &HashMap<String, SessionRequestTarget>,
request_span: pest::Span<'_>,
) -> Result<(), pest::error::Error<Rule>> {
let Some(session_name) = session_name else {
return Ok(());
};
let Some(existing) = session_targets.get(session_name) else {
return Ok(());
};
if existing.protocol != protocol {
return Err(pest::error::Error::<Rule>::new_from_span(
pest::error::ErrorVariant::CustomError {
message: format!(
"session '{}' already uses protocol '{}', so this step cannot use protocol '{}'",
session_name,
existing.protocol.as_str(),
protocol.as_str()
),
},
request_span,
));
}
Ok(())
}
fn parse_request_protocol(
raw: &str,
span: pest::Span<'_>,
) -> Result<RequestProtocol, pest::error::Error<Rule>> {
match raw {
"http" => Ok(RequestProtocol::Http),
"graphql" => Ok(RequestProtocol::Graphql),
"mcp" => Ok(RequestProtocol::Mcp),
"sse" => Ok(RequestProtocol::Sse),
"ws" => Ok(RequestProtocol::Ws),
other => Err(pest::error::Error::<Rule>::new_from_span(
pest::error::ErrorVariant::CustomError {
message: format!(
"unsupported protocol '{}' (supported: http, graphql, mcp, sse, ws)",
other
),
},
span,
)),
}
}
pub(super) fn is_graphql_document_content_type(value: &str) -> bool {
matches!(value.trim(), "graphql" | "application/graphql")
}
pub(super) fn ws_send_kind_as_str(kind: WsSendKindTemplate) -> &'static str {
match kind {
WsSendKindTemplate::Text => "text",
WsSendKindTemplate::Json => "json",
}
}
fn parse_ws_send_directive_kind(raw: &str) -> Result<WsSendKindTemplate, String> {
match raw.trim() {
"text" => Ok(WsSendKindTemplate::Text),
"json" => Ok(WsSendKindTemplate::Json),
other => Err(format!(
"unsupported WebSocket send kind '{}' (supported: text, json)",
other
)),
}
}
fn parse_ws_body_content_type_kind(raw: &str) -> Result<WsSendKindTemplate, String> {
match raw.trim() {
"json" | "application/json" => Ok(WsSendKindTemplate::Json),
"text" | "text/plain" => Ok(WsSendKindTemplate::Text),
other => Err(format!(
"unsupported WebSocket body block type '{}' (supported: plain ~~~, ~~~text, ~~~text/plain, ~~~json, or ~~~application/json)",
other
)),
}
}
pub(super) fn infer_ws_send_kind(
explicit_send: Option<&str>,
body_content_type: Option<&str>,
has_body: bool,
) -> Result<Option<WsSendKindTemplate>, String> {
if explicit_send.is_none() && body_content_type.is_none() && !has_body {
return Ok(None);
}
let explicit_kind = explicit_send.map(parse_ws_send_directive_kind).transpose()?;
let normalized_body_content_type = body_content_type.map(str::trim);
let body_kind = normalized_body_content_type
.map(parse_ws_body_content_type_kind)
.transpose()?;
match (explicit_kind, body_kind, normalized_body_content_type) {
(Some(explicit_kind), Some(body_kind), Some(raw_body_content_type)) => {
if explicit_kind != body_kind {
return Err(format!(
"WebSocket send kind '{}' conflicts with body block type '{}'",
ws_send_kind_as_str(explicit_kind),
raw_body_content_type
));
}
Ok(Some(explicit_kind))
}
(Some(explicit_kind), Some(_), None) => Ok(Some(explicit_kind)),
(Some(explicit_kind), None, _) => Ok(Some(explicit_kind)),
(None, Some(body_kind), _) => Ok(Some(body_kind)),
(None, None, _) if has_body => Ok(Some(WsSendKindTemplate::Text)),
(None, None, _) => Ok(None),
}
}
pub(super) fn infer_syntax_ws_send_kind<'a>(
request_span: pest::Span<'a>,
ws_send_raw: Option<(&'a str, pest::Span<'a>)>,
body_content_type_raw: Option<(&'a str, pest::Span<'a>)>,
body_seen: bool,
scalar_map: &HashMap<String, String>,
) -> Result<Option<WsSendKindTemplate>, pest::error::Error<Rule>> {
let ws_send = ws_send_raw
.map(|(raw, _)| context::inject_from_variable(raw, scalar_map));
let ws_body_content_type = body_content_type_raw
.map(|(raw, _)| context::inject_from_variable(raw, scalar_map));
infer_ws_send_kind(
ws_send.as_deref(),
ws_body_content_type.as_deref(),
body_seen,
)
.map_err(|message| {
let span = body_content_type_raw
.as_ref()
.map(|(_, span)| span.clone())
.or_else(|| ws_send_raw.as_ref().map(|(_, span)| span.clone()))
.unwrap_or(request_span);
pest::error::Error::<Rule>::new_from_span(
pest::error::ErrorVariant::CustomError { message },
span,
)
})
}
pub(super) fn graphql_protocol_context_json(
operation_name: Option<&str>,
variables_json: Option<&str>,
) -> Option<serde_json::Value> {
let mut context = serde_json::Map::new();
if let Some(operation_name) = operation_name {
context.insert(
"operationName".to_string(),
serde_json::Value::String(operation_name.to_string()),
);
}
if let Some(variables_json) = variables_json {
match serde_json::from_str::<serde_json::Value>(variables_json) {
Ok(value) => {
context.insert("variables".to_string(), value);
}
Err(_) => {
context.insert(
"variablesText".to_string(),
serde_json::Value::String(variables_json.to_string()),
);
}
}
}
if context.is_empty() {
None
} else {
Some(serde_json::Value::Object(context))
}
}
pub(super) fn mcp_protocol_context_json(
session_name: Option<&str>,
call: Option<&str>,
tool: Option<&str>,
arguments_json: Option<&str>,
protocol_version: Option<&str>,
client_name: Option<&str>,
client_version: Option<&str>,
capabilities_json: Option<&str>,
) -> Option<serde_json::Value> {
let mut context = serde_json::Map::new();
if let Some(session_name) = session_name {
context.insert(
"sessionName".to_string(),
serde_json::Value::String(session_name.to_string()),
);
}
if let Some(call) = call {
context.insert(
"call".to_string(),
serde_json::Value::String(call.to_string()),
);
}
if let Some(tool) = tool {
context.insert(
"tool".to_string(),
serde_json::Value::String(tool.to_string()),
);
}
if let Some(protocol_version) = protocol_version {
context.insert(
"protocolVersion".to_string(),
serde_json::Value::String(protocol_version.to_string()),
);
}
if let Some(client_name) = client_name {
context.insert(
"clientName".to_string(),
serde_json::Value::String(client_name.to_string()),
);
}
if let Some(client_version) = client_version {
context.insert(
"clientVersion".to_string(),
serde_json::Value::String(client_version.to_string()),
);
}
if let Some(arguments_json) = arguments_json {
match serde_json::from_str::<serde_json::Value>(arguments_json) {
Ok(value) => {
context.insert("arguments".to_string(), value);
}
Err(_) => {
context.insert(
"argumentsText".to_string(),
serde_json::Value::String(arguments_json.to_string()),
);
}
}
}
if let Some(capabilities_json) = capabilities_json {
match serde_json::from_str::<serde_json::Value>(capabilities_json) {
Ok(value) => {
context.insert("capabilities".to_string(), value);
}
Err(_) => {
context.insert(
"capabilitiesText".to_string(),
serde_json::Value::String(capabilities_json.to_string()),
);
}
}
}
if context.is_empty() {
None
} else {
Some(serde_json::Value::Object(context))
}
}
pub(super) fn sse_protocol_context_json(
receive_seen: bool,
within: Option<&str>,
) -> Option<serde_json::Value> {
let mut context = serde_json::Map::new();
context.insert(
"action".to_string(),
serde_json::Value::String(if receive_seen { "receive" } else { "open" }.to_string()),
);
if let Some(within) = within {
context.insert(
"within".to_string(),
serde_json::Value::String(within.to_string()),
);
}
Some(serde_json::Value::Object(context))
}
pub(super) fn ws_protocol_context_json(
action: &str,
send_kind: Option<&str>,
within: Option<&str>,
) -> Option<serde_json::Value> {
let mut context = serde_json::Map::new();
context.insert(
"action".to_string(),
serde_json::Value::String(action.to_string()),
);
if let Some(kind) = send_kind {
context.insert(
"kind".to_string(),
serde_json::Value::String(kind.to_string()),
);
}
if let Some(within) = within {
context.insert(
"within".to_string(),
serde_json::Value::String(within.to_string()),
);
}
Some(serde_json::Value::Object(context))
}
fn ws_action_name(
receive_seen: bool,
send_kind: Option<WsSendKindTemplate>,
within: Option<&str>,
) -> &'static str {
if receive_seen {
"receive"
} else if send_kind.is_some() {
if within.is_some() {
"exchange"
} else {
"send"
}
} else {
"open"
}
}
pub(super) fn syntax_protocol_context_json(
protocol: RequestProtocol,
session_name: Option<&str>,
operation_name: Option<&str>,
graphql_variables: Option<&str>,
mcp_call: Option<&str>,
mcp_tool: Option<&str>,
mcp_arguments: Option<&str>,
mcp_protocol_version: Option<&str>,
mcp_client_name: Option<&str>,
mcp_client_version: Option<&str>,
mcp_capabilities: Option<&str>,
sse_receive_seen: bool,
ws_send_kind: Option<WsSendKindTemplate>,
sse_within: Option<&str>,
) -> Option<serde_json::Value> {
match protocol {
RequestProtocol::Http => None,
RequestProtocol::Graphql => {
graphql_protocol_context_json(operation_name, graphql_variables)
}
RequestProtocol::Mcp => mcp_protocol_context_json(
session_name,
mcp_call,
mcp_tool,
mcp_arguments,
mcp_protocol_version,
mcp_client_name,
mcp_client_version,
mcp_capabilities,
),
RequestProtocol::Sse => {
sse_protocol_context_json(sse_receive_seen, sse_within)
}
RequestProtocol::Ws => ws_protocol_context_json(
ws_action_name(sse_receive_seen, ws_send_kind, sse_within),
ws_send_kind.map(ws_send_kind_as_str),
sse_within,
),
}
}
pub(super) fn remember_session_target(
session_targets: &mut HashMap<String, SessionRequestTarget>,
session_name: Option<&str>,
protocol: RequestProtocol,
method: &Method,
url: &str,
) {
if let Some(session_name) = session_name {
session_targets.insert(
session_name.to_string(),
SessionRequestTarget {
protocol,
method: method.clone(),
url: url.to_string(),
},
);
}
}
pub(super) fn validate_mcp_object_json(
raw: &str,
field_name: &str,
span: pest::Span<'_>,
) -> Result<(), pest::error::Error<Rule>> {
let parsed = serde_json::from_str::<serde_json::Value>(raw).map_err(|err| {
pest::error::Error::<Rule>::new_from_span(
pest::error::ErrorVariant::CustomError {
message: format!("invalid MCP {} JSON: {}", field_name, err),
},
span.clone(),
)
})?;
if !parsed.is_object() {
return Err(pest::error::Error::<Rule>::new_from_span(
pest::error::ErrorVariant::CustomError {
message: format!("MCP {} must be a JSON object", field_name),
},
span,
));
}
Ok(())
}
pub(super) fn validate_graphql_variables_json(
raw: &str,
span: pest::Span<'_>,
) -> Result<(), pest::error::Error<Rule>> {
serde_json::from_str::<serde_json::Value>(raw).map_err(|err| {
pest::error::Error::<Rule>::new_from_span(
pest::error::ErrorVariant::CustomError {
message: format!("invalid GraphQL variables JSON: {}", err),
},
span,
)
})?;
Ok(())
}