1use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};
2use serde::ser::{SerializeSeq, Serializer};
3use serde::{Deserialize, Serialize};
4use std::convert::Infallible;
5use std::fmt;
6use std::marker::PhantomData;
7use std::str::FromStr;
8
9#[derive(PartialEq, Eq, Debug, Clone)]
15pub struct OneOrMany<T> {
16 first: T,
18 rest: Vec<T>,
20}
21
22#[derive(Debug, thiserror::Error)]
24#[error("Cannot create OneOrMany with an empty vector.")]
25pub struct EmptyListError;
26
27impl<T: Clone> OneOrMany<T> {
28 pub fn first(&self) -> T {
30 self.first.clone()
31 }
32
33 pub fn first_ref(&self) -> &T {
35 &self.first
36 }
37
38 pub fn first_mut(&mut self) -> &mut T {
40 &mut self.first
41 }
42
43 pub fn last(&self) -> T {
45 self.rest
46 .last()
47 .cloned()
48 .unwrap_or_else(|| self.first.clone())
49 }
50
51 pub fn last_ref(&self) -> &T {
53 self.rest.last().unwrap_or(&self.first)
54 }
55
56 pub fn last_mut(&mut self) -> &mut T {
58 self.rest.last_mut().unwrap_or(&mut self.first)
59 }
60
61 pub fn rest(&self) -> Vec<T> {
63 self.rest.clone()
64 }
65
66 pub fn push(&mut self, item: T) {
68 self.rest.push(item);
69 }
70
71 pub fn insert(&mut self, index: usize, item: T) {
73 if index == 0 {
74 let old_first = std::mem::replace(&mut self.first, item);
75 self.rest.insert(0, old_first);
76 } else {
77 self.rest.insert(index - 1, item);
78 }
79 }
80
81 pub fn len(&self) -> usize {
83 1 + self.rest.len()
84 }
85
86 pub fn is_empty(&self) -> bool {
89 false
90 }
91
92 pub fn one(item: T) -> Self {
94 OneOrMany {
95 first: item,
96 rest: vec![],
97 }
98 }
99
100 pub fn many<I>(items: I) -> Result<Self, EmptyListError>
102 where
103 I: IntoIterator<Item = T>,
104 {
105 let mut iter = items.into_iter();
106 Ok(OneOrMany {
107 first: match iter.next() {
108 Some(item) => item,
109 None => return Err(EmptyListError),
110 },
111 rest: iter.collect(),
112 })
113 }
114
115 pub fn merge<I>(one_or_many_items: I) -> Result<Self, EmptyListError>
117 where
118 I: IntoIterator<Item = OneOrMany<T>>,
119 {
120 let items = one_or_many_items
121 .into_iter()
122 .flat_map(|one_or_many| one_or_many.into_iter())
123 .collect::<Vec<_>>();
124
125 OneOrMany::many(items)
126 }
127
128 pub(crate) fn map<U, F: FnMut(T) -> U>(self, mut op: F) -> OneOrMany<U> {
134 OneOrMany {
135 first: op(self.first),
136 rest: self.rest.into_iter().map(op).collect(),
137 }
138 }
139
140 pub(crate) fn try_map<U, E, F>(self, mut op: F) -> Result<OneOrMany<U>, E>
144 where
145 F: FnMut(T) -> Result<U, E>,
146 {
147 Ok(OneOrMany {
148 first: op(self.first)?,
149 rest: self
150 .rest
151 .into_iter()
152 .map(op)
153 .collect::<Result<Vec<_>, E>>()?,
154 })
155 }
156
157 pub fn iter(&self) -> Iter<'_, T> {
158 Iter {
159 first: Some(&self.first),
160 rest: self.rest.iter(),
161 }
162 }
163
164 pub fn iter_mut(&mut self) -> IterMut<'_, T> {
165 IterMut {
166 first: Some(&mut self.first),
167 rest: self.rest.iter_mut(),
168 }
169 }
170}
171
172pub struct Iter<'a, T> {
181 first: Option<&'a T>,
183 rest: std::slice::Iter<'a, T>,
184}
185
186impl<'a, T> Iterator for Iter<'a, T> {
189 type Item = &'a T;
190
191 fn next(&mut self) -> Option<Self::Item> {
192 if let Some(first) = self.first.take() {
193 Some(first)
194 } else {
195 self.rest.next()
196 }
197 }
198
199 fn size_hint(&self) -> (usize, Option<usize>) {
200 let first = if self.first.is_some() { 1 } else { 0 };
201 let max = self.rest.size_hint().1.unwrap_or(0) + first;
202 if max > 0 {
203 (1, Some(max))
204 } else {
205 (0, Some(0))
206 }
207 }
208}
209
210pub struct IntoIter<T> {
212 first: Option<T>,
214 rest: std::vec::IntoIter<T>,
215}
216
217impl<T> IntoIterator for OneOrMany<T>
219where
220 T: Clone,
221{
222 type Item = T;
223 type IntoIter = IntoIter<T>;
224
225 fn into_iter(self) -> Self::IntoIter {
226 IntoIter {
227 first: Some(self.first),
228 rest: self.rest.into_iter(),
229 }
230 }
231}
232
233impl<T> Iterator for IntoIter<T>
236where
237 T: Clone,
238{
239 type Item = T;
240
241 fn next(&mut self) -> Option<Self::Item> {
242 match self.first.take() {
243 Some(first) => Some(first),
244 _ => self.rest.next(),
245 }
246 }
247
248 fn size_hint(&self) -> (usize, Option<usize>) {
249 let first = if self.first.is_some() { 1 } else { 0 };
250 let max = self.rest.size_hint().1.unwrap_or(0) + first;
251 if max > 0 {
252 (1, Some(max))
253 } else {
254 (0, Some(0))
255 }
256 }
257}
258
259pub struct IterMut<'a, T> {
261 first: Option<&'a mut T>,
263 rest: std::slice::IterMut<'a, T>,
264}
265
266impl<'a, T> Iterator for IterMut<'a, T> {
269 type Item = &'a mut T;
270
271 fn next(&mut self) -> Option<Self::Item> {
272 if let Some(first) = self.first.take() {
273 Some(first)
274 } else {
275 self.rest.next()
276 }
277 }
278
279 fn size_hint(&self) -> (usize, Option<usize>) {
280 let first = if self.first.is_some() { 1 } else { 0 };
281 let max = self.rest.size_hint().1.unwrap_or(0) + first;
282 if max > 0 {
283 (1, Some(max))
284 } else {
285 (0, Some(0))
286 }
287 }
288}
289
290impl<T> Serialize for OneOrMany<T>
292where
293 T: Serialize + Clone,
294{
295 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
296 where
297 S: Serializer,
298 {
299 let mut seq = serializer.serialize_seq(Some(self.len()))?;
301 for e in self.iter() {
303 seq.serialize_element(e)?;
304 }
305 seq.end()
307 }
308}
309
310impl<'de, T> Deserialize<'de> for OneOrMany<T>
314where
315 T: Deserialize<'de> + Clone,
316{
317 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
318 where
319 D: Deserializer<'de>,
320 {
321 struct OneOrManyVisitor<T>(std::marker::PhantomData<T>);
323
324 impl<'de, T> Visitor<'de> for OneOrManyVisitor<T>
325 where
326 T: Deserialize<'de> + Clone,
327 {
328 type Value = OneOrMany<T>;
329
330 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
331 formatter.write_str("a sequence of at least one element")
332 }
333
334 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
336 where
337 A: SeqAccess<'de>,
338 {
339 let first = seq
341 .next_element()?
342 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
343
344 let mut rest = Vec::new();
346 while let Some(value) = seq.next_element()? {
347 rest.push(value);
348 }
349
350 Ok(OneOrMany { first, rest })
352 }
353 }
354
355 deserializer.deserialize_any(OneOrManyVisitor(std::marker::PhantomData))
357 }
358}
359
360pub fn string_or_one_or_many<'de, T, D>(deserializer: D) -> Result<OneOrMany<T>, D::Error>
369where
370 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
371 D: Deserializer<'de>,
372{
373 struct StringOrOneOrMany<T>(PhantomData<fn() -> T>);
374
375 impl<'de, T> Visitor<'de> for StringOrOneOrMany<T>
376 where
377 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
378 {
379 type Value = OneOrMany<T>;
380
381 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
382 formatter.write_str("a string or sequence")
383 }
384
385 fn visit_str<E>(self, value: &str) -> Result<OneOrMany<T>, E>
386 where
387 E: de::Error,
388 {
389 let item = FromStr::from_str(value).map_err(de::Error::custom)?;
390 Ok(OneOrMany::one(item))
391 }
392
393 fn visit_seq<A>(self, seq: A) -> Result<OneOrMany<T>, A::Error>
394 where
395 A: SeqAccess<'de>,
396 {
397 Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))
398 }
399
400 fn visit_map<M>(self, map: M) -> Result<OneOrMany<T>, M::Error>
401 where
402 M: MapAccess<'de>,
403 {
404 let item = Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
405 Ok(OneOrMany::one(item))
406 }
407 }
408
409 deserializer.deserialize_any(StringOrOneOrMany(PhantomData))
410}
411
412pub fn string_or_option_one_or_many<'de, T, D>(
421 deserializer: D,
422) -> Result<Option<OneOrMany<T>>, D::Error>
423where
424 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
425 D: Deserializer<'de>,
426{
427 struct StringOrOptionOneOrMany<T>(PhantomData<fn() -> T>);
428
429 impl<'de, T> Visitor<'de> for StringOrOptionOneOrMany<T>
430 where
431 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
432 {
433 type Value = Option<OneOrMany<T>>;
434
435 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
436 formatter.write_str("null, a string, or a sequence")
437 }
438
439 fn visit_none<E>(self) -> Result<Option<OneOrMany<T>>, E>
440 where
441 E: de::Error,
442 {
443 Ok(None)
444 }
445
446 fn visit_unit<E>(self) -> Result<Option<OneOrMany<T>>, E>
447 where
448 E: de::Error,
449 {
450 Ok(None)
451 }
452
453 fn visit_some<D>(self, deserializer: D) -> Result<Option<OneOrMany<T>>, D::Error>
454 where
455 D: Deserializer<'de>,
456 {
457 string_or_one_or_many(deserializer).map(Some)
458 }
459 }
460
461 deserializer.deserialize_option(StringOrOptionOneOrMany(PhantomData))
462}
463
464#[cfg(test)]
465mod test {
466 use serde::{self, Deserialize};
467 use serde_json::json;
468
469 use super::*;
470
471 #[test]
472 fn test_single() {
473 let one_or_many = OneOrMany::one("hello".to_string());
474
475 assert_eq!(one_or_many.iter().count(), 1);
476
477 one_or_many.iter().for_each(|i| {
478 assert_eq!(i, "hello");
479 });
480 }
481
482 #[test]
483 fn test() {
484 let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
485
486 assert_eq!(one_or_many.iter().count(), 2);
487
488 one_or_many.iter().enumerate().for_each(|(i, item)| {
489 if i == 0 {
490 assert_eq!(item, "hello");
491 }
492 if i == 1 {
493 assert_eq!(item, "word");
494 }
495 });
496 }
497
498 #[test]
499 fn test_size_hint() {
500 let foo = "bar".to_string();
501 let one_or_many = OneOrMany::one(foo);
502 let size_hint = one_or_many.iter().size_hint();
503 assert_eq!(size_hint.0, 1);
504 assert_eq!(size_hint.1, Some(1));
505
506 let vec = vec!["foo".to_string(), "bar".to_string(), "baz".to_string()];
507 let mut one_or_many = OneOrMany::many(vec).expect("this should never fail");
508 let size_hint = one_or_many.iter().size_hint();
509 assert_eq!(size_hint.0, 1);
510 assert_eq!(size_hint.1, Some(3));
511
512 let size_hint = one_or_many.clone().into_iter().size_hint();
513 assert_eq!(size_hint.0, 1);
514 assert_eq!(size_hint.1, Some(3));
515
516 let size_hint = one_or_many.iter_mut().size_hint();
517 assert_eq!(size_hint.0, 1);
518 assert_eq!(size_hint.1, Some(3));
519 }
520
521 #[test]
522 fn test_one_or_many_into_iter_single() {
523 let one_or_many = OneOrMany::one("hello".to_string());
524
525 assert_eq!(one_or_many.clone().into_iter().count(), 1);
526
527 one_or_many.into_iter().for_each(|i| {
528 assert_eq!(i, "hello".to_string());
529 });
530 }
531
532 #[test]
533 fn test_one_or_many_into_iter() {
534 let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
535
536 assert_eq!(one_or_many.clone().into_iter().count(), 2);
537
538 one_or_many.into_iter().enumerate().for_each(|(i, item)| {
539 if i == 0 {
540 assert_eq!(item, "hello".to_string());
541 }
542 if i == 1 {
543 assert_eq!(item, "word".to_string());
544 }
545 });
546 }
547
548 #[test]
549 fn test_one_or_many_merge() {
550 let one_or_many_1 = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
551
552 let one_or_many_2 = OneOrMany::one("sup".to_string());
553
554 let merged = OneOrMany::merge(vec![one_or_many_1, one_or_many_2]).unwrap();
555
556 assert_eq!(merged.iter().count(), 3);
557
558 merged.iter().enumerate().for_each(|(i, item)| {
559 if i == 0 {
560 assert_eq!(item, "hello");
561 }
562 if i == 1 {
563 assert_eq!(item, "word");
564 }
565 if i == 2 {
566 assert_eq!(item, "sup");
567 }
568 });
569 }
570
571 #[test]
572 fn test_mut_single() {
573 let mut one_or_many = OneOrMany::one("hello".to_string());
574
575 assert_eq!(one_or_many.iter_mut().count(), 1);
576
577 one_or_many.iter_mut().for_each(|i| {
578 assert_eq!(i, "hello");
579 });
580 }
581
582 #[test]
583 fn test_mut() {
584 let mut one_or_many =
585 OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
586
587 assert_eq!(one_or_many.iter_mut().count(), 2);
588
589 one_or_many.iter_mut().enumerate().for_each(|(i, item)| {
590 if i == 0 {
591 item.push_str(" world");
592 assert_eq!(item, "hello world");
593 }
594 if i == 1 {
595 assert_eq!(item, "word");
596 }
597 });
598 }
599
600 #[test]
601 fn test_one_or_many_error() {
602 assert!(OneOrMany::<String>::many(vec![]).is_err())
603 }
604
605 #[test]
606 fn test_len_single() {
607 let one_or_many = OneOrMany::one("hello".to_string());
608
609 assert_eq!(one_or_many.len(), 1);
610 }
611
612 #[test]
613 fn test_len_many() {
614 let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
615
616 assert_eq!(one_or_many.len(), 2);
617 }
618
619 #[test]
621 fn test_deserialize_list() {
622 let json_data = json!({"field": [1, 2, 3]});
623 let one_or_many: OneOrMany<i32> =
624 serde_json::from_value(json_data["field"].clone()).unwrap();
625
626 assert_eq!(one_or_many.len(), 3);
627 assert_eq!(one_or_many.first(), 1);
628 assert_eq!(one_or_many.rest(), vec![2, 3]);
629 }
630
631 #[test]
632 fn test_deserialize_list_of_maps() {
633 let json_data = json!({"field": [{"key": "value1"}, {"key": "value2"}]});
634 let one_or_many: OneOrMany<serde_json::Value> =
635 serde_json::from_value(json_data["field"].clone()).unwrap();
636
637 assert_eq!(one_or_many.len(), 2);
638 assert_eq!(one_or_many.first(), json!({"key": "value1"}));
639 assert_eq!(one_or_many.rest(), vec![json!({"key": "value2"})]);
640 }
641
642 #[derive(Debug, Deserialize, PartialEq)]
643 struct DummyStruct {
644 #[serde(deserialize_with = "string_or_one_or_many")]
645 field: OneOrMany<DummyString>,
646 }
647
648 #[derive(Debug, Deserialize, PartialEq)]
649 struct DummyStructOption {
650 #[serde(deserialize_with = "string_or_option_one_or_many")]
651 field: Option<OneOrMany<DummyString>>,
652 }
653
654 #[derive(Debug, Clone, Deserialize, PartialEq)]
655 struct DummyString {
656 pub string: String,
657 }
658
659 impl FromStr for DummyString {
660 type Err = Infallible;
661
662 fn from_str(s: &str) -> Result<Self, Self::Err> {
663 Ok(DummyString {
664 string: s.to_string(),
665 })
666 }
667 }
668
669 #[derive(Debug, Deserialize, PartialEq)]
670 #[serde(tag = "role", rename_all = "lowercase")]
671 enum DummyMessage {
672 Assistant {
673 #[serde(deserialize_with = "string_or_option_one_or_many")]
674 content: Option<OneOrMany<DummyString>>,
675 },
676 }
677
678 #[test]
679 fn test_deserialize_unit() {
680 let raw_json = r#"
681 {
682 "role": "assistant",
683 "content": null
684 }
685 "#;
686 let dummy: DummyMessage = serde_json::from_str(raw_json).unwrap();
687
688 assert_eq!(dummy, DummyMessage::Assistant { content: None });
689 }
690
691 #[test]
692 fn test_deserialize_string() {
693 let json_data = json!({"field": "hello"});
694 let dummy: DummyStruct = serde_json::from_value(json_data).unwrap();
695
696 assert_eq!(dummy.field.len(), 1);
697 assert_eq!(dummy.field.first(), DummyString::from_str("hello").unwrap());
698 }
699
700 #[test]
701 fn test_deserialize_string_option() {
702 let json_data = json!({"field": "hello"});
703 let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
704
705 assert!(dummy.field.is_some());
706 let field = dummy.field.unwrap();
707 assert_eq!(field.len(), 1);
708 assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
709 }
710
711 #[test]
712 fn test_deserialize_list_option() {
713 let json_data = json!({"field": [{"string": "hello"}, {"string": "world"}]});
714 let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
715
716 assert!(dummy.field.is_some());
717 let field = dummy.field.unwrap();
718 assert_eq!(field.len(), 2);
719 assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
720 assert_eq!(field.rest(), vec![DummyString::from_str("world").unwrap()]);
721 }
722
723 #[test]
724 fn test_deserialize_null_option() {
725 let json_data = json!({"field": null});
726 let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
727
728 assert!(dummy.field.is_none());
729 }
730}