1use std::{
5 cell::RefCell,
6 cmp,
7 cmp::Ordering,
8 collections::HashMap,
9 fmt,
10 fmt::{Display, Formatter},
11 ops::Deref,
12 sync::Arc,
13};
14
15use serde::{
16 Deserialize, Serialize,
17 de::{self, EnumAccess, MapAccess, VariantAccess, Visitor},
18};
19
20const INTERN_CAP: usize = 4096;
21
22thread_local! {
23 static INTERN: RefCell<HashMap<Arc<str>, ()>> = RefCell::new(HashMap::new());
24}
25
26fn intern(text: &str) -> Arc<str> {
27 INTERN.with(|table| {
28 let mut guard = table.borrow_mut();
29 if let Some((existing, _)) = guard.get_key_value(text) {
30 return existing.clone();
31 }
32 if guard.len() >= INTERN_CAP {
33 return Arc::from(text);
34 }
35 let arc: Arc<str> = Arc::from(text);
36 guard.insert(arc.clone(), ());
37 arc
38 })
39}
40
41#[repr(transparent)]
42#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
43pub struct StatementColumn(pub u32);
44
45impl Deref for StatementColumn {
46 type Target = u32;
47
48 fn deref(&self) -> &Self::Target {
49 &self.0
50 }
51}
52
53impl PartialEq<i32> for StatementColumn {
54 fn eq(&self, other: &i32) -> bool {
55 self.0 == *other as u32
56 }
57}
58
59#[repr(transparent)]
60#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
61pub struct StatementLine(pub u32);
62
63impl Deref for StatementLine {
64 type Target = u32;
65
66 fn deref(&self) -> &Self::Target {
67 &self.0
68 }
69}
70
71impl PartialEq<i32> for StatementLine {
72 fn eq(&self, other: &i32) -> bool {
73 self.0 == *other as u32
74 }
75}
76
77#[derive(Debug, Clone, PartialEq, Hash, Serialize, Default)]
78pub enum Fragment {
79 #[default]
80 None,
81
82 Statement {
83 text: Arc<str>,
84 line: StatementLine,
85 column: StatementColumn,
86 },
87
88 Internal {
89 text: Arc<str>,
90 },
91}
92
93enum FragmentVariant {
94 None,
95 Statement,
96 Internal,
97}
98
99impl<'de> Deserialize<'de> for FragmentVariant {
100 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
101 where
102 D: de::Deserializer<'de>,
103 {
104 struct VariantVisitor;
105
106 impl<'de> Visitor<'de> for VariantVisitor {
107 type Value = FragmentVariant;
108
109 fn expecting(&self, f: &mut Formatter) -> fmt::Result {
110 f.write_str("variant identifier")
111 }
112
113 fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
114 where
115 E: de::Error,
116 {
117 match value {
118 0 => Ok(FragmentVariant::None),
119 1 => Ok(FragmentVariant::Statement),
120 2 => Ok(FragmentVariant::Internal),
121 _ => Err(de::Error::invalid_value(
122 de::Unexpected::Unsigned(value),
123 &"variant index 0 <= i < 3",
124 )),
125 }
126 }
127
128 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
129 where
130 E: de::Error,
131 {
132 match value {
133 "None" => Ok(FragmentVariant::None),
134 "Statement" => Ok(FragmentVariant::Statement),
135 "Internal" => Ok(FragmentVariant::Internal),
136 _ => Err(de::Error::unknown_variant(value, VARIANTS)),
137 }
138 }
139
140 fn visit_bytes<E>(self, value: &[u8]) -> Result<Self::Value, E>
141 where
142 E: de::Error,
143 {
144 match value {
145 b"None" => Ok(FragmentVariant::None),
146 b"Statement" => Ok(FragmentVariant::Statement),
147 b"Internal" => Ok(FragmentVariant::Internal),
148 _ => Err(de::Error::unknown_variant(&String::from_utf8_lossy(value), VARIANTS)),
149 }
150 }
151 }
152
153 deserializer.deserialize_identifier(VariantVisitor)
154 }
155}
156
157enum StatementField {
158 Text,
159 Line,
160 Column,
161}
162
163impl<'de> Deserialize<'de> for StatementField {
164 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
165 where
166 D: de::Deserializer<'de>,
167 {
168 struct FieldVisitor;
169
170 impl<'de> Visitor<'de> for FieldVisitor {
171 type Value = StatementField;
172
173 fn expecting(&self, f: &mut Formatter) -> fmt::Result {
174 f.write_str("field identifier")
175 }
176
177 fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
178 where
179 E: de::Error,
180 {
181 match value {
182 0 => Ok(StatementField::Text),
183 1 => Ok(StatementField::Line),
184 2 => Ok(StatementField::Column),
185 _ => Err(de::Error::invalid_value(
186 de::Unexpected::Unsigned(value),
187 &"field index 0 <= i < 3",
188 )),
189 }
190 }
191
192 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
193 where
194 E: de::Error,
195 {
196 match value {
197 "text" => Ok(StatementField::Text),
198 "line" => Ok(StatementField::Line),
199 "column" => Ok(StatementField::Column),
200 _ => Err(de::Error::unknown_field(value, STATEMENT_FIELDS)),
201 }
202 }
203 }
204
205 deserializer.deserialize_identifier(FieldVisitor)
206 }
207}
208
209enum InternalField {
210 Text,
211}
212
213impl<'de> Deserialize<'de> for InternalField {
214 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
215 where
216 D: de::Deserializer<'de>,
217 {
218 struct FieldVisitor;
219
220 impl<'de> Visitor<'de> for FieldVisitor {
221 type Value = InternalField;
222
223 fn expecting(&self, f: &mut Formatter) -> fmt::Result {
224 f.write_str("field identifier")
225 }
226
227 fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
228 where
229 E: de::Error,
230 {
231 match value {
232 0 => Ok(InternalField::Text),
233 _ => Err(de::Error::invalid_value(
234 de::Unexpected::Unsigned(value),
235 &"field index 0 <= i < 1",
236 )),
237 }
238 }
239
240 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
241 where
242 E: de::Error,
243 {
244 match value {
245 "text" => Ok(InternalField::Text),
246 _ => Err(de::Error::unknown_field(value, INTERNAL_FIELDS)),
247 }
248 }
249 }
250
251 deserializer.deserialize_identifier(FieldVisitor)
252 }
253}
254
255struct InternedText(Arc<str>);
256
257impl<'de> Deserialize<'de> for InternedText {
258 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
259 where
260 D: de::Deserializer<'de>,
261 {
262 struct InternedVisitor;
263
264 impl<'de> Visitor<'de> for InternedVisitor {
265 type Value = InternedText;
266
267 fn expecting(&self, f: &mut Formatter) -> fmt::Result {
268 f.write_str("a string")
269 }
270
271 fn visit_borrowed_str<E>(self, value: &'de str) -> Result<Self::Value, E>
272 where
273 E: de::Error,
274 {
275 Ok(InternedText(intern(value)))
276 }
277
278 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
279 where
280 E: de::Error,
281 {
282 Ok(InternedText(intern(value)))
283 }
284
285 fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
286 where
287 E: de::Error,
288 {
289 Ok(InternedText(intern(&value)))
290 }
291 }
292
293 deserializer.deserialize_str(InternedVisitor)
294 }
295}
296
297const VARIANTS: &[&str] = &["None", "Statement", "Internal"];
298const STATEMENT_FIELDS: &[&str] = &["text", "line", "column"];
299const INTERNAL_FIELDS: &[&str] = &["text"];
300
301impl<'de> Deserialize<'de> for Fragment {
302 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
303 where
304 D: de::Deserializer<'de>,
305 {
306 struct FragmentVisitor;
307
308 impl<'de> Visitor<'de> for FragmentVisitor {
309 type Value = Fragment;
310
311 fn expecting(&self, f: &mut Formatter) -> fmt::Result {
312 f.write_str("enum Fragment")
313 }
314
315 fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
316 where
317 A: EnumAccess<'de>,
318 {
319 let (variant, access) = data.variant::<FragmentVariant>()?;
320 match variant {
321 FragmentVariant::None => {
322 access.unit_variant()?;
323 Ok(Fragment::None)
324 }
325 FragmentVariant::Statement => {
326 access.struct_variant(STATEMENT_FIELDS, StatementVisitor)
327 }
328 FragmentVariant::Internal => {
329 access.struct_variant(INTERNAL_FIELDS, InternalVisitor)
330 }
331 }
332 }
333 }
334
335 struct StatementVisitor;
336
337 impl<'de> Visitor<'de> for StatementVisitor {
338 type Value = Fragment;
339
340 fn expecting(&self, f: &mut Formatter) -> fmt::Result {
341 f.write_str("struct variant Fragment::Statement")
342 }
343
344 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
345 where
346 A: de::SeqAccess<'de>,
347 {
348 let text: Arc<str> =
349 seq.next_element()?.ok_or_else(|| de::Error::invalid_length(0, &self))?;
350 let line: StatementLine =
351 seq.next_element()?.ok_or_else(|| de::Error::invalid_length(1, &self))?;
352 let column: StatementColumn =
353 seq.next_element()?.ok_or_else(|| de::Error::invalid_length(2, &self))?;
354 Ok(Fragment::Statement {
355 text,
356 line,
357 column,
358 })
359 }
360
361 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
362 where
363 A: MapAccess<'de>,
364 {
365 let mut text: Option<Arc<str>> = None;
366 let mut line: Option<StatementLine> = None;
367 let mut column: Option<StatementColumn> = None;
368 while let Some(key) = map.next_key::<StatementField>()? {
369 match key {
370 StatementField::Text => {
371 if text.is_some() {
372 return Err(de::Error::duplicate_field("text"));
373 }
374 text = Some(map.next_value()?);
375 }
376 StatementField::Line => {
377 if line.is_some() {
378 return Err(de::Error::duplicate_field("line"));
379 }
380 line = Some(map.next_value()?);
381 }
382 StatementField::Column => {
383 if column.is_some() {
384 return Err(de::Error::duplicate_field("column"));
385 }
386 column = Some(map.next_value()?);
387 }
388 }
389 }
390 Ok(Fragment::Statement {
391 text: text.ok_or_else(|| de::Error::missing_field("text"))?,
392 line: line.ok_or_else(|| de::Error::missing_field("line"))?,
393 column: column.ok_or_else(|| de::Error::missing_field("column"))?,
394 })
395 }
396 }
397
398 struct InternalVisitor;
399
400 impl<'de> Visitor<'de> for InternalVisitor {
401 type Value = Fragment;
402
403 fn expecting(&self, f: &mut Formatter) -> fmt::Result {
404 f.write_str("struct variant Fragment::Internal")
405 }
406
407 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
408 where
409 A: de::SeqAccess<'de>,
410 {
411 let text: InternedText =
412 seq.next_element()?.ok_or_else(|| de::Error::invalid_length(0, &self))?;
413 Ok(Fragment::Internal {
414 text: text.0,
415 })
416 }
417
418 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
419 where
420 A: MapAccess<'de>,
421 {
422 let mut text: Option<Arc<str>> = None;
423 while let Some(key) = map.next_key::<InternalField>()? {
424 match key {
425 InternalField::Text => {
426 if text.is_some() {
427 return Err(de::Error::duplicate_field("text"));
428 }
429 let value: InternedText = map.next_value()?;
430 text = Some(value.0);
431 }
432 }
433 }
434 Ok(Fragment::Internal {
435 text: text.ok_or_else(|| de::Error::missing_field("text"))?,
436 })
437 }
438 }
439
440 deserializer.deserialize_enum("Fragment", VARIANTS, FragmentVisitor)
441 }
442}
443
444impl Fragment {
445 pub fn text(&self) -> &str {
446 match self {
447 Fragment::None => "",
448 Fragment::Statement {
449 text,
450 ..
451 }
452 | Fragment::Internal {
453 text,
454 ..
455 } => text,
456 }
457 }
458
459 pub fn line(&self) -> StatementLine {
460 match self {
461 Fragment::Statement {
462 line,
463 ..
464 } => *line,
465 _ => StatementLine(1),
466 }
467 }
468
469 pub fn column(&self) -> StatementColumn {
470 match self {
471 Fragment::Statement {
472 column,
473 ..
474 } => *column,
475 _ => StatementColumn(0),
476 }
477 }
478
479 pub fn sub_fragment(&self, offset: usize, length: usize) -> Fragment {
480 let text = self.text();
481 let end = cmp::min(offset + length, text.len());
482 let sub_text = if offset < text.len() {
483 &text[offset..end]
484 } else {
485 ""
486 };
487
488 match self {
489 Fragment::None => Fragment::None,
490 Fragment::Statement {
491 line,
492 column,
493 ..
494 } => Fragment::Statement {
495 text: Arc::from(sub_text),
496 line: *line,
497 column: StatementColumn(column.0 + offset as u32),
498 },
499 Fragment::Internal {
500 ..
501 } => Fragment::Internal {
502 text: Arc::from(sub_text),
503 },
504 }
505 }
506
507 pub fn with_text(&self, text: impl AsRef<str>) -> Fragment {
508 let text = Arc::from(text.as_ref());
509 match self {
510 Fragment::Statement {
511 line,
512 column,
513 ..
514 } => Fragment::Statement {
515 text,
516 line: *line,
517 column: *column,
518 },
519 Fragment::Internal {
520 ..
521 } => Fragment::Internal {
522 text,
523 },
524 Fragment::None => Fragment::Internal {
525 text,
526 },
527 }
528 }
529}
530
531impl Fragment {
532 pub fn internal(text: impl AsRef<str>) -> Self {
533 Fragment::Internal {
534 text: intern(text.as_ref()),
535 }
536 }
537
538 pub fn testing(text: impl AsRef<str>) -> Self {
539 Fragment::Statement {
540 text: Arc::from(text.as_ref()),
541 line: StatementLine(1),
542 column: StatementColumn(0),
543 }
544 }
545
546 pub fn testing_empty() -> Self {
547 Self::testing("")
548 }
549
550 pub fn merge_all(fragments: impl IntoIterator<Item = Fragment>) -> Fragment {
551 let mut fragments: Vec<Fragment> = fragments.into_iter().collect();
552 assert!(!fragments.is_empty());
553
554 fragments.sort();
555
556 let first = fragments.first().unwrap();
557
558 let mut text = String::with_capacity(fragments.iter().map(|f| f.text().len()).sum());
559 for fragment in &fragments {
560 text.push_str(fragment.text());
561 }
562
563 match first {
564 Fragment::None => Fragment::None,
565 Fragment::Statement {
566 line,
567 column,
568 ..
569 } => Fragment::Statement {
570 text: Arc::from(text),
571 line: *line,
572 column: *column,
573 },
574 Fragment::Internal {
575 ..
576 } => Fragment::Internal {
577 text: Arc::from(text),
578 },
579 }
580 }
581
582 pub fn fragment(&self) -> &str {
583 self.text()
584 }
585}
586
587impl AsRef<str> for Fragment {
588 fn as_ref(&self) -> &str {
589 self.text()
590 }
591}
592
593impl Display for Fragment {
594 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
595 Display::fmt(self.text(), f)
596 }
597}
598
599impl PartialOrd for Fragment {
600 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
601 Some(self.cmp(other))
602 }
603}
604
605impl Ord for Fragment {
606 fn cmp(&self, other: &Self) -> Ordering {
607 self.column().cmp(&other.column()).then(self.line().cmp(&other.line()))
608 }
609}
610
611impl Eq for Fragment {}
612
613impl From<String> for Fragment {
614 fn from(s: String) -> Self {
615 Fragment::Internal {
616 text: Arc::from(s),
617 }
618 }
619}
620
621impl From<&str> for Fragment {
622 fn from(s: &str) -> Self {
623 Fragment::Internal {
624 text: Arc::from(s),
625 }
626 }
627}
628
629impl Fragment {
630 pub fn statement(text: impl AsRef<str>, line: u32, column: u32) -> Self {
631 Fragment::Statement {
632 text: Arc::from(text.as_ref()),
633 line: StatementLine(line),
634 column: StatementColumn(column),
635 }
636 }
637
638 pub fn none() -> Self {
639 Fragment::None
640 }
641}
642
643impl PartialEq<str> for Fragment {
644 fn eq(&self, other: &str) -> bool {
645 self.text() == other
646 }
647}
648
649impl PartialEq<&str> for Fragment {
650 fn eq(&self, other: &&str) -> bool {
651 self.text() == *other
652 }
653}
654
655impl PartialEq<String> for Fragment {
656 fn eq(&self, other: &String) -> bool {
657 self.text() == other.as_str()
658 }
659}
660
661impl PartialEq<String> for &Fragment {
662 fn eq(&self, other: &String) -> bool {
663 self.text() == other.as_str()
664 }
665}
666
667pub trait LazyFragment {
668 fn fragment(&self) -> Fragment;
669}
670
671impl<F> LazyFragment for F
672where
673 F: Fn() -> Fragment,
674{
675 fn fragment(&self) -> Fragment {
676 self()
677 }
678}
679
680impl LazyFragment for &Fragment {
681 fn fragment(&self) -> Fragment {
682 (*self).clone()
683 }
684}
685
686impl LazyFragment for Fragment {
687 fn fragment(&self) -> Fragment {
688 self.clone()
689 }
690}
691
692#[cfg(test)]
693mod tests {
694 use std::sync::Arc;
695
696 use postcard::{from_bytes, to_allocvec};
697
698 use super::*;
699
700 fn internal_text(fragment: &Fragment) -> &Arc<str> {
701 match fragment {
702 Fragment::Internal {
703 text,
704 } => text,
705 other => panic!("expected Internal fragment, got {other:?}"),
706 }
707 }
708
709 fn statement_text(fragment: &Fragment) -> &Arc<str> {
710 match fragment {
711 Fragment::Statement {
712 text,
713 ..
714 } => text,
715 other => panic!("expected Statement fragment, got {other:?}"),
716 }
717 }
718
719 #[test]
720 fn two_internal_constructions_share_storage() {
721 let a = Fragment::internal("price_share_test_a");
724 let b = Fragment::internal("price_share_test_a");
725 assert!(Arc::ptr_eq(internal_text(&a), internal_text(&b)));
726 }
727
728 #[test]
729 fn deserialize_internal_shares_storage_with_construction() {
730 let constructed = Fragment::internal("vwap_share_test");
734 let bytes = to_allocvec(&constructed).unwrap();
735 let decoded: Fragment = from_bytes(&bytes).unwrap();
736 assert_eq!(constructed, decoded);
737 assert!(Arc::ptr_eq(internal_text(&constructed), internal_text(&decoded)));
738 }
739
740 #[test]
741 fn statement_text_is_not_interned() {
742 let a = Fragment::statement("select arbitrary rql text", 1, 0);
746 let b = Fragment::statement("select arbitrary rql text", 1, 0);
747 assert!(!Arc::ptr_eq(statement_text(&a), statement_text(&b)));
748
749 let bytes = to_allocvec(&a).unwrap();
750 let decoded_one: Fragment = from_bytes(&bytes).unwrap();
751 let decoded_two: Fragment = from_bytes(&bytes).unwrap();
752 assert!(!Arc::ptr_eq(statement_text(&decoded_one), statement_text(&decoded_two)));
753 }
754
755 #[test]
756 fn round_trip_preserves_value_for_each_variant() {
757 let variants = [
760 Fragment::None,
761 Fragment::statement("from foo map { a }", 7, 3),
762 Fragment::internal("round_trip_internal"),
763 ];
764 for variant in variants {
765 let bytes = to_allocvec(&variant).unwrap();
766 let decoded: Fragment = from_bytes(&bytes).unwrap();
767 assert_eq!(variant, decoded);
768 }
769 }
770
771 #[test]
772 fn round_trip_preserves_statement_line_and_column() {
773 let original = Fragment::statement("xy", 42, 99);
776 let bytes = to_allocvec(&original).unwrap();
777 let decoded: Fragment = from_bytes(&bytes).unwrap();
778 assert_eq!(decoded.line(), StatementLine(42));
779 assert_eq!(decoded.column(), StatementColumn(99));
780 assert_eq!(decoded.text(), "xy");
781 }
782}