use std::fmt::Display;
use std::future::Future;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use pyo3::types::{PyAny, PyDict, PyList};
use serde_json::{Map, Value};
use tokio::runtime::Builder as RuntimeBuilder;
use crate::client::Client;
use crate::params::{MakeParams, ScorecardParams, StepParams};
fn pyany_to_json(obj: &Bound<'_, PyAny>) -> PyResult<Value> {
if obj.is_none() {
return Ok(Value::Null);
}
if let Ok(b) = obj.extract::<bool>() {
return Ok(Value::Bool(b));
}
if let Ok(i) = obj.extract::<i64>() {
return Ok(Value::from(i));
}
if let Ok(f) = obj.extract::<f64>() {
return Ok(Value::from(f));
}
if let Ok(s) = obj.extract::<String>() {
return Ok(Value::String(s));
}
if let Ok(list) = obj.extract::<Bound<'_, PyList>>() {
let arr: PyResult<Vec<_>> = list.iter().map(|v| pyany_to_json(&v)).collect();
return Ok(Value::Array(arr?));
}
if let Ok(dict) = obj.extract::<Bound<'_, PyDict>>() {
let mut map = Map::new();
for (k, v) in dict.iter() {
let key: String = k.extract()?;
map.insert(key, pyany_to_json(&v)?);
}
return Ok(Value::Object(map));
}
Err(PyRuntimeError::new_err(format!(
"cannot convert Python value of type {} to JSON",
obj.get_type().name()?
)))
}
fn pydict_to_json(d: &Bound<'_, PyDict>) -> PyResult<Value> {
pyany_to_json(d.as_any())
}
fn block_on<F, T, E>(future: F) -> PyResult<T>
where
F: Future<Output = Result<T, E>>,
E: Display,
{
RuntimeBuilder::new_current_thread()
.enable_all()
.build()
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?
.block_on(future)
.map_err(|e| PyRuntimeError::new_err(e.to_string()))
}
#[pyclass(name = "EnvironmentInfo", frozen, skip_from_py_object)]
#[derive(Debug, Clone)]
pub struct PyEnvironmentInfo {
#[pyo3(get)]
pub game_id: String,
#[pyo3(get)]
pub title: Option<String>,
#[pyo3(get)]
pub default_fps: Option<u32>,
#[pyo3(get)]
pub tags: Option<Vec<String>>,
}
#[pymethods]
impl PyEnvironmentInfo {
pub fn __repr__(&self) -> String {
format!(
"EnvironmentInfo(game_id={:?}, title={:?})",
self.game_id, self.title
)
}
}
#[pyclass(name = "FrameData", frozen, skip_from_py_object)]
#[derive(Debug, Clone)]
pub struct PyFrameData {
#[pyo3(get)]
pub game_id: String,
#[pyo3(get)]
pub guid: Option<String>,
#[pyo3(get)]
pub state: String,
#[pyo3(get)]
pub levels_completed: u32,
#[pyo3(get)]
pub win_levels: u32,
#[pyo3(get)]
pub available_actions: Vec<u32>,
#[pyo3(get)]
pub full_reset: bool,
}
#[pymethods]
impl PyFrameData {
pub fn __repr__(&self) -> String {
format!(
"FrameData(game_id={:?}, guid={:?}, state={:?})",
self.game_id, self.guid, self.state
)
}
}
#[pyclass(name = "EnvironmentScore", frozen, skip_from_py_object)]
#[derive(Debug, Clone)]
pub struct PyEnvironmentScore {
#[pyo3(get)]
pub id: Option<String>,
#[pyo3(get)]
pub score: f64,
#[pyo3(get)]
pub levels_completed: u32,
#[pyo3(get)]
pub actions: u32,
#[pyo3(get)]
pub completed: Option<bool>,
}
#[pymethods]
impl PyEnvironmentScore {
pub fn __repr__(&self) -> String {
format!(
"EnvironmentScore(id={:?}, score={:.2})",
self.id, self.score
)
}
}
#[pyclass(name = "EnvironmentScorecard", frozen, skip_from_py_object)]
#[derive(Debug, Clone)]
pub struct PyEnvironmentScorecard {
#[pyo3(get)]
pub card_id: String,
#[pyo3(get)]
pub score: f64,
#[pyo3(get)]
pub competition_mode: Option<bool>,
#[pyo3(get)]
pub total_environments_completed: Option<u32>,
#[pyo3(get)]
pub total_environments: Option<u32>,
#[pyo3(get)]
pub total_levels_completed: Option<u32>,
#[pyo3(get)]
pub total_levels: Option<u32>,
#[pyo3(get)]
pub total_actions: Option<u32>,
}
#[pymethods]
impl PyEnvironmentScorecard {
pub fn __repr__(&self) -> String {
format!(
"EnvironmentScorecard(card_id={:?}, score={:.2})",
self.card_id, self.score
)
}
}
#[pyclass(name = "ArcAgiClient")]
pub struct PyArcAgiClient {
inner: Client,
}
#[pymethods]
impl PyArcAgiClient {
#[new]
#[pyo3(signature = (api_key=None, base_url=None, cookie_store=false, proxy=None))]
pub fn new(
api_key: Option<String>,
base_url: Option<String>,
cookie_store: bool,
proxy: Option<String>,
) -> PyResult<Self> {
let mut builder = Client::builder();
if let Some(key) = api_key {
builder = builder.api_key(key);
}
if let Some(url) = base_url {
builder = builder.base_url(url);
}
if cookie_store {
builder = builder.cookie_store(true);
}
if let Some(proxy_url) = proxy {
builder = builder.proxy(proxy_url);
}
let inner = builder
.build()
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
Ok(Self { inner })
}
pub fn get_anonymous_key(&self) -> PyResult<String> {
block_on(self.inner.get_anonymous_key())
}
pub fn list_environments(&self) -> PyResult<Vec<PyEnvironmentInfo>> {
let envs = block_on(self.inner.list_environments())?;
Ok(envs
.into_iter()
.map(|e| PyEnvironmentInfo {
game_id: e.game_id,
title: e.title,
default_fps: e.default_fps,
tags: e.tags,
})
.collect())
}
pub fn get_environment(&self, game_id: String) -> PyResult<PyEnvironmentInfo> {
let info = block_on(self.inner.get_environment(&game_id))?;
Ok(PyEnvironmentInfo {
game_id: info.game_id,
title: info.title,
default_fps: info.default_fps,
tags: info.tags,
})
}
#[pyo3(signature = (source_url=None, tags=None, competition_mode=None))]
pub fn open_scorecard(
&self,
source_url: Option<String>,
tags: Option<Vec<String>>,
competition_mode: Option<bool>,
) -> PyResult<String> {
let mut params = ScorecardParams::new();
if let Some(url) = source_url {
params = params.source_url(url);
}
if let Some(t) = tags {
params = params.tags(t);
}
if let Some(cm) = competition_mode {
params = params.competition_mode(cm);
}
block_on(self.inner.open_scorecard(Some(params)))
}
pub fn get_scorecard(&self, card_id: String) -> PyResult<PyEnvironmentScorecard> {
let card = block_on(self.inner.get_scorecard(&card_id))?;
Ok(PyEnvironmentScorecard {
card_id: card.card_id,
score: card.score,
competition_mode: card.competition_mode,
total_environments_completed: card.total_environments_completed,
total_environments: card.total_environments,
total_levels_completed: card.total_levels_completed,
total_levels: card.total_levels,
total_actions: card.total_actions,
})
}
pub fn close_scorecard(&self, card_id: String) -> PyResult<PyEnvironmentScorecard> {
let card = block_on(self.inner.close_scorecard(&card_id))?;
Ok(PyEnvironmentScorecard {
card_id: card.card_id,
score: card.score,
competition_mode: card.competition_mode,
total_environments_completed: card.total_environments_completed,
total_environments: card.total_environments,
total_levels_completed: card.total_levels_completed,
total_levels: card.total_levels,
total_actions: card.total_actions,
})
}
#[pyo3(signature = (game_id, scorecard_id, guid=None, seed=0))]
pub fn reset(
&self,
game_id: String,
scorecard_id: String,
guid: Option<String>,
seed: u32,
) -> PyResult<PyFrameData> {
let mut params = MakeParams::new(&game_id, &scorecard_id).seed(seed);
if let Some(g) = guid {
params = params.guid(g);
}
let frame = block_on(self.inner.reset(params))?;
Ok(PyFrameData {
game_id: frame.game_id,
guid: frame.guid,
state: frame.state.as_str().to_string(),
levels_completed: frame.levels_completed,
win_levels: frame.win_levels,
available_actions: frame.available_actions,
full_reset: frame.full_reset,
})
}
#[pyo3(signature = (game_id, scorecard_id, guid, action_id, data=None, reasoning=None))]
pub fn step(
&self,
game_id: String,
scorecard_id: String,
guid: String,
action_id: u32,
data: Option<&pyo3::Bound<'_, pyo3::types::PyDict>>,
reasoning: Option<&pyo3::Bound<'_, pyo3::types::PyDict>>,
) -> PyResult<PyFrameData> {
let data_json = data.map(pydict_to_json).transpose()?;
let reasoning_json = reasoning.map(pydict_to_json).transpose()?;
let mut params = StepParams::new(&game_id, &scorecard_id, &guid, action_id);
if let Some(d) = data_json {
params = params.data(d);
}
if let Some(r) = reasoning_json {
params = params.reasoning(r);
}
let frame = block_on(self.inner.step(params))?;
Ok(PyFrameData {
game_id: frame.game_id,
guid: frame.guid,
state: frame.state.as_str().to_string(),
levels_completed: frame.levels_completed,
win_levels: frame.win_levels,
available_actions: frame.available_actions,
full_reset: frame.full_reset,
})
}
pub fn __repr__(&self) -> String {
format!("ArcAgiClient(base_url={:?})", self.inner.base_url())
}
}
pub fn register_python_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyEnvironmentInfo>()?;
m.add_class::<PyFrameData>()?;
m.add_class::<PyEnvironmentScore>()?;
m.add_class::<PyEnvironmentScorecard>()?;
m.add_class::<PyArcAgiClient>()?;
Ok(())
}