ruby_prism/
lib.rs

1//! # ruby-prism
2//!
3//! Rustified version of Ruby's prism parser.
4//!
5#![warn(clippy::all, clippy::nursery, clippy::pedantic, future_incompatible, missing_docs, nonstandard_style, rust_2018_idioms, trivial_casts, trivial_numeric_casts, unreachable_pub, unused_qualifications)]
6
7// Most of the code in this file is generated, so sometimes it generates code
8// that doesn't follow the clippy rules. We don't want to see those warnings.
9#[allow(clippy::too_many_lines, clippy::use_self)]
10mod bindings {
11    // In `build.rs`, we generate bindings based on the config.yml file. Here is
12    // where we pull in those bindings and make them part of our library.
13    include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
14}
15
16use std::ffi::{c_char, CStr};
17use std::marker::PhantomData;
18use std::mem::MaybeUninit;
19use std::ptr::NonNull;
20
21pub use self::bindings::*;
22use ruby_prism_sys::{pm_comment_t, pm_constant_id_list_t, pm_constant_id_t, pm_diagnostic_t, pm_integer_t, pm_location_t, pm_magic_comment_t, pm_node_destroy, pm_node_list, pm_node_t, pm_parse, pm_parser_free, pm_parser_init, pm_parser_t};
23
24/// A range in the source file.
25pub struct Location<'pr> {
26    parser: NonNull<pm_parser_t>,
27    pub(crate) start: *const u8,
28    pub(crate) end: *const u8,
29    marker: PhantomData<&'pr [u8]>,
30}
31
32impl<'pr> Location<'pr> {
33    /// Returns a byte slice for the range.
34    #[must_use]
35    pub fn as_slice(&self) -> &'pr [u8] {
36        unsafe {
37            let len = usize::try_from(self.end.offset_from(self.start)).expect("end should point to memory after start");
38            std::slice::from_raw_parts(self.start, len)
39        }
40    }
41
42    /// Return a Location from the given `pm_location_t`.
43    #[must_use]
44    pub(crate) const fn new(parser: NonNull<pm_parser_t>, loc: &'pr pm_location_t) -> Location<'pr> {
45        Location {
46            parser,
47            start: loc.start,
48            end: loc.end,
49            marker: PhantomData,
50        }
51    }
52
53    /// Return a Location starting at self and ending at the end of other.
54    /// Returns None if both locations did not originate from the same parser,
55    /// or if self starts after other.
56    #[must_use]
57    pub fn join(&self, other: &Location<'pr>) -> Option<Location<'pr>> {
58        if self.parser != other.parser || self.start > other.start {
59            None
60        } else {
61            Some(Location {
62                parser: self.parser,
63                start: self.start,
64                end: other.end,
65                marker: PhantomData,
66            })
67        }
68    }
69
70    /// Return the start offset from the beginning of the parsed source.
71    #[must_use]
72    pub fn start_offset(&self) -> usize {
73        unsafe {
74            let parser_start = (*self.parser.as_ptr()).start;
75            usize::try_from(self.start.offset_from(parser_start)).expect("start should point to memory after the parser's start")
76        }
77    }
78
79    /// Return the end offset from the beginning of the parsed source.
80    #[must_use]
81    pub fn end_offset(&self) -> usize {
82        unsafe {
83            let parser_start = (*self.parser.as_ptr()).start;
84            usize::try_from(self.end.offset_from(parser_start)).expect("end should point to memory after the parser's start")
85        }
86    }
87}
88
89impl std::fmt::Debug for Location<'_> {
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        let slice: &[u8] = self.as_slice();
92
93        let mut visible = String::new();
94        visible.push('"');
95
96        for &byte in slice {
97            let part: Vec<u8> = std::ascii::escape_default(byte).collect();
98            visible.push_str(std::str::from_utf8(&part).unwrap());
99        }
100
101        visible.push('"');
102        write!(f, "{visible}")
103    }
104}
105
106/// An iterator over the nodes in a list.
107pub struct NodeListIter<'pr> {
108    parser: NonNull<pm_parser_t>,
109    pointer: NonNull<pm_node_list>,
110    index: usize,
111    marker: PhantomData<&'pr mut pm_node_list>,
112}
113
114impl<'pr> Iterator for NodeListIter<'pr> {
115    type Item = Node<'pr>;
116
117    fn next(&mut self) -> Option<Self::Item> {
118        if self.index >= unsafe { self.pointer.as_ref().size } {
119            None
120        } else {
121            let node: *mut pm_node_t = unsafe { *(self.pointer.as_ref().nodes.add(self.index)) };
122            self.index += 1;
123            Some(Node::new(self.parser, node))
124        }
125    }
126}
127
128/// A list of nodes.
129pub struct NodeList<'pr> {
130    parser: NonNull<pm_parser_t>,
131    pointer: NonNull<pm_node_list>,
132    marker: PhantomData<&'pr mut pm_node_list>,
133}
134
135impl<'pr> NodeList<'pr> {
136    /// Returns an iterator over the nodes.
137    #[must_use]
138    pub fn iter(&self) -> NodeListIter<'pr> {
139        NodeListIter {
140            parser: self.parser,
141            pointer: self.pointer,
142            index: 0,
143            marker: PhantomData,
144        }
145    }
146}
147
148impl std::fmt::Debug for NodeList<'_> {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        write!(f, "{:?}", self.iter().collect::<Vec<_>>())
151    }
152}
153
154/// A handle for a constant ID.
155pub struct ConstantId<'pr> {
156    parser: NonNull<pm_parser_t>,
157    id: pm_constant_id_t,
158    marker: PhantomData<&'pr mut pm_constant_id_t>,
159}
160
161impl<'pr> ConstantId<'pr> {
162    fn new(parser: NonNull<pm_parser_t>, id: pm_constant_id_t) -> Self {
163        ConstantId { parser, id, marker: PhantomData }
164    }
165
166    /// Returns a byte slice for the constant ID.
167    ///
168    /// # Panics
169    ///
170    /// Panics if the constant ID is not found in the constant pool.
171    #[must_use]
172    pub fn as_slice(&self) -> &'pr [u8] {
173        unsafe {
174            let pool = &(*self.parser.as_ptr()).constant_pool;
175            let constant = &(*pool.constants.add((self.id - 1).try_into().unwrap()));
176            std::slice::from_raw_parts(constant.start, constant.length)
177        }
178    }
179}
180
181impl std::fmt::Debug for ConstantId<'_> {
182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        write!(f, "{:?}", self.id)
184    }
185}
186
187/// An iterator over the constants in a list.
188pub struct ConstantListIter<'pr> {
189    parser: NonNull<pm_parser_t>,
190    pointer: NonNull<pm_constant_id_list_t>,
191    index: usize,
192    marker: PhantomData<&'pr mut pm_constant_id_list_t>,
193}
194
195impl<'pr> Iterator for ConstantListIter<'pr> {
196    type Item = ConstantId<'pr>;
197
198    fn next(&mut self) -> Option<Self::Item> {
199        if self.index >= unsafe { self.pointer.as_ref().size } {
200            None
201        } else {
202            let constant_id: pm_constant_id_t = unsafe { *(self.pointer.as_ref().ids.add(self.index)) };
203            self.index += 1;
204            Some(ConstantId::new(self.parser, constant_id))
205        }
206    }
207}
208
209/// A list of constants.
210pub struct ConstantList<'pr> {
211    /// The raw pointer to the parser where this list came from.
212    parser: NonNull<pm_parser_t>,
213
214    /// The raw pointer to the list allocated by prism.
215    pointer: NonNull<pm_constant_id_list_t>,
216
217    /// The marker to indicate the lifetime of the pointer.
218    marker: PhantomData<&'pr mut pm_constant_id_list_t>,
219}
220
221impl<'pr> ConstantList<'pr> {
222    /// Returns an iterator over the constants in the list.
223    #[must_use]
224    pub fn iter(&self) -> ConstantListIter<'pr> {
225        ConstantListIter {
226            parser: self.parser,
227            pointer: self.pointer,
228            index: 0,
229            marker: PhantomData,
230        }
231    }
232}
233
234impl std::fmt::Debug for ConstantList<'_> {
235    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236        write!(f, "{:?}", self.iter().collect::<Vec<_>>())
237    }
238}
239
240/// A handle for an arbitarily-sized integer.
241pub struct Integer<'pr> {
242    /// The raw pointer to the integer allocated by prism.
243    pointer: *const pm_integer_t,
244
245    /// The marker to indicate the lifetime of the pointer.
246    marker: PhantomData<&'pr mut pm_constant_id_t>,
247}
248
249impl<'pr> Integer<'pr> {
250    fn new(pointer: *const pm_integer_t) -> Self {
251        Integer { pointer, marker: PhantomData }
252    }
253}
254
255impl std::fmt::Debug for Integer<'_> {
256    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257        write!(f, "{:?}", self.pointer)
258    }
259}
260
261impl TryInto<i32> for Integer<'_> {
262    type Error = ();
263
264    fn try_into(self) -> Result<i32, Self::Error> {
265        let negative = unsafe { (*self.pointer).negative };
266        let length = unsafe { (*self.pointer).length };
267
268        if length == 0 {
269            i32::try_from(unsafe { (*self.pointer).value }).map_or(Err(()), |value| if negative { Ok(-value) } else { Ok(value) })
270        } else {
271            Err(())
272        }
273    }
274}
275
276/// A diagnostic message that came back from the parser.
277#[derive(Debug)]
278pub struct Diagnostic<'pr> {
279    diagnostic: NonNull<pm_diagnostic_t>,
280    parser: NonNull<pm_parser_t>,
281    marker: PhantomData<&'pr pm_diagnostic_t>,
282}
283
284impl<'pr> Diagnostic<'pr> {
285    /// Returns the message associated with the diagnostic.
286    ///
287    /// # Panics
288    ///
289    /// Panics if the message is not able to be converted into a `CStr`.
290    ///
291    #[must_use]
292    pub fn message(&self) -> &str {
293        unsafe {
294            let message: *mut c_char = self.diagnostic.as_ref().message.cast_mut();
295            CStr::from_ptr(message).to_str().expect("prism allows only UTF-8 for diagnostics.")
296        }
297    }
298
299    /// The location of the diagnostic in the source.
300    #[must_use]
301    pub fn location(&self) -> Location<'pr> {
302        Location::new(self.parser, unsafe { &self.diagnostic.as_ref().location })
303    }
304}
305
306/// A comment that was found during parsing.
307#[derive(Debug)]
308pub struct Comment<'pr> {
309    comment: NonNull<pm_comment_t>,
310    parser: NonNull<pm_parser_t>,
311    marker: PhantomData<&'pr pm_comment_t>,
312}
313
314impl<'pr> Comment<'pr> {
315    /// Returns the text of the comment.
316    ///
317    /// # Panics
318    /// Panics if the end offset is not greater than the start offset.
319    #[must_use]
320    pub fn text(&self) -> &[u8] {
321        self.location().as_slice()
322    }
323
324    /// The location of the comment in the source.
325    #[must_use]
326    pub fn location(&self) -> Location<'pr> {
327        Location::new(self.parser, unsafe { &self.comment.as_ref().location })
328    }
329}
330
331/// A magic comment that was found during parsing.
332#[derive(Debug)]
333pub struct MagicComment<'pr> {
334    comment: NonNull<pm_magic_comment_t>,
335    marker: PhantomData<&'pr pm_magic_comment_t>,
336}
337
338impl<'pr> MagicComment<'pr> {
339    /// Returns the text of the comment's key.
340    #[must_use]
341    pub fn key(&self) -> &[u8] {
342        unsafe {
343            let start = self.comment.as_ref().key_start;
344            let len = self.comment.as_ref().key_length as usize;
345            std::slice::from_raw_parts(start, len)
346        }
347    }
348
349    /// Returns the text of the comment's value.
350    #[must_use]
351    pub fn value(&self) -> &[u8] {
352        unsafe {
353            let start = self.comment.as_ref().value_start;
354            let len = self.comment.as_ref().value_length as usize;
355            std::slice::from_raw_parts(start, len)
356        }
357    }
358}
359
360/// A struct created by the `errors` or `warnings` methods on `ParseResult`. It
361/// can be used to iterate over the diagnostics in the parse result.
362pub struct Diagnostics<'pr> {
363    diagnostic: *mut pm_diagnostic_t,
364    parser: NonNull<pm_parser_t>,
365    marker: PhantomData<&'pr pm_diagnostic_t>,
366}
367
368impl<'pr> Iterator for Diagnostics<'pr> {
369    type Item = Diagnostic<'pr>;
370
371    fn next(&mut self) -> Option<Self::Item> {
372        if let Some(diagnostic) = NonNull::new(self.diagnostic) {
373            let current = Diagnostic { diagnostic, parser: self.parser, marker: PhantomData };
374            self.diagnostic = unsafe { diagnostic.as_ref().node.next.cast::<pm_diagnostic_t>() };
375            Some(current)
376        } else {
377            None
378        }
379    }
380}
381
382/// A struct created by the `comments` method on `ParseResult`. It can be used
383/// to iterate over the comments in the parse result.
384pub struct Comments<'pr> {
385    comment: *mut pm_comment_t,
386    parser: NonNull<pm_parser_t>,
387    marker: PhantomData<&'pr pm_comment_t>,
388}
389
390impl<'pr> Iterator for Comments<'pr> {
391    type Item = Comment<'pr>;
392
393    fn next(&mut self) -> Option<Self::Item> {
394        if let Some(comment) = NonNull::new(self.comment) {
395            let current = Comment { comment, parser: self.parser, marker: PhantomData };
396            self.comment = unsafe { comment.as_ref().node.next.cast::<pm_comment_t>() };
397            Some(current)
398        } else {
399            None
400        }
401    }
402}
403
404/// A struct created by the `magic_comments` method on `ParseResult`. It can be used
405/// to iterate over the magic comments in the parse result.
406pub struct MagicComments<'pr> {
407    comment: *mut pm_magic_comment_t,
408    marker: PhantomData<&'pr pm_magic_comment_t>,
409}
410
411impl<'pr> Iterator for MagicComments<'pr> {
412    type Item = MagicComment<'pr>;
413
414    fn next(&mut self) -> Option<Self::Item> {
415        if let Some(comment) = NonNull::new(self.comment) {
416            let current = MagicComment { comment, marker: PhantomData };
417            self.comment = unsafe { comment.as_ref().node.next.cast::<pm_magic_comment_t>() };
418            Some(current)
419        } else {
420            None
421        }
422    }
423}
424
425/// The result of parsing a source string.
426#[derive(Debug)]
427pub struct ParseResult<'pr> {
428    source: &'pr [u8],
429    parser: NonNull<pm_parser_t>,
430    node: NonNull<pm_node_t>,
431}
432
433impl<'pr> ParseResult<'pr> {
434    /// Returns the source string that was parsed.
435    #[must_use]
436    pub const fn source(&self) -> &'pr [u8] {
437        self.source
438    }
439
440    /// Returns whether we found a `frozen_string_literal` magic comment with a true value.
441    #[must_use]
442    pub fn frozen_string_literals(&self) -> bool {
443        unsafe { (*self.parser.as_ptr()).frozen_string_literal == 1 }
444    }
445
446    /// Returns a slice of the source string that was parsed using the given
447    /// location range.
448    ///
449    /// # Panics
450    /// Panics if start offset or end offset are not valid offsets from the root.
451    #[must_use]
452    pub fn as_slice(&self, location: &Location<'pr>) -> &'pr [u8] {
453        let root = self.source.as_ptr();
454
455        let start = usize::try_from(unsafe { location.start.offset_from(root) }).expect("start should point to memory after root");
456        let end = usize::try_from(unsafe { location.end.offset_from(root) }).expect("end should point to memory after root");
457
458        &self.source[start..end]
459    }
460
461    /// Returns an iterator that can be used to iterate over the errors in the
462    /// parse result.
463    #[must_use]
464    pub fn errors(&self) -> Diagnostics<'_> {
465        unsafe {
466            let list = &mut (*self.parser.as_ptr()).error_list;
467            Diagnostics {
468                diagnostic: list.head.cast::<pm_diagnostic_t>(),
469                parser: self.parser,
470                marker: PhantomData,
471            }
472        }
473    }
474
475    /// Returns an iterator that can be used to iterate over the warnings in the
476    /// parse result.
477    #[must_use]
478    pub fn warnings(&self) -> Diagnostics<'_> {
479        unsafe {
480            let list = &mut (*self.parser.as_ptr()).warning_list;
481            Diagnostics {
482                diagnostic: list.head.cast::<pm_diagnostic_t>(),
483                parser: self.parser,
484                marker: PhantomData,
485            }
486        }
487    }
488
489    /// Returns an iterator that can be used to iterate over the comments in the
490    /// parse result.
491    #[must_use]
492    pub fn comments(&self) -> Comments<'_> {
493        unsafe {
494            let list = &mut (*self.parser.as_ptr()).comment_list;
495            Comments {
496                comment: list.head.cast::<pm_comment_t>(),
497                parser: self.parser,
498                marker: PhantomData,
499            }
500        }
501    }
502
503    /// Returns an iterator that can be used to iterate over the magic comments in the
504    /// parse result.
505    #[must_use]
506    pub fn magic_comments(&self) -> MagicComments<'_> {
507        unsafe {
508            let list = &mut (*self.parser.as_ptr()).magic_comment_list;
509            MagicComments {
510                comment: list.head.cast::<pm_magic_comment_t>(),
511                marker: PhantomData,
512            }
513        }
514    }
515
516    /// Returns an optional location of the __END__ marker and the rest of the content of the file.
517    #[must_use]
518    pub fn data_loc(&self) -> Option<Location<'_>> {
519        let location = unsafe { &(*self.parser.as_ptr()).data_loc };
520        if location.start.is_null() {
521            None
522        } else {
523            Some(Location::new(self.parser, location))
524        }
525    }
526
527    /// Returns the root node of the parse result.
528    #[must_use]
529    pub fn node(&self) -> Node<'_> {
530        Node::new(self.parser, self.node.as_ptr())
531    }
532}
533
534impl<'pr> Drop for ParseResult<'pr> {
535    fn drop(&mut self) {
536        unsafe {
537            pm_node_destroy(self.parser.as_ptr(), self.node.as_ptr());
538            pm_parser_free(self.parser.as_ptr());
539            drop(Box::from_raw(self.parser.as_ptr()));
540        }
541    }
542}
543
544/// Parses the given source string and returns a parse result.
545///
546/// # Panics
547///
548/// Panics if the parser fails to initialize.
549///
550#[must_use]
551pub fn parse(source: &[u8]) -> ParseResult<'_> {
552    unsafe {
553        let uninit = Box::new(MaybeUninit::<pm_parser_t>::uninit());
554        let uninit = Box::into_raw(uninit);
555
556        pm_parser_init((*uninit).as_mut_ptr(), source.as_ptr(), source.len(), std::ptr::null());
557
558        let parser = (*uninit).assume_init_mut();
559        let parser = NonNull::new_unchecked(parser);
560
561        let node = pm_parse(parser.as_ptr());
562        let node = NonNull::new_unchecked(node);
563
564        ParseResult { source, parser, node }
565    }
566}
567
568#[cfg(test)]
569mod tests {
570    use super::parse;
571
572    #[test]
573    fn comments_test() {
574        let source = "# comment 1\n# comment 2\n# comment 3\n";
575        let result = parse(source.as_ref());
576
577        for comment in result.comments() {
578            let text = std::str::from_utf8(comment.text()).unwrap();
579            assert!(text.starts_with("# comment"));
580        }
581    }
582
583    #[test]
584    fn magic_comments_test() {
585        use crate::MagicComment;
586
587        let source = "# typed: ignore\n# typed:true\n#typed: strict\n";
588        let result = parse(source.as_ref());
589
590        let comments: Vec<MagicComment<'_>> = result.magic_comments().collect();
591        assert_eq!(3, comments.len());
592
593        assert_eq!(b"typed", comments[0].key());
594        assert_eq!(b"ignore", comments[0].value());
595
596        assert_eq!(b"typed", comments[1].key());
597        assert_eq!(b"true", comments[1].value());
598
599        assert_eq!(b"typed", comments[2].key());
600        assert_eq!(b"strict", comments[2].value());
601    }
602
603    #[test]
604    fn data_loc_test() {
605        let source = "1";
606        let result = parse(source.as_ref());
607        let data_loc = result.data_loc();
608        assert!(data_loc.is_none());
609
610        let source = "__END__\nabc\n";
611        let result = parse(source.as_ref());
612        let data_loc = result.data_loc().unwrap();
613        let slice = std::str::from_utf8(result.as_slice(&data_loc)).unwrap();
614        assert_eq!(slice, "__END__\nabc\n");
615
616        let source = "1\n2\n3\n__END__\nabc\ndef\n";
617        let result = parse(source.as_ref());
618        let data_loc = result.data_loc().unwrap();
619        let slice = std::str::from_utf8(result.as_slice(&data_loc)).unwrap();
620        assert_eq!(slice, "__END__\nabc\ndef\n");
621    }
622
623    #[test]
624    fn location_test() {
625        let source = "111 + 222 + 333";
626        let result = parse(source.as_ref());
627
628        let node = result.node();
629        let node = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
630        let node = node.as_call_node().unwrap().receiver().unwrap();
631        let plus = node.as_call_node().unwrap();
632        let node = plus.arguments().unwrap().arguments().iter().next().unwrap();
633
634        let location = node.as_integer_node().unwrap().location();
635        let slice = std::str::from_utf8(result.as_slice(&location)).unwrap();
636
637        assert_eq!(slice, "222");
638        assert_eq!(6, location.start_offset());
639        assert_eq!(9, location.end_offset());
640
641        let recv_loc = plus.receiver().unwrap().location();
642        assert_eq!(recv_loc.as_slice(), b"111");
643        assert_eq!(0, recv_loc.start_offset());
644        assert_eq!(3, recv_loc.end_offset());
645
646        let joined = recv_loc.join(&location).unwrap();
647        assert_eq!(joined.as_slice(), b"111 + 222");
648
649        let not_joined = location.join(&recv_loc);
650        assert!(not_joined.is_none());
651
652        {
653            let result = parse(source.as_ref());
654            let node = result.node();
655            let node = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
656            let node = node.as_call_node().unwrap().receiver().unwrap();
657            let plus = node.as_call_node().unwrap();
658            let node = plus.arguments().unwrap().arguments().iter().next().unwrap();
659
660            let location = node.as_integer_node().unwrap().location();
661            let not_joined = recv_loc.join(&location);
662            assert!(not_joined.is_none());
663
664            let not_joined = location.join(&recv_loc);
665            assert!(not_joined.is_none());
666        }
667
668        let location = node.location();
669        let slice = std::str::from_utf8(result.as_slice(&location)).unwrap();
670
671        assert_eq!(slice, "222");
672
673        let slice = std::str::from_utf8(location.as_slice()).unwrap();
674
675        assert_eq!(slice, "222");
676    }
677
678    #[test]
679    fn visitor_test() {
680        use super::{visit_interpolated_regular_expression_node, visit_regular_expression_node, InterpolatedRegularExpressionNode, RegularExpressionNode, Visit};
681
682        struct RegularExpressionVisitor {
683            count: usize,
684        }
685
686        impl Visit<'_> for RegularExpressionVisitor {
687            fn visit_interpolated_regular_expression_node(&mut self, node: &InterpolatedRegularExpressionNode<'_>) {
688                self.count += 1;
689                visit_interpolated_regular_expression_node(self, node);
690            }
691
692            fn visit_regular_expression_node(&mut self, node: &RegularExpressionNode<'_>) {
693                self.count += 1;
694                visit_regular_expression_node(self, node);
695            }
696        }
697
698        let source = "# comment 1\n# comment 2\nmodule Foo; class Bar; /abc #{/def/}/; end; end";
699        let result = parse(source.as_ref());
700
701        let mut visitor = RegularExpressionVisitor { count: 0 };
702        visitor.visit(&result.node());
703
704        assert_eq!(visitor.count, 2);
705    }
706
707    #[test]
708    fn node_upcast_test() {
709        use super::Node;
710
711        let source = "module Foo; end";
712        let result = parse(source.as_ref());
713
714        let node = result.node();
715        let upcast_node = node.as_program_node().unwrap().as_node();
716        assert!(matches!(upcast_node, Node::ProgramNode { .. }));
717
718        let node = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
719        let upcast_node = node.as_module_node().unwrap().as_node();
720        assert!(matches!(upcast_node, Node::ModuleNode { .. }));
721    }
722
723    #[test]
724    fn constant_id_test() {
725        let source = "module Foo; x = 1; end";
726        let result = parse(source.as_ref());
727
728        let node = result.node();
729        let module = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
730        let module = module.as_module_node().unwrap();
731        let locals = module.locals().iter().collect::<Vec<_>>();
732
733        assert_eq!(locals.len(), 1);
734
735        assert_eq!(locals[0].as_slice(), b"x");
736    }
737
738    #[test]
739    fn optional_loc_test() {
740        let source = r#"
741module Example
742  x = call_func(3, 4)
743  y = x.call_func 5, 6
744end
745"#;
746        let result = parse(source.as_ref());
747
748        let node = result.node();
749        let module = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
750        let module = module.as_module_node().unwrap();
751        let body = module.body();
752        let writes = body.iter().next().unwrap().as_statements_node().unwrap().body().iter().collect::<Vec<_>>();
753        assert_eq!(writes.len(), 2);
754
755        let asgn = &writes[0];
756        let call = asgn.as_local_variable_write_node().unwrap().value();
757        let call = call.as_call_node().unwrap();
758
759        let call_operator_loc = call.call_operator_loc();
760        assert!(call_operator_loc.is_none());
761        let closing_loc = call.closing_loc();
762        assert!(closing_loc.is_some());
763
764        let asgn = &writes[1];
765        let call = asgn.as_local_variable_write_node().unwrap().value();
766        let call = call.as_call_node().unwrap();
767
768        let call_operator_loc = call.call_operator_loc();
769        assert!(call_operator_loc.is_some());
770        let closing_loc = call.closing_loc();
771        assert!(closing_loc.is_none());
772    }
773
774    #[test]
775    fn frozen_strings_test() {
776        let source = r#"
777# frozen_string_literal: true
778"foo"
779"#;
780        let result = parse(source.as_ref());
781        assert!(result.frozen_string_literals());
782
783        let source = "3";
784        let result = parse(source.as_ref());
785        assert!(!result.frozen_string_literals());
786    }
787
788    #[test]
789    fn string_flags_test() {
790        let source = r#"
791# frozen_string_literal: true
792"foo"
793"#;
794        let result = parse(source.as_ref());
795
796        let node = result.node();
797        let string = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
798        let string = string.as_string_node().unwrap();
799        assert!(string.is_frozen());
800
801        let source = r#"
802"foo"
803"#;
804        let result = parse(source.as_ref());
805
806        let node = result.node();
807        let string = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
808        let string = string.as_string_node().unwrap();
809        assert!(!string.is_frozen());
810    }
811
812    #[test]
813    fn call_flags_test() {
814        let source = r#"
815x
816"#;
817        let result = parse(source.as_ref());
818
819        let node = result.node();
820        let call = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
821        let call = call.as_call_node().unwrap();
822        assert!(call.is_variable_call());
823
824        let source = r#"
825x&.foo
826"#;
827        let result = parse(source.as_ref());
828
829        let node = result.node();
830        let call = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
831        let call = call.as_call_node().unwrap();
832        assert!(call.is_safe_navigation());
833    }
834
835    #[test]
836    fn integer_flags_test() {
837        let source = r#"
8380b1
839"#;
840        let result = parse(source.as_ref());
841
842        let node = result.node();
843        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
844        let i = i.as_integer_node().unwrap();
845        assert!(i.is_binary());
846        assert!(!i.is_decimal());
847        assert!(!i.is_octal());
848        assert!(!i.is_hexadecimal());
849
850        let source = r#"
8511
852"#;
853        let result = parse(source.as_ref());
854
855        let node = result.node();
856        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
857        let i = i.as_integer_node().unwrap();
858        assert!(!i.is_binary());
859        assert!(i.is_decimal());
860        assert!(!i.is_octal());
861        assert!(!i.is_hexadecimal());
862
863        let source = r#"
8640o1
865"#;
866        let result = parse(source.as_ref());
867
868        let node = result.node();
869        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
870        let i = i.as_integer_node().unwrap();
871        assert!(!i.is_binary());
872        assert!(!i.is_decimal());
873        assert!(i.is_octal());
874        assert!(!i.is_hexadecimal());
875
876        let source = r#"
8770x1
878"#;
879        let result = parse(source.as_ref());
880
881        let node = result.node();
882        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
883        let i = i.as_integer_node().unwrap();
884        assert!(!i.is_binary());
885        assert!(!i.is_decimal());
886        assert!(!i.is_octal());
887        assert!(i.is_hexadecimal());
888    }
889
890    #[test]
891    fn range_flags_test() {
892        let source = r#"
8930..1
894"#;
895        let result = parse(source.as_ref());
896
897        let node = result.node();
898        let range = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
899        let range = range.as_range_node().unwrap();
900        assert!(!range.is_exclude_end());
901
902        let source = r#"
9030...1
904"#;
905        let result = parse(source.as_ref());
906
907        let node = result.node();
908        let range = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
909        let range = range.as_range_node().unwrap();
910        assert!(range.is_exclude_end());
911    }
912
913    #[allow(clippy::too_many_lines, clippy::cognitive_complexity)]
914    #[test]
915    fn regex_flags_test() {
916        let source = r#"
917/a/i
918"#;
919        let result = parse(source.as_ref());
920
921        let node = result.node();
922        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
923        let regex = regex.as_regular_expression_node().unwrap();
924        assert!(regex.is_ignore_case());
925        assert!(!regex.is_extended());
926        assert!(!regex.is_multi_line());
927        assert!(!regex.is_euc_jp());
928        assert!(!regex.is_ascii_8bit());
929        assert!(!regex.is_windows_31j());
930        assert!(!regex.is_utf_8());
931        assert!(!regex.is_once());
932
933        let source = r#"
934/a/x
935"#;
936        let result = parse(source.as_ref());
937
938        let node = result.node();
939        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
940        let regex = regex.as_regular_expression_node().unwrap();
941        assert!(!regex.is_ignore_case());
942        assert!(regex.is_extended());
943        assert!(!regex.is_multi_line());
944        assert!(!regex.is_euc_jp());
945        assert!(!regex.is_ascii_8bit());
946        assert!(!regex.is_windows_31j());
947        assert!(!regex.is_utf_8());
948        assert!(!regex.is_once());
949
950        let source = r#"
951/a/m
952"#;
953        let result = parse(source.as_ref());
954
955        let node = result.node();
956        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
957        let regex = regex.as_regular_expression_node().unwrap();
958        assert!(!regex.is_ignore_case());
959        assert!(!regex.is_extended());
960        assert!(regex.is_multi_line());
961        assert!(!regex.is_euc_jp());
962        assert!(!regex.is_ascii_8bit());
963        assert!(!regex.is_windows_31j());
964        assert!(!regex.is_utf_8());
965        assert!(!regex.is_once());
966
967        let source = r#"
968/a/e
969"#;
970        let result = parse(source.as_ref());
971
972        let node = result.node();
973        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
974        let regex = regex.as_regular_expression_node().unwrap();
975        assert!(!regex.is_ignore_case());
976        assert!(!regex.is_extended());
977        assert!(!regex.is_multi_line());
978        assert!(regex.is_euc_jp());
979        assert!(!regex.is_ascii_8bit());
980        assert!(!regex.is_windows_31j());
981        assert!(!regex.is_utf_8());
982        assert!(!regex.is_once());
983
984        let source = r#"
985/a/n
986"#;
987        let result = parse(source.as_ref());
988
989        let node = result.node();
990        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
991        let regex = regex.as_regular_expression_node().unwrap();
992        assert!(!regex.is_ignore_case());
993        assert!(!regex.is_extended());
994        assert!(!regex.is_multi_line());
995        assert!(!regex.is_euc_jp());
996        assert!(regex.is_ascii_8bit());
997        assert!(!regex.is_windows_31j());
998        assert!(!regex.is_utf_8());
999        assert!(!regex.is_once());
1000
1001        let source = r#"
1002/a/s
1003"#;
1004        let result = parse(source.as_ref());
1005
1006        let node = result.node();
1007        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1008        let regex = regex.as_regular_expression_node().unwrap();
1009        assert!(!regex.is_ignore_case());
1010        assert!(!regex.is_extended());
1011        assert!(!regex.is_multi_line());
1012        assert!(!regex.is_euc_jp());
1013        assert!(!regex.is_ascii_8bit());
1014        assert!(regex.is_windows_31j());
1015        assert!(!regex.is_utf_8());
1016        assert!(!regex.is_once());
1017
1018        let source = r#"
1019/a/u
1020"#;
1021        let result = parse(source.as_ref());
1022
1023        let node = result.node();
1024        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1025        let regex = regex.as_regular_expression_node().unwrap();
1026        assert!(!regex.is_ignore_case());
1027        assert!(!regex.is_extended());
1028        assert!(!regex.is_multi_line());
1029        assert!(!regex.is_euc_jp());
1030        assert!(!regex.is_ascii_8bit());
1031        assert!(!regex.is_windows_31j());
1032        assert!(regex.is_utf_8());
1033        assert!(!regex.is_once());
1034
1035        let source = r#"
1036/a/o
1037"#;
1038        let result = parse(source.as_ref());
1039
1040        let node = result.node();
1041        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1042        let regex = regex.as_regular_expression_node().unwrap();
1043        assert!(!regex.is_ignore_case());
1044        assert!(!regex.is_extended());
1045        assert!(!regex.is_multi_line());
1046        assert!(!regex.is_euc_jp());
1047        assert!(!regex.is_ascii_8bit());
1048        assert!(!regex.is_windows_31j());
1049        assert!(!regex.is_utf_8());
1050        assert!(regex.is_once());
1051    }
1052
1053    #[test]
1054    fn visitor_traversal_test() {
1055        use crate::{Node, Visit};
1056
1057        #[derive(Default)]
1058        struct NodeCounts {
1059            pre_parent: usize,
1060            post_parent: usize,
1061            pre_leaf: usize,
1062            post_leaf: usize,
1063        }
1064
1065        #[derive(Default)]
1066        struct CountingVisitor {
1067            counts: NodeCounts,
1068        }
1069
1070        impl<'pr> Visit<'pr> for CountingVisitor {
1071            fn visit_branch_node_enter(&mut self, _node: Node<'_>) {
1072                self.counts.pre_parent += 1;
1073            }
1074
1075            fn visit_branch_node_leave(&mut self) {
1076                self.counts.post_parent += 1;
1077            }
1078
1079            fn visit_leaf_node_enter(&mut self, _node: Node<'_>) {
1080                self.counts.pre_leaf += 1;
1081            }
1082
1083            fn visit_leaf_node_leave(&mut self) {
1084                self.counts.post_leaf += 1;
1085            }
1086        }
1087
1088        let source = r#"
1089module Example
1090  x = call_func(3, 4)
1091  y = x.call_func 5, 6
1092end
1093"#;
1094        let result = parse(source.as_ref());
1095        let node = result.node();
1096        let mut visitor = CountingVisitor::default();
1097        visitor.visit(&node);
1098
1099        assert_eq!(7, visitor.counts.pre_parent);
1100        assert_eq!(7, visitor.counts.post_parent);
1101        assert_eq!(6, visitor.counts.pre_leaf);
1102        assert_eq!(6, visitor.counts.post_leaf);
1103    }
1104
1105    #[test]
1106    fn visitor_lifetime_test() {
1107        use crate::{Node, Visit};
1108
1109        #[derive(Default)]
1110        struct StackingNodeVisitor<'a> {
1111            stack: Vec<Node<'a>>,
1112            max_depth: usize,
1113        }
1114
1115        impl<'pr> Visit<'pr> for StackingNodeVisitor<'pr> {
1116            fn visit_branch_node_enter(&mut self, node: Node<'pr>) {
1117                self.stack.push(node);
1118            }
1119
1120            fn visit_branch_node_leave(&mut self) {
1121                self.stack.pop();
1122            }
1123
1124            fn visit_leaf_node_leave(&mut self) {
1125                self.max_depth = self.max_depth.max(self.stack.len());
1126            }
1127        }
1128
1129        let source = r#"
1130module Example
1131  x = call_func(3, 4)
1132  y = x.call_func 5, 6
1133end
1134"#;
1135        let result = parse(source.as_ref());
1136        let node = result.node();
1137        let mut visitor = StackingNodeVisitor::default();
1138        visitor.visit(&node);
1139
1140        assert_eq!(0, visitor.stack.len());
1141        assert_eq!(5, visitor.max_depth);
1142    }
1143
1144    #[test]
1145    fn integer_value_test() {
1146        let result = parse("0xA".as_ref());
1147        let value: i32 = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_integer_node().unwrap().value().try_into().unwrap();
1148
1149        assert_eq!(value, 10);
1150    }
1151
1152    #[test]
1153    fn float_value_test() {
1154        let result = parse("1.0".as_ref());
1155        let value: f64 = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_float_node().unwrap().value();
1156
1157        assert!((value - 1.0).abs() < f64::EPSILON);
1158    }
1159
1160    #[test]
1161    fn node_field_lifetime_test() {
1162        // The code below wouldn't typecheck prior to https://github.com/ruby/prism/pull/2519,
1163        // but we need to stop clippy from complaining about it.
1164        #![allow(clippy::needless_pass_by_value)]
1165
1166        use crate::Node;
1167
1168        #[derive(Default)]
1169        struct Extract<'pr> {
1170            scopes: Vec<crate::ConstantId<'pr>>,
1171        }
1172
1173        impl<'pr> Extract<'pr> {
1174            fn push_scope(&mut self, path: Node<'pr>) {
1175                if let Some(cread) = path.as_constant_read_node() {
1176                    self.scopes.push(cread.name());
1177                } else if let Some(cpath) = path.as_constant_path_node() {
1178                    if let Some(parent) = cpath.parent() {
1179                        self.push_scope(parent);
1180                    }
1181                    self.scopes.push(cpath.name().unwrap());
1182                } else {
1183                    panic!("Wrong node kind!");
1184                }
1185            }
1186        }
1187
1188        let source = "Some::Random::Constant";
1189        let result = parse(source.as_ref());
1190        let node = result.node();
1191        let mut extractor = Extract::default();
1192        extractor.push_scope(node.as_program_node().unwrap().statements().body().iter().next().unwrap());
1193        assert_eq!(3, extractor.scopes.len());
1194    }
1195}