1extern crate alloc;
6
7use alloc::string::{String, ToString};
8use alloc::vec::Vec;
9
10use crate::document::{EureDocument, NodeId};
11use crate::identifier::Identifier;
12
13use super::{DocumentParser, ParseError, ParseErrorKind};
14
15pub const VARIANT: Identifier = Identifier::new_unchecked("variant");
17
18pub struct UnionParser<'doc, T> {
45 doc: &'doc EureDocument,
46 node_id: NodeId,
47 variant_name: Option<String>,
49 variant_result: Option<Result<T, ParseError>>,
51 priority_result: Option<T>,
53 other_results: Vec<(String, T)>,
55 other_failures: Vec<(String, ParseError)>,
57}
58
59impl<'doc, T> UnionParser<'doc, T> {
60 pub(crate) fn new(doc: &'doc EureDocument, node_id: NodeId) -> Self {
62 let variant_name = Self::extract_variant(doc, node_id);
63
64 Self {
65 doc,
66 node_id,
67 variant_name,
68 variant_result: None,
69 priority_result: None,
70 other_results: Vec::new(),
71 other_failures: Vec::new(),
72 }
73 }
74
75 fn extract_variant(doc: &EureDocument, node_id: NodeId) -> Option<String> {
77 let node = doc.node(node_id);
78 let variant_node_id = node.extensions.get(&VARIANT)?;
79 doc.parse::<&str>(*variant_node_id).ok().map(String::from)
80 }
81
82 pub fn variant<P>(mut self, name: &str, parser: P) -> Self
87 where
88 P: DocumentParser<'doc, Output = T> + 'doc,
89 {
90 if let Some(ref vn) = self.variant_name {
91 if vn == name && self.variant_result.is_none() {
93 self.variant_result = Some(parser.parse(self.doc, self.node_id));
94 }
95 } else if self.priority_result.is_none()
96 && let Ok(value) = parser.parse(self.doc, self.node_id)
97 {
98 self.priority_result = Some(value);
100 }
101 self
102 }
103
104 pub fn other<P>(mut self, name: &str, parser: P) -> Self
109 where
110 P: DocumentParser<'doc, Output = T> + 'doc,
111 {
112 if let Some(ref vn) = self.variant_name {
113 if vn == name && self.variant_result.is_none() {
115 self.variant_result = Some(parser.parse(self.doc, self.node_id));
116 }
117 } else {
118 if self.priority_result.is_none() {
120 match parser.parse(self.doc, self.node_id) {
121 Ok(value) => self.other_results.push((name.to_string(), value)),
122 Err(e) => self.other_failures.push((name.to_string(), e)),
123 }
124 }
125 }
126 self
127 }
128
129 pub fn parse(self) -> Result<T, ParseError> {
137 let node_id = self.node_id;
138
139 if let Some(variant_name) = self.variant_name {
141 return self.variant_result.unwrap_or_else(|| {
142 Err(ParseError {
143 node_id,
144 kind: ParseErrorKind::UnknownVariant(variant_name),
145 })
146 });
147 }
148
149 if let Some(value) = self.priority_result {
151 return Ok(value);
152 }
153
154 match self.other_results.len() {
156 0 => Err(Self::no_match_error(node_id, self.other_failures)),
157 1 => Ok(self.other_results.into_iter().next().unwrap().1),
158 _ => Err(ParseError {
159 node_id,
160 kind: ParseErrorKind::AmbiguousUnion(
161 self.other_results
162 .into_iter()
163 .map(|(name, _)| name)
164 .collect(),
165 ),
166 }),
167 }
168 }
169
170 fn no_match_error(node_id: NodeId, failures: Vec<(String, ParseError)>) -> ParseError {
172 failures
175 .into_iter()
176 .next()
177 .map(|(_, e)| e)
178 .unwrap_or(ParseError {
179 node_id,
180 kind: ParseErrorKind::NoMatchingVariant,
181 })
182 }
183}
184
185impl EureDocument {
186 pub fn parse_union<T>(&self, node_id: NodeId) -> UnionParser<'_, T> {
197 UnionParser::new(self, node_id)
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use crate::document::node::NodeValue;
205 use crate::text::Text;
206 use crate::value::PrimitiveValue;
207
208 fn identifier(s: &str) -> Identifier {
209 s.parse().unwrap()
210 }
211
212 #[derive(Debug, PartialEq)]
213 enum TestEnum {
214 Foo,
215 Bar,
216 }
217
218 fn create_text_doc(text: &str) -> EureDocument {
219 let mut doc = EureDocument::new();
220 let root_id = doc.get_root_id();
221 doc.node_mut(root_id).content =
222 NodeValue::Primitive(PrimitiveValue::Text(Text::plaintext(text.to_string())));
223 doc
224 }
225
226 fn create_doc_with_variant(content: &str, variant: &str) -> EureDocument {
228 let mut doc = EureDocument::new();
229 let root_id = doc.get_root_id();
230
231 doc.node_mut(root_id).content =
233 NodeValue::Primitive(PrimitiveValue::Text(Text::plaintext(content.to_string())));
234
235 let variant_node_id = doc
237 .add_extension(identifier("variant"), root_id)
238 .unwrap()
239 .node_id;
240 doc.node_mut(variant_node_id).content =
241 NodeValue::Primitive(PrimitiveValue::Text(Text::plaintext(variant.to_string())));
242
243 doc
244 }
245
246 #[test]
247 fn test_union_single_match() {
248 let doc = create_text_doc("foo");
249 let root_id = doc.get_root_id();
250
251 let result: TestEnum = doc
252 .parse_union(root_id)
253 .variant("foo", |doc: &EureDocument, id| {
254 let s: &str = doc.parse(id)?;
255 if s == "foo" {
256 Ok(TestEnum::Foo)
257 } else {
258 Err(ParseError {
259 node_id: id,
260 kind: ParseErrorKind::UnknownVariant(s.to_string()),
261 })
262 }
263 })
264 .variant("bar", |doc: &EureDocument, id| {
265 let s: &str = doc.parse(id)?;
266 if s == "bar" {
267 Ok(TestEnum::Bar)
268 } else {
269 Err(ParseError {
270 node_id: id,
271 kind: ParseErrorKind::UnknownVariant(s.to_string()),
272 })
273 }
274 })
275 .parse()
276 .unwrap();
277
278 assert_eq!(result, TestEnum::Foo);
279 }
280
281 #[test]
282 fn test_union_priority_short_circuit() {
283 let doc = create_text_doc("value");
284 let root_id = doc.get_root_id();
285
286 let result: String = doc
288 .parse_union(root_id)
289 .variant("first", |doc: &EureDocument, id| doc.parse::<String>(id))
290 .variant("second", |doc: &EureDocument, id| doc.parse::<String>(id))
291 .parse()
292 .unwrap();
293
294 assert_eq!(result, "value");
295 }
296
297 #[test]
298 fn test_union_no_match() {
299 let doc = create_text_doc("baz");
300 let root_id = doc.get_root_id();
301
302 let result: Result<TestEnum, _> = doc
303 .parse_union(root_id)
304 .variant("foo", |doc: &EureDocument, id| {
305 let s: &str = doc.parse(id)?;
306 if s == "foo" {
307 Ok(TestEnum::Foo)
308 } else {
309 Err(ParseError {
310 node_id: id,
311 kind: ParseErrorKind::UnknownVariant(s.to_string()),
312 })
313 }
314 })
315 .parse();
316
317 assert!(result.is_err());
318 }
319
320 #[test]
321 fn test_union_with_borrowed_str_fn_pointer() {
322 fn parse_str(doc: &EureDocument, id: NodeId) -> Result<&str, ParseError> {
324 doc.parse(id)
325 }
326
327 let doc = create_text_doc("hello");
328 let root_id = doc.get_root_id();
329
330 let result: &str = doc
331 .parse_union(root_id)
332 .variant("str", parse_str)
333 .parse()
334 .unwrap();
335
336 assert_eq!(result, "hello");
337 }
338
339 #[test]
340 fn test_union_with_borrowed_str_closure() {
341 let doc = create_text_doc("world");
343 let root_id = doc.get_root_id();
344
345 let result: &str = doc
346 .parse_union(root_id)
347 .variant(
348 "str",
349 (|doc, id| doc.parse(id)) as fn(&EureDocument, NodeId) -> Result<&str, ParseError>,
350 )
351 .parse()
352 .unwrap();
353
354 assert_eq!(result, "world");
355 }
356
357 #[test]
360 fn test_variant_extension_match_success() {
361 let doc = create_doc_with_variant("anything", "baz");
364 let root_id = doc.get_root_id();
365
366 let result: TestEnum = doc
367 .parse_union(root_id)
368 .variant("foo", |_, _| Ok(TestEnum::Foo))
369 .other("baz", |_, _| Ok(TestEnum::Bar))
370 .parse()
371 .unwrap();
372
373 assert_eq!(result, TestEnum::Bar);
374 }
375
376 #[test]
377 fn test_variant_extension_unknown() {
378 let doc = create_doc_with_variant("anything", "unknown");
381 let root_id = doc.get_root_id();
382
383 let err = doc
384 .parse_union::<TestEnum>(root_id)
385 .variant("foo", |_, _| Ok(TestEnum::Foo))
386 .other("baz", |_, _| Ok(TestEnum::Bar))
387 .parse()
388 .unwrap_err();
389
390 assert_eq!(err.node_id, root_id);
391 assert_eq!(
392 err.kind,
393 ParseErrorKind::UnknownVariant("unknown".to_string())
394 );
395 }
396
397 #[test]
398 fn test_variant_extension_match_parse_failure() {
399 let doc = create_doc_with_variant("anything", "baz");
401 let root_id = doc.get_root_id();
402
403 let err = doc
404 .parse_union::<TestEnum>(root_id)
405 .variant("foo", |_, _| Ok(TestEnum::Foo))
406 .other("baz", |_, id| {
407 Err(ParseError {
408 node_id: id,
409 kind: ParseErrorKind::MissingField("test".to_string()),
410 })
411 })
412 .parse()
413 .unwrap_err();
414
415 assert_eq!(err.node_id, root_id);
417 assert_eq!(err.kind, ParseErrorKind::MissingField("test".to_string()));
418 }
419}