Skip to main content

batuta/agent/
repl.rs

1//! Interactive REPL for `apr code`.
2//!
3//! Terminal UI: user types prompts, agent streams responses token-by-token.
4//! Crossterm raw mode input, tokio agent loop, mpsc streaming events.
5//! See: apr-code.md §3, agent-and-playbook.md §7
6
7use std::io::{self, Write};
8use std::sync::Arc;
9
10use std::sync::atomic::{AtomicBool, Ordering};
11
12use tokio::sync::mpsc;
13
14use crate::agent::driver::{LlmDriver, Message, StreamEvent};
15use crate::agent::memory::MemorySubstrate;
16use crate::agent::result::AgentLoopResult;
17use crate::agent::session::SessionStore;
18use crate::agent::tool::ToolRegistry;
19use crate::agent::AgentManifest;
20use crate::ansi_colors::Colorize;
21use crate::serve::context::TokenEstimator;
22
23/// Slash commands recognized by the REPL.
24#[derive(Debug, PartialEq)]
25enum SlashCommand {
26    Help,
27    Quit,
28    Cost,
29    Context,
30    Model,
31    Compact,
32    Clear,
33    Session,
34    Sessions,
35    Test,
36    Quality,
37    Unknown(String),
38}
39
40impl SlashCommand {
41    fn parse(input: &str) -> Option<Self> {
42        let trimmed = input.trim();
43        if !trimmed.starts_with('/') {
44            return None;
45        }
46        let cmd = trimmed.split_whitespace().next().unwrap_or("");
47        Some(match cmd {
48            "/help" | "/h" | "/?" => Self::Help,
49            "/quit" | "/q" | "/exit" => Self::Quit,
50            "/cost" => Self::Cost,
51            "/context" | "/ctx" => Self::Context,
52            "/model" => Self::Model,
53            "/compact" => Self::Compact,
54            "/clear" => Self::Clear,
55            "/session" => Self::Session,
56            "/sessions" => Self::Sessions,
57            "/test" => Self::Test,
58            "/quality" => Self::Quality,
59            other => Self::Unknown(other.to_string()),
60        })
61    }
62}
63
64/// Auto-compaction threshold (80% of context window). See apr-code.md §7.3.
65const AUTO_COMPACT_THRESHOLD: f64 = 0.80;
66
67/// Session state tracked across turns.
68pub(super) struct ReplSession {
69    pub(super) turn_count: u32,
70    pub(super) total_input_tokens: u64,
71    pub(super) total_output_tokens: u64,
72    pub(super) total_tool_calls: u32,
73    pub(super) estimated_cost_usd: f64,
74    /// Persistent session store (JSONL). None if persistence failed to init.
75    pub(super) store: Option<SessionStore>,
76    /// Context window size in tokens (from driver).
77    pub(super) context_window: usize,
78}
79
80impl ReplSession {
81    fn new(agent_name: &str, context_window: usize) -> Self {
82        let store = SessionStore::create(agent_name).ok();
83        Self {
84            turn_count: 0,
85            total_input_tokens: 0,
86            total_output_tokens: 0,
87            total_tool_calls: 0,
88            estimated_cost_usd: 0.0,
89            store,
90            context_window,
91        }
92    }
93
94    fn record_turn(&mut self, result: &AgentLoopResult, cost: f64) {
95        self.turn_count += 1;
96        self.total_input_tokens += result.usage.input_tokens;
97        self.total_output_tokens += result.usage.output_tokens;
98        self.total_tool_calls += result.tool_calls;
99        self.estimated_cost_usd += cost;
100        // Persist turn count
101        if let Some(ref mut store) = self.store {
102            let _ = store.record_turn();
103        }
104    }
105
106    /// Persist new messages from this turn to JSONL.
107    fn persist_messages(&self, history: &[Message], prev_len: usize) {
108        if let Some(ref store) = self.store {
109            let new_msgs = &history[prev_len..];
110            if !new_msgs.is_empty() {
111                let _ = store.append_messages(new_msgs);
112            }
113        }
114    }
115
116    pub(crate) fn session_id(&self) -> Option<&str> {
117        self.store.as_ref().map(|s| s.id())
118    }
119
120    /// Estimate total tokens used by conversation history.
121    fn estimate_history_tokens(history: &[Message]) -> usize {
122        let estimator = TokenEstimator::new();
123        let chat_msgs: Vec<_> = history.iter().map(Message::to_chat_message).collect();
124        estimator.estimate_messages(&chat_msgs)
125    }
126
127    /// Context usage as fraction (0.0–1.0).
128    pub(super) fn context_usage(&self, history: &[Message]) -> f64 {
129        if self.context_window == 0 {
130            return 0.0;
131        }
132        Self::estimate_history_tokens(history) as f64 / self.context_window as f64
133    }
134
135    /// Auto-compact if history exceeds 80% of context window.
136    /// Returns true if compaction was triggered.
137    fn auto_compact_if_needed(&self, history: &mut Vec<Message>) -> bool {
138        let usage = self.context_usage(history);
139        if usage >= AUTO_COMPACT_THRESHOLD {
140            let before = history.len();
141            compact_history(history);
142            let after = history.len();
143            if after < before {
144                println!(
145                    "  {} Auto-compacted: {} → {} messages ({:.0}% context)",
146                    "⚙".dimmed(),
147                    before,
148                    after,
149                    self.context_usage(history) * 100.0
150                );
151                return true;
152            }
153        }
154        false
155    }
156}
157
158/// Resume an existing session or create a new one.
159fn resume_or_new(
160    resume_id: Option<&str>,
161    agent_name: &str,
162    ctx_window: usize,
163) -> (ReplSession, Vec<Message>) {
164    if let Some(id) = resume_id {
165        if let Ok(store) = SessionStore::resume(id) {
166            let msgs = store.load_messages().unwrap_or_default();
167            let turns = store.manifest.turns;
168            println!("  {} Resumed {} ({turns} turns, {} msgs)", "✓".green(), id, msgs.len());
169            let s = ReplSession {
170                turn_count: turns,
171                total_input_tokens: 0,
172                total_output_tokens: 0,
173                total_tool_calls: 0,
174                estimated_cost_usd: 0.0,
175                store: Some(store),
176                context_window: ctx_window,
177            };
178            return (s, msgs);
179        }
180        println!("  {} Could not resume session: {id}", "⚠".bright_yellow());
181    }
182    let s = ReplSession::new(agent_name, ctx_window);
183    if let Some(id) = s.session_id() {
184        println!("  {} {}", "Session:".dimmed(), id.dimmed());
185    }
186    (s, Vec::new())
187}
188
189/// Run the interactive REPL.
190///
191/// This is the main entry point for `apr code` interactive mode.
192/// Returns when the user types `/quit` or Ctrl+D.
193///
194/// If `resume_id` is provided, loads conversation history from
195/// the corresponding session in `~/.apr/sessions/`.
196pub fn run_repl(
197    manifest: &AgentManifest,
198    driver: &dyn LlmDriver,
199    tools: &ToolRegistry,
200    memory: &dyn MemorySubstrate,
201    max_turns: u32,
202    budget_usd: f64,
203    resume_id: Option<&str>,
204) -> anyhow::Result<()> {
205    let rt = tokio::runtime::Builder::new_current_thread()
206        .enable_all()
207        .build()
208        .map_err(|e| anyhow::anyhow!("tokio runtime: {e}"))?;
209
210    print_welcome(manifest, driver);
211
212    let ctx_window = driver.context_window();
213
214    let (mut session, mut history) = resume_or_new(resume_id, &manifest.name, ctx_window);
215
216    let stdin = io::stdin();
217    let mut line_buf = String::new();
218
219    loop {
220        // Check turn budget
221        if session.turn_count >= max_turns {
222            println!(
223                "\n{} Max turns ({}) reached. Session complete.",
224                "⚠".bright_yellow(),
225                max_turns
226            );
227            break;
228        }
229        if session.estimated_cost_usd >= budget_usd {
230            println!(
231                "\n{} Budget (${:.2}) exhausted. Session complete.",
232                "⚠".bright_yellow(),
233                budget_usd
234            );
235            break;
236        }
237
238        // Read input
239        let input = match read_input(&stdin, &mut line_buf, &session, budget_usd, &mut history) {
240            InputResult::Prompt(s) => s,
241            InputResult::SlashHandled => continue,
242            InputResult::Exit => break,
243            InputResult::Empty => continue,
244        };
245
246        // Execute turn with streaming
247        let cancel = Arc::new(AtomicBool::new(false));
248        let cancel_clone = Arc::clone(&cancel);
249
250        rt.block_on(async {
251            let flag = cancel_clone;
252            tokio::spawn(async move {
253                if tokio::signal::ctrl_c().await.is_ok() {
254                    flag.store(true, Ordering::SeqCst);
255                }
256            });
257        });
258
259        let (tx, rx) = mpsc::channel::<StreamEvent>(64);
260
261        println!();
262
263        let history_len_before = history.len();
264        let result = rt.block_on(run_turn_streaming(
265            manifest,
266            &input,
267            driver,
268            tools,
269            memory,
270            &mut history,
271            tx,
272            rx,
273            &cancel,
274        ));
275
276        match result {
277            Ok(r) => {
278                let cost = driver.estimate_cost(&r.usage);
279                session.record_turn(&r, cost);
280                // Persist new messages to JSONL
281                session.persist_messages(&history, history_len_before);
282                // Auto-compact at 80% context window (spec §7.3)
283                session.auto_compact_if_needed(&mut history);
284                print_turn_footer(&r, cost, &session, budget_usd);
285            }
286            Err(e) => {
287                if cancel.load(Ordering::SeqCst) {
288                    println!("\n{} Generation cancelled.", "⚠".bright_yellow());
289                } else {
290                    println!("\n{} Error: {e}", "✗".bright_red());
291                }
292            }
293        }
294    }
295
296    print_session_summary(&session);
297    Ok(())
298}
299
300/// Input reading result.
301enum InputResult {
302    Prompt(String),
303    SlashHandled,
304    Exit,
305    Empty,
306}
307
308/// Read one line of input, handling slash commands inline.
309fn read_input(
310    stdin: &io::Stdin,
311    buf: &mut String,
312    session: &ReplSession,
313    budget: f64,
314    history: &mut Vec<Message>,
315) -> InputResult {
316    let cost_str = if session.estimated_cost_usd > 0.0 {
317        format!(" ${:.3}", session.estimated_cost_usd)
318    } else {
319        String::new()
320    };
321    print!(
322        "\n{}{} ",
323        format!("[{}/{}{}]", session.turn_count + 1, "?", cost_str).dimmed(),
324        " >".bright_green().bold(),
325    );
326    io::stdout().flush().ok();
327
328    buf.clear();
329    let bytes = match stdin.read_line(buf) {
330        Ok(b) => b,
331        Err(_) => return InputResult::Exit,
332    };
333    if bytes == 0 {
334        println!();
335        return InputResult::Exit;
336    }
337
338    let trimmed = buf.trim();
339    if trimmed.is_empty() {
340        return InputResult::Empty;
341    }
342
343    // Handle slash commands
344    if let Some(cmd) = SlashCommand::parse(trimmed) {
345        handle_slash_command(&cmd, session, budget, history);
346        return match cmd {
347            SlashCommand::Quit => InputResult::Exit,
348            _ => InputResult::SlashHandled,
349        };
350    }
351
352    InputResult::Prompt(trimmed.to_string())
353}
354
355/// Handle a slash command.
356fn handle_slash_command(
357    cmd: &SlashCommand,
358    session: &ReplSession,
359    budget: f64,
360    history: &mut Vec<Message>,
361) {
362    match cmd {
363        SlashCommand::Help => print_help(),
364        SlashCommand::Quit => println!("{} Goodbye.", "✓".green()),
365        SlashCommand::Cost => {
366            // PMAT-169: local inference is free — show tokens, not misleading dollars
367            if session.estimated_cost_usd < 0.0001 {
368                println!("  Cost: {} (local inference)", "free".green());
369            } else {
370                println!(
371                    "  Cost: ${:.4} / ${:.2} ({:.1}%)",
372                    session.estimated_cost_usd,
373                    budget,
374                    (session.estimated_cost_usd / budget * 100.0).min(100.0)
375                );
376            }
377            println!(
378                "  Tokens: {} in / {} out",
379                session.total_input_tokens, session.total_output_tokens
380            );
381            println!("  Turns: {}, Tool calls: {}", session.turn_count, session.total_tool_calls);
382        }
383        SlashCommand::Context => {
384            let user_msgs = history.iter().filter(|m| matches!(m, Message::User(_))).count();
385            let asst_msgs = history.iter().filter(|m| matches!(m, Message::Assistant(_))).count();
386            let tool_msgs = history
387                .iter()
388                .filter(|m| matches!(m, Message::AssistantToolUse(_) | Message::ToolResult(_)))
389                .count();
390            let usage_pct = session.context_usage(history) * 100.0;
391            let est_tokens = ReplSession::estimate_history_tokens(history);
392            println!(
393                "  History: {} messages ({} user, {} assistant, {} tool)",
394                history.len(),
395                user_msgs,
396                asst_msgs,
397                tool_msgs
398            );
399            println!(
400                "  Context: ~{} / {} tokens ({:.0}%)",
401                est_tokens, session.context_window, usage_pct
402            );
403            if usage_pct >= 80.0 {
404                println!("  {} Near context limit — /compact to free space", "⚠".bright_yellow());
405            }
406            println!("  Turns: {}", session.turn_count);
407        }
408        SlashCommand::Model => {
409            println!("  Model switching not yet implemented.");
410        }
411        SlashCommand::Compact => {
412            let before = history.len();
413            compact_history(history);
414            println!("  Compacted: {} -> {} messages", before, history.len());
415        }
416        SlashCommand::Clear => {
417            history.clear();
418            print!("\x1B[2J\x1B[1;1H");
419            io::stdout().flush().ok();
420            println!("  Screen and conversation history cleared.");
421        }
422        SlashCommand::Session => {
423            if let Some(id) = session.session_id() {
424                println!("  Session: {id}");
425                println!("  Turns: {}, Messages: {}", session.turn_count, history.len());
426            } else {
427                println!("  No active session (persistence disabled).");
428            }
429        }
430        SlashCommand::Sessions => {
431            list_recent_sessions();
432        }
433        SlashCommand::Test => {
434            println!("  Running tests...");
435            let _ = io::stdout().flush();
436            run_shell_shortcut("cargo test --lib 2>&1 | tail -5");
437        }
438        SlashCommand::Quality => {
439            println!("  Running quality gate...");
440            let _ = io::stdout().flush();
441            run_shell_shortcut("cargo clippy -- -D warnings 2>&1 | tail -3 && cargo test --lib --quiet 2>&1 | tail -3");
442        }
443        SlashCommand::Unknown(name) => {
444            println!("  {} Unknown command: {name}. Type /help for commands.", "?".bright_yellow());
445        }
446    }
447}
448
449/// Execute one turn with streaming output and multi-turn history.
450#[allow(clippy::too_many_arguments)]
451async fn run_turn_streaming(
452    manifest: &AgentManifest,
453    prompt: &str,
454    driver: &dyn LlmDriver,
455    tools: &ToolRegistry,
456    memory: &dyn MemorySubstrate,
457    history: &mut Vec<Message>,
458    tx: mpsc::Sender<StreamEvent>,
459    mut rx: mpsc::Receiver<StreamEvent>,
460    cancel: &Arc<AtomicBool>,
461) -> Result<AgentLoopResult, crate::agent::result::AgentError> {
462    // Drain task: print streaming events as they arrive
463    let drain = tokio::spawn(async move {
464        while let Some(event) = rx.recv().await {
465            print_stream_event_repl(&event);
466        }
467    });
468
469    let result = crate::agent::runtime::run_agent_turn(
470        manifest,
471        history,
472        prompt,
473        driver,
474        tools,
475        memory,
476        Some(tx),
477    )
478    .await;
479
480    // If cancelled, wrap the error
481    if cancel.load(Ordering::SeqCst) && result.is_err() {
482        return Err(crate::agent::result::AgentError::CircuitBreak("cancelled by user".into()));
483    }
484
485    // Ensure drain task finishes
486    let _ = drain.await;
487    result
488}
489
490// Display functions extracted to repl_display.rs for file size compliance.
491use super::repl_display::{
492    compact_history, list_recent_sessions, print_help, print_session_summary,
493    print_stream_event_repl, print_turn_footer, print_welcome, run_shell_shortcut,
494};
495
496#[cfg(test)]
497#[path = "repl_tests.rs"]
498mod tests;