1use alloc::{boxed::Box, string::String, sync::Arc};
2use core::fmt;
3
4use vm_core::FieldElement;
5
6use super::DocString;
7use crate::{Felt, SourceSpan, Span, Spanned, ast::Ident, parser::ParsingError};
8
9pub struct Constant {
14 pub span: SourceSpan,
16 pub docs: Option<DocString>,
18 pub name: Ident,
20 pub value: ConstantExpr,
22}
23
24impl Constant {
25 pub fn new(span: SourceSpan, name: Ident, value: ConstantExpr) -> Self {
27 Self { span, docs: None, name, value }
28 }
29
30 pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
32 self.docs = docs.map(DocString::new);
33 self
34 }
35}
36
37impl fmt::Debug for Constant {
38 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
39 f.debug_struct("Constant")
40 .field("docs", &self.docs)
41 .field("name", &self.name)
42 .field("value", &self.value)
43 .finish()
44 }
45}
46
47impl crate::prettier::PrettyPrint for Constant {
48 fn render(&self) -> crate::prettier::Document {
49 use crate::prettier::*;
50
51 let mut doc = self
52 .docs
53 .as_ref()
54 .map(|docstring| docstring.render())
55 .unwrap_or(Document::Empty);
56
57 doc += flatten(const_text("const") + const_text(".") + display(&self.name));
58 doc += const_text("=");
59
60 doc + self.value.render()
61 }
62}
63
64impl Eq for Constant {}
65
66impl PartialEq for Constant {
67 fn eq(&self, other: &Self) -> bool {
68 self.name == other.name && self.value == other.value
69 }
70}
71
72impl Spanned for Constant {
73 fn span(&self) -> SourceSpan {
74 self.span
75 }
76}
77
78#[derive(Clone)]
83pub enum ConstantExpr {
84 Literal(Span<Felt>),
86 Var(Ident),
88 BinaryOp {
90 span: SourceSpan,
91 op: ConstantOp,
92 lhs: Box<ConstantExpr>,
93 rhs: Box<ConstantExpr>,
94 },
95 String(Ident),
96}
97
98impl ConstantExpr {
99 #[track_caller]
104 pub fn expect_literal(&self) -> Felt {
105 match self {
106 Self::Literal(spanned) => spanned.into_inner(),
107 other => panic!("expected constant expression to be a literal, got {other:#?}"),
108 }
109 }
110
111 pub fn expect_string(&self) -> Arc<str> {
112 match self {
113 Self::String(spanned) => spanned.clone().into_inner(),
114 other => panic!("expected constant expression to be a string, got {other:#?}"),
115 }
116 }
117
118 pub fn try_fold(self) -> Result<Self, ParsingError> {
125 match self {
126 Self::String(_) | Self::Literal(_) | Self::Var(_) => Ok(self),
127 Self::BinaryOp { span, op, lhs, rhs } => {
128 if rhs.is_literal() {
129 let rhs = Self::into_inner(rhs).try_fold()?;
130 match rhs {
131 Self::String(ident) => {
132 Err(ParsingError::StringInArithmeticExpression { span: ident.span() })
133 },
134 Self::Literal(rhs) => {
135 let lhs = Self::into_inner(lhs).try_fold()?;
136 match lhs {
137 Self::String(ident) => {
138 Err(ParsingError::StringInArithmeticExpression {
139 span: ident.span(),
140 })
141 },
142 Self::Literal(lhs) => {
143 let lhs = lhs.into_inner();
144 let rhs = rhs.into_inner();
145 let is_division =
146 matches!(op, ConstantOp::Div | ConstantOp::IntDiv);
147 let is_division_by_zero = is_division && rhs == Felt::ZERO;
148 if is_division_by_zero {
149 return Err(ParsingError::DivisionByZero { span });
150 }
151 match op {
152 ConstantOp::Add => {
153 Ok(Self::Literal(Span::new(span, lhs + rhs)))
154 },
155 ConstantOp::Sub => {
156 Ok(Self::Literal(Span::new(span, lhs - rhs)))
157 },
158 ConstantOp::Mul => {
159 Ok(Self::Literal(Span::new(span, lhs * rhs)))
160 },
161 ConstantOp::Div => {
162 Ok(Self::Literal(Span::new(span, lhs / rhs)))
163 },
164 ConstantOp::IntDiv => Ok(Self::Literal(Span::new(
165 span,
166 Felt::new(lhs.as_int() / rhs.as_int()),
167 ))),
168 }
169 },
170 lhs => Ok(Self::BinaryOp {
171 span,
172 op,
173 lhs: Box::new(lhs),
174 rhs: Box::new(Self::Literal(rhs)),
175 }),
176 }
177 },
178 rhs => {
179 let lhs = Self::into_inner(lhs).try_fold()?;
180 Ok(Self::BinaryOp {
181 span,
182 op,
183 lhs: Box::new(lhs),
184 rhs: Box::new(rhs),
185 })
186 },
187 }
188 } else {
189 let lhs = Self::into_inner(lhs).try_fold()?;
190 Ok(Self::BinaryOp { span, op, lhs: Box::new(lhs), rhs })
191 }
192 },
193 }
194 }
195
196 fn is_literal(&self) -> bool {
197 match self {
198 Self::Literal(_) | Self::String(_) => true,
199 Self::Var(_) => false,
200 Self::BinaryOp { lhs, rhs, .. } => lhs.is_literal() && rhs.is_literal(),
201 }
202 }
203
204 #[inline(always)]
205 #[allow(clippy::boxed_local)]
206 fn into_inner(self: Box<Self>) -> Self {
207 *self
208 }
209}
210
211impl Eq for ConstantExpr {}
212
213impl PartialEq for ConstantExpr {
214 fn eq(&self, other: &Self) -> bool {
215 match (self, other) {
216 (Self::Literal(l), Self::Literal(y)) => l == y,
217 (Self::Var(l), Self::Var(y)) => l == y,
218 (
219 Self::BinaryOp { op: lop, lhs: llhs, rhs: lrhs, .. },
220 Self::BinaryOp { op: rop, lhs: rlhs, rhs: rrhs, .. },
221 ) => lop == rop && llhs == rlhs && lrhs == rrhs,
222 _ => false,
223 }
224 }
225}
226
227impl fmt::Debug for ConstantExpr {
228 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
229 match self {
230 Self::Literal(lit) => fmt::Debug::fmt(&**lit, f),
231 Self::Var(name) | Self::String(name) => fmt::Debug::fmt(&**name, f),
232 Self::BinaryOp { op, lhs, rhs, .. } => {
233 f.debug_tuple(op.name()).field(lhs).field(rhs).finish()
234 },
235 }
236 }
237}
238
239impl crate::prettier::PrettyPrint for ConstantExpr {
240 fn render(&self) -> crate::prettier::Document {
241 use crate::prettier::*;
242
243 match self {
244 Self::Literal(literal) => display(literal),
245 Self::Var(ident) | Self::String(ident) => display(ident),
246 Self::BinaryOp { op, lhs, rhs, .. } => {
247 let single_line = lhs.render() + display(op) + rhs.render();
248 let multi_line = lhs.render() + nl() + (display(op)) + rhs.render();
249 single_line | multi_line
250 },
251 }
252 }
253}
254
255impl Spanned for ConstantExpr {
256 fn span(&self) -> SourceSpan {
257 match self {
258 Self::Literal(spanned) => spanned.span(),
259 Self::Var(spanned) | Self::String(spanned) => spanned.span(),
260 Self::BinaryOp { span, .. } => *span,
261 }
262 }
263}
264
265#[derive(Debug, Copy, Clone, PartialEq, Eq)]
270pub enum ConstantOp {
271 Add,
272 Sub,
273 Mul,
274 Div,
275 IntDiv,
276}
277
278impl ConstantOp {
279 const fn name(&self) -> &'static str {
280 match self {
281 Self::Add => "Add",
282 Self::Sub => "Sub",
283 Self::Mul => "Mul",
284 Self::Div => "Div",
285 Self::IntDiv => "IntDiv",
286 }
287 }
288}
289
290impl fmt::Display for ConstantOp {
291 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
292 match self {
293 Self::Add => f.write_str("+"),
294 Self::Sub => f.write_str("-"),
295 Self::Mul => f.write_str("*"),
296 Self::Div => f.write_str("/"),
297 Self::IntDiv => f.write_str("//"),
298 }
299 }
300}