Skip to main content

run/engine/
cpp.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3use std::process::{Command, Stdio};
4use std::sync::{Mutex, OnceLock};
5use std::time::{Duration, Instant};
6
7use anyhow::{Context, Result};
8use tempfile::{Builder, TempDir};
9
10use super::{
11    ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession, cache_lookup, cache_store,
12    compiler_command, hash_source, perf_record, run_version_command, try_cached_execution,
13};
14
15pub struct CppEngine {
16    compiler: Option<PathBuf>,
17}
18
19impl Default for CppEngine {
20    fn default() -> Self {
21        Self::new()
22    }
23}
24
25impl CppEngine {
26    pub fn new() -> Self {
27        Self {
28            compiler: resolve_cpp_compiler(),
29        }
30    }
31
32    fn ensure_compiler(&self) -> Result<&Path> {
33        self.compiler.as_deref().ok_or_else(|| {
34            anyhow::anyhow!(
35                "C++ support requires a C++ compiler such as `c++`, `clang++`, or `g++`. Install one and ensure it is on your PATH."
36            )
37        })
38    }
39
40    fn write_source(&self, code: &str, dir: &Path) -> Result<PathBuf> {
41        let source_path = dir.join("main.cpp");
42        std::fs::write(&source_path, code).with_context(|| {
43            format!(
44                "failed to write temporary C++ source to {}",
45                source_path.display()
46            )
47        })?;
48        Ok(source_path)
49    }
50
51    fn copy_source(&self, original: &Path, dir: &Path) -> Result<PathBuf> {
52        let target = dir.join("main.cpp");
53        std::fs::copy(original, &target).with_context(|| {
54            format!(
55                "failed to copy C++ source from {} to {}",
56                original.display(),
57                target.display()
58            )
59        })?;
60        Ok(target)
61    }
62
63    fn compile(&self, source: &Path, output: &Path) -> Result<std::process::Output> {
64        let compiler = self.ensure_compiler()?;
65        let mut cmd = compiler_command(compiler);
66        cmd.arg(source)
67            .arg("-std=c++17")
68            .arg("-O0")
69            .arg("-w")
70            .arg("-o")
71            .arg(output)
72            .stdout(Stdio::piped())
73            .stderr(Stdio::piped());
74        if let Some(pch_header) = ensure_global_cpp_pch(compiler) {
75            cmd.arg("-include").arg(pch_header);
76        }
77        cmd.output().with_context(|| {
78            format!(
79                "failed to invoke {} to compile {}",
80                compiler.display(),
81                source.display()
82            )
83        })
84    }
85
86    fn run_binary(&self, binary: &Path, args: &[String]) -> Result<std::process::Output> {
87        let mut cmd = Command::new(binary);
88        cmd.args(args).stdout(Stdio::piped()).stderr(Stdio::piped());
89        cmd.stdin(Stdio::inherit());
90        cmd.output()
91            .with_context(|| format!("failed to execute compiled binary {}", binary.display()))
92    }
93
94    fn binary_path(dir: &Path) -> PathBuf {
95        let mut path = dir.join("run_cpp_binary");
96        let suffix = std::env::consts::EXE_SUFFIX;
97        if !suffix.is_empty() {
98            if let Some(stripped) = suffix.strip_prefix('.') {
99                path.set_extension(stripped);
100            } else {
101                path = PathBuf::from(format!("{}{}", path.display(), suffix));
102            }
103        }
104        path
105    }
106
107    fn execute_file_incremental(&self, source: &Path, args: &[String]) -> Result<ExecutionOutcome> {
108        let start = Instant::now();
109        let source_text = fs::read_to_string(source).unwrap_or_default();
110        let source_hash = hash_source(&source_text);
111
112        let compiler = self.ensure_compiler()?;
113        let source_key = source
114            .canonicalize()
115            .unwrap_or_else(|_| source.to_path_buf());
116        let workspace =
117            crate::cache::workspace("cpp-file", hash_source(&source_key.to_string_lossy()))?;
118        fs::create_dir_all(&workspace).with_context(|| {
119            format!(
120                "failed to create C++ incremental workspace {}",
121                workspace.display()
122            )
123        })?;
124        let obj = workspace.join("main.o");
125        let dep = workspace.join("main.d");
126        let bin = workspace.join("run_cpp_incremental_binary");
127
128        let needs_compile = cpp_needs_recompile(source, &obj, &dep);
129        if !needs_compile && bin.exists() {
130            perf_record("cpp", "file.workspace_hit");
131            cache_store("cpp-file", source_hash, &bin);
132            let run_output = self.run_binary(&bin, args)?;
133            return Ok(ExecutionOutcome {
134                language: self.id().to_string(),
135                exit_code: run_output.status.code(),
136                stdout: String::from_utf8_lossy(&run_output.stdout).into_owned(),
137                stderr: String::from_utf8_lossy(&run_output.stderr).into_owned(),
138                duration: start.elapsed(),
139            });
140        }
141
142        if let Some(cached_bin) = cache_lookup("cpp-file", source_hash) {
143            perf_record("cpp", "file.cache_hit");
144            let _ = fs::copy(&cached_bin, &bin);
145            let run_output = self.run_binary(&bin, args)?;
146            return Ok(ExecutionOutcome {
147                language: self.id().to_string(),
148                exit_code: run_output.status.code(),
149                stdout: String::from_utf8_lossy(&run_output.stdout).into_owned(),
150                stderr: String::from_utf8_lossy(&run_output.stderr).into_owned(),
151                duration: start.elapsed(),
152            });
153        }
154        perf_record("cpp", "file.cache_miss");
155
156        if needs_compile {
157            perf_record("cpp", "file.compile");
158            let mut compile = compiler_command(compiler);
159            compile
160                .arg(source)
161                .arg("-std=c++17")
162                .arg("-O0")
163                .arg("-w")
164                .arg("-c")
165                .arg("-MMD")
166                .arg("-MF")
167                .arg(&dep)
168                .arg("-o")
169                .arg(&obj)
170                .stdout(Stdio::piped())
171                .stderr(Stdio::piped());
172            let compile_out = compile.output().with_context(|| {
173                format!(
174                    "failed to invoke {} for incremental C++ compile",
175                    compiler.display()
176                )
177            })?;
178            if !compile_out.status.success() {
179                perf_record("cpp", "file.compile_fail");
180                return Ok(ExecutionOutcome {
181                    language: self.id().to_string(),
182                    exit_code: compile_out.status.code(),
183                    stdout: String::from_utf8_lossy(&compile_out.stdout).into_owned(),
184                    stderr: String::from_utf8_lossy(&compile_out.stderr).into_owned(),
185                    duration: start.elapsed(),
186                });
187            }
188
189            let mut link = compiler_command(compiler);
190            perf_record("cpp", "file.link");
191            link.arg(&obj)
192                .arg("-o")
193                .arg(&bin)
194                .stdout(Stdio::piped())
195                .stderr(Stdio::piped());
196            let link_out = link.output().with_context(|| {
197                format!(
198                    "failed to invoke {} for incremental C++ link",
199                    compiler.display()
200                )
201            })?;
202            if !link_out.status.success() {
203                perf_record("cpp", "file.link_fail");
204                return Ok(ExecutionOutcome {
205                    language: self.id().to_string(),
206                    exit_code: link_out.status.code(),
207                    stdout: String::from_utf8_lossy(&link_out.stdout).into_owned(),
208                    stderr: String::from_utf8_lossy(&link_out.stderr).into_owned(),
209                    duration: start.elapsed(),
210                });
211            }
212            cache_store("cpp-file", source_hash, &bin);
213        } else {
214            // Rehydrate persistent cache even when incremental workspace is already up-to-date.
215            perf_record("cpp", "file.rehydrate_cache");
216            cache_store("cpp-file", source_hash, &bin);
217        }
218
219        let run_output = self.run_binary(&bin, args)?;
220        Ok(ExecutionOutcome {
221            language: self.id().to_string(),
222            exit_code: run_output.status.code(),
223            stdout: String::from_utf8_lossy(&run_output.stdout).into_owned(),
224            stderr: String::from_utf8_lossy(&run_output.stderr).into_owned(),
225            duration: start.elapsed(),
226        })
227    }
228}
229
230impl LanguageEngine for CppEngine {
231    fn id(&self) -> &'static str {
232        "cpp"
233    }
234
235    fn display_name(&self) -> &'static str {
236        "C++"
237    }
238
239    fn aliases(&self) -> &[&'static str] {
240        &["c++"]
241    }
242
243    fn supports_sessions(&self) -> bool {
244        self.compiler.is_some()
245    }
246
247    fn validate(&self) -> Result<()> {
248        let compiler = self.ensure_compiler()?;
249        let mut cmd = Command::new(compiler);
250        cmd.arg("--version")
251            .stdout(Stdio::null())
252            .stderr(Stdio::null());
253        cmd.status()
254            .with_context(|| format!("failed to invoke {}", compiler.display()))?
255            .success()
256            .then_some(())
257            .ok_or_else(|| anyhow::anyhow!("{} is not executable", compiler.display()))
258    }
259
260    fn toolchain_version(&self) -> Result<Option<String>> {
261        let compiler = self.ensure_compiler()?;
262        let mut cmd = Command::new(compiler);
263        cmd.arg("--version");
264        let context = format!("{}", compiler.display());
265        run_version_command(cmd, &context)
266    }
267
268    fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
269        let args = payload.args();
270        if let ExecutionPayload::File { path, .. } = payload {
271            return self.execute_file_incremental(path, args);
272        }
273
274        // Try cache for inline/stdin payloads
275        if let Some(code) = match payload {
276            ExecutionPayload::Inline { code, .. } | ExecutionPayload::Stdin { code, .. } => {
277                Some(code.as_str())
278            }
279            _ => None,
280        } {
281            let src_hash = hash_source(code);
282            if let Some(output) = try_cached_execution("cpp", src_hash) {
283                perf_record("cpp", "inline.cache_hit");
284                let start = Instant::now();
285                return Ok(ExecutionOutcome {
286                    language: self.id().to_string(),
287                    exit_code: output.status.code(),
288                    stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
289                    stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
290                    duration: start.elapsed(),
291                });
292            }
293            perf_record("cpp", "inline.cache_miss");
294        }
295
296        let temp_dir = Builder::new()
297            .prefix("run-cpp")
298            .tempdir()
299            .context("failed to create temporary directory for cpp build")?;
300        let dir_path = temp_dir.path();
301
302        let (source_path, cache_key) = match payload {
303            ExecutionPayload::Inline { code, .. } | ExecutionPayload::Stdin { code, .. } => {
304                let h = hash_source(code);
305                (self.write_source(code, dir_path)?, Some(h))
306            }
307            ExecutionPayload::File { path, .. } => (self.copy_source(path, dir_path)?, None),
308        };
309
310        let binary_path = Self::binary_path(dir_path);
311        let start = Instant::now();
312
313        let compile_output = self.compile(&source_path, &binary_path)?;
314        if !compile_output.status.success() {
315            return Ok(ExecutionOutcome {
316                language: self.id().to_string(),
317                exit_code: compile_output.status.code(),
318                stdout: String::from_utf8_lossy(&compile_output.stdout).into_owned(),
319                stderr: String::from_utf8_lossy(&compile_output.stderr).into_owned(),
320                duration: start.elapsed(),
321            });
322        }
323
324        if let Some(h) = cache_key {
325            cache_store("cpp", h, &binary_path);
326        }
327
328        let run_output = self.run_binary(&binary_path, args)?;
329        Ok(ExecutionOutcome {
330            language: self.id().to_string(),
331            exit_code: run_output.status.code(),
332            stdout: String::from_utf8_lossy(&run_output.stdout).into_owned(),
333            stderr: String::from_utf8_lossy(&run_output.stderr).into_owned(),
334            duration: start.elapsed(),
335        })
336    }
337
338    fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
339        let compiler = self.ensure_compiler().map(Path::to_path_buf)?;
340
341        let temp_dir = Builder::new()
342            .prefix("run-cpp-repl")
343            .tempdir()
344            .context("failed to create temporary directory for cpp repl")?;
345        let dir_path = temp_dir.path();
346        let source_path = dir_path.join("main.cpp");
347        let binary_path = Self::binary_path(dir_path);
348
349        Ok(Box::new(CppSession {
350            compiler,
351            _temp_dir: temp_dir,
352            source_path,
353            binary_path,
354            definitions: Vec::new(),
355            statements: Vec::new(),
356            previous_stdout: String::new(),
357            previous_stderr: String::new(),
358        }))
359    }
360}
361
362fn resolve_cpp_compiler() -> Option<PathBuf> {
363    ["c++", "clang++", "g++"]
364        .into_iter()
365        .find_map(|candidate| which::which(candidate).ok())
366}
367
368const SESSION_PREAMBLE: &str = concat!(
369    "#include <iostream>\n",
370    "#include <iomanip>\n",
371    "#include <string>\n",
372    "#include <vector>\n",
373    "#include <map>\n",
374    "#include <set>\n",
375    "#include <unordered_map>\n",
376    "#include <unordered_set>\n",
377    "#include <deque>\n",
378    "#include <list>\n",
379    "#include <queue>\n",
380    "#include <stack>\n",
381    "#include <memory>\n",
382    "#include <functional>\n",
383    "#include <algorithm>\n",
384    "#include <numeric>\n",
385    "#include <cmath>\n\n",
386    "using namespace std;\n\n",
387);
388
389struct CppSession {
390    compiler: PathBuf,
391    _temp_dir: TempDir,
392    source_path: PathBuf,
393    binary_path: PathBuf,
394    definitions: Vec<String>,
395    statements: Vec<String>,
396    previous_stdout: String,
397    previous_stderr: String,
398}
399
400impl CppSession {
401    fn render_prelude(&self) -> String {
402        let mut source = String::from(SESSION_PREAMBLE);
403        for def in &self.definitions {
404            source.push_str(def);
405            if !def.ends_with('\n') {
406                source.push('\n');
407            }
408            source.push('\n');
409        }
410        source
411    }
412
413    fn render_source(&self) -> String {
414        let mut source = self.render_prelude();
415        source.push_str("int main()\n{\n    ios::sync_with_stdio(false);\n    cin.tie(nullptr);\n    cout.setf(std::ios::boolalpha);\n");
416        for stmt in &self.statements {
417            for line in stmt.lines() {
418                source.push_str("    ");
419                source.push_str(line);
420                source.push('\n');
421            }
422            if !stmt.ends_with('\n') {
423                source.push('\n');
424            }
425        }
426        source.push_str("    return 0;\n}\n");
427        source
428    }
429
430    fn write_source(&self, contents: &str) -> Result<()> {
431        fs::write(&self.source_path, contents).with_context(|| {
432            format!(
433                "failed to write generated C++ REPL source to {}",
434                self.source_path.display()
435            )
436        })
437    }
438
439    fn compile_and_run(&mut self) -> Result<(std::process::Output, Duration)> {
440        let start = Instant::now();
441        let source = self.render_source();
442        self.write_source(&source)?;
443        let compile_output =
444            invoke_cpp_compiler(&self.compiler, &self.source_path, &self.binary_path)?;
445        if !compile_output.status.success() {
446            let duration = start.elapsed();
447            return Ok((compile_output, duration));
448        }
449        let execution_output = run_cpp_binary(&self.binary_path)?;
450        let duration = start.elapsed();
451        Ok((execution_output, duration))
452    }
453
454    fn run_standalone_program(&mut self, code: &str) -> Result<ExecutionOutcome> {
455        let start = Instant::now();
456        let mut source = self.render_prelude();
457        if !source.ends_with('\n') {
458            source.push('\n');
459        }
460        source.push_str(code);
461        if !code.ends_with('\n') {
462            source.push('\n');
463        }
464
465        let standalone_path = self
466            .source_path
467            .parent()
468            .unwrap_or_else(|| Path::new("."))
469            .join("standalone.cpp");
470        fs::write(&standalone_path, &source)
471            .with_context(|| "failed to write standalone C++ source".to_string())?;
472
473        let compile_output =
474            invoke_cpp_compiler(&self.compiler, &standalone_path, &self.binary_path)?;
475        if !compile_output.status.success() {
476            return Ok(ExecutionOutcome {
477                language: "cpp".to_string(),
478                exit_code: compile_output.status.code(),
479                stdout: normalize_output(&compile_output.stdout),
480                stderr: normalize_output(&compile_output.stderr),
481                duration: start.elapsed(),
482            });
483        }
484
485        let run_output = run_cpp_binary(&self.binary_path)?;
486        Ok(ExecutionOutcome {
487            language: "cpp".to_string(),
488            exit_code: run_output.status.code(),
489            stdout: normalize_output(&run_output.stdout),
490            stderr: normalize_output(&run_output.stderr),
491            duration: start.elapsed(),
492        })
493    }
494
495    fn reset_state(&mut self) -> Result<()> {
496        self.definitions.clear();
497        self.statements.clear();
498        self.previous_stdout.clear();
499        self.previous_stderr.clear();
500        let source = self.render_source();
501        self.write_source(&source)
502    }
503
504    fn diff_outputs(
505        &mut self,
506        output: &std::process::Output,
507        duration: Duration,
508    ) -> ExecutionOutcome {
509        let stdout_full = normalize_output(&output.stdout);
510        let stderr_full = normalize_output(&output.stderr);
511
512        let stdout_delta = diff_output(&self.previous_stdout, &stdout_full);
513        let stderr_delta = diff_output(&self.previous_stderr, &stderr_full);
514
515        if output.status.success() {
516            self.previous_stdout = stdout_full;
517            self.previous_stderr = stderr_full;
518        }
519
520        ExecutionOutcome {
521            language: "cpp".to_string(),
522            exit_code: output.status.code(),
523            stdout: stdout_delta,
524            stderr: stderr_delta,
525            duration,
526        }
527    }
528
529    fn add_definition(&mut self, snippet: String) {
530        self.definitions.push(snippet);
531    }
532
533    fn add_statement(&mut self, snippet: String) {
534        self.statements.push(snippet);
535    }
536
537    fn remove_last_definition(&mut self) {
538        let _ = self.definitions.pop();
539    }
540
541    fn remove_last_statement(&mut self) {
542        let _ = self.statements.pop();
543    }
544}
545
546impl LanguageSession for CppSession {
547    fn language_id(&self) -> &str {
548        "cpp"
549    }
550
551    fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
552        let trimmed = code.trim();
553        if trimmed.is_empty() {
554            return Ok(ExecutionOutcome {
555                language: self.language_id().to_string(),
556                exit_code: None,
557                stdout: String::new(),
558                stderr: String::new(),
559                duration: Instant::now().elapsed(),
560            });
561        }
562
563        if trimmed.eq_ignore_ascii_case(":reset") {
564            self.reset_state()?;
565            return Ok(ExecutionOutcome {
566                language: self.language_id().to_string(),
567                exit_code: None,
568                stdout: String::new(),
569                stderr: String::new(),
570                duration: Duration::default(),
571            });
572        }
573
574        if trimmed.eq_ignore_ascii_case(":help") {
575            return Ok(ExecutionOutcome {
576                language: self.language_id().to_string(),
577                exit_code: None,
578                stdout:
579                    "C++ commands:\n  :reset - clear session state\n  :help  - show this message\n"
580                        .to_string(),
581                stderr: String::new(),
582                duration: Duration::default(),
583            });
584        }
585
586        if contains_main_definition(code) {
587            return self.run_standalone_program(code);
588        }
589
590        let classification = classify_snippet(trimmed);
591        match classification {
592            SnippetKind::Definition => {
593                self.add_definition(code.to_string());
594                let (output, duration) = self.compile_and_run()?;
595                if !output.status.success() {
596                    self.remove_last_definition();
597                }
598                Ok(self.diff_outputs(&output, duration))
599            }
600            SnippetKind::Expression => {
601                let wrapped = wrap_cpp_expression(trimmed);
602                self.add_statement(wrapped);
603                let (output, duration) = self.compile_and_run()?;
604                if !output.status.success() {
605                    self.remove_last_statement();
606                    return Ok(self.diff_outputs(&output, duration));
607                }
608                Ok(self.diff_outputs(&output, duration))
609            }
610            SnippetKind::Statement => {
611                let stmt = ensure_trailing_newline(code);
612                self.add_statement(stmt);
613                let (output, duration) = self.compile_and_run()?;
614                if !output.status.success() {
615                    self.remove_last_statement();
616                }
617                Ok(self.diff_outputs(&output, duration))
618            }
619        }
620    }
621
622    fn shutdown(&mut self) -> Result<()> {
623        Ok(())
624    }
625}
626
627#[derive(Debug, Clone, Copy, PartialEq, Eq)]
628enum SnippetKind {
629    Definition,
630    Statement,
631    Expression,
632}
633
634fn classify_snippet(code: &str) -> SnippetKind {
635    let trimmed = code.trim();
636    if trimmed.starts_with("#include")
637        || trimmed.starts_with("using ")
638        || trimmed.starts_with("namespace ")
639        || trimmed.starts_with("class ")
640        || trimmed.starts_with("struct ")
641        || trimmed.starts_with("enum ")
642        || trimmed.starts_with("template ")
643        || trimmed.ends_with("};")
644    {
645        return SnippetKind::Definition;
646    }
647
648    if trimmed.contains('{') && trimmed.contains('}') && trimmed.contains('(') {
649        const CONTROL_KEYWORDS: [&str; 8] =
650            ["if", "for", "while", "switch", "do", "else", "try", "catch"];
651        let first = trimmed.split_whitespace().next().unwrap_or("");
652        if !CONTROL_KEYWORDS.iter().any(|kw| {
653            first == *kw
654                || trimmed.starts_with(&format!("{} ", kw))
655                || trimmed.starts_with(&format!("{}(", kw))
656        }) {
657            return SnippetKind::Definition;
658        }
659    }
660
661    if is_cpp_expression(trimmed) {
662        return SnippetKind::Expression;
663    }
664
665    SnippetKind::Statement
666}
667
668fn is_cpp_expression(code: &str) -> bool {
669    if code.contains('\n') {
670        return false;
671    }
672    if code.ends_with(';') {
673        return false;
674    }
675    if code.starts_with("return ") {
676        return false;
677    }
678    if code.starts_with("if ")
679        || code.starts_with("for ")
680        || code.starts_with("while ")
681        || code.starts_with("switch ")
682        || code.starts_with("do ")
683        || code.starts_with("auto ")
684    {
685        return false;
686    }
687    if code.starts_with("std::") && code.contains('(') {
688        return false;
689    }
690    if code.starts_with("cout") || code.starts_with("cin") {
691        return false;
692    }
693    if code.starts_with('"') && code.ends_with('"') {
694        return true;
695    }
696    if code.parse::<f64>().is_ok() {
697        return true;
698    }
699    if code == "true" || code == "false" {
700        return true;
701    }
702    if code.contains("==") || code.contains("!=") || code.contains("<=") || code.contains(">=") {
703        return true;
704    }
705    if code.chars().any(|c| "+-*/%<>^|&".contains(c)) {
706        return true;
707    }
708    if code
709        .chars()
710        .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '.')
711    {
712        return true;
713    }
714    false
715}
716
717fn contains_main_definition(code: &str) -> bool {
718    let bytes = code.as_bytes();
719    let len = bytes.len();
720    let mut i = 0;
721    let mut in_line_comment = false;
722    let mut in_block_comment = false;
723    let mut in_string = false;
724    let mut string_delim = b'"';
725    let mut in_char = false;
726
727    while i < len {
728        let b = bytes[i];
729
730        if in_line_comment {
731            if b == b'\n' {
732                in_line_comment = false;
733            }
734            i += 1;
735            continue;
736        }
737
738        if in_block_comment {
739            if b == b'*' && i + 1 < len && bytes[i + 1] == b'/' {
740                in_block_comment = false;
741                i += 2;
742                continue;
743            }
744            i += 1;
745            continue;
746        }
747
748        if in_string {
749            if b == b'\\' {
750                i = (i + 2).min(len);
751                continue;
752            }
753            if b == string_delim {
754                in_string = false;
755            }
756            i += 1;
757            continue;
758        }
759
760        if in_char {
761            if b == b'\\' {
762                i = (i + 2).min(len);
763                continue;
764            }
765            if b == b'\'' {
766                in_char = false;
767            }
768            i += 1;
769            continue;
770        }
771
772        match b {
773            b'/' if i + 1 < len && bytes[i + 1] == b'/' => {
774                in_line_comment = true;
775                i += 2;
776                continue;
777            }
778            b'/' if i + 1 < len && bytes[i + 1] == b'*' => {
779                in_block_comment = true;
780                i += 2;
781                continue;
782            }
783            b'"' | b'\'' => {
784                if b == b'"' {
785                    in_string = true;
786                    string_delim = b;
787                } else {
788                    in_char = true;
789                }
790                i += 1;
791                continue;
792            }
793            b'm' if i + 4 <= len && &bytes[i..i + 4] == b"main" => {
794                if i > 0 {
795                    let prev = bytes[i - 1];
796                    if prev.is_ascii_alphanumeric() || prev == b'_' {
797                        i += 1;
798                        continue;
799                    }
800                }
801
802                let after_name = i + 4;
803                if after_name < len {
804                    let next = bytes[after_name];
805                    if next.is_ascii_alphanumeric() || next == b'_' {
806                        i += 1;
807                        continue;
808                    }
809                }
810
811                let mut j = after_name;
812                while j < len && bytes[j].is_ascii_whitespace() {
813                    j += 1;
814                }
815                if j >= len || bytes[j] != b'(' {
816                    i += 1;
817                    continue;
818                }
819
820                let mut depth = 1usize;
821                let mut k = j + 1;
822                let mut inner_line_comment = false;
823                let mut inner_block_comment = false;
824                let mut inner_string = false;
825                let mut inner_char = false;
826
827                while k < len {
828                    let ch = bytes[k];
829
830                    if inner_line_comment {
831                        if ch == b'\n' {
832                            inner_line_comment = false;
833                        }
834                        k += 1;
835                        continue;
836                    }
837
838                    if inner_block_comment {
839                        if ch == b'*' && k + 1 < len && bytes[k + 1] == b'/' {
840                            inner_block_comment = false;
841                            k += 2;
842                            continue;
843                        }
844                        k += 1;
845                        continue;
846                    }
847
848                    if inner_string {
849                        if ch == b'\\' {
850                            k = (k + 2).min(len);
851                            continue;
852                        }
853                        if ch == b'"' {
854                            inner_string = false;
855                        }
856                        k += 1;
857                        continue;
858                    }
859
860                    if inner_char {
861                        if ch == b'\\' {
862                            k = (k + 2).min(len);
863                            continue;
864                        }
865                        if ch == b'\'' {
866                            inner_char = false;
867                        }
868                        k += 1;
869                        continue;
870                    }
871
872                    match ch {
873                        b'/' if k + 1 < len && bytes[k + 1] == b'/' => {
874                            inner_line_comment = true;
875                            k += 2;
876                            continue;
877                        }
878                        b'/' if k + 1 < len && bytes[k + 1] == b'*' => {
879                            inner_block_comment = true;
880                            k += 2;
881                            continue;
882                        }
883                        b'"' => {
884                            inner_string = true;
885                            k += 1;
886                            continue;
887                        }
888                        b'\'' => {
889                            inner_char = true;
890                            k += 1;
891                            continue;
892                        }
893                        b'(' => {
894                            depth += 1;
895                        }
896                        b')' => {
897                            depth -= 1;
898                            k += 1;
899                            if depth == 0 {
900                                break;
901                            } else {
902                                continue;
903                            }
904                        }
905                        _ => {}
906                    }
907
908                    k += 1;
909                }
910
911                if depth != 0 {
912                    i += 1;
913                    continue;
914                }
915
916                let mut after = k;
917                loop {
918                    while after < len && bytes[after].is_ascii_whitespace() {
919                        after += 1;
920                    }
921                    if after + 1 < len && bytes[after] == b'/' && bytes[after + 1] == b'/' {
922                        after += 2;
923                        while after < len && bytes[after] != b'\n' {
924                            after += 1;
925                        }
926                        continue;
927                    }
928                    if after + 1 < len && bytes[after] == b'/' && bytes[after + 1] == b'*' {
929                        after += 2;
930                        while after + 1 < len {
931                            if bytes[after] == b'*' && bytes[after + 1] == b'/' {
932                                after += 2;
933                                break;
934                            }
935                            after += 1;
936                        }
937                        continue;
938                    }
939                    break;
940                }
941
942                while after < len {
943                    match bytes[after] {
944                        b'{' => return true,
945                        b';' => break,
946                        b'/' if after + 1 < len && bytes[after + 1] == b'/' => {
947                            after += 2;
948                            while after < len && bytes[after] != b'\n' {
949                                after += 1;
950                            }
951                        }
952                        b'/' if after + 1 < len && bytes[after + 1] == b'*' => {
953                            after += 2;
954                            while after + 1 < len {
955                                if bytes[after] == b'*' && bytes[after + 1] == b'/' {
956                                    after += 2;
957                                    break;
958                                }
959                                after += 1;
960                            }
961                        }
962                        b'"' => {
963                            after += 1;
964                            while after < len {
965                                if bytes[after] == b'"' {
966                                    after += 1;
967                                    break;
968                                }
969                                if bytes[after] == b'\\' {
970                                    after = (after + 2).min(len);
971                                } else {
972                                    after += 1;
973                                }
974                            }
975                        }
976                        b'\'' => {
977                            after += 1;
978                            while after < len {
979                                if bytes[after] == b'\'' {
980                                    after += 1;
981                                    break;
982                                }
983                                if bytes[after] == b'\\' {
984                                    after = (after + 2).min(len);
985                                } else {
986                                    after += 1;
987                                }
988                            }
989                        }
990                        b'-' if after + 1 < len && bytes[after + 1] == b'>' => {
991                            after += 2;
992                        }
993                        b'(' => {
994                            let mut depth = 1usize;
995                            after += 1;
996                            while after < len && depth > 0 {
997                                match bytes[after] {
998                                    b'(' => depth += 1,
999                                    b')' => depth -= 1,
1000                                    b'"' => {
1001                                        after += 1;
1002                                        while after < len {
1003                                            if bytes[after] == b'"' {
1004                                                after += 1;
1005                                                break;
1006                                            }
1007                                            if bytes[after] == b'\\' {
1008                                                after = (after + 2).min(len);
1009                                            } else {
1010                                                after += 1;
1011                                            }
1012                                        }
1013                                        continue;
1014                                    }
1015                                    b'\'' => {
1016                                        after += 1;
1017                                        while after < len {
1018                                            if bytes[after] == b'\'' {
1019                                                after += 1;
1020                                                break;
1021                                            }
1022                                            if bytes[after] == b'\\' {
1023                                                after = (after + 2).min(len);
1024                                            } else {
1025                                                after += 1;
1026                                            }
1027                                        }
1028                                        continue;
1029                                    }
1030                                    _ => {}
1031                                }
1032                                after += 1;
1033                            }
1034                        }
1035                        _ => {
1036                            after += 1;
1037                        }
1038                    }
1039                }
1040            }
1041            _ => {}
1042        }
1043
1044        i += 1;
1045    }
1046
1047    false
1048}
1049
1050fn wrap_cpp_expression(code: &str) -> String {
1051    format!("std::cout << ({code}) << std::endl;\n")
1052}
1053
1054fn ensure_trailing_newline(code: &str) -> String {
1055    let mut owned = code.to_string();
1056    if !owned.ends_with('\n') {
1057        owned.push('\n');
1058    }
1059    owned
1060}
1061
1062fn diff_output(previous: &str, current: &str) -> String {
1063    if let Some(stripped) = current.strip_prefix(previous) {
1064        stripped.to_string()
1065    } else {
1066        current.to_string()
1067    }
1068}
1069
1070fn normalize_output(bytes: &[u8]) -> String {
1071    String::from_utf8_lossy(bytes)
1072        .replace("\r\n", "\n")
1073        .replace('\r', "")
1074}
1075
1076fn invoke_cpp_compiler(
1077    compiler: &Path,
1078    source: &Path,
1079    output: &Path,
1080) -> Result<std::process::Output> {
1081    let mut cmd = compiler_command(compiler);
1082    cmd.arg(source)
1083        .arg("-std=c++17")
1084        .arg("-O0")
1085        .arg("-w")
1086        .arg("-o")
1087        .arg(output)
1088        .stdout(Stdio::piped())
1089        .stderr(Stdio::piped());
1090    if let Some(pch_header) = ensure_global_cpp_pch(compiler) {
1091        cmd.arg("-include").arg(pch_header);
1092    }
1093    cmd.output().with_context(|| {
1094        format!(
1095            "failed to invoke {} to compile {}",
1096            compiler.display(),
1097            source.display()
1098        )
1099    })
1100}
1101
1102fn run_cpp_binary(binary: &Path) -> Result<std::process::Output> {
1103    let mut cmd = Command::new(binary);
1104    cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
1105    cmd.output()
1106        .with_context(|| format!("failed to execute compiled binary {}", binary.display()))
1107}
1108
1109fn ensure_global_cpp_pch(compiler: &Path) -> Option<PathBuf> {
1110    static PCH_BUILD_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
1111    let lock = PCH_BUILD_LOCK.get_or_init(|| Mutex::new(()));
1112    let _guard = lock.lock().ok()?;
1113
1114    let cache_dir = std::env::temp_dir()
1115        .join("run-compile-cache")
1116        .join(cpp_pch_cache_key(compiler));
1117    std::fs::create_dir_all(&cache_dir).ok()?;
1118    let header = cache_dir.join("run_cpp_pch.hpp");
1119    let gch = cache_dir.join("run_cpp_pch.hpp.gch");
1120    if !header.exists() {
1121        let contents = concat!(
1122            "#include <iostream>\n",
1123            "#include <vector>\n",
1124            "#include <string>\n",
1125            "#include <map>\n",
1126            "#include <unordered_map>\n",
1127            "#include <algorithm>\n",
1128            "#include <utility>\n"
1129        );
1130        let tmp_header = cache_dir.join(format!("run_cpp_pch.hpp.tmp.{}", std::process::id()));
1131        std::fs::write(&tmp_header, contents).ok()?;
1132        std::fs::rename(&tmp_header, &header).ok()?;
1133    }
1134    let needs_build = if !gch.exists() {
1135        true
1136    } else {
1137        let h = header.metadata().ok()?.modified().ok()?;
1138        let g = gch.metadata().ok()?.modified().ok()?;
1139        h > g
1140    };
1141    if needs_build {
1142        let tmp_gch = cache_dir.join(format!("run_cpp_pch.hpp.gch.tmp.{}", std::process::id()));
1143        let mut pch_cmd = compiler_command(compiler);
1144        let out = pch_cmd
1145            .arg("-std=c++17")
1146            .arg("-x")
1147            .arg("c++-header")
1148            .arg(&header)
1149            .arg("-o")
1150            .arg(&tmp_gch)
1151            .stdout(Stdio::piped())
1152            .stderr(Stdio::piped())
1153            .output()
1154            .ok()?;
1155        if !out.status.success() {
1156            let _ = std::fs::remove_file(&tmp_gch);
1157            return None;
1158        }
1159        std::fs::rename(&tmp_gch, &gch).ok()?;
1160    }
1161    Some(header)
1162}
1163
1164fn cpp_pch_cache_key(compiler: &Path) -> String {
1165    let mut hasher = blake3::Hasher::new();
1166    hasher.update(b"run-kit/cpp-pch/v2");
1167    hasher.update(compiler.to_string_lossy().as_bytes());
1168
1169    if let Ok(output) = Command::new(compiler).arg("--version").output() {
1170        hasher.update(&output.stdout);
1171        hasher.update(&output.stderr);
1172    }
1173
1174    hasher.finalize().to_hex()[..16].to_string()
1175}
1176
1177fn cpp_needs_recompile(source: &Path, object: &Path, depfile: &Path) -> bool {
1178    if !object.exists() {
1179        return true;
1180    }
1181    let obj_time = match object.metadata().and_then(|m| m.modified()) {
1182        Ok(t) => t,
1183        Err(_) => return true,
1184    };
1185    let src_time = match source.metadata().and_then(|m| m.modified()) {
1186        Ok(t) => t,
1187        Err(_) => return true,
1188    };
1189    if src_time > obj_time {
1190        return true;
1191    }
1192    if !depfile.exists() {
1193        return true;
1194    }
1195    let dep_text = match fs::read_to_string(depfile) {
1196        Ok(t) => t.replace("\\\n", " "),
1197        Err(_) => return true,
1198    };
1199    for token in dep_text.split_whitespace().skip(1) {
1200        let path = token.trim_end_matches(':');
1201        if path.is_empty() {
1202            continue;
1203        }
1204        let p = Path::new(path);
1205        if let Ok(t) = p.metadata().and_then(|m| m.modified())
1206            && t > obj_time
1207        {
1208            return true;
1209        }
1210    }
1211    false
1212}
1213
1214#[cfg(test)]
1215mod tests {
1216    use super::cpp_pch_cache_key;
1217    use std::path::Path;
1218
1219    #[test]
1220    fn pch_cache_key_is_stable_for_same_compiler_path() {
1221        assert_eq!(
1222            cpp_pch_cache_key(Path::new("/missing/c++")),
1223            cpp_pch_cache_key(Path::new("/missing/c++"))
1224        );
1225    }
1226
1227    #[test]
1228    fn pch_cache_key_changes_with_compiler_path() {
1229        assert_ne!(
1230            cpp_pch_cache_key(Path::new("/missing/c++")),
1231            cpp_pch_cache_key(Path::new("/other/c++"))
1232        );
1233    }
1234}