Skip to main content

rig/
one_or_many.rs

1use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};
2use serde::ser::{SerializeSeq, Serializer};
3use serde::{Deserialize, Serialize};
4use std::convert::Infallible;
5use std::fmt;
6use std::marker::PhantomData;
7use std::str::FromStr;
8
9/// Struct containing either a single item or a list of items of type T.
10/// If a single item is present, `first` will contain it and `rest` will be empty.
11/// If multiple items are present, `first` will contain the first item and `rest` will contain the rest.
12/// IMPORTANT: this struct cannot be created with an empty vector.
13/// OneOrMany objects can only be created using OneOrMany::from() or OneOrMany::try_from().
14#[derive(PartialEq, Eq, Debug, Clone)]
15pub struct OneOrMany<T> {
16    /// First item in the list.
17    first: T,
18    /// Rest of the items in the list.
19    rest: Vec<T>,
20}
21
22/// Error type for when trying to create a OneOrMany object with an empty vector.
23#[derive(Debug, thiserror::Error)]
24#[error("Cannot create OneOrMany with an empty vector.")]
25pub struct EmptyListError;
26
27impl<T: Clone> OneOrMany<T> {
28    /// Get the first item in the list.
29    pub fn first(&self) -> T {
30        self.first.clone()
31    }
32
33    /// Get a reference to the first item in the list.
34    pub fn first_ref(&self) -> &T {
35        &self.first
36    }
37
38    /// Get a mutable reference to the first item in the list.
39    pub fn first_mut(&mut self) -> &mut T {
40        &mut self.first
41    }
42
43    /// Get the last item in the list.
44    pub fn last(&self) -> T {
45        self.rest
46            .last()
47            .cloned()
48            .unwrap_or_else(|| self.first.clone())
49    }
50
51    /// Get a reference to the last item in the list.
52    pub fn last_ref(&self) -> &T {
53        self.rest.last().unwrap_or(&self.first)
54    }
55
56    /// Get a mutable reference to the last item in the list.
57    pub fn last_mut(&mut self) -> &mut T {
58        self.rest.last_mut().unwrap_or(&mut self.first)
59    }
60
61    /// Get the rest of the items in the list (excluding the first one).
62    pub fn rest(&self) -> Vec<T> {
63        self.rest.clone()
64    }
65
66    /// After `OneOrMany<T>` is created, add an item of type T to the `rest`.
67    pub fn push(&mut self, item: T) {
68        self.rest.push(item);
69    }
70
71    /// After `OneOrMany<T>` is created, insert an item of type T at an index.
72    pub fn insert(&mut self, index: usize, item: T) {
73        if index == 0 {
74            let old_first = std::mem::replace(&mut self.first, item);
75            self.rest.insert(0, old_first);
76        } else {
77            self.rest.insert(index - 1, item);
78        }
79    }
80
81    /// Length of all items in `OneOrMany<T>`.
82    pub fn len(&self) -> usize {
83        1 + self.rest.len()
84    }
85
86    /// If `OneOrMany<T>` is empty. This will always be false because you cannot create an empty `OneOrMany<T>`.
87    /// This method is required when the method `len` exists.
88    pub fn is_empty(&self) -> bool {
89        false
90    }
91
92    /// Create a `OneOrMany` object with a single item of any type.
93    pub fn one(item: T) -> Self {
94        OneOrMany {
95            first: item,
96            rest: vec![],
97        }
98    }
99
100    /// Create a `OneOrMany` object with a vector of items of any type.
101    pub fn many<I>(items: I) -> Result<Self, EmptyListError>
102    where
103        I: IntoIterator<Item = T>,
104    {
105        let mut iter = items.into_iter();
106        Ok(OneOrMany {
107            first: match iter.next() {
108                Some(item) => item,
109                None => return Err(EmptyListError),
110            },
111            rest: iter.collect(),
112        })
113    }
114
115    /// Merge a list of OneOrMany items into a single OneOrMany item.
116    pub fn merge<I>(one_or_many_items: I) -> Result<Self, EmptyListError>
117    where
118        I: IntoIterator<Item = OneOrMany<T>>,
119    {
120        let items = one_or_many_items
121            .into_iter()
122            .flat_map(|one_or_many| one_or_many.into_iter())
123            .collect::<Vec<_>>();
124
125        OneOrMany::many(items)
126    }
127
128    /// Specialized map function for OneOrMany objects.
129    ///
130    /// Since OneOrMany objects have *atleast* 1 item, using `.collect::<Vec<_>>()` and
131    /// `OneOrMany::many()` is fallible resulting in unergonomic uses of `.expect` or `.unwrap`.
132    /// This function bypasses those hurdles by directly constructing the `OneOrMany` struct.
133    pub(crate) fn map<U, F: FnMut(T) -> U>(self, mut op: F) -> OneOrMany<U> {
134        OneOrMany {
135            first: op(self.first),
136            rest: self.rest.into_iter().map(op).collect(),
137        }
138    }
139
140    /// Build a `OneOrMany` from an iterator when the caller can naturally handle an empty input.
141    pub(crate) fn from_iter_optional<I>(items: I) -> Option<Self>
142    where
143        I: IntoIterator<Item = T>,
144    {
145        let mut iter = items.into_iter();
146        let first = iter.next()?;
147        Some(OneOrMany {
148            first,
149            rest: iter.collect(),
150        })
151    }
152
153    /// Specialized try map function for OneOrMany objects.
154    ///
155    /// Same as `OneOrMany::map` but fallible.
156    pub(crate) fn try_map<U, E, F>(self, mut op: F) -> Result<OneOrMany<U>, E>
157    where
158        F: FnMut(T) -> Result<U, E>,
159    {
160        Ok(OneOrMany {
161            first: op(self.first)?,
162            rest: self
163                .rest
164                .into_iter()
165                .map(op)
166                .collect::<Result<Vec<_>, E>>()?,
167        })
168    }
169
170    pub fn iter(&self) -> Iter<'_, T> {
171        Iter {
172            first: Some(&self.first),
173            rest: self.rest.iter(),
174        }
175    }
176
177    pub fn iter_mut(&mut self) -> IterMut<'_, T> {
178        IterMut {
179            first: Some(&mut self.first),
180            rest: self.rest.iter_mut(),
181        }
182    }
183}
184
185// ================================================================
186// Implementations of Iterator for OneOrMany
187//   - OneOrMany<T>::iter() -> iterate over references of T objects
188//   - OneOrMany<T>::into_iter() -> iterate over owned T objects
189//   - OneOrMany<T>::iter_mut() -> iterate over mutable references of T objects
190// ================================================================
191
192/// Struct returned by call to `OneOrMany::iter()`.
193pub struct Iter<'a, T> {
194    // References.
195    first: Option<&'a T>,
196    rest: std::slice::Iter<'a, T>,
197}
198
199/// Implement `Iterator` for `Iter<T>`.
200/// The Item type of the `Iterator` trait is a reference of `T`.
201impl<'a, T> Iterator for Iter<'a, T> {
202    type Item = &'a T;
203
204    fn next(&mut self) -> Option<Self::Item> {
205        if let Some(first) = self.first.take() {
206            Some(first)
207        } else {
208            self.rest.next()
209        }
210    }
211
212    fn size_hint(&self) -> (usize, Option<usize>) {
213        let first = if self.first.is_some() { 1 } else { 0 };
214        let max = self.rest.size_hint().1.unwrap_or(0) + first;
215        if max > 0 {
216            (1, Some(max))
217        } else {
218            (0, Some(0))
219        }
220    }
221}
222
223/// Struct returned by call to `OneOrMany::into_iter()`.
224pub struct IntoIter<T> {
225    // Owned.
226    first: Option<T>,
227    rest: std::vec::IntoIter<T>,
228}
229
230/// Implement `Iterator` for `IntoIter<T>`.
231impl<T> IntoIterator for OneOrMany<T>
232where
233    T: Clone,
234{
235    type Item = T;
236    type IntoIter = IntoIter<T>;
237
238    fn into_iter(self) -> Self::IntoIter {
239        IntoIter {
240            first: Some(self.first),
241            rest: self.rest.into_iter(),
242        }
243    }
244}
245
246/// Implement `Iterator` for `IntoIter<T>`.
247/// The Item type of the `Iterator` trait is an owned `T`.
248impl<T> Iterator for IntoIter<T>
249where
250    T: Clone,
251{
252    type Item = T;
253
254    fn next(&mut self) -> Option<Self::Item> {
255        match self.first.take() {
256            Some(first) => Some(first),
257            _ => self.rest.next(),
258        }
259    }
260
261    fn size_hint(&self) -> (usize, Option<usize>) {
262        let first = if self.first.is_some() { 1 } else { 0 };
263        let max = self.rest.size_hint().1.unwrap_or(0) + first;
264        if max > 0 {
265            (1, Some(max))
266        } else {
267            (0, Some(0))
268        }
269    }
270}
271
272/// Struct returned by call to `OneOrMany::iter_mut()`.
273pub struct IterMut<'a, T> {
274    // Mutable references.
275    first: Option<&'a mut T>,
276    rest: std::slice::IterMut<'a, T>,
277}
278
279// Implement `Iterator` for `IterMut<T>`.
280// The Item type of the `Iterator` trait is a mutable reference of `OneOrMany<T>`.
281impl<'a, T> Iterator for IterMut<'a, T> {
282    type Item = &'a mut T;
283
284    fn next(&mut self) -> Option<Self::Item> {
285        if let Some(first) = self.first.take() {
286            Some(first)
287        } else {
288            self.rest.next()
289        }
290    }
291
292    fn size_hint(&self) -> (usize, Option<usize>) {
293        let first = if self.first.is_some() { 1 } else { 0 };
294        let max = self.rest.size_hint().1.unwrap_or(0) + first;
295        if max > 0 {
296            (1, Some(max))
297        } else {
298            (0, Some(0))
299        }
300    }
301}
302
303// Serialize `OneOrMany<T>` into a json sequence (akin to `Vec<T>`)
304impl<T> Serialize for OneOrMany<T>
305where
306    T: Serialize + Clone,
307{
308    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
309    where
310        S: Serializer,
311    {
312        // Create a sequence serializer with the length of the OneOrMany object.
313        let mut seq = serializer.serialize_seq(Some(self.len()))?;
314        // Serialize each element in the OneOrMany object.
315        for e in self.iter() {
316            seq.serialize_element(e)?;
317        }
318        // End the sequence serialization.
319        seq.end()
320    }
321}
322
323// Deserialize a json sequence into `OneOrMany<T>` (akin to `Vec<T>`).
324// Additionally, deserialize a single element (of type `T`) into `OneOrMany<T>` using
325// `OneOrMany::one`, which is helpful to avoid `Either<T, OneOrMany<T>>` typing in serde structs.
326impl<'de, T> Deserialize<'de> for OneOrMany<T>
327where
328    T: Deserialize<'de> + Clone,
329{
330    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
331    where
332        D: Deserializer<'de>,
333    {
334        // Visitor struct to handle deserialization.
335        struct OneOrManyVisitor<T>(std::marker::PhantomData<T>);
336
337        impl<'de, T> Visitor<'de> for OneOrManyVisitor<T>
338        where
339            T: Deserialize<'de> + Clone,
340        {
341            type Value = OneOrMany<T>;
342
343            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
344                formatter.write_str("a sequence of at least one element")
345            }
346
347            // Visit a sequence and deserialize it into OneOrMany.
348            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
349            where
350                A: SeqAccess<'de>,
351            {
352                // Get the first element.
353                let first = seq
354                    .next_element()?
355                    .ok_or_else(|| de::Error::invalid_length(0, &self))?;
356
357                // Collect the rest of the elements.
358                let mut rest = Vec::new();
359                while let Some(value) = seq.next_element()? {
360                    rest.push(value);
361                }
362
363                // Return the deserialized OneOrMany object.
364                Ok(OneOrMany { first, rest })
365            }
366        }
367
368        // Deserialize any type into OneOrMany using the visitor.
369        deserializer.deserialize_any(OneOrManyVisitor(std::marker::PhantomData))
370    }
371}
372
373// A special deserialize_with function for fields with `OneOrMany<T: FromStr>`
374//
375// Usage:
376// #[derive(Deserialize)]
377// struct MyStruct {
378//     #[serde(deserialize_with = "string_or_one_or_many")]
379//     field: OneOrMany<String>,
380// }
381pub fn string_or_one_or_many<'de, T, D>(deserializer: D) -> Result<OneOrMany<T>, D::Error>
382where
383    T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
384    D: Deserializer<'de>,
385{
386    struct StringOrOneOrMany<T>(PhantomData<fn() -> T>);
387
388    impl<'de, T> Visitor<'de> for StringOrOneOrMany<T>
389    where
390        T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
391    {
392        type Value = OneOrMany<T>;
393
394        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
395            formatter.write_str("a string or sequence")
396        }
397
398        fn visit_str<E>(self, value: &str) -> Result<OneOrMany<T>, E>
399        where
400            E: de::Error,
401        {
402            let item = FromStr::from_str(value).map_err(de::Error::custom)?;
403            Ok(OneOrMany::one(item))
404        }
405
406        fn visit_seq<A>(self, seq: A) -> Result<OneOrMany<T>, A::Error>
407        where
408            A: SeqAccess<'de>,
409        {
410            Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))
411        }
412
413        fn visit_map<M>(self, map: M) -> Result<OneOrMany<T>, M::Error>
414        where
415            M: MapAccess<'de>,
416        {
417            let item = Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
418            Ok(OneOrMany::one(item))
419        }
420    }
421
422    deserializer.deserialize_any(StringOrOneOrMany(PhantomData))
423}
424
425// A variant of the `string_or_one_or_many` function that returns an `Option<OneOrMany<T>>`.
426//
427// Usage:
428// #[derive(Deserialize)]
429// struct MyStruct {
430//     #[serde(deserialize_with = "string_or_option_one_or_many")]
431//     field: Option<OneOrMany<String>>,
432// }
433pub fn string_or_option_one_or_many<'de, T, D>(
434    deserializer: D,
435) -> Result<Option<OneOrMany<T>>, D::Error>
436where
437    T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
438    D: Deserializer<'de>,
439{
440    struct StringOrOptionOneOrMany<T>(PhantomData<fn() -> T>);
441
442    impl<'de, T> Visitor<'de> for StringOrOptionOneOrMany<T>
443    where
444        T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
445    {
446        type Value = Option<OneOrMany<T>>;
447
448        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
449            formatter.write_str("null, a string, or a sequence")
450        }
451
452        fn visit_none<E>(self) -> Result<Option<OneOrMany<T>>, E>
453        where
454            E: de::Error,
455        {
456            Ok(None)
457        }
458
459        fn visit_unit<E>(self) -> Result<Option<OneOrMany<T>>, E>
460        where
461            E: de::Error,
462        {
463            Ok(None)
464        }
465
466        fn visit_some<D>(self, deserializer: D) -> Result<Option<OneOrMany<T>>, D::Error>
467        where
468            D: Deserializer<'de>,
469        {
470            string_or_one_or_many(deserializer).map(Some)
471        }
472    }
473
474    deserializer.deserialize_option(StringOrOptionOneOrMany(PhantomData))
475}
476
477#[cfg(test)]
478mod test {
479    use serde::{self, Deserialize};
480    use serde_json::json;
481
482    use super::*;
483
484    #[test]
485    fn test_single() {
486        let one_or_many = OneOrMany::one("hello".to_string());
487
488        assert_eq!(one_or_many.iter().count(), 1);
489
490        one_or_many.iter().for_each(|i| {
491            assert_eq!(i, "hello");
492        });
493    }
494
495    #[test]
496    fn test() {
497        let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
498
499        assert_eq!(one_or_many.iter().count(), 2);
500
501        one_or_many.iter().enumerate().for_each(|(i, item)| {
502            if i == 0 {
503                assert_eq!(item, "hello");
504            }
505            if i == 1 {
506                assert_eq!(item, "word");
507            }
508        });
509    }
510
511    #[test]
512    fn test_size_hint() {
513        let foo = "bar".to_string();
514        let one_or_many = OneOrMany::one(foo);
515        let size_hint = one_or_many.iter().size_hint();
516        assert_eq!(size_hint.0, 1);
517        assert_eq!(size_hint.1, Some(1));
518
519        let vec = vec!["foo".to_string(), "bar".to_string(), "baz".to_string()];
520        let mut one_or_many = OneOrMany::many(vec).expect("this should never fail");
521        let size_hint = one_or_many.iter().size_hint();
522        assert_eq!(size_hint.0, 1);
523        assert_eq!(size_hint.1, Some(3));
524
525        let size_hint = one_or_many.clone().into_iter().size_hint();
526        assert_eq!(size_hint.0, 1);
527        assert_eq!(size_hint.1, Some(3));
528
529        let size_hint = one_or_many.iter_mut().size_hint();
530        assert_eq!(size_hint.0, 1);
531        assert_eq!(size_hint.1, Some(3));
532    }
533
534    #[test]
535    fn test_one_or_many_into_iter_single() {
536        let one_or_many = OneOrMany::one("hello".to_string());
537
538        assert_eq!(one_or_many.clone().into_iter().count(), 1);
539
540        one_or_many.into_iter().for_each(|i| {
541            assert_eq!(i, "hello".to_string());
542        });
543    }
544
545    #[test]
546    fn test_one_or_many_into_iter() {
547        let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
548
549        assert_eq!(one_or_many.clone().into_iter().count(), 2);
550
551        one_or_many.into_iter().enumerate().for_each(|(i, item)| {
552            if i == 0 {
553                assert_eq!(item, "hello".to_string());
554            }
555            if i == 1 {
556                assert_eq!(item, "word".to_string());
557            }
558        });
559    }
560
561    #[test]
562    fn test_one_or_many_merge() {
563        let one_or_many_1 = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
564
565        let one_or_many_2 = OneOrMany::one("sup".to_string());
566
567        let merged = OneOrMany::merge(vec![one_or_many_1, one_or_many_2]).unwrap();
568
569        assert_eq!(merged.iter().count(), 3);
570
571        merged.iter().enumerate().for_each(|(i, item)| {
572            if i == 0 {
573                assert_eq!(item, "hello");
574            }
575            if i == 1 {
576                assert_eq!(item, "word");
577            }
578            if i == 2 {
579                assert_eq!(item, "sup");
580            }
581        });
582    }
583
584    #[test]
585    fn test_mut_single() {
586        let mut one_or_many = OneOrMany::one("hello".to_string());
587
588        assert_eq!(one_or_many.iter_mut().count(), 1);
589
590        one_or_many.iter_mut().for_each(|i| {
591            assert_eq!(i, "hello");
592        });
593    }
594
595    #[test]
596    fn test_mut() {
597        let mut one_or_many =
598            OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
599
600        assert_eq!(one_or_many.iter_mut().count(), 2);
601
602        one_or_many.iter_mut().enumerate().for_each(|(i, item)| {
603            if i == 0 {
604                item.push_str(" world");
605                assert_eq!(item, "hello world");
606            }
607            if i == 1 {
608                assert_eq!(item, "word");
609            }
610        });
611    }
612
613    #[test]
614    fn test_one_or_many_error() {
615        assert!(OneOrMany::<String>::many(vec![]).is_err())
616    }
617
618    #[test]
619    fn test_len_single() {
620        let one_or_many = OneOrMany::one("hello".to_string());
621
622        assert_eq!(one_or_many.len(), 1);
623    }
624
625    #[test]
626    fn test_len_many() {
627        let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
628
629        assert_eq!(one_or_many.len(), 2);
630    }
631
632    // Testing deserialization
633    #[test]
634    fn test_deserialize_list() {
635        let json_data = json!({"field": [1, 2, 3]});
636        let one_or_many: OneOrMany<i32> =
637            serde_json::from_value(json_data["field"].clone()).unwrap();
638
639        assert_eq!(one_or_many.len(), 3);
640        assert_eq!(one_or_many.first(), 1);
641        assert_eq!(one_or_many.rest(), vec![2, 3]);
642    }
643
644    #[test]
645    fn test_deserialize_list_of_maps() {
646        let json_data = json!({"field": [{"key": "value1"}, {"key": "value2"}]});
647        let one_or_many: OneOrMany<serde_json::Value> =
648            serde_json::from_value(json_data["field"].clone()).unwrap();
649
650        assert_eq!(one_or_many.len(), 2);
651        assert_eq!(one_or_many.first(), json!({"key": "value1"}));
652        assert_eq!(one_or_many.rest(), vec![json!({"key": "value2"})]);
653    }
654
655    #[derive(Debug, Deserialize, PartialEq)]
656    struct DummyStruct {
657        #[serde(deserialize_with = "string_or_one_or_many")]
658        field: OneOrMany<DummyString>,
659    }
660
661    #[derive(Debug, Deserialize, PartialEq)]
662    struct DummyStructOption {
663        #[serde(deserialize_with = "string_or_option_one_or_many")]
664        field: Option<OneOrMany<DummyString>>,
665    }
666
667    #[derive(Debug, Clone, Deserialize, PartialEq)]
668    struct DummyString {
669        pub string: String,
670    }
671
672    impl FromStr for DummyString {
673        type Err = Infallible;
674
675        fn from_str(s: &str) -> Result<Self, Self::Err> {
676            Ok(DummyString {
677                string: s.to_string(),
678            })
679        }
680    }
681
682    #[derive(Debug, Deserialize, PartialEq)]
683    #[serde(tag = "role", rename_all = "lowercase")]
684    enum DummyMessage {
685        Assistant {
686            #[serde(deserialize_with = "string_or_option_one_or_many")]
687            content: Option<OneOrMany<DummyString>>,
688        },
689    }
690
691    #[test]
692    fn test_deserialize_unit() {
693        let raw_json = r#"
694        {
695            "role": "assistant",
696            "content": null
697        }
698        "#;
699        let dummy: DummyMessage = serde_json::from_str(raw_json).unwrap();
700
701        assert_eq!(dummy, DummyMessage::Assistant { content: None });
702    }
703
704    #[test]
705    fn test_deserialize_string() {
706        let json_data = json!({"field": "hello"});
707        let dummy: DummyStruct = serde_json::from_value(json_data).unwrap();
708
709        assert_eq!(dummy.field.len(), 1);
710        assert_eq!(dummy.field.first(), DummyString::from_str("hello").unwrap());
711    }
712
713    #[test]
714    fn test_deserialize_string_option() {
715        let json_data = json!({"field": "hello"});
716        let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
717
718        assert!(dummy.field.is_some());
719        let field = dummy.field.unwrap();
720        assert_eq!(field.len(), 1);
721        assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
722    }
723
724    #[test]
725    fn test_deserialize_list_option() {
726        let json_data = json!({"field": [{"string": "hello"}, {"string": "world"}]});
727        let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
728
729        assert!(dummy.field.is_some());
730        let field = dummy.field.unwrap();
731        assert_eq!(field.len(), 2);
732        assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
733        assert_eq!(field.rest(), vec![DummyString::from_str("world").unwrap()]);
734    }
735
736    #[test]
737    fn test_deserialize_null_option() {
738        let json_data = json!({"field": null});
739        let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
740
741        assert!(dummy.field.is_none());
742    }
743}