prqlc_ast/
span.rs

1use std::fmt::{self, Debug, Formatter};
2use std::ops::{Add, Range, Sub};
3
4use chumsky::Stream;
5use serde::de::Visitor;
6use serde::{Deserialize, Serialize};
7
8#[derive(Clone, PartialEq, Eq, Copy)]
9pub struct Span {
10    pub start: usize,
11    pub end: usize,
12
13    /// A key representing the path of the source. Value is stored in prqlc's SourceTree::source_ids.
14    pub source_id: u16,
15}
16
17impl From<Span> for Range<usize> {
18    fn from(a: Span) -> Self {
19        a.start..a.end
20    }
21}
22
23impl Debug for Span {
24    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
25        write!(f, "{}:{}-{}", self.source_id, self.start, self.end)
26    }
27}
28
29impl Serialize for Span {
30    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
31    where
32        S: serde::Serializer,
33    {
34        let str = format!("{self:?}");
35        serializer.serialize_str(&str)
36    }
37}
38
39impl PartialOrd for Span {
40    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
41        // We could expand this to compare source_id too, starting with minimum surprise
42        match other.source_id.partial_cmp(&self.source_id) {
43            Some(std::cmp::Ordering::Equal) => {
44                debug_assert!((self.start <= other.start) == (self.end <= other.end));
45                self.start.partial_cmp(&other.start)
46            }
47            _ => None,
48        }
49    }
50}
51
52impl<'de> Deserialize<'de> for Span {
53    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
54    where
55        D: serde::Deserializer<'de>,
56    {
57        struct SpanVisitor {}
58
59        impl<'de> Visitor<'de> for SpanVisitor {
60            type Value = Span;
61
62            fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
63                write!(f, "A span string of form `file_id:x-y`")
64            }
65
66            fn visit_str<E>(self, v: &str) -> std::result::Result<Self::Value, E>
67            where
68                E: serde::de::Error,
69            {
70                use serde::de;
71
72                if let Some((file_id, char_span)) = v.split_once(':') {
73                    let file_id = file_id
74                        .parse::<u16>()
75                        .map_err(|e| de::Error::custom(e.to_string()))?;
76
77                    if let Some((start, end)) = char_span.split_once('-') {
78                        let start = start
79                            .parse::<usize>()
80                            .map_err(|e| de::Error::custom(e.to_string()))?;
81                        let end = end
82                            .parse::<usize>()
83                            .map_err(|e| de::Error::custom(e.to_string()))?;
84
85                        return Ok(Span {
86                            start,
87                            end,
88                            source_id: file_id,
89                        });
90                    }
91                }
92
93                Err(de::Error::custom("malformed span"))
94            }
95
96            fn visit_string<E>(self, v: String) -> std::result::Result<Self::Value, E>
97            where
98                E: serde::de::Error,
99            {
100                self.visit_str(&v)
101            }
102        }
103
104        deserializer.deserialize_string(SpanVisitor {})
105    }
106}
107
108impl chumsky::Span for Span {
109    type Context = u16;
110
111    type Offset = usize;
112
113    fn new(context: Self::Context, range: std::ops::Range<Self::Offset>) -> Self {
114        Self {
115            start: range.start,
116            end: range.end,
117            source_id: context,
118        }
119    }
120
121    fn context(&self) -> Self::Context {
122        self.source_id
123    }
124
125    fn start(&self) -> Self::Offset {
126        self.start
127    }
128
129    fn end(&self) -> Self::Offset {
130        self.end
131    }
132}
133
134impl Add<usize> for Span {
135    type Output = Span;
136
137    fn add(self, rhs: usize) -> Span {
138        Self {
139            start: self.start + rhs,
140            end: self.end + rhs,
141            source_id: self.source_id,
142        }
143    }
144}
145
146impl Sub<usize> for Span {
147    type Output = Span;
148
149    fn sub(self, rhs: usize) -> Span {
150        Self {
151            start: self.start - rhs,
152            end: self.end - rhs,
153            source_id: self.source_id,
154        }
155    }
156}
157
158pub fn string_stream<'a>(
159    s: String,
160    span_base: Span,
161) -> Stream<'a, char, Span, Box<dyn Iterator<Item = (char, Span)>>> {
162    let chars = s.chars().collect::<Vec<_>>();
163
164    Stream::from_iter(
165        Span {
166            start: span_base.start + chars.len(),
167            end: span_base.start + chars.len(),
168            source_id: span_base.source_id,
169        },
170        Box::new(chars.into_iter().enumerate().map(move |(i, c)| {
171            (
172                c,
173                Span {
174                    start: span_base.start + i,
175                    end: span_base.start + i + 1,
176                    source_id: span_base.source_id,
177                },
178            )
179        })),
180    )
181}
182
183#[cfg(test)]
184mod test {
185    use super::*;
186
187    #[test]
188    fn test_span_serde() {
189        let span = Span {
190            start: 12,
191            end: 15,
192            source_id: 45,
193        };
194        let span_serialized = serde_json::to_string(&span).unwrap();
195        insta::assert_snapshot!(span_serialized, @r###""45:12-15""###);
196        let span_deserialized: Span = serde_json::from_str(&span_serialized).unwrap();
197        assert_eq!(span_deserialized, span);
198    }
199
200    #[test]
201    fn test_span_partial_cmp() {
202        let span1 = Span {
203            start: 10,
204            end: 20,
205            source_id: 1,
206        };
207        let span2 = Span {
208            start: 15,
209            end: 25,
210            source_id: 1,
211        };
212        let span3 = Span {
213            start: 5,
214            end: 15,
215            source_id: 2,
216        };
217
218        // span1 and span2 have the same source_id, so their start values are compared
219        assert_eq!(span1.partial_cmp(&span2), Some(std::cmp::Ordering::Less));
220        assert_eq!(span2.partial_cmp(&span1), Some(std::cmp::Ordering::Greater));
221
222        // span1 and span3 have different source_id, so their source_id values are compared
223        assert_eq!(span1.partial_cmp(&span3), None);
224        assert_eq!(span3.partial_cmp(&span1), None);
225    }
226}