1use std::fmt::{self, Debug, Formatter};
2use std::ops::{Add, Range, Sub};
3
4use chumsky::Stream;
5use schemars::JsonSchema;
6use serde::de::Visitor;
7use serde::{Deserialize, Serialize};
8
9#[derive(Clone, PartialEq, Eq, Copy, JsonSchema)]
10pub struct Span {
11 pub start: usize,
12 pub end: usize,
13
14 pub source_id: u16,
16}
17
18impl From<Span> for Range<usize> {
19 fn from(a: Span) -> Self {
20 a.start..a.end
21 }
22}
23
24impl Debug for Span {
25 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
26 write!(f, "{}:{}-{}", self.source_id, self.start, self.end)
27 }
28}
29
30impl Serialize for Span {
31 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
32 where
33 S: serde::Serializer,
34 {
35 let str = format!("{self:?}");
36 serializer.serialize_str(&str)
37 }
38}
39
40impl PartialOrd for Span {
41 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
42 match other.source_id.partial_cmp(&self.source_id) {
44 Some(std::cmp::Ordering::Equal) => {
45 debug_assert!((self.start <= other.start) == (self.end <= other.end));
46 self.start.partial_cmp(&other.start)
47 }
48 _ => None,
49 }
50 }
51}
52
53impl<'de> Deserialize<'de> for Span {
54 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
55 where
56 D: serde::Deserializer<'de>,
57 {
58 struct SpanVisitor {}
59
60 impl Visitor<'_> for SpanVisitor {
61 type Value = Span;
62
63 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
64 write!(f, "A span string of form `file_id:x-y`")
65 }
66
67 fn visit_str<E>(self, v: &str) -> std::result::Result<Self::Value, E>
68 where
69 E: serde::de::Error,
70 {
71 use serde::de;
72
73 if let Some((file_id, char_span)) = v.split_once(':') {
74 let file_id = file_id
75 .parse::<u16>()
76 .map_err(|e| de::Error::custom(e.to_string()))?;
77
78 if let Some((start, end)) = char_span.split_once('-') {
79 let start = start
80 .parse::<usize>()
81 .map_err(|e| de::Error::custom(e.to_string()))?;
82 let end = end
83 .parse::<usize>()
84 .map_err(|e| de::Error::custom(e.to_string()))?;
85
86 return Ok(Span {
87 start,
88 end,
89 source_id: file_id,
90 });
91 }
92 }
93
94 Err(de::Error::custom("malformed span"))
95 }
96
97 fn visit_string<E>(self, v: String) -> std::result::Result<Self::Value, E>
98 where
99 E: serde::de::Error,
100 {
101 self.visit_str(&v)
102 }
103 }
104
105 deserializer.deserialize_string(SpanVisitor {})
106 }
107}
108
109impl chumsky::Span for Span {
110 type Context = u16;
111
112 type Offset = usize;
113
114 fn new(context: Self::Context, range: std::ops::Range<Self::Offset>) -> Self {
115 Self {
116 start: range.start,
117 end: range.end,
118 source_id: context,
119 }
120 }
121
122 fn context(&self) -> Self::Context {
123 self.source_id
124 }
125
126 fn start(&self) -> Self::Offset {
127 self.start
128 }
129
130 fn end(&self) -> Self::Offset {
131 self.end
132 }
133}
134
135impl Add<usize> for Span {
136 type Output = Span;
137
138 fn add(self, rhs: usize) -> Span {
139 Self {
140 start: self.start + rhs,
141 end: self.end + rhs,
142 source_id: self.source_id,
143 }
144 }
145}
146
147impl Sub<usize> for Span {
148 type Output = Span;
149
150 fn sub(self, rhs: usize) -> Span {
151 Self {
152 start: self.start - rhs,
153 end: self.end - rhs,
154 source_id: self.source_id,
155 }
156 }
157}
158
159pub(crate) fn string_stream<'a>(
160 s: String,
161 span_base: Span,
162) -> Stream<'a, char, Span, Box<dyn Iterator<Item = (char, Span)>>> {
163 let chars = s.chars().collect::<Vec<_>>();
164
165 Stream::from_iter(
166 Span {
167 start: span_base.start + chars.len(),
168 end: span_base.start + chars.len(),
169 source_id: span_base.source_id,
170 },
171 Box::new(chars.into_iter().enumerate().map(move |(i, c)| {
172 (
173 c,
174 Span {
175 start: span_base.start + i,
176 end: span_base.start + i + 1,
177 source_id: span_base.source_id,
178 },
179 )
180 })),
181 )
182}
183
184#[cfg(test)]
185mod test {
186 use super::*;
187
188 #[test]
189 fn test_span_serde() {
190 let span = Span {
191 start: 12,
192 end: 15,
193 source_id: 45,
194 };
195 let span_serialized = serde_json::to_string(&span).unwrap();
196 insta::assert_snapshot!(span_serialized, @r#""45:12-15""#);
197 let span_deserialized: Span = serde_json::from_str(&span_serialized).unwrap();
198 assert_eq!(span_deserialized, span);
199 }
200
201 #[test]
202 fn test_span_partial_cmp() {
203 let span1 = Span {
204 start: 10,
205 end: 20,
206 source_id: 1,
207 };
208 let span2 = Span {
209 start: 15,
210 end: 25,
211 source_id: 1,
212 };
213 let span3 = Span {
214 start: 5,
215 end: 15,
216 source_id: 2,
217 };
218
219 assert_eq!(span1.partial_cmp(&span2), Some(std::cmp::Ordering::Less));
221 assert_eq!(span2.partial_cmp(&span1), Some(std::cmp::Ordering::Greater));
222
223 assert_eq!(span1.partial_cmp(&span3), None);
225 assert_eq!(span3.partial_cmp(&span1), None);
226 }
227}