1use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};
2use serde::ser::{SerializeSeq, Serializer};
3use serde::{Deserialize, Serialize};
4use std::convert::Infallible;
5use std::fmt;
6use std::marker::PhantomData;
7use std::str::FromStr;
8
9#[derive(PartialEq, Eq, Debug, Clone)]
15pub struct OneOrMany<T> {
16 first: T,
18 rest: Vec<T>,
20}
21
22#[derive(Debug, thiserror::Error)]
24#[error("Cannot create OneOrMany with an empty vector.")]
25pub struct EmptyListError;
26
27impl<T: Clone> OneOrMany<T> {
28 pub fn first(&self) -> T {
30 self.first.clone()
31 }
32
33 pub fn first_ref(&self) -> &T {
35 &self.first
36 }
37
38 pub fn first_mut(&mut self) -> &mut T {
40 &mut self.first
41 }
42
43 pub fn last(&self) -> T {
45 self.rest
46 .last()
47 .cloned()
48 .unwrap_or_else(|| self.first.clone())
49 }
50
51 pub fn last_ref(&self) -> &T {
53 self.rest.last().unwrap_or(&self.first)
54 }
55
56 pub fn last_mut(&mut self) -> &mut T {
58 self.rest.last_mut().unwrap_or(&mut self.first)
59 }
60
61 pub fn rest(&self) -> Vec<T> {
63 self.rest.clone()
64 }
65
66 pub fn push(&mut self, item: T) {
68 self.rest.push(item);
69 }
70
71 pub fn insert(&mut self, index: usize, item: T) {
73 if index == 0 {
74 let old_first = std::mem::replace(&mut self.first, item);
75 self.rest.insert(0, old_first);
76 } else {
77 self.rest.insert(index - 1, item);
78 }
79 }
80
81 pub fn len(&self) -> usize {
83 1 + self.rest.len()
84 }
85
86 pub fn is_empty(&self) -> bool {
89 false
90 }
91
92 pub fn one(item: T) -> Self {
94 OneOrMany {
95 first: item,
96 rest: vec![],
97 }
98 }
99
100 pub fn many<I>(items: I) -> Result<Self, EmptyListError>
102 where
103 I: IntoIterator<Item = T>,
104 {
105 let mut iter = items.into_iter();
106 Ok(OneOrMany {
107 first: match iter.next() {
108 Some(item) => item,
109 None => return Err(EmptyListError),
110 },
111 rest: iter.collect(),
112 })
113 }
114
115 pub fn merge<I>(one_or_many_items: I) -> Result<Self, EmptyListError>
117 where
118 I: IntoIterator<Item = OneOrMany<T>>,
119 {
120 let items = one_or_many_items
121 .into_iter()
122 .flat_map(|one_or_many| one_or_many.into_iter())
123 .collect::<Vec<_>>();
124
125 OneOrMany::many(items)
126 }
127
128 pub(crate) fn map<U, F: FnMut(T) -> U>(self, mut op: F) -> OneOrMany<U> {
134 OneOrMany {
135 first: op(self.first),
136 rest: self.rest.into_iter().map(op).collect(),
137 }
138 }
139
140 pub(crate) fn from_iter_optional<I>(items: I) -> Option<Self>
142 where
143 I: IntoIterator<Item = T>,
144 {
145 let mut iter = items.into_iter();
146 let first = iter.next()?;
147 Some(OneOrMany {
148 first,
149 rest: iter.collect(),
150 })
151 }
152
153 pub(crate) fn try_map<U, E, F>(self, mut op: F) -> Result<OneOrMany<U>, E>
157 where
158 F: FnMut(T) -> Result<U, E>,
159 {
160 Ok(OneOrMany {
161 first: op(self.first)?,
162 rest: self
163 .rest
164 .into_iter()
165 .map(op)
166 .collect::<Result<Vec<_>, E>>()?,
167 })
168 }
169
170 pub fn iter(&self) -> Iter<'_, T> {
171 Iter {
172 first: Some(&self.first),
173 rest: self.rest.iter(),
174 }
175 }
176
177 pub fn iter_mut(&mut self) -> IterMut<'_, T> {
178 IterMut {
179 first: Some(&mut self.first),
180 rest: self.rest.iter_mut(),
181 }
182 }
183}
184
185pub struct Iter<'a, T> {
194 first: Option<&'a T>,
196 rest: std::slice::Iter<'a, T>,
197}
198
199impl<'a, T> Iterator for Iter<'a, T> {
202 type Item = &'a T;
203
204 fn next(&mut self) -> Option<Self::Item> {
205 if let Some(first) = self.first.take() {
206 Some(first)
207 } else {
208 self.rest.next()
209 }
210 }
211
212 fn size_hint(&self) -> (usize, Option<usize>) {
213 let first = if self.first.is_some() { 1 } else { 0 };
214 let max = self.rest.size_hint().1.unwrap_or(0) + first;
215 if max > 0 {
216 (1, Some(max))
217 } else {
218 (0, Some(0))
219 }
220 }
221}
222
223pub struct IntoIter<T> {
225 first: Option<T>,
227 rest: std::vec::IntoIter<T>,
228}
229
230impl<T> IntoIterator for OneOrMany<T>
232where
233 T: Clone,
234{
235 type Item = T;
236 type IntoIter = IntoIter<T>;
237
238 fn into_iter(self) -> Self::IntoIter {
239 IntoIter {
240 first: Some(self.first),
241 rest: self.rest.into_iter(),
242 }
243 }
244}
245
246impl<T> Iterator for IntoIter<T>
249where
250 T: Clone,
251{
252 type Item = T;
253
254 fn next(&mut self) -> Option<Self::Item> {
255 match self.first.take() {
256 Some(first) => Some(first),
257 _ => self.rest.next(),
258 }
259 }
260
261 fn size_hint(&self) -> (usize, Option<usize>) {
262 let first = if self.first.is_some() { 1 } else { 0 };
263 let max = self.rest.size_hint().1.unwrap_or(0) + first;
264 if max > 0 {
265 (1, Some(max))
266 } else {
267 (0, Some(0))
268 }
269 }
270}
271
272pub struct IterMut<'a, T> {
274 first: Option<&'a mut T>,
276 rest: std::slice::IterMut<'a, T>,
277}
278
279impl<'a, T> Iterator for IterMut<'a, T> {
282 type Item = &'a mut T;
283
284 fn next(&mut self) -> Option<Self::Item> {
285 if let Some(first) = self.first.take() {
286 Some(first)
287 } else {
288 self.rest.next()
289 }
290 }
291
292 fn size_hint(&self) -> (usize, Option<usize>) {
293 let first = if self.first.is_some() { 1 } else { 0 };
294 let max = self.rest.size_hint().1.unwrap_or(0) + first;
295 if max > 0 {
296 (1, Some(max))
297 } else {
298 (0, Some(0))
299 }
300 }
301}
302
303impl<T> Serialize for OneOrMany<T>
305where
306 T: Serialize + Clone,
307{
308 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
309 where
310 S: Serializer,
311 {
312 let mut seq = serializer.serialize_seq(Some(self.len()))?;
314 for e in self.iter() {
316 seq.serialize_element(e)?;
317 }
318 seq.end()
320 }
321}
322
323impl<'de, T> Deserialize<'de> for OneOrMany<T>
327where
328 T: Deserialize<'de> + Clone,
329{
330 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
331 where
332 D: Deserializer<'de>,
333 {
334 struct OneOrManyVisitor<T>(std::marker::PhantomData<T>);
336
337 impl<'de, T> Visitor<'de> for OneOrManyVisitor<T>
338 where
339 T: Deserialize<'de> + Clone,
340 {
341 type Value = OneOrMany<T>;
342
343 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
344 formatter.write_str("a sequence of at least one element")
345 }
346
347 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
349 where
350 A: SeqAccess<'de>,
351 {
352 let first = seq
354 .next_element()?
355 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
356
357 let mut rest = Vec::new();
359 while let Some(value) = seq.next_element()? {
360 rest.push(value);
361 }
362
363 Ok(OneOrMany { first, rest })
365 }
366 }
367
368 deserializer.deserialize_any(OneOrManyVisitor(std::marker::PhantomData))
370 }
371}
372
373pub fn string_or_one_or_many<'de, T, D>(deserializer: D) -> Result<OneOrMany<T>, D::Error>
382where
383 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
384 D: Deserializer<'de>,
385{
386 struct StringOrOneOrMany<T>(PhantomData<fn() -> T>);
387
388 impl<'de, T> Visitor<'de> for StringOrOneOrMany<T>
389 where
390 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
391 {
392 type Value = OneOrMany<T>;
393
394 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
395 formatter.write_str("a string or sequence")
396 }
397
398 fn visit_str<E>(self, value: &str) -> Result<OneOrMany<T>, E>
399 where
400 E: de::Error,
401 {
402 let item = FromStr::from_str(value).map_err(de::Error::custom)?;
403 Ok(OneOrMany::one(item))
404 }
405
406 fn visit_seq<A>(self, seq: A) -> Result<OneOrMany<T>, A::Error>
407 where
408 A: SeqAccess<'de>,
409 {
410 Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))
411 }
412
413 fn visit_map<M>(self, map: M) -> Result<OneOrMany<T>, M::Error>
414 where
415 M: MapAccess<'de>,
416 {
417 let item = Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
418 Ok(OneOrMany::one(item))
419 }
420 }
421
422 deserializer.deserialize_any(StringOrOneOrMany(PhantomData))
423}
424
425pub fn string_or_option_one_or_many<'de, T, D>(
434 deserializer: D,
435) -> Result<Option<OneOrMany<T>>, D::Error>
436where
437 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
438 D: Deserializer<'de>,
439{
440 struct StringOrOptionOneOrMany<T>(PhantomData<fn() -> T>);
441
442 impl<'de, T> Visitor<'de> for StringOrOptionOneOrMany<T>
443 where
444 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
445 {
446 type Value = Option<OneOrMany<T>>;
447
448 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
449 formatter.write_str("null, a string, or a sequence")
450 }
451
452 fn visit_none<E>(self) -> Result<Option<OneOrMany<T>>, E>
453 where
454 E: de::Error,
455 {
456 Ok(None)
457 }
458
459 fn visit_unit<E>(self) -> Result<Option<OneOrMany<T>>, E>
460 where
461 E: de::Error,
462 {
463 Ok(None)
464 }
465
466 fn visit_some<D>(self, deserializer: D) -> Result<Option<OneOrMany<T>>, D::Error>
467 where
468 D: Deserializer<'de>,
469 {
470 string_or_one_or_many(deserializer).map(Some)
471 }
472 }
473
474 deserializer.deserialize_option(StringOrOptionOneOrMany(PhantomData))
475}
476
477#[cfg(test)]
478mod test {
479 use serde::{self, Deserialize};
480 use serde_json::json;
481
482 use super::*;
483
484 #[test]
485 fn test_single() {
486 let one_or_many = OneOrMany::one("hello".to_string());
487
488 assert_eq!(one_or_many.iter().count(), 1);
489
490 one_or_many.iter().for_each(|i| {
491 assert_eq!(i, "hello");
492 });
493 }
494
495 #[test]
496 fn test() {
497 let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
498
499 assert_eq!(one_or_many.iter().count(), 2);
500
501 one_or_many.iter().enumerate().for_each(|(i, item)| {
502 if i == 0 {
503 assert_eq!(item, "hello");
504 }
505 if i == 1 {
506 assert_eq!(item, "word");
507 }
508 });
509 }
510
511 #[test]
512 fn test_size_hint() {
513 let foo = "bar".to_string();
514 let one_or_many = OneOrMany::one(foo);
515 let size_hint = one_or_many.iter().size_hint();
516 assert_eq!(size_hint.0, 1);
517 assert_eq!(size_hint.1, Some(1));
518
519 let vec = vec!["foo".to_string(), "bar".to_string(), "baz".to_string()];
520 let mut one_or_many = OneOrMany::many(vec).expect("this should never fail");
521 let size_hint = one_or_many.iter().size_hint();
522 assert_eq!(size_hint.0, 1);
523 assert_eq!(size_hint.1, Some(3));
524
525 let size_hint = one_or_many.clone().into_iter().size_hint();
526 assert_eq!(size_hint.0, 1);
527 assert_eq!(size_hint.1, Some(3));
528
529 let size_hint = one_or_many.iter_mut().size_hint();
530 assert_eq!(size_hint.0, 1);
531 assert_eq!(size_hint.1, Some(3));
532 }
533
534 #[test]
535 fn test_one_or_many_into_iter_single() {
536 let one_or_many = OneOrMany::one("hello".to_string());
537
538 assert_eq!(one_or_many.clone().into_iter().count(), 1);
539
540 one_or_many.into_iter().for_each(|i| {
541 assert_eq!(i, "hello".to_string());
542 });
543 }
544
545 #[test]
546 fn test_one_or_many_into_iter() {
547 let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
548
549 assert_eq!(one_or_many.clone().into_iter().count(), 2);
550
551 one_or_many.into_iter().enumerate().for_each(|(i, item)| {
552 if i == 0 {
553 assert_eq!(item, "hello".to_string());
554 }
555 if i == 1 {
556 assert_eq!(item, "word".to_string());
557 }
558 });
559 }
560
561 #[test]
562 fn test_one_or_many_merge() {
563 let one_or_many_1 = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
564
565 let one_or_many_2 = OneOrMany::one("sup".to_string());
566
567 let merged = OneOrMany::merge(vec![one_or_many_1, one_or_many_2]).unwrap();
568
569 assert_eq!(merged.iter().count(), 3);
570
571 merged.iter().enumerate().for_each(|(i, item)| {
572 if i == 0 {
573 assert_eq!(item, "hello");
574 }
575 if i == 1 {
576 assert_eq!(item, "word");
577 }
578 if i == 2 {
579 assert_eq!(item, "sup");
580 }
581 });
582 }
583
584 #[test]
585 fn test_mut_single() {
586 let mut one_or_many = OneOrMany::one("hello".to_string());
587
588 assert_eq!(one_or_many.iter_mut().count(), 1);
589
590 one_or_many.iter_mut().for_each(|i| {
591 assert_eq!(i, "hello");
592 });
593 }
594
595 #[test]
596 fn test_mut() {
597 let mut one_or_many =
598 OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
599
600 assert_eq!(one_or_many.iter_mut().count(), 2);
601
602 one_or_many.iter_mut().enumerate().for_each(|(i, item)| {
603 if i == 0 {
604 item.push_str(" world");
605 assert_eq!(item, "hello world");
606 }
607 if i == 1 {
608 assert_eq!(item, "word");
609 }
610 });
611 }
612
613 #[test]
614 fn test_one_or_many_error() {
615 assert!(OneOrMany::<String>::many(vec![]).is_err())
616 }
617
618 #[test]
619 fn test_len_single() {
620 let one_or_many = OneOrMany::one("hello".to_string());
621
622 assert_eq!(one_or_many.len(), 1);
623 }
624
625 #[test]
626 fn test_len_many() {
627 let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
628
629 assert_eq!(one_or_many.len(), 2);
630 }
631
632 #[test]
634 fn test_deserialize_list() {
635 let json_data = json!({"field": [1, 2, 3]});
636 let one_or_many: OneOrMany<i32> =
637 serde_json::from_value(json_data["field"].clone()).unwrap();
638
639 assert_eq!(one_or_many.len(), 3);
640 assert_eq!(one_or_many.first(), 1);
641 assert_eq!(one_or_many.rest(), vec![2, 3]);
642 }
643
644 #[test]
645 fn test_deserialize_list_of_maps() {
646 let json_data = json!({"field": [{"key": "value1"}, {"key": "value2"}]});
647 let one_or_many: OneOrMany<serde_json::Value> =
648 serde_json::from_value(json_data["field"].clone()).unwrap();
649
650 assert_eq!(one_or_many.len(), 2);
651 assert_eq!(one_or_many.first(), json!({"key": "value1"}));
652 assert_eq!(one_or_many.rest(), vec![json!({"key": "value2"})]);
653 }
654
655 #[derive(Debug, Deserialize, PartialEq)]
656 struct DummyStruct {
657 #[serde(deserialize_with = "string_or_one_or_many")]
658 field: OneOrMany<DummyString>,
659 }
660
661 #[derive(Debug, Deserialize, PartialEq)]
662 struct DummyStructOption {
663 #[serde(deserialize_with = "string_or_option_one_or_many")]
664 field: Option<OneOrMany<DummyString>>,
665 }
666
667 #[derive(Debug, Clone, Deserialize, PartialEq)]
668 struct DummyString {
669 pub string: String,
670 }
671
672 impl FromStr for DummyString {
673 type Err = Infallible;
674
675 fn from_str(s: &str) -> Result<Self, Self::Err> {
676 Ok(DummyString {
677 string: s.to_string(),
678 })
679 }
680 }
681
682 #[derive(Debug, Deserialize, PartialEq)]
683 #[serde(tag = "role", rename_all = "lowercase")]
684 enum DummyMessage {
685 Assistant {
686 #[serde(deserialize_with = "string_or_option_one_or_many")]
687 content: Option<OneOrMany<DummyString>>,
688 },
689 }
690
691 #[test]
692 fn test_deserialize_unit() {
693 let raw_json = r#"
694 {
695 "role": "assistant",
696 "content": null
697 }
698 "#;
699 let dummy: DummyMessage = serde_json::from_str(raw_json).unwrap();
700
701 assert_eq!(dummy, DummyMessage::Assistant { content: None });
702 }
703
704 #[test]
705 fn test_deserialize_string() {
706 let json_data = json!({"field": "hello"});
707 let dummy: DummyStruct = serde_json::from_value(json_data).unwrap();
708
709 assert_eq!(dummy.field.len(), 1);
710 assert_eq!(dummy.field.first(), DummyString::from_str("hello").unwrap());
711 }
712
713 #[test]
714 fn test_deserialize_string_option() {
715 let json_data = json!({"field": "hello"});
716 let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
717
718 assert!(dummy.field.is_some());
719 let field = dummy.field.unwrap();
720 assert_eq!(field.len(), 1);
721 assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
722 }
723
724 #[test]
725 fn test_deserialize_list_option() {
726 let json_data = json!({"field": [{"string": "hello"}, {"string": "world"}]});
727 let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
728
729 assert!(dummy.field.is_some());
730 let field = dummy.field.unwrap();
731 assert_eq!(field.len(), 2);
732 assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
733 assert_eq!(field.rest(), vec![DummyString::from_str("world").unwrap()]);
734 }
735
736 #[test]
737 fn test_deserialize_null_option() {
738 let json_data = json!({"field": null});
739 let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
740
741 assert!(dummy.field.is_none());
742 }
743}