1use super::{check_ok_needed_outputs, gather_operands, ParseNode, ParseNodeFunctional};
6
7use crate::ast_node::{ASTContext, Node, TokenSlice};
8use crate::operations::Operation;
9use crate::syntax_error::SyntaxError;
10use crate::tokens::{Number, Token};
11use crate::types::{FunctionSignature, IJType};
12use anyhow::Result;
13use std::rc::Rc;
14
15#[derive(Debug, PartialEq)]
16pub struct MinusOp;
17impl ParseNode for MinusOp {
18 fn next_node(
19 _: Token,
20 slice: TokenSlice,
21 context: &mut ASTContext,
22 ) -> Result<(Rc<Node>, TokenSlice)> {
23 let allowed_types = vec![
26 vec![IJType::Tensor(None)],
27 vec![IJType::Number(None)],
28 vec![IJType::Tensor(None), IJType::Tensor(None)],
29 vec![IJType::Number(None), IJType::Number(None)],
30 vec![IJType::Tensor(None), IJType::Number(None)],
31 vec![IJType::Number(None), IJType::Tensor(None)],
32 ];
33 let (operands, rest) = gather_operands(allowed_types, slice, context)?;
34 let input_types = operands
35 .iter()
36 .map(|n| n.output_type.clone())
37 .collect::<Vec<IJType>>();
38 let output_type = if input_types
39 .iter()
40 .all(|t| t.type_match(&IJType::Number(None)))
41 {
42 IJType::Number(None)
43 } else if input_types
44 .iter()
45 .any(|t| t.type_match(&IJType::Tensor(None)))
46 {
47 IJType::Tensor(None)
48 } else {
49 return Err(SyntaxError::TypeError(
50 "S,S or S,T, or T,S or T,T".to_string(),
51 format!("{:?}", input_types),
52 )
53 .into());
54 };
55 if operands.len() == 1 {
56 if let Operation::Number(x) = operands[0].op.clone() {
57 let x_val = x.value;
58 let modified_x_val = if x_val.starts_with('-') {
59 x_val.trim_start_matches('-').to_string()
60 } else {
61 format!("-{}", x_val)
62 };
63
64 let node = Node::new(
65 Operation::Number(Number {
66 value: modified_x_val,
67 }),
68 vec![],
69 output_type,
70 vec![],
71 context,
72 )?;
73 return Ok((Rc::new(node), rest));
74 }
75 }
76 match operands.len() {
77 1 => Ok((
78 Rc::new(Node::new(
79 Operation::Negate,
80 input_types,
81 output_type,
82 operands,
83 context,
84 )?),
85 rest,
86 )),
87 2 => Ok((
88 Rc::new(Node::new(
89 Operation::Subtract,
90 input_types,
91 output_type,
92 operands,
93 context,
94 )?),
95 rest,
96 )),
97 _ => unreachable!(),
98 }
99 }
100}
101
102impl ParseNodeFunctional for MinusOp {
103 fn next_node_functional_impl(
104 _op: Token,
105 slice: TokenSlice,
106 context: &mut ASTContext,
107 needed_output: Option<&[IJType]>,
108 ) -> Result<(Vec<Rc<Node>>, TokenSlice)> {
109 let rest = slice.move_start(1)?;
110 let mut nodes = vec![];
111 if check_ok_needed_outputs(needed_output, &IJType::Number(None)) {
112 nodes.push(Rc::new(Node::new(
113 Operation::Subtract,
114 vec![],
115 IJType::number_function(2),
116 vec![],
117 context,
118 )?));
119 nodes.push(Rc::new(Node::new(
120 Operation::Negate,
121 vec![],
122 IJType::number_function(1),
123 vec![],
124 context,
125 )?));
126 }
127
128 if check_ok_needed_outputs(needed_output, &IJType::Tensor(None)) {
129 let output_type = IJType::Tensor(None);
130 let input_types = vec![
131 vec![IJType::Tensor(None), IJType::Tensor(None)],
132 vec![IJType::Number(None), IJType::Tensor(None)],
133 vec![IJType::Tensor(None), IJType::Number(None)],
134 vec![IJType::Tensor(None)],
135 ];
136 for input_type in input_types {
137 let op = if input_type.len() == 1 {
138 Operation::Negate
139 } else {
140 Operation::Subtract
141 };
142 nodes.push(Rc::new(Node::new(
143 op,
144 vec![],
145 IJType::Function(FunctionSignature::new(input_type, output_type.clone())),
146 vec![],
147 context,
148 )?));
149 }
150 }
151 if nodes.is_empty() {
152 return Err(SyntaxError::FunctionSignatureMismatch(
153 format!("{:?}", needed_output),
154 "Fn(T,T->T) or Fn(S,S->S) or Fn(T,S->T) or Fn(S,T->S) or Fn(N,N->N) or Fn(N->N) or Fn(S->S) or Fn(T->T)".to_string(),
155 )
156 .into());
157 }
158 Ok((nodes, rest))
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use crate::parser::parse_str_no_context;
166 use crate::tokens::Number;
167 use std::str::FromStr;
168
169 #[test]
170 fn test_negate_with_scalar() {
171 let result = parse_str_no_context("- 1.0");
172 assert!(result.is_ok());
173 let (node, _) = result.unwrap();
174 assert_eq!(
175 node.op,
176 Operation::Number(Number {
177 value: "-1.0".to_string()
178 })
179 );
180 assert_eq!(node.input_types, vec![]);
181 assert_eq!(node.output_type, IJType::Number(None));
182 }
183 #[test]
184 fn test_double_negate_with_scalar() {
185 let result = parse_str_no_context("-- 1.0");
186 assert!(result.is_ok());
187 let (node, _) = result.unwrap();
188 assert_eq!(
189 node.op,
190 Operation::Number(Number {
191 value: "1.0".to_string()
192 })
193 );
194 assert_eq!(node.input_types, vec![]);
195 assert_eq!(node.output_type, IJType::Number(None));
196 }
197
198 #[test]
199 fn test_negate_with_tensor() {
200 let result = parse_str_no_context("- [1]");
201 assert!(result.is_ok());
202 let (node, _) = result.unwrap();
203 assert_eq!(node.op, Operation::Negate);
204 assert_eq!(node.input_types, vec![IJType::Tensor(None)]);
205 assert_eq!(node.output_type, IJType::Tensor(None));
206 }
207
208 #[test]
209 fn test_subtract_with_tensors() {
210 let result = parse_str_no_context("- [1] [2]");
211 assert!(result.is_ok());
212 let (node, _) = result.unwrap();
213 assert_eq!(node.op, Operation::Subtract);
214 assert_eq!(
215 node.input_types,
216 vec![IJType::Tensor(None), IJType::Tensor(None)]
217 );
218 assert_eq!(node.output_type, IJType::Tensor(None));
219 }
220
221 #[test]
222 fn test_subtract_with_tensor_and_scalar() {
223 let result = parse_str_no_context("- [1] 1");
224 assert!(result.is_ok());
225 let (node, _) = result.unwrap();
226 assert_eq!(node.op, Operation::Subtract);
227 assert_eq!(
228 node.input_types,
229 vec![IJType::Tensor(None), IJType::Number(None)]
230 );
231 assert_eq!(node.output_type, IJType::Tensor(None));
232 }
233
234 #[test]
235 fn test_subtract_with_scalars() {
236 let result = parse_str_no_context("- 1 2");
237 assert!(result.is_ok());
238 let (node, _) = result.unwrap();
239 assert_eq!(node.op, Operation::Subtract);
240 assert_eq!(
241 node.input_types,
242 vec![IJType::Number(None), IJType::Number(None)]
243 );
244 assert_eq!(node.output_type, IJType::Number(None));
245 }
246
247 #[test]
248 fn test_negate_single_group() -> Result<()> {
249 let (node, _) = parse_str_no_context("-(1)")?;
250 assert_eq!(node.op, Operation::Number(Number::from_str("-1")?));
251 assert_eq!(node.operands.len(), 0);
252 Ok(())
253 }
254
255 #[test]
256 fn test_subtract_with_number_types() -> Result<()> {
257 let (node, _) = parse_str_no_context("- 1<a> 2")?;
258 assert_eq!(node.op, Operation::Subtract);
259 assert_eq!(
260 node.input_types,
261 vec![IJType::Number(Some("a".to_string())), IJType::Number(None)]
262 );
263 assert_eq!(node.output_type, IJType::Number(Some("a".to_string())));
264
265 let (node, _) = parse_str_no_context("- 1 2<a>")?;
266 assert_eq!(node.op, Operation::Subtract);
267 assert_eq!(
268 node.input_types,
269 vec![IJType::Number(None), IJType::Number(Some("a".to_string()))]
270 );
271 assert_eq!(node.output_type, IJType::Number(Some("a".to_string())));
272
273 let (node, _) = parse_str_no_context("- 1<a> 2<a>")?;
274 assert_eq!(node.op, Operation::Subtract);
275 assert_eq!(
276 node.input_types,
277 vec![
278 IJType::Number(Some("a".to_string())),
279 IJType::Number(Some("a".to_string()))
280 ]
281 );
282 assert_eq!(node.output_type, IJType::Number(Some("a".to_string())));
283
284 let result = parse_str_no_context("- 1<a> 2<b>");
285 assert!(result.is_err());
286
287 Ok(())
288 }
289
290 #[test]
291 fn test_minus_functional() -> Result<()> {
292 let (node, _) = parse_str_no_context("~-: Fn(T->T)")?;
293 assert_eq!(node.op, Operation::Negate);
294 assert_eq!(node.input_types, vec![]);
295 assert_eq!(
296 node.output_type,
297 IJType::Function(FunctionSignature::new(
298 vec![IJType::Tensor(None)],
299 IJType::Tensor(None)
300 ))
301 );
302 Ok(())
303 }
304}