rig/
one_or_many.rs

1use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};
2use serde::ser::{SerializeSeq, Serializer};
3use serde::{Deserialize, Serialize};
4use std::convert::Infallible;
5use std::fmt;
6use std::marker::PhantomData;
7use std::str::FromStr;
8
9/// Struct containing either a single item or a list of items of type T.
10/// If a single item is present, `first` will contain it and `rest` will be empty.
11/// If multiple items are present, `first` will contain the first item and `rest` will contain the rest.
12/// IMPORTANT: this struct cannot be created with an empty vector.
13/// OneOrMany objects can only be created using OneOrMany::from() or OneOrMany::try_from().
14#[derive(PartialEq, Eq, Debug, Clone)]
15pub struct OneOrMany<T> {
16    /// First item in the list.
17    first: T,
18    /// Rest of the items in the list.
19    rest: Vec<T>,
20}
21
22/// Error type for when trying to create a OneOrMany object with an empty vector.
23#[derive(Debug, thiserror::Error)]
24#[error("Cannot create OneOrMany with an empty vector.")]
25pub struct EmptyListError;
26
27impl<T: Clone> OneOrMany<T> {
28    /// Get the first item in the list.
29    pub fn first(&self) -> T {
30        self.first.clone()
31    }
32
33    /// Get the rest of the items in the list (excluding the first one).
34    pub fn rest(&self) -> Vec<T> {
35        self.rest.clone()
36    }
37
38    /// After `OneOrMany<T>` is created, add an item of type T to the `rest`.
39    pub fn push(&mut self, item: T) {
40        self.rest.push(item);
41    }
42
43    /// After `OneOrMany<T>` is created, insert an item of type T at an index.
44    pub fn insert(&mut self, index: usize, item: T) {
45        if index == 0 {
46            let old_first = std::mem::replace(&mut self.first, item);
47            self.rest.insert(0, old_first);
48        } else {
49            self.rest.insert(index - 1, item);
50        }
51    }
52
53    /// Length of all items in `OneOrMany<T>`.
54    pub fn len(&self) -> usize {
55        1 + self.rest.len()
56    }
57
58    /// If `OneOrMany<T>` is empty. This will always be false because you cannot create an empty `OneOrMany<T>`.
59    /// This method is required when the method `len` exists.
60    pub fn is_empty(&self) -> bool {
61        false
62    }
63
64    /// Create a `OneOrMany` object with a single item of any type.
65    pub fn one(item: T) -> Self {
66        OneOrMany {
67            first: item,
68            rest: vec![],
69        }
70    }
71
72    /// Create a `OneOrMany` object with a vector of items of any type.
73    pub fn many<I>(items: I) -> Result<Self, EmptyListError>
74    where
75        I: IntoIterator<Item = T>,
76    {
77        let mut iter = items.into_iter();
78        Ok(OneOrMany {
79            first: match iter.next() {
80                Some(item) => item,
81                None => return Err(EmptyListError),
82            },
83            rest: iter.collect(),
84        })
85    }
86
87    /// Merge a list of OneOrMany items into a single OneOrMany item.
88    pub fn merge<I>(one_or_many_items: I) -> Result<Self, EmptyListError>
89    where
90        I: IntoIterator<Item = OneOrMany<T>>,
91    {
92        let items = one_or_many_items
93            .into_iter()
94            .flat_map(|one_or_many| one_or_many.into_iter())
95            .collect::<Vec<_>>();
96
97        OneOrMany::many(items)
98    }
99
100    /// Specialized map function for OneOrMany objects.
101    ///
102    /// Since OneOrMany objects have *atleast* 1 item, using `.collect::<Vec<_>>()` and
103    /// `OneOrMany::many()` is fallible resulting in unergonomic uses of `.expect` or `.unwrap`.
104    /// This function bypasses those hurdles by directly constructing the `OneOrMany` struct.
105    pub(crate) fn map<U, F: FnMut(T) -> U>(self, mut op: F) -> OneOrMany<U> {
106        OneOrMany {
107            first: op(self.first),
108            rest: self.rest.into_iter().map(op).collect(),
109        }
110    }
111
112    /// Specialized try map function for OneOrMany objects.
113    ///
114    /// Same as `OneOrMany::map` but fallible.
115    pub(crate) fn try_map<U, E, F: FnMut(T) -> Result<U, E>>(
116        self,
117        mut op: F,
118    ) -> Result<OneOrMany<U>, E> {
119        Ok(OneOrMany {
120            first: op(self.first)?,
121            rest: self
122                .rest
123                .into_iter()
124                .map(op)
125                .collect::<Result<Vec<_>, E>>()?,
126        })
127    }
128
129    pub fn iter(&self) -> Iter<T> {
130        Iter {
131            first: Some(&self.first),
132            rest: self.rest.iter(),
133        }
134    }
135
136    pub fn iter_mut(&mut self) -> IterMut<'_, T> {
137        IterMut {
138            first: Some(&mut self.first),
139            rest: self.rest.iter_mut(),
140        }
141    }
142}
143
144// ================================================================
145// Implementations of Iterator for OneOrMany
146//   - OneOrMany<T>::iter() -> iterate over references of T objects
147//   - OneOrMany<T>::into_iter() -> iterate over owned T objects
148//   - OneOrMany<T>::iter_mut() -> iterate over mutable references of T objects
149// ================================================================
150
151/// Struct returned by call to `OneOrMany::iter()`.
152pub struct Iter<'a, T> {
153    // References.
154    first: Option<&'a T>,
155    rest: std::slice::Iter<'a, T>,
156}
157
158/// Implement `Iterator` for `Iter<T>`.
159/// The Item type of the `Iterator` trait is a reference of `T`.
160impl<'a, T> Iterator for Iter<'a, T> {
161    type Item = &'a T;
162
163    fn next(&mut self) -> Option<Self::Item> {
164        if let Some(first) = self.first.take() {
165            Some(first)
166        } else {
167            self.rest.next()
168        }
169    }
170
171    fn size_hint(&self) -> (usize, Option<usize>) {
172        let first = if self.first.is_some() { 1 } else { 0 };
173        let max = self.rest.size_hint().1.unwrap_or(0) + first;
174        if max > 0 {
175            (1, Some(max))
176        } else {
177            (0, Some(0))
178        }
179    }
180}
181
182/// Struct returned by call to `OneOrMany::into_iter()`.
183pub struct IntoIter<T> {
184    // Owned.
185    first: Option<T>,
186    rest: std::vec::IntoIter<T>,
187}
188
189/// Implement `Iterator` for `IntoIter<T>`.
190impl<T: Clone> IntoIterator for OneOrMany<T> {
191    type Item = T;
192    type IntoIter = IntoIter<T>;
193
194    fn into_iter(self) -> Self::IntoIter {
195        IntoIter {
196            first: Some(self.first),
197            rest: self.rest.into_iter(),
198        }
199    }
200}
201
202/// Implement `Iterator` for `IntoIter<T>`.
203/// The Item type of the `Iterator` trait is an owned `T`.
204impl<T: Clone> Iterator for IntoIter<T> {
205    type Item = T;
206
207    fn next(&mut self) -> Option<Self::Item> {
208        match self.first.take() {
209            Some(first) => Some(first),
210            _ => self.rest.next(),
211        }
212    }
213
214    fn size_hint(&self) -> (usize, Option<usize>) {
215        let first = if self.first.is_some() { 1 } else { 0 };
216        let max = self.rest.size_hint().1.unwrap_or(0) + first;
217        if max > 0 {
218            (1, Some(max))
219        } else {
220            (0, Some(0))
221        }
222    }
223}
224
225/// Struct returned by call to `OneOrMany::iter_mut()`.
226pub struct IterMut<'a, T> {
227    // Mutable references.
228    first: Option<&'a mut T>,
229    rest: std::slice::IterMut<'a, T>,
230}
231
232// Implement `Iterator` for `IterMut<T>`.
233// The Item type of the `Iterator` trait is a mutable reference of `OneOrMany<T>`.
234impl<'a, T> Iterator for IterMut<'a, T> {
235    type Item = &'a mut T;
236
237    fn next(&mut self) -> Option<Self::Item> {
238        if let Some(first) = self.first.take() {
239            Some(first)
240        } else {
241            self.rest.next()
242        }
243    }
244
245    fn size_hint(&self) -> (usize, Option<usize>) {
246        let first = if self.first.is_some() { 1 } else { 0 };
247        let max = self.rest.size_hint().1.unwrap_or(0) + first;
248        if max > 0 {
249            (1, Some(max))
250        } else {
251            (0, Some(0))
252        }
253    }
254}
255
256// Serialize `OneOrMany<T>` into a json sequence (akin to `Vec<T>`)
257impl<T: Clone> Serialize for OneOrMany<T>
258where
259    T: Serialize,
260{
261    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
262    where
263        S: Serializer,
264    {
265        // Create a sequence serializer with the length of the OneOrMany object.
266        let mut seq = serializer.serialize_seq(Some(self.len()))?;
267        // Serialize each element in the OneOrMany object.
268        for e in self.iter() {
269            seq.serialize_element(e)?;
270        }
271        // End the sequence serialization.
272        seq.end()
273    }
274}
275
276// Deserialize a json sequence into `OneOrMany<T>` (akin to `Vec<T>`).
277// Additionally, deserialize a single element (of type `T`) into `OneOrMany<T>` using
278// `OneOrMany::one`, which is helpful to avoid `Either<T, OneOrMany<T>>` typing in serde structs.
279impl<'de, T> Deserialize<'de> for OneOrMany<T>
280where
281    T: Deserialize<'de> + Clone,
282{
283    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
284    where
285        D: Deserializer<'de>,
286    {
287        // Visitor struct to handle deserialization.
288        struct OneOrManyVisitor<T>(std::marker::PhantomData<T>);
289
290        impl<'de, T> Visitor<'de> for OneOrManyVisitor<T>
291        where
292            T: Deserialize<'de> + Clone,
293        {
294            type Value = OneOrMany<T>;
295
296            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
297                formatter.write_str("a sequence of at least one element")
298            }
299
300            // Visit a sequence and deserialize it into OneOrMany.
301            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
302            where
303                A: SeqAccess<'de>,
304            {
305                // Get the first element.
306                let first = seq
307                    .next_element()?
308                    .ok_or_else(|| de::Error::invalid_length(0, &self))?;
309
310                // Collect the rest of the elements.
311                let mut rest = Vec::new();
312                while let Some(value) = seq.next_element()? {
313                    rest.push(value);
314                }
315
316                // Return the deserialized OneOrMany object.
317                Ok(OneOrMany { first, rest })
318            }
319        }
320
321        // Deserialize any type into OneOrMany using the visitor.
322        deserializer.deserialize_any(OneOrManyVisitor(std::marker::PhantomData))
323    }
324}
325
326// A special deserialize_with function for fields with `OneOrMany<T: FromStr>`
327//
328// Usage:
329// #[derive(Deserialize)]
330// struct MyStruct {
331//     #[serde(deserialize_with = "string_or_one_or_many")]
332//     field: OneOrMany<String>,
333// }
334pub fn string_or_one_or_many<'de, T, D>(deserializer: D) -> Result<OneOrMany<T>, D::Error>
335where
336    T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
337    D: Deserializer<'de>,
338{
339    struct StringOrOneOrMany<T>(PhantomData<fn() -> T>);
340
341    impl<'de, T> Visitor<'de> for StringOrOneOrMany<T>
342    where
343        T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
344    {
345        type Value = OneOrMany<T>;
346
347        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
348            formatter.write_str("a string or sequence")
349        }
350
351        fn visit_str<E>(self, value: &str) -> Result<OneOrMany<T>, E>
352        where
353            E: de::Error,
354        {
355            let item = FromStr::from_str(value).map_err(de::Error::custom)?;
356            Ok(OneOrMany::one(item))
357        }
358
359        fn visit_seq<A>(self, seq: A) -> Result<OneOrMany<T>, A::Error>
360        where
361            A: SeqAccess<'de>,
362        {
363            Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))
364        }
365
366        fn visit_map<M>(self, map: M) -> Result<OneOrMany<T>, M::Error>
367        where
368            M: MapAccess<'de>,
369        {
370            let item = Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
371            Ok(OneOrMany::one(item))
372        }
373    }
374
375    deserializer.deserialize_any(StringOrOneOrMany(PhantomData))
376}
377
378// A variant of the `string_or_one_or_many` function that returns an `Option<OneOrMany<T>>`.
379//
380// Usage:
381// #[derive(Deserialize)]
382// struct MyStruct {
383//     #[serde(deserialize_with = "string_or_option_one_or_many")]
384//     field: Option<OneOrMany<String>>,
385// }
386pub fn string_or_option_one_or_many<'de, T, D>(
387    deserializer: D,
388) -> Result<Option<OneOrMany<T>>, D::Error>
389where
390    T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
391    D: Deserializer<'de>,
392{
393    struct StringOrOptionOneOrMany<T>(PhantomData<fn() -> T>);
394
395    impl<'de, T> Visitor<'de> for StringOrOptionOneOrMany<T>
396    where
397        T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
398    {
399        type Value = Option<OneOrMany<T>>;
400
401        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
402            formatter.write_str("null, a string, or a sequence")
403        }
404
405        fn visit_none<E>(self) -> Result<Option<OneOrMany<T>>, E>
406        where
407            E: de::Error,
408        {
409            Ok(None)
410        }
411
412        fn visit_unit<E>(self) -> Result<Option<OneOrMany<T>>, E>
413        where
414            E: de::Error,
415        {
416            Ok(None)
417        }
418
419        fn visit_some<D>(self, deserializer: D) -> Result<Option<OneOrMany<T>>, D::Error>
420        where
421            D: Deserializer<'de>,
422        {
423            string_or_one_or_many(deserializer).map(Some)
424        }
425    }
426
427    deserializer.deserialize_option(StringOrOptionOneOrMany(PhantomData))
428}
429
430#[cfg(test)]
431mod test {
432    use serde::{self, Deserialize};
433    use serde_json::json;
434
435    use super::*;
436
437    #[test]
438    fn test_single() {
439        let one_or_many = OneOrMany::one("hello".to_string());
440
441        assert_eq!(one_or_many.iter().count(), 1);
442
443        one_or_many.iter().for_each(|i| {
444            assert_eq!(i, "hello");
445        });
446    }
447
448    #[test]
449    fn test() {
450        let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
451
452        assert_eq!(one_or_many.iter().count(), 2);
453
454        one_or_many.iter().enumerate().for_each(|(i, item)| {
455            if i == 0 {
456                assert_eq!(item, "hello");
457            }
458            if i == 1 {
459                assert_eq!(item, "word");
460            }
461        });
462    }
463
464    #[test]
465    fn test_size_hint() {
466        let foo = "bar".to_string();
467        let one_or_many = OneOrMany::one(foo);
468        let size_hint = one_or_many.iter().size_hint();
469        assert_eq!(size_hint.0, 1);
470        assert_eq!(size_hint.1, Some(1));
471
472        let vec = vec!["foo".to_string(), "bar".to_string(), "baz".to_string()];
473        let mut one_or_many = OneOrMany::many(vec).expect("this should never fail");
474        let size_hint = one_or_many.iter().size_hint();
475        assert_eq!(size_hint.0, 1);
476        assert_eq!(size_hint.1, Some(3));
477
478        let size_hint = one_or_many.clone().into_iter().size_hint();
479        assert_eq!(size_hint.0, 1);
480        assert_eq!(size_hint.1, Some(3));
481
482        let size_hint = one_or_many.iter_mut().size_hint();
483        assert_eq!(size_hint.0, 1);
484        assert_eq!(size_hint.1, Some(3));
485    }
486
487    #[test]
488    fn test_one_or_many_into_iter_single() {
489        let one_or_many = OneOrMany::one("hello".to_string());
490
491        assert_eq!(one_or_many.clone().into_iter().count(), 1);
492
493        one_or_many.into_iter().for_each(|i| {
494            assert_eq!(i, "hello".to_string());
495        });
496    }
497
498    #[test]
499    fn test_one_or_many_into_iter() {
500        let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
501
502        assert_eq!(one_or_many.clone().into_iter().count(), 2);
503
504        one_or_many.into_iter().enumerate().for_each(|(i, item)| {
505            if i == 0 {
506                assert_eq!(item, "hello".to_string());
507            }
508            if i == 1 {
509                assert_eq!(item, "word".to_string());
510            }
511        });
512    }
513
514    #[test]
515    fn test_one_or_many_merge() {
516        let one_or_many_1 = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
517
518        let one_or_many_2 = OneOrMany::one("sup".to_string());
519
520        let merged = OneOrMany::merge(vec![one_or_many_1, one_or_many_2]).unwrap();
521
522        assert_eq!(merged.iter().count(), 3);
523
524        merged.iter().enumerate().for_each(|(i, item)| {
525            if i == 0 {
526                assert_eq!(item, "hello");
527            }
528            if i == 1 {
529                assert_eq!(item, "word");
530            }
531            if i == 2 {
532                assert_eq!(item, "sup");
533            }
534        });
535    }
536
537    #[test]
538    fn test_mut_single() {
539        let mut one_or_many = OneOrMany::one("hello".to_string());
540
541        assert_eq!(one_or_many.iter_mut().count(), 1);
542
543        one_or_many.iter_mut().for_each(|i| {
544            assert_eq!(i, "hello");
545        });
546    }
547
548    #[test]
549    fn test_mut() {
550        let mut one_or_many =
551            OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
552
553        assert_eq!(one_or_many.iter_mut().count(), 2);
554
555        one_or_many.iter_mut().enumerate().for_each(|(i, item)| {
556            if i == 0 {
557                item.push_str(" world");
558                assert_eq!(item, "hello world");
559            }
560            if i == 1 {
561                assert_eq!(item, "word");
562            }
563        });
564    }
565
566    #[test]
567    fn test_one_or_many_error() {
568        assert!(OneOrMany::<String>::many(vec![]).is_err())
569    }
570
571    #[test]
572    fn test_len_single() {
573        let one_or_many = OneOrMany::one("hello".to_string());
574
575        assert_eq!(one_or_many.len(), 1);
576    }
577
578    #[test]
579    fn test_len_many() {
580        let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
581
582        assert_eq!(one_or_many.len(), 2);
583    }
584
585    // Testing deserialization
586    #[test]
587    fn test_deserialize_list() {
588        let json_data = json!({"field": [1, 2, 3]});
589        let one_or_many: OneOrMany<i32> =
590            serde_json::from_value(json_data["field"].clone()).unwrap();
591
592        assert_eq!(one_or_many.len(), 3);
593        assert_eq!(one_or_many.first(), 1);
594        assert_eq!(one_or_many.rest(), vec![2, 3]);
595    }
596
597    #[test]
598    fn test_deserialize_list_of_maps() {
599        let json_data = json!({"field": [{"key": "value1"}, {"key": "value2"}]});
600        let one_or_many: OneOrMany<serde_json::Value> =
601            serde_json::from_value(json_data["field"].clone()).unwrap();
602
603        assert_eq!(one_or_many.len(), 2);
604        assert_eq!(one_or_many.first(), json!({"key": "value1"}));
605        assert_eq!(one_or_many.rest(), vec![json!({"key": "value2"})]);
606    }
607
608    #[derive(Debug, Deserialize, PartialEq)]
609    struct DummyStruct {
610        #[serde(deserialize_with = "string_or_one_or_many")]
611        field: OneOrMany<DummyString>,
612    }
613
614    #[derive(Debug, Deserialize, PartialEq)]
615    struct DummyStructOption {
616        #[serde(deserialize_with = "string_or_option_one_or_many")]
617        field: Option<OneOrMany<DummyString>>,
618    }
619
620    #[derive(Debug, Clone, Deserialize, PartialEq)]
621    struct DummyString {
622        pub string: String,
623    }
624
625    impl FromStr for DummyString {
626        type Err = Infallible;
627
628        fn from_str(s: &str) -> Result<Self, Self::Err> {
629            Ok(DummyString {
630                string: s.to_string(),
631            })
632        }
633    }
634
635    #[derive(Debug, Deserialize, PartialEq)]
636    #[serde(tag = "role", rename_all = "lowercase")]
637    enum DummyMessage {
638        Assistant {
639            #[serde(deserialize_with = "string_or_option_one_or_many")]
640            content: Option<OneOrMany<DummyString>>,
641        },
642    }
643
644    #[test]
645    fn test_deserialize_unit() {
646        let raw_json = r#"
647        {
648            "role": "assistant",
649            "content": null
650        }
651        "#;
652        let dummy: DummyMessage = serde_json::from_str(raw_json).unwrap();
653
654        assert_eq!(dummy, DummyMessage::Assistant { content: None });
655    }
656
657    #[test]
658    fn test_deserialize_string() {
659        let json_data = json!({"field": "hello"});
660        let dummy: DummyStruct = serde_json::from_value(json_data).unwrap();
661
662        assert_eq!(dummy.field.len(), 1);
663        assert_eq!(dummy.field.first(), DummyString::from_str("hello").unwrap());
664    }
665
666    #[test]
667    fn test_deserialize_string_option() {
668        let json_data = json!({"field": "hello"});
669        let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
670
671        assert!(dummy.field.is_some());
672        let field = dummy.field.unwrap();
673        assert_eq!(field.len(), 1);
674        assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
675    }
676
677    #[test]
678    fn test_deserialize_list_option() {
679        let json_data = json!({"field": [{"string": "hello"}, {"string": "world"}]});
680        let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
681
682        assert!(dummy.field.is_some());
683        let field = dummy.field.unwrap();
684        assert_eq!(field.len(), 2);
685        assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
686        assert_eq!(field.rest(), vec![DummyString::from_str("world").unwrap()]);
687    }
688
689    #[test]
690    fn test_deserialize_null_option() {
691        let json_data = json!({"field": null});
692        let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
693
694        assert!(dummy.field.is_none());
695    }
696}