1use alloc::{boxed::Box, string::String};
2use core::fmt;
3
4use vm_core::FieldElement;
5
6use crate::{Felt, SourceSpan, Span, Spanned, ast::Ident, parser::ParsingError};
7
8pub struct Constant {
13 pub span: SourceSpan,
15 pub docs: Option<Span<String>>,
17 pub name: Ident,
19 pub value: ConstantExpr,
21}
22
23impl Constant {
24 pub fn new(span: SourceSpan, name: Ident, value: ConstantExpr) -> Self {
26 Self { span, docs: None, name, value }
27 }
28
29 pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
31 self.docs = docs;
32 self
33 }
34}
35
36impl fmt::Debug for Constant {
37 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
38 f.debug_struct("Constant")
39 .field("docs", &self.docs)
40 .field("name", &self.name)
41 .field("value", &self.value)
42 .finish()
43 }
44}
45
46impl crate::prettier::PrettyPrint for Constant {
47 fn render(&self) -> crate::prettier::Document {
48 use crate::prettier::*;
49
50 let mut doc = Document::Empty;
51 if let Some(docs) = self.docs.as_ref() {
52 let fragment =
53 docs.lines().map(text).reduce(|acc, line| acc + nl() + const_text("#! ") + line);
54
55 if let Some(fragment) = fragment {
56 doc += fragment;
57 }
58 }
59
60 doc += nl();
61 doc += flatten(const_text("const") + const_text(".") + display(&self.name));
62 doc += const_text("=");
63
64 doc + self.value.render()
65 }
66}
67
68impl Eq for Constant {}
69
70impl PartialEq for Constant {
71 fn eq(&self, other: &Self) -> bool {
72 self.name == other.name && self.value == other.value
73 }
74}
75
76impl Spanned for Constant {
77 fn span(&self) -> SourceSpan {
78 self.span
79 }
80}
81
82pub enum ConstantExpr {
87 Literal(Span<Felt>),
89 Var(Ident),
91 BinaryOp {
93 span: SourceSpan,
94 op: ConstantOp,
95 lhs: Box<ConstantExpr>,
96 rhs: Box<ConstantExpr>,
97 },
98}
99
100impl ConstantExpr {
101 #[track_caller]
106 pub fn expect_literal(&self) -> Felt {
107 match self {
108 Self::Literal(spanned) => spanned.into_inner(),
109 other => panic!("expected constant expression to be a literal, got {other:#?}"),
110 }
111 }
112
113 pub fn try_fold(self) -> Result<Self, ParsingError> {
120 match self {
121 Self::Literal(_) | Self::Var(_) => Ok(self),
122 Self::BinaryOp { span, op, lhs, rhs } => {
123 if rhs.is_literal() {
124 let rhs = Self::into_inner(rhs).try_fold()?;
125 match rhs {
126 Self::Literal(rhs) => {
127 let lhs = Self::into_inner(lhs).try_fold()?;
128 match lhs {
129 Self::Literal(lhs) => {
130 let lhs = lhs.into_inner();
131 let rhs = rhs.into_inner();
132 let is_division =
133 matches!(op, ConstantOp::Div | ConstantOp::IntDiv);
134 let is_division_by_zero = is_division && rhs == Felt::ZERO;
135 if is_division_by_zero {
136 return Err(ParsingError::DivisionByZero { span });
137 }
138 match op {
139 ConstantOp::Add => {
140 Ok(Self::Literal(Span::new(span, lhs + rhs)))
141 },
142 ConstantOp::Sub => {
143 Ok(Self::Literal(Span::new(span, lhs - rhs)))
144 },
145 ConstantOp::Mul => {
146 Ok(Self::Literal(Span::new(span, lhs * rhs)))
147 },
148 ConstantOp::Div => {
149 Ok(Self::Literal(Span::new(span, lhs / rhs)))
150 },
151 ConstantOp::IntDiv => Ok(Self::Literal(Span::new(
152 span,
153 Felt::new(lhs.as_int() / rhs.as_int()),
154 ))),
155 }
156 },
157 lhs => Ok(Self::BinaryOp {
158 span,
159 op,
160 lhs: Box::new(lhs),
161 rhs: Box::new(Self::Literal(rhs)),
162 }),
163 }
164 },
165 rhs => {
166 let lhs = Self::into_inner(lhs).try_fold()?;
167 Ok(Self::BinaryOp {
168 span,
169 op,
170 lhs: Box::new(lhs),
171 rhs: Box::new(rhs),
172 })
173 },
174 }
175 } else {
176 let lhs = Self::into_inner(lhs).try_fold()?;
177 Ok(Self::BinaryOp { span, op, lhs: Box::new(lhs), rhs })
178 }
179 },
180 }
181 }
182
183 fn is_literal(&self) -> bool {
184 match self {
185 Self::Literal(_) => true,
186 Self::Var(_) => false,
187 Self::BinaryOp { lhs, rhs, .. } => lhs.is_literal() && rhs.is_literal(),
188 }
189 }
190
191 #[inline(always)]
192 #[allow(clippy::boxed_local)]
193 fn into_inner(self: Box<Self>) -> Self {
194 *self
195 }
196}
197
198impl Eq for ConstantExpr {}
199
200impl PartialEq for ConstantExpr {
201 fn eq(&self, other: &Self) -> bool {
202 match (self, other) {
203 (Self::Literal(l), Self::Literal(y)) => l == y,
204 (Self::Var(l), Self::Var(y)) => l == y,
205 (
206 Self::BinaryOp { op: lop, lhs: llhs, rhs: lrhs, .. },
207 Self::BinaryOp { op: rop, lhs: rlhs, rhs: rrhs, .. },
208 ) => lop == rop && llhs == rlhs && lrhs == rrhs,
209 _ => false,
210 }
211 }
212}
213
214impl fmt::Debug for ConstantExpr {
215 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
216 match self {
217 Self::Literal(lit) => fmt::Debug::fmt(&**lit, f),
218 Self::Var(name) => fmt::Debug::fmt(&**name, f),
219 Self::BinaryOp { op, lhs, rhs, .. } => {
220 f.debug_tuple(op.name()).field(lhs).field(rhs).finish()
221 },
222 }
223 }
224}
225
226impl crate::prettier::PrettyPrint for ConstantExpr {
227 fn render(&self) -> crate::prettier::Document {
228 use crate::prettier::*;
229
230 match self {
231 Self::Literal(literal) => display(literal),
232 Self::Var(ident) => display(ident),
233 Self::BinaryOp { op, lhs, rhs, .. } => {
234 let single_line = lhs.render() + display(op) + rhs.render();
235 let multi_line = lhs.render() + nl() + (display(op)) + rhs.render();
236 single_line | multi_line
237 },
238 }
239 }
240}
241
242impl Spanned for ConstantExpr {
243 fn span(&self) -> SourceSpan {
244 match self {
245 Self::Literal(spanned) => spanned.span(),
246 Self::Var(spanned) => spanned.span(),
247 Self::BinaryOp { span, .. } => *span,
248 }
249 }
250}
251
252#[derive(Debug, Copy, Clone, PartialEq, Eq)]
257pub enum ConstantOp {
258 Add,
259 Sub,
260 Mul,
261 Div,
262 IntDiv,
263}
264
265impl ConstantOp {
266 const fn name(&self) -> &'static str {
267 match self {
268 Self::Add => "Add",
269 Self::Sub => "Sub",
270 Self::Mul => "Mul",
271 Self::Div => "Div",
272 Self::IntDiv => "IntDiv",
273 }
274 }
275}
276
277impl fmt::Display for ConstantOp {
278 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
279 match self {
280 Self::Add => f.write_str("+"),
281 Self::Sub => f.write_str("-"),
282 Self::Mul => f.write_str("*"),
283 Self::Div => f.write_str("/"),
284 Self::IntDiv => f.write_str("//"),
285 }
286 }
287}