arc_agi_rs/python.rs
1// Copyright 2026 Mahmoud Harmouch.
2//
3// Licensed under the MIT license
4// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
5// option. This file may not be copied, modified, or distributed
6// except according to those terms.
7
8//! # Python Bindings
9//!
10//! Exposes the `arc-agi-rs` library to Python via [`pyo3`].
11//! Every type and function is gated behind the `python` cargo feature.
12//!
13//! The bindings provide **synchronous** wrappers around the async Rust API by
14//! driving a temporary, single-threaded [`tokio`] runtime inside each call,
15//! making the API feel native and straightforward for Python callers.
16//!
17//! # See Also
18//!
19//! - [ARC-AGI-3 Reference](https://arcprize.org/arc-agi/3)
20
21use std::fmt::Display;
22use std::future::Future;
23
24use pyo3::exceptions::PyRuntimeError;
25use pyo3::prelude::*;
26use pyo3::types::{PyAny, PyDict, PyList};
27use serde_json::{Map, Value};
28use tokio::runtime::Builder as RuntimeBuilder;
29
30use crate::client::Client;
31use crate::params::{MakeParams, ScorecardParams, StepParams};
32
33/// Convert a Python dict/object to a [`Value`] without the `pythonize` crate.
34///
35/// Only `dict`, `list`, `str`, `int`, `float`, `bool`, and `None` are
36/// supported - sufficient for the `data` and `reasoning` step payloads.
37fn pyany_to_json(obj: &Bound<'_, PyAny>) -> PyResult<Value> {
38 if obj.is_none() {
39 return Ok(Value::Null);
40 }
41 if let Ok(b) = obj.extract::<bool>() {
42 return Ok(Value::Bool(b));
43 }
44 if let Ok(i) = obj.extract::<i64>() {
45 return Ok(Value::from(i));
46 }
47 if let Ok(f) = obj.extract::<f64>() {
48 return Ok(Value::from(f));
49 }
50 if let Ok(s) = obj.extract::<String>() {
51 return Ok(Value::String(s));
52 }
53 if let Ok(list) = obj.extract::<Bound<'_, PyList>>() {
54 let arr: PyResult<Vec<_>> = list.iter().map(|v| pyany_to_json(&v)).collect();
55 return Ok(Value::Array(arr?));
56 }
57 if let Ok(dict) = obj.extract::<Bound<'_, PyDict>>() {
58 let mut map = Map::new();
59 for (k, v) in dict.iter() {
60 let key: String = k.extract()?;
61 map.insert(key, pyany_to_json(&v)?);
62 }
63 return Ok(Value::Object(map));
64 }
65 Err(PyRuntimeError::new_err(format!(
66 "cannot convert Python value of type {} to JSON",
67 obj.get_type().name()?
68 )))
69}
70
71/// Converts a Python `dict` to a [`Value`].
72fn pydict_to_json(d: &Bound<'_, PyDict>) -> PyResult<Value> {
73 pyany_to_json(d.as_any())
74}
75
76/// Drives an async future on a fresh single-threaded Tokio runtime.
77///
78/// This function is adapted from the `duckduckgo` crate's Python binding helper:
79/// <https://github.com/wiseaidotdev/duckduckgo/blob/main/src/python.rs>
80fn block_on<F, T, E>(future: F) -> PyResult<T>
81where
82 F: Future<Output = Result<T, E>>,
83 E: Display,
84{
85 RuntimeBuilder::new_current_thread()
86 .enable_all()
87 .build()
88 .map_err(|e| PyRuntimeError::new_err(e.to_string()))?
89 .block_on(future)
90 .map_err(|e| PyRuntimeError::new_err(e.to_string()))
91}
92
93/// Metadata for a single ARC-AGI-3 game environment.
94///
95/// All fields are read-only once the object is constructed by a
96/// :meth:`ArcAgiClient.list_environments` or
97/// :meth:`ArcAgiClient.get_environment` call.
98///
99/// # See Also
100///
101/// - [ARC-AGI-3 Reference](https://arcprize.org/arc-agi/3)
102#[pyclass(name = "EnvironmentInfo", frozen, skip_from_py_object)]
103#[derive(Debug, Clone)]
104pub struct PyEnvironmentInfo {
105 /// Unique game identifier (e.g. ``"ls20"``).
106 #[pyo3(get)]
107 pub game_id: String,
108 /// Human-readable title.
109 #[pyo3(get)]
110 pub title: Option<String>,
111 /// Default frames-per-second.
112 #[pyo3(get)]
113 pub default_fps: Option<u32>,
114 /// Classification tags.
115 #[pyo3(get)]
116 pub tags: Option<Vec<String>>,
117}
118
119#[pymethods]
120impl PyEnvironmentInfo {
121 pub fn __repr__(&self) -> String {
122 format!(
123 "EnvironmentInfo(game_id={:?}, title={:?})",
124 self.game_id, self.title
125 )
126 }
127}
128
129/// The current state of a game run returned by reset and step calls.
130///
131/// All fields are read-only once the object is constructed by a
132/// :meth:`ArcAgiClient.reset` or :meth:`ArcAgiClient.step` call.
133///
134/// # See Also
135///
136/// - [ARC-AGI-3 Reference](https://arcprize.org/arc-agi/3)
137#[pyclass(name = "FrameData", frozen, skip_from_py_object)]
138#[derive(Debug, Clone)]
139pub struct PyFrameData {
140 /// Game identifier.
141 #[pyo3(get)]
142 pub game_id: String,
143 /// Unique run identifier assigned by the server.
144 #[pyo3(get)]
145 pub guid: Option<String>,
146 /// Current lifecycle state (e.g. ``"NOT_FINISHED"``, ``"WIN"``).
147 #[pyo3(get)]
148 pub state: String,
149 /// Number of levels completed in this run.
150 #[pyo3(get)]
151 pub levels_completed: u32,
152 /// Total levels that must be completed to win.
153 #[pyo3(get)]
154 pub win_levels: u32,
155 /// Action IDs the agent may send on the next step.
156 #[pyo3(get)]
157 pub available_actions: Vec<u32>,
158 /// Whether this response corresponds to a full game reset.
159 #[pyo3(get)]
160 pub full_reset: bool,
161}
162
163#[pymethods]
164impl PyFrameData {
165 pub fn __repr__(&self) -> String {
166 format!(
167 "FrameData(game_id={:?}, guid={:?}, state={:?})",
168 self.game_id, self.guid, self.state
169 )
170 }
171}
172
173/// A per-game score entry inside an :class:`EnvironmentScorecard`.
174///
175/// All fields are read-only.
176#[pyclass(name = "EnvironmentScore", frozen, skip_from_py_object)]
177#[derive(Debug, Clone)]
178pub struct PyEnvironmentScore {
179 /// Game identifier.
180 #[pyo3(get)]
181 pub id: Option<String>,
182 /// Aggregate score (0.0–115.0).
183 #[pyo3(get)]
184 pub score: f64,
185 /// Number of levels completed.
186 #[pyo3(get)]
187 pub levels_completed: u32,
188 /// Total actions taken.
189 #[pyo3(get)]
190 pub actions: u32,
191 /// Whether the environment was fully completed.
192 #[pyo3(get)]
193 pub completed: Option<bool>,
194}
195
196#[pymethods]
197impl PyEnvironmentScore {
198 pub fn __repr__(&self) -> String {
199 format!(
200 "EnvironmentScore(id={:?}, score={:.2})",
201 self.id, self.score
202 )
203 }
204}
205
206/// The full scored scorecard returned by scorecard endpoints.
207///
208/// All fields are read-only once the object is constructed by a
209/// :meth:`ArcAgiClient.get_scorecard` or :meth:`ArcAgiClient.close_scorecard`
210/// call.
211///
212/// # See Also
213///
214/// - [ARC-AGI-3 Reference](https://arcprize.org/arc-agi/3)
215#[pyclass(name = "EnvironmentScorecard", frozen, skip_from_py_object)]
216#[derive(Debug, Clone)]
217pub struct PyEnvironmentScorecard {
218 /// Unique scorecard identifier.
219 #[pyo3(get)]
220 pub card_id: String,
221 /// Aggregate score across all environments (0.0–115.0).
222 #[pyo3(get)]
223 pub score: f64,
224 /// Whether the scorecard was created in competition mode.
225 #[pyo3(get)]
226 pub competition_mode: Option<bool>,
227 /// Total environments completed.
228 #[pyo3(get)]
229 pub total_environments_completed: Option<u32>,
230 /// Total environments.
231 #[pyo3(get)]
232 pub total_environments: Option<u32>,
233 /// Total levels completed.
234 #[pyo3(get)]
235 pub total_levels_completed: Option<u32>,
236 /// Total levels.
237 #[pyo3(get)]
238 pub total_levels: Option<u32>,
239 /// Total actions taken.
240 #[pyo3(get)]
241 pub total_actions: Option<u32>,
242}
243
244#[pymethods]
245impl PyEnvironmentScorecard {
246 pub fn __repr__(&self) -> String {
247 format!(
248 "EnvironmentScorecard(card_id={:?}, score={:.2})",
249 self.card_id, self.score
250 )
251 }
252}
253
254/// An HTTP client for interacting with the ARC-AGI-3 REST API.
255///
256/// Construct with ``ArcAgiClient()`` for a zero-configuration default that
257/// reads credentials from the ``ARC_API_KEY`` and ``ARC_BASE_URL`` environment
258/// variables, or supply keyword arguments to configure them explicitly. All
259/// network methods are **synchronous** from Python's perspective; they drive
260/// an internal single-threaded Tokio runtime for each call.
261///
262/// # See Also
263///
264/// - [ARC-AGI-3 Reference](https://arcprize.org/arc-agi/3)
265#[pyclass(name = "ArcAgiClient")]
266pub struct PyArcAgiClient {
267 inner: Client,
268}
269
270#[pymethods]
271impl PyArcAgiClient {
272 /// Create a new ``ArcAgiClient``.
273 ///
274 /// Args:
275 /// api_key: Optional API key string. Falls back to ``ARC_API_KEY``
276 /// environment variable and then to an empty string.
277 /// base_url: Optional server base URL. Falls back to ``ARC_BASE_URL``
278 /// environment variable and then to
279 /// ``"https://three.arcprize.org"``.
280 /// cookie_store: Enable cookie persistence across requests
281 /// (default ``False``).
282 /// proxy: Optional proxy URL, e.g. ``"socks5://127.0.0.1:9050"``.
283 ///
284 /// Raises:
285 /// RuntimeError: If the proxy URL is invalid or the HTTP client cannot
286 /// be constructed.
287 #[new]
288 #[pyo3(signature = (api_key=None, base_url=None, cookie_store=false, proxy=None))]
289 pub fn new(
290 api_key: Option<String>,
291 base_url: Option<String>,
292 cookie_store: bool,
293 proxy: Option<String>,
294 ) -> PyResult<Self> {
295 let mut builder = Client::builder();
296 if let Some(key) = api_key {
297 builder = builder.api_key(key);
298 }
299 if let Some(url) = base_url {
300 builder = builder.base_url(url);
301 }
302 if cookie_store {
303 builder = builder.cookie_store(true);
304 }
305 if let Some(proxy_url) = proxy {
306 builder = builder.proxy(proxy_url);
307 }
308 let inner = builder
309 .build()
310 .map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
311 Ok(Self { inner })
312 }
313
314 /// Retrieve an anonymous API key from the server.
315 ///
316 /// Returns:
317 /// The anonymous API key string.
318 ///
319 /// Raises:
320 /// RuntimeError: On network or HTTP failure.
321 pub fn get_anonymous_key(&self) -> PyResult<String> {
322 block_on(self.inner.get_anonymous_key())
323 }
324
325 /// Return the list of all available game environments.
326 ///
327 /// Returns:
328 /// A list of :class:`EnvironmentInfo` objects.
329 ///
330 /// Raises:
331 /// RuntimeError: On network or HTTP failure.
332 pub fn list_environments(&self) -> PyResult<Vec<PyEnvironmentInfo>> {
333 let envs = block_on(self.inner.list_environments())?;
334 Ok(envs
335 .into_iter()
336 .map(|e| PyEnvironmentInfo {
337 game_id: e.game_id,
338 title: e.title,
339 default_fps: e.default_fps,
340 tags: e.tags,
341 })
342 .collect())
343 }
344
345 /// Return metadata for a single game environment.
346 ///
347 /// Args:
348 /// game_id: The game identifier (e.g. ``"ls20"``).
349 ///
350 /// Returns:
351 /// An :class:`EnvironmentInfo` object.
352 ///
353 /// Raises:
354 /// RuntimeError: If the game is not found or on network failure.
355 pub fn get_environment(&self, game_id: String) -> PyResult<PyEnvironmentInfo> {
356 let info = block_on(self.inner.get_environment(&game_id))?;
357 Ok(PyEnvironmentInfo {
358 game_id: info.game_id,
359 title: info.title,
360 default_fps: info.default_fps,
361 tags: info.tags,
362 })
363 }
364
365 /// Create a new scorecard and return its ID.
366 ///
367 /// Args:
368 /// source_url: Optional URL linking to the agent being evaluated.
369 /// tags: Optional list of classification tag strings.
370 /// competition_mode: When ``True``, enables one-way competition semantics.
371 ///
372 /// Returns:
373 /// The ``card_id`` string of the newly created scorecard.
374 ///
375 /// Raises:
376 /// RuntimeError: On network or HTTP failure.
377 #[pyo3(signature = (source_url=None, tags=None, competition_mode=None))]
378 pub fn open_scorecard(
379 &self,
380 source_url: Option<String>,
381 tags: Option<Vec<String>>,
382 competition_mode: Option<bool>,
383 ) -> PyResult<String> {
384 let mut params = ScorecardParams::new();
385 if let Some(url) = source_url {
386 params = params.source_url(url);
387 }
388 if let Some(t) = tags {
389 params = params.tags(t);
390 }
391 if let Some(cm) = competition_mode {
392 params = params.competition_mode(cm);
393 }
394 block_on(self.inner.open_scorecard(Some(params)))
395 }
396
397 /// Retrieve an existing scorecard by its ID.
398 ///
399 /// Args:
400 /// card_id: The scorecard identifier.
401 ///
402 /// Returns:
403 /// An :class:`EnvironmentScorecard` object.
404 ///
405 /// Raises:
406 /// RuntimeError: If the scorecard is not found or on network failure.
407 pub fn get_scorecard(&self, card_id: String) -> PyResult<PyEnvironmentScorecard> {
408 let card = block_on(self.inner.get_scorecard(&card_id))?;
409 Ok(PyEnvironmentScorecard {
410 card_id: card.card_id,
411 score: card.score,
412 competition_mode: card.competition_mode,
413 total_environments_completed: card.total_environments_completed,
414 total_environments: card.total_environments,
415 total_levels_completed: card.total_levels_completed,
416 total_levels: card.total_levels,
417 total_actions: card.total_actions,
418 })
419 }
420
421 /// Close and finalise a scorecard.
422 ///
423 /// Args:
424 /// card_id: The scorecard identifier.
425 ///
426 /// Returns:
427 /// An :class:`EnvironmentScorecard` with the final scores.
428 ///
429 /// Raises:
430 /// RuntimeError: If the scorecard is not found or on network failure.
431 pub fn close_scorecard(&self, card_id: String) -> PyResult<PyEnvironmentScorecard> {
432 let card = block_on(self.inner.close_scorecard(&card_id))?;
433 Ok(PyEnvironmentScorecard {
434 card_id: card.card_id,
435 score: card.score,
436 competition_mode: card.competition_mode,
437 total_environments_completed: card.total_environments_completed,
438 total_environments: card.total_environments,
439 total_levels_completed: card.total_levels_completed,
440 total_levels: card.total_levels,
441 total_actions: card.total_actions,
442 })
443 }
444
445 /// Reset (or start) a game environment.
446 ///
447 /// Args:
448 /// game_id: The game identifier.
449 /// scorecard_id: The scorecard to record this run under.
450 /// guid: Optional existing run GUID to re-use.
451 /// seed: Random seed for reproducible level ordering (default ``0``).
452 ///
453 /// Returns:
454 /// A :class:`FrameData` object with the initial game state.
455 ///
456 /// Raises:
457 /// RuntimeError: On network or HTTP failure.
458 #[pyo3(signature = (game_id, scorecard_id, guid=None, seed=0))]
459 pub fn reset(
460 &self,
461 game_id: String,
462 scorecard_id: String,
463 guid: Option<String>,
464 seed: u32,
465 ) -> PyResult<PyFrameData> {
466 let mut params = MakeParams::new(&game_id, &scorecard_id).seed(seed);
467 if let Some(g) = guid {
468 params = params.guid(g);
469 }
470 let frame = block_on(self.inner.reset(params))?;
471 Ok(PyFrameData {
472 game_id: frame.game_id,
473 guid: frame.guid,
474 state: frame.state.as_str().to_string(),
475 levels_completed: frame.levels_completed,
476 win_levels: frame.win_levels,
477 available_actions: frame.available_actions,
478 full_reset: frame.full_reset,
479 })
480 }
481
482 /// Send one game action and receive the resulting frame.
483 ///
484 /// Args:
485 /// game_id: The game identifier.
486 /// scorecard_id: The scorecard identifier.
487 /// guid: The run GUID from a prior :meth:`reset` call.
488 /// action_id: Numeric action ID (0 = RESET).
489 /// data: Optional dict of action data (e.g. ``{"x": 3, "y": 4}``).
490 /// reasoning: Optional dict of freeform reasoning.
491 ///
492 /// Returns:
493 /// A :class:`FrameData` object with the updated game state.
494 ///
495 /// Raises:
496 /// RuntimeError: On network or HTTP failure.
497 #[pyo3(signature = (game_id, scorecard_id, guid, action_id, data=None, reasoning=None))]
498 pub fn step(
499 &self,
500 game_id: String,
501 scorecard_id: String,
502 guid: String,
503 action_id: u32,
504 data: Option<&pyo3::Bound<'_, pyo3::types::PyDict>>,
505 reasoning: Option<&pyo3::Bound<'_, pyo3::types::PyDict>>,
506 ) -> PyResult<PyFrameData> {
507 let data_json = data.map(pydict_to_json).transpose()?;
508 let reasoning_json = reasoning.map(pydict_to_json).transpose()?;
509
510 let mut params = StepParams::new(&game_id, &scorecard_id, &guid, action_id);
511 if let Some(d) = data_json {
512 params = params.data(d);
513 }
514 if let Some(r) = reasoning_json {
515 params = params.reasoning(r);
516 }
517 let frame = block_on(self.inner.step(params))?;
518 Ok(PyFrameData {
519 game_id: frame.game_id,
520 guid: frame.guid,
521 state: frame.state.as_str().to_string(),
522 levels_completed: frame.levels_completed,
523 win_levels: frame.win_levels,
524 available_actions: frame.available_actions,
525 full_reset: frame.full_reset,
526 })
527 }
528
529 pub fn __repr__(&self) -> String {
530 format!("ArcAgiClient(base_url={:?})", self.inner.base_url())
531 }
532}
533
534/// Register all Python-exposed types and functions into the ``arc_agi_rs`` module.
535pub fn register_python_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
536 m.add_class::<PyEnvironmentInfo>()?;
537 m.add_class::<PyFrameData>()?;
538 m.add_class::<PyEnvironmentScore>()?;
539 m.add_class::<PyEnvironmentScorecard>()?;
540 m.add_class::<PyArcAgiClient>()?;
541 Ok(())
542}
543
544// Copyright 2026 Mahmoud Harmouch.
545//
546// Licensed under the MIT license
547// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
548// option. This file may not be copied, modified, or distributed
549// except according to those terms.