use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use async_stream::stream;
use futures::StreamExt;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use serde_json::{Value, json};
use tokio::sync::{Mutex, RwLock};
use tokio_stream::Stream;
use crate::conversation::Conversation;
use crate::error::Error;
use crate::hooks::{Hooks, PostToolUseInput, PreToolUseInput, StopInput, UserPromptSubmitInput};
use crate::mcp_server::McpServer;
use crate::options::Options;
use crate::proto::control::{HookCallbackRequest, Request, ResponseEnvelope};
use crate::proto::{
ContentBlock, Incoming, Message, OutgoingUserMessage, RequestEnvelope, UserContent,
};
use crate::response::{RateLimitResponse, Response, Responses};
use crate::transport::Transport;
#[derive(Debug, Clone)]
enum HookCallbackEntry {
PreToolUse(usize),
PostToolUse(usize),
UserPromptSubmit(usize),
Stop(usize),
}
pub struct Client {
transport: Mutex<Transport>,
session_id: RwLock<Option<String>>,
responded_tool_ids: Mutex<HashSet<String>>,
mcp_servers: HashMap<String, Arc<McpServer>>,
hooks: Option<Hooks>,
hook_callbacks: HashMap<String, HookCallbackEntry>,
json_schema: Option<String>,
}
impl Client {
pub async fn new(mut options: Options) -> Result<Self, Error> {
let transport_options = options.to_transport_options();
let transport = Transport::new(&transport_options).await?;
let mcp_servers = options.mcp_servers().clone();
let hooks = options.take_hooks();
let json_schema = options.json_schema().map(|s| s.to_owned());
let hook_callbacks = Self::build_hook_callbacks(&hooks);
let client = Self {
transport: Mutex::new(transport),
session_id: RwLock::new(None),
responded_tool_ids: Mutex::new(HashSet::new()),
mcp_servers,
hooks,
hook_callbacks,
json_schema,
};
client.initialize().await?;
Ok(client)
}
fn build_hook_callbacks(hooks: &Option<Hooks>) -> HashMap<String, HookCallbackEntry> {
let mut callbacks = HashMap::new();
let Some(hooks) = hooks else {
return callbacks;
};
let mut id = 0;
for (idx, _) in hooks.pre_tool_use_hooks().enumerate() {
callbacks.insert(format!("hook_{id}"), HookCallbackEntry::PreToolUse(idx));
id += 1;
}
for (idx, _) in hooks.post_tool_use_hooks().enumerate() {
callbacks.insert(format!("hook_{id}"), HookCallbackEntry::PostToolUse(idx));
id += 1;
}
for (idx, _) in hooks.user_prompt_submit_hooks().enumerate() {
callbacks.insert(
format!("hook_{id}"),
HookCallbackEntry::UserPromptSubmit(idx),
);
id += 1;
}
for (idx, _) in hooks.stop_hooks().enumerate() {
callbacks.insert(format!("hook_{id}"), HookCallbackEntry::Stop(idx));
id += 1;
}
callbacks
}
async fn initialize(&self) -> Result<(), Error> {
let mut init_request = crate::proto::control::InitializeRequest::new();
if let Some(hooks) = self.build_hooks_config() {
init_request = init_request.with_hooks(hooks);
}
let mcp_names = self.mcp_servers.keys().cloned().collect::<Vec<_>>();
if !mcp_names.is_empty() {
init_request = init_request.with_sdk_mcp_servers(mcp_names);
}
let request = crate::proto::Request::Initialize(init_request);
let envelope = RequestEnvelope::new(request);
self.transport.lock().await.send_request(&envelope).await?;
tracing::debug!("sent initialize control request, waiting for response");
loop {
let incoming = {
let mut transport = self.transport.lock().await;
transport.receive().await
};
match incoming {
Ok(Some(incoming)) => {
match &incoming {
Incoming::System(crate::proto::SystemMessage::HookResponse(hook))
if hook.outcome() == Some("error") =>
{
return Err(Error::ProtocolError(format!(
"hook failed during initialization: hook={} event={} exit_code={:?}",
hook.hook_name().unwrap_or("unknown"),
hook.hook_event().unwrap_or("unknown"),
hook.exit_code()
)));
}
Incoming::System(crate::proto::SystemMessage::Error(err)) => {
return Err(Error::ProtocolError(format!(
"system error during initialization: {}",
err.error()
)));
}
_ => {}
}
if let Some(ctrl) = incoming.as_control_request() {
let response = match ctrl.request() {
Request::McpMessage(mcp_req) => {
self.handle_mcp_message(
ctrl.request_id(),
mcp_req.server_name(),
mcp_req.message(),
)
.await
}
Request::HookCallback(hook_req) => {
self.handle_hook_callback(ctrl.request_id(), hook_req).await
}
_ => continue,
};
let mut transport = self.transport.lock().await;
if let Err(e) = transport.send_response(&response).await {
tracing::warn!(error = %e, "failed to send control response during initialization");
}
continue;
}
if let Some(response) = incoming.as_control_response() {
match response.response() {
crate::proto::Response::Success(success) => {
tracing::debug!(
request_id = %success.request_id(),
"received initialize response"
);
return Ok(());
}
crate::proto::Response::Error(err) => {
return Err(Error::ControlError {
request_id: err.request_id().to_owned(),
message: err.error().message().to_owned(),
});
}
}
}
tracing::debug!("initialization loop: skipping non-control message");
}
Ok(None) => {
return Err(Error::ProtocolError(
"stream ended during initialization".to_owned(),
));
}
Err(e) => return Err(e),
}
}
}
fn build_hooks_config(&self) -> Option<HashMap<String, Value>> {
let hooks = self.hooks.as_ref()?;
let mut result = HashMap::new();
if hooks.has_pre_tool_use_hooks() {
let entries = hooks
.pre_tool_use_hooks()
.enumerate()
.map(|(id, (pattern, _))| {
json!({"matcher": pattern, "hookCallbackIds": [format!("hook_{id}")]})
})
.collect::<Vec<_>>();
result.insert("PreToolUse".to_owned(), json!(entries));
}
if hooks.has_post_tool_use_hooks() {
let base_id = hooks.pre_tool_use_hooks().len();
let entries = hooks
.post_tool_use_hooks()
.enumerate()
.map(|(idx, (pattern, _))| {
json!({"matcher": pattern, "hookCallbackIds": [format!("hook_{}", base_id + idx)]})
})
.collect::<Vec<_>>();
result.insert("PostToolUse".to_owned(), json!(entries));
}
if hooks.has_user_prompt_submit_hooks() {
let base_id = hooks.pre_tool_use_hooks().len() + hooks.post_tool_use_hooks().len();
let ids = (0..hooks.user_prompt_submit_hooks().len())
.map(|i| format!("hook_{}", base_id + i))
.collect::<Vec<_>>();
result.insert(
"UserPromptSubmit".to_owned(),
json!([{ "hookCallbackIds": ids }]),
);
}
if hooks.has_stop_hooks() {
let base_id = hooks.pre_tool_use_hooks().len()
+ hooks.post_tool_use_hooks().len()
+ hooks.user_prompt_submit_hooks().len();
let ids = (0..hooks.stop_hooks().len())
.map(|i| format!("hook_{}", base_id + i))
.collect::<Vec<_>>();
result.insert("Stop".to_owned(), json!([{ "hookCallbackIds": ids }]));
}
Some(result)
}
pub async fn session_id(&self) -> Option<String> {
self.session_id.read().await.clone()
}
pub fn conversation(&self) -> Conversation<'_> {
Conversation::new(self)
}
pub async fn query(&self, prompt: &str) -> Result<(), Error> {
let msg = OutgoingUserMessage::text(prompt);
let json = serde_json::to_value(&msg)?;
self.transport.lock().await.send(&json).await
}
pub async fn send_message(&self, content: UserContent) -> Result<(), Error> {
let msg = OutgoingUserMessage::new(content);
let json = serde_json::to_value(&msg)?;
self.transport.lock().await.send(&json).await
}
pub async fn respond_to_tool(
&self,
tool_use_id: &str,
content: Value,
is_error: bool,
) -> Result<(), Error> {
let mut responded = self.responded_tool_ids.lock().await;
if responded.contains(tool_use_id) {
tracing::warn!(tool_use_id, "already responded to tool, skipping");
return Ok(());
}
let tool_result = ContentBlock::ToolResult(
crate::proto::content_block::ToolResult::new(tool_use_id)
.with_content(content)
.with_error(is_error),
);
let msg = OutgoingUserMessage::new(UserContent::Blocks(vec![tool_result]));
let json = serde_json::to_value(&msg)?;
self.transport.lock().await.send(&json).await?;
responded.insert(tool_use_id.to_owned());
Ok(())
}
pub async fn clear_tool_response_tracking(&self) {
self.responded_tool_ids.lock().await.clear();
}
pub fn receive(&self) -> impl Stream<Item = Result<Response, Error>> + '_ {
stream! {
loop {
let incoming = {
let mut transport = self.transport.lock().await;
transport.receive().await
};
match incoming {
Ok(Some(incoming)) => {
if let Some(ctrl) = incoming.as_control_request() {
let response = match ctrl.request() {
Request::McpMessage(mcp_req) => {
self.handle_mcp_message(
ctrl.request_id(),
mcp_req.server_name(),
mcp_req.message(),
)
.await
}
Request::HookCallback(hook_req) => {
self.handle_hook_callback(ctrl.request_id(), hook_req)
.await
}
_ => continue,
};
let mut transport = self.transport.lock().await;
if let Err(e) = transport.send_response(&response).await {
tracing::warn!(error = %e, "failed to send control response");
}
continue;
}
if let Incoming::RateLimitEvent(event) = incoming {
tracing::trace!(
status = %event.status(),
utilization = ?event.utilization(),
resets_at = ?event.resets_at(),
"rate limit event",
);
let response = RateLimitResponse::from(event);
if let Some(delay) = response.backoff_delay() {
tracing::warn!(delay_secs = delay.as_secs_f64(), "rate limited, backing off");
tokio::time::sleep(delay).await;
}
yield Ok(Response::RateLimit(response));
continue;
}
if let Some(msg) = incoming.to_message() {
if let Message::System(crate::proto::SystemMessage::Init(init)) = &msg
&& let Some(sid) = init.session_id()
{
*self.session_id.write().await = Some(sid.to_owned());
tracing::debug!(session_id = %sid, "session initialized");
}
for response in Response::from_message(&msg) {
let is_complete = matches!(response, Response::Complete(_));
yield Ok(response);
if is_complete {
return;
}
}
}
}
Ok(None) => {
tracing::info!("stream ended (EOF)");
return;
}
Err(e) => {
yield Err(e);
return;
}
}
}
}
}
async fn handle_mcp_message(
&self,
request_id: &str,
server_name: &str,
message: &Value,
) -> ResponseEnvelope {
tracing::debug!(server_name, "handling MCP message");
match self.mcp_servers.get(server_name) {
Some(server) => {
let mcp_response = server.handle_json_message(message).await;
let response_data = json!({ "mcp_response": mcp_response });
ResponseEnvelope::success(request_id, Some(response_data))
}
None => {
tracing::warn!(server_name, "MCP server not found");
let error_response = json!({
"mcp_response": {
"jsonrpc": "2.0",
"id": null,
"error": {
"code": -32601,
"message": format!("MCP server '{}' not found", server_name)
}
}
});
ResponseEnvelope::success(request_id, Some(error_response))
}
}
}
async fn handle_hook_callback(
&self,
request_id: &str,
hook_req: &HookCallbackRequest,
) -> ResponseEnvelope {
let callback_id = hook_req.callback_id();
let input = hook_req.input();
tracing::debug!(callback_id, "handling hook callback");
let Some(entry) = self.hook_callbacks.get(callback_id) else {
tracing::warn!(callback_id, "hook callback not found");
return ResponseEnvelope::success(request_id, Some(json!({})));
};
let Some(hooks) = &self.hooks else {
tracing::warn!("hooks not available");
return ResponseEnvelope::success(request_id, Some(json!({})));
};
let session_id = input["session_id"].as_str().unwrap_or_default();
let transcript_path = input["transcript_path"].as_str().unwrap_or_default();
let response_data = match entry {
HookCallbackEntry::PreToolUse(idx) => {
let tool_name = input["tool_name"].as_str().unwrap_or_default();
let tool_input = input["tool_input"].clone();
let hook_input =
PreToolUseInput::new(session_id, transcript_path, tool_name, tool_input.into());
if let Some((_, callback)) = hooks.get_pre_tool_use_hook(*idx) {
let output = callback(hook_input).await;
output.to_hook_response()
} else {
json!({})
}
}
HookCallbackEntry::PostToolUse(idx) => {
let tool_name = input["tool_name"].as_str().unwrap_or_default();
let tool_input = input["tool_input"].clone();
let tool_response = input["tool_response"].clone();
let hook_input = PostToolUseInput::new(
session_id,
transcript_path,
tool_name,
tool_input.into(),
tool_response,
);
if let Some((_, callback)) = hooks.get_post_tool_use_hook(*idx) {
let output = callback(hook_input).await;
output.to_hook_response()
} else {
json!({})
}
}
HookCallbackEntry::UserPromptSubmit(idx) => {
let prompt = input["prompt"].as_str().unwrap_or_default();
let hook_input = UserPromptSubmitInput::new(session_id, transcript_path, prompt);
if let Some(callback) = hooks.user_prompt_submit_hooks().nth(*idx) {
let output = callback(hook_input).await;
output.to_hook_response()
} else {
json!({})
}
}
HookCallbackEntry::Stop(idx) => {
let stop_hook_active = input["stop_hook_active"].as_bool().unwrap_or_default();
let hook_input = StopInput::new(session_id, transcript_path, stop_hook_active);
if let Some(callback) = hooks.stop_hooks().nth(*idx) {
let output = callback(hook_input).await;
output.to_hook_response()
} else {
json!({})
}
}
};
ResponseEnvelope::success(request_id, Some(response_data))
}
pub async fn receive_all(&self) -> Result<Vec<Response>, Error> {
let mut responses = Vec::new();
let mut stream = std::pin::pin!(self.receive());
while let Some(result) = stream.next().await {
responses.push(result?);
}
Ok(responses)
}
pub async fn query_once(&self, prompt: &str) -> Result<(String, Responses), Error> {
self.query(prompt).await?;
let responses = Responses::from(self.receive_all().await?);
let text = responses.text_content();
Ok((text, responses))
}
pub async fn query_once_as<T>(&self, prompt: &str) -> Result<(T, Responses), Error>
where
T: DeserializeOwned + JsonSchema,
{
let expected_schema = crate::util::schema_for_structured_output::<T>().to_string();
match &self.json_schema {
Some(configured) if configured == &expected_schema => {}
Some(configured) => {
return Err(Error::SchemaMismatch {
expected: expected_schema,
configured: configured.clone(),
});
}
None => {
return Err(Error::NoSchemaConfigured);
}
}
self.query(prompt).await?;
let responses = Responses::from(self.receive_all().await?);
let structured_output = responses
.completion()
.and_then(|c| c.structured_output())
.cloned()
.ok_or_else(|| Error::ProtocolError("no structured output in response".to_owned()))?;
let result = serde_json::from_value::<T>(structured_output)?;
Ok((result, responses))
}
pub async fn interrupt(&self) -> Result<(), Error> {
self.transport.lock().await.interrupt().await
}
pub async fn set_permission_mode(
&self,
mode: crate::proto::PermissionMode,
) -> Result<(), Error> {
let request = crate::proto::Request::SetPermissionMode(
crate::proto::control::SetPermissionModeRequest::new(mode),
);
let envelope = RequestEnvelope::new(request);
self.transport.lock().await.send_request(&envelope).await
}
pub async fn set_model(&self, model: &str) -> Result<(), Error> {
let request =
crate::proto::Request::SetModel(crate::proto::control::SetModelRequest::new(model));
let envelope = RequestEnvelope::new(request);
self.transport.lock().await.send_request(&envelope).await
}
pub async fn get_server_info(&self) -> Result<crate::proto::ServerInfo, Error> {
let request = crate::proto::Request::GetServerInfo;
let envelope = RequestEnvelope::new(request);
let mut transport = self.transport.lock().await;
transport.send_request(&envelope).await?;
loop {
match transport.receive().await? {
Some(Incoming::ControlResponse(resp)) => match resp.response() {
crate::proto::Response::Success(success) => {
if let Some(data) = success.response() {
let info =
serde_json::from_value::<crate::proto::ServerInfo>(data.clone())?;
return Ok(info);
}
return Err(Error::ProtocolError("empty response".to_owned()));
}
crate::proto::Response::Error(err) => {
return Err(Error::ControlError {
request_id: err.request_id().to_owned(),
message: err.error().message().to_owned(),
});
}
},
Some(_) => continue,
None => return Err(Error::ConnectionError("stream ended".to_owned())),
}
}
}
}