Skip to main content

code_executor/
judge.rs

1use std::{env, io, marker::PhantomData, path::PathBuf, process::Stdio, time::Duration};
2
3use bon::bon;
4use byte_unit::Byte;
5use futures_lite::{Stream, StreamExt};
6use state_shift::{impl_state, type_state};
7use tokio::{
8    fs,
9    io::{AsyncReadExt, AsyncWriteExt},
10};
11use uuid::Uuid;
12
13use crate::{AggregatedMetrics, Language, Metrics, Resource, Sandbox, Verdict};
14
15const MAIN: &str = "main";
16const CHECKER: &str = "checker";
17const BUFFER_SIZE: usize = 8 * 1024;
18
19pub struct Code<'a> {
20    pub language: Language,
21    pub content: &'a [u8],
22}
23
24#[type_state(
25    states = (Created, Compiled),
26    slots = (Created)
27)]
28#[derive(Default)]
29pub struct Judge {
30    pub project_path: PathBuf,
31    pub language: Language,
32    pub checker_language: Option<Language>,
33    pub is_interactive: bool,
34    pub resource: Resource,
35    pub time_limit: Duration,
36    pub idle_time_limit: Duration,
37}
38
39#[bon]
40impl Judge<Created> {
41    #[builder]
42    pub async fn new<'a>(
43        main: Code<'a>,
44        checker: Option<Code<'a>>,
45        #[builder(default = false, name = "interactive")] is_interactive: bool,
46        #[builder(default)] resource: Resource,
47        #[builder(default)] time_limit: Duration,
48        #[builder(default = Duration::from_secs(1))] idle_time_limit: Duration,
49    ) -> io::Result<Judge<Created>> {
50        let project_path = env::temp_dir().join(Uuid::new_v4().to_string());
51        fs::create_dir(&project_path).await?;
52
53        let main_path = project_path
54            .join(MAIN)
55            .with_extension(main.language.extension);
56        fs::write(&main_path, main.content).await?;
57        if let Some(checker) = &checker {
58            let mut checker_path = project_path.join(CHECKER);
59            if checker.language.is_interpreted() {
60                checker_path.set_extension(checker.language.extension);
61            }
62            let mut checker_file = fs::OpenOptions::new()
63                .create(true)
64                .write(true)
65                .truncate(true)
66                .mode(0o755)
67                .open(&checker_path)
68                .await?;
69            checker_file.write_all(checker.content).await?;
70            checker_file.sync_all().await?;
71        }
72
73        Ok(Judge {
74            project_path,
75            language: main.language,
76            checker_language: checker.map(|checker| checker.language),
77            is_interactive,
78            resource,
79            time_limit,
80            idle_time_limit,
81            _state: PhantomData,
82        })
83    }
84}
85
86#[impl_state]
87impl Judge {
88    #[require(Created)]
89    #[switch_to(Compiled)]
90    pub async fn compile(self) -> io::Result<Result<Judge<Compiled>, Verdict>> {
91        if let Some(mut cmd) = self.language.get_compile_command(MAIN) {
92            let mut process = cmd.current_dir(&self.project_path).spawn()?;
93            let status = process.wait().await?;
94            if !status.success() {
95                return Ok(Err(Verdict::CompilationError));
96            }
97        }
98
99        Ok(Ok(Judge {
100            project_path: self.project_path,
101            language: self.language,
102            checker_language: self.checker_language,
103            is_interactive: self.is_interactive,
104            resource: self.resource,
105            time_limit: self.time_limit,
106            idle_time_limit: self.idle_time_limit,
107        }))
108    }
109
110    #[require(Compiled)]
111    pub async fn read_executable(&self) -> io::Result<Vec<u8>> {
112        let mut path = self.project_path.join(MAIN);
113        if self.language.is_interpreted() {
114            path.set_extension(self.language.extension);
115        }
116
117        fs::read(path).await
118    }
119
120    #[require(Compiled)]
121    pub async fn run(&self, input: &[u8]) -> io::Result<Metrics> {
122        let checker_language = self
123            .checker_language
124            .ok_or(io::Error::other("Missing checker"))?;
125        let mut checker = checker_language
126            .get_run_command(CHECKER)
127            .current_dir(&self.project_path)
128            .stdin(Stdio::piped())
129            .stdout(Stdio::piped())
130            .stderr(Stdio::null())
131            .spawn()?;
132        let mut cstdin = checker.stdin.take().unwrap();
133        let mut cstdout = checker.stdout.take().unwrap();
134        cstdin.write_all(input).await?;
135        cstdin.write_all(b"\n").await?;
136        cstdin.flush().await?;
137
138        let sandbox = Sandbox::new(self.resource, self.time_limit, self.idle_time_limit)?;
139        let mut cmd = self.language.get_run_command(MAIN);
140        cmd.current_dir(&self.project_path)
141            .stdin(Stdio::piped())
142            .stdout(Stdio::piped())
143            .stderr(Stdio::piped());
144        let mut main = sandbox.spawn(cmd)?;
145        let mut stdin = main.stdin.take().unwrap();
146        let mut stdout = main.stdout.take().unwrap();
147        let mut stderr = main.stderr.take().unwrap();
148
149        let monitor = tokio::spawn(async move { sandbox.monitor(main).await });
150        if !self.is_interactive {
151            stdin.write_all(input).await?;
152            stdin.write_all(b"\n").await?;
153            stdin.flush().await?;
154        }
155        let stdin_thread =
156            tokio::spawn(async move { tokio::io::copy(&mut cstdout, &mut stdin).await });
157        let stdout_thread = tokio::spawn(async move {
158            let mut out = vec![];
159            let mut buffer = [0u8; BUFFER_SIZE];
160            loop {
161                let n = stdout.read(&mut buffer).await?;
162                if n == 0 {
163                    break;
164                }
165                if cstdin.write_all(&buffer[..n]).await.is_err() {
166                    break;
167                }
168                cstdin.flush().await?;
169                out.extend_from_slice(&buffer[0..n]);
170            }
171
172            Ok::<_, io::Error>(out)
173        });
174
175        let (verdict, run_time, memory_usage) = monitor.await.unwrap()?;
176        let checker_status = checker.wait().await?;
177        drop(checker);
178
179        let _ = stdin_thread.await;
180        let stdout = stdout_thread.await.unwrap()?;
181        let mut err = vec![];
182        stderr.read_to_end(&mut err).await?;
183
184        if let Some(verdict) = verdict {
185            return Ok(Metrics {
186                verdict,
187                run_time,
188                stdout,
189                stderr: err,
190                memory_usage,
191            });
192        }
193
194        let verdict = if checker_status.success() {
195            Verdict::Accepted
196        } else {
197            Verdict::WrongAnswer
198        };
199
200        Ok(Metrics {
201            verdict,
202            run_time,
203            stdout,
204            stderr: err,
205            memory_usage,
206        })
207    }
208
209    #[require(Compiled)]
210    pub async fn batch_run(
211        &self,
212        inputs: impl Iterator<Item = &[u8]>,
213    ) -> io::Result<AggregatedMetrics> {
214        let mut verdict = Verdict::Accepted;
215        let mut total_run_time = Duration::ZERO;
216        let mut total_memory_usage = Byte::default();
217        let mut count = 0;
218
219        // running sequentially to enable early exit, saving resources
220        for input in inputs {
221            let metrics = self.run(input).await?;
222            total_run_time += metrics.run_time;
223            total_memory_usage = total_memory_usage
224                .add(metrics.memory_usage)
225                .expect("memory usage should not overflow u32");
226            count += 1;
227            if metrics.verdict != Verdict::Accepted {
228                verdict = metrics.verdict;
229                break;
230            }
231        }
232
233        Ok(AggregatedMetrics {
234            verdict,
235            average_run_time: total_run_time / count,
236            average_memory_usage: total_memory_usage
237                .divide(count as usize)
238                .expect("count must be greater than 0"),
239        })
240    }
241
242    #[require(Compiled)]
243    pub async fn streamed_batch_run(
244        &self,
245        mut inputs: impl Stream<Item = &[u8]> + std::marker::Unpin,
246    ) -> io::Result<AggregatedMetrics> {
247        let mut verdict = Verdict::Accepted;
248        let mut total_run_time = Duration::ZERO;
249        let mut total_memory_usage = Byte::default();
250        let mut count = 0;
251
252        // running sequentially to enable early exit, saving resources
253        while let Some(input) = inputs.next().await {
254            let metrics = self.run(input).await?;
255            total_run_time += metrics.run_time;
256            total_memory_usage = total_memory_usage
257                .add(metrics.memory_usage)
258                .expect("memory usage should not overflow u32");
259            count += 1;
260            if metrics.verdict != Verdict::Accepted {
261                verdict = metrics.verdict;
262                break;
263            }
264        }
265
266        Ok(AggregatedMetrics {
267            verdict,
268            average_run_time: total_run_time / count,
269            average_memory_usage: total_memory_usage
270                .divide(count as usize)
271                .expect("count must be greater than 0"),
272        })
273    }
274}