Skip to main content

sbpf_assembler/
astnode.rs

1use {
2    crate::{errors::CompileError, parser::Token},
3    sbpf_common::{inst_param::Number, instruction::Instruction},
4    std::ops::Range,
5};
6
7#[derive(Debug, Clone)]
8pub enum ASTNode {
9    // only present in AST
10    Directive {
11        directive: Directive,
12    },
13    GlobalDecl {
14        global_decl: GlobalDecl,
15    },
16    EquDecl {
17        equ_decl: EquDecl,
18    },
19    ExternDecl {
20        extern_decl: ExternDecl,
21    },
22    RodataDecl {
23        rodata_decl: RodataDecl,
24    },
25    Label {
26        label: Label,
27        offset: u64,
28    },
29    // present in both AST and bytecode
30    ROData {
31        rodata: ROData,
32        offset: u64,
33    },
34    Instruction {
35        instruction: Instruction,
36        offset: u64,
37    },
38}
39
40#[derive(Debug, Clone)]
41pub struct Directive {
42    pub name: String,
43    pub args: Vec<Token>,
44    pub span: Range<usize>,
45}
46
47#[derive(Debug, Clone)]
48pub struct GlobalDecl {
49    pub entry_label: String,
50    pub span: Range<usize>,
51}
52
53impl GlobalDecl {
54    pub fn get_entry_label(&self) -> String {
55        self.entry_label.clone()
56    }
57}
58
59#[derive(Debug, Clone)]
60pub struct EquDecl {
61    pub name: String,
62    pub value: Token,
63    pub span: Range<usize>,
64}
65
66impl EquDecl {
67    pub fn get_name(&self) -> String {
68        self.name.clone()
69    }
70    pub fn get_val(&self) -> Number {
71        match &self.value {
72            Token::ImmediateValue(val, _) => val.clone(),
73            _ => panic!("Invalid Equ declaration"),
74        }
75    }
76}
77
78#[derive(Debug, Clone)]
79pub struct ExternDecl {
80    pub args: Vec<Token>,
81    pub span: Range<usize>,
82}
83
84#[derive(Debug, Clone)]
85pub struct RodataDecl {
86    pub span: Range<usize>,
87}
88
89#[derive(Debug, Clone)]
90pub struct Label {
91    pub name: String,
92    pub span: Range<usize>,
93}
94
95#[derive(Debug, Clone)]
96pub struct ROData {
97    pub name: String,
98    pub args: Vec<Token>,
99    pub span: Range<usize>,
100}
101
102impl ROData {
103    /// Validates that an immediate value is within the specified range
104    fn validate_immediate_range(
105        value: &Number,
106        min: i64,
107        max: u64,
108        span: Range<usize>,
109    ) -> Result<(), CompileError> {
110        let raw = value.to_i64();
111
112        if raw < min || (raw >= 0 && (raw as u64) > max) {
113            return Err(CompileError::OutOfRangeLiteral {
114                span,
115                custom_label: None,
116            });
117        }
118        Ok(())
119    }
120
121    pub fn get_size(&self) -> u64 {
122        let size: u64;
123        match (&self.args[0], &self.args[1]) {
124            (Token::Directive(_, _), Token::StringLiteral(s, _)) => {
125                size = s.len() as u64;
126            }
127            (Token::Directive(directive, _), Token::VectorLiteral(values, _)) => {
128                match directive.as_str() {
129                    "byte" => {
130                        size = values.len() as u64;
131                    }
132                    "short" | "word" => {
133                        size = values.len() as u64 * 2;
134                    }
135                    "int" | "long" => {
136                        size = values.len() as u64 * 4;
137                    }
138                    "quad" => {
139                        size = values.len() as u64 * 8;
140                    }
141                    _ => panic!("Invalid ROData declaration"),
142                }
143            }
144            _ => panic!("Invalid ROData declaration"),
145        }
146        size
147    }
148    pub fn verify(&self) -> Result<(), CompileError> {
149        match (&self.args[0], &self.args[1]) {
150            (Token::Directive(directive, directive_span), Token::StringLiteral(_, _)) => {
151                if directive.as_str() != "ascii" {
152                    return Err(CompileError::InvalidRODataDirective {
153                        span: directive_span.clone(),
154                        custom_label: None,
155                    });
156                }
157            }
158            (
159                Token::Directive(directive, directive_span),
160                Token::VectorLiteral(values, vector_literal_span),
161            ) => match directive.as_str() {
162                "byte" => {
163                    for value in values {
164                        Self::validate_immediate_range(
165                            value,
166                            i8::MIN as i64,
167                            u8::MAX as u64,
168                            vector_literal_span.clone(),
169                        )?;
170                    }
171                }
172                "short" | "word" => {
173                    for value in values {
174                        Self::validate_immediate_range(
175                            value,
176                            i16::MIN as i64,
177                            u16::MAX as u64,
178                            vector_literal_span.clone(),
179                        )?;
180                    }
181                }
182                "int" | "long" => {
183                    for value in values {
184                        Self::validate_immediate_range(
185                            value,
186                            i32::MIN as i64,
187                            u32::MAX as u64,
188                            vector_literal_span.clone(),
189                        )?;
190                    }
191                }
192                "quad" => {
193                    for value in values {
194                        Self::validate_immediate_range(
195                            value,
196                            i64::MIN,
197                            u64::MAX,
198                            vector_literal_span.clone(),
199                        )?;
200                    }
201                }
202                _ => {
203                    return Err(CompileError::InvalidRODataDirective {
204                        span: directive_span.clone(),
205                        custom_label: None,
206                    });
207                }
208            },
209            _ => {
210                return Err(CompileError::InvalidRodataDecl {
211                    span: self.span.clone(),
212                    custom_label: None,
213                });
214            }
215        }
216        Ok(())
217    }
218}
219
220impl ASTNode {
221    pub fn bytecode(&self) -> Option<Vec<u8>> {
222        match self {
223            ASTNode::Instruction { instruction, .. } => Some(instruction.to_bytes().unwrap()),
224            ASTNode::ROData {
225                rodata: ROData { args, .. },
226                ..
227            } => {
228                let mut bytes = Vec::new();
229                match (&args[0], &args[1]) {
230                    (Token::Directive(_, _), Token::StringLiteral(str_literal, _)) => {
231                        let str_bytes = str_literal.as_bytes().to_vec();
232                        bytes.extend(str_bytes);
233                    }
234                    (Token::Directive(directive, _), Token::VectorLiteral(values, _)) => {
235                        if directive == "byte" {
236                            for value in values {
237                                let imm8 = match value {
238                                    Number::Int(val) => *val as i8,
239                                    Number::Addr(val) => *val as i8,
240                                };
241                                bytes.extend(imm8.to_le_bytes());
242                            }
243                        } else if directive == "short" || directive == "word" {
244                            for value in values {
245                                let imm16 = match value {
246                                    Number::Int(val) => *val as i16,
247                                    Number::Addr(val) => *val as i16,
248                                };
249                                bytes.extend(imm16.to_le_bytes());
250                            }
251                        } else if directive == "int" || directive == "long" {
252                            for value in values {
253                                let imm32 = match value {
254                                    Number::Int(val) => *val as i32,
255                                    Number::Addr(val) => *val as i32,
256                                };
257                                bytes.extend(imm32.to_le_bytes());
258                            }
259                        } else if directive == "quad" {
260                            for value in values {
261                                let imm64 = match value {
262                                    Number::Int(val) => *val,
263                                    Number::Addr(val) => *val,
264                                };
265                                bytes.extend(imm64.to_le_bytes());
266                            }
267                        } else {
268                            panic!("Invalid ROData declaration");
269                        }
270                    }
271                    _ => panic!("Invalid ROData declaration"),
272                }
273                Some(bytes)
274            }
275            _ => None,
276        }
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use {
283        super::*,
284        sbpf_common::{instruction::Instruction, opcode::Opcode},
285    };
286
287    #[test]
288    fn test_global_decl_get_entry_label() {
289        let global = GlobalDecl {
290            entry_label: "entrypoint".to_string(),
291            span: 0..10,
292        };
293        assert_eq!(global.get_entry_label(), "entrypoint");
294    }
295
296    #[test]
297    fn test_equ_decl_methods() {
298        let equ = EquDecl {
299            name: "MY_CONST".to_string(),
300            value: Token::ImmediateValue(Number::Int(42), 5..7),
301            span: 0..15,
302        };
303        assert_eq!(equ.get_name(), "MY_CONST");
304        assert_eq!(equ.get_val(), Number::Int(42));
305    }
306
307    #[test]
308    #[should_panic(expected = "Invalid Equ declaration")]
309    fn test_equ_decl_invalid_value() {
310        let equ = EquDecl {
311            name: "INVALID".to_string(),
312            value: Token::Identifier("not_a_number".to_string(), 0..5),
313            span: 0..10,
314        };
315        let _ = equ.get_val(); // Should panic
316    }
317
318    #[test]
319    fn test_rodata_get_size_ascii() {
320        let rodata = ROData {
321            name: "my_string".to_string(),
322            args: vec![
323                Token::Directive("ascii".to_string(), 0..5),
324                Token::StringLiteral("Hello".to_string(), 6..13),
325            ],
326            span: 0..13,
327        };
328        assert_eq!(rodata.get_size(), 5);
329    }
330
331    #[test]
332    fn test_rodata_get_size_byte() {
333        let rodata = ROData {
334            name: "my_bytes".to_string(),
335            args: vec![
336                Token::Directive("byte".to_string(), 0..4),
337                Token::VectorLiteral(vec![Number::Int(1), Number::Int(2), Number::Int(3)], 5..14),
338            ],
339            span: 0..14,
340        };
341        assert_eq!(rodata.get_size(), 3);
342    }
343
344    #[test]
345    fn test_rodata_get_size_short() {
346        let rodata = ROData {
347            name: "my_shorts".to_string(),
348            args: vec![
349                Token::Directive("short".to_string(), 0..5),
350                Token::VectorLiteral(vec![Number::Int(1), Number::Int(2)], 6..12),
351            ],
352            span: 0..12,
353        };
354        assert_eq!(rodata.get_size(), 4); // 2 shorts * 2 bytes
355    }
356
357    #[test]
358    fn test_rodata_get_size_int() {
359        let rodata = ROData {
360            name: "my_ints".to_string(),
361            args: vec![
362                Token::Directive("int".to_string(), 0..3),
363                Token::VectorLiteral(vec![Number::Int(100)], 4..7),
364            ],
365            span: 0..7,
366        };
367        assert_eq!(rodata.get_size(), 4); // 1 int * 4 bytes
368    }
369
370    #[test]
371    fn test_rodata_get_size_quad() {
372        let rodata = ROData {
373            name: "my_quads".to_string(),
374            args: vec![
375                Token::Directive("quad".to_string(), 0..4),
376                Token::VectorLiteral(vec![Number::Int(1000)], 5..9),
377            ],
378            span: 0..9,
379        };
380        assert_eq!(rodata.get_size(), 8); // 1 quad * 8 bytes
381    }
382
383    #[test]
384    fn test_rodata_verify_ascii() {
385        let rodata = ROData {
386            name: "str".to_string(),
387            args: vec![
388                Token::Directive("ascii".to_string(), 0..5),
389                Token::StringLiteral("test".to_string(), 6..12),
390            ],
391            span: 0..12,
392        };
393        assert!(rodata.verify().is_ok());
394    }
395
396    #[test]
397    fn test_rodata_verify_byte_valid() {
398        let rodata = ROData {
399            name: "bytes".to_string(),
400            args: vec![
401                Token::Directive("byte".to_string(), 0..4),
402                Token::VectorLiteral(
403                    vec![Number::Int(0), Number::Int(127), Number::Int(-128)],
404                    5..15,
405                ),
406            ],
407            span: 0..15,
408        };
409        assert!(rodata.verify().is_ok());
410    }
411
412    #[test]
413    fn test_rodata_verify_byte_out_of_range() {
414        let rodata = ROData {
415            name: "bytes".to_string(),
416            args: vec![
417                Token::Directive("byte".to_string(), 0..4),
418                Token::VectorLiteral(vec![Number::Int(256)], 5..10),
419            ],
420            span: 0..10,
421        };
422        assert!(rodata.verify().is_err());
423    }
424
425    #[test]
426    fn test_rodata_verify_short_valid() {
427        let rodata = ROData {
428            name: "shorts".to_string(),
429            args: vec![
430                Token::Directive("short".to_string(), 0..5),
431                Token::VectorLiteral(vec![Number::Int(32767), Number::Int(-32768)], 6..16),
432            ],
433            span: 0..16,
434        };
435        assert!(rodata.verify().is_ok());
436    }
437
438    #[test]
439    fn test_rodata_verify_int_valid() {
440        let rodata = ROData {
441            name: "ints".to_string(),
442            args: vec![
443                Token::Directive("int".to_string(), 0..3),
444                Token::VectorLiteral(vec![Number::Int(2147483647)], 4..14),
445            ],
446            span: 0..14,
447        };
448        assert!(rodata.verify().is_ok());
449    }
450
451    #[test]
452    fn test_rodata_verify_quad_valid() {
453        let rodata = ROData {
454            name: "quads".to_string(),
455            args: vec![
456                Token::Directive("quad".to_string(), 0..4),
457                Token::VectorLiteral(vec![Number::Int(9223372036854775807)], 5..20),
458            ],
459            span: 0..20,
460        };
461        assert!(rodata.verify().is_ok());
462    }
463
464    #[test]
465    fn test_rodata_verify_invalid_directive() {
466        let rodata = ROData {
467            name: "invalid".to_string(),
468            args: vec![
469                Token::Directive("invalid".to_string(), 0..7),
470                Token::VectorLiteral(vec![Number::Int(1)], 8..11),
471            ],
472            span: 0..11,
473        };
474        assert!(rodata.verify().is_err());
475    }
476
477    #[test]
478    fn test_astnode_instruction_bytecode() {
479        let inst = Instruction {
480            opcode: Opcode::Exit,
481            dst: None,
482            src: None,
483            off: None,
484            imm: None,
485            span: 0..4,
486        };
487        let node = ASTNode::Instruction {
488            instruction: inst,
489            offset: 0,
490        };
491
492        let bytecode = node.bytecode();
493        assert!(bytecode.is_some());
494        assert_eq!(bytecode.unwrap().len(), 8);
495    }
496
497    #[test]
498    fn test_astnode_rodata_bytecode_ascii() {
499        let rodata = ROData {
500            name: "msg".to_string(),
501            args: vec![
502                Token::Directive("ascii".to_string(), 0..5),
503                Token::StringLiteral("Hi".to_string(), 6..10),
504            ],
505            span: 0..10,
506        };
507        let node = ASTNode::ROData { rodata, offset: 0 };
508
509        let bytecode = node.bytecode();
510        assert!(bytecode.is_some());
511        assert_eq!(bytecode.unwrap(), b"Hi");
512    }
513
514    #[test]
515    fn test_astnode_rodata_bytecode_byte() {
516        let rodata = ROData {
517            name: "data".to_string(),
518            args: vec![
519                Token::Directive("byte".to_string(), 0..4),
520                Token::VectorLiteral(vec![Number::Int(0x42), Number::Int(0x43)], 5..13),
521            ],
522            span: 0..13,
523        };
524        let node = ASTNode::ROData { rodata, offset: 0 };
525
526        let bytecode = node.bytecode();
527        assert!(bytecode.is_some());
528        assert_eq!(bytecode.unwrap(), vec![0x42u8, 0x43u8]);
529    }
530
531    #[test]
532    fn test_astnode_rodata_bytecode_short() {
533        let rodata = ROData {
534            name: "data".to_string(),
535            args: vec![
536                Token::Directive("short".to_string(), 0..5),
537                Token::VectorLiteral(vec![Number::Int(0x1234)], 6..12),
538            ],
539            span: 0..12,
540        };
541        let node = ASTNode::ROData { rodata, offset: 0 };
542
543        let bytecode = node.bytecode();
544        assert!(bytecode.is_some());
545        let bytes = bytecode.unwrap();
546        assert_eq!(bytes.len(), 2);
547        assert_eq!(i16::from_le_bytes([bytes[0], bytes[1]]), 0x1234);
548    }
549
550    #[test]
551    fn test_astnode_rodata_bytecode_int() {
552        let rodata = ROData {
553            name: "data".to_string(),
554            args: vec![
555                Token::Directive("int".to_string(), 0..3),
556                Token::VectorLiteral(vec![Number::Int(0x12345678)], 4..14),
557            ],
558            span: 0..14,
559        };
560        let node = ASTNode::ROData { rodata, offset: 0 };
561
562        let bytecode = node.bytecode();
563        assert!(bytecode.is_some());
564        let bytes = bytecode.unwrap();
565        assert_eq!(bytes.len(), 4);
566    }
567
568    #[test]
569    fn test_astnode_rodata_bytecode_quad() {
570        let rodata = ROData {
571            name: "data".to_string(),
572            args: vec![
573                Token::Directive("quad".to_string(), 0..4),
574                Token::VectorLiteral(vec![Number::Int(0x123456789ABCDEF0)], 5..21),
575            ],
576            span: 0..21,
577        };
578        let node = ASTNode::ROData { rodata, offset: 0 };
579
580        let bytecode = node.bytecode();
581        assert!(bytecode.is_some());
582        let bytes = bytecode.unwrap();
583        assert_eq!(bytes.len(), 8);
584    }
585
586    #[test]
587    fn test_astnode_label_no_bytecode() {
588        let node = ASTNode::Label {
589            label: Label {
590                name: "loop".to_string(),
591                span: 0..4,
592            },
593            offset: 0,
594        };
595        assert!(node.bytecode().is_none());
596    }
597
598    #[test]
599    fn test_astnode_directive_no_bytecode() {
600        let node = ASTNode::Directive {
601            directive: Directive {
602                name: "section".to_string(),
603                args: vec![],
604                span: 0..7,
605            },
606        };
607        assert!(node.bytecode().is_none());
608    }
609}