prqlc_parser/
span.rs

1use std::fmt::{self, Debug, Formatter};
2use std::ops::{Add, Range, Sub};
3
4use chumsky;
5use schemars::JsonSchema;
6use serde::de::Visitor;
7use serde::{Deserialize, Serialize};
8
9#[derive(Clone, PartialEq, Eq, Copy, JsonSchema)]
10pub struct Span {
11    pub start: usize,
12    pub end: usize,
13
14    /// A key representing the path of the source. Value is stored in prqlc's SourceTree::source_ids.
15    pub source_id: u16,
16}
17
18// Implement Chumsky's Span trait for our custom Span
19impl chumsky::span::Span for Span {
20    type Context = u16; // source_id
21    type Offset = usize;
22
23    fn new(context: Self::Context, range: Range<Self::Offset>) -> Self {
24        Span {
25            start: range.start,
26            end: range.end,
27            source_id: context,
28        }
29    }
30
31    fn context(&self) -> Self::Context {
32        self.source_id
33    }
34
35    fn start(&self) -> Self::Offset {
36        self.start
37    }
38
39    fn end(&self) -> Self::Offset {
40        self.end
41    }
42
43    fn to_end(&self) -> Self {
44        Span {
45            start: self.end,
46            end: self.end,
47            source_id: self.source_id,
48        }
49    }
50}
51
52impl From<Span> for Range<usize> {
53    fn from(a: Span) -> Self {
54        a.start..a.end
55    }
56}
57
58impl Debug for Span {
59    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
60        write!(f, "{}:{}-{}", self.source_id, self.start, self.end)
61    }
62}
63
64impl Serialize for Span {
65    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
66    where
67        S: serde::Serializer,
68    {
69        let str = format!("{self:?}");
70        serializer.serialize_str(&str)
71    }
72}
73
74impl PartialOrd for Span {
75    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
76        // We could expand this to compare source_id too, starting with minimum surprise
77        match other.source_id.partial_cmp(&self.source_id) {
78            Some(std::cmp::Ordering::Equal) => {
79                debug_assert!((self.start <= other.start) == (self.end <= other.end));
80                self.start.partial_cmp(&other.start)
81            }
82            _ => None,
83        }
84    }
85}
86
87impl<'de> Deserialize<'de> for Span {
88    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
89    where
90        D: serde::Deserializer<'de>,
91    {
92        struct SpanVisitor {}
93
94        impl Visitor<'_> for SpanVisitor {
95            type Value = Span;
96
97            fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
98                write!(f, "A span string of form `file_id:x-y`")
99            }
100
101            fn visit_str<E>(self, v: &str) -> std::result::Result<Self::Value, E>
102            where
103                E: serde::de::Error,
104            {
105                use serde::de;
106
107                if let Some((file_id, char_span)) = v.split_once(':') {
108                    let file_id = file_id
109                        .parse::<u16>()
110                        .map_err(|e| de::Error::custom(e.to_string()))?;
111
112                    if let Some((start, end)) = char_span.split_once('-') {
113                        let start = start
114                            .parse::<usize>()
115                            .map_err(|e| de::Error::custom(e.to_string()))?;
116                        let end = end
117                            .parse::<usize>()
118                            .map_err(|e| de::Error::custom(e.to_string()))?;
119
120                        return Ok(Span {
121                            start,
122                            end,
123                            source_id: file_id,
124                        });
125                    }
126                }
127
128                Err(de::Error::custom("malformed span"))
129            }
130
131            fn visit_string<E>(self, v: String) -> std::result::Result<Self::Value, E>
132            where
133                E: serde::de::Error,
134            {
135                self.visit_str(&v)
136            }
137        }
138
139        deserializer.deserialize_string(SpanVisitor {})
140    }
141}
142
143impl Add<usize> for Span {
144    type Output = Span;
145
146    fn add(self, rhs: usize) -> Span {
147        Self {
148            start: self.start + rhs,
149            end: self.end + rhs,
150            source_id: self.source_id,
151        }
152    }
153}
154
155impl Sub<usize> for Span {
156    type Output = Span;
157
158    fn sub(self, rhs: usize) -> Span {
159        Self {
160            start: self.start - rhs,
161            end: self.end - rhs,
162            source_id: self.source_id,
163        }
164    }
165}
166
167#[cfg(test)]
168mod test {
169    use super::*;
170
171    #[test]
172    fn test_span_serde() {
173        let span = Span {
174            start: 12,
175            end: 15,
176            source_id: 45,
177        };
178        let span_serialized = serde_json::to_string(&span).unwrap();
179        insta::assert_snapshot!(span_serialized, @r#""45:12-15""#);
180        let span_deserialized: Span = serde_json::from_str(&span_serialized).unwrap();
181        assert_eq!(span_deserialized, span);
182    }
183
184    #[test]
185    fn test_span_partial_cmp() {
186        let span1 = Span {
187            start: 10,
188            end: 20,
189            source_id: 1,
190        };
191        let span2 = Span {
192            start: 15,
193            end: 25,
194            source_id: 1,
195        };
196        let span3 = Span {
197            start: 5,
198            end: 15,
199            source_id: 2,
200        };
201
202        // span1 and span2 have the same source_id, so their start values are compared
203        assert_eq!(span1.partial_cmp(&span2), Some(std::cmp::Ordering::Less));
204        assert_eq!(span2.partial_cmp(&span1), Some(std::cmp::Ordering::Greater));
205
206        // span1 and span3 have different source_id, so their source_id values are compared
207        assert_eq!(span1.partial_cmp(&span3), None);
208        assert_eq!(span3.partial_cmp(&span1), None);
209    }
210}