use acir_field::AcirField;
use serde::{Deserialize, Serialize};
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
#[serde(untagged)]
pub enum ForeignCallParam<F> {
Single(F),
Array(Vec<F>),
}
impl<F> From<F> for ForeignCallParam<F> {
fn from(value: F) -> Self {
ForeignCallParam::Single(value)
}
}
impl<F> From<Vec<F>> for ForeignCallParam<F> {
fn from(values: Vec<F>) -> Self {
ForeignCallParam::Array(values)
}
}
impl<F: AcirField> ForeignCallParam<F> {
pub fn fields(&self) -> Vec<F> {
match self {
ForeignCallParam::Single(value) => vec![*value],
ForeignCallParam::Array(values) => values.clone(),
}
}
pub fn unwrap_field(&self) -> F {
match self {
ForeignCallParam::Single(value) => *value,
ForeignCallParam::Array(_) => panic!("Expected single value, found array"),
}
}
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone, Default)]
pub struct ForeignCallResult<F> {
pub values: Vec<ForeignCallParam<F>>,
}
impl<F> From<F> for ForeignCallResult<F> {
fn from(value: F) -> Self {
ForeignCallResult { values: vec![value.into()] }
}
}
impl<F> From<Vec<F>> for ForeignCallResult<F> {
fn from(values: Vec<F>) -> Self {
ForeignCallResult { values: vec![values.into()] }
}
}
impl<F> From<Vec<ForeignCallParam<F>>> for ForeignCallResult<F> {
fn from(values: Vec<ForeignCallParam<F>>) -> Self {
ForeignCallResult { values }
}
}
#[cfg(test)]
mod tests {
use super::*;
use acir_field::FieldElement;
#[test]
fn test_foreign_call_param_from_single() {
let value = FieldElement::from(42u128);
let param = ForeignCallParam::from(value);
assert_eq!(param, ForeignCallParam::Single(value));
assert_eq!(param.fields(), vec![value]);
assert_eq!(param.unwrap_field(), value);
}
#[test]
fn test_foreign_call_param_from_array() {
let values =
vec![FieldElement::from(1u128), FieldElement::from(2u128), FieldElement::from(3u128)];
let param = ForeignCallParam::from(values.clone());
assert_eq!(param, ForeignCallParam::Array(values.clone()));
assert_eq!(param.fields(), values);
}
#[test]
fn test_foreign_call_param_array_roundtrip() {
let original = vec![
FieldElement::from(10u128),
FieldElement::from(20u128),
FieldElement::from(30u128),
];
let param: ForeignCallParam<FieldElement> = original.clone().into();
let roundtrip = param.fields();
assert_eq!(roundtrip, original);
}
#[test]
fn test_foreign_call_param_single_to_array() {
let value = FieldElement::from(42u128);
let param = ForeignCallParam::Single(value);
assert_eq!(param.fields(), vec![value]);
}
#[test]
#[should_panic(expected = "Expected single value, found array")]
fn test_foreign_call_param_unwrap_field_panics_on_array() {
let param =
ForeignCallParam::Array(vec![FieldElement::from(1u128), FieldElement::from(2u128)]);
param.unwrap_field();
}
#[test]
fn test_foreign_call_result_from_single_value() {
let value = FieldElement::from(42u128);
let result = ForeignCallResult::from(value);
assert_eq!(result.values.len(), 1);
assert_eq!(result.values[0], ForeignCallParam::Single(value));
}
#[test]
fn test_foreign_call_result_from_vec_creates_single_array_output() {
let values =
vec![FieldElement::from(1u128), FieldElement::from(2u128), FieldElement::from(3u128)];
let result = ForeignCallResult::from(values.clone());
assert_eq!(result.values.len(), 1);
assert_eq!(result.values[0], ForeignCallParam::Array(values));
}
#[test]
fn test_foreign_call_result_from_params_creates_multiple_outputs() {
let params = vec![
ForeignCallParam::Single(FieldElement::from(1u128)),
ForeignCallParam::Single(FieldElement::from(2u128)),
ForeignCallParam::Single(FieldElement::from(3u128)),
];
let result = ForeignCallResult::from(params.clone());
assert_eq!(result.values.len(), 3);
assert_eq!(result.values, params);
}
}