ijzer_lib/parser/
lambda_variable.rs1use super::{ParseNode, ParseNodeFunctional};
6
7use crate::ast_node::{ASTContext, Node, TokenSlice};
8use crate::operations::Operation;
9use crate::syntax_error::SyntaxError;
10use crate::tokens::Token;
11use crate::types::IJType;
12use anyhow::Result;
13use std::rc::Rc;
14
15pub fn parse_lambda_assign_lhs(
16 op: Token,
17 slice: TokenSlice,
18 context: &mut ASTContext,
19) -> Result<(Rc<Node>, TokenSlice)> {
20 let name = match op {
21 Token::LambdaVariable(v) => v.name,
22 _ => unreachable!(),
23 };
24 let (node_type, rest) = if !slice.is_empty()
25 && context.get_token_at_index(slice.start)? == &Token::TypeDeclaration
26 {
27 let mut rest = slice.move_start(1)?;
28 let (parsed_type, type_end) = IJType::parse_tokens(&context.get_tokens_from_slice(rest))?;
29 rest = rest.move_start(type_end)?;
30 (parsed_type, rest)
31 } else {
32 (IJType::Tensor(None), slice)
33 };
34
35 let node = Node::new(
36 Operation::LambdaVariable(name),
37 vec![node_type.clone()],
38 node_type,
39 vec![],
40 context,
41 )?;
42 Ok((Rc::new(node), rest))
43}
44
45pub struct LambdaVariable;
46impl ParseNode for LambdaVariable {
47 fn next_node(
48 op: Token,
49 slice: TokenSlice,
50 context: &mut ASTContext,
51 ) -> Result<(Rc<Node>, TokenSlice)> {
52 let name = match op {
53 Token::LambdaVariable(v) => v.name,
54 _ => unreachable!(),
55 };
56 let var_type = match context.get_lambda_var_type(name.clone()) {
57 Some(var_type) => var_type,
58 None => return Err(SyntaxError::UnknownSymbol(name.clone()).into()),
59 };
60 let node = Node::new(
61 Operation::LambdaVariable(name),
62 vec![],
63 var_type,
64 vec![],
65 context,
66 )?;
67 Ok((Rc::new(node), slice))
68 }
69}
70
71impl ParseNodeFunctional for LambdaVariable {
72 fn next_node_functional_impl(
73 op: Token,
74 slice: TokenSlice,
75 context: &mut ASTContext,
76 needed_outputs: Option<&[IJType]>,
77 ) -> Result<(Vec<Rc<Node>>, TokenSlice)> {
78 let name = match op {
79 Token::LambdaVariable(v) => v.name,
80 _ => unreachable!(),
81 };
82 let signature = match context.get_lambda_var_type(name.clone()) {
83 Some(IJType::Function(signature)) => signature,
84 Some(_) => return Err(SyntaxError::ExpectedFunction(name.clone()).into()),
85 _ => return Err(SyntaxError::UnknownSymbol(name.clone()).into()),
86 };
87
88 if let Some(outputs) = needed_outputs {
89 if outputs.iter().any(|output| *output != *signature.output) {
90 return Err(SyntaxError::RequiredOutputsDoNotMatchFunctionOutputs(
91 needed_outputs
92 .unwrap()
93 .iter()
94 .map(|output| format!("{:?}", output))
95 .collect::<Vec<String>>()
96 .join(", "),
97 signature.output.to_string(),
98 )
99 .into());
100 }
101 }
102
103 let node = Node::new(
104 Operation::LambdaVariable(name),
105 vec![],
106 IJType::Function(signature),
107 vec![],
108 context,
109 )?;
110 Ok((vec![Rc::new(node)], slice.move_start(1)?))
111 }
112}