1use std::sync::atomic::{AtomicU64, Ordering};
17use std::time::{Duration, SystemTime, UNIX_EPOCH};
18
19use serde::Serialize;
20use thiserror::Error;
21use tokio::time::sleep;
22
23pub const VERSION: &str = "0.1.0";
25
26pub const DEFAULT_BASE_URL: &str = "https://rustbox-api.orkait.com";
28
29const DEFAULT_TIMEOUT: Duration = Duration::from_secs(65);
30const DEFAULT_MAX_RETRIES: u32 = 2;
31
32#[derive(Serialize, Debug, Clone, Copy, PartialEq, Eq)]
38#[serde(rename_all = "lowercase")]
39pub enum Profile {
40 Judge,
41 Agent,
42}
43
44#[derive(Serialize, Debug, Clone, Default)]
45pub struct SubmitRequest {
46 pub language: String,
47 pub code: String,
48 pub stdin: String,
49 #[serde(skip_serializing_if = "Option::is_none")]
51 pub profile: Option<Profile>,
52 #[serde(skip_serializing_if = "Option::is_none")]
55 pub webhook_url: Option<String>,
56 #[serde(skip_serializing_if = "Option::is_none")]
57 pub webhook_secret: Option<String>,
58}
59
60#[derive(Debug, Error)]
62pub enum RustboxError {
63 #[error("api_key required")]
64 MissingApiKey,
65 #[error("invalid base_url")]
66 InvalidBaseUrl,
67 #[error("invalid or missing API key (HTTP {0})")]
68 Auth(u16),
69 #[error("rate limit exceeded (HTTP 429)")]
70 RateLimit,
71 #[error("server error (HTTP {0})")]
72 Server(u16),
73 #[error("API error (HTTP {status}): {body}")]
74 Api { status: u16, body: String },
75 #[error("request timed out")]
76 Timeout,
77 #[error(transparent)]
78 Transport(#[from] reqwest::Error),
79 #[error("response decode failed: {0}")]
80 Decode(String),
81}
82
83#[derive(Debug, Clone, Default)]
85pub struct SubmitOptions {
86 pub idempotency_key: Option<String>,
88}
89
90pub struct Rustbox {
91 api_key: String,
92 base_url: String,
93 client: reqwest::Client,
94 max_retries: u32,
95}
96
97impl Rustbox {
98 pub fn new(api_key: &str) -> Result<Self, RustboxError> {
101 if api_key.is_empty() {
102 return Err(RustboxError::MissingApiKey);
103 }
104 let client = reqwest::Client::builder()
105 .timeout(DEFAULT_TIMEOUT)
106 .build()
107 .map_err(RustboxError::Transport)?;
108 Ok(Self {
109 api_key: api_key.to_string(),
110 base_url: DEFAULT_BASE_URL.to_string(),
111 client,
112 max_retries: DEFAULT_MAX_RETRIES,
113 })
114 }
115
116 pub fn with_base_url(mut self, base_url: &str) -> Result<Self, RustboxError> {
119 if base_url.is_empty() {
120 return Err(RustboxError::InvalidBaseUrl);
121 }
122 self.base_url = base_url.trim_end_matches('/').to_string();
123 Ok(self)
124 }
125
126 pub fn with_timeout(mut self, timeout: Duration) -> Result<Self, RustboxError> {
128 let mut builder = reqwest::Client::builder();
129 if !timeout.is_zero() {
130 builder = builder.timeout(timeout);
131 }
132 self.client = builder.build().map_err(RustboxError::Transport)?;
133 Ok(self)
134 }
135
136 pub fn with_max_retries(mut self, n: u32) -> Self {
138 self.max_retries = n;
139 self
140 }
141
142 pub fn base_url(&self) -> &str {
143 &self.base_url
144 }
145
146 fn backoff_delay(&self, attempt: u32) -> Duration {
147 Duration::from_millis((100u64 * (1u64 << attempt.min(8))).min(5_000))
148 }
149
150 async fn send_with_retry(
151 &self,
152 build: impl Fn() -> reqwest::RequestBuilder,
153 ) -> Result<reqwest::Response, RustboxError> {
154 let mut last_err: Option<RustboxError> = None;
155 for attempt in 0..=self.max_retries {
156 let req = build()
157 .header("X-API-Key", &self.api_key)
158 .header("User-Agent", format!("rustbox-sdk-rust/{VERSION}"));
159 match req.send().await {
160 Ok(resp) => {
161 if resp.status().as_u16() >= 500 && attempt < self.max_retries {
162 sleep(self.backoff_delay(attempt)).await;
163 continue;
164 }
165 return Ok(resp);
166 }
167 Err(e) => {
168 let is_timeout = e.is_timeout();
169 last_err = Some(if is_timeout {
170 RustboxError::Timeout
171 } else {
172 RustboxError::Transport(e)
173 });
174 if attempt >= self.max_retries {
175 return Err(last_err.unwrap());
176 }
177 sleep(self.backoff_delay(attempt)).await;
178 }
179 }
180 }
181 Err(last_err.unwrap_or(RustboxError::Decode("retry exhausted".into())))
182 }
183
184 async fn handle(&self, resp: reqwest::Response) -> Result<serde_json::Value, RustboxError> {
185 let status = resp.status();
186 let code = status.as_u16();
187 if status.is_success() || code == 408 {
188 return resp
189 .json()
190 .await
191 .map_err(|e| RustboxError::Decode(e.to_string()));
192 }
193 match code {
194 401 | 403 => Err(RustboxError::Auth(code)),
195 429 => Err(RustboxError::RateLimit),
196 500..=599 => Err(RustboxError::Server(code)),
197 _ => {
198 let body = resp.text().await.unwrap_or_default();
199 Err(RustboxError::Api { status: code, body })
200 }
201 }
202 }
203
204 pub async fn submit(
205 &self,
206 req: &SubmitRequest,
207 wait: bool,
208 opts: SubmitOptions,
209 ) -> Result<serde_json::Value, RustboxError> {
210 let url = format!("{}/api/submit?wait={}", self.base_url, wait);
211 let body = serde_json::to_vec(req).map_err(|e| RustboxError::Decode(e.to_string()))?;
212
213 let resp = self
214 .send_with_retry(|| {
215 let mut rb = self
216 .client
217 .post(&url)
218 .header("Content-Type", "application/json")
219 .body(body.clone());
220 if let Some(ref key) = opts.idempotency_key {
221 rb = rb.header("Idempotency-Key", key);
222 }
223 rb
224 })
225 .await?;
226 self.handle(resp).await
227 }
228
229 pub async fn get_result(&self, id: &str) -> Result<serde_json::Value, RustboxError> {
230 let url = format!("{}/api/result/{}", self.base_url, id);
231 let resp = self.send_with_retry(|| self.client.get(&url)).await?;
232 self.handle(resp).await
233 }
234
235 pub async fn get_languages(&self) -> Result<Vec<String>, RustboxError> {
236 let url = format!("{}/api/languages", self.base_url);
237 let resp = self.send_with_retry(|| self.client.get(&url)).await?;
238 let val = self.handle(resp).await?;
239 serde_json::from_value(val).map_err(|e| RustboxError::Decode(e.to_string()))
240 }
241
242 pub async fn get_health(&self) -> Result<serde_json::Value, RustboxError> {
243 let url = format!("{}/api/health", self.base_url);
244 let resp = self.send_with_retry(|| self.client.get(&url)).await?;
245 self.handle(resp).await
246 }
247
248 pub async fn get_ready(&self) -> Result<serde_json::Value, RustboxError> {
249 let url = format!("{}/api/health/ready", self.base_url);
250 let resp = self.send_with_retry(|| self.client.get(&url)).await?;
251 self.handle(resp).await
252 }
253
254 pub async fn run(&self, req: &SubmitRequest) -> Result<serde_json::Value, RustboxError> {
257 let opts = SubmitOptions {
258 idempotency_key: Some(idempotency_id()),
259 };
260 let mut res = self.submit(req, true, opts).await?;
261 if res.get("verdict").is_some() {
262 return Ok(res);
263 }
264
265 let id = match res.get("id").and_then(|v| v.as_str()) {
266 Some(i) => i.to_string(),
267 None => return Ok(res),
268 };
269
270 for i in 0..45 {
271 let delay_ms = (40.0 * (1.5_f64).powi(i)).min(600.0) as u64;
272 sleep(Duration::from_millis(delay_ms)).await;
273
274 res = self.get_result(&id).await?;
275 if res.get("verdict").is_some() {
276 return Ok(res);
277 }
278 }
279 Ok(res)
280 }
281}
282
283static COUNTER: AtomicU64 = AtomicU64::new(0);
286fn idempotency_id() -> String {
287 let nanos = SystemTime::now()
288 .duration_since(UNIX_EPOCH)
289 .map(|d| d.as_nanos() as u64)
290 .unwrap_or(0);
291 let n = COUNTER.fetch_add(1, Ordering::Relaxed);
292 format!("{nanos:016x}-{n:016x}")
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use wiremock::matchers::{method, path};
299 use wiremock::{Mock, MockServer, ResponseTemplate};
300
301 fn req() -> SubmitRequest {
302 SubmitRequest {
303 language: "python".into(),
304 code: "print(1)".into(),
305 stdin: "".into(),
306 profile: None,
307 webhook_url: None,
308 webhook_secret: None,
309 }
310 }
311
312 #[tokio::test]
313 async fn new_should_default_base_url_to_production() {
314 let client = Rustbox::new("k").unwrap();
315 assert_eq!(client.base_url(), DEFAULT_BASE_URL);
316 }
317
318 #[tokio::test]
319 async fn new_should_return_err_when_api_key_empty() {
320 let r = Rustbox::new("");
321 assert!(matches!(r, Err(RustboxError::MissingApiKey)));
322 }
323
324 #[tokio::test]
325 async fn with_base_url_should_override_default_and_trim_slash() {
326 let client = Rustbox::new("k")
327 .unwrap()
328 .with_base_url("https://custom.example.com/")
329 .unwrap();
330 assert_eq!(client.base_url(), "https://custom.example.com");
331 }
332
333 #[tokio::test]
334 async fn with_base_url_should_return_err_when_empty() {
335 let r = Rustbox::new("k").unwrap().with_base_url("");
336 assert!(matches!(r, Err(RustboxError::InvalidBaseUrl)));
337 }
338
339 #[tokio::test]
340 async fn run_should_return_verdict_on_first_response_when_complete() {
341 let mock_server = MockServer::start().await;
342 Mock::given(method("POST"))
343 .and(path("/api/submit"))
344 .respond_with(
345 ResponseTemplate::new(200)
346 .set_body_json(serde_json::json!({"id": "1", "verdict": "AC"})),
347 )
348 .mount(&mock_server)
349 .await;
350
351 let client = Rustbox::new("test")
352 .unwrap()
353 .with_base_url(&mock_server.uri())
354 .unwrap();
355 let res = client.run(&req()).await.unwrap();
356 assert_eq!(res.get("verdict").unwrap().as_str().unwrap(), "AC");
357 }
358
359 #[tokio::test]
360 async fn run_should_poll_until_verdict_when_initial_returns_408() {
361 let mock_server = MockServer::start().await;
362 Mock::given(method("POST"))
363 .and(path("/api/submit"))
364 .respond_with(ResponseTemplate::new(408).set_body_json(serde_json::json!({"id": "1"})))
365 .mount(&mock_server)
366 .await;
367
368 Mock::given(method("GET"))
369 .and(path("/api/result/1"))
370 .respond_with(
371 ResponseTemplate::new(200)
372 .set_body_json(serde_json::json!({"id": "1", "verdict": "TLE"})),
373 )
374 .mount(&mock_server)
375 .await;
376
377 let client = Rustbox::new("test")
378 .unwrap()
379 .with_base_url(&mock_server.uri())
380 .unwrap();
381 let res = client.run(&req()).await.unwrap();
382 assert_eq!(res.get("verdict").unwrap().as_str().unwrap(), "TLE");
383 }
384
385 #[tokio::test]
386 async fn submit_should_return_auth_err_on_401() {
387 let mock_server = MockServer::start().await;
388 Mock::given(method("POST"))
389 .and(path("/api/submit"))
390 .respond_with(ResponseTemplate::new(401))
391 .mount(&mock_server)
392 .await;
393
394 let client = Rustbox::new("test")
395 .unwrap()
396 .with_base_url(&mock_server.uri())
397 .unwrap();
398 let err = client
399 .submit(&req(), false, SubmitOptions::default())
400 .await
401 .unwrap_err();
402 assert!(matches!(err, RustboxError::Auth(401)));
403 }
404
405 #[tokio::test]
406 async fn submit_should_return_rate_limit_on_429() {
407 let mock_server = MockServer::start().await;
408 Mock::given(method("POST"))
409 .and(path("/api/submit"))
410 .respond_with(ResponseTemplate::new(429))
411 .mount(&mock_server)
412 .await;
413
414 let client = Rustbox::new("test")
415 .unwrap()
416 .with_base_url(&mock_server.uri())
417 .unwrap();
418 let err = client
419 .submit(&req(), false, SubmitOptions::default())
420 .await
421 .unwrap_err();
422 assert!(matches!(err, RustboxError::RateLimit));
423 }
424
425 #[tokio::test]
426 async fn submit_should_return_server_err_on_503_after_retries() {
427 let mock_server = MockServer::start().await;
428 Mock::given(method("POST"))
429 .and(path("/api/submit"))
430 .respond_with(ResponseTemplate::new(503))
431 .mount(&mock_server)
432 .await;
433
434 let client = Rustbox::new("test")
435 .unwrap()
436 .with_base_url(&mock_server.uri())
437 .unwrap()
438 .with_max_retries(1);
439 let err = client
440 .submit(&req(), false, SubmitOptions::default())
441 .await
442 .unwrap_err();
443 assert!(matches!(err, RustboxError::Server(503)));
444 }
445
446 #[tokio::test]
447 async fn submit_should_send_user_agent_header() {
448 let mock_server = MockServer::start().await;
449 Mock::given(method("POST"))
450 .and(path("/api/submit"))
451 .and(wiremock::matchers::header_regex(
452 "user-agent",
453 r"^rustbox-sdk-rust/",
454 ))
455 .respond_with(
456 ResponseTemplate::new(200)
457 .set_body_json(serde_json::json!({"id": "1", "verdict": "AC"})),
458 )
459 .mount(&mock_server)
460 .await;
461
462 let client = Rustbox::new("test")
463 .unwrap()
464 .with_base_url(&mock_server.uri())
465 .unwrap();
466 let res = client
467 .submit(&req(), false, SubmitOptions::default())
468 .await
469 .unwrap();
470 assert_eq!(res.get("verdict").unwrap().as_str().unwrap(), "AC");
471 }
472}