use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use oatf::primitives::{interpolate_template, interpolate_value};
use serde_json::{Value, json};
use tokio::io::AsyncReadExt;
use tokio::process::{Child, ChildStderr};
use tokio::sync::{mpsc, watch};
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use crate::engine::driver::PhaseDriver;
use crate::engine::types::{Direction, DriveResult, ProtocolEvent};
use crate::error::EngineError;
use super::handler::{normalize_action, server_request_handler};
use super::multiplexer::MessageMultiplexer;
use super::transport::{
McpClientTransportReader, McpClientTransportWriter, create_http_transport,
spawn_stdio_transport,
};
use super::{
CorrelatedResponse, DEFAULT_PHASE_TIMEOUT, DEFAULT_REQUEST_TIMEOUT, HANDLER_EVENT_BUFFER_SIZE,
HandlerState, INIT_TIMEOUT, NOTIFICATION_BUFFER_SIZE, NotificationMessage,
POST_ACTION_IDLE_TIMEOUT, PendingRequest, SERVER_REQUEST_BUFFER_SIZE,
};
pub struct McpClientDriver {
pub(super) writer: Arc<tokio::sync::Mutex<Box<dyn McpClientTransportWriter>>>,
pub(super) pending: Arc<std::sync::Mutex<HashMap<String, PendingRequest>>>,
pub(super) mux: Option<MessageMultiplexer>,
pub(super) notification_rx: Option<mpsc::Receiver<NotificationMessage>>,
pub(super) handler_event_rx: Option<mpsc::Receiver<ProtocolEvent>>,
pub(super) handler_state: Arc<tokio::sync::RwLock<HandlerState>>,
pub(super) handler_handle: Option<JoinHandle<()>>,
pub(super) server_capabilities: Option<Value>,
pub(super) negotiated_version: Option<String>,
pub(super) request_timeout: std::time::Duration,
pub(super) phase_timeout: std::time::Duration,
pub(super) initialized: bool,
pub(super) next_request_id: u64,
pub(super) raw_synthesize: bool,
pub(super) reader: Option<Box<dyn McpClientTransportReader>>,
pub(super) transport_cancel: CancellationToken,
pub(super) child: Option<Child>,
pub(super) child_stderr: Option<ChildStderr>,
}
impl McpClientDriver {
fn cleanup_pending_request(&self, id_key: &str) {
self.pending
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.remove(id_key);
if let Some(mux) = self.mux.as_ref() {
mux.remove_response(id_key);
}
}
async fn capture_stderr(&mut self) -> String {
let Some(mut stderr) = self.child_stderr.take() else {
return String::new();
};
let mut buf = vec![0u8; 4096];
match tokio::time::timeout(std::time::Duration::from_millis(100), stderr.read(&mut buf))
.await
{
Ok(Ok(n)) if n > 0 => String::from_utf8_lossy(&buf[..n]).to_string(),
_ => String::new(),
}
}
pub(super) const fn next_id(&mut self) -> u64 {
let id = self.next_request_id;
self.next_request_id += 1;
id
}
async fn send_and_await(
&mut self,
method: &str,
params: Option<Value>,
event_tx: &mpsc::Sender<ProtocolEvent>,
) -> Result<CorrelatedResponse, EngineError> {
let id = json!(self.next_id());
let id_key = id.to_string();
self.pending
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.insert(
id_key.clone(),
PendingRequest {
method: method.to_string(),
params: params.clone(),
},
);
let mux = self
.mux
.as_ref()
.ok_or_else(|| EngineError::Driver("multiplexer not started".to_string()))?;
let response_rx = mux.register_response(&id);
match tokio::time::timeout(self.request_timeout, async {
self.writer
.lock()
.await
.send_request_with_id(method, params.clone(), &id)
.await
})
.await
{
Ok(Ok(())) => {}
Ok(Err(err)) => {
self.cleanup_pending_request(&id_key);
return Err(err);
}
Err(_) => {
self.cleanup_pending_request(&id_key);
return Err(EngineError::Driver(format!(
"request timeout while sending '{method}' after {:?}",
self.request_timeout
)));
}
}
let _ = event_tx
.send(ProtocolEvent {
direction: Direction::Outgoing,
method: method.to_string(),
content: params.unwrap_or(Value::Null),
})
.await;
let response = match tokio::time::timeout(self.request_timeout, response_rx).await {
Ok(Ok(resp)) => resp,
Ok(Err(_)) => {
self.cleanup_pending_request(&id_key);
let reason = mux.close_reason();
let stderr = self.capture_stderr().await;
let mut msg = format!("multiplexer closed while awaiting '{method}': {reason}");
if !stderr.is_empty() {
use std::fmt::Write;
let _ = write!(msg, "\nserver stderr: {stderr}");
}
return Err(EngineError::Driver(msg));
}
Err(_) => {
self.cleanup_pending_request(&id_key);
return Err(EngineError::Driver(format!(
"request timeout for '{method}' after {:?}",
self.request_timeout
)));
}
};
let mut content = response.result.clone();
if let Some(ref req_params) = response.request_params
&& let Some(obj) = content.as_object_mut()
{
obj.insert("_request".to_string(), req_params.clone());
}
let _ = event_tx
.send(ProtocolEvent {
direction: Direction::Incoming,
method: response.method.clone(),
content,
})
.await;
Ok(response)
}
pub(super) fn forward_pending_events(&mut self, event_tx: &mpsc::Sender<ProtocolEvent>) {
if let Some(ref mut rx) = self.handler_event_rx {
while let Ok(evt) = rx.try_recv() {
let _ = event_tx.try_send(evt);
}
}
if let Some(ref mut rx) = self.notification_rx {
while let Ok(notif) = rx.try_recv() {
let _ = event_tx.try_send(ProtocolEvent {
direction: Direction::Incoming,
method: notif.method,
content: notif.params.unwrap_or(Value::Null),
});
}
}
}
#[allow(clippy::too_many_lines)]
pub(super) async fn initialize(
&mut self,
state: &Value,
event_tx: &mpsc::Sender<ProtocolEvent>,
) -> Result<(), EngineError> {
let init_params = json!({
"protocolVersion": "2025-11-25",
"capabilities": build_client_capabilities(state),
"clientInfo": {
"name": "ThoughtJack",
"version": env!("CARGO_PKG_VERSION"),
}
});
let id = json!(self.next_id());
let id_key = id.to_string();
self.pending
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.insert(
id_key.clone(),
PendingRequest {
method: "initialize".to_string(),
params: Some(init_params.clone()),
},
);
let mux = self
.mux
.as_ref()
.ok_or_else(|| EngineError::Driver("multiplexer not started".to_string()))?;
let response_rx = mux.register_response(&id);
match tokio::time::timeout(INIT_TIMEOUT, async {
self.writer
.lock()
.await
.send_request_with_id("initialize", Some(init_params.clone()), &id)
.await
})
.await
{
Ok(Ok(())) => {}
Ok(Err(err)) => {
self.cleanup_pending_request(&id_key);
return Err(err);
}
Err(_) => {
self.cleanup_pending_request(&id_key);
return Err(EngineError::Driver("initialization timeout".to_string()));
}
}
let _ = event_tx
.send(ProtocolEvent {
direction: Direction::Outgoing,
method: "initialize".to_string(),
content: init_params,
})
.await;
let response = match tokio::time::timeout(INIT_TIMEOUT, response_rx).await {
Ok(Ok(resp)) => resp,
Ok(Err(_)) => {
self.cleanup_pending_request(&id_key);
let reason = mux.close_reason();
let stderr = self.capture_stderr().await;
let mut msg = format!("multiplexer closed during initialization: {reason}");
if !stderr.is_empty() {
use std::fmt::Write;
let _ = write!(msg, "\nserver stderr: {stderr}");
}
return Err(EngineError::Driver(msg));
}
Err(_) => {
self.cleanup_pending_request(&id_key);
return Err(EngineError::Driver("initialization timeout".to_string()));
}
};
if response.is_error {
return Err(EngineError::Driver(format!(
"server rejected initialization: {}",
response.result
)));
}
self.server_capabilities = Some(response.result.clone());
self.negotiated_version = response
.result
.get("protocolVersion")
.and_then(Value::as_str)
.map(String::from);
let version = self
.negotiated_version
.clone()
.unwrap_or_else(|| "2025-11-25".to_string());
self.writer.lock().await.set_protocol_version(version).await;
let _ = event_tx
.send(ProtocolEvent {
direction: Direction::Incoming,
method: "initialize".to_string(),
content: response.result,
})
.await;
self.writer
.lock()
.await
.send_notification("notifications/initialized", None)
.await?;
self.initialized = true;
tracing::info!("MCP client initialization complete");
Ok(())
}
#[allow(clippy::cognitive_complexity)]
async fn execute_action(
&mut self,
action: &Value,
extractors: &HashMap<String, String>,
event_tx: &mpsc::Sender<ProtocolEvent>,
) -> Result<(), EngineError> {
let action_type = action["type"].as_str().unwrap_or("");
match action_type {
"list_tools" => {
self.send_and_await("tools/list", None, event_tx).await?;
}
"call_tool" => {
let raw_name = action["name"].as_str().unwrap_or_default();
let (name, name_diags) = interpolate_template(raw_name, extractors, None, None);
if !name_diags.is_empty() {
tracing::warn!(raw_name, diagnostics = ?name_diags, "interpolation warnings in call_tool name");
}
let arguments = action.get("arguments").cloned().unwrap_or(json!({}));
let (interpolated_args, args_diags) =
interpolate_value(&arguments, extractors, None, None);
if !args_diags.is_empty() {
tracing::warn!(diagnostics = ?args_diags, "interpolation warnings in call_tool arguments");
}
let params = json!({"name": name, "arguments": interpolated_args});
self.send_and_await("tools/call", Some(params), event_tx)
.await?;
}
"list_resources" => {
self.send_and_await("resources/list", None, event_tx)
.await?;
}
"read_resource" => {
let raw_uri = action["uri"].as_str().unwrap_or_default();
let (uri, uri_diags) = interpolate_template(raw_uri, extractors, None, None);
if !uri_diags.is_empty() {
tracing::warn!(diagnostics = ?uri_diags, "interpolation warnings in uri");
}
let params = json!({"uri": uri});
self.send_and_await("resources/read", Some(params), event_tx)
.await?;
}
"list_prompts" => {
self.send_and_await("prompts/list", None, event_tx).await?;
}
"get_prompt" => {
let raw_name = action["name"].as_str().unwrap_or_default();
let (name, name_diags) = interpolate_template(raw_name, extractors, None, None);
if !name_diags.is_empty() {
tracing::warn!(diagnostics = ?name_diags, "interpolation warnings in name");
}
let arguments = action.get("arguments").cloned().unwrap_or(json!({}));
let (interpolated_args, interpolated_args_diags) =
interpolate_value(&arguments, extractors, None, None);
if !interpolated_args_diags.is_empty() {
tracing::warn!(diagnostics = ?interpolated_args_diags, "interpolation warnings in interpolated_args");
}
let params = json!({"name": name, "arguments": interpolated_args});
self.send_and_await("prompts/get", Some(params), event_tx)
.await?;
}
"subscribe_resource" => {
let raw_uri = action["uri"].as_str().unwrap_or_default();
let (uri, uri_diags) = interpolate_template(raw_uri, extractors, None, None);
if !uri_diags.is_empty() {
tracing::warn!(diagnostics = ?uri_diags, "interpolation warnings in uri");
}
let params = json!({"uri": uri});
self.send_and_await("resources/subscribe", Some(params), event_tx)
.await?;
}
unknown => {
tracing::warn!(action_type = %unknown, "unknown MCP client action type, skipping");
}
}
Ok(())
}
pub(super) fn bootstrap(
&mut self,
extractors: watch::Receiver<HashMap<String, String>>,
) -> Result<(), EngineError> {
let reader = self.reader.take().ok_or_else(|| {
EngineError::Driver(
"MCP client reader already consumed (bootstrap called twice?)".into(),
)
})?;
let (server_request_tx, server_request_rx) = mpsc::channel(SERVER_REQUEST_BUFFER_SIZE);
let (notification_tx, notification_rx) = mpsc::channel(NOTIFICATION_BUFFER_SIZE);
let (handler_event_tx, handler_event_rx) = mpsc::channel(HANDLER_EVENT_BUFFER_SIZE);
let response_senders = Arc::new(std::sync::Mutex::new(HashMap::new()));
let close_reason = Arc::new(std::sync::Mutex::new(None));
let mux = MessageMultiplexer::spawn(
reader,
Arc::clone(&self.writer),
Arc::clone(&self.pending),
server_request_tx,
notification_tx,
response_senders,
close_reason,
self.transport_cancel.clone(),
);
let handler_handle = tokio::spawn(server_request_handler(
server_request_rx,
Arc::clone(&self.writer),
Arc::clone(&self.handler_state),
extractors, handler_event_tx,
self.raw_synthesize,
self.transport_cancel.clone(),
));
self.mux = Some(mux);
self.notification_rx = Some(notification_rx);
self.handler_event_rx = Some(handler_event_rx);
self.handler_handle = Some(handler_handle);
Ok(())
}
}
impl Drop for McpClientDriver {
fn drop(&mut self) {
self.transport_cancel.cancel();
if let Some(handle) = self.handler_handle.take() {
handle.abort();
}
if let Some(mux) = self.mux.take() {
mux.abort();
}
if let Some(child) = self.child.as_mut() {
let _ = child.start_kill();
}
}
}
#[async_trait]
impl PhaseDriver for McpClientDriver {
async fn drive_phase(
&mut self,
_phase_index: usize,
state: &Value,
extractors: watch::Receiver<HashMap<String, String>>,
event_tx: mpsc::Sender<ProtocolEvent>,
cancel: CancellationToken,
) -> Result<DriveResult, EngineError> {
if self.mux.is_none() {
self.bootstrap(extractors.clone())?;
}
if !self.initialized {
self.initialize(state, &event_tx).await?;
}
{
let mut hs = self.handler_state.write().await;
hs.state = state.clone();
}
let current_extractors = extractors.borrow().clone();
let has_actions = state
.get("actions")
.and_then(Value::as_array)
.is_some_and(|actions| !actions.is_empty());
if let Some(actions) = state.get("actions").and_then(Value::as_array) {
for action_value in actions {
self.forward_pending_events(&event_tx);
let normalized = normalize_action(action_value);
self.execute_action(&normalized, ¤t_extractors, &event_tx)
.await?;
}
}
let idle_timeout = if has_actions {
POST_ACTION_IDLE_TIMEOUT
} else {
self.phase_timeout
};
loop {
tokio::select! {
biased;
() = cancel.cancelled() => {
break;
}
result = async {
if let Some(ref mut handle) = self.handler_handle {
handle.await
} else {
std::future::pending().await
}
} => {
if let Err(join_err) = result {
tracing::error!(
error = %join_err,
"server request handler task panicked"
);
}
self.handler_handle = None;
break;
}
evt = async {
if let Some(ref mut rx) = self.handler_event_rx {
rx.recv().await
} else {
std::future::pending().await
}
} => {
if let Some(evt) = evt {
let _ = event_tx.send(evt).await;
} else {
break;
}
}
notif = async {
if let Some(ref mut rx) = self.notification_rx {
rx.recv().await
} else {
std::future::pending().await
}
} => {
if let Some(n) = notif {
let _ = event_tx.send(ProtocolEvent {
direction: Direction::Incoming,
method: n.method,
content: n.params.unwrap_or(Value::Null),
}).await;
} else {
break;
}
}
() = tokio::time::sleep(idle_timeout) => {
break;
}
}
}
Ok(DriveResult::Complete)
}
async fn on_phase_advanced(&mut self, _from: usize, _to: usize) -> Result<(), EngineError> {
Ok(())
}
}
pub(super) fn build_client_capabilities(state: &Value) -> Value {
let mut caps = json!({});
if state.get("sampling_responses").is_some() {
caps["sampling"] = json!({});
}
if state.get("roots").is_some() {
caps["roots"] = json!({"listChanged": false});
}
if state.get("elicitation_responses").is_some() {
caps["elicitation"] = json!({});
}
caps
}
pub fn create_mcp_client_driver(
command: Option<&str>,
args: &[String],
endpoint: Option<&str>,
headers: &[(String, String)],
raw_synthesize: bool,
) -> Result<McpClientDriver, EngineError> {
#[allow(clippy::type_complexity)]
let (reader, writer, child, child_stderr): (
Box<dyn McpClientTransportReader>,
Box<dyn McpClientTransportWriter>,
Option<Child>,
Option<ChildStderr>,
) = match (command, endpoint) {
(Some(cmd), _) => {
let (r, w, mut c) = spawn_stdio_transport(cmd, args)?;
let stderr = c.stderr.take();
(Box::new(r), Box::new(w), Some(c), stderr)
}
(None, Some(ep)) => {
let (r, w) = create_http_transport(ep, headers)?;
(Box::new(r), Box::new(w), None, None)
}
(None, None) => {
return Err(EngineError::Driver(
"mcp_client mode requires --mcp-client-command (stdio) \
or --mcp-client-endpoint (HTTP)"
.to_string(),
));
}
};
let transport_cancel = CancellationToken::new();
Ok(McpClientDriver {
writer: Arc::new(tokio::sync::Mutex::new(writer)),
pending: Arc::new(std::sync::Mutex::new(HashMap::new())),
mux: None,
notification_rx: None,
handler_event_rx: None,
handler_state: Arc::new(tokio::sync::RwLock::new(HandlerState {
state: Value::Null,
})),
handler_handle: None,
server_capabilities: None,
negotiated_version: None,
request_timeout: DEFAULT_REQUEST_TIMEOUT,
phase_timeout: DEFAULT_PHASE_TIMEOUT,
initialized: false,
next_request_id: 1,
raw_synthesize,
reader: Some(reader),
transport_cancel,
child,
child_stderr,
})
}