prqlc_parser/
span.rs

1use std::fmt::{self, Debug, Formatter};
2use std::ops::{Add, Range, Sub};
3
4use chumsky::Stream;
5use schemars::JsonSchema;
6use serde::de::Visitor;
7use serde::{Deserialize, Serialize};
8
9#[derive(Clone, PartialEq, Eq, Copy, JsonSchema)]
10pub struct Span {
11    pub start: usize,
12    pub end: usize,
13
14    /// A key representing the path of the source. Value is stored in prqlc's SourceTree::source_ids.
15    pub source_id: u16,
16}
17
18impl From<Span> for Range<usize> {
19    fn from(a: Span) -> Self {
20        a.start..a.end
21    }
22}
23
24impl Debug for Span {
25    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
26        write!(f, "{}:{}-{}", self.source_id, self.start, self.end)
27    }
28}
29
30impl Serialize for Span {
31    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
32    where
33        S: serde::Serializer,
34    {
35        let str = format!("{self:?}");
36        serializer.serialize_str(&str)
37    }
38}
39
40impl PartialOrd for Span {
41    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
42        // We could expand this to compare source_id too, starting with minimum surprise
43        match other.source_id.partial_cmp(&self.source_id) {
44            Some(std::cmp::Ordering::Equal) => {
45                debug_assert!((self.start <= other.start) == (self.end <= other.end));
46                self.start.partial_cmp(&other.start)
47            }
48            _ => None,
49        }
50    }
51}
52
53impl<'de> Deserialize<'de> for Span {
54    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
55    where
56        D: serde::Deserializer<'de>,
57    {
58        struct SpanVisitor {}
59
60        impl Visitor<'_> for SpanVisitor {
61            type Value = Span;
62
63            fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
64                write!(f, "A span string of form `file_id:x-y`")
65            }
66
67            fn visit_str<E>(self, v: &str) -> std::result::Result<Self::Value, E>
68            where
69                E: serde::de::Error,
70            {
71                use serde::de;
72
73                if let Some((file_id, char_span)) = v.split_once(':') {
74                    let file_id = file_id
75                        .parse::<u16>()
76                        .map_err(|e| de::Error::custom(e.to_string()))?;
77
78                    if let Some((start, end)) = char_span.split_once('-') {
79                        let start = start
80                            .parse::<usize>()
81                            .map_err(|e| de::Error::custom(e.to_string()))?;
82                        let end = end
83                            .parse::<usize>()
84                            .map_err(|e| de::Error::custom(e.to_string()))?;
85
86                        return Ok(Span {
87                            start,
88                            end,
89                            source_id: file_id,
90                        });
91                    }
92                }
93
94                Err(de::Error::custom("malformed span"))
95            }
96
97            fn visit_string<E>(self, v: String) -> std::result::Result<Self::Value, E>
98            where
99                E: serde::de::Error,
100            {
101                self.visit_str(&v)
102            }
103        }
104
105        deserializer.deserialize_string(SpanVisitor {})
106    }
107}
108
109impl chumsky::Span for Span {
110    type Context = u16;
111
112    type Offset = usize;
113
114    fn new(context: Self::Context, range: std::ops::Range<Self::Offset>) -> Self {
115        Self {
116            start: range.start,
117            end: range.end,
118            source_id: context,
119        }
120    }
121
122    fn context(&self) -> Self::Context {
123        self.source_id
124    }
125
126    fn start(&self) -> Self::Offset {
127        self.start
128    }
129
130    fn end(&self) -> Self::Offset {
131        self.end
132    }
133}
134
135impl Add<usize> for Span {
136    type Output = Span;
137
138    fn add(self, rhs: usize) -> Span {
139        Self {
140            start: self.start + rhs,
141            end: self.end + rhs,
142            source_id: self.source_id,
143        }
144    }
145}
146
147impl Sub<usize> for Span {
148    type Output = Span;
149
150    fn sub(self, rhs: usize) -> Span {
151        Self {
152            start: self.start - rhs,
153            end: self.end - rhs,
154            source_id: self.source_id,
155        }
156    }
157}
158
159pub(crate) fn string_stream<'a>(
160    s: String,
161    span_base: Span,
162) -> Stream<'a, char, Span, Box<dyn Iterator<Item = (char, Span)>>> {
163    let chars = s.chars().collect::<Vec<_>>();
164
165    Stream::from_iter(
166        Span {
167            start: span_base.start + chars.len(),
168            end: span_base.start + chars.len(),
169            source_id: span_base.source_id,
170        },
171        Box::new(chars.into_iter().enumerate().map(move |(i, c)| {
172            (
173                c,
174                Span {
175                    start: span_base.start + i,
176                    end: span_base.start + i + 1,
177                    source_id: span_base.source_id,
178                },
179            )
180        })),
181    )
182}
183
184#[cfg(test)]
185mod test {
186    use super::*;
187
188    #[test]
189    fn test_span_serde() {
190        let span = Span {
191            start: 12,
192            end: 15,
193            source_id: 45,
194        };
195        let span_serialized = serde_json::to_string(&span).unwrap();
196        insta::assert_snapshot!(span_serialized, @r#""45:12-15""#);
197        let span_deserialized: Span = serde_json::from_str(&span_serialized).unwrap();
198        assert_eq!(span_deserialized, span);
199    }
200
201    #[test]
202    fn test_span_partial_cmp() {
203        let span1 = Span {
204            start: 10,
205            end: 20,
206            source_id: 1,
207        };
208        let span2 = Span {
209            start: 15,
210            end: 25,
211            source_id: 1,
212        };
213        let span3 = Span {
214            start: 5,
215            end: 15,
216            source_id: 2,
217        };
218
219        // span1 and span2 have the same source_id, so their start values are compared
220        assert_eq!(span1.partial_cmp(&span2), Some(std::cmp::Ordering::Less));
221        assert_eq!(span2.partial_cmp(&span1), Some(std::cmp::Ordering::Greater));
222
223        // span1 and span3 have different source_id, so their source_id values are compared
224        assert_eq!(span1.partial_cmp(&span3), None);
225        assert_eq!(span3.partial_cmp(&span1), None);
226    }
227}