ijzer_lib/parser/
minus.rs

1//! Parses the minus operation `-`.
2//! 
3//! It does not fall under either the binary or unary operations, so it is treated separately.
4//! It first determines whether it is givien one operand or two, and based on that infers the correct operation.
5use super::{check_ok_needed_outputs, gather_operands, ParseNode, ParseNodeFunctional};
6
7use crate::ast_node::{ASTContext, Node, TokenSlice};
8use crate::operations::Operation;
9use crate::syntax_error::SyntaxError;
10use crate::tokens::{Number, Token};
11use crate::types::{FunctionSignature, IJType};
12use anyhow::Result;
13use std::rc::Rc;
14
15#[derive(Debug, PartialEq)]
16pub struct MinusOp;
17impl ParseNode for MinusOp {
18    fn next_node(
19        _: Token,
20        slice: TokenSlice,
21        context: &mut ASTContext,
22    ) -> Result<(Rc<Node>, TokenSlice)> {
23        // let single_tensor = vec![IJType::Tensor];
24        // let two_tensors = vec![IJType::Tensor; 2];
25        let allowed_types = vec![
26            vec![IJType::Tensor(None)],
27            vec![IJType::Number(None)],
28            vec![IJType::Tensor(None), IJType::Tensor(None)],
29            vec![IJType::Number(None), IJType::Number(None)],
30            vec![IJType::Tensor(None), IJType::Number(None)],
31            vec![IJType::Number(None), IJType::Tensor(None)],
32        ];
33        let (operands, rest) = gather_operands(allowed_types, slice, context)?;
34        let input_types = operands
35            .iter()
36            .map(|n| n.output_type.clone())
37            .collect::<Vec<IJType>>();
38        let output_type = if input_types
39            .iter()
40            .all(|t| t.type_match(&IJType::Number(None)))
41        {
42            IJType::Number(None)
43        } else if input_types
44            .iter()
45            .any(|t| t.type_match(&IJType::Tensor(None)))
46        {
47            IJType::Tensor(None)
48        } else {
49            return Err(SyntaxError::TypeError(
50                "S,S or S,T, or T,S or T,T".to_string(),
51                format!("{:?}", input_types),
52            )
53            .into());
54        };
55        if operands.len() == 1 {
56            if let Operation::Number(x) = operands[0].op.clone() {
57                let x_val = x.value;
58                let modified_x_val = if x_val.starts_with('-') {
59                    x_val.trim_start_matches('-').to_string()
60                } else {
61                    format!("-{}", x_val)
62                };
63
64                let node = Node::new(
65                    Operation::Number(Number {
66                        value: modified_x_val,
67                    }),
68                    vec![],
69                    output_type,
70                    vec![],
71                    context,
72                )?;
73                return Ok((Rc::new(node), rest));
74            }
75        }
76        match operands.len() {
77            1 => Ok((
78                Rc::new(Node::new(
79                    Operation::Negate,
80                    input_types,
81                    output_type,
82                    operands,
83                    context,
84                )?),
85                rest,
86            )),
87            2 => Ok((
88                Rc::new(Node::new(
89                    Operation::Subtract,
90                    input_types,
91                    output_type,
92                    operands,
93                    context,
94                )?),
95                rest,
96            )),
97            _ => unreachable!(),
98        }
99    }
100}
101
102impl ParseNodeFunctional for MinusOp {
103    fn next_node_functional_impl(
104        _op: Token,
105        slice: TokenSlice,
106        context: &mut ASTContext,
107        needed_output: Option<&[IJType]>,
108    ) -> Result<(Vec<Rc<Node>>, TokenSlice)> {
109        let rest = slice.move_start(1)?;
110        let mut nodes = vec![];
111        if check_ok_needed_outputs(needed_output, &IJType::Number(None)) {
112            nodes.push(Rc::new(Node::new(
113                Operation::Subtract,
114                vec![],
115                IJType::number_function(2),
116                vec![],
117                context,
118            )?));
119            nodes.push(Rc::new(Node::new(
120                Operation::Negate,
121                vec![],
122                IJType::number_function(1),
123                vec![],
124                context,
125            )?));
126        }
127
128        if check_ok_needed_outputs(needed_output, &IJType::Tensor(None)) {
129            let output_type = IJType::Tensor(None);
130            let input_types = vec![
131                vec![IJType::Tensor(None), IJType::Tensor(None)],
132                vec![IJType::Number(None), IJType::Tensor(None)],
133                vec![IJType::Tensor(None), IJType::Number(None)],
134                vec![IJType::Tensor(None)],
135            ];
136            for input_type in input_types {
137                let op = if input_type.len() == 1 {
138                    Operation::Negate
139                } else {
140                    Operation::Subtract
141                };
142                nodes.push(Rc::new(Node::new(
143                    op,
144                    vec![],
145                    IJType::Function(FunctionSignature::new(input_type, output_type.clone())),
146                    vec![],
147                    context,
148                )?));
149            }
150        }
151        if nodes.is_empty() {
152            return Err(SyntaxError::FunctionSignatureMismatch(
153                format!("{:?}", needed_output),
154                "Fn(T,T->T) or Fn(S,S->S) or Fn(T,S->T) or Fn(S,T->S) or Fn(N,N->N) or Fn(N->N) or Fn(S->S) or Fn(T->T)".to_string(),
155            )
156            .into());
157        }
158        Ok((nodes, rest))
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use crate::parser::parse_str_no_context;
166    use crate::tokens::Number;
167    use std::str::FromStr;
168
169    #[test]
170    fn test_negate_with_scalar() {
171        let result = parse_str_no_context("- 1.0");
172        assert!(result.is_ok());
173        let (node, _) = result.unwrap();
174        assert_eq!(
175            node.op,
176            Operation::Number(Number {
177                value: "-1.0".to_string()
178            })
179        );
180        assert_eq!(node.input_types, vec![]);
181        assert_eq!(node.output_type, IJType::Number(None));
182    }
183    #[test]
184    fn test_double_negate_with_scalar() {
185        let result = parse_str_no_context("-- 1.0");
186        assert!(result.is_ok());
187        let (node, _) = result.unwrap();
188        assert_eq!(
189            node.op,
190            Operation::Number(Number {
191                value: "1.0".to_string()
192            })
193        );
194        assert_eq!(node.input_types, vec![]);
195        assert_eq!(node.output_type, IJType::Number(None));
196    }
197
198    #[test]
199    fn test_negate_with_tensor() {
200        let result = parse_str_no_context("- [1]");
201        assert!(result.is_ok());
202        let (node, _) = result.unwrap();
203        assert_eq!(node.op, Operation::Negate);
204        assert_eq!(node.input_types, vec![IJType::Tensor(None)]);
205        assert_eq!(node.output_type, IJType::Tensor(None));
206    }
207
208    #[test]
209    fn test_subtract_with_tensors() {
210        let result = parse_str_no_context("- [1] [2]");
211        assert!(result.is_ok());
212        let (node, _) = result.unwrap();
213        assert_eq!(node.op, Operation::Subtract);
214        assert_eq!(
215            node.input_types,
216            vec![IJType::Tensor(None), IJType::Tensor(None)]
217        );
218        assert_eq!(node.output_type, IJType::Tensor(None));
219    }
220
221    #[test]
222    fn test_subtract_with_tensor_and_scalar() {
223        let result = parse_str_no_context("- [1] 1");
224        assert!(result.is_ok());
225        let (node, _) = result.unwrap();
226        assert_eq!(node.op, Operation::Subtract);
227        assert_eq!(
228            node.input_types,
229            vec![IJType::Tensor(None), IJType::Number(None)]
230        );
231        assert_eq!(node.output_type, IJType::Tensor(None));
232    }
233
234    #[test]
235    fn test_subtract_with_scalars() {
236        let result = parse_str_no_context("- 1 2");
237        assert!(result.is_ok());
238        let (node, _) = result.unwrap();
239        assert_eq!(node.op, Operation::Subtract);
240        assert_eq!(
241            node.input_types,
242            vec![IJType::Number(None), IJType::Number(None)]
243        );
244        assert_eq!(node.output_type, IJType::Number(None));
245    }
246
247    #[test]
248    fn test_negate_single_group() -> Result<()> {
249        let (node, _) = parse_str_no_context("-(1)")?;
250        assert_eq!(node.op, Operation::Number(Number::from_str("-1")?));
251        assert_eq!(node.operands.len(), 0);
252        Ok(())
253    }
254
255    #[test]
256    fn test_subtract_with_number_types() -> Result<()> {
257        let (node, _) = parse_str_no_context("- 1<a> 2")?;
258        assert_eq!(node.op, Operation::Subtract);
259        assert_eq!(
260            node.input_types,
261            vec![IJType::Number(Some("a".to_string())), IJType::Number(None)]
262        );
263        assert_eq!(node.output_type, IJType::Number(Some("a".to_string())));
264
265        let (node, _) = parse_str_no_context("- 1 2<a>")?;
266        assert_eq!(node.op, Operation::Subtract);
267        assert_eq!(
268            node.input_types,
269            vec![IJType::Number(None), IJType::Number(Some("a".to_string()))]
270        );
271        assert_eq!(node.output_type, IJType::Number(Some("a".to_string())));
272
273        let (node, _) = parse_str_no_context("- 1<a> 2<a>")?;
274        assert_eq!(node.op, Operation::Subtract);
275        assert_eq!(
276            node.input_types,
277            vec![
278                IJType::Number(Some("a".to_string())),
279                IJType::Number(Some("a".to_string()))
280            ]
281        );
282        assert_eq!(node.output_type, IJType::Number(Some("a".to_string())));
283
284        let result = parse_str_no_context("- 1<a> 2<b>");
285        assert!(result.is_err());
286
287        Ok(())
288    }
289
290    #[test]
291    fn test_minus_functional() -> Result<()> {
292        let (node, _) = parse_str_no_context("~-: Fn(T->T)")?;
293        assert_eq!(node.op, Operation::Negate);
294        assert_eq!(node.input_types, vec![]);
295        assert_eq!(
296            node.output_type,
297            IJType::Function(FunctionSignature::new(
298                vec![IJType::Tensor(None)],
299                IJType::Tensor(None)
300            ))
301        );
302        Ok(())
303    }
304}