Skip to main content

ruby_rbs/node/
mod.rs

1include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
2use rbs_encoding_type_t::RBS_ENCODING_UTF_8;
3use ruby_rbs_sys::bindings::*;
4use std::marker::PhantomData;
5use std::ptr::NonNull;
6
7/// Parse RBS code into an AST.
8///
9/// ```rust
10/// use ruby_rbs::node::parse;
11/// let rbs_code = r#"type foo = "hello""#;
12/// let signature = parse(rbs_code);
13/// assert!(signature.is_ok(), "Failed to parse RBS signature");
14/// ```
15pub fn parse(rbs_code: &str) -> Result<SignatureNode<'_>, String> {
16    unsafe {
17        let start_ptr = rbs_code.as_ptr().cast::<std::os::raw::c_char>();
18        let end_ptr = start_ptr.add(rbs_code.len());
19        let bytes = rbs_code.len() as i32;
20
21        let raw_rbs_string_value = rbs_string_new(start_ptr, end_ptr);
22
23        let encoding_ptr = &rbs_encodings[RBS_ENCODING_UTF_8 as usize] as *const rbs_encoding_t;
24        let parser = rbs_parser_new(raw_rbs_string_value, encoding_ptr, 0, bytes);
25
26        let mut signature: *mut rbs_signature_t = std::ptr::null_mut();
27        let result = rbs_parse_signature(parser, &mut signature);
28
29        let signature_node = SignatureNode {
30            parser: NonNull::new_unchecked(parser),
31            pointer: signature,
32            marker: PhantomData,
33        };
34
35        if result {
36            Ok(signature_node)
37        } else {
38            let error_message = (*parser)
39                .error
40                .as_ref()
41                .filter(|error| !error.message.is_null())
42                .map(|error| {
43                    std::ffi::CStr::from_ptr(error.message)
44                        .to_string_lossy()
45                        .into_owned()
46                })
47                .unwrap_or_else(|| String::from("Failed to parse RBS signature"));
48
49            Err(error_message)
50        }
51    }
52}
53
54impl Drop for SignatureNode<'_> {
55    fn drop(&mut self) {
56        unsafe {
57            rbs_parser_free(self.parser.as_ptr());
58        }
59    }
60}
61
62/// Instance variable name specification for attributes.
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub enum AttrIvarName {
65    /// The attribute has inferred instance variable (nil)
66    Unspecified,
67    /// The attribute has no instance variable (false)
68    Empty,
69    /// The attribute has instance variable with the given name
70    Name(rbs_constant_id_t),
71}
72
73impl AttrIvarName {
74    /// Converts the raw C struct to the Rust enum.
75    #[must_use]
76    pub fn from_raw(raw: rbs_attr_ivar_name_t) -> Self {
77        match raw.tag {
78            rbs_attr_ivar_name_tag::RBS_ATTR_IVAR_NAME_TAG_UNSPECIFIED => Self::Unspecified,
79            rbs_attr_ivar_name_tag::RBS_ATTR_IVAR_NAME_TAG_EMPTY => Self::Empty,
80            rbs_attr_ivar_name_tag::RBS_ATTR_IVAR_NAME_TAG_NAME => Self::Name(raw.name),
81            _ => panic!("Unknown ivar_name_tag: {}", raw.tag),
82        }
83    }
84}
85
86pub struct NodeList<'a> {
87    parser: NonNull<rbs_parser_t>,
88    pointer: *mut rbs_node_list_t,
89    marker: PhantomData<&'a mut rbs_node_list_t>,
90}
91
92impl<'a> NodeList<'a> {
93    #[must_use]
94    pub fn new(parser: NonNull<rbs_parser_t>, pointer: *mut rbs_node_list_t) -> Self {
95        Self {
96            parser,
97            pointer,
98            marker: PhantomData,
99        }
100    }
101
102    /// Returns an iterator over the nodes.
103    #[must_use]
104    pub fn iter(&self) -> NodeListIter<'a> {
105        NodeListIter {
106            parser: self.parser,
107            current: unsafe { (*self.pointer).head },
108            marker: PhantomData,
109        }
110    }
111}
112
113pub struct NodeListIter<'a> {
114    parser: NonNull<rbs_parser_t>,
115    current: *mut rbs_node_list_node_t,
116    marker: PhantomData<&'a mut rbs_node_list_node_t>,
117}
118
119impl<'a> Iterator for NodeListIter<'a> {
120    type Item = Node<'a>;
121
122    fn next(&mut self) -> Option<Self::Item> {
123        if self.current.is_null() {
124            None
125        } else {
126            let pointer_data = unsafe { *self.current };
127            let node = Node::new(self.parser, pointer_data.node);
128            self.current = pointer_data.next;
129            Some(node)
130        }
131    }
132}
133
134pub struct RBSHash<'a> {
135    parser: NonNull<rbs_parser_t>,
136    pointer: *mut rbs_hash,
137    marker: PhantomData<&'a mut rbs_hash>,
138}
139
140impl<'a> RBSHash<'a> {
141    #[must_use]
142    pub fn new(parser: NonNull<rbs_parser_t>, pointer: *mut rbs_hash) -> Self {
143        Self {
144            parser,
145            pointer,
146            marker: PhantomData,
147        }
148    }
149
150    /// Returns an iterator over the key-value pairs.
151    #[must_use]
152    pub fn iter(&self) -> RBSHashIter<'a> {
153        RBSHashIter {
154            parser: self.parser,
155            current: unsafe { (*self.pointer).head },
156            marker: PhantomData,
157        }
158    }
159}
160
161pub struct RBSHashIter<'a> {
162    parser: NonNull<rbs_parser_t>,
163    current: *mut rbs_hash_node_t,
164    marker: PhantomData<&'a mut rbs_hash_node_t>,
165}
166
167impl<'a> Iterator for RBSHashIter<'a> {
168    type Item = (Node<'a>, Node<'a>);
169
170    fn next(&mut self) -> Option<Self::Item> {
171        if self.current.is_null() {
172            None
173        } else {
174            let pointer_data = unsafe { *self.current };
175            let key = Node::new(self.parser, pointer_data.key);
176            let value = Node::new(self.parser, pointer_data.value);
177            self.current = pointer_data.next;
178            Some((key, value))
179        }
180    }
181}
182
183pub struct RBSLocationRange {
184    range: rbs_location_range,
185}
186
187impl RBSLocationRange {
188    #[must_use]
189    pub fn new(range: rbs_location_range) -> Self {
190        Self { range }
191    }
192
193    #[must_use]
194    pub fn start(&self) -> i32 {
195        self.range.start_byte
196    }
197
198    #[must_use]
199    pub fn end(&self) -> i32 {
200        self.range.end_byte
201    }
202}
203
204pub struct RBSLocationRangeList<'a> {
205    #[allow(dead_code)]
206    parser: NonNull<rbs_parser_t>,
207    pointer: *mut rbs_location_range_list_t,
208    marker: PhantomData<&'a mut rbs_location_range_list_t>,
209}
210
211impl<'a> RBSLocationRangeList<'a> {
212    /// Returns an iterator over the location ranges.
213    #[must_use]
214    pub fn iter(&self) -> RBSLocationRangeListIter {
215        RBSLocationRangeListIter {
216            current: unsafe { (*self.pointer).head },
217        }
218    }
219}
220
221pub struct RBSLocationRangeListIter {
222    current: *mut rbs_location_range_list_node_t,
223}
224
225impl Iterator for RBSLocationRangeListIter {
226    type Item = RBSLocationRange;
227
228    fn next(&mut self) -> Option<Self::Item> {
229        if self.current.is_null() {
230            None
231        } else {
232            let pointer_data = unsafe { *self.current };
233            let range = RBSLocationRange::new(pointer_data.range);
234            self.current = pointer_data.next;
235            Some(range)
236        }
237    }
238}
239
240#[derive(Debug)]
241pub struct RBSString<'a> {
242    pointer: *const rbs_string_t,
243    marker: PhantomData<&'a rbs_string_t>,
244}
245
246impl<'a> RBSString<'a> {
247    #[must_use]
248    pub fn new(pointer: *const rbs_string_t) -> Self {
249        Self {
250            pointer,
251            marker: PhantomData,
252        }
253    }
254
255    #[must_use]
256    #[allow(clippy::unnecessary_cast)]
257    pub fn as_bytes(&self) -> &[u8] {
258        unsafe {
259            let s = *self.pointer;
260            std::slice::from_raw_parts(s.start as *const u8, s.end.offset_from(s.start) as usize)
261        }
262    }
263
264    #[must_use]
265    pub fn as_str(&self) -> &str {
266        unsafe { std::str::from_utf8_unchecked(self.as_bytes()) }
267    }
268}
269
270impl std::fmt::Display for RBSString<'_> {
271    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272        f.write_str(self.as_str())
273    }
274}
275
276impl SymbolNode<'_> {
277    #[must_use]
278    pub fn as_bytes(&self) -> &[u8] {
279        unsafe {
280            let constant_ptr = rbs_constant_pool_id_to_constant(
281                &(*self.parser.as_ptr()).constant_pool,
282                (*self.pointer).constant_id,
283            );
284            if constant_ptr.is_null() {
285                panic!("Constant ID for symbol is not present in the pool");
286            }
287
288            let constant = &*constant_ptr;
289            std::slice::from_raw_parts(constant.start, constant.length)
290        }
291    }
292
293    #[must_use]
294    pub fn as_str(&self) -> &str {
295        unsafe { std::str::from_utf8_unchecked(self.as_bytes()) }
296    }
297}
298
299impl std::fmt::Display for SymbolNode<'_> {
300    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301        f.write_str(self.as_str())
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn test_parse_error_contains_actual_message() {
311        let rbs_code = "class { end";
312        let result = parse(rbs_code);
313        let error_message = result.unwrap_err();
314        assert_eq!(error_message, "expected one of class/module/constant name");
315    }
316
317    #[test]
318    fn test_parse() {
319        let rbs_code = r#"type foo = "hello""#;
320        let signature = parse(rbs_code);
321        assert!(signature.is_ok(), "Failed to parse RBS signature");
322
323        let rbs_code2 = r#"class Foo end"#;
324        let signature2 = parse(rbs_code2);
325        assert!(signature2.is_ok(), "Failed to parse RBS signature");
326    }
327
328    #[test]
329    fn test_parse_integer() {
330        let rbs_code = r#"type foo = 1"#;
331        let signature = parse(rbs_code);
332        assert!(signature.is_ok(), "Failed to parse RBS signature");
333
334        let signature_node = signature.unwrap();
335        if let Node::TypeAlias(node) = signature_node.declarations().iter().next().unwrap()
336            && let Node::LiteralType(literal) = node.type_()
337            && let Node::Integer(integer) = literal.literal()
338        {
339            assert_eq!(integer.string_representation().as_str(), "1");
340        } else {
341            panic!("No literal type node found");
342        }
343    }
344
345    #[test]
346    fn test_rbs_hash_via_record_type() {
347        // RecordType stores its fields in an RBSHash via all_fields()
348        let rbs_code = r#"type foo = { name: String, age: Integer }"#;
349        let signature = parse(rbs_code);
350        assert!(signature.is_ok(), "Failed to parse RBS signature");
351
352        let signature_node = signature.unwrap();
353        if let Node::TypeAlias(type_alias) = signature_node.declarations().iter().next().unwrap()
354            && let Node::RecordType(record) = type_alias.type_()
355        {
356            let hash = record.all_fields();
357            let fields: Vec<_> = hash.iter().collect();
358            assert_eq!(fields.len(), 2, "Expected 2 fields in record");
359
360            // Build a map of field names to type names
361            let mut field_types: Vec<(String, String)> = Vec::new();
362            for (key, value) in &fields {
363                let Node::Symbol(sym) = key else {
364                    panic!("Expected Symbol key");
365                };
366                let Node::RecordFieldType(field_type) = value else {
367                    panic!("Expected RecordFieldType value");
368                };
369                let Node::ClassInstanceType(class_type) = field_type.type_() else {
370                    panic!("Expected ClassInstanceType");
371                };
372
373                let key_name = sym.to_string();
374                let type_name_node = class_type.name();
375                let type_name_sym = type_name_node.name();
376                let type_name = type_name_sym.to_string();
377                field_types.push((key_name, type_name));
378            }
379
380            assert!(
381                field_types.contains(&("name".to_string(), "String".to_string())),
382                "Expected 'name: String'"
383            );
384            assert!(
385                field_types.contains(&("age".to_string(), "Integer".to_string())),
386                "Expected 'age: Integer'"
387            );
388        } else {
389            panic!("Expected TypeAlias with RecordType");
390        }
391    }
392
393    #[test]
394    fn visitor_test() {
395        struct Visitor {
396            visited: Vec<String>,
397        }
398
399        impl Visit for Visitor {
400            fn visit_bool_type_node(&mut self, node: &BoolTypeNode) {
401                self.visited.push("type:bool".to_string());
402
403                crate::node::visit_bool_type_node(self, node);
404            }
405
406            fn visit_class_node(&mut self, node: &ClassNode) {
407                self.visited.push(format!("class:{}", node.name().name()));
408
409                crate::node::visit_class_node(self, node);
410            }
411
412            fn visit_class_instance_type_node(&mut self, node: &ClassInstanceTypeNode) {
413                self.visited.push(format!("type:{}", node.name().name()));
414
415                crate::node::visit_class_instance_type_node(self, node);
416            }
417
418            fn visit_class_super_node(&mut self, node: &ClassSuperNode) {
419                self.visited.push(format!("super:{}", node.name().name()));
420
421                crate::node::visit_class_super_node(self, node);
422            }
423
424            fn visit_function_type_node(&mut self, node: &FunctionTypeNode) {
425                let count = node.required_positionals().iter().count();
426                self.visited
427                    .push(format!("function:required_positionals:{count}"));
428
429                crate::node::visit_function_type_node(self, node);
430            }
431
432            fn visit_method_definition_node(&mut self, node: &MethodDefinitionNode) {
433                self.visited.push(format!("method:{}", node.name()));
434
435                crate::node::visit_method_definition_node(self, node);
436            }
437
438            fn visit_record_type_node(&mut self, node: &RecordTypeNode) {
439                self.visited.push("record".to_string());
440
441                crate::node::visit_record_type_node(self, node);
442            }
443
444            fn visit_symbol_node(&mut self, node: &SymbolNode) {
445                self.visited.push(format!("symbol:{node}"));
446
447                crate::node::visit_symbol_node(self, node);
448            }
449        }
450
451        let rbs_code = r#"
452            class Foo < Bar
453                def process: ({ name: String, age: Integer }, bool) -> void
454            end
455        "#;
456
457        let signature = parse(rbs_code).unwrap();
458
459        let mut visitor = Visitor {
460            visited: Vec::new(),
461        };
462
463        visitor.visit(&signature.as_node());
464
465        assert_eq!(
466            vec![
467                "class:Foo",
468                "symbol:Foo",
469                "super:Bar",
470                "symbol:Bar",
471                "method:process",
472                "symbol:process",
473                "function:required_positionals:2",
474                "record",
475                "symbol:name",
476                "type:String",
477                "symbol:String",
478                "symbol:age",
479                "type:Integer",
480                "symbol:Integer",
481                "type:bool",
482            ],
483            visitor.visited
484        );
485    }
486
487    #[test]
488    fn test_node_location_ranges() {
489        let rbs_code = r#"type foo = 1"#;
490        let signature = parse(rbs_code).unwrap();
491
492        let declaration = signature.declarations().iter().next().unwrap();
493        let Node::TypeAlias(type_alias) = declaration else {
494            panic!("Expected TypeAlias");
495        };
496
497        // TypeAlias spans the entire declaration
498        let loc = type_alias.location();
499        assert_eq!(0, loc.start());
500        assert_eq!(12, loc.end());
501
502        // The literal "1" is at position 11-12
503        let Node::LiteralType(literal) = type_alias.type_() else {
504            panic!("Expected LiteralType");
505        };
506        let Node::Integer(integer) = literal.literal() else {
507            panic!("Expected Integer");
508        };
509
510        let int_loc = integer.location();
511        assert_eq!(11, int_loc.start());
512        assert_eq!(12, int_loc.end());
513    }
514
515    #[test]
516    fn test_sub_locations() {
517        let rbs_code = r#"class Foo < Bar end"#;
518        let signature = parse(rbs_code).unwrap();
519
520        let declaration = signature.declarations().iter().next().unwrap();
521        let Node::Class(class) = declaration else {
522            panic!("Expected Class");
523        };
524
525        // Test required sub-locations
526        let keyword_loc = class.keyword_location();
527        assert_eq!(0, keyword_loc.start());
528        assert_eq!(5, keyword_loc.end());
529
530        let name_loc = class.name_location();
531        assert_eq!(6, name_loc.start());
532        assert_eq!(9, name_loc.end());
533
534        let end_loc = class.end_location();
535        assert_eq!(16, end_loc.start());
536        assert_eq!(19, end_loc.end());
537
538        // Test optional sub-location that's present
539        let lt_loc = class.lt_location();
540        assert!(lt_loc.is_some());
541        let lt = lt_loc.unwrap();
542        assert_eq!(10, lt.start());
543        assert_eq!(11, lt.end());
544
545        // Test optional sub-location that's not present (no type params in this class)
546        let type_params_loc = class.type_params_location();
547        assert!(type_params_loc.is_none());
548    }
549
550    #[test]
551    fn test_type_alias_sub_locations() {
552        let rbs_code = r#"type foo = String"#;
553        let signature = parse(rbs_code).unwrap();
554
555        let declaration = signature.declarations().iter().next().unwrap();
556        let Node::TypeAlias(type_alias) = declaration else {
557            panic!("Expected TypeAlias");
558        };
559
560        // Test required sub-locations
561        let keyword_loc = type_alias.keyword_location();
562        assert_eq!(0, keyword_loc.start());
563        assert_eq!(4, keyword_loc.end());
564
565        let name_loc = type_alias.name_location();
566        assert_eq!(5, name_loc.start());
567        assert_eq!(8, name_loc.end());
568
569        let eq_loc = type_alias.eq_location();
570        assert_eq!(9, eq_loc.start());
571        assert_eq!(10, eq_loc.end());
572
573        // Test optional sub-location that's not present (no type params)
574        let type_params_loc = type_alias.type_params_location();
575        assert!(type_params_loc.is_none());
576    }
577
578    #[test]
579    fn test_module_sub_locations() {
580        let rbs_code = r#"module Foo[T] : Bar end"#;
581        let signature = parse(rbs_code).unwrap();
582
583        let declaration = signature.declarations().iter().next().unwrap();
584        let Node::Module(module) = declaration else {
585            panic!("Expected Module");
586        };
587
588        // Test required sub-locations
589        let keyword_loc = module.keyword_location();
590        assert_eq!(0, keyword_loc.start());
591        assert_eq!(6, keyword_loc.end());
592
593        let name_loc = module.name_location();
594        assert_eq!(7, name_loc.start());
595        assert_eq!(10, name_loc.end());
596
597        let end_loc = module.end_location();
598        assert_eq!(20, end_loc.start());
599        assert_eq!(23, end_loc.end());
600
601        // Test optional sub-locations that are present
602        let type_params_loc = module.type_params_location();
603        assert!(type_params_loc.is_some());
604        let tp = type_params_loc.unwrap();
605        assert_eq!(10, tp.start());
606        assert_eq!(13, tp.end());
607
608        let colon_loc = module.colon_location();
609        assert!(colon_loc.is_some());
610        let colon = colon_loc.unwrap();
611        assert_eq!(14, colon.start());
612        assert_eq!(15, colon.end());
613
614        let self_types_loc = module.self_types_location();
615        assert!(self_types_loc.is_some());
616        let st = self_types_loc.unwrap();
617        assert_eq!(16, st.start());
618        assert_eq!(19, st.end());
619    }
620
621    #[test]
622    fn test_enum_types() {
623        let rbs_code = r#"
624            class Foo
625                attr_reader name: String
626                def self.process: () -> void
627                alias instance_method target_method
628                alias self.singleton_method self.target_method
629            end
630
631            class Bar[out T, in U, V]
632            end
633        "#;
634        let signature = parse(rbs_code).unwrap();
635
636        let declarations: Vec<_> = signature.declarations().iter().collect();
637
638        // Test class Foo
639        let Node::Class(class_foo) = &declarations[0] else {
640            panic!("Expected Class");
641        };
642
643        let members: Vec<_> = class_foo.members().iter().collect();
644
645        // attr_reader - should be instance with unspecified visibility (default)
646        if let Node::AttrReader(attr) = &members[0] {
647            assert_eq!(attr.kind(), AttributeKind::Instance);
648            assert_eq!(attr.visibility(), AttributeVisibility::Unspecified);
649        } else {
650            panic!("Expected AttrReader");
651        }
652
653        // def self.process - should be singleton method with unspecified visibility (default)
654        if let Node::MethodDefinition(method) = &members[1] {
655            assert_eq!(method.kind(), MethodDefinitionKind::Singleton);
656            assert_eq!(method.visibility(), MethodDefinitionVisibility::Unspecified);
657        } else {
658            panic!("Expected MethodDefinition");
659        }
660
661        // alias instance_method
662        if let Node::Alias(alias) = &members[2] {
663            assert_eq!(alias.kind(), AliasKind::Instance);
664        } else {
665            panic!("Expected Alias");
666        }
667
668        // alias self.singleton_method
669        if let Node::Alias(alias) = &members[3] {
670            assert_eq!(alias.kind(), AliasKind::Singleton);
671        } else {
672            panic!("Expected Alias");
673        }
674
675        // Test class Bar with type params
676        let Node::Class(class_bar) = &declarations[1] else {
677            panic!("Expected Class");
678        };
679
680        let type_params: Vec<_> = class_bar.type_params().iter().collect();
681        assert_eq!(type_params.len(), 3);
682
683        // out T - covariant
684        if let Node::TypeParam(param) = &type_params[0] {
685            assert_eq!(param.variance(), TypeParamVariance::Covariant);
686        } else {
687            panic!("Expected TypeParam");
688        }
689
690        // in U - contravariant
691        if let Node::TypeParam(param) = &type_params[1] {
692            assert_eq!(param.variance(), TypeParamVariance::Contravariant);
693        } else {
694            panic!("Expected TypeParam");
695        }
696
697        // V - invariant (default)
698        if let Node::TypeParam(param) = &type_params[2] {
699            assert_eq!(param.variance(), TypeParamVariance::Invariant);
700        } else {
701            panic!("Expected TypeParam");
702        }
703    }
704
705    #[test]
706    fn test_ivar_name_enum() {
707        let rbs_code = r#"
708            class Foo
709                attr_reader name: String
710                attr_accessor age(): Integer
711                attr_writer email(@email): String
712            end
713        "#;
714        let signature = parse(rbs_code).unwrap();
715
716        let Node::Class(class) = signature.declarations().iter().next().unwrap() else {
717            panic!("Expected Class");
718        };
719
720        let members: Vec<_> = class.members().iter().collect();
721
722        // attr_reader name: String - should be Unspecified (inferred as @name)
723        if let Node::AttrReader(attr) = &members[0] {
724            let ivar = attr.ivar_name();
725            assert_eq!(ivar, AttrIvarName::Unspecified);
726        } else {
727            panic!("Expected AttrReader");
728        }
729
730        // attr_accessor age(): Integer - should be Empty (no ivar)
731        if let Node::AttrAccessor(attr) = &members[1] {
732            let ivar = attr.ivar_name();
733            assert_eq!(ivar, AttrIvarName::Empty);
734        } else {
735            panic!("Expected AttrAccessor");
736        }
737
738        // attr_writer email(@email): String - should be Name with constant ID
739        if let Node::AttrWriter(attr) = &members[2] {
740            let ivar = attr.ivar_name();
741            match ivar {
742                AttrIvarName::Name(id) => {
743                    assert!(id > 0, "Expected valid constant ID");
744                }
745                _ => panic!("Expected AttrIvarName::Name, got {:?}", ivar),
746            }
747        } else {
748            panic!("Expected AttrWriter");
749        }
750    }
751}