use std::io::BufRead;
use std::process::{Child, Command, Stdio};
use std::sync::mpsc::{Receiver, Sender};
use std::sync::{Arc, Mutex};
#[derive(Default)]
pub enum McpConfig {
#[default]
Auto,
Custom(String),
None,
}
#[derive(Default)]
pub struct ClaudeConfig {
pub system_prompt: Option<String>,
pub allowed_tools: Option<Vec<String>>,
pub disallowed_tools: Option<Vec<String>>,
pub mcp_config: McpConfig,
pub custom_args: Vec<String>,
pub env: Vec<(String, String)>,
}
pub enum CliCommand {
StartQuery {
prompt: String,
session_id: Option<String>,
model: Option<String>,
},
Cancel,
}
pub enum CliEvent {
SessionStarted { session_id: String },
TextDelta { text: String },
ThinkingDelta { text: String },
ToolUseStarted { tool_name: String, tool_id: String },
ToolUseInputDelta {
tool_id: String,
partial_json: String,
},
ToolUseFinished { tool_id: String },
TurnComplete { session_id: String },
Complete {
session_id: String,
total_cost_usd: Option<f64>,
num_turns: u32,
},
Error { message: String },
}
pub fn create_cli_channels() -> (
Sender<CliCommand>,
Receiver<CliCommand>,
Sender<CliEvent>,
Receiver<CliEvent>,
) {
let (command_sender, command_receiver) = std::sync::mpsc::channel();
let (event_sender, event_receiver) = std::sync::mpsc::channel();
(
command_sender,
command_receiver,
event_sender,
event_receiver,
)
}
#[cfg(feature = "mcp")]
fn auto_mcp_config() -> String {
serde_json::json!({
"mcpServers": {
"nightshade": {
"type": "http",
"url": format!("http://{}:{}/mcp", crate::mcp::MCP_DEFAULT_HOST, crate::mcp::MCP_DEFAULT_PORT)
}
}
})
.to_string()
}
fn resolve_mcp_config(config: &McpConfig) -> Option<String> {
match config {
McpConfig::Auto => {
#[cfg(feature = "mcp")]
{
Some(auto_mcp_config())
}
#[cfg(not(feature = "mcp"))]
{
Option::None
}
}
McpConfig::Custom(json) => Some(json.clone()),
McpConfig::None => Option::None,
}
}
fn is_auto_mcp(config: &McpConfig) -> bool {
#[cfg(feature = "mcp")]
{
matches!(config, McpConfig::Auto)
}
#[cfg(not(feature = "mcp"))]
{
let _ = config;
false
}
}
pub fn spawn_cli_worker(
command_receiver: Receiver<CliCommand>,
event_sender: Sender<CliEvent>,
config: ClaudeConfig,
) {
std::thread::spawn(move || {
let mut current_child: Option<Child> = None;
let shared_session_id: Arc<Mutex<String>> = Arc::new(Mutex::new(String::new()));
loop {
match command_receiver.recv() {
Ok(CliCommand::StartQuery {
prompt,
session_id,
model,
}) => {
if let Some(mut child) = current_child.take() {
let _ = child.kill();
let _ = child.wait();
}
let mut args = vec![
"-p".to_string(),
prompt,
"--output-format".to_string(),
"stream-json".to_string(),
"--verbose".to_string(),
"--include-partial-messages".to_string(),
];
if let Some(ref system_prompt) = config.system_prompt {
args.push("--append-system-prompt".to_string());
args.push(system_prompt.clone());
}
if let Some(ref allowed) = config.allowed_tools {
for tool in allowed {
args.push("--allowedTools".to_string());
args.push(tool.clone());
}
}
if let Some(ref disallowed) = config.disallowed_tools {
for tool in disallowed {
args.push("--disallowedTools".to_string());
args.push(tool.clone());
}
}
if let Some(mcp_json) = resolve_mcp_config(&config.mcp_config) {
args.push("--mcp-config".to_string());
args.push(mcp_json);
}
if is_auto_mcp(&config.mcp_config) {
if config.allowed_tools.is_none() {
args.push("--allowedTools".to_string());
args.push("mcp__nightshade__*".to_string());
}
if config.disallowed_tools.is_none() {
args.push("--disallowedTools".to_string());
args.push("Bash,Edit,Write,NotebookEdit,Task".to_string());
}
}
for arg in &config.custom_args {
args.push(arg.clone());
}
if let Some(session) = session_id {
args.push("--resume".to_string());
args.push(session);
}
if let Some(model_name) = model {
args.push("--model".to_string());
args.push(model_name);
}
let mut command = Command::new("claude");
command
.args(&args)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.env_remove("CLAUDECODE");
for (key, value) in &config.env {
command.env(key, value);
}
#[cfg(target_os = "windows")]
{
use std::os::windows::process::CommandExt;
command.creation_flags(0x08000000);
}
match command.spawn() {
Ok(mut child) => {
let stdout = child.stdout.take().expect("stdout was piped");
let stderr = child.stderr.take().expect("stderr was piped");
current_child = Some(child);
std::thread::spawn(move || {
let reader = std::io::BufReader::new(stderr);
for _ in reader.lines() {}
});
let event_sender_clone = event_sender.clone();
let session_id_writer = shared_session_id.clone();
*shared_session_id.lock().unwrap() = String::new();
std::thread::spawn(move || {
let reader = std::io::BufReader::new(stdout);
let mut session_id = String::new();
let mut current_tool_id = String::new();
for line_result in reader.lines() {
let line = match line_result {
Ok(line) => line,
Err(_) => break,
};
if line.trim().is_empty() {
continue;
}
let json_value: serde_json::Value =
match serde_json::from_str(&line) {
Ok(value) => value,
Err(_) => continue,
};
let events = parse_stream_json_line(
&json_value,
&mut session_id,
&mut current_tool_id,
);
for event in &events {
if let CliEvent::SessionStarted { session_id: sid } = event
{
*session_id_writer.lock().unwrap() = sid.clone();
}
}
for event in events {
if event_sender_clone.send(event).is_err() {
return;
}
}
}
});
}
Err(error) => {
let _ = event_sender.send(CliEvent::Error {
message: format!("Failed to spawn claude CLI: {error}"),
});
}
}
}
Ok(CliCommand::Cancel) => {
if let Some(mut child) = current_child.take() {
let _ = child.kill();
let _ = child.wait();
}
let session_id = shared_session_id.lock().unwrap().clone();
let _ = event_sender.send(CliEvent::TurnComplete { session_id });
}
Err(_) => {
if let Some(mut child) = current_child.take() {
let _ = child.kill();
let _ = child.wait();
}
break;
}
}
}
});
}
pub fn parse_stream_json_line(
value: &serde_json::Value,
session_id: &mut String,
current_tool_id: &mut String,
) -> Vec<CliEvent> {
let mut events = Vec::new();
let message_type = value
.get("type")
.and_then(|value| value.as_str())
.unwrap_or("");
match message_type {
"system" => {
if let Some(sid) = value.get("session_id").and_then(|value| value.as_str()) {
*session_id = sid.to_string();
events.push(CliEvent::SessionStarted {
session_id: sid.to_string(),
});
}
}
"stream_event" => {
if let Some(event) = value.get("event") {
let event_type = event
.get("type")
.and_then(|value| value.as_str())
.unwrap_or("");
match event_type {
"content_block_start" => {
if let Some(content_block) = event.get("content_block") {
let block_type = content_block
.get("type")
.and_then(|value| value.as_str())
.unwrap_or("");
if block_type == "tool_use" {
let tool_name = content_block
.get("name")
.and_then(|value| value.as_str())
.unwrap_or("unknown")
.to_string();
let tool_id = content_block
.get("id")
.and_then(|value| value.as_str())
.unwrap_or("")
.to_string();
*current_tool_id = tool_id.clone();
events.push(CliEvent::ToolUseStarted { tool_name, tool_id });
}
}
}
"content_block_delta" => {
if let Some(delta) = event.get("delta") {
let delta_type = delta
.get("type")
.and_then(|value| value.as_str())
.unwrap_or("");
match delta_type {
"text_delta" => {
if let Some(text) =
delta.get("text").and_then(|value| value.as_str())
{
events.push(CliEvent::TextDelta {
text: text.to_string(),
});
}
}
"input_json_delta" => {
if let Some(partial) =
delta.get("partial_json").and_then(|value| value.as_str())
{
events.push(CliEvent::ToolUseInputDelta {
tool_id: current_tool_id.clone(),
partial_json: partial.to_string(),
});
}
}
"thinking_delta" => {
if let Some(text) =
delta.get("thinking").and_then(|value| value.as_str())
{
events.push(CliEvent::ThinkingDelta {
text: text.to_string(),
});
}
}
_ => {}
}
}
}
"content_block_stop" if !current_tool_id.is_empty() => {
events.push(CliEvent::ToolUseFinished {
tool_id: current_tool_id.clone(),
});
current_tool_id.clear();
}
"message_stop" => {
events.push(CliEvent::TurnComplete {
session_id: session_id.clone(),
});
}
_ => {}
}
}
}
"result" => {
let total_cost = value.get("total_cost_usd").and_then(|value| value.as_f64());
let num_turns = value
.get("num_turns")
.and_then(|value| value.as_u64())
.unwrap_or(0) as u32;
events.push(CliEvent::Complete {
session_id: session_id.clone(),
total_cost_usd: total_cost,
num_turns,
});
}
_ => {}
}
events
}