1use crate::{Expression, IntegerType, Literal, LiteralVariant, ProgramId, Type};
18use snarkvm::console::program::ArrayType as ConsoleArrayType;
19
20use leo_span::Span;
21use serde::{Deserialize, Serialize};
22use snarkvm::prelude::Network;
23use std::fmt;
24
25#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
27pub struct ArrayType {
28 pub element_type: Box<Type>,
29 pub length: Box<Expression>,
30}
31
32impl ArrayType {
33 pub fn new(element: Type, length: Expression) -> Self {
35 Self { element_type: Box::new(element), length: Box::new(length) }
36 }
37
38 pub fn bit_array(length: u32) -> Self {
40 Self {
41 element_type: Box::new(Type::Boolean),
42 length: Box::new(Expression::Literal(Literal {
43 variant: LiteralVariant::Integer(IntegerType::U32, length.to_string()),
44 id: Default::default(),
45 span: Span::default(),
46 })),
47 }
48 }
49
50 pub fn element_type(&self) -> &Type {
52 &self.element_type
53 }
54
55 pub fn base_element_type(&self) -> &Type {
57 match self.element_type.as_ref() {
58 Type::Array(array_type) => array_type.base_element_type(),
59 type_ => type_,
60 }
61 }
62
63 pub fn from_snarkvm<N: Network>(array_type: &ConsoleArrayType<N>, program_id: ProgramId) -> Self {
64 Self {
65 element_type: Box::new(Type::from_snarkvm(array_type.next_element_type(), program_id)),
66 length: Box::new(Expression::Literal(Literal {
67 variant: LiteralVariant::Integer(IntegerType::U32, array_type.length().to_string().replace("u32", "")),
68 id: Default::default(),
69 span: Span::default(),
70 })),
71 }
72 }
73
74 pub fn to_snarkvm<N: Network>(&self) -> anyhow::Result<ConsoleArrayType<N>> {
75 let length = if let Expression::Literal(literal) = &*self.length {
76 match &literal.variant {
77 LiteralVariant::Integer(_, s) => {
78 s.parse::<u32>().map_err(|e| anyhow::anyhow!("Array length is not a valid u32: {e}"))?
79 }
80 LiteralVariant::Unsuffixed(s) => {
81 s.parse::<u32>().map_err(|e| anyhow::anyhow!("Array length is not a valid u32: {e}"))?
82 }
83 _ => anyhow::bail!("Array length must be an integer literal"),
84 }
85 } else {
86 anyhow::bail!("Array length must be an integer literal")
87 };
88
89 ConsoleArrayType::new(self.element_type.to_snarkvm()?, vec![snarkvm::console::types::U32::new(length)])
90 }
91}
92
93impl fmt::Display for ArrayType {
94 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
95 if let Expression::Literal(literal) = &*self.length
97 && let LiteralVariant::Integer(_, s) = &literal.variant
98 {
99 return write!(f, "[{}; {s}]", self.element_type);
100 }
101
102 write!(f, "[{}; {}]", self.element_type, self.length)
103 }
104}
105
106impl From<ArrayType> for Type {
107 fn from(value: ArrayType) -> Self {
108 Type::Array(value)
109 }
110}