Skip to main content

batuta/agent/
repl.rs

1//! Interactive REPL for `apr code`.
2//!
3//! Terminal UI: user types prompts, agent streams responses token-by-token.
4//! Crossterm raw mode input, tokio agent loop, mpsc streaming events.
5//! See: apr-code.md §3, agent-and-playbook.md §7
6
7use std::io::{self, Write};
8use std::sync::Arc;
9
10use std::sync::atomic::{AtomicBool, Ordering};
11
12use tokio::sync::mpsc;
13
14use crate::agent::driver::{LlmDriver, Message, StreamEvent};
15use crate::agent::memory::MemorySubstrate;
16use crate::agent::result::AgentLoopResult;
17use crate::agent::session::SessionStore;
18use crate::agent::tool::ToolRegistry;
19use crate::agent::AgentManifest;
20use crate::ansi_colors::Colorize;
21use crate::serve::context::TokenEstimator;
22
23/// Slash commands recognized by the REPL.
24///
25/// PMAT-CODE-SLASH-PARITY-001: expanded from 11 → 21 variants to mirror
26/// Claude Code's built-in slash command surface. Variants marked STUB
27/// print a placeholder pointing to the closure ticket; the parser
28/// recognizes them so `/help` can advertise them and the user sees a
29/// deliberate "not yet implemented" message rather than `Unknown`.
30#[derive(Debug, PartialEq)]
31enum SlashCommand {
32    Help,
33    Quit,
34    Cost,
35    Context,
36    Model,
37    Compact,
38    Clear,
39    Session,
40    Sessions,
41    Test,
42    Quality,
43    // PMAT-CODE-SLASH-PARITY-001: 10 new variants (Claude-Code parity).
44    // Semantics per row documented in the handler match arms below.
45    Mcp,
46    Config,
47    Review,
48    Memory,
49    Permissions,
50    Hooks,
51    Init,
52    Resume,
53    AddDir,
54    Agents,
55    Unknown(String),
56}
57
58impl SlashCommand {
59    fn parse(input: &str) -> Option<Self> {
60        let trimmed = input.trim();
61        if !trimmed.starts_with('/') {
62            return None;
63        }
64        let cmd = trimmed.split_whitespace().next().unwrap_or("");
65        Some(match cmd {
66            "/help" | "/h" | "/?" => Self::Help,
67            "/quit" | "/q" | "/exit" => Self::Quit,
68            "/cost" => Self::Cost,
69            "/context" | "/ctx" => Self::Context,
70            "/model" => Self::Model,
71            "/compact" => Self::Compact,
72            "/clear" => Self::Clear,
73            "/session" => Self::Session,
74            "/sessions" => Self::Sessions,
75            "/test" => Self::Test,
76            "/quality" => Self::Quality,
77            // PMAT-CODE-SLASH-PARITY-001
78            "/mcp" => Self::Mcp,
79            "/config" | "/cfg" => Self::Config,
80            "/review" => Self::Review,
81            "/memory" => Self::Memory,
82            "/permissions" | "/perms" => Self::Permissions,
83            "/hooks" => Self::Hooks,
84            "/init" => Self::Init,
85            "/resume" => Self::Resume,
86            "/add-dir" | "/adddir" => Self::AddDir,
87            "/agents" => Self::Agents,
88            other => Self::Unknown(other.to_string()),
89        })
90    }
91}
92
93/// Auto-compaction threshold (80% of context window). See apr-code.md §7.3.
94const AUTO_COMPACT_THRESHOLD: f64 = 0.80;
95
96/// Session state tracked across turns.
97pub(super) struct ReplSession {
98    pub(super) turn_count: u32,
99    pub(super) total_input_tokens: u64,
100    pub(super) total_output_tokens: u64,
101    pub(super) total_tool_calls: u32,
102    pub(super) estimated_cost_usd: f64,
103    /// Persistent session store (JSONL). None if persistence failed to init.
104    pub(super) store: Option<SessionStore>,
105    /// Context window size in tokens (from driver).
106    pub(super) context_window: usize,
107}
108
109impl ReplSession {
110    fn new(agent_name: &str, context_window: usize) -> Self {
111        let store = SessionStore::create(agent_name).ok();
112        Self {
113            turn_count: 0,
114            total_input_tokens: 0,
115            total_output_tokens: 0,
116            total_tool_calls: 0,
117            estimated_cost_usd: 0.0,
118            store,
119            context_window,
120        }
121    }
122
123    fn record_turn(&mut self, result: &AgentLoopResult, cost: f64) {
124        self.turn_count += 1;
125        self.total_input_tokens += result.usage.input_tokens;
126        self.total_output_tokens += result.usage.output_tokens;
127        self.total_tool_calls += result.tool_calls;
128        self.estimated_cost_usd += cost;
129        // Persist turn count
130        if let Some(ref mut store) = self.store {
131            let _ = store.record_turn();
132        }
133    }
134
135    /// Persist new messages from this turn to JSONL.
136    fn persist_messages(&self, history: &[Message], prev_len: usize) {
137        if let Some(ref store) = self.store {
138            let new_msgs = &history[prev_len..];
139            if !new_msgs.is_empty() {
140                let _ = store.append_messages(new_msgs);
141            }
142        }
143    }
144
145    pub(crate) fn session_id(&self) -> Option<&str> {
146        self.store.as_ref().map(|s| s.id())
147    }
148
149    /// Estimate total tokens used by conversation history.
150    fn estimate_history_tokens(history: &[Message]) -> usize {
151        let estimator = TokenEstimator::new();
152        let chat_msgs: Vec<_> = history.iter().map(Message::to_chat_message).collect();
153        estimator.estimate_messages(&chat_msgs)
154    }
155
156    /// Context usage as fraction (0.0–1.0).
157    pub(super) fn context_usage(&self, history: &[Message]) -> f64 {
158        if self.context_window == 0 {
159            return 0.0;
160        }
161        Self::estimate_history_tokens(history) as f64 / self.context_window as f64
162    }
163
164    /// Auto-compact if history exceeds 80% of context window.
165    /// Returns true if compaction was triggered.
166    fn auto_compact_if_needed(&self, history: &mut Vec<Message>) -> bool {
167        let usage = self.context_usage(history);
168        if usage >= AUTO_COMPACT_THRESHOLD {
169            let before = history.len();
170            compact_history(history);
171            let after = history.len();
172            if after < before {
173                println!(
174                    "  {} Auto-compacted: {} → {} messages ({:.0}% context)",
175                    "⚙".dimmed(),
176                    before,
177                    after,
178                    self.context_usage(history) * 100.0
179                );
180                return true;
181            }
182        }
183        false
184    }
185}
186
187/// Resume an existing session or create a new one.
188fn resume_or_new(
189    resume_id: Option<&str>,
190    agent_name: &str,
191    ctx_window: usize,
192) -> (ReplSession, Vec<Message>) {
193    if let Some(id) = resume_id {
194        if let Ok(store) = SessionStore::resume(id) {
195            let msgs = store.load_messages().unwrap_or_default();
196            let turns = store.manifest.turns;
197            println!("  {} Resumed {} ({turns} turns, {} msgs)", "✓".green(), id, msgs.len());
198            let s = ReplSession {
199                turn_count: turns,
200                total_input_tokens: 0,
201                total_output_tokens: 0,
202                total_tool_calls: 0,
203                estimated_cost_usd: 0.0,
204                store: Some(store),
205                context_window: ctx_window,
206            };
207            return (s, msgs);
208        }
209        println!("  {} Could not resume session: {id}", "⚠".bright_yellow());
210    }
211    let s = ReplSession::new(agent_name, ctx_window);
212    if let Some(id) = s.session_id() {
213        println!("  {} {}", "Session:".dimmed(), id.dimmed());
214    }
215    (s, Vec::new())
216}
217
218/// Run the interactive REPL.
219///
220/// This is the main entry point for `apr code` interactive mode.
221/// Returns when the user types `/quit` or Ctrl+D.
222///
223/// If `resume_id` is provided, loads conversation history from
224/// the corresponding session in `~/.apr/sessions/`.
225pub fn run_repl(
226    manifest: &AgentManifest,
227    driver: &dyn LlmDriver,
228    tools: &ToolRegistry,
229    memory: &dyn MemorySubstrate,
230    max_turns: u32,
231    budget_usd: f64,
232    resume_id: Option<&str>,
233) -> anyhow::Result<()> {
234    let rt = tokio::runtime::Builder::new_current_thread()
235        .enable_all()
236        .build()
237        .map_err(|e| anyhow::anyhow!("tokio runtime: {e}"))?;
238
239    print_welcome(manifest, driver);
240
241    let ctx_window = driver.context_window();
242
243    let (mut session, mut history) = resume_or_new(resume_id, &manifest.name, ctx_window);
244
245    let stdin = io::stdin();
246    let mut line_buf = String::new();
247
248    loop {
249        // Check turn budget
250        if session.turn_count >= max_turns {
251            println!(
252                "\n{} Max turns ({}) reached. Session complete.",
253                "⚠".bright_yellow(),
254                max_turns
255            );
256            break;
257        }
258        if session.estimated_cost_usd >= budget_usd {
259            println!(
260                "\n{} Budget (${:.2}) exhausted. Session complete.",
261                "⚠".bright_yellow(),
262                budget_usd
263            );
264            break;
265        }
266
267        // Read input
268        let input = match read_input(&stdin, &mut line_buf, &session, budget_usd, &mut history) {
269            InputResult::Prompt(s) => s,
270            InputResult::SlashHandled => continue,
271            InputResult::Exit => break,
272            InputResult::Empty => continue,
273        };
274
275        // Execute turn with streaming
276        let cancel = Arc::new(AtomicBool::new(false));
277        let cancel_clone = Arc::clone(&cancel);
278
279        rt.block_on(async {
280            let flag = cancel_clone;
281            tokio::spawn(async move {
282                if tokio::signal::ctrl_c().await.is_ok() {
283                    flag.store(true, Ordering::SeqCst);
284                }
285            });
286        });
287
288        let (tx, rx) = mpsc::channel::<StreamEvent>(64);
289
290        println!();
291
292        let history_len_before = history.len();
293        let result = rt.block_on(run_turn_streaming(
294            manifest,
295            &input,
296            driver,
297            tools,
298            memory,
299            &mut history,
300            tx,
301            rx,
302            &cancel,
303        ));
304
305        match result {
306            Ok(r) => {
307                let cost = driver.estimate_cost(&r.usage);
308                session.record_turn(&r, cost);
309                // Persist new messages to JSONL
310                session.persist_messages(&history, history_len_before);
311                // Auto-compact at 80% context window (spec §7.3)
312                session.auto_compact_if_needed(&mut history);
313                print_turn_footer(&r, cost, &session, budget_usd);
314            }
315            Err(e) => {
316                if cancel.load(Ordering::SeqCst) {
317                    println!("\n{} Generation cancelled.", "⚠".bright_yellow());
318                } else {
319                    println!("\n{} Error: {e}", "✗".bright_red());
320                }
321            }
322        }
323    }
324
325    print_session_summary(&session);
326    Ok(())
327}
328
329/// Input reading result.
330enum InputResult {
331    Prompt(String),
332    SlashHandled,
333    Exit,
334    Empty,
335}
336
337/// Read one line of input, handling slash commands inline.
338fn read_input(
339    stdin: &io::Stdin,
340    buf: &mut String,
341    session: &ReplSession,
342    budget: f64,
343    history: &mut Vec<Message>,
344) -> InputResult {
345    let cost_str = if session.estimated_cost_usd > 0.0 {
346        format!(" ${:.3}", session.estimated_cost_usd)
347    } else {
348        String::new()
349    };
350    print!(
351        "\n{}{} ",
352        format!("[{}/{}{}]", session.turn_count + 1, "?", cost_str).dimmed(),
353        " >".bright_green().bold(),
354    );
355    io::stdout().flush().ok();
356
357    buf.clear();
358    let bytes = match stdin.read_line(buf) {
359        Ok(b) => b,
360        Err(_) => return InputResult::Exit,
361    };
362    if bytes == 0 {
363        println!();
364        return InputResult::Exit;
365    }
366
367    let trimmed = buf.trim();
368    if trimmed.is_empty() {
369        return InputResult::Empty;
370    }
371
372    // Handle slash commands
373    if let Some(cmd) = SlashCommand::parse(trimmed) {
374        handle_slash_command(&cmd, session, budget, history);
375        return match cmd {
376            SlashCommand::Quit => InputResult::Exit,
377            _ => InputResult::SlashHandled,
378        };
379    }
380
381    InputResult::Prompt(trimmed.to_string())
382}
383
384/// Handle a slash command.
385fn handle_slash_command(
386    cmd: &SlashCommand,
387    session: &ReplSession,
388    budget: f64,
389    history: &mut Vec<Message>,
390) {
391    match cmd {
392        SlashCommand::Help => print_help(),
393        SlashCommand::Quit => println!("{} Goodbye.", "✓".green()),
394        SlashCommand::Cost => {
395            // PMAT-169: local inference is free — show tokens, not misleading dollars
396            if session.estimated_cost_usd < 0.0001 {
397                println!("  Cost: {} (local inference)", "free".green());
398            } else {
399                println!(
400                    "  Cost: ${:.4} / ${:.2} ({:.1}%)",
401                    session.estimated_cost_usd,
402                    budget,
403                    (session.estimated_cost_usd / budget * 100.0).min(100.0)
404                );
405            }
406            println!(
407                "  Tokens: {} in / {} out",
408                session.total_input_tokens, session.total_output_tokens
409            );
410            println!("  Turns: {}, Tool calls: {}", session.turn_count, session.total_tool_calls);
411        }
412        SlashCommand::Context => {
413            let user_msgs = history.iter().filter(|m| matches!(m, Message::User(_))).count();
414            let asst_msgs = history.iter().filter(|m| matches!(m, Message::Assistant(_))).count();
415            let tool_msgs = history
416                .iter()
417                .filter(|m| matches!(m, Message::AssistantToolUse(_) | Message::ToolResult(_)))
418                .count();
419            let usage_pct = session.context_usage(history) * 100.0;
420            let est_tokens = ReplSession::estimate_history_tokens(history);
421            println!(
422                "  History: {} messages ({} user, {} assistant, {} tool)",
423                history.len(),
424                user_msgs,
425                asst_msgs,
426                tool_msgs
427            );
428            println!(
429                "  Context: ~{} / {} tokens ({:.0}%)",
430                est_tokens, session.context_window, usage_pct
431            );
432            if usage_pct >= 80.0 {
433                println!("  {} Near context limit — /compact to free space", "⚠".bright_yellow());
434            }
435            println!("  Turns: {}", session.turn_count);
436        }
437        SlashCommand::Model => {
438            println!("  Model switching not yet implemented.");
439        }
440        SlashCommand::Compact => {
441            let before = history.len();
442            compact_history(history);
443            println!("  Compacted: {} -> {} messages", before, history.len());
444        }
445        SlashCommand::Clear => {
446            history.clear();
447            print!("\x1B[2J\x1B[1;1H");
448            io::stdout().flush().ok();
449            println!("  Screen and conversation history cleared.");
450        }
451        SlashCommand::Session => {
452            if let Some(id) = session.session_id() {
453                println!("  Session: {id}");
454                println!("  Turns: {}, Messages: {}", session.turn_count, history.len());
455            } else {
456                println!("  No active session (persistence disabled).");
457            }
458        }
459        SlashCommand::Sessions => {
460            list_recent_sessions();
461        }
462        SlashCommand::Test => {
463            println!("  Running tests...");
464            let _ = io::stdout().flush();
465            run_shell_shortcut("cargo test --lib 2>&1 | tail -5");
466        }
467        SlashCommand::Quality => {
468            println!("  Running quality gate...");
469            let _ = io::stdout().flush();
470            run_shell_shortcut("cargo clippy -- -D warnings 2>&1 | tail -3 && cargo test --lib --quiet 2>&1 | tail -3");
471        }
472        // PMAT-CODE-SLASH-PARITY-001: 10 new Claude-Code-parity variants.
473        // Kept as minimal stubs so the parser + /help advertise them; each
474        // points to its closure ticket so users see a deliberate message
475        // instead of "Unknown command".
476        SlashCommand::Mcp => {
477            println!(
478                "  MCP servers are configured under {} in the AgentManifest TOML.",
479                "mcp_servers[]".bright_yellow()
480            );
481            println!("  Project-root .mcp.json loader: PMAT-CODE-MCP-JSON-LOADER-001 (P2).");
482        }
483        SlashCommand::Config => {
484            println!(
485                "  Config source: {} (TOML). User-global ladder tracked in PMAT-CODE-CONFIG-LADDER-001.",
486                "AgentManifest".bright_yellow()
487            );
488        }
489        SlashCommand::Review => {
490            println!("  /review not yet implemented — tracked by PMAT-CODE-REVIEW-001.");
491        }
492        SlashCommand::Memory => {
493            println!(
494                "  Use the {} tool for CRUD on project memory; /memory TUI: PMAT-CODE-MEMORY-TUI-001.",
495                "memory".bright_yellow()
496            );
497        }
498        SlashCommand::Permissions => {
499            println!(
500                "  Permission modes not yet implemented — tracked by PMAT-CODE-PERMISSIONS-001."
501            );
502        }
503        SlashCommand::Hooks => {
504            println!("  Hooks not yet implemented — tracked by PMAT-CODE-HOOKS-001.");
505        }
506        SlashCommand::Init => {
507            println!("  /init scaffold not yet implemented — tracked by PMAT-CODE-INIT-001.");
508        }
509        SlashCommand::Resume => {
510            println!("  REPL-scope /resume not yet implemented — CLI `apr code --resume [id]` works today.");
511        }
512        SlashCommand::AddDir => {
513            println!("  /add-dir not yet implemented — tracked by PMAT-CODE-ADDDIR-001.");
514        }
515        SlashCommand::Agents => {
516            println!(
517                "  Custom agents not yet implemented — tracked by PMAT-CODE-CUSTOM-AGENTS-001."
518            );
519        }
520        SlashCommand::Unknown(name) => {
521            println!("  {} Unknown command: {name}. Type /help for commands.", "?".bright_yellow());
522        }
523    }
524}
525
526/// Execute one turn with streaming output and multi-turn history.
527#[allow(clippy::too_many_arguments)]
528async fn run_turn_streaming(
529    manifest: &AgentManifest,
530    prompt: &str,
531    driver: &dyn LlmDriver,
532    tools: &ToolRegistry,
533    memory: &dyn MemorySubstrate,
534    history: &mut Vec<Message>,
535    tx: mpsc::Sender<StreamEvent>,
536    mut rx: mpsc::Receiver<StreamEvent>,
537    cancel: &Arc<AtomicBool>,
538) -> Result<AgentLoopResult, crate::agent::result::AgentError> {
539    // Drain task: print streaming events as they arrive
540    let drain = tokio::spawn(async move {
541        while let Some(event) = rx.recv().await {
542            print_stream_event_repl(&event);
543        }
544    });
545
546    let result = crate::agent::runtime::run_agent_turn(
547        manifest,
548        history,
549        prompt,
550        driver,
551        tools,
552        memory,
553        Some(tx),
554    )
555    .await;
556
557    // If cancelled, wrap the error
558    if cancel.load(Ordering::SeqCst) && result.is_err() {
559        return Err(crate::agent::result::AgentError::CircuitBreak("cancelled by user".into()));
560    }
561
562    // Ensure drain task finishes
563    let _ = drain.await;
564    result
565}
566
567// Display functions extracted to repl_display.rs for file size compliance.
568use super::repl_display::{
569    compact_history, list_recent_sessions, print_help, print_session_summary,
570    print_stream_event_repl, print_turn_footer, print_welcome, run_shell_shortcut,
571};
572
573#[cfg(test)]
574#[path = "repl_tests.rs"]
575mod tests;