sbpf_assembler/
astnode.rs

1use {
2    crate::{debuginfo::DebugInfo, errors::CompileError, parser::Token},
3    sbpf_common::{inst_param::Number, instruction::Instruction},
4    std::{collections::HashMap, ops::Range},
5};
6
7#[derive(Debug, Clone)]
8pub enum ASTNode {
9    // only present in AST
10    Directive {
11        directive: Directive,
12    },
13    GlobalDecl {
14        global_decl: GlobalDecl,
15    },
16    EquDecl {
17        equ_decl: EquDecl,
18    },
19    ExternDecl {
20        extern_decl: ExternDecl,
21    },
22    RodataDecl {
23        rodata_decl: RodataDecl,
24    },
25    Label {
26        label: Label,
27        offset: u64,
28    },
29    // present in both AST and bytecode
30    ROData {
31        rodata: ROData,
32        offset: u64,
33    },
34    Instruction {
35        instruction: Instruction,
36        offset: u64,
37    },
38}
39
40#[derive(Debug, Clone)]
41pub struct Directive {
42    pub name: String,
43    pub args: Vec<Token>,
44    pub span: Range<usize>,
45}
46
47#[derive(Debug, Clone)]
48pub struct GlobalDecl {
49    pub entry_label: String,
50    pub span: Range<usize>,
51}
52
53impl GlobalDecl {
54    pub fn get_entry_label(&self) -> String {
55        self.entry_label.clone()
56    }
57}
58
59#[derive(Debug, Clone)]
60pub struct EquDecl {
61    pub name: String,
62    pub value: Token,
63    pub span: Range<usize>,
64}
65
66impl EquDecl {
67    pub fn get_name(&self) -> String {
68        self.name.clone()
69    }
70    pub fn get_val(&self) -> Number {
71        match &self.value {
72            Token::ImmediateValue(val, _) => val.clone(),
73            _ => panic!("Invalid Equ declaration"),
74        }
75    }
76}
77
78#[derive(Debug, Clone)]
79pub struct ExternDecl {
80    pub args: Vec<Token>,
81    pub span: Range<usize>,
82}
83
84#[derive(Debug, Clone)]
85pub struct RodataDecl {
86    pub span: Range<usize>,
87}
88
89#[derive(Debug, Clone)]
90pub struct Label {
91    pub name: String,
92    pub span: Range<usize>,
93}
94
95#[derive(Debug, Clone)]
96pub struct ROData {
97    pub name: String,
98    pub args: Vec<Token>,
99    pub span: Range<usize>,
100}
101
102impl ROData {
103    /// Validates that an immediate value is within the specified range
104    fn validate_immediate_range(
105        value: &Number,
106        min: i64,
107        max: u64,
108        span: Range<usize>,
109    ) -> Result<(), CompileError> {
110        let raw = value.to_i64();
111
112        if raw < min || (raw >= 0 && (raw as u64) > max) {
113            return Err(CompileError::OutOfRangeLiteral {
114                span,
115                custom_label: None,
116            });
117        }
118        Ok(())
119    }
120
121    pub fn get_size(&self) -> u64 {
122        let size: u64;
123        match (&self.args[0], &self.args[1]) {
124            (Token::Directive(_, _), Token::StringLiteral(s, _)) => {
125                size = s.len() as u64;
126            }
127            (Token::Directive(directive, _), Token::VectorLiteral(values, _)) => {
128                match directive.as_str() {
129                    "byte" => {
130                        size = values.len() as u64;
131                    }
132                    "short" | "word" => {
133                        size = values.len() as u64 * 2;
134                    }
135                    "int" | "long" => {
136                        size = values.len() as u64 * 4;
137                    }
138                    "quad" => {
139                        size = values.len() as u64 * 8;
140                    }
141                    _ => panic!("Invalid ROData declaration"),
142                }
143            }
144            _ => panic!("Invalid ROData declaration"),
145        }
146        size
147    }
148    pub fn verify(&self) -> Result<(), CompileError> {
149        match (&self.args[0], &self.args[1]) {
150            (Token::Directive(directive, directive_span), Token::StringLiteral(_, _)) => {
151                if directive.as_str() != "ascii" {
152                    return Err(CompileError::InvalidRODataDirective {
153                        span: directive_span.clone(),
154                        custom_label: None,
155                    });
156                }
157            }
158            (
159                Token::Directive(directive, directive_span),
160                Token::VectorLiteral(values, vector_literal_span),
161            ) => match directive.as_str() {
162                "byte" => {
163                    for value in values {
164                        Self::validate_immediate_range(
165                            value,
166                            i8::MIN as i64,
167                            u8::MAX as u64,
168                            vector_literal_span.clone(),
169                        )?;
170                    }
171                }
172                "short" | "word" => {
173                    for value in values {
174                        Self::validate_immediate_range(
175                            value,
176                            i16::MIN as i64,
177                            u16::MAX as u64,
178                            vector_literal_span.clone(),
179                        )?;
180                    }
181                }
182                "int" | "long" => {
183                    for value in values {
184                        Self::validate_immediate_range(
185                            value,
186                            i32::MIN as i64,
187                            u32::MAX as u64,
188                            vector_literal_span.clone(),
189                        )?;
190                    }
191                }
192                "quad" => {
193                    for value in values {
194                        Self::validate_immediate_range(
195                            value,
196                            i64::MIN,
197                            u64::MAX,
198                            vector_literal_span.clone(),
199                        )?;
200                    }
201                }
202                _ => {
203                    return Err(CompileError::InvalidRODataDirective {
204                        span: directive_span.clone(),
205                        custom_label: None,
206                    });
207                }
208            },
209            _ => {
210                return Err(CompileError::InvalidRodataDecl {
211                    span: self.span.clone(),
212                    custom_label: None,
213                });
214            }
215        }
216        Ok(())
217    }
218}
219
220impl ASTNode {
221    pub fn bytecode_with_debug_map(&self) -> Option<(Vec<u8>, HashMap<u64, DebugInfo>)> {
222        match self {
223            ASTNode::Instruction {
224                instruction,
225                offset,
226            } => {
227                // TODO: IMPLEMENT DEBUG INFO HANDLING AND DELETE THIS
228                let mut debug_map = HashMap::new();
229                let debug_info = DebugInfo::new(instruction.span.clone());
230
231                debug_map.insert(*offset, debug_info);
232
233                Some((instruction.to_bytes().unwrap(), debug_map))
234            }
235            ASTNode::ROData {
236                rodata: ROData { name: _, args, .. },
237                ..
238            } => {
239                let mut bytes = Vec::new();
240                let debug_map = HashMap::<u64, DebugInfo>::new();
241                match (&args[0], &args[1]) {
242                    (Token::Directive(_, _), Token::StringLiteral(str_literal, _)) => {
243                        let str_bytes = str_literal.as_bytes().to_vec();
244                        bytes.extend(str_bytes);
245                    }
246                    (Token::Directive(directive, _), Token::VectorLiteral(values, _)) => {
247                        if directive == "byte" {
248                            for value in values {
249                                let imm8 = match value {
250                                    Number::Int(val) => *val as i8,
251                                    Number::Addr(val) => *val as i8,
252                                };
253                                bytes.extend(imm8.to_le_bytes());
254                            }
255                        } else if directive == "short" || directive == "word" {
256                            for value in values {
257                                let imm16 = match value {
258                                    Number::Int(val) => *val as i16,
259                                    Number::Addr(val) => *val as i16,
260                                };
261                                bytes.extend(imm16.to_le_bytes());
262                            }
263                        } else if directive == "int" || directive == "long" {
264                            for value in values {
265                                let imm32 = match value {
266                                    Number::Int(val) => *val as i32,
267                                    Number::Addr(val) => *val as i32,
268                                };
269                                bytes.extend(imm32.to_le_bytes());
270                            }
271                        } else if directive == "quad" {
272                            for value in values {
273                                let imm64 = match value {
274                                    Number::Int(val) => *val,
275                                    Number::Addr(val) => *val,
276                                };
277                                bytes.extend(imm64.to_le_bytes());
278                            }
279                        } else {
280                            panic!("Invalid ROData declaration");
281                        }
282                    }
283
284                    _ => panic!("Invalid ROData declaration"),
285                }
286                Some((bytes, debug_map))
287            }
288            _ => None,
289        }
290    }
291
292    // Keep the old bytecode method for backward compatibility
293    pub fn bytecode(&self) -> Option<Vec<u8>> {
294        self.bytecode_with_debug_map().map(|(bytes, _)| bytes)
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use {
301        super::*,
302        sbpf_common::{instruction::Instruction, opcode::Opcode},
303    };
304
305    #[test]
306    fn test_global_decl_get_entry_label() {
307        let global = GlobalDecl {
308            entry_label: "entrypoint".to_string(),
309            span: 0..10,
310        };
311        assert_eq!(global.get_entry_label(), "entrypoint");
312    }
313
314    #[test]
315    fn test_equ_decl_methods() {
316        let equ = EquDecl {
317            name: "MY_CONST".to_string(),
318            value: Token::ImmediateValue(Number::Int(42), 5..7),
319            span: 0..15,
320        };
321        assert_eq!(equ.get_name(), "MY_CONST");
322        assert_eq!(equ.get_val(), Number::Int(42));
323    }
324
325    #[test]
326    #[should_panic(expected = "Invalid Equ declaration")]
327    fn test_equ_decl_invalid_value() {
328        let equ = EquDecl {
329            name: "INVALID".to_string(),
330            value: Token::Identifier("not_a_number".to_string(), 0..5),
331            span: 0..10,
332        };
333        let _ = equ.get_val(); // Should panic
334    }
335
336    #[test]
337    fn test_rodata_get_size_ascii() {
338        let rodata = ROData {
339            name: "my_string".to_string(),
340            args: vec![
341                Token::Directive("ascii".to_string(), 0..5),
342                Token::StringLiteral("Hello".to_string(), 6..13),
343            ],
344            span: 0..13,
345        };
346        assert_eq!(rodata.get_size(), 5);
347    }
348
349    #[test]
350    fn test_rodata_get_size_byte() {
351        let rodata = ROData {
352            name: "my_bytes".to_string(),
353            args: vec![
354                Token::Directive("byte".to_string(), 0..4),
355                Token::VectorLiteral(vec![Number::Int(1), Number::Int(2), Number::Int(3)], 5..14),
356            ],
357            span: 0..14,
358        };
359        assert_eq!(rodata.get_size(), 3);
360    }
361
362    #[test]
363    fn test_rodata_get_size_short() {
364        let rodata = ROData {
365            name: "my_shorts".to_string(),
366            args: vec![
367                Token::Directive("short".to_string(), 0..5),
368                Token::VectorLiteral(vec![Number::Int(1), Number::Int(2)], 6..12),
369            ],
370            span: 0..12,
371        };
372        assert_eq!(rodata.get_size(), 4); // 2 shorts * 2 bytes
373    }
374
375    #[test]
376    fn test_rodata_get_size_int() {
377        let rodata = ROData {
378            name: "my_ints".to_string(),
379            args: vec![
380                Token::Directive("int".to_string(), 0..3),
381                Token::VectorLiteral(vec![Number::Int(100)], 4..7),
382            ],
383            span: 0..7,
384        };
385        assert_eq!(rodata.get_size(), 4); // 1 int * 4 bytes
386    }
387
388    #[test]
389    fn test_rodata_get_size_quad() {
390        let rodata = ROData {
391            name: "my_quads".to_string(),
392            args: vec![
393                Token::Directive("quad".to_string(), 0..4),
394                Token::VectorLiteral(vec![Number::Int(1000)], 5..9),
395            ],
396            span: 0..9,
397        };
398        assert_eq!(rodata.get_size(), 8); // 1 quad * 8 bytes
399    }
400
401    #[test]
402    fn test_rodata_verify_ascii() {
403        let rodata = ROData {
404            name: "str".to_string(),
405            args: vec![
406                Token::Directive("ascii".to_string(), 0..5),
407                Token::StringLiteral("test".to_string(), 6..12),
408            ],
409            span: 0..12,
410        };
411        assert!(rodata.verify().is_ok());
412    }
413
414    #[test]
415    fn test_rodata_verify_byte_valid() {
416        let rodata = ROData {
417            name: "bytes".to_string(),
418            args: vec![
419                Token::Directive("byte".to_string(), 0..4),
420                Token::VectorLiteral(
421                    vec![Number::Int(0), Number::Int(127), Number::Int(-128)],
422                    5..15,
423                ),
424            ],
425            span: 0..15,
426        };
427        assert!(rodata.verify().is_ok());
428    }
429
430    #[test]
431    fn test_rodata_verify_byte_out_of_range() {
432        let rodata = ROData {
433            name: "bytes".to_string(),
434            args: vec![
435                Token::Directive("byte".to_string(), 0..4),
436                Token::VectorLiteral(vec![Number::Int(256)], 5..10),
437            ],
438            span: 0..10,
439        };
440        assert!(rodata.verify().is_err());
441    }
442
443    #[test]
444    fn test_rodata_verify_short_valid() {
445        let rodata = ROData {
446            name: "shorts".to_string(),
447            args: vec![
448                Token::Directive("short".to_string(), 0..5),
449                Token::VectorLiteral(vec![Number::Int(32767), Number::Int(-32768)], 6..16),
450            ],
451            span: 0..16,
452        };
453        assert!(rodata.verify().is_ok());
454    }
455
456    #[test]
457    fn test_rodata_verify_int_valid() {
458        let rodata = ROData {
459            name: "ints".to_string(),
460            args: vec![
461                Token::Directive("int".to_string(), 0..3),
462                Token::VectorLiteral(vec![Number::Int(2147483647)], 4..14),
463            ],
464            span: 0..14,
465        };
466        assert!(rodata.verify().is_ok());
467    }
468
469    #[test]
470    fn test_rodata_verify_quad_valid() {
471        let rodata = ROData {
472            name: "quads".to_string(),
473            args: vec![
474                Token::Directive("quad".to_string(), 0..4),
475                Token::VectorLiteral(vec![Number::Int(9223372036854775807)], 5..20),
476            ],
477            span: 0..20,
478        };
479        assert!(rodata.verify().is_ok());
480    }
481
482    #[test]
483    fn test_rodata_verify_invalid_directive() {
484        let rodata = ROData {
485            name: "invalid".to_string(),
486            args: vec![
487                Token::Directive("invalid".to_string(), 0..7),
488                Token::VectorLiteral(vec![Number::Int(1)], 8..11),
489            ],
490            span: 0..11,
491        };
492        assert!(rodata.verify().is_err());
493    }
494
495    #[test]
496    fn test_astnode_instruction_bytecode() {
497        let inst = Instruction {
498            opcode: Opcode::Exit,
499            dst: None,
500            src: None,
501            off: None,
502            imm: None,
503            span: 0..4,
504        };
505        let node = ASTNode::Instruction {
506            instruction: inst,
507            offset: 0,
508        };
509
510        let bytecode = node.bytecode();
511        assert!(bytecode.is_some());
512        assert_eq!(bytecode.unwrap().len(), 8);
513    }
514
515    #[test]
516    fn test_astnode_rodata_bytecode_ascii() {
517        let rodata = ROData {
518            name: "msg".to_string(),
519            args: vec![
520                Token::Directive("ascii".to_string(), 0..5),
521                Token::StringLiteral("Hi".to_string(), 6..10),
522            ],
523            span: 0..10,
524        };
525        let node = ASTNode::ROData { rodata, offset: 0 };
526
527        let bytecode = node.bytecode();
528        assert!(bytecode.is_some());
529        assert_eq!(bytecode.unwrap(), b"Hi");
530    }
531
532    #[test]
533    fn test_astnode_rodata_bytecode_byte() {
534        let rodata = ROData {
535            name: "data".to_string(),
536            args: vec![
537                Token::Directive("byte".to_string(), 0..4),
538                Token::VectorLiteral(vec![Number::Int(0x42), Number::Int(0x43)], 5..13),
539            ],
540            span: 0..13,
541        };
542        let node = ASTNode::ROData { rodata, offset: 0 };
543
544        let bytecode = node.bytecode();
545        assert!(bytecode.is_some());
546        assert_eq!(bytecode.unwrap(), vec![0x42u8, 0x43u8]);
547    }
548
549    #[test]
550    fn test_astnode_rodata_bytecode_short() {
551        let rodata = ROData {
552            name: "data".to_string(),
553            args: vec![
554                Token::Directive("short".to_string(), 0..5),
555                Token::VectorLiteral(vec![Number::Int(0x1234)], 6..12),
556            ],
557            span: 0..12,
558        };
559        let node = ASTNode::ROData { rodata, offset: 0 };
560
561        let bytecode = node.bytecode();
562        assert!(bytecode.is_some());
563        let bytes = bytecode.unwrap();
564        assert_eq!(bytes.len(), 2);
565        assert_eq!(i16::from_le_bytes([bytes[0], bytes[1]]), 0x1234);
566    }
567
568    #[test]
569    fn test_astnode_rodata_bytecode_int() {
570        let rodata = ROData {
571            name: "data".to_string(),
572            args: vec![
573                Token::Directive("int".to_string(), 0..3),
574                Token::VectorLiteral(vec![Number::Int(0x12345678)], 4..14),
575            ],
576            span: 0..14,
577        };
578        let node = ASTNode::ROData { rodata, offset: 0 };
579
580        let bytecode = node.bytecode();
581        assert!(bytecode.is_some());
582        let bytes = bytecode.unwrap();
583        assert_eq!(bytes.len(), 4);
584    }
585
586    #[test]
587    fn test_astnode_rodata_bytecode_quad() {
588        let rodata = ROData {
589            name: "data".to_string(),
590            args: vec![
591                Token::Directive("quad".to_string(), 0..4),
592                Token::VectorLiteral(vec![Number::Int(0x123456789ABCDEF0)], 5..21),
593            ],
594            span: 0..21,
595        };
596        let node = ASTNode::ROData { rodata, offset: 0 };
597
598        let bytecode = node.bytecode();
599        assert!(bytecode.is_some());
600        let bytes = bytecode.unwrap();
601        assert_eq!(bytes.len(), 8);
602    }
603
604    #[test]
605    fn test_astnode_label_no_bytecode() {
606        let node = ASTNode::Label {
607            label: Label {
608                name: "loop".to_string(),
609                span: 0..4,
610            },
611            offset: 0,
612        };
613        assert!(node.bytecode().is_none());
614    }
615
616    #[test]
617    fn test_astnode_directive_no_bytecode() {
618        let node = ASTNode::Directive {
619            directive: Directive {
620                name: "section".to_string(),
621                args: vec![],
622                span: 0..7,
623            },
624        };
625        assert!(node.bytecode().is_none());
626    }
627}