Skip to main content

arc_agi_rs/
python.rs

1// Copyright 2026 Mahmoud Harmouch.
2//
3// Licensed under the MIT license
4// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
5// option. This file may not be copied, modified, or distributed
6// except according to those terms.
7
8//! # Python Bindings
9//!
10//! Exposes the `arc-agi-rs` library to Python via [`pyo3`].
11//! Every type and function is gated behind the `python` cargo feature.
12//!
13//! The bindings provide **synchronous** wrappers around the async Rust API by
14//! driving a temporary, single-threaded [`tokio`] runtime inside each call,
15//! making the API feel native and straightforward for Python callers.
16//!
17//! # See Also
18//!
19//! - [ARC-AGI-3 Reference](https://arcprize.org/arc-agi/3)
20
21use std::fmt::Display;
22use std::future::Future;
23
24use pyo3::exceptions::PyRuntimeError;
25use pyo3::prelude::*;
26use pyo3::types::{PyAny, PyDict, PyList};
27use serde_json::{Map, Value};
28use tokio::runtime::Builder as RuntimeBuilder;
29
30use crate::client::Client;
31use crate::params::{MakeParams, ScorecardParams, StepParams};
32
33/// Convert a Python dict/object to a [`Value`] without the `pythonize` crate.
34///
35/// Only `dict`, `list`, `str`, `int`, `float`, `bool`, and `None` are
36/// supported - sufficient for the `data` and `reasoning` step payloads.
37fn pyany_to_json(obj: &Bound<'_, PyAny>) -> PyResult<Value> {
38    if obj.is_none() {
39        return Ok(Value::Null);
40    }
41    if let Ok(b) = obj.extract::<bool>() {
42        return Ok(Value::Bool(b));
43    }
44    if let Ok(i) = obj.extract::<i64>() {
45        return Ok(Value::from(i));
46    }
47    if let Ok(f) = obj.extract::<f64>() {
48        return Ok(Value::from(f));
49    }
50    if let Ok(s) = obj.extract::<String>() {
51        return Ok(Value::String(s));
52    }
53    if let Ok(list) = obj.extract::<Bound<'_, PyList>>() {
54        let arr: PyResult<Vec<_>> = list.iter().map(|v| pyany_to_json(&v)).collect();
55        return Ok(Value::Array(arr?));
56    }
57    if let Ok(dict) = obj.extract::<Bound<'_, PyDict>>() {
58        let mut map = Map::new();
59        for (k, v) in dict.iter() {
60            let key: String = k.extract()?;
61            map.insert(key, pyany_to_json(&v)?);
62        }
63        return Ok(Value::Object(map));
64    }
65    Err(PyRuntimeError::new_err(format!(
66        "cannot convert Python value of type {} to JSON",
67        obj.get_type().name()?
68    )))
69}
70
71/// Converts a Python `dict` to a [`Value`].
72fn pydict_to_json(d: &Bound<'_, PyDict>) -> PyResult<Value> {
73    pyany_to_json(d.as_any())
74}
75
76/// Drives an async future on a fresh single-threaded Tokio runtime.
77///
78/// This function is adapted from the `duckduckgo` crate's Python binding helper:
79/// <https://github.com/wiseaidotdev/duckduckgo/blob/main/src/python.rs>
80fn block_on<F, T, E>(future: F) -> PyResult<T>
81where
82    F: Future<Output = Result<T, E>>,
83    E: Display,
84{
85    RuntimeBuilder::new_current_thread()
86        .enable_all()
87        .build()
88        .map_err(|e| PyRuntimeError::new_err(e.to_string()))?
89        .block_on(future)
90        .map_err(|e| PyRuntimeError::new_err(e.to_string()))
91}
92
93/// Metadata for a single ARC-AGI-3 game environment.
94///
95/// All fields are read-only once the object is constructed by a
96/// :meth:`ArcAgiClient.list_environments` or
97/// :meth:`ArcAgiClient.get_environment` call.
98///
99/// # See Also
100///
101/// - [ARC-AGI-3 Reference](https://arcprize.org/arc-agi/3)
102#[pyclass(name = "EnvironmentInfo", frozen, skip_from_py_object)]
103#[derive(Debug, Clone)]
104pub struct PyEnvironmentInfo {
105    /// Unique game identifier (e.g. ``"ls20"``).
106    #[pyo3(get)]
107    pub game_id: String,
108    /// Human-readable title.
109    #[pyo3(get)]
110    pub title: Option<String>,
111    /// Default frames-per-second.
112    #[pyo3(get)]
113    pub default_fps: Option<u32>,
114    /// Classification tags.
115    #[pyo3(get)]
116    pub tags: Option<Vec<String>>,
117}
118
119#[pymethods]
120impl PyEnvironmentInfo {
121    pub fn __repr__(&self) -> String {
122        format!(
123            "EnvironmentInfo(game_id={:?}, title={:?})",
124            self.game_id, self.title
125        )
126    }
127}
128
129/// The current state of a game run returned by reset and step calls.
130///
131/// All fields are read-only once the object is constructed by a
132/// :meth:`ArcAgiClient.reset` or :meth:`ArcAgiClient.step` call.
133///
134/// # See Also
135///
136/// - [ARC-AGI-3 Reference](https://arcprize.org/arc-agi/3)
137#[pyclass(name = "FrameData", frozen, skip_from_py_object)]
138#[derive(Debug, Clone)]
139pub struct PyFrameData {
140    /// Game identifier.
141    #[pyo3(get)]
142    pub game_id: String,
143    /// Unique run identifier assigned by the server.
144    #[pyo3(get)]
145    pub guid: Option<String>,
146    /// Current lifecycle state (e.g. ``"NOT_FINISHED"``, ``"WIN"``).
147    #[pyo3(get)]
148    pub state: String,
149    /// Number of levels completed in this run.
150    #[pyo3(get)]
151    pub levels_completed: u32,
152    /// Total levels that must be completed to win.
153    #[pyo3(get)]
154    pub win_levels: u32,
155    /// Action IDs the agent may send on the next step.
156    #[pyo3(get)]
157    pub available_actions: Vec<u32>,
158    /// Whether this response corresponds to a full game reset.
159    #[pyo3(get)]
160    pub full_reset: bool,
161}
162
163#[pymethods]
164impl PyFrameData {
165    pub fn __repr__(&self) -> String {
166        format!(
167            "FrameData(game_id={:?}, guid={:?}, state={:?})",
168            self.game_id, self.guid, self.state
169        )
170    }
171}
172
173/// A per-game score entry inside an :class:`EnvironmentScorecard`.
174///
175/// All fields are read-only.
176#[pyclass(name = "EnvironmentScore", frozen, skip_from_py_object)]
177#[derive(Debug, Clone)]
178pub struct PyEnvironmentScore {
179    /// Game identifier.
180    #[pyo3(get)]
181    pub id: Option<String>,
182    /// Aggregate score (0.0–115.0).
183    #[pyo3(get)]
184    pub score: f64,
185    /// Number of levels completed.
186    #[pyo3(get)]
187    pub levels_completed: u32,
188    /// Total actions taken.
189    #[pyo3(get)]
190    pub actions: u32,
191    /// Whether the environment was fully completed.
192    #[pyo3(get)]
193    pub completed: Option<bool>,
194}
195
196#[pymethods]
197impl PyEnvironmentScore {
198    pub fn __repr__(&self) -> String {
199        format!(
200            "EnvironmentScore(id={:?}, score={:.2})",
201            self.id, self.score
202        )
203    }
204}
205
206/// The full scored scorecard returned by scorecard endpoints.
207///
208/// All fields are read-only once the object is constructed by a
209/// :meth:`ArcAgiClient.get_scorecard` or :meth:`ArcAgiClient.close_scorecard`
210/// call.
211///
212/// # See Also
213///
214/// - [ARC-AGI-3 Reference](https://arcprize.org/arc-agi/3)
215#[pyclass(name = "EnvironmentScorecard", frozen, skip_from_py_object)]
216#[derive(Debug, Clone)]
217pub struct PyEnvironmentScorecard {
218    /// Unique scorecard identifier.
219    #[pyo3(get)]
220    pub card_id: String,
221    /// Aggregate score across all environments (0.0–115.0).
222    #[pyo3(get)]
223    pub score: f64,
224    /// Whether the scorecard was created in competition mode.
225    #[pyo3(get)]
226    pub competition_mode: Option<bool>,
227    /// Total environments completed.
228    #[pyo3(get)]
229    pub total_environments_completed: Option<u32>,
230    /// Total environments.
231    #[pyo3(get)]
232    pub total_environments: Option<u32>,
233    /// Total levels completed.
234    #[pyo3(get)]
235    pub total_levels_completed: Option<u32>,
236    /// Total levels.
237    #[pyo3(get)]
238    pub total_levels: Option<u32>,
239    /// Total actions taken.
240    #[pyo3(get)]
241    pub total_actions: Option<u32>,
242}
243
244#[pymethods]
245impl PyEnvironmentScorecard {
246    pub fn __repr__(&self) -> String {
247        format!(
248            "EnvironmentScorecard(card_id={:?}, score={:.2})",
249            self.card_id, self.score
250        )
251    }
252}
253
254/// An HTTP client for interacting with the ARC-AGI-3 REST API.
255///
256/// Construct with ``ArcAgiClient()`` for a zero-configuration default that
257/// reads credentials from the ``ARC_API_KEY`` and ``ARC_BASE_URL`` environment
258/// variables, or supply keyword arguments to configure them explicitly. All
259/// network methods are **synchronous** from Python's perspective; they drive
260/// an internal single-threaded Tokio runtime for each call.
261///
262/// # See Also
263///
264/// - [ARC-AGI-3 Reference](https://arcprize.org/arc-agi/3)
265#[pyclass(name = "ArcAgiClient")]
266pub struct PyArcAgiClient {
267    inner: Client,
268}
269
270#[pymethods]
271impl PyArcAgiClient {
272    /// Create a new ``ArcAgiClient``.
273    ///
274    /// Args:
275    ///     api_key:      Optional API key string. Falls back to ``ARC_API_KEY``
276    ///                   environment variable and then to an empty string.
277    ///     base_url:     Optional server base URL. Falls back to ``ARC_BASE_URL``
278    ///                   environment variable and then to
279    ///                   ``"https://three.arcprize.org"``.
280    ///     cookie_store: Enable cookie persistence across requests
281    ///                   (default ``False``).
282    ///     proxy:        Optional proxy URL, e.g. ``"socks5://127.0.0.1:9050"``.
283    ///
284    /// Raises:
285    ///     RuntimeError: If the proxy URL is invalid or the HTTP client cannot
286    ///                   be constructed.
287    #[new]
288    #[pyo3(signature = (api_key=None, base_url=None, cookie_store=false, proxy=None))]
289    pub fn new(
290        api_key: Option<String>,
291        base_url: Option<String>,
292        cookie_store: bool,
293        proxy: Option<String>,
294    ) -> PyResult<Self> {
295        let mut builder = Client::builder();
296        if let Some(key) = api_key {
297            builder = builder.api_key(key);
298        }
299        if let Some(url) = base_url {
300            builder = builder.base_url(url);
301        }
302        if cookie_store {
303            builder = builder.cookie_store(true);
304        }
305        if let Some(proxy_url) = proxy {
306            builder = builder.proxy(proxy_url);
307        }
308        let inner = builder
309            .build()
310            .map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
311        Ok(Self { inner })
312    }
313
314    /// Retrieve an anonymous API key from the server.
315    ///
316    /// Returns:
317    ///     The anonymous API key string.
318    ///
319    /// Raises:
320    ///     RuntimeError: On network or HTTP failure.
321    pub fn get_anonymous_key(&self) -> PyResult<String> {
322        block_on(self.inner.get_anonymous_key())
323    }
324
325    /// Return the list of all available game environments.
326    ///
327    /// Returns:
328    ///     A list of :class:`EnvironmentInfo` objects.
329    ///
330    /// Raises:
331    ///     RuntimeError: On network or HTTP failure.
332    pub fn list_environments(&self) -> PyResult<Vec<PyEnvironmentInfo>> {
333        let envs = block_on(self.inner.list_environments())?;
334        Ok(envs
335            .into_iter()
336            .map(|e| PyEnvironmentInfo {
337                game_id: e.game_id,
338                title: e.title,
339                default_fps: e.default_fps,
340                tags: e.tags,
341            })
342            .collect())
343    }
344
345    /// Return metadata for a single game environment.
346    ///
347    /// Args:
348    ///     game_id: The game identifier (e.g. ``"ls20"``).
349    ///
350    /// Returns:
351    ///     An :class:`EnvironmentInfo` object.
352    ///
353    /// Raises:
354    ///     RuntimeError: If the game is not found or on network failure.
355    pub fn get_environment(&self, game_id: String) -> PyResult<PyEnvironmentInfo> {
356        let info = block_on(self.inner.get_environment(&game_id))?;
357        Ok(PyEnvironmentInfo {
358            game_id: info.game_id,
359            title: info.title,
360            default_fps: info.default_fps,
361            tags: info.tags,
362        })
363    }
364
365    /// Create a new scorecard and return its ID.
366    ///
367    /// Args:
368    ///     source_url:       Optional URL linking to the agent being evaluated.
369    ///     tags:             Optional list of classification tag strings.
370    ///     competition_mode: When ``True``, enables one-way competition semantics.
371    ///
372    /// Returns:
373    ///     The ``card_id`` string of the newly created scorecard.
374    ///
375    /// Raises:
376    ///     RuntimeError: On network or HTTP failure.
377    #[pyo3(signature = (source_url=None, tags=None, competition_mode=None))]
378    pub fn open_scorecard(
379        &self,
380        source_url: Option<String>,
381        tags: Option<Vec<String>>,
382        competition_mode: Option<bool>,
383    ) -> PyResult<String> {
384        let mut params = ScorecardParams::new();
385        if let Some(url) = source_url {
386            params = params.source_url(url);
387        }
388        if let Some(t) = tags {
389            params = params.tags(t);
390        }
391        if let Some(cm) = competition_mode {
392            params = params.competition_mode(cm);
393        }
394        block_on(self.inner.open_scorecard(Some(params)))
395    }
396
397    /// Retrieve an existing scorecard by its ID.
398    ///
399    /// Args:
400    ///     card_id: The scorecard identifier.
401    ///
402    /// Returns:
403    ///     An :class:`EnvironmentScorecard` object.
404    ///
405    /// Raises:
406    ///     RuntimeError: If the scorecard is not found or on network failure.
407    pub fn get_scorecard(&self, card_id: String) -> PyResult<PyEnvironmentScorecard> {
408        let card = block_on(self.inner.get_scorecard(&card_id))?;
409        Ok(PyEnvironmentScorecard {
410            card_id: card.card_id,
411            score: card.score,
412            competition_mode: card.competition_mode,
413            total_environments_completed: card.total_environments_completed,
414            total_environments: card.total_environments,
415            total_levels_completed: card.total_levels_completed,
416            total_levels: card.total_levels,
417            total_actions: card.total_actions,
418        })
419    }
420
421    /// Close and finalise a scorecard.
422    ///
423    /// Args:
424    ///     card_id: The scorecard identifier.
425    ///
426    /// Returns:
427    ///     An :class:`EnvironmentScorecard` with the final scores.
428    ///
429    /// Raises:
430    ///     RuntimeError: If the scorecard is not found or on network failure.
431    pub fn close_scorecard(&self, card_id: String) -> PyResult<PyEnvironmentScorecard> {
432        let card = block_on(self.inner.close_scorecard(&card_id))?;
433        Ok(PyEnvironmentScorecard {
434            card_id: card.card_id,
435            score: card.score,
436            competition_mode: card.competition_mode,
437            total_environments_completed: card.total_environments_completed,
438            total_environments: card.total_environments,
439            total_levels_completed: card.total_levels_completed,
440            total_levels: card.total_levels,
441            total_actions: card.total_actions,
442        })
443    }
444
445    /// Reset (or start) a game environment.
446    ///
447    /// Args:
448    ///     game_id:      The game identifier.
449    ///     scorecard_id: The scorecard to record this run under.
450    ///     guid:         Optional existing run GUID to re-use.
451    ///     seed:         Random seed for reproducible level ordering (default ``0``).
452    ///
453    /// Returns:
454    ///     A :class:`FrameData` object with the initial game state.
455    ///
456    /// Raises:
457    ///     RuntimeError: On network or HTTP failure.
458    #[pyo3(signature = (game_id, scorecard_id, guid=None, seed=0))]
459    pub fn reset(
460        &self,
461        game_id: String,
462        scorecard_id: String,
463        guid: Option<String>,
464        seed: u32,
465    ) -> PyResult<PyFrameData> {
466        let mut params = MakeParams::new(&game_id, &scorecard_id).seed(seed);
467        if let Some(g) = guid {
468            params = params.guid(g);
469        }
470        let frame = block_on(self.inner.reset(params))?;
471        Ok(PyFrameData {
472            game_id: frame.game_id,
473            guid: frame.guid,
474            state: frame.state.as_str().to_string(),
475            levels_completed: frame.levels_completed,
476            win_levels: frame.win_levels,
477            available_actions: frame.available_actions,
478            full_reset: frame.full_reset,
479        })
480    }
481
482    /// Send one game action and receive the resulting frame.
483    ///
484    /// Args:
485    ///     game_id:      The game identifier.
486    ///     scorecard_id: The scorecard identifier.
487    ///     guid:         The run GUID from a prior :meth:`reset` call.
488    ///     action_id:    Numeric action ID (0 = RESET).
489    ///     data:         Optional dict of action data (e.g. ``{"x": 3, "y": 4}``).
490    ///     reasoning:    Optional dict of freeform reasoning.
491    ///
492    /// Returns:
493    ///     A :class:`FrameData` object with the updated game state.
494    ///
495    /// Raises:
496    ///     RuntimeError: On network or HTTP failure.
497    #[pyo3(signature = (game_id, scorecard_id, guid, action_id, data=None, reasoning=None))]
498    pub fn step(
499        &self,
500        game_id: String,
501        scorecard_id: String,
502        guid: String,
503        action_id: u32,
504        data: Option<&pyo3::Bound<'_, pyo3::types::PyDict>>,
505        reasoning: Option<&pyo3::Bound<'_, pyo3::types::PyDict>>,
506    ) -> PyResult<PyFrameData> {
507        let data_json = data.map(pydict_to_json).transpose()?;
508        let reasoning_json = reasoning.map(pydict_to_json).transpose()?;
509
510        let mut params = StepParams::new(&game_id, &scorecard_id, &guid, action_id);
511        if let Some(d) = data_json {
512            params = params.data(d);
513        }
514        if let Some(r) = reasoning_json {
515            params = params.reasoning(r);
516        }
517        let frame = block_on(self.inner.step(params))?;
518        Ok(PyFrameData {
519            game_id: frame.game_id,
520            guid: frame.guid,
521            state: frame.state.as_str().to_string(),
522            levels_completed: frame.levels_completed,
523            win_levels: frame.win_levels,
524            available_actions: frame.available_actions,
525            full_reset: frame.full_reset,
526        })
527    }
528
529    pub fn __repr__(&self) -> String {
530        format!("ArcAgiClient(base_url={:?})", self.inner.base_url())
531    }
532}
533
534/// Register all Python-exposed types and functions into the ``arc_agi_rs`` module.
535pub fn register_python_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
536    m.add_class::<PyEnvironmentInfo>()?;
537    m.add_class::<PyFrameData>()?;
538    m.add_class::<PyEnvironmentScore>()?;
539    m.add_class::<PyEnvironmentScorecard>()?;
540    m.add_class::<PyArcAgiClient>()?;
541    Ok(())
542}
543
544// Copyright 2026 Mahmoud Harmouch.
545//
546// Licensed under the MIT license
547// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
548// option. This file may not be copied, modified, or distributed
549// except according to those terms.