1include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
2use rbs_encoding_type_t::RBS_ENCODING_UTF_8;
3use ruby_rbs_sys::bindings::*;
4use std::marker::PhantomData;
5use std::ptr::NonNull;
6
7pub fn parse(rbs_code: &str) -> Result<SignatureNode<'_>, String> {
16 unsafe {
17 let start_ptr = rbs_code.as_ptr().cast::<std::os::raw::c_char>();
18 let end_ptr = start_ptr.add(rbs_code.len());
19 let bytes = rbs_code.len() as i32;
20
21 let raw_rbs_string_value = rbs_string_new(start_ptr, end_ptr);
22
23 let encoding_ptr = &rbs_encodings[RBS_ENCODING_UTF_8 as usize] as *const rbs_encoding_t;
24 let parser = rbs_parser_new(raw_rbs_string_value, encoding_ptr, 0, bytes);
25
26 let mut signature: *mut rbs_signature_t = std::ptr::null_mut();
27 let result = rbs_parse_signature(parser, &mut signature);
28
29 let signature_node = SignatureNode {
30 parser: NonNull::new_unchecked(parser),
31 pointer: signature,
32 marker: PhantomData,
33 };
34
35 if result {
36 Ok(signature_node)
37 } else {
38 let error_message = (*parser)
39 .error
40 .as_ref()
41 .filter(|error| !error.message.is_null())
42 .map(|error| {
43 std::ffi::CStr::from_ptr(error.message)
44 .to_string_lossy()
45 .into_owned()
46 })
47 .unwrap_or_else(|| String::from("Failed to parse RBS signature"));
48
49 Err(error_message)
50 }
51 }
52}
53
54impl Drop for SignatureNode<'_> {
55 fn drop(&mut self) {
56 unsafe {
57 rbs_parser_free(self.parser.as_ptr());
58 }
59 }
60}
61
62#[derive(Debug, Clone, PartialEq, Eq)]
64pub enum AttrIvarName {
65 Unspecified,
67 Empty,
69 Name(rbs_constant_id_t),
71}
72
73impl AttrIvarName {
74 #[must_use]
76 pub fn from_raw(raw: rbs_attr_ivar_name_t) -> Self {
77 match raw.tag {
78 rbs_attr_ivar_name_tag::RBS_ATTR_IVAR_NAME_TAG_UNSPECIFIED => Self::Unspecified,
79 rbs_attr_ivar_name_tag::RBS_ATTR_IVAR_NAME_TAG_EMPTY => Self::Empty,
80 rbs_attr_ivar_name_tag::RBS_ATTR_IVAR_NAME_TAG_NAME => Self::Name(raw.name),
81 _ => panic!("Unknown ivar_name_tag: {}", raw.tag),
82 }
83 }
84}
85
86pub struct NodeList<'a> {
87 parser: NonNull<rbs_parser_t>,
88 pointer: *mut rbs_node_list_t,
89 marker: PhantomData<&'a mut rbs_node_list_t>,
90}
91
92impl<'a> NodeList<'a> {
93 #[must_use]
94 pub fn new(parser: NonNull<rbs_parser_t>, pointer: *mut rbs_node_list_t) -> Self {
95 Self {
96 parser,
97 pointer,
98 marker: PhantomData,
99 }
100 }
101
102 #[must_use]
104 pub fn iter(&self) -> NodeListIter<'a> {
105 NodeListIter {
106 parser: self.parser,
107 current: unsafe { (*self.pointer).head },
108 marker: PhantomData,
109 }
110 }
111}
112
113pub struct NodeListIter<'a> {
114 parser: NonNull<rbs_parser_t>,
115 current: *mut rbs_node_list_node_t,
116 marker: PhantomData<&'a mut rbs_node_list_node_t>,
117}
118
119impl<'a> Iterator for NodeListIter<'a> {
120 type Item = Node<'a>;
121
122 fn next(&mut self) -> Option<Self::Item> {
123 if self.current.is_null() {
124 None
125 } else {
126 let pointer_data = unsafe { *self.current };
127 let node = Node::new(self.parser, pointer_data.node);
128 self.current = pointer_data.next;
129 Some(node)
130 }
131 }
132}
133
134pub struct RBSHash<'a> {
135 parser: NonNull<rbs_parser_t>,
136 pointer: *mut rbs_hash,
137 marker: PhantomData<&'a mut rbs_hash>,
138}
139
140impl<'a> RBSHash<'a> {
141 #[must_use]
142 pub fn new(parser: NonNull<rbs_parser_t>, pointer: *mut rbs_hash) -> Self {
143 Self {
144 parser,
145 pointer,
146 marker: PhantomData,
147 }
148 }
149
150 #[must_use]
152 pub fn iter(&self) -> RBSHashIter<'a> {
153 RBSHashIter {
154 parser: self.parser,
155 current: unsafe { (*self.pointer).head },
156 marker: PhantomData,
157 }
158 }
159}
160
161pub struct RBSHashIter<'a> {
162 parser: NonNull<rbs_parser_t>,
163 current: *mut rbs_hash_node_t,
164 marker: PhantomData<&'a mut rbs_hash_node_t>,
165}
166
167impl<'a> Iterator for RBSHashIter<'a> {
168 type Item = (Node<'a>, Node<'a>);
169
170 fn next(&mut self) -> Option<Self::Item> {
171 if self.current.is_null() {
172 None
173 } else {
174 let pointer_data = unsafe { *self.current };
175 let key = Node::new(self.parser, pointer_data.key);
176 let value = Node::new(self.parser, pointer_data.value);
177 self.current = pointer_data.next;
178 Some((key, value))
179 }
180 }
181}
182
183pub struct RBSLocationRange {
184 range: rbs_location_range,
185}
186
187impl RBSLocationRange {
188 #[must_use]
189 pub fn new(range: rbs_location_range) -> Self {
190 Self { range }
191 }
192
193 #[must_use]
194 pub fn start(&self) -> i32 {
195 self.range.start_byte
196 }
197
198 #[must_use]
199 pub fn end(&self) -> i32 {
200 self.range.end_byte
201 }
202}
203
204pub struct RBSLocationRangeList<'a> {
205 #[allow(dead_code)]
206 parser: NonNull<rbs_parser_t>,
207 pointer: *mut rbs_location_range_list_t,
208 marker: PhantomData<&'a mut rbs_location_range_list_t>,
209}
210
211impl<'a> RBSLocationRangeList<'a> {
212 #[must_use]
214 pub fn iter(&self) -> RBSLocationRangeListIter {
215 RBSLocationRangeListIter {
216 current: unsafe { (*self.pointer).head },
217 }
218 }
219}
220
221pub struct RBSLocationRangeListIter {
222 current: *mut rbs_location_range_list_node_t,
223}
224
225impl Iterator for RBSLocationRangeListIter {
226 type Item = RBSLocationRange;
227
228 fn next(&mut self) -> Option<Self::Item> {
229 if self.current.is_null() {
230 None
231 } else {
232 let pointer_data = unsafe { *self.current };
233 let range = RBSLocationRange::new(pointer_data.range);
234 self.current = pointer_data.next;
235 Some(range)
236 }
237 }
238}
239
240#[derive(Debug)]
241pub struct RBSString<'a> {
242 pointer: *const rbs_string_t,
243 marker: PhantomData<&'a rbs_string_t>,
244}
245
246impl<'a> RBSString<'a> {
247 #[must_use]
248 pub fn new(pointer: *const rbs_string_t) -> Self {
249 Self {
250 pointer,
251 marker: PhantomData,
252 }
253 }
254
255 #[must_use]
256 #[allow(clippy::unnecessary_cast)]
257 pub fn as_bytes(&self) -> &[u8] {
258 unsafe {
259 let s = *self.pointer;
260 std::slice::from_raw_parts(s.start as *const u8, s.end.offset_from(s.start) as usize)
261 }
262 }
263
264 #[must_use]
265 pub fn as_str(&self) -> &str {
266 unsafe { std::str::from_utf8_unchecked(self.as_bytes()) }
267 }
268}
269
270impl std::fmt::Display for RBSString<'_> {
271 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272 f.write_str(self.as_str())
273 }
274}
275
276impl SymbolNode<'_> {
277 #[must_use]
278 pub fn as_bytes(&self) -> &[u8] {
279 unsafe {
280 let constant_ptr = rbs_constant_pool_id_to_constant(
281 &(*self.parser.as_ptr()).constant_pool,
282 (*self.pointer).constant_id,
283 );
284 if constant_ptr.is_null() {
285 panic!("Constant ID for symbol is not present in the pool");
286 }
287
288 let constant = &*constant_ptr;
289 std::slice::from_raw_parts(constant.start, constant.length)
290 }
291 }
292
293 #[must_use]
294 pub fn as_str(&self) -> &str {
295 unsafe { std::str::from_utf8_unchecked(self.as_bytes()) }
296 }
297}
298
299impl std::fmt::Display for SymbolNode<'_> {
300 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301 f.write_str(self.as_str())
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn test_parse_error_contains_actual_message() {
311 let rbs_code = "class { end";
312 let result = parse(rbs_code);
313 let error_message = result.unwrap_err();
314 assert_eq!(error_message, "expected one of class/module/constant name");
315 }
316
317 #[test]
318 fn test_parse() {
319 let rbs_code = r#"type foo = "hello""#;
320 let signature = parse(rbs_code);
321 assert!(signature.is_ok(), "Failed to parse RBS signature");
322
323 let rbs_code2 = r#"class Foo end"#;
324 let signature2 = parse(rbs_code2);
325 assert!(signature2.is_ok(), "Failed to parse RBS signature");
326 }
327
328 #[test]
329 fn test_parse_integer() {
330 let rbs_code = r#"type foo = 1"#;
331 let signature = parse(rbs_code);
332 assert!(signature.is_ok(), "Failed to parse RBS signature");
333
334 let signature_node = signature.unwrap();
335 if let Node::TypeAlias(node) = signature_node.declarations().iter().next().unwrap()
336 && let Node::LiteralType(literal) = node.type_()
337 && let Node::Integer(integer) = literal.literal()
338 {
339 assert_eq!(integer.string_representation().as_str(), "1");
340 } else {
341 panic!("No literal type node found");
342 }
343 }
344
345 #[test]
346 fn test_rbs_hash_via_record_type() {
347 let rbs_code = r#"type foo = { name: String, age: Integer }"#;
349 let signature = parse(rbs_code);
350 assert!(signature.is_ok(), "Failed to parse RBS signature");
351
352 let signature_node = signature.unwrap();
353 if let Node::TypeAlias(type_alias) = signature_node.declarations().iter().next().unwrap()
354 && let Node::RecordType(record) = type_alias.type_()
355 {
356 let hash = record.all_fields();
357 let fields: Vec<_> = hash.iter().collect();
358 assert_eq!(fields.len(), 2, "Expected 2 fields in record");
359
360 let mut field_types: Vec<(String, String)> = Vec::new();
362 for (key, value) in &fields {
363 let Node::Symbol(sym) = key else {
364 panic!("Expected Symbol key");
365 };
366 let Node::RecordFieldType(field_type) = value else {
367 panic!("Expected RecordFieldType value");
368 };
369 let Node::ClassInstanceType(class_type) = field_type.type_() else {
370 panic!("Expected ClassInstanceType");
371 };
372
373 let key_name = sym.to_string();
374 let type_name_node = class_type.name();
375 let type_name_sym = type_name_node.name();
376 let type_name = type_name_sym.to_string();
377 field_types.push((key_name, type_name));
378 }
379
380 assert!(
381 field_types.contains(&("name".to_string(), "String".to_string())),
382 "Expected 'name: String'"
383 );
384 assert!(
385 field_types.contains(&("age".to_string(), "Integer".to_string())),
386 "Expected 'age: Integer'"
387 );
388 } else {
389 panic!("Expected TypeAlias with RecordType");
390 }
391 }
392
393 #[test]
394 fn visitor_test() {
395 struct Visitor {
396 visited: Vec<String>,
397 }
398
399 impl Visit for Visitor {
400 fn visit_bool_type_node(&mut self, node: &BoolTypeNode) {
401 self.visited.push("type:bool".to_string());
402
403 crate::node::visit_bool_type_node(self, node);
404 }
405
406 fn visit_class_node(&mut self, node: &ClassNode) {
407 self.visited.push(format!("class:{}", node.name().name()));
408
409 crate::node::visit_class_node(self, node);
410 }
411
412 fn visit_class_instance_type_node(&mut self, node: &ClassInstanceTypeNode) {
413 self.visited.push(format!("type:{}", node.name().name()));
414
415 crate::node::visit_class_instance_type_node(self, node);
416 }
417
418 fn visit_class_super_node(&mut self, node: &ClassSuperNode) {
419 self.visited.push(format!("super:{}", node.name().name()));
420
421 crate::node::visit_class_super_node(self, node);
422 }
423
424 fn visit_function_type_node(&mut self, node: &FunctionTypeNode) {
425 let count = node.required_positionals().iter().count();
426 self.visited
427 .push(format!("function:required_positionals:{count}"));
428
429 crate::node::visit_function_type_node(self, node);
430 }
431
432 fn visit_method_definition_node(&mut self, node: &MethodDefinitionNode) {
433 self.visited.push(format!("method:{}", node.name()));
434
435 crate::node::visit_method_definition_node(self, node);
436 }
437
438 fn visit_record_type_node(&mut self, node: &RecordTypeNode) {
439 self.visited.push("record".to_string());
440
441 crate::node::visit_record_type_node(self, node);
442 }
443
444 fn visit_symbol_node(&mut self, node: &SymbolNode) {
445 self.visited.push(format!("symbol:{node}"));
446
447 crate::node::visit_symbol_node(self, node);
448 }
449 }
450
451 let rbs_code = r#"
452 class Foo < Bar
453 def process: ({ name: String, age: Integer }, bool) -> void
454 end
455 "#;
456
457 let signature = parse(rbs_code).unwrap();
458
459 let mut visitor = Visitor {
460 visited: Vec::new(),
461 };
462
463 visitor.visit(&signature.as_node());
464
465 assert_eq!(
466 vec![
467 "class:Foo",
468 "symbol:Foo",
469 "super:Bar",
470 "symbol:Bar",
471 "method:process",
472 "symbol:process",
473 "function:required_positionals:2",
474 "record",
475 "symbol:name",
476 "type:String",
477 "symbol:String",
478 "symbol:age",
479 "type:Integer",
480 "symbol:Integer",
481 "type:bool",
482 ],
483 visitor.visited
484 );
485 }
486
487 #[test]
488 fn test_node_location_ranges() {
489 let rbs_code = r#"type foo = 1"#;
490 let signature = parse(rbs_code).unwrap();
491
492 let declaration = signature.declarations().iter().next().unwrap();
493 let Node::TypeAlias(type_alias) = declaration else {
494 panic!("Expected TypeAlias");
495 };
496
497 let loc = type_alias.location();
499 assert_eq!(0, loc.start());
500 assert_eq!(12, loc.end());
501
502 let Node::LiteralType(literal) = type_alias.type_() else {
504 panic!("Expected LiteralType");
505 };
506 let Node::Integer(integer) = literal.literal() else {
507 panic!("Expected Integer");
508 };
509
510 let int_loc = integer.location();
511 assert_eq!(11, int_loc.start());
512 assert_eq!(12, int_loc.end());
513 }
514
515 #[test]
516 fn test_sub_locations() {
517 let rbs_code = r#"class Foo < Bar end"#;
518 let signature = parse(rbs_code).unwrap();
519
520 let declaration = signature.declarations().iter().next().unwrap();
521 let Node::Class(class) = declaration else {
522 panic!("Expected Class");
523 };
524
525 let keyword_loc = class.keyword_location();
527 assert_eq!(0, keyword_loc.start());
528 assert_eq!(5, keyword_loc.end());
529
530 let name_loc = class.name_location();
531 assert_eq!(6, name_loc.start());
532 assert_eq!(9, name_loc.end());
533
534 let end_loc = class.end_location();
535 assert_eq!(16, end_loc.start());
536 assert_eq!(19, end_loc.end());
537
538 let lt_loc = class.lt_location();
540 assert!(lt_loc.is_some());
541 let lt = lt_loc.unwrap();
542 assert_eq!(10, lt.start());
543 assert_eq!(11, lt.end());
544
545 let type_params_loc = class.type_params_location();
547 assert!(type_params_loc.is_none());
548 }
549
550 #[test]
551 fn test_type_alias_sub_locations() {
552 let rbs_code = r#"type foo = String"#;
553 let signature = parse(rbs_code).unwrap();
554
555 let declaration = signature.declarations().iter().next().unwrap();
556 let Node::TypeAlias(type_alias) = declaration else {
557 panic!("Expected TypeAlias");
558 };
559
560 let keyword_loc = type_alias.keyword_location();
562 assert_eq!(0, keyword_loc.start());
563 assert_eq!(4, keyword_loc.end());
564
565 let name_loc = type_alias.name_location();
566 assert_eq!(5, name_loc.start());
567 assert_eq!(8, name_loc.end());
568
569 let eq_loc = type_alias.eq_location();
570 assert_eq!(9, eq_loc.start());
571 assert_eq!(10, eq_loc.end());
572
573 let type_params_loc = type_alias.type_params_location();
575 assert!(type_params_loc.is_none());
576 }
577
578 #[test]
579 fn test_module_sub_locations() {
580 let rbs_code = r#"module Foo[T] : Bar end"#;
581 let signature = parse(rbs_code).unwrap();
582
583 let declaration = signature.declarations().iter().next().unwrap();
584 let Node::Module(module) = declaration else {
585 panic!("Expected Module");
586 };
587
588 let keyword_loc = module.keyword_location();
590 assert_eq!(0, keyword_loc.start());
591 assert_eq!(6, keyword_loc.end());
592
593 let name_loc = module.name_location();
594 assert_eq!(7, name_loc.start());
595 assert_eq!(10, name_loc.end());
596
597 let end_loc = module.end_location();
598 assert_eq!(20, end_loc.start());
599 assert_eq!(23, end_loc.end());
600
601 let type_params_loc = module.type_params_location();
603 assert!(type_params_loc.is_some());
604 let tp = type_params_loc.unwrap();
605 assert_eq!(10, tp.start());
606 assert_eq!(13, tp.end());
607
608 let colon_loc = module.colon_location();
609 assert!(colon_loc.is_some());
610 let colon = colon_loc.unwrap();
611 assert_eq!(14, colon.start());
612 assert_eq!(15, colon.end());
613
614 let self_types_loc = module.self_types_location();
615 assert!(self_types_loc.is_some());
616 let st = self_types_loc.unwrap();
617 assert_eq!(16, st.start());
618 assert_eq!(19, st.end());
619 }
620
621 #[test]
622 fn test_enum_types() {
623 let rbs_code = r#"
624 class Foo
625 attr_reader name: String
626 def self.process: () -> void
627 alias instance_method target_method
628 alias self.singleton_method self.target_method
629 end
630
631 class Bar[out T, in U, V]
632 end
633 "#;
634 let signature = parse(rbs_code).unwrap();
635
636 let declarations: Vec<_> = signature.declarations().iter().collect();
637
638 let Node::Class(class_foo) = &declarations[0] else {
640 panic!("Expected Class");
641 };
642
643 let members: Vec<_> = class_foo.members().iter().collect();
644
645 if let Node::AttrReader(attr) = &members[0] {
647 assert_eq!(attr.kind(), AttributeKind::Instance);
648 assert_eq!(attr.visibility(), AttributeVisibility::Unspecified);
649 } else {
650 panic!("Expected AttrReader");
651 }
652
653 if let Node::MethodDefinition(method) = &members[1] {
655 assert_eq!(method.kind(), MethodDefinitionKind::Singleton);
656 assert_eq!(method.visibility(), MethodDefinitionVisibility::Unspecified);
657 } else {
658 panic!("Expected MethodDefinition");
659 }
660
661 if let Node::Alias(alias) = &members[2] {
663 assert_eq!(alias.kind(), AliasKind::Instance);
664 } else {
665 panic!("Expected Alias");
666 }
667
668 if let Node::Alias(alias) = &members[3] {
670 assert_eq!(alias.kind(), AliasKind::Singleton);
671 } else {
672 panic!("Expected Alias");
673 }
674
675 let Node::Class(class_bar) = &declarations[1] else {
677 panic!("Expected Class");
678 };
679
680 let type_params: Vec<_> = class_bar.type_params().iter().collect();
681 assert_eq!(type_params.len(), 3);
682
683 if let Node::TypeParam(param) = &type_params[0] {
685 assert_eq!(param.variance(), TypeParamVariance::Covariant);
686 } else {
687 panic!("Expected TypeParam");
688 }
689
690 if let Node::TypeParam(param) = &type_params[1] {
692 assert_eq!(param.variance(), TypeParamVariance::Contravariant);
693 } else {
694 panic!("Expected TypeParam");
695 }
696
697 if let Node::TypeParam(param) = &type_params[2] {
699 assert_eq!(param.variance(), TypeParamVariance::Invariant);
700 } else {
701 panic!("Expected TypeParam");
702 }
703 }
704
705 #[test]
706 fn test_ivar_name_enum() {
707 let rbs_code = r#"
708 class Foo
709 attr_reader name: String
710 attr_accessor age(): Integer
711 attr_writer email(@email): String
712 end
713 "#;
714 let signature = parse(rbs_code).unwrap();
715
716 let Node::Class(class) = signature.declarations().iter().next().unwrap() else {
717 panic!("Expected Class");
718 };
719
720 let members: Vec<_> = class.members().iter().collect();
721
722 if let Node::AttrReader(attr) = &members[0] {
724 let ivar = attr.ivar_name();
725 assert_eq!(ivar, AttrIvarName::Unspecified);
726 } else {
727 panic!("Expected AttrReader");
728 }
729
730 if let Node::AttrAccessor(attr) = &members[1] {
732 let ivar = attr.ivar_name();
733 assert_eq!(ivar, AttrIvarName::Empty);
734 } else {
735 panic!("Expected AttrAccessor");
736 }
737
738 if let Node::AttrWriter(attr) = &members[2] {
740 let ivar = attr.ivar_name();
741 match ivar {
742 AttrIvarName::Name(id) => {
743 assert!(id > 0, "Expected valid constant ID");
744 }
745 _ => panic!("Expected AttrIvarName::Name, got {:?}", ivar),
746 }
747 } else {
748 panic!("Expected AttrWriter");
749 }
750 }
751}