Skip to main content

ruby_rbs/node/
mod.rs

1include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
2use rbs_encoding_type_t::RBS_ENCODING_UTF_8;
3use ruby_rbs_sys::bindings::*;
4use std::marker::PhantomData;
5use std::ptr::NonNull;
6
7/// Parse RBS code into an AST.
8///
9/// ```rust
10/// use ruby_rbs::node::parse;
11/// let rbs_code = r#"type foo = "hello""#;
12/// let signature = parse(rbs_code.as_bytes());
13/// assert!(signature.is_ok(), "Failed to parse RBS signature");
14/// ```
15pub fn parse(rbs_code: &[u8]) -> Result<SignatureNode<'_>, String> {
16    unsafe {
17        let start_ptr = rbs_code.as_ptr() as *const i8;
18        let end_ptr = start_ptr.add(rbs_code.len());
19
20        let raw_rbs_string_value = rbs_string_new(start_ptr, end_ptr);
21
22        let encoding_ptr = &rbs_encodings[RBS_ENCODING_UTF_8 as usize] as *const rbs_encoding_t;
23        let parser = rbs_parser_new(raw_rbs_string_value, encoding_ptr, 0, rbs_code.len() as i32);
24
25        let mut signature: *mut rbs_signature_t = std::ptr::null_mut();
26        let result = rbs_parse_signature(parser, &mut signature);
27
28        let signature_node = SignatureNode {
29            parser: NonNull::new_unchecked(parser),
30            pointer: signature,
31            marker: PhantomData,
32        };
33
34        if result {
35            Ok(signature_node)
36        } else {
37            Err(String::from("Failed to parse RBS signature"))
38        }
39    }
40}
41
42impl Drop for SignatureNode<'_> {
43    fn drop(&mut self) {
44        unsafe {
45            rbs_parser_free(self.parser.as_ptr());
46        }
47    }
48}
49
50/// Instance variable name specification for attributes.
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub enum AttrIvarName {
53    /// The attribute has inferred instance variable (nil)
54    Unspecified,
55    /// The attribute has no instance variable (false)
56    Empty,
57    /// The attribute has instance variable with the given name
58    Name(rbs_constant_id_t),
59}
60
61impl AttrIvarName {
62    /// Converts the raw C struct to the Rust enum.
63    #[must_use]
64    pub fn from_raw(raw: rbs_attr_ivar_name_t) -> Self {
65        match raw.tag {
66            rbs_attr_ivar_name_tag::RBS_ATTR_IVAR_NAME_TAG_UNSPECIFIED => Self::Unspecified,
67            rbs_attr_ivar_name_tag::RBS_ATTR_IVAR_NAME_TAG_EMPTY => Self::Empty,
68            rbs_attr_ivar_name_tag::RBS_ATTR_IVAR_NAME_TAG_NAME => Self::Name(raw.name),
69            _ => panic!("Unknown ivar_name_tag: {}", raw.tag),
70        }
71    }
72}
73
74pub struct NodeList<'a> {
75    parser: NonNull<rbs_parser_t>,
76    pointer: *mut rbs_node_list_t,
77    marker: PhantomData<&'a mut rbs_node_list_t>,
78}
79
80impl<'a> NodeList<'a> {
81    #[must_use]
82    pub fn new(parser: NonNull<rbs_parser_t>, pointer: *mut rbs_node_list_t) -> Self {
83        Self {
84            parser,
85            pointer,
86            marker: PhantomData,
87        }
88    }
89
90    /// Returns an iterator over the nodes.
91    #[must_use]
92    pub fn iter(&self) -> NodeListIter<'a> {
93        NodeListIter {
94            parser: self.parser,
95            current: unsafe { (*self.pointer).head },
96            marker: PhantomData,
97        }
98    }
99}
100
101pub struct NodeListIter<'a> {
102    parser: NonNull<rbs_parser_t>,
103    current: *mut rbs_node_list_node_t,
104    marker: PhantomData<&'a mut rbs_node_list_node_t>,
105}
106
107impl<'a> Iterator for NodeListIter<'a> {
108    type Item = Node<'a>;
109
110    fn next(&mut self) -> Option<Self::Item> {
111        if self.current.is_null() {
112            None
113        } else {
114            let pointer_data = unsafe { *self.current };
115            let node = Node::new(self.parser, pointer_data.node);
116            self.current = pointer_data.next;
117            Some(node)
118        }
119    }
120}
121
122pub struct RBSHash<'a> {
123    parser: NonNull<rbs_parser_t>,
124    pointer: *mut rbs_hash,
125    marker: PhantomData<&'a mut rbs_hash>,
126}
127
128impl<'a> RBSHash<'a> {
129    #[must_use]
130    pub fn new(parser: NonNull<rbs_parser_t>, pointer: *mut rbs_hash) -> Self {
131        Self {
132            parser,
133            pointer,
134            marker: PhantomData,
135        }
136    }
137
138    /// Returns an iterator over the key-value pairs.
139    #[must_use]
140    pub fn iter(&self) -> RBSHashIter<'a> {
141        RBSHashIter {
142            parser: self.parser,
143            current: unsafe { (*self.pointer).head },
144            marker: PhantomData,
145        }
146    }
147}
148
149pub struct RBSHashIter<'a> {
150    parser: NonNull<rbs_parser_t>,
151    current: *mut rbs_hash_node_t,
152    marker: PhantomData<&'a mut rbs_hash_node_t>,
153}
154
155impl<'a> Iterator for RBSHashIter<'a> {
156    type Item = (Node<'a>, Node<'a>);
157
158    fn next(&mut self) -> Option<Self::Item> {
159        if self.current.is_null() {
160            None
161        } else {
162            let pointer_data = unsafe { *self.current };
163            let key = Node::new(self.parser, pointer_data.key);
164            let value = Node::new(self.parser, pointer_data.value);
165            self.current = pointer_data.next;
166            Some((key, value))
167        }
168    }
169}
170
171pub struct RBSLocationRange {
172    range: rbs_location_range,
173}
174
175impl RBSLocationRange {
176    #[must_use]
177    pub fn new(range: rbs_location_range) -> Self {
178        Self { range }
179    }
180
181    #[must_use]
182    pub fn start(&self) -> i32 {
183        self.range.start_byte
184    }
185
186    #[must_use]
187    pub fn end(&self) -> i32 {
188        self.range.end_byte
189    }
190}
191
192pub struct RBSLocationRangeList<'a> {
193    #[allow(dead_code)]
194    parser: NonNull<rbs_parser_t>,
195    pointer: *mut rbs_location_range_list_t,
196    marker: PhantomData<&'a mut rbs_location_range_list_t>,
197}
198
199impl<'a> RBSLocationRangeList<'a> {
200    /// Returns an iterator over the location ranges.
201    #[must_use]
202    pub fn iter(&self) -> RBSLocationRangeListIter {
203        RBSLocationRangeListIter {
204            current: unsafe { (*self.pointer).head },
205        }
206    }
207}
208
209pub struct RBSLocationRangeListIter {
210    current: *mut rbs_location_range_list_node_t,
211}
212
213impl Iterator for RBSLocationRangeListIter {
214    type Item = RBSLocationRange;
215
216    fn next(&mut self) -> Option<Self::Item> {
217        if self.current.is_null() {
218            None
219        } else {
220            let pointer_data = unsafe { *self.current };
221            let range = RBSLocationRange::new(pointer_data.range);
222            self.current = pointer_data.next;
223            Some(range)
224        }
225    }
226}
227
228#[derive(Debug)]
229pub struct RBSString {
230    pointer: *const rbs_string_t,
231}
232
233impl RBSString {
234    #[must_use]
235    pub fn new(pointer: *const rbs_string_t) -> Self {
236        Self { pointer }
237    }
238
239    #[must_use]
240    pub fn as_bytes(&self) -> &[u8] {
241        unsafe {
242            let s = *self.pointer;
243            std::slice::from_raw_parts(s.start as *const u8, s.end.offset_from(s.start) as usize)
244        }
245    }
246}
247
248impl SymbolNode<'_> {
249    #[must_use]
250    pub fn name(&self) -> &[u8] {
251        unsafe {
252            let constant_ptr = rbs_constant_pool_id_to_constant(
253                &(*self.parser.as_ptr()).constant_pool,
254                (*self.pointer).constant_id,
255            );
256            if constant_ptr.is_null() {
257                panic!("Constant ID for symbol is not present in the pool");
258            }
259
260            let constant = &*constant_ptr;
261            std::slice::from_raw_parts(constant.start, constant.length)
262        }
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn test_parse() {
272        let rbs_code = r#"type foo = "hello""#;
273        let signature = parse(rbs_code.as_bytes());
274        assert!(signature.is_ok(), "Failed to parse RBS signature");
275
276        let rbs_code2 = r#"class Foo end"#;
277        let signature2 = parse(rbs_code2.as_bytes());
278        assert!(signature2.is_ok(), "Failed to parse RBS signature");
279    }
280
281    #[test]
282    fn test_parse_integer() {
283        let rbs_code = r#"type foo = 1"#;
284        let signature = parse(rbs_code.as_bytes());
285        assert!(signature.is_ok(), "Failed to parse RBS signature");
286
287        let signature_node = signature.unwrap();
288        if let Node::TypeAlias(node) = signature_node.declarations().iter().next().unwrap()
289            && let Node::LiteralType(literal) = node.type_()
290            && let Node::Integer(integer) = literal.literal()
291        {
292            assert_eq!(
293                "1".to_string(),
294                String::from_utf8(integer.string_representation().as_bytes().to_vec()).unwrap()
295            );
296        } else {
297            panic!("No literal type node found");
298        }
299    }
300
301    #[test]
302    fn test_rbs_hash_via_record_type() {
303        // RecordType stores its fields in an RBSHash via all_fields()
304        let rbs_code = r#"type foo = { name: String, age: Integer }"#;
305        let signature = parse(rbs_code.as_bytes());
306        assert!(signature.is_ok(), "Failed to parse RBS signature");
307
308        let signature_node = signature.unwrap();
309        if let Node::TypeAlias(type_alias) = signature_node.declarations().iter().next().unwrap()
310            && let Node::RecordType(record) = type_alias.type_()
311        {
312            let hash = record.all_fields();
313            let fields: Vec<_> = hash.iter().collect();
314            assert_eq!(fields.len(), 2, "Expected 2 fields in record");
315
316            // Build a map of field names to type names
317            let mut field_types: Vec<(String, String)> = Vec::new();
318            for (key, value) in &fields {
319                let Node::Symbol(sym) = key else {
320                    panic!("Expected Symbol key");
321                };
322                let Node::RecordFieldType(field_type) = value else {
323                    panic!("Expected RecordFieldType value");
324                };
325                let Node::ClassInstanceType(class_type) = field_type.type_() else {
326                    panic!("Expected ClassInstanceType");
327                };
328
329                let key_name = String::from_utf8(sym.name().to_vec()).unwrap();
330                let type_name_node = class_type.name();
331                let type_name_sym = type_name_node.name();
332                let type_name = String::from_utf8(type_name_sym.name().to_vec()).unwrap();
333                field_types.push((key_name, type_name));
334            }
335
336            assert!(
337                field_types.contains(&("name".to_string(), "String".to_string())),
338                "Expected 'name: String'"
339            );
340            assert!(
341                field_types.contains(&("age".to_string(), "Integer".to_string())),
342                "Expected 'age: Integer'"
343            );
344        } else {
345            panic!("Expected TypeAlias with RecordType");
346        }
347    }
348
349    #[test]
350    fn visitor_test() {
351        struct Visitor {
352            visited: Vec<String>,
353        }
354
355        impl Visit for Visitor {
356            fn visit_bool_type_node(&mut self, node: &BoolTypeNode) {
357                self.visited.push("type:bool".to_string());
358
359                crate::node::visit_bool_type_node(self, node);
360            }
361
362            fn visit_class_node(&mut self, node: &ClassNode) {
363                self.visited.push(format!(
364                    "class:{}",
365                    String::from_utf8(node.name().name().name().to_vec()).unwrap()
366                ));
367
368                crate::node::visit_class_node(self, node);
369            }
370
371            fn visit_class_instance_type_node(&mut self, node: &ClassInstanceTypeNode) {
372                self.visited.push(format!(
373                    "type:{}",
374                    String::from_utf8(node.name().name().name().to_vec()).unwrap()
375                ));
376
377                crate::node::visit_class_instance_type_node(self, node);
378            }
379
380            fn visit_class_super_node(&mut self, node: &ClassSuperNode) {
381                self.visited.push(format!(
382                    "super:{}",
383                    String::from_utf8(node.name().name().name().to_vec()).unwrap()
384                ));
385
386                crate::node::visit_class_super_node(self, node);
387            }
388
389            fn visit_function_type_node(&mut self, node: &FunctionTypeNode) {
390                let count = node.required_positionals().iter().count();
391                self.visited
392                    .push(format!("function:required_positionals:{count}"));
393
394                crate::node::visit_function_type_node(self, node);
395            }
396
397            fn visit_method_definition_node(&mut self, node: &MethodDefinitionNode) {
398                self.visited.push(format!(
399                    "method:{}",
400                    String::from_utf8(node.name().name().to_vec()).unwrap()
401                ));
402
403                crate::node::visit_method_definition_node(self, node);
404            }
405
406            fn visit_record_type_node(&mut self, node: &RecordTypeNode) {
407                self.visited.push("record".to_string());
408
409                crate::node::visit_record_type_node(self, node);
410            }
411
412            fn visit_symbol_node(&mut self, node: &SymbolNode) {
413                self.visited.push(format!(
414                    "symbol:{}",
415                    String::from_utf8(node.name().to_vec()).unwrap()
416                ));
417
418                crate::node::visit_symbol_node(self, node);
419            }
420        }
421
422        let rbs_code = r#"
423            class Foo < Bar
424                def process: ({ name: String, age: Integer }, bool) -> void
425            end
426        "#;
427
428        let signature = parse(rbs_code.as_bytes()).unwrap();
429
430        let mut visitor = Visitor {
431            visited: Vec::new(),
432        };
433
434        visitor.visit(&signature.as_node());
435
436        assert_eq!(
437            vec![
438                "class:Foo",
439                "symbol:Foo",
440                "super:Bar",
441                "symbol:Bar",
442                "method:process",
443                "symbol:process",
444                "function:required_positionals:2",
445                "record",
446                "symbol:name",
447                "type:String",
448                "symbol:String",
449                "symbol:age",
450                "type:Integer",
451                "symbol:Integer",
452                "type:bool",
453            ],
454            visitor.visited
455        );
456    }
457
458    #[test]
459    fn test_node_location_ranges() {
460        let rbs_code = r#"type foo = 1"#;
461        let signature = parse(rbs_code.as_bytes()).unwrap();
462
463        let declaration = signature.declarations().iter().next().unwrap();
464        let Node::TypeAlias(type_alias) = declaration else {
465            panic!("Expected TypeAlias");
466        };
467
468        // TypeAlias spans the entire declaration
469        let loc = type_alias.location();
470        assert_eq!(0, loc.start());
471        assert_eq!(12, loc.end());
472
473        // The literal "1" is at position 11-12
474        let Node::LiteralType(literal) = type_alias.type_() else {
475            panic!("Expected LiteralType");
476        };
477        let Node::Integer(integer) = literal.literal() else {
478            panic!("Expected Integer");
479        };
480
481        let int_loc = integer.location();
482        assert_eq!(11, int_loc.start());
483        assert_eq!(12, int_loc.end());
484    }
485
486    #[test]
487    fn test_enum_types() {
488        let rbs_code = r#"
489            class Foo
490                attr_reader name: String
491                def self.process: () -> void
492                alias instance_method target_method
493                alias self.singleton_method self.target_method
494            end
495
496            class Bar[out T, in U, V]
497            end
498        "#;
499        let signature = parse(rbs_code.as_bytes()).unwrap();
500
501        let declarations: Vec<_> = signature.declarations().iter().collect();
502
503        // Test class Foo
504        let Node::Class(class_foo) = &declarations[0] else {
505            panic!("Expected Class");
506        };
507
508        let members: Vec<_> = class_foo.members().iter().collect();
509
510        // attr_reader - should be instance with unspecified visibility (default)
511        if let Node::AttrReader(attr) = &members[0] {
512            assert_eq!(attr.kind(), AttributeKind::Instance);
513            assert_eq!(attr.visibility(), AttributeVisibility::Unspecified);
514        } else {
515            panic!("Expected AttrReader");
516        }
517
518        // def self.process - should be singleton method with unspecified visibility (default)
519        if let Node::MethodDefinition(method) = &members[1] {
520            assert_eq!(method.kind(), MethodDefinitionKind::Singleton);
521            assert_eq!(method.visibility(), MethodDefinitionVisibility::Unspecified);
522        } else {
523            panic!("Expected MethodDefinition");
524        }
525
526        // alias instance_method
527        if let Node::Alias(alias) = &members[2] {
528            assert_eq!(alias.kind(), AliasKind::Instance);
529        } else {
530            panic!("Expected Alias");
531        }
532
533        // alias self.singleton_method
534        if let Node::Alias(alias) = &members[3] {
535            assert_eq!(alias.kind(), AliasKind::Singleton);
536        } else {
537            panic!("Expected Alias");
538        }
539
540        // Test class Bar with type params
541        let Node::Class(class_bar) = &declarations[1] else {
542            panic!("Expected Class");
543        };
544
545        let type_params: Vec<_> = class_bar.type_params().iter().collect();
546        assert_eq!(type_params.len(), 3);
547
548        // out T - covariant
549        if let Node::TypeParam(param) = &type_params[0] {
550            assert_eq!(param.variance(), TypeParamVariance::Covariant);
551        } else {
552            panic!("Expected TypeParam");
553        }
554
555        // in U - contravariant
556        if let Node::TypeParam(param) = &type_params[1] {
557            assert_eq!(param.variance(), TypeParamVariance::Contravariant);
558        } else {
559            panic!("Expected TypeParam");
560        }
561
562        // V - invariant (default)
563        if let Node::TypeParam(param) = &type_params[2] {
564            assert_eq!(param.variance(), TypeParamVariance::Invariant);
565        } else {
566            panic!("Expected TypeParam");
567        }
568    }
569
570    #[test]
571    fn test_ivar_name_enum() {
572        let rbs_code = r#"
573            class Foo
574                attr_reader name: String
575                attr_accessor age(): Integer
576                attr_writer email(@email): String
577            end
578        "#;
579        let signature = parse(rbs_code.as_bytes()).unwrap();
580
581        let Node::Class(class) = signature.declarations().iter().next().unwrap() else {
582            panic!("Expected Class");
583        };
584
585        let members: Vec<_> = class.members().iter().collect();
586
587        // attr_reader name: String - should be Unspecified (inferred as @name)
588        if let Node::AttrReader(attr) = &members[0] {
589            let ivar = attr.ivar_name();
590            assert_eq!(ivar, AttrIvarName::Unspecified);
591        } else {
592            panic!("Expected AttrReader");
593        }
594
595        // attr_accessor age(): Integer - should be Empty (no ivar)
596        if let Node::AttrAccessor(attr) = &members[1] {
597            let ivar = attr.ivar_name();
598            assert_eq!(ivar, AttrIvarName::Empty);
599        } else {
600            panic!("Expected AttrAccessor");
601        }
602
603        // attr_writer email(@email): String - should be Name with constant ID
604        if let Node::AttrWriter(attr) = &members[2] {
605            let ivar = attr.ivar_name();
606            match ivar {
607                AttrIvarName::Name(id) => {
608                    assert!(id > 0, "Expected valid constant ID");
609                }
610                _ => panic!("Expected AttrIvarName::Name, got {:?}", ivar),
611            }
612        } else {
613            panic!("Expected AttrWriter");
614        }
615    }
616}