use std::collections::HashMap;
use std::sync::Arc;
use oatf::ResponseEntry;
use oatf::primitives::{interpolate_value, select_response};
use serde_json::{Value, json};
use tokio::sync::{mpsc, watch};
use tokio_util::sync::CancellationToken;
use crate::engine::types::{Direction, ProtocolEvent};
use crate::error::EngineError;
use crate::transport::jsonrpc::error_codes;
use super::transport::McpClientTransportWriter;
use super::{HandlerState, ServerRequestMessage};
#[allow(clippy::too_many_arguments, clippy::cognitive_complexity)]
pub(super) async fn server_request_handler(
mut server_request_rx: mpsc::Receiver<ServerRequestMessage>,
writer: Arc<tokio::sync::Mutex<Box<dyn McpClientTransportWriter>>>,
handler_state: Arc<tokio::sync::RwLock<HandlerState>>,
extractors_rx: watch::Receiver<HashMap<String, String>>,
handler_event_tx: mpsc::Sender<ProtocolEvent>,
raw_synthesize: bool,
cancel: CancellationToken,
) {
loop {
tokio::select! {
biased;
() = cancel.cancelled() => break,
msg = server_request_rx.recv() => {
let Some(req) = msg else { break };
let content = req.params.clone().unwrap_or(Value::Null);
if handler_event_tx.try_send(ProtocolEvent {
direction: Direction::Incoming,
method: req.method.clone(),
content: content.clone(),
}).is_err() {
tracing::error!("handler event backlog exceeded; cancelling MCP client session");
let _ = writer.lock().await
.send_error_response(&req.id, error_codes::INTERNAL_ERROR, "handler backlog exceeded")
.await;
cancel.cancel();
break;
}
let hs = handler_state.read().await;
let current_extractors = extractors_rx.borrow().clone();
let result = match req.method.as_str() {
"sampling/createMessage" => {
handle_sampling(&hs.state, ¤t_extractors, &content, raw_synthesize)
}
"elicitation/create" => {
Ok(handle_elicitation(&hs.state, ¤t_extractors, &content))
}
"roots/list" => Ok(handle_roots_list(&hs.state)),
"ping" => Ok(json!({})),
other => {
tracing::debug!(method = %other, "unknown server-initiated request, returning empty result");
Ok(json!({}))
}
};
drop(hs);
match result {
Ok(result_value) => {
if handler_event_tx.try_send(ProtocolEvent {
direction: Direction::Outgoing,
method: req.method.clone(),
content: result_value.clone(),
}).is_err() {
tracing::error!("handler event backlog exceeded on response; cancelling MCP client session");
cancel.cancel();
break;
}
let _ = writer.lock().await.send_response(&req.id, result_value).await;
}
Err(e) => {
tracing::warn!(method = %req.method, error = %e, "handler error, sending error response");
let _ = handler_event_tx.try_send(ProtocolEvent {
direction: Direction::Outgoing,
method: req.method.clone(),
content: json!({"error": e.to_string()}),
});
let _ = writer.lock().await
.send_error_response(&req.id, error_codes::INTERNAL_ERROR, &e.to_string())
.await;
}
}
}
}
}
}
pub(super) fn handle_sampling(
state: &Value,
extractors: &HashMap<String, String>,
params: &Value,
_raw_synthesize: bool,
) -> Result<Value, EngineError> {
let Some(responses_value) = state.get("sampling_responses") else {
return Ok(default_sampling_response());
};
let entries: Vec<ResponseEntry> = match serde_json::from_value(responses_value.clone()) {
Ok(e) => e,
Err(err) => {
tracing::warn!(error = %err, "failed to deserialize sampling_responses");
return Ok(default_sampling_response());
}
};
let Some(entry) = select_response(&entries, params) else {
return Ok(default_sampling_response());
};
if entry.synthesize.is_some() && entry.extra.is_empty() {
tracing::info!(
"sampling synthesize block encountered but GenerationProvider not available"
);
return Err(EngineError::Driver(
"synthesize not yet supported — GenerationProvider not available".to_string(),
));
}
let extra_value = serde_json::to_value(&entry.extra).unwrap_or(Value::Null);
let (interpolated, diagnostics) =
interpolate_value(&extra_value, extractors, Some(params), None);
for diag in &diagnostics {
tracing::debug!(diagnostic = ?diag, "sampling interpolation diagnostic");
}
Ok(interpolated)
}
pub(super) fn handle_elicitation(
state: &Value,
extractors: &HashMap<String, String>,
params: &Value,
) -> Value {
let Some(responses_value) = state.get("elicitation_responses") else {
return json!({"action": "cancel"});
};
let entries: Vec<ResponseEntry> = match serde_json::from_value(responses_value.clone()) {
Ok(e) => e,
Err(err) => {
tracing::warn!(error = %err, "failed to deserialize elicitation_responses");
return json!({"action": "cancel"});
}
};
let Some(entry) = select_response(&entries, params) else {
return json!({"action": "cancel"});
};
let extra_value = serde_json::to_value(&entry.extra).unwrap_or(Value::Null);
let (interpolated, diagnostics) =
interpolate_value(&extra_value, extractors, Some(params), None);
for diag in &diagnostics {
tracing::debug!(diagnostic = ?diag, "elicitation interpolation diagnostic");
}
interpolated
}
pub(super) fn handle_roots_list(state: &Value) -> Value {
state
.get("roots")
.map_or_else(|| json!({"roots": []}), |roots| json!({"roots": roots}))
}
pub(super) fn default_sampling_response() -> Value {
json!({
"role": "assistant",
"content": {"type": "text", "text": ""},
"model": "default",
"stopReason": "endTurn"
})
}
pub(super) fn normalize_action(value: &Value) -> Value {
match value {
Value::String(s) => {
json!({"type": s})
}
Value::Object(map) if map.len() == 1 && !map.contains_key("type") => {
let (key, val) = map.iter().next().expect("single-key object");
let mut normalized = json!({"type": key});
if let Value::Object(inner) = val {
for (k, v) in inner {
normalized[k] = v.clone();
}
}
normalized
}
other => {
other.clone()
}
}
}