use std::{ops::Deref, str::FromStr};
use cairo_felt::Felt252;
use cairo_lang_runner::Arg;
use serde::{de::Visitor, Deserialize};
use serde_json::Value;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum VecArgError {
#[error("failed to parse array")]
ArrayParseError,
#[error("failed to parse number: {0}")]
NumberParseError(#[from] std::num::ParseIntError),
#[error("failed to parse bigint: {0}")]
BigIntParseError(#[from] num_bigint::ParseBigIntError),
#[error("number out of range")]
NumberOutOfRange,
}
#[derive(Debug)]
pub struct VecArg(Vec<Arg>);
impl VecArg {
#[must_use]
pub fn new(args: Vec<Arg>) -> Self {
Self(args)
}
}
impl Deref for VecArg {
type Target = Vec<Arg>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<VecArg> for Vec<Arg> {
fn from(args: VecArg) -> Self {
args.0
}
}
impl From<Vec<Arg>> for VecArg {
fn from(args: Vec<Arg>) -> Self {
Self(args)
}
}
impl VecArg {
fn visit_seq_helper(seq: &[Value]) -> Result<Self, VecArgError> {
let iterator = seq.iter();
let mut args = Vec::new();
for arg in iterator {
match arg {
Value::Number(n) => {
let n = Felt252::from(n.as_u64().ok_or(VecArgError::NumberOutOfRange)?);
args.push(Arg::Value(n));
}
Value::String(n) => {
let n = Felt252::from(num_bigint::BigUint::from_str(n)?);
args.push(Arg::Value(n));
}
Value::Array(a) => {
let mut inner_args = Vec::new();
for x in a {
match x {
Value::Number(n) => {
let n =
Felt252::from(n.as_u64().ok_or(VecArgError::NumberOutOfRange)?);
inner_args.push(Felt252::new(n));
}
Value::String(n) => {
let n = Felt252::from(num_bigint::BigUint::from_str(n)?);
inner_args.push(Felt252::new(n));
}
_ => return Err(VecArgError::ArrayParseError),
}
}
args.push(Arg::Array(inner_args));
}
_ => (),
}
}
Ok(Self::new(args))
}
}
impl<'de> Visitor<'de> for VecArg {
type Value = VecArg;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a list of arguments")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut args = Vec::new();
while let Some(arg) = seq.next_element()? {
match arg {
Value::Number(n) => args.push(Value::Number(n)),
Value::String(n) => args.push(Value::String(n)),
Value::Array(a) => args.push(Value::Array(a)),
_ => return Err(serde::de::Error::custom("Invalid type")),
}
}
Self::visit_seq_helper(&args).map_err(|e| serde::de::Error::custom(e.to_string()))
}
}
impl<'de> Deserialize<'de> for VecArg {
fn deserialize<D>(deserializer: D) -> Result<VecArg, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_seq(VecArg(Vec::new()))
}
}