1use std::{env, io, marker::PhantomData, path::PathBuf, process::Stdio, time::Duration};
2
3use bon::bon;
4use byte_unit::Byte;
5use futures_lite::{Stream, StreamExt};
6use state_shift::{impl_state, type_state};
7use tokio::{
8 fs,
9 io::{AsyncReadExt, AsyncWriteExt},
10};
11use uuid::Uuid;
12
13use crate::{AggregatedMetrics, Language, Metrics, Resource, Sandbox, Verdict};
14
15const MAIN: &str = "main";
16const CHECKER: &str = "checker";
17const BUFFER_SIZE: usize = 8 * 1024;
18
19pub struct Code<'a> {
20 pub language: Language,
21 pub content: &'a [u8],
22}
23
24#[type_state(
25 states = (Created, Compiled),
26 slots = (Created)
27)]
28#[derive(Default)]
29pub struct Judge {
30 pub project_path: PathBuf,
31 pub language: Language,
32 pub checker_language: Option<Language>,
33 pub is_interactive: bool,
34 pub resource: Resource,
35 pub time_limit: Duration,
36 pub idle_time_limit: Duration,
37}
38
39#[bon]
40impl Judge<Created> {
41 #[builder]
42 pub async fn new<'a>(
43 main: Code<'a>,
44 checker: Option<Code<'a>>,
45 #[builder(default = false, name = "interactive")] is_interactive: bool,
46 #[builder(default)] resource: Resource,
47 #[builder(default)] time_limit: Duration,
48 #[builder(default = Duration::from_secs(1))] idle_time_limit: Duration,
49 ) -> io::Result<Judge<Created>> {
50 let project_path = env::temp_dir().join(Uuid::new_v4().to_string());
51 fs::create_dir(&project_path).await?;
52
53 let main_path = project_path
54 .join(MAIN)
55 .with_extension(main.language.extension);
56 fs::write(&main_path, main.content).await?;
57 if let Some(checker) = &checker {
58 let mut checker_path = project_path.join(CHECKER);
59 if checker.language.is_interpreted() {
60 checker_path.set_extension(checker.language.extension);
61 }
62 let mut checker_file = fs::OpenOptions::new()
63 .create(true)
64 .write(true)
65 .truncate(true)
66 .mode(0o755)
67 .open(&checker_path)
68 .await?;
69 checker_file.write_all(checker.content).await?;
70 checker_file.sync_all().await?;
71 }
72
73 Ok(Judge {
74 project_path,
75 language: main.language,
76 checker_language: checker.map(|checker| checker.language),
77 is_interactive,
78 resource,
79 time_limit,
80 idle_time_limit,
81 _state: PhantomData,
82 })
83 }
84}
85
86#[impl_state]
87impl Judge {
88 #[require(Created)]
89 #[switch_to(Compiled)]
90 pub async fn compile(self) -> io::Result<Result<Judge<Compiled>, Verdict>> {
91 if let Some(mut cmd) = self.language.get_compile_command(MAIN) {
92 let mut process = cmd.current_dir(&self.project_path).spawn()?;
93 let status = process.wait().await?;
94 if !status.success() {
95 return Ok(Err(Verdict::CompilationError));
96 }
97 }
98
99 Ok(Ok(Judge {
100 project_path: self.project_path,
101 language: self.language,
102 checker_language: self.checker_language,
103 is_interactive: self.is_interactive,
104 resource: self.resource,
105 time_limit: self.time_limit,
106 idle_time_limit: self.idle_time_limit,
107 }))
108 }
109
110 #[require(Compiled)]
111 pub async fn read_executable(&self) -> io::Result<Vec<u8>> {
112 let mut path = self.project_path.join(MAIN);
113 if self.language.is_interpreted() {
114 path.set_extension(self.language.extension);
115 }
116
117 fs::read(path).await
118 }
119
120 #[require(Compiled)]
121 pub async fn run(&self, input: &[u8]) -> io::Result<Metrics> {
122 let checker_language = self
123 .checker_language
124 .ok_or(io::Error::other("Missing checker"))?;
125 let mut checker = checker_language
126 .get_run_command(CHECKER)
127 .current_dir(&self.project_path)
128 .stdin(Stdio::piped())
129 .stdout(Stdio::piped())
130 .stderr(Stdio::null())
131 .spawn()?;
132 let mut cstdin = checker.stdin.take().unwrap();
133 let mut cstdout = checker.stdout.take().unwrap();
134 cstdin.write_all(input).await?;
135 cstdin.write_all(b"\n").await?;
136 cstdin.flush().await?;
137
138 let sandbox = Sandbox::new(self.resource, self.time_limit, self.idle_time_limit)?;
139 let mut cmd = self.language.get_run_command(MAIN);
140 cmd.current_dir(&self.project_path)
141 .stdin(Stdio::piped())
142 .stdout(Stdio::piped())
143 .stderr(Stdio::piped());
144 let mut main = sandbox.spawn(cmd)?;
145 let mut stdin = main.stdin.take().unwrap();
146 let mut stdout = main.stdout.take().unwrap();
147 let mut stderr = main.stderr.take().unwrap();
148
149 let monitor = tokio::spawn(async move { sandbox.monitor(main).await });
150 if !self.is_interactive {
151 stdin.write_all(input).await?;
152 stdin.write_all(b"\n").await?;
153 stdin.flush().await?;
154 }
155 let stdin_thread =
156 tokio::spawn(async move { tokio::io::copy(&mut cstdout, &mut stdin).await });
157 let stdout_thread = tokio::spawn(async move {
158 let mut out = vec![];
159 let mut buffer = [0u8; BUFFER_SIZE];
160 loop {
161 let n = stdout.read(&mut buffer).await?;
162 if n == 0 {
163 break;
164 }
165 if cstdin.write_all(&buffer[..n]).await.is_err() {
166 break;
167 }
168 cstdin.flush().await?;
169 out.extend_from_slice(&buffer[0..n]);
170 }
171
172 Ok::<_, io::Error>(out)
173 });
174
175 let (verdict, run_time, memory_usage) = monitor.await.unwrap()?;
176 let checker_status = checker.wait().await?;
177 drop(checker);
178
179 let _ = stdin_thread.await;
180 let stdout = stdout_thread.await.unwrap()?;
181 let mut err = vec![];
182 stderr.read_to_end(&mut err).await?;
183
184 if let Some(verdict) = verdict {
185 return Ok(Metrics {
186 verdict,
187 run_time,
188 stdout,
189 stderr: err,
190 memory_usage,
191 });
192 }
193
194 let verdict = if checker_status.success() {
195 Verdict::Accepted
196 } else {
197 Verdict::WrongAnswer
198 };
199
200 Ok(Metrics {
201 verdict,
202 run_time,
203 stdout,
204 stderr: err,
205 memory_usage,
206 })
207 }
208
209 #[require(Compiled)]
210 pub async fn batch_run(
211 &self,
212 inputs: impl Iterator<Item = &[u8]>,
213 ) -> io::Result<AggregatedMetrics> {
214 let mut verdict = Verdict::Accepted;
215 let mut total_run_time = Duration::ZERO;
216 let mut total_memory_usage = Byte::default();
217 let mut count = 0;
218
219 for input in inputs {
221 let metrics = self.run(input).await?;
222 total_run_time += metrics.run_time;
223 total_memory_usage = total_memory_usage
224 .add(metrics.memory_usage)
225 .expect("memory usage should not overflow u32");
226 count += 1;
227 if metrics.verdict != Verdict::Accepted {
228 verdict = metrics.verdict;
229 break;
230 }
231 }
232
233 Ok(AggregatedMetrics {
234 verdict,
235 average_run_time: total_run_time / count,
236 average_memory_usage: total_memory_usage
237 .divide(count as usize)
238 .expect("count must be greater than 0"),
239 })
240 }
241
242 #[require(Compiled)]
243 pub async fn streamed_batch_run(
244 &self,
245 mut inputs: impl Stream<Item = &[u8]> + std::marker::Unpin,
246 ) -> io::Result<AggregatedMetrics> {
247 let mut verdict = Verdict::Accepted;
248 let mut total_run_time = Duration::ZERO;
249 let mut total_memory_usage = Byte::default();
250 let mut count = 0;
251
252 while let Some(input) = inputs.next().await {
254 let metrics = self.run(input).await?;
255 total_run_time += metrics.run_time;
256 total_memory_usage = total_memory_usage
257 .add(metrics.memory_usage)
258 .expect("memory usage should not overflow u32");
259 count += 1;
260 if metrics.verdict != Verdict::Accepted {
261 verdict = metrics.verdict;
262 break;
263 }
264 }
265
266 Ok(AggregatedMetrics {
267 verdict,
268 average_run_time: total_run_time / count,
269 average_memory_usage: total_memory_usage
270 .divide(count as usize)
271 .expect("count must be greater than 0"),
272 })
273 }
274}