1use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::process::Stdio;
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
15use tokio::process::{Child, Command};
16use tokio::time::timeout;
17
18use crate::provider::{CompletionRequest, ContentPart, Message, Provider, Role};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
22#[serde(rename_all = "lowercase")]
23pub enum ReplRuntime {
24 #[default]
26 Rust,
27 Bun,
29 Python,
31}
32
33pub struct RlmRepl {
35 runtime: ReplRuntime,
36 context: String,
37 context_lines: Vec<String>,
38 variables: HashMap<String, String>,
39}
40
41#[derive(Debug, Clone)]
43pub struct ReplResult {
44 pub stdout: String,
45 pub stderr: String,
46 pub final_answer: Option<String>,
47}
48
49impl RlmRepl {
50 pub fn new(context: String, runtime: ReplRuntime) -> Self {
52 let context_lines = context.lines().map(|s| s.to_string()).collect();
53 Self {
54 runtime,
55 context,
56 context_lines,
57 variables: HashMap::new(),
58 }
59 }
60
61 pub fn context(&self) -> &str {
63 &self.context
64 }
65
66 pub fn lines(&self) -> &[String] {
68 &self.context_lines
69 }
70
71 pub fn head(&self, n: usize) -> Vec<&str> {
73 self.context_lines
74 .iter()
75 .take(n)
76 .map(|s| s.as_str())
77 .collect()
78 }
79
80 pub fn tail(&self, n: usize) -> Vec<&str> {
82 let start = self.context_lines.len().saturating_sub(n);
83 self.context_lines
84 .iter()
85 .skip(start)
86 .map(|s| s.as_str())
87 .collect()
88 }
89
90 pub fn grep(&self, pattern: &str) -> Vec<(usize, &str)> {
92 let re = match regex::Regex::new(pattern) {
93 Ok(r) => r,
94 Err(_) => {
95 return self
97 .context_lines
98 .iter()
99 .enumerate()
100 .filter(|(_, line)| line.contains(pattern))
101 .map(|(i, line)| (i + 1, line.as_str()))
102 .collect();
103 }
104 };
105
106 self.context_lines
107 .iter()
108 .enumerate()
109 .filter(|(_, line)| re.is_match(line))
110 .map(|(i, line)| (i + 1, line.as_str()))
111 .collect()
112 }
113
114 pub fn count(&self, pattern: &str) -> usize {
116 let re = match regex::Regex::new(pattern) {
117 Ok(r) => r,
118 Err(_) => return self.context.matches(pattern).count(),
119 };
120 re.find_iter(&self.context).count()
121 }
122
123 pub fn slice(&self, start: usize, end: usize) -> &str {
125 let end = end.min(self.context.len());
126 let start = start.min(end);
127 &self.context[start..end]
128 }
129
130 pub fn chunks(&self, n: usize) -> Vec<String> {
132 if n == 0 {
133 return vec![self.context.clone()];
134 }
135
136 let chunk_size = self.context_lines.len().div_ceil(n);
137 self.context_lines
138 .chunks(chunk_size)
139 .map(|chunk| chunk.join("\n"))
140 .collect()
141 }
142
143 pub fn set_var(&mut self, name: &str, value: String) {
145 self.variables.insert(name.to_string(), value);
146 }
147
148 pub fn get_var(&self, name: &str) -> Option<&str> {
150 self.variables.get(name).map(|s| s.as_str())
151 }
152
153 pub fn execute(&mut self, code: &str) -> ReplResult {
164 match self.runtime {
165 ReplRuntime::Rust => self.execute_rust_dsl(code),
166 ReplRuntime::Bun | ReplRuntime::Python => {
167 self.execute_rust_dsl(code)
170 }
171 }
172 }
173
174 fn execute_rust_dsl(&mut self, code: &str) -> ReplResult {
175 let mut stdout = Vec::new();
176 let mut final_answer = None;
177
178 for line in code.lines() {
179 let line = line.trim();
180 if line.is_empty() || line.starts_with("//") || line.starts_with('#') {
181 continue;
182 }
183
184 if let Some(result) = self.execute_dsl_line(line) {
186 match result {
187 DslResult::Output(s) => stdout.push(s),
188 DslResult::Final(s) => {
189 final_answer = Some(s);
190 break;
191 }
192 DslResult::Error(s) => stdout.push(format!("Error: {}", s)),
193 }
194 }
195 }
196
197 ReplResult {
198 stdout: stdout.join("\n"),
199 stderr: String::new(),
200 final_answer,
201 }
202 }
203
204 pub fn execute_dsl_line(&mut self, line: &str) -> Option<DslResult> {
205 if line.starts_with("FINAL(") || line.starts_with("FINAL!(") {
207 let start = line.find('(').unwrap() + 1;
208 let end = line.rfind(')').unwrap_or(line.len());
209 let answer = line[start..end]
210 .trim()
211 .trim_matches(|c| c == '"' || c == '\'' || c == '`');
212 return Some(DslResult::Final(answer.to_string()));
213 }
214
215 if line.starts_with("print(")
217 || line.starts_with("println!(")
218 || line.starts_with("console.log(")
219 {
220 let start = line.find('(').unwrap() + 1;
221 let end = line.rfind(')').unwrap_or(line.len());
222 let content = line[start..end]
223 .trim()
224 .trim_matches(|c| c == '"' || c == '\'' || c == '`');
225
226 let expanded = self.expand_expression(content);
228 return Some(DslResult::Output(expanded));
229 }
230
231 if let Some(eq_pos) = line.find('=') {
233 if !line.contains("==") && !line.starts_with("if ") {
234 let var_name = line[..eq_pos]
235 .trim()
236 .trim_start_matches("let ")
237 .trim_start_matches("const ")
238 .trim_start_matches("var ")
239 .trim();
240 let expr = line[eq_pos + 1..].trim().trim_end_matches(';');
241
242 let value = self.evaluate_expression(expr);
243 self.set_var(var_name, value);
244 return None;
245 }
246 }
247
248 if line.starts_with("head(")
250 || line.starts_with("tail(")
251 || line.starts_with("grep(")
252 || line.starts_with("count(")
253 || line.starts_with("lines()")
254 || line.starts_with("slice(")
255 || line.starts_with("chunks(")
256 || line.starts_with("context")
257 {
258 let result = self.evaluate_expression(line);
259 return Some(DslResult::Output(result));
260 }
261
262 None
263 }
264
265 fn expand_expression(&self, expr: &str) -> String {
266 let mut result = expr.to_string();
268
269 for (name, value) in &self.variables {
270 let patterns = [
271 format!("${{{}}}", name),
272 format!("${}", name),
273 format!("{{{}}}", name),
274 ];
275 for p in patterns {
276 result = result.replace(&p, value);
277 }
278 }
279
280 if result.contains("context.len()") || result.contains("context.length") {
282 result = result
283 .replace("context.len()", &self.context.len().to_string())
284 .replace("context.length", &self.context.len().to_string());
285 }
286
287 if result.contains("lines().len()") || result.contains("lines().length") {
288 result = result
289 .replace("lines().len()", &self.context_lines.len().to_string())
290 .replace("lines().length", &self.context_lines.len().to_string());
291 }
292
293 result
294 }
295
296 pub fn evaluate_expression(&mut self, expr: &str) -> String {
297 let expr = expr.trim().trim_end_matches(';');
298
299 if expr.starts_with("head(") {
301 let n = self.extract_number(expr).unwrap_or(10);
302 return self.head(n).join("\n");
303 }
304
305 if expr.starts_with("tail(") {
307 let n = self.extract_number(expr).unwrap_or(10);
308 return self.tail(n).join("\n");
309 }
310
311 if expr.starts_with("grep(") {
313 let pattern = self.extract_string(expr).unwrap_or_default();
314 let matches = self.grep(&pattern);
315 return matches
316 .iter()
317 .map(|(i, line)| format!("{}:{}", i, line))
318 .collect::<Vec<_>>()
319 .join("\n");
320 }
321
322 if expr.starts_with("count(") {
324 let pattern = self.extract_string(expr).unwrap_or_default();
325 return self.count(&pattern).to_string();
326 }
327
328 if expr == "lines()" || expr == "lines" {
330 return format!("Lines: {}", self.context_lines.len());
331 }
332
333 if expr.starts_with("slice(") {
335 let nums = self.extract_numbers(expr);
336 if nums.len() >= 2 {
337 return self.slice(nums[0], nums[1]).to_string();
338 }
339 }
340
341 if expr.starts_with("chunks(") || expr.starts_with("chunk(") {
343 let n = self.extract_number(expr).unwrap_or(5);
344 let chunks = self.chunks(n);
345 return format!(
346 "[{} chunks of {} lines each]",
347 chunks.len(),
348 chunks.first().map(|c| c.lines().count()).unwrap_or(0)
349 );
350 }
351
352 if expr == "context" || expr.starts_with("context.slice") || expr.starts_with("context[") {
354 return format!(
355 "[Context: {} chars, {} lines]",
356 self.context.len(),
357 self.context_lines.len()
358 );
359 }
360
361 if let Some(val) = self.get_var(expr) {
363 return val.to_string();
364 }
365
366 if (expr.starts_with('"') && expr.ends_with('"'))
368 || (expr.starts_with('\'') && expr.ends_with('\''))
369 {
370 return expr[1..expr.len() - 1].to_string();
371 }
372
373 expr.to_string()
374 }
375
376 fn extract_number(&self, expr: &str) -> Option<usize> {
377 let start = expr.find('(')?;
378 let end = expr.find(')')?;
379 let inner = expr[start + 1..end].trim();
380 inner.parse().ok()
381 }
382
383 fn extract_numbers(&self, expr: &str) -> Vec<usize> {
384 let start = expr.find('(').unwrap_or(0);
385 let end = expr.find(')').unwrap_or(expr.len());
386 let inner = &expr[start + 1..end];
387
388 inner
389 .split(',')
390 .filter_map(|s| s.trim().parse().ok())
391 .collect()
392 }
393
394 fn extract_string(&self, expr: &str) -> Option<String> {
395 let start = expr.find('(')?;
396 let end = expr.rfind(')')?;
397 let inner = expr[start + 1..end].trim();
398
399 let unquoted = inner
401 .trim_start_matches(['"', '\'', '`', '/'])
402 .trim_end_matches(['"', '\'', '`', '/']);
403
404 Some(unquoted.to_string())
405 }
406}
407
408pub enum DslResult {
409 Output(String),
410 Final(String),
411 #[allow(dead_code)]
412 Error(String),
413}
414
415pub struct RlmExecutor {
423 repl: RlmRepl,
424 provider: Arc<dyn Provider>,
425 model: String,
426 max_iterations: usize,
427 sub_queries: Vec<SubQuery>,
428 verbose: bool,
429}
430
431#[derive(Debug, Clone, Serialize, Deserialize)]
433pub struct SubQuery {
434 pub query: String,
435 pub context_slice: Option<String>,
436 pub response: String,
437 pub tokens_used: usize,
438}
439
440impl RlmExecutor {
441 pub fn new(context: String, provider: Arc<dyn Provider>, model: String) -> Self {
443 Self {
444 repl: RlmRepl::new(context, ReplRuntime::Rust),
445 provider,
446 model,
447 max_iterations: 5, sub_queries: Vec::new(),
449 verbose: false,
450 }
451 }
452
453 pub fn with_max_iterations(mut self, max: usize) -> Self {
455 self.max_iterations = max;
456 self
457 }
458
459 pub fn with_verbose(mut self, verbose: bool) -> Self {
464 self.verbose = verbose;
465 self
466 }
467
468 pub async fn analyze(&mut self, query: &str) -> Result<RlmAnalysisResult> {
470 let start = std::time::Instant::now();
471 let mut iterations = 0;
472 let mut total_input_tokens = 0;
473 let mut total_output_tokens = 0;
474
475 let context_summary = format!(
477 "=== CONTEXT LOADED ===\n\
478 Total: {} chars, {} lines\n\
479 Available functions:\n\
480 - head(n) - first n lines\n\
481 - tail(n) - last n lines\n\
482 - grep(\"pattern\") - find lines matching regex\n\
483 - count(\"pattern\") - count regex matches\n\
484 - slice(start, end) - slice by char position\n\
485 - chunks(n) - split into n chunks\n\
486 - llm_query(\"question\", context?) - ask sub-LM a question\n\
487 - FINAL(\"answer\") - return final answer\n\
488 === END CONTEXT INFO ===",
489 self.repl.context().len(),
490 self.repl.lines().len()
491 );
492
493 if self.verbose {
495 tracing::info!("RLM Context Summary:\n{}", context_summary);
496 println!(
497 "[RLM] Context loaded: {} chars, {} lines",
498 self.repl.context().len(),
499 self.repl.lines().len()
500 );
501 }
502
503 let system_prompt = format!(
504 "You are a code analysis assistant. Answer questions by examining the provided context.\n\n\
505 IMPORTANT: You MUST end your response with FINAL(\"your answer\") in 1-3 iterations.\n\n\
506 Available commands:\n\
507 - head(n), tail(n): See first/last n lines\n\
508 - grep(\"pattern\"): Search for patterns\n\
509 - llm_query(\"question\"): Ask a focused sub-question\n\
510 - FINAL(\"answer\"): Return your final answer (REQUIRED)\n\n\
511 The context has {} chars across {} lines. A preview follows:\n\n\
512 {}\n\n\
513 Now analyze the context. Use 1-2 commands if needed, then call FINAL() with your answer.",
514 self.repl.context().len(),
515 self.repl.lines().len(),
516 self.repl.head(25).join("\n")
517 );
518
519 let mut messages = vec![
520 Message {
521 role: Role::System,
522 content: vec![ContentPart::Text {
523 text: system_prompt,
524 }],
525 },
526 Message {
527 role: Role::User,
528 content: vec![ContentPart::Text {
529 text: format!("Analyze and answer: {}", query),
530 }],
531 },
532 ];
533
534 let mut final_answer = None;
535
536 while iterations < self.max_iterations {
537 iterations += 1;
538 tracing::info!("RLM iteration {}", iterations);
539
540 tracing::debug!("Sending LLM request...");
542 let response = match tokio::time::timeout(
543 std::time::Duration::from_secs(60),
544 self.provider.complete(CompletionRequest {
545 messages: messages.clone(),
546 tools: vec![],
547 model: self.model.clone(),
548 temperature: Some(0.3),
549 top_p: None,
550 max_tokens: Some(2000),
551 stop: vec![],
552 }),
553 )
554 .await
555 {
556 Ok(Ok(r)) => {
557 tracing::debug!("LLM response received");
558 r
559 }
560 Ok(Err(e)) => return Err(e),
561 Err(_) => return Err(anyhow::anyhow!("LLM request timed out after 60 seconds")),
562 };
563
564 total_input_tokens += response.usage.prompt_tokens;
565 total_output_tokens += response.usage.completion_tokens;
566
567 let assistant_text = response
569 .message
570 .content
571 .iter()
572 .filter_map(|p| match p {
573 ContentPart::Text { text } => Some(text.as_str()),
574 _ => None,
575 })
576 .collect::<Vec<_>>()
577 .join("");
578
579 messages.push(Message {
581 role: Role::Assistant,
582 content: vec![ContentPart::Text {
583 text: assistant_text.clone(),
584 }],
585 });
586
587 let code = self.extract_code(&assistant_text);
589
590 if self.verbose {
592 println!("[RLM] Iteration {}: Executing code:\n{}", iterations, code);
593 }
594
595 let execution_result = self.execute_with_llm_query(&code).await?;
596
597 if self.verbose {
599 if let Some(ref answer) = execution_result.final_answer {
600 println!("[RLM] Final answer received: {}", answer);
601 } else if !execution_result.stdout.is_empty() {
602 let preview = if execution_result.stdout.len() > 200 {
603 format!("{}...", &execution_result.stdout[..200])
604 } else {
605 execution_result.stdout.clone()
606 };
607 println!("[RLM] Execution output:\n{}", preview);
608 }
609 }
610
611 if let Some(answer) = &execution_result.final_answer {
613 final_answer = Some(answer.clone());
614 break;
615 }
616
617 let result_text = if execution_result.stdout.is_empty() {
619 "[No output]".to_string()
620 } else {
621 format!("Execution result:\n{}", execution_result.stdout)
622 };
623
624 messages.push(Message {
625 role: Role::User,
626 content: vec![ContentPart::Text { text: result_text }],
627 });
628 }
629
630 let elapsed = start.elapsed();
631
632 Ok(RlmAnalysisResult {
633 answer: final_answer.unwrap_or_else(|| "Analysis incomplete".to_string()),
634 iterations,
635 sub_queries: self.sub_queries.clone(),
636 stats: super::RlmStats {
637 input_tokens: total_input_tokens,
638 output_tokens: total_output_tokens,
639 iterations,
640 subcalls: self.sub_queries.len(),
641 elapsed_ms: elapsed.as_millis() as u64,
642 compression_ratio: 1.0,
643 },
644 })
645 }
646
647 fn extract_code(&self, text: &str) -> String {
649 let mut code_lines = Vec::new();
651 let mut in_code_block = false;
652
653 for line in text.lines() {
654 if line.starts_with("```") {
655 in_code_block = !in_code_block;
656 continue;
657 }
658 if in_code_block {
659 code_lines.push(line);
660 }
661 }
662
663 if !code_lines.is_empty() {
664 return code_lines.join("\n");
665 }
666
667 text.lines()
669 .filter(|line| {
670 let l = line.trim();
671 l.starts_with("head(")
672 || l.starts_with("tail(")
673 || l.starts_with("grep(")
674 || l.starts_with("count(")
675 || l.starts_with("llm_query(")
676 || l.starts_with("FINAL(")
677 || l.starts_with("let ")
678 || l.starts_with("const ")
679 || l.starts_with("print")
680 || l.starts_with("console.")
681 })
682 .collect::<Vec<_>>()
683 .join("\n")
684 }
685
686 async fn execute_with_llm_query(&mut self, code: &str) -> Result<ReplResult> {
688 let mut stdout = Vec::new();
689 let mut final_answer = None;
690
691 for line in code.lines() {
692 let line = line.trim();
693 if line.is_empty() || line.starts_with("//") || line.starts_with('#') {
694 continue;
695 }
696
697 if line.starts_with("llm_query(") || line.contains("= llm_query(") {
699 let result = self.handle_llm_query(line).await?;
700 stdout.push(result);
701 continue;
702 }
703
704 if let Some(result) = self.repl.execute_dsl_line(line) {
706 match result {
707 DslResult::Output(s) => stdout.push(s),
708 DslResult::Final(s) => {
709 final_answer = Some(s);
710 break;
711 }
712 DslResult::Error(s) => stdout.push(format!("Error: {}", s)),
713 }
714 }
715 }
716
717 Ok(ReplResult {
718 stdout: stdout.join("\n"),
719 stderr: String::new(),
720 final_answer,
721 })
722 }
723
724 async fn handle_llm_query(&mut self, line: &str) -> Result<String> {
726 let (query, context_slice) = self.parse_llm_query(line);
728
729 let context_to_analyze = context_slice
731 .clone()
732 .unwrap_or_else(|| self.repl.context().to_string());
733
734 let truncated_context = if context_to_analyze.len() > 8000 {
736 format!(
737 "{}...\n[truncated, {} chars total]",
738 &context_to_analyze[..7500],
739 context_to_analyze.len()
740 )
741 } else {
742 context_to_analyze.clone()
743 };
744
745 let messages = vec![
747 Message {
748 role: Role::System,
749 content: vec![ContentPart::Text {
750 text: "You are a focused analysis assistant. Answer the question based on the provided context. Be concise.".to_string(),
751 }],
752 },
753 Message {
754 role: Role::User,
755 content: vec![ContentPart::Text {
756 text: format!("Context:\n{}\n\nQuestion: {}", truncated_context, query),
757 }],
758 },
759 ];
760
761 let response = self
762 .provider
763 .complete(CompletionRequest {
764 messages,
765 tools: vec![],
766 model: self.model.clone(),
767 temperature: Some(0.3),
768 top_p: None,
769 max_tokens: Some(500),
770 stop: vec![],
771 })
772 .await?;
773
774 let answer = response
775 .message
776 .content
777 .iter()
778 .filter_map(|p| match p {
779 ContentPart::Text { text } => Some(text.as_str()),
780 _ => None,
781 })
782 .collect::<Vec<_>>()
783 .join("");
784
785 self.sub_queries.push(SubQuery {
787 query: query.clone(),
788 context_slice,
789 response: answer.clone(),
790 tokens_used: response.usage.total_tokens,
791 });
792
793 Ok(format!("llm_query result: {}", answer))
794 }
795
796 fn parse_llm_query(&mut self, line: &str) -> (String, Option<String>) {
798 let start = line.find('(').unwrap_or(0) + 1;
800 let end = line.rfind(')').unwrap_or(line.len());
801 let args = &line[start..end];
802
803 let mut query = String::new();
805 let mut context = None;
806 let mut in_quotes = false;
807 let mut current = String::new();
808 let mut parts = Vec::new();
809
810 for c in args.chars() {
811 if c == '"' || c == '\'' {
812 in_quotes = !in_quotes;
813 } else if c == ',' && !in_quotes {
814 parts.push(current.trim().to_string());
815 current = String::new();
816 continue;
817 }
818 current.push(c);
819 }
820 if !current.is_empty() {
821 parts.push(current.trim().to_string());
822 }
823
824 if let Some(q) = parts.first() {
826 query = q.trim_matches(|c| c == '"' || c == '\'').to_string();
827 }
828
829 if let Some(ctx_expr) = parts.get(1) {
831 let ctx = self.repl.evaluate_expression(ctx_expr);
833 if !ctx.is_empty() && !ctx.starts_with('[') {
834 context = Some(ctx);
835 }
836 }
837
838 (query, context)
839 }
840}
841
842#[derive(Debug, Clone, Serialize, Deserialize)]
844pub struct RlmAnalysisResult {
845 pub answer: String,
846 pub iterations: usize,
847 pub sub_queries: Vec<SubQuery>,
848 pub stats: super::RlmStats,
849}
850
851pub struct ExternalRepl {
853 child: Child,
854 #[allow(dead_code)]
855 runtime: ReplRuntime,
856}
857
858impl ExternalRepl {
859 pub async fn spawn_bun(context: &str) -> Result<Self> {
861 let init_script = Self::generate_bun_init(context);
862
863 let temp_dir = std::env::temp_dir().join("rlm-repl");
865 tokio::fs::create_dir_all(&temp_dir).await?;
866 let script_path = temp_dir.join(format!("init_{}.js", std::process::id()));
867 tokio::fs::write(&script_path, init_script).await?;
868
869 let runtime = if Self::is_bun_available().await {
871 "bun"
872 } else {
873 "node"
874 };
875
876 let child = Command::new(runtime)
877 .arg(&script_path)
878 .stdin(Stdio::piped())
879 .stdout(Stdio::piped())
880 .stderr(Stdio::piped())
881 .spawn()?;
882
883 Ok(Self {
884 child,
885 runtime: ReplRuntime::Bun,
886 })
887 }
888
889 async fn is_bun_available() -> bool {
890 Command::new("bun")
891 .arg("--version")
892 .output()
893 .await
894 .map(|o| o.status.success())
895 .unwrap_or(false)
896 }
897
898 fn generate_bun_init(context: &str) -> String {
899 let escaped = context
900 .replace('\\', "\\\\")
901 .replace('"', "\\\"")
902 .replace('\n', "\\n");
903
904 format!(
905 r#"
906const readline = require('readline');
907const rl = readline.createInterface({{ input: process.stdin, output: process.stdout, terminal: false }});
908
909const context = "{escaped}";
910
911function lines() {{ return context.split("\n"); }}
912function head(n = 10) {{ return lines().slice(0, n).join("\n"); }}
913function tail(n = 10) {{ return lines().slice(-n).join("\n"); }}
914function grep(pattern) {{
915 const re = pattern instanceof RegExp ? pattern : new RegExp(pattern, 'gi');
916 return lines().filter(l => re.test(l));
917}}
918function count(pattern) {{
919 const re = pattern instanceof RegExp ? pattern : new RegExp(pattern, 'gi');
920 return (context.match(re) || []).length;
921}}
922function FINAL(answer) {{
923 console.log("__FINAL__" + String(answer) + "__FINAL_END__");
924}}
925
926console.log("READY");
927
928rl.on('line', async (line) => {{
929 try {{
930 const result = eval(line);
931 if (result !== undefined) console.log(result);
932 }} catch (e) {{
933 console.error("Error:", e.message);
934 }}
935 console.log("__DONE__");
936}});
937"#
938 )
939 }
940
941 pub async fn execute(&mut self, code: &str) -> Result<ReplResult> {
943 let stdin = self
944 .child
945 .stdin
946 .as_mut()
947 .ok_or_else(|| anyhow::anyhow!("No stdin"))?;
948 let stdout = self
949 .child
950 .stdout
951 .as_mut()
952 .ok_or_else(|| anyhow::anyhow!("No stdout"))?;
953
954 stdin.write_all(code.as_bytes()).await?;
955 stdin.write_all(b"\n").await?;
956 stdin.flush().await?;
957
958 let mut reader = BufReader::new(stdout);
959 let mut output = Vec::new();
960 let mut final_answer = None;
961
962 loop {
963 let mut line = String::new();
964 match timeout(Duration::from_secs(30), reader.read_line(&mut line)).await {
965 Ok(Ok(0)) | Err(_) => break, Ok(Ok(_)) => {
967 let line = line.trim();
968 if line == "__DONE__" {
969 break;
970 }
971 if let Some(answer) = Self::extract_final(line) {
972 final_answer = Some(answer);
973 break;
974 }
975 output.push(line.to_string());
976 }
977 Ok(Err(e)) => return Err(anyhow::anyhow!("Read error: {}", e)),
978 }
979 }
980
981 Ok(ReplResult {
982 stdout: output.join("\n"),
983 stderr: String::new(),
984 final_answer,
985 })
986 }
987
988 fn extract_final(line: &str) -> Option<String> {
989 if line.contains("__FINAL__") {
990 let start = line.find("__FINAL__")? + 9;
991 let end = line.find("__FINAL_END__")?;
992 return Some(line[start..end].to_string());
993 }
994 None
995 }
996
997 pub async fn destroy(&mut self) -> Result<()> {
999 tracing::debug!(runtime = ?self.runtime, "Destroying external REPL");
1000 self.child.kill().await?;
1001 Ok(())
1002 }
1003
1004 pub fn runtime(&self) -> ReplRuntime {
1006 self.runtime
1007 }
1008}
1009
1010#[cfg(test)]
1011mod tests {
1012 use super::*;
1013
1014 #[test]
1015 fn test_repl_head_tail() {
1016 let context = (1..=100)
1017 .map(|i| format!("line {}", i))
1018 .collect::<Vec<_>>()
1019 .join("\n");
1020 let repl = RlmRepl::new(context, ReplRuntime::Rust);
1021
1022 let head = repl.head(5);
1023 assert_eq!(head.len(), 5);
1024 assert_eq!(head[0], "line 1");
1025
1026 let tail = repl.tail(5);
1027 assert_eq!(tail.len(), 5);
1028 assert_eq!(tail[4], "line 100");
1029 }
1030
1031 #[test]
1032 fn test_repl_grep() {
1033 let context = "error: something failed\ninfo: all good\nerror: another failure".to_string();
1034 let repl = RlmRepl::new(context, ReplRuntime::Rust);
1035
1036 let matches = repl.grep("error");
1037 assert_eq!(matches.len(), 2);
1038 }
1039
1040 #[test]
1041 fn test_repl_execute_final() {
1042 let context = "test content".to_string();
1043 let mut repl = RlmRepl::new(context, ReplRuntime::Rust);
1044
1045 let result = repl.execute(r#"FINAL("This is the answer")"#);
1046 assert_eq!(result.final_answer, Some("This is the answer".to_string()));
1047 }
1048
1049 #[test]
1050 fn test_repl_chunks() {
1051 let context = (1..=100)
1052 .map(|i| format!("line {}", i))
1053 .collect::<Vec<_>>()
1054 .join("\n");
1055 let repl = RlmRepl::new(context, ReplRuntime::Rust);
1056
1057 let chunks = repl.chunks(5);
1058 assert_eq!(chunks.len(), 5);
1059 }
1060}