spade_hir/
expression.rs

1use std::borrow::BorrowMut;
2
3use crate::{ConstGenericWithId, Pattern, TypeExpression, TypeParam, UnitKind};
4
5use super::{Block, NameID};
6use num::{BigInt, BigUint};
7use serde::{Deserialize, Serialize};
8use spade_common::{
9    id_tracker::ExprID,
10    location_info::Loc,
11    name::{Identifier, Path},
12    num_ext::InfallibleToBigInt,
13};
14
15#[derive(Clone, Copy, PartialEq, Debug, Serialize, Deserialize)]
16pub enum BinaryOperator {
17    Add,
18    Sub,
19    Mul,
20    Div,
21    Mod,
22    Eq,
23    NotEq,
24    Gt,
25    Lt,
26    Ge,
27    Le,
28    LeftShift,
29    RightShift,
30    ArithmeticRightShift,
31    LogicalAnd,
32    LogicalOr,
33    LogicalXor,
34    BitwiseOr,
35    BitwiseAnd,
36    BitwiseXor,
37}
38
39impl std::fmt::Display for BinaryOperator {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        match self {
42            BinaryOperator::Add => write!(f, "+"),
43            BinaryOperator::Sub => write!(f, "-"),
44            BinaryOperator::Mul => write!(f, "*"),
45            BinaryOperator::Div => write!(f, "/"),
46            BinaryOperator::Mod => write!(f, "%"),
47            BinaryOperator::Eq => write!(f, "=="),
48            BinaryOperator::NotEq => write!(f, "!="),
49            BinaryOperator::Gt => write!(f, ">"),
50            BinaryOperator::Lt => write!(f, "<"),
51            BinaryOperator::Ge => write!(f, ">="),
52            BinaryOperator::Le => write!(f, "<="),
53            BinaryOperator::LeftShift => write!(f, ">>"),
54            BinaryOperator::RightShift => write!(f, "<<"),
55            BinaryOperator::ArithmeticRightShift => write!(f, ">>>"),
56            BinaryOperator::LogicalAnd => write!(f, "&&"),
57            BinaryOperator::LogicalOr => write!(f, "||"),
58            BinaryOperator::LogicalXor => write!(f, "^^"),
59            BinaryOperator::BitwiseOr => write!(f, "|"),
60            BinaryOperator::BitwiseAnd => write!(f, "&"),
61            BinaryOperator::BitwiseXor => write!(f, "^"),
62        }
63    }
64}
65
66#[derive(Clone, Copy, PartialEq, Debug, Serialize, Deserialize)]
67pub enum UnaryOperator {
68    Sub,
69    Not,
70    BitwiseNot,
71    Dereference,
72    Reference,
73}
74
75impl std::fmt::Display for UnaryOperator {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        match self {
78            UnaryOperator::Sub => write!(f, "-"),
79            UnaryOperator::Not => write!(f, "!"),
80            UnaryOperator::BitwiseNot => write!(f, "~"),
81            UnaryOperator::Dereference => write!(f, "*"),
82            UnaryOperator::Reference => write!(f, "&"),
83        }
84    }
85}
86
87// Named arguments are used for both type parameters in turbofishes and in argument lists. T is the
88// right hand side of a binding, i.e. an expression in an argument list
89#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
90pub enum NamedArgument<T> {
91    /// Binds the arguent named LHS in the outer scope to the expression
92    Full(Loc<Identifier>, Loc<T>),
93    /// Binds a local variable to an argument with the same name
94    Short(Loc<Identifier>, Loc<T>),
95}
96
97/// Specifies how an argument is bound. Mainly used for error reporting without
98/// code duplication
99#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
100pub enum ArgumentKind {
101    Positional,
102    Named,
103    ShortNamed,
104}
105
106#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
107pub enum ArgumentList<T> {
108    Named(Vec<NamedArgument<T>>),
109    Positional(Vec<Loc<T>>),
110}
111
112impl<T> ArgumentList<T> {
113    pub fn expressions(&self) -> Vec<&Loc<T>> {
114        match self {
115            ArgumentList::Named(n) => n
116                .iter()
117                .map(|arg| match &arg {
118                    NamedArgument::Full(_, expr) => expr,
119                    NamedArgument::Short(_, expr) => expr,
120                })
121                .collect(),
122            ArgumentList::Positional(arg) => arg.iter().collect(),
123        }
124    }
125    pub fn expressions_mut(&mut self) -> Vec<&mut Loc<T>> {
126        match self {
127            ArgumentList::Named(n) => n
128                .iter_mut()
129                .map(|arg| match arg {
130                    NamedArgument::Full(_, expr) => expr,
131                    NamedArgument::Short(_, expr) => expr,
132                })
133                .collect(),
134            ArgumentList::Positional(arg) => arg.iter_mut().collect(),
135        }
136    }
137}
138
139#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
140pub struct Argument<T> {
141    pub target: Loc<Identifier>,
142    pub value: Loc<T>,
143    pub kind: ArgumentKind,
144}
145
146// FIXME: Migrate entity, pipeline and fn instantiation to this
147#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
148pub enum CallKind {
149    Function,
150    Entity(Loc<()>),
151    Pipeline {
152        inst_loc: Loc<()>,
153        depth: Loc<TypeExpression>,
154        /// An expression ID for which the type inferer will infer the depth of the instantiated
155        /// pipeline, i.e. inst(<this>)
156        depth_typeexpr_id: ExprID,
157    },
158}
159
160#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
161pub enum TriLiteral {
162    Low,
163    High,
164    HighImp,
165}
166
167#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
168pub enum IntLiteralKind {
169    Unsized,
170    Signed(BigUint),
171    Unsigned(BigUint),
172}
173
174#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
175pub enum PipelineRefKind {
176    Absolute(Loc<NameID>),
177    Relative(Loc<TypeExpression>),
178}
179
180#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
181pub struct OuterLambdaParam {
182    pub name_in_lambda: NameID,
183    pub name_in_body: Loc<NameID>,
184}
185
186#[derive(PartialEq, Eq, Debug, Clone, Copy, Serialize, Deserialize)]
187pub enum Safety {
188    Default,
189    Unsafe,
190}
191
192#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
193pub struct LambdaTypeParams {
194    /// The parameters that will contain the types of each argument
195    pub arg: Vec<Loc<TypeParam>>,
196    pub output: Loc<TypeParam>,
197    /// The parameters that will contain the types of the captured variables
198    pub captures: Vec<Loc<TypeParam>>,
199    /// The type parameters that are inherited from the unit in which the lambda is defined
200    pub outer: Vec<Loc<TypeParam>>,
201}
202
203impl LambdaTypeParams {
204    pub fn all(&self) -> impl Iterator<Item = &Loc<TypeParam>> {
205        let Self {
206            arg,
207            output,
208            captures,
209            outer,
210        } = self;
211        arg.iter().chain(Some(output)).chain(captures).chain(outer)
212    }
213}
214
215#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
216pub enum ExprKind {
217    Error,
218    Identifier(NameID),
219    IntLiteral(BigInt, IntLiteralKind),
220    BoolLiteral(bool),
221    TriLiteral(TriLiteral),
222    TypeLevelInteger(NameID),
223    CreatePorts,
224    TupleLiteral(Vec<Loc<Expression>>),
225    ArrayLiteral(Vec<Loc<Expression>>),
226    ArrayShorthandLiteral(Box<Loc<Expression>>, Loc<ConstGenericWithId>),
227    Index(Box<Loc<Expression>>, Box<Loc<Expression>>),
228    RangeIndex {
229        target: Box<Loc<Expression>>,
230        start: Loc<ConstGenericWithId>,
231        end: Loc<ConstGenericWithId>,
232    },
233    TupleIndex(Box<Loc<Expression>>, Loc<u128>),
234    FieldAccess(Box<Loc<Expression>>, Loc<Identifier>),
235    MethodCall {
236        target: Box<Loc<Expression>>,
237        name: Loc<Identifier>,
238        args: Loc<ArgumentList<Expression>>,
239        call_kind: CallKind,
240        turbofish: Option<Loc<ArgumentList<TypeExpression>>>,
241        safety: Safety,
242    },
243    Call {
244        kind: CallKind,
245        callee: Loc<NameID>,
246        args: Loc<ArgumentList<Expression>>,
247        turbofish: Option<Loc<ArgumentList<TypeExpression>>>,
248        safety: Safety,
249    },
250    BinaryOperator(
251        Box<Loc<Expression>>,
252        Loc<BinaryOperator>,
253        Box<Loc<Expression>>,
254    ),
255    UnaryOperator(Loc<UnaryOperator>, Box<Loc<Expression>>),
256    Match(Box<Loc<Expression>>, Vec<(Loc<Pattern>, Loc<Expression>)>),
257    Block(Box<Block>),
258    If(
259        Box<Loc<Expression>>,
260        Box<Loc<Expression>>,
261        Box<Loc<Expression>>,
262    ),
263    TypeLevelIf(
264        // FIXME: Having a random u64 is not great, let's make TypeExpressions always have associated ids
265        Loc<ConstGenericWithId>,
266        Box<Loc<Expression>>,
267        Box<Loc<Expression>>,
268    ),
269    PipelineRef {
270        stage: Loc<PipelineRefKind>,
271        name: Loc<NameID>,
272        declares_name: bool,
273        /// An expression ID which after typeinference will contain the absolute depth
274        /// of this referenced value
275        depth_typeexpr_id: ExprID,
276    },
277    LambdaDef {
278        unit_kind: Loc<UnitKind>,
279        /// The type that this lambda definition creates
280        lambda_type: NameID,
281        type_params: LambdaTypeParams,
282        outer_generic_params: Vec<OuterLambdaParam>,
283        /// The unit which is the `call` method on this lambda
284        lambda_unit: NameID,
285        arguments: Vec<Loc<Pattern>>,
286        body: Box<Loc<Expression>>,
287        clock: Option<Loc<NameID>>,
288        captures: Vec<(Loc<Identifier>, Loc<NameID>)>,
289    },
290    StageValid,
291    StageReady,
292    StaticUnreachable(Loc<String>),
293    // This is a special case expression which is never created in user code, but which can be used
294    // in type inference to create virtual expressions with specific IDs
295    Null,
296}
297
298impl ExprKind {
299    pub fn with_id(self, id: ExprID) -> Expression {
300        Expression { kind: self, id }
301    }
302
303    // FIXME: These really should be #[cfg(test)]'d away
304    pub fn idless(self) -> Expression {
305        Expression {
306            kind: self,
307            id: ExprID(0),
308        }
309    }
310
311    pub fn int_literal(val: i32) -> Self {
312        Self::IntLiteral(val.to_bigint(), IntLiteralKind::Unsized)
313    }
314}
315
316#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct Expression {
318    pub kind: ExprKind,
319    // This ID is used to associate types with the expression
320    pub id: ExprID,
321}
322
323impl Expression {
324    /// Create a new expression referencing an identifier with the specified
325    /// id and name
326    pub fn ident(expr_id: ExprID, name_id: u64, name: &str) -> Expression {
327        ExprKind::Identifier(NameID(name_id, Path::from_strs(&[name]))).with_id(expr_id)
328    }
329
330    /// Returns the block that is this expression if it is a block, an error if it is an Error node, and panics if the expression is not a block or error
331    pub fn assume_block(&self) -> std::result::Result<&Block, ()> {
332        if let ExprKind::Block(ref block) = self.kind {
333            Ok(block)
334        } else if let ExprKind::Error = self.kind {
335            Err(())
336        } else {
337            panic!("Expression is not a block")
338        }
339    }
340
341    /// Returns the block that is this expression. Panics if the expression is not a block
342    pub fn assume_block_mut(&mut self) -> &mut Block {
343        if let ExprKind::Block(block) = &mut self.kind {
344            block.borrow_mut()
345        } else {
346            panic!("Expression is not a block")
347        }
348    }
349}
350
351impl PartialEq for Expression {
352    fn eq(&self, other: &Self) -> bool {
353        self.kind == other.kind
354    }
355}