use std::cell::RefCell;
use std::collections::HashMap;
use std::fmt::Debug;
use std::path::Path;
use anyhow::{Result, bail};
use crossterm::event::Event;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::mpsc;
use crate::event::AppEvent;
use crate::session::runner::SessionRunner;
pub const DEFAULT_TEST_MODEL: &str = "claude-haiku-4-5-20251001";
pub trait Recordable: Sized {
type Recorded: Serialize + DeserializeOwned;
fn to_recorded(&self) -> Result<Self::Recorded>;
fn from_recorded(recorded: Self::Recorded) -> Result<Self>;
}
impl<T: Serialize + DeserializeOwned> Recordable for T {
type Recorded = Value;
fn to_recorded(&self) -> Result<Value> {
Ok(serde_json::to_value(self)?)
}
fn from_recorded(v: Value) -> Result<Self> {
Ok(serde_json::from_value(v)?)
}
}
pub trait RecordableError: std::error::Error + Sized {
type Recorded: Serialize + DeserializeOwned;
fn to_recorded_err(&self) -> Result<Self::Recorded>;
fn from_recorded_err(recorded: Self::Recorded) -> Result<Self>;
}
impl<E: Serialize + DeserializeOwned + std::error::Error> RecordableError for E {
type Recorded = Value;
fn to_recorded_err(&self) -> Result<Value> {
Ok(serde_json::to_value(self)?)
}
fn from_recorded_err(v: Value) -> Result<Self> {
Ok(serde_json::from_value(v)?)
}
}
impl Recordable for SessionRunner {
type Recorded = ();
fn to_recorded(&self) -> Result<()> {
Ok(())
}
fn from_recorded((): ()) -> Result<Self> {
Ok(SessionRunner::stub())
}
}
#[derive(Serialize, Deserialize)]
struct VcrEntry {
label: String,
args: Value,
result: Value,
}
enum VcrMode {
Live,
Record(RefCell<Vec<VcrEntry>>),
Replay(RefCell<ReplayState>),
}
struct ReplayState {
entries: Vec<VcrEntry>,
position: usize,
}
pub struct VcrContext {
mode: VcrMode,
trigger_controller: Option<RefCell<TriggerController>>,
}
impl VcrContext {
pub fn live() -> Self {
Self {
mode: VcrMode::Live,
trigger_controller: None,
}
}
pub fn record() -> Self {
Self {
mode: VcrMode::Record(RefCell::new(Vec::new())),
trigger_controller: None,
}
}
pub fn record_with_triggers(controller: TriggerController) -> Self {
Self {
mode: VcrMode::Record(RefCell::new(Vec::new())),
trigger_controller: Some(RefCell::new(controller)),
}
}
pub fn replay(data: &str) -> Result<Self> {
let mut entries = Vec::new();
for line in data.lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
entries.push(serde_json::from_str(line)?);
}
Ok(Self {
mode: VcrMode::Replay(RefCell::new(ReplayState {
entries,
position: 0,
})),
trigger_controller: None,
})
}
pub fn write_recording(&self, path: &Path) -> Result<()> {
let VcrMode::Record(ref entries) = self.mode else {
bail!("write_recording called on non-Record VcrContext");
};
let entries = entries.borrow();
let mut output = String::new();
for entry in entries.iter() {
output.push_str(&serde_json::to_string(entry)?);
output.push('\n');
}
std::fs::write(path, output)?;
Ok(())
}
pub async fn call<A, T>(
&self,
label: &str,
args: A,
f: impl AsyncFnOnce(&A) -> Result<T>,
) -> Result<T>
where
A: Recordable,
A::Recorded: PartialEq + Debug,
T: Recordable,
{
match &self.mode {
VcrMode::Live => f(&args).await,
VcrMode::Record(entries) => {
let result = f(&args).await;
let recorded_result: std::result::Result<T::Recorded, String> = match &result {
Ok(t) => Ok(t.to_recorded()?),
Err(e) => Err(format!("{e:#}")),
};
self.push_entry(entries, label, &args, &recorded_result)?;
result
}
VcrMode::Replay(state) => {
let entry_result = Self::advance_replay(state, label, &args)?;
let recorded_result: std::result::Result<T::Recorded, String> =
serde_json::from_value(entry_result)?;
match recorded_result {
Ok(t) => Ok(T::from_recorded(t)?),
Err(msg) => Err(anyhow::anyhow!("{msg}")),
}
}
}
}
pub async fn call_typed_err<A, T, E>(
&self,
label: &str,
args: A,
f: impl AsyncFnOnce(&A) -> std::result::Result<T, E>,
) -> Result<std::result::Result<T, E>>
where
A: Recordable,
A::Recorded: PartialEq + Debug,
T: Recordable,
E: RecordableError,
{
match &self.mode {
VcrMode::Live => Ok(f(&args).await),
VcrMode::Record(entries) => {
let result = f(&args).await;
let recorded_result: std::result::Result<T::Recorded, E::Recorded> = match &result {
Ok(t) => Ok(t.to_recorded()?),
Err(e) => Err(e.to_recorded_err()?),
};
self.push_entry(entries, label, &args, &recorded_result)?;
Ok(result)
}
VcrMode::Replay(state) => {
let entry_result = Self::advance_replay(state, label, &args)?;
let recorded_result: std::result::Result<T::Recorded, E::Recorded> =
serde_json::from_value(entry_result)?;
match recorded_result {
Ok(t) => Ok(Ok(T::from_recorded(t)?)),
Err(e) => Ok(Err(E::from_recorded_err(e)?)),
}
}
}
}
fn advance_replay<A>(state: &RefCell<ReplayState>, label: &str, args: &A) -> Result<Value>
where
A: Recordable,
A::Recorded: PartialEq + Debug,
{
let (entry_label, entry_args, entry_result, pos) = {
let mut state = state.borrow_mut();
anyhow::ensure!(
state.position < state.entries.len(),
"VCR replay exhausted: expected more entries after position {}",
state.position
);
let pos = state.position;
let entry = &state.entries[pos];
let result = (
entry.label.clone(),
entry.args.clone(),
entry.result.clone(),
pos,
);
state.position += 1;
result
};
anyhow::ensure!(
entry_label == label,
"VCR label mismatch at position {pos}: expected '{entry_label}', got '{label}'"
);
let recorded_args: A::Recorded = serde_json::from_value(entry_args)?;
let actual_args = args.to_recorded()?;
anyhow::ensure!(
recorded_args == actual_args,
"VCR args mismatch for '{label}' at position {pos}: expected {recorded_args:?}, got {actual_args:?}"
);
Ok(entry_result)
}
fn push_entry<A, R: Serialize>(
&self,
entries: &RefCell<Vec<VcrEntry>>,
label: &str,
args: &A,
recorded_result: &R,
) -> Result<()>
where
A: Recordable,
{
let entry = VcrEntry {
label: label.to_string(),
args: serde_json::to_value(args.to_recorded()?)?,
result: serde_json::to_value(recorded_result)?,
};
let result_value = entry.result.clone();
entries.borrow_mut().push(entry);
if let Some(ref tc) = self.trigger_controller {
tc.borrow_mut().check(label, &result_value);
}
Ok(())
}
pub fn is_live(&self) -> bool {
matches!(&self.mode, VcrMode::Live)
}
pub fn is_replay(&self) -> bool {
matches!(&self.mode, VcrMode::Replay(_))
}
pub fn is_record(&self) -> bool {
matches!(&self.mode, VcrMode::Record(_))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum IoEvent {
Claude(AppEvent),
Terminal(Event),
}
pub struct Io {
event_rx: mpsc::UnboundedReceiver<AppEvent>,
term_rx: mpsc::UnboundedReceiver<Event>,
idle_tx: Option<mpsc::UnboundedSender<AppEvent>>,
}
impl Io {
pub fn new(
event_rx: mpsc::UnboundedReceiver<AppEvent>,
term_rx: mpsc::UnboundedReceiver<Event>,
) -> Self {
Self {
event_rx,
term_rx,
idle_tx: None,
}
}
pub fn dummy() -> Self {
let (_tx1, rx1) = mpsc::unbounded_channel();
let (_tx2, rx2) = mpsc::unbounded_channel();
Self {
event_rx: rx1,
term_rx: rx2,
idle_tx: None,
}
}
pub async fn next_event(&mut self) -> Result<IoEvent> {
tokio::select! {
event = self.event_rx.recv() => {
Ok(IoEvent::Claude(
event.unwrap_or(AppEvent::ProcessExit(None))
))
}
event = self.term_rx.recv() => {
match event {
Some(e) => Ok(IoEvent::Terminal(e)),
None => Ok(IoEvent::Claude(AppEvent::ProcessExit(None))),
}
}
}
}
pub fn replace_event_channel(&mut self) -> mpsc::UnboundedSender<AppEvent> {
self.idle_tx = None;
let (tx, rx) = mpsc::unbounded_channel();
self.event_rx = rx;
tx
}
pub fn clear_event_channel(&mut self) {
let (tx, rx) = mpsc::unbounded_channel();
self.event_rx = rx;
self.idle_tx = Some(tx);
}
}
pub struct TriggerController {
triggers: Vec<PendingTrigger>,
term_tx: mpsc::UnboundedSender<Event>,
auto_exit: bool,
}
struct PendingTrigger {
condition: Option<Value>,
label: Option<String>,
text: String,
mode: TriggerInputMode,
fired: bool,
}
#[derive(Clone, Copy, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum TriggerInputMode {
#[default]
Followup,
Steering,
Exit,
Interrupt,
}
impl TriggerController {
pub fn new(messages: &[TestMessage], term_tx: mpsc::UnboundedSender<Event>) -> Result<Self> {
let triggers = messages
.iter()
.map(|m| {
anyhow::ensure!(
m.trigger.is_some() || m.label.is_some(),
"trigger message must have at least one of `trigger` or `label`"
);
let condition = m
.trigger
.as_ref()
.map(|t| serde_json::from_str(t))
.transpose()?;
Ok(PendingTrigger {
condition,
label: m.label.clone(),
text: m.content.clone(),
mode: m.mode,
fired: false,
})
})
.collect::<Result<Vec<_>>>()?;
Ok(Self {
triggers,
term_tx,
auto_exit: false,
})
}
#[must_use]
pub fn with_auto_exit(mut self) -> Self {
self.auto_exit = true;
self
}
pub fn check(&mut self, vcr_label: &str, recorded_result: &Value) {
let to_inject: Vec<(String, TriggerInputMode)> = self
.triggers
.iter_mut()
.filter(|t| {
if t.fired {
return false;
}
if t.label.as_deref().is_some_and(|l| l != vcr_label) {
return false;
}
match &t.condition {
Some(cond) => is_subset(cond, recorded_result),
None => true, }
})
.map(|t| {
t.fired = true;
(t.text.clone(), t.mode)
})
.collect();
let any_fired_this_call = !to_inject.is_empty();
for (text, mode) in &to_inject {
match mode {
TriggerInputMode::Exit => inject_exit(&self.term_tx),
TriggerInputMode::Interrupt => {
inject_interrupt(&self.term_tx);
if !text.is_empty() {
inject_text(&self.term_tx, text, TriggerInputMode::Steering);
}
}
_ => inject_text(&self.term_tx, text, *mode),
}
}
if self.auto_exit && !any_fired_this_call && self.triggers.iter().all(|t| t.fired) {
let result_pattern =
serde_json::json!({"Ok": {"Claude": {"Claude": {"type": "result"}}}});
if is_subset(&result_pattern, recorded_result) {
inject_exit(&self.term_tx);
}
}
}
}
fn inject_text(term_tx: &mpsc::UnboundedSender<Event>, text: &str, mode: TriggerInputMode) {
use crossterm::event::{KeyCode, KeyEvent, KeyEventKind, KeyModifiers};
for ch in text.chars() {
let event = Event::Key(KeyEvent::new(KeyCode::Char(ch), KeyModifiers::NONE));
let _ = term_tx.send(event);
}
let enter_modifiers = match mode {
TriggerInputMode::Followup => KeyModifiers::ALT,
TriggerInputMode::Steering => KeyModifiers::NONE,
TriggerInputMode::Exit | TriggerInputMode::Interrupt => {
unreachable!("Exit/Interrupt triggers are handled in check(), not inject_text()")
}
};
let enter = Event::Key(KeyEvent {
code: KeyCode::Enter,
modifiers: enter_modifiers,
kind: KeyEventKind::Press,
state: crossterm::event::KeyEventState::NONE,
});
let _ = term_tx.send(enter);
}
fn inject_exit(term_tx: &mpsc::UnboundedSender<Event>) {
use crossterm::event::{KeyCode, KeyEvent, KeyEventKind, KeyModifiers};
let exit = Event::Key(KeyEvent {
code: KeyCode::Char('d'),
modifiers: KeyModifiers::CONTROL,
kind: KeyEventKind::Press,
state: crossterm::event::KeyEventState::NONE,
});
let _ = term_tx.send(exit);
}
fn inject_interrupt(term_tx: &mpsc::UnboundedSender<Event>) {
use crossterm::event::{KeyCode, KeyEvent, KeyEventKind, KeyModifiers};
let ctrl_c = Event::Key(KeyEvent {
code: KeyCode::Char('c'),
modifiers: KeyModifiers::CONTROL,
kind: KeyEventKind::Press,
state: crossterm::event::KeyEventState::NONE,
});
let _ = term_tx.send(ctrl_c);
}
fn is_subset(pattern: &Value, event: &Value) -> bool {
match (pattern, event) {
(Value::Object(p), Value::Object(e)) => p
.iter()
.all(|(k, v)| e.get(k).is_some_and(|ev| is_subset(v, ev))),
_ => pattern == event,
}
}
#[derive(Deserialize, Default)]
pub struct TestCase {
pub run: Option<RunConfig>,
pub ralph: Option<RalphConfig>,
pub worker: Option<WorkerTestConfig>,
#[serde(default)]
pub display: DisplayConfig,
#[serde(default)]
pub files: HashMap<String, String>,
#[serde(default)]
pub messages: Vec<TestMessage>,
#[serde(default)]
pub views: Vec<String>,
}
#[derive(Deserialize, Default)]
pub struct DisplayConfig {
#[serde(default)]
pub show_thinking: bool,
}
#[derive(Deserialize)]
pub struct RunConfig {
pub prompt: String,
#[serde(default)]
pub claude_args: Vec<String>,
}
#[derive(Deserialize)]
pub struct RalphConfig {
pub prompt: String,
#[serde(default = "default_break_tag")]
pub break_tag: String,
#[serde(default)]
pub claude_args: Vec<String>,
}
fn default_break_tag() -> String {
"break".to_string()
}
#[derive(Deserialize)]
pub struct WorkerTestConfig {
#[serde(default)]
pub claude_args: Vec<String>,
}
#[derive(Deserialize)]
pub struct TestMessage {
pub content: String,
#[serde(default)]
pub trigger: Option<String>,
#[serde(default)]
pub label: Option<String>,
#[serde(default)]
pub mode: TriggerInputMode,
}
impl TestCase {
pub fn is_ralph(&self) -> bool {
self.ralph.is_some()
}
pub fn is_worker(&self) -> bool {
self.worker.is_some()
}
}