use ralph_core::UrgentSteerStore;
use ralph_proto::{GuidanceTarget, RpcCommand, RpcEvent, RpcState, emit_event_line, parse_command};
use std::path::PathBuf;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncRead, BufReader};
use tokio::sync::{mpsc, watch};
use tracing::{debug, info, warn};
pub struct RpcDispatcher<F>
where
F: Fn() -> RpcState + Send + Sync,
{
pub interrupt_tx: watch::Sender<bool>,
pub guidance_tx: mpsc::Sender<GuidanceMessage>,
pub response_tx: mpsc::Sender<RpcEvent>,
pub state_fn: Arc<F>,
pub loop_started: Arc<std::sync::atomic::AtomicBool>,
pub urgent_steer_path: Option<PathBuf>,
}
#[derive(Debug, Clone)]
pub struct GuidanceMessage {
pub message: String,
pub target: GuidanceTarget,
}
impl<F> RpcDispatcher<F>
where
F: Fn() -> RpcState + Send + Sync,
{
pub fn new(
interrupt_tx: watch::Sender<bool>,
guidance_tx: mpsc::Sender<GuidanceMessage>,
response_tx: mpsc::Sender<RpcEvent>,
urgent_steer_path: Option<PathBuf>,
state_fn: F,
) -> Self {
Self {
interrupt_tx,
guidance_tx,
response_tx,
state_fn: Arc::new(state_fn),
loop_started: Arc::new(std::sync::atomic::AtomicBool::new(false)),
urgent_steer_path,
}
}
pub fn mark_loop_started(&self) {
self.loop_started
.store(true, std::sync::atomic::Ordering::SeqCst);
}
fn loop_has_started(&self) -> bool {
self.loop_started.load(std::sync::atomic::Ordering::SeqCst)
}
async fn dispatch(&self, cmd: RpcCommand) -> RpcEvent {
let id = cmd.id().map(|s| s.to_string());
let cmd_type = cmd.command_type();
match cmd {
RpcCommand::Prompt { .. } => {
if self.loop_has_started() {
return RpcEvent::error_response(
cmd_type,
id,
"loop already running; use guidance or steer instead",
);
}
RpcEvent::error_response(
cmd_type,
id,
"prompt command is not supported after startup; pass -p/--prompt when launching",
)
}
RpcCommand::Guidance { message, .. } => {
let msg = GuidanceMessage {
message: message.clone(),
target: GuidanceTarget::Next,
};
match self.guidance_tx.send(msg).await {
Ok(()) => {
let _ = self
.response_tx
.send(RpcEvent::GuidanceAck {
message: message.clone(),
applies_to: GuidanceTarget::Next,
})
.await;
RpcEvent::success_response(cmd_type, id, None)
}
Err(_) => RpcEvent::error_response(cmd_type, id, "guidance channel closed"),
}
}
RpcCommand::Steer { message, .. } => {
if let Some(path) = &self.urgent_steer_path
&& let Err(err) =
UrgentSteerStore::new(path.clone()).append_message(message.clone())
{
return RpcEvent::error_response(
cmd_type,
id,
format!("failed to persist urgent steer: {err}"),
);
}
let msg = GuidanceMessage {
message: message.clone(),
target: GuidanceTarget::Current,
};
match self.guidance_tx.send(msg).await {
Ok(()) => {
let _ = self
.response_tx
.send(RpcEvent::GuidanceAck {
message: message.clone(),
applies_to: GuidanceTarget::Current,
})
.await;
RpcEvent::success_response(cmd_type, id, None)
}
Err(_) => RpcEvent::error_response(cmd_type, id, "guidance channel closed"),
}
}
RpcCommand::FollowUp { message, .. } => {
let msg = GuidanceMessage {
message: message.clone(),
target: GuidanceTarget::Next,
};
match self.guidance_tx.send(msg).await {
Ok(()) => {
let _ = self
.response_tx
.send(RpcEvent::GuidanceAck {
message: message.clone(),
applies_to: GuidanceTarget::Next,
})
.await;
RpcEvent::success_response(cmd_type, id, None)
}
Err(_) => RpcEvent::error_response(cmd_type, id, "guidance channel closed"),
}
}
RpcCommand::Abort { reason, .. } => {
debug!(reason = ?reason, "Received abort command");
match self.interrupt_tx.send(true) {
Ok(()) => RpcEvent::success_response(cmd_type, id, None),
Err(_) => RpcEvent::error_response(cmd_type, id, "interrupt channel closed"),
}
}
RpcCommand::GetState { .. } => {
let state = (self.state_fn)();
let data = serde_json::to_value(&state).ok();
RpcEvent::success_response(cmd_type, id, data)
}
RpcCommand::GetIterations {
include_content, ..
} => {
let state = (self.state_fn)();
let data = serde_json::json!({
"iteration": state.iteration,
"max_iterations": state.max_iterations,
"include_content": include_content,
});
RpcEvent::success_response(cmd_type, id, Some(data))
}
RpcCommand::SetHat { .. } => {
RpcEvent::error_response(cmd_type, id, "not yet implemented")
}
RpcCommand::ExtensionUiResponse { .. } => {
RpcEvent::error_response(cmd_type, id, "not yet implemented")
}
}
}
}
pub async fn run_stdin_reader<F, R>(dispatcher: RpcDispatcher<F>, reader: R)
where
F: Fn() -> RpcState + Send + Sync + 'static,
R: AsyncRead + Unpin,
{
let mut lines = BufReader::new(reader).lines();
loop {
match lines.next_line().await {
Ok(Some(line)) => {
let line = line.trim();
if line.is_empty() {
continue;
}
debug!(line = %line, "Received stdin command");
let response = match parse_command(line) {
Ok(cmd) => dispatcher.dispatch(cmd).await,
Err(parse_error) => {
warn!(error = %parse_error, line = %line, "Failed to parse command");
RpcEvent::error_response("parse", None, parse_error)
}
};
if dispatcher.response_tx.send(response).await.is_err() {
warn!("Response channel closed, stopping stdin reader");
break;
}
}
Ok(None) => {
info!("Stdin closed (EOF), stopping reader task");
break;
}
Err(e) => {
warn!(error = %e, "Error reading from stdin, stopping reader task");
break;
}
}
}
}
pub async fn run_stdout_emitter(mut rx: mpsc::Receiver<RpcEvent>) {
use std::io::Write;
while let Some(event) = rx.recv().await {
let line = emit_event_line(&event);
let stdout = std::io::stdout();
let mut stdout = stdout.lock();
if stdout.write_all(line.as_bytes()).is_err() {
warn!("Failed to write to stdout, stopping emitter");
break;
}
if stdout.flush().is_err() {
warn!("Failed to flush stdout");
}
}
debug!("Stdout emitter task finished");
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::sync::mpsc;
fn default_state() -> RpcState {
RpcState {
iteration: 1,
max_iterations: Some(10),
hat: "builder".to_string(),
hat_display: "🔨Builder".to_string(),
backend: "claude".to_string(),
completed: false,
started_at: 1_700_000_000_000,
iteration_started_at: Some(1_700_000_001_000),
task_counts: ralph_proto::RpcTaskCounts::default(),
active_task: None,
total_cost_usd: 0.0,
}
}
#[tokio::test]
async fn test_abort_triggers_interrupt() {
let (interrupt_tx, interrupt_rx) = watch::channel(false);
let (guidance_tx, _guidance_rx) = mpsc::channel(10);
let (response_tx, _response_rx) = mpsc::channel(10);
let dispatcher = RpcDispatcher::new(interrupt_tx, guidance_tx, response_tx, None, || {
default_state()
});
let cmd = RpcCommand::Abort {
id: Some("abort-1".to_string()),
reason: Some("test abort".to_string()),
};
let response = dispatcher.dispatch(cmd).await;
assert!(*interrupt_rx.borrow());
match response {
RpcEvent::Response {
command, success, ..
} => {
assert_eq!(command, "abort");
assert!(success);
}
_ => panic!("Expected Response event"),
}
}
#[tokio::test]
async fn test_guidance_routes_to_channel() {
let (interrupt_tx, _) = watch::channel(false);
let (guidance_tx, mut guidance_rx) = mpsc::channel(10);
let (response_tx, _) = mpsc::channel(10);
let dispatcher = RpcDispatcher::new(interrupt_tx, guidance_tx, response_tx, None, || {
default_state()
});
let cmd = RpcCommand::Guidance {
id: None,
message: "focus on tests".to_string(),
};
let _response = dispatcher.dispatch(cmd).await;
let msg = guidance_rx.recv().await.expect("should receive guidance");
assert_eq!(msg.message, "focus on tests");
assert_eq!(msg.target, GuidanceTarget::Next);
}
#[tokio::test]
async fn test_get_state_returns_snapshot() {
let (interrupt_tx, _) = watch::channel(false);
let (guidance_tx, _) = mpsc::channel(10);
let (response_tx, _) = mpsc::channel(10);
let dispatcher = RpcDispatcher::new(interrupt_tx, guidance_tx, response_tx, None, || {
default_state()
});
let cmd = RpcCommand::GetState {
id: Some("state-1".to_string()),
};
let response = dispatcher.dispatch(cmd).await;
match response {
RpcEvent::Response {
command,
id,
success,
data,
..
} => {
assert_eq!(command, "get_state");
assert_eq!(id, Some("state-1".to_string()));
assert!(success);
let data = data.expect("should have data");
assert_eq!(data["iteration"], 1);
assert_eq!(data["hat"], "builder");
}
_ => panic!("Expected Response event"),
}
}
#[tokio::test]
async fn test_steer_vs_follow_up_semantics() {
let (interrupt_tx, _) = watch::channel(false);
let (guidance_tx, mut guidance_rx) = mpsc::channel(10);
let (response_tx, _) = mpsc::channel(10);
let dispatcher = RpcDispatcher::new(interrupt_tx, guidance_tx, response_tx, None, || {
default_state()
});
let steer_cmd = RpcCommand::Steer {
id: None,
message: "steer now".to_string(),
};
let _response = dispatcher.dispatch(steer_cmd).await;
let steer_msg = guidance_rx.recv().await.expect("steer message");
assert_eq!(steer_msg.target, GuidanceTarget::Current);
let follow_up_cmd = RpcCommand::FollowUp {
id: None,
message: "follow up later".to_string(),
};
let _response = dispatcher.dispatch(follow_up_cmd).await;
let follow_up_msg = guidance_rx.recv().await.expect("follow_up message");
assert_eq!(follow_up_msg.target, GuidanceTarget::Next);
}
#[tokio::test]
async fn test_steer_persists_urgent_marker() {
let (interrupt_tx, _) = watch::channel(false);
let (guidance_tx, _guidance_rx) = mpsc::channel(10);
let (response_tx, _) = mpsc::channel(10);
let temp_dir = tempfile::tempdir().expect("temp dir");
let urgent_path = temp_dir.path().join("urgent-steer.json");
let dispatcher = RpcDispatcher::new(
interrupt_tx,
guidance_tx,
response_tx,
Some(urgent_path.clone()),
|| default_state(),
);
let steer_cmd = RpcCommand::Steer {
id: None,
message: "steer now".to_string(),
};
let _response = dispatcher.dispatch(steer_cmd).await;
let record = UrgentSteerStore::new(urgent_path)
.load()
.expect("load marker")
.expect("record");
assert_eq!(record.messages, vec!["steer now"]);
}
#[tokio::test]
async fn test_prompt_rejected_after_loop_started() {
let (interrupt_tx, _) = watch::channel(false);
let (guidance_tx, _) = mpsc::channel(10);
let (response_tx, _) = mpsc::channel(10);
let dispatcher = RpcDispatcher::new(interrupt_tx, guidance_tx, response_tx, None, || {
default_state()
});
dispatcher.mark_loop_started();
let cmd = RpcCommand::Prompt {
id: Some("prompt-1".to_string()),
prompt: "do something".to_string(),
backend: None,
max_iterations: None,
};
let response = dispatcher.dispatch(cmd).await;
match response {
RpcEvent::Response { success, error, .. } => {
assert!(!success);
assert!(error.unwrap().contains("loop already running"));
}
_ => panic!("Expected Response event"),
}
}
#[tokio::test]
async fn test_stdin_reader_parses_json_commands() {
let (interrupt_tx, _) = watch::channel(false);
let (guidance_tx, _) = mpsc::channel(10);
let (response_tx, mut response_rx) = mpsc::channel(10);
let dispatcher = RpcDispatcher::new(interrupt_tx, guidance_tx, response_tx, None, || {
default_state()
});
let input = r#"{"type": "get_state", "id": "test-1"}"#;
let reader = std::io::Cursor::new(input.as_bytes().to_vec());
tokio::spawn(async move {
run_stdin_reader(dispatcher, reader).await;
});
let response = tokio::time::timeout(std::time::Duration::from_secs(1), response_rx.recv())
.await
.expect("timeout")
.expect("should receive response");
match response {
RpcEvent::Response {
command,
id,
success,
..
} => {
assert_eq!(command, "get_state");
assert_eq!(id, Some("test-1".to_string()));
assert!(success);
}
_ => panic!("Expected Response event"),
}
}
#[tokio::test]
async fn test_parse_error_returns_error_response() {
let (interrupt_tx, _) = watch::channel(false);
let (guidance_tx, _) = mpsc::channel(10);
let (response_tx, mut response_rx) = mpsc::channel(10);
let dispatcher = RpcDispatcher::new(interrupt_tx, guidance_tx, response_tx, None, || {
default_state()
});
let input = r#"{"type": "nonexistent_command"}"#;
let reader = std::io::Cursor::new(input.as_bytes().to_vec());
tokio::spawn(async move {
run_stdin_reader(dispatcher, reader).await;
});
let response = tokio::time::timeout(std::time::Duration::from_secs(1), response_rx.recv())
.await
.expect("timeout")
.expect("should receive response");
match response {
RpcEvent::Response {
command,
success,
error,
..
} => {
assert_eq!(command, "parse");
assert!(!success);
assert!(error.is_some());
}
_ => panic!("Expected error Response event"),
}
}
#[tokio::test]
async fn test_stdin_eof_exits_gracefully() {
let (interrupt_tx, _) = watch::channel(false);
let (guidance_tx, _) = mpsc::channel(10);
let (response_tx, _response_rx) = mpsc::channel(10);
let dispatcher = RpcDispatcher::new(interrupt_tx, guidance_tx, response_tx, None, || {
default_state()
});
let reader = std::io::Cursor::new(Vec::<u8>::new());
run_stdin_reader(dispatcher, reader).await;
}
}