1include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
2use rbs_encoding_type_t::RBS_ENCODING_UTF_8;
3use ruby_rbs_sys::bindings::*;
4use std::marker::PhantomData;
5use std::ptr::NonNull;
6
7pub fn parse(rbs_code: &[u8]) -> Result<SignatureNode<'_>, String> {
16 unsafe {
17 let start_ptr = rbs_code.as_ptr() as *const i8;
18 let end_ptr = start_ptr.add(rbs_code.len());
19
20 let raw_rbs_string_value = rbs_string_new(start_ptr, end_ptr);
21
22 let encoding_ptr = &rbs_encodings[RBS_ENCODING_UTF_8 as usize] as *const rbs_encoding_t;
23 let parser = rbs_parser_new(raw_rbs_string_value, encoding_ptr, 0, rbs_code.len() as i32);
24
25 let mut signature: *mut rbs_signature_t = std::ptr::null_mut();
26 let result = rbs_parse_signature(parser, &mut signature);
27
28 let signature_node = SignatureNode {
29 parser: NonNull::new_unchecked(parser),
30 pointer: signature,
31 marker: PhantomData,
32 };
33
34 if result {
35 Ok(signature_node)
36 } else {
37 Err(String::from("Failed to parse RBS signature"))
38 }
39 }
40}
41
42impl Drop for SignatureNode<'_> {
43 fn drop(&mut self) {
44 unsafe {
45 rbs_parser_free(self.parser.as_ptr());
46 }
47 }
48}
49
50#[derive(Debug, Clone, PartialEq, Eq)]
52pub enum AttrIvarName {
53 Unspecified,
55 Empty,
57 Name(rbs_constant_id_t),
59}
60
61impl AttrIvarName {
62 #[must_use]
64 pub fn from_raw(raw: rbs_attr_ivar_name_t) -> Self {
65 match raw.tag {
66 rbs_attr_ivar_name_tag::RBS_ATTR_IVAR_NAME_TAG_UNSPECIFIED => Self::Unspecified,
67 rbs_attr_ivar_name_tag::RBS_ATTR_IVAR_NAME_TAG_EMPTY => Self::Empty,
68 rbs_attr_ivar_name_tag::RBS_ATTR_IVAR_NAME_TAG_NAME => Self::Name(raw.name),
69 _ => panic!("Unknown ivar_name_tag: {}", raw.tag),
70 }
71 }
72}
73
74pub struct NodeList<'a> {
75 parser: NonNull<rbs_parser_t>,
76 pointer: *mut rbs_node_list_t,
77 marker: PhantomData<&'a mut rbs_node_list_t>,
78}
79
80impl<'a> NodeList<'a> {
81 #[must_use]
82 pub fn new(parser: NonNull<rbs_parser_t>, pointer: *mut rbs_node_list_t) -> Self {
83 Self {
84 parser,
85 pointer,
86 marker: PhantomData,
87 }
88 }
89
90 #[must_use]
92 pub fn iter(&self) -> NodeListIter<'a> {
93 NodeListIter {
94 parser: self.parser,
95 current: unsafe { (*self.pointer).head },
96 marker: PhantomData,
97 }
98 }
99}
100
101pub struct NodeListIter<'a> {
102 parser: NonNull<rbs_parser_t>,
103 current: *mut rbs_node_list_node_t,
104 marker: PhantomData<&'a mut rbs_node_list_node_t>,
105}
106
107impl<'a> Iterator for NodeListIter<'a> {
108 type Item = Node<'a>;
109
110 fn next(&mut self) -> Option<Self::Item> {
111 if self.current.is_null() {
112 None
113 } else {
114 let pointer_data = unsafe { *self.current };
115 let node = Node::new(self.parser, pointer_data.node);
116 self.current = pointer_data.next;
117 Some(node)
118 }
119 }
120}
121
122pub struct RBSHash<'a> {
123 parser: NonNull<rbs_parser_t>,
124 pointer: *mut rbs_hash,
125 marker: PhantomData<&'a mut rbs_hash>,
126}
127
128impl<'a> RBSHash<'a> {
129 #[must_use]
130 pub fn new(parser: NonNull<rbs_parser_t>, pointer: *mut rbs_hash) -> Self {
131 Self {
132 parser,
133 pointer,
134 marker: PhantomData,
135 }
136 }
137
138 #[must_use]
140 pub fn iter(&self) -> RBSHashIter<'a> {
141 RBSHashIter {
142 parser: self.parser,
143 current: unsafe { (*self.pointer).head },
144 marker: PhantomData,
145 }
146 }
147}
148
149pub struct RBSHashIter<'a> {
150 parser: NonNull<rbs_parser_t>,
151 current: *mut rbs_hash_node_t,
152 marker: PhantomData<&'a mut rbs_hash_node_t>,
153}
154
155impl<'a> Iterator for RBSHashIter<'a> {
156 type Item = (Node<'a>, Node<'a>);
157
158 fn next(&mut self) -> Option<Self::Item> {
159 if self.current.is_null() {
160 None
161 } else {
162 let pointer_data = unsafe { *self.current };
163 let key = Node::new(self.parser, pointer_data.key);
164 let value = Node::new(self.parser, pointer_data.value);
165 self.current = pointer_data.next;
166 Some((key, value))
167 }
168 }
169}
170
171pub struct RBSLocationRange {
172 range: rbs_location_range,
173}
174
175impl RBSLocationRange {
176 #[must_use]
177 pub fn new(range: rbs_location_range) -> Self {
178 Self { range }
179 }
180
181 #[must_use]
182 pub fn start(&self) -> i32 {
183 self.range.start_byte
184 }
185
186 #[must_use]
187 pub fn end(&self) -> i32 {
188 self.range.end_byte
189 }
190}
191
192pub struct RBSLocationRangeList<'a> {
193 #[allow(dead_code)]
194 parser: NonNull<rbs_parser_t>,
195 pointer: *mut rbs_location_range_list_t,
196 marker: PhantomData<&'a mut rbs_location_range_list_t>,
197}
198
199impl<'a> RBSLocationRangeList<'a> {
200 #[must_use]
202 pub fn iter(&self) -> RBSLocationRangeListIter {
203 RBSLocationRangeListIter {
204 current: unsafe { (*self.pointer).head },
205 }
206 }
207}
208
209pub struct RBSLocationRangeListIter {
210 current: *mut rbs_location_range_list_node_t,
211}
212
213impl Iterator for RBSLocationRangeListIter {
214 type Item = RBSLocationRange;
215
216 fn next(&mut self) -> Option<Self::Item> {
217 if self.current.is_null() {
218 None
219 } else {
220 let pointer_data = unsafe { *self.current };
221 let range = RBSLocationRange::new(pointer_data.range);
222 self.current = pointer_data.next;
223 Some(range)
224 }
225 }
226}
227
228#[derive(Debug)]
229pub struct RBSString {
230 pointer: *const rbs_string_t,
231}
232
233impl RBSString {
234 #[must_use]
235 pub fn new(pointer: *const rbs_string_t) -> Self {
236 Self { pointer }
237 }
238
239 #[must_use]
240 pub fn as_bytes(&self) -> &[u8] {
241 unsafe {
242 let s = *self.pointer;
243 std::slice::from_raw_parts(s.start as *const u8, s.end.offset_from(s.start) as usize)
244 }
245 }
246}
247
248impl SymbolNode<'_> {
249 #[must_use]
250 pub fn name(&self) -> &[u8] {
251 unsafe {
252 let constant_ptr = rbs_constant_pool_id_to_constant(
253 &(*self.parser.as_ptr()).constant_pool,
254 (*self.pointer).constant_id,
255 );
256 if constant_ptr.is_null() {
257 panic!("Constant ID for symbol is not present in the pool");
258 }
259
260 let constant = &*constant_ptr;
261 std::slice::from_raw_parts(constant.start, constant.length)
262 }
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn test_parse() {
272 let rbs_code = r#"type foo = "hello""#;
273 let signature = parse(rbs_code.as_bytes());
274 assert!(signature.is_ok(), "Failed to parse RBS signature");
275
276 let rbs_code2 = r#"class Foo end"#;
277 let signature2 = parse(rbs_code2.as_bytes());
278 assert!(signature2.is_ok(), "Failed to parse RBS signature");
279 }
280
281 #[test]
282 fn test_parse_integer() {
283 let rbs_code = r#"type foo = 1"#;
284 let signature = parse(rbs_code.as_bytes());
285 assert!(signature.is_ok(), "Failed to parse RBS signature");
286
287 let signature_node = signature.unwrap();
288 if let Node::TypeAlias(node) = signature_node.declarations().iter().next().unwrap()
289 && let Node::LiteralType(literal) = node.type_()
290 && let Node::Integer(integer) = literal.literal()
291 {
292 assert_eq!(
293 "1".to_string(),
294 String::from_utf8(integer.string_representation().as_bytes().to_vec()).unwrap()
295 );
296 } else {
297 panic!("No literal type node found");
298 }
299 }
300
301 #[test]
302 fn test_rbs_hash_via_record_type() {
303 let rbs_code = r#"type foo = { name: String, age: Integer }"#;
305 let signature = parse(rbs_code.as_bytes());
306 assert!(signature.is_ok(), "Failed to parse RBS signature");
307
308 let signature_node = signature.unwrap();
309 if let Node::TypeAlias(type_alias) = signature_node.declarations().iter().next().unwrap()
310 && let Node::RecordType(record) = type_alias.type_()
311 {
312 let hash = record.all_fields();
313 let fields: Vec<_> = hash.iter().collect();
314 assert_eq!(fields.len(), 2, "Expected 2 fields in record");
315
316 let mut field_types: Vec<(String, String)> = Vec::new();
318 for (key, value) in &fields {
319 let Node::Symbol(sym) = key else {
320 panic!("Expected Symbol key");
321 };
322 let Node::RecordFieldType(field_type) = value else {
323 panic!("Expected RecordFieldType value");
324 };
325 let Node::ClassInstanceType(class_type) = field_type.type_() else {
326 panic!("Expected ClassInstanceType");
327 };
328
329 let key_name = String::from_utf8(sym.name().to_vec()).unwrap();
330 let type_name_node = class_type.name();
331 let type_name_sym = type_name_node.name();
332 let type_name = String::from_utf8(type_name_sym.name().to_vec()).unwrap();
333 field_types.push((key_name, type_name));
334 }
335
336 assert!(
337 field_types.contains(&("name".to_string(), "String".to_string())),
338 "Expected 'name: String'"
339 );
340 assert!(
341 field_types.contains(&("age".to_string(), "Integer".to_string())),
342 "Expected 'age: Integer'"
343 );
344 } else {
345 panic!("Expected TypeAlias with RecordType");
346 }
347 }
348
349 #[test]
350 fn visitor_test() {
351 struct Visitor {
352 visited: Vec<String>,
353 }
354
355 impl Visit for Visitor {
356 fn visit_bool_type_node(&mut self, node: &BoolTypeNode) {
357 self.visited.push("type:bool".to_string());
358
359 crate::node::visit_bool_type_node(self, node);
360 }
361
362 fn visit_class_node(&mut self, node: &ClassNode) {
363 self.visited.push(format!(
364 "class:{}",
365 String::from_utf8(node.name().name().name().to_vec()).unwrap()
366 ));
367
368 crate::node::visit_class_node(self, node);
369 }
370
371 fn visit_class_instance_type_node(&mut self, node: &ClassInstanceTypeNode) {
372 self.visited.push(format!(
373 "type:{}",
374 String::from_utf8(node.name().name().name().to_vec()).unwrap()
375 ));
376
377 crate::node::visit_class_instance_type_node(self, node);
378 }
379
380 fn visit_class_super_node(&mut self, node: &ClassSuperNode) {
381 self.visited.push(format!(
382 "super:{}",
383 String::from_utf8(node.name().name().name().to_vec()).unwrap()
384 ));
385
386 crate::node::visit_class_super_node(self, node);
387 }
388
389 fn visit_function_type_node(&mut self, node: &FunctionTypeNode) {
390 let count = node.required_positionals().iter().count();
391 self.visited
392 .push(format!("function:required_positionals:{count}"));
393
394 crate::node::visit_function_type_node(self, node);
395 }
396
397 fn visit_method_definition_node(&mut self, node: &MethodDefinitionNode) {
398 self.visited.push(format!(
399 "method:{}",
400 String::from_utf8(node.name().name().to_vec()).unwrap()
401 ));
402
403 crate::node::visit_method_definition_node(self, node);
404 }
405
406 fn visit_record_type_node(&mut self, node: &RecordTypeNode) {
407 self.visited.push("record".to_string());
408
409 crate::node::visit_record_type_node(self, node);
410 }
411
412 fn visit_symbol_node(&mut self, node: &SymbolNode) {
413 self.visited.push(format!(
414 "symbol:{}",
415 String::from_utf8(node.name().to_vec()).unwrap()
416 ));
417
418 crate::node::visit_symbol_node(self, node);
419 }
420 }
421
422 let rbs_code = r#"
423 class Foo < Bar
424 def process: ({ name: String, age: Integer }, bool) -> void
425 end
426 "#;
427
428 let signature = parse(rbs_code.as_bytes()).unwrap();
429
430 let mut visitor = Visitor {
431 visited: Vec::new(),
432 };
433
434 visitor.visit(&signature.as_node());
435
436 assert_eq!(
437 vec![
438 "class:Foo",
439 "symbol:Foo",
440 "super:Bar",
441 "symbol:Bar",
442 "method:process",
443 "symbol:process",
444 "function:required_positionals:2",
445 "record",
446 "symbol:name",
447 "type:String",
448 "symbol:String",
449 "symbol:age",
450 "type:Integer",
451 "symbol:Integer",
452 "type:bool",
453 ],
454 visitor.visited
455 );
456 }
457
458 #[test]
459 fn test_node_location_ranges() {
460 let rbs_code = r#"type foo = 1"#;
461 let signature = parse(rbs_code.as_bytes()).unwrap();
462
463 let declaration = signature.declarations().iter().next().unwrap();
464 let Node::TypeAlias(type_alias) = declaration else {
465 panic!("Expected TypeAlias");
466 };
467
468 let loc = type_alias.location();
470 assert_eq!(0, loc.start());
471 assert_eq!(12, loc.end());
472
473 let Node::LiteralType(literal) = type_alias.type_() else {
475 panic!("Expected LiteralType");
476 };
477 let Node::Integer(integer) = literal.literal() else {
478 panic!("Expected Integer");
479 };
480
481 let int_loc = integer.location();
482 assert_eq!(11, int_loc.start());
483 assert_eq!(12, int_loc.end());
484 }
485
486 #[test]
487 fn test_enum_types() {
488 let rbs_code = r#"
489 class Foo
490 attr_reader name: String
491 def self.process: () -> void
492 alias instance_method target_method
493 alias self.singleton_method self.target_method
494 end
495
496 class Bar[out T, in U, V]
497 end
498 "#;
499 let signature = parse(rbs_code.as_bytes()).unwrap();
500
501 let declarations: Vec<_> = signature.declarations().iter().collect();
502
503 let Node::Class(class_foo) = &declarations[0] else {
505 panic!("Expected Class");
506 };
507
508 let members: Vec<_> = class_foo.members().iter().collect();
509
510 if let Node::AttrReader(attr) = &members[0] {
512 assert_eq!(attr.kind(), AttributeKind::Instance);
513 assert_eq!(attr.visibility(), AttributeVisibility::Unspecified);
514 } else {
515 panic!("Expected AttrReader");
516 }
517
518 if let Node::MethodDefinition(method) = &members[1] {
520 assert_eq!(method.kind(), MethodDefinitionKind::Singleton);
521 assert_eq!(method.visibility(), MethodDefinitionVisibility::Unspecified);
522 } else {
523 panic!("Expected MethodDefinition");
524 }
525
526 if let Node::Alias(alias) = &members[2] {
528 assert_eq!(alias.kind(), AliasKind::Instance);
529 } else {
530 panic!("Expected Alias");
531 }
532
533 if let Node::Alias(alias) = &members[3] {
535 assert_eq!(alias.kind(), AliasKind::Singleton);
536 } else {
537 panic!("Expected Alias");
538 }
539
540 let Node::Class(class_bar) = &declarations[1] else {
542 panic!("Expected Class");
543 };
544
545 let type_params: Vec<_> = class_bar.type_params().iter().collect();
546 assert_eq!(type_params.len(), 3);
547
548 if let Node::TypeParam(param) = &type_params[0] {
550 assert_eq!(param.variance(), TypeParamVariance::Covariant);
551 } else {
552 panic!("Expected TypeParam");
553 }
554
555 if let Node::TypeParam(param) = &type_params[1] {
557 assert_eq!(param.variance(), TypeParamVariance::Contravariant);
558 } else {
559 panic!("Expected TypeParam");
560 }
561
562 if let Node::TypeParam(param) = &type_params[2] {
564 assert_eq!(param.variance(), TypeParamVariance::Invariant);
565 } else {
566 panic!("Expected TypeParam");
567 }
568 }
569
570 #[test]
571 fn test_ivar_name_enum() {
572 let rbs_code = r#"
573 class Foo
574 attr_reader name: String
575 attr_accessor age(): Integer
576 attr_writer email(@email): String
577 end
578 "#;
579 let signature = parse(rbs_code.as_bytes()).unwrap();
580
581 let Node::Class(class) = signature.declarations().iter().next().unwrap() else {
582 panic!("Expected Class");
583 };
584
585 let members: Vec<_> = class.members().iter().collect();
586
587 if let Node::AttrReader(attr) = &members[0] {
589 let ivar = attr.ivar_name();
590 assert_eq!(ivar, AttrIvarName::Unspecified);
591 } else {
592 panic!("Expected AttrReader");
593 }
594
595 if let Node::AttrAccessor(attr) = &members[1] {
597 let ivar = attr.ivar_name();
598 assert_eq!(ivar, AttrIvarName::Empty);
599 } else {
600 panic!("Expected AttrAccessor");
601 }
602
603 if let Node::AttrWriter(attr) = &members[2] {
605 let ivar = attr.ivar_name();
606 match ivar {
607 AttrIvarName::Name(id) => {
608 assert!(id > 0, "Expected valid constant ID");
609 }
610 _ => panic!("Expected AttrIvarName::Name, got {:?}", ivar),
611 }
612 } else {
613 panic!("Expected AttrWriter");
614 }
615 }
616}