rig/
one_or_many.rs

1use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};
2use serde::ser::{SerializeSeq, Serializer};
3use serde::{Deserialize, Serialize};
4use std::convert::Infallible;
5use std::fmt;
6use std::marker::PhantomData;
7use std::str::FromStr;
8
9/// Struct containing either a single item or a list of items of type T.
10/// If a single item is present, `first` will contain it and `rest` will be empty.
11/// If multiple items are present, `first` will contain the first item and `rest` will contain the rest.
12/// IMPORTANT: this struct cannot be created with an empty vector.
13/// OneOrMany objects can only be created using OneOrMany::from() or OneOrMany::try_from().
14#[derive(PartialEq, Eq, Debug, Clone)]
15pub struct OneOrMany<T> {
16    /// First item in the list.
17    first: T,
18    /// Rest of the items in the list.
19    rest: Vec<T>,
20}
21
22/// Error type for when trying to create a OneOrMany object with an empty vector.
23#[derive(Debug, thiserror::Error)]
24#[error("Cannot create OneOrMany with an empty vector.")]
25pub struct EmptyListError;
26
27impl<T: Clone> OneOrMany<T> {
28    /// Get the first item in the list.
29    pub fn first(&self) -> T {
30        self.first.clone()
31    }
32
33    /// Get a reference to the first item in the list.
34    pub fn first_ref(&self) -> &T {
35        &self.first
36    }
37
38    /// Get a mutable reference to the first item in the list.
39    pub fn first_mut(&mut self) -> &mut T {
40        &mut self.first
41    }
42
43    /// Get the last item in the list.
44    pub fn last(&self) -> T {
45        self.rest
46            .last()
47            .cloned()
48            .unwrap_or_else(|| self.first.clone())
49    }
50
51    /// Get a reference to the last item in the list.
52    pub fn last_ref(&self) -> &T {
53        self.rest.last().unwrap_or(&self.first)
54    }
55
56    /// Get a mutable reference to the last item in the list.
57    pub fn last_mut(&mut self) -> &mut T {
58        self.rest.last_mut().unwrap_or(&mut self.first)
59    }
60
61    /// Get the rest of the items in the list (excluding the first one).
62    pub fn rest(&self) -> Vec<T> {
63        self.rest.clone()
64    }
65
66    /// After `OneOrMany<T>` is created, add an item of type T to the `rest`.
67    pub fn push(&mut self, item: T) {
68        self.rest.push(item);
69    }
70
71    /// After `OneOrMany<T>` is created, insert an item of type T at an index.
72    pub fn insert(&mut self, index: usize, item: T) {
73        if index == 0 {
74            let old_first = std::mem::replace(&mut self.first, item);
75            self.rest.insert(0, old_first);
76        } else {
77            self.rest.insert(index - 1, item);
78        }
79    }
80
81    /// Length of all items in `OneOrMany<T>`.
82    pub fn len(&self) -> usize {
83        1 + self.rest.len()
84    }
85
86    /// If `OneOrMany<T>` is empty. This will always be false because you cannot create an empty `OneOrMany<T>`.
87    /// This method is required when the method `len` exists.
88    pub fn is_empty(&self) -> bool {
89        false
90    }
91
92    /// Create a `OneOrMany` object with a single item of any type.
93    pub fn one(item: T) -> Self {
94        OneOrMany {
95            first: item,
96            rest: vec![],
97        }
98    }
99
100    /// Create a `OneOrMany` object with a vector of items of any type.
101    pub fn many<I>(items: I) -> Result<Self, EmptyListError>
102    where
103        I: IntoIterator<Item = T>,
104    {
105        let mut iter = items.into_iter();
106        Ok(OneOrMany {
107            first: match iter.next() {
108                Some(item) => item,
109                None => return Err(EmptyListError),
110            },
111            rest: iter.collect(),
112        })
113    }
114
115    /// Merge a list of OneOrMany items into a single OneOrMany item.
116    pub fn merge<I>(one_or_many_items: I) -> Result<Self, EmptyListError>
117    where
118        I: IntoIterator<Item = OneOrMany<T>>,
119    {
120        let items = one_or_many_items
121            .into_iter()
122            .flat_map(|one_or_many| one_or_many.into_iter())
123            .collect::<Vec<_>>();
124
125        OneOrMany::many(items)
126    }
127
128    /// Specialized map function for OneOrMany objects.
129    ///
130    /// Since OneOrMany objects have *atleast* 1 item, using `.collect::<Vec<_>>()` and
131    /// `OneOrMany::many()` is fallible resulting in unergonomic uses of `.expect` or `.unwrap`.
132    /// This function bypasses those hurdles by directly constructing the `OneOrMany` struct.
133    pub(crate) fn map<U, F: FnMut(T) -> U>(self, mut op: F) -> OneOrMany<U> {
134        OneOrMany {
135            first: op(self.first),
136            rest: self.rest.into_iter().map(op).collect(),
137        }
138    }
139
140    /// Specialized try map function for OneOrMany objects.
141    ///
142    /// Same as `OneOrMany::map` but fallible.
143    pub(crate) fn try_map<U, E, F>(self, mut op: F) -> Result<OneOrMany<U>, E>
144    where
145        F: FnMut(T) -> Result<U, E>,
146    {
147        Ok(OneOrMany {
148            first: op(self.first)?,
149            rest: self
150                .rest
151                .into_iter()
152                .map(op)
153                .collect::<Result<Vec<_>, E>>()?,
154        })
155    }
156
157    pub fn iter(&self) -> Iter<'_, T> {
158        Iter {
159            first: Some(&self.first),
160            rest: self.rest.iter(),
161        }
162    }
163
164    pub fn iter_mut(&mut self) -> IterMut<'_, T> {
165        IterMut {
166            first: Some(&mut self.first),
167            rest: self.rest.iter_mut(),
168        }
169    }
170}
171
172// ================================================================
173// Implementations of Iterator for OneOrMany
174//   - OneOrMany<T>::iter() -> iterate over references of T objects
175//   - OneOrMany<T>::into_iter() -> iterate over owned T objects
176//   - OneOrMany<T>::iter_mut() -> iterate over mutable references of T objects
177// ================================================================
178
179/// Struct returned by call to `OneOrMany::iter()`.
180pub struct Iter<'a, T> {
181    // References.
182    first: Option<&'a T>,
183    rest: std::slice::Iter<'a, T>,
184}
185
186/// Implement `Iterator` for `Iter<T>`.
187/// The Item type of the `Iterator` trait is a reference of `T`.
188impl<'a, T> Iterator for Iter<'a, T> {
189    type Item = &'a T;
190
191    fn next(&mut self) -> Option<Self::Item> {
192        if let Some(first) = self.first.take() {
193            Some(first)
194        } else {
195            self.rest.next()
196        }
197    }
198
199    fn size_hint(&self) -> (usize, Option<usize>) {
200        let first = if self.first.is_some() { 1 } else { 0 };
201        let max = self.rest.size_hint().1.unwrap_or(0) + first;
202        if max > 0 {
203            (1, Some(max))
204        } else {
205            (0, Some(0))
206        }
207    }
208}
209
210/// Struct returned by call to `OneOrMany::into_iter()`.
211pub struct IntoIter<T> {
212    // Owned.
213    first: Option<T>,
214    rest: std::vec::IntoIter<T>,
215}
216
217/// Implement `Iterator` for `IntoIter<T>`.
218impl<T> IntoIterator for OneOrMany<T>
219where
220    T: Clone,
221{
222    type Item = T;
223    type IntoIter = IntoIter<T>;
224
225    fn into_iter(self) -> Self::IntoIter {
226        IntoIter {
227            first: Some(self.first),
228            rest: self.rest.into_iter(),
229        }
230    }
231}
232
233/// Implement `Iterator` for `IntoIter<T>`.
234/// The Item type of the `Iterator` trait is an owned `T`.
235impl<T> Iterator for IntoIter<T>
236where
237    T: Clone,
238{
239    type Item = T;
240
241    fn next(&mut self) -> Option<Self::Item> {
242        match self.first.take() {
243            Some(first) => Some(first),
244            _ => self.rest.next(),
245        }
246    }
247
248    fn size_hint(&self) -> (usize, Option<usize>) {
249        let first = if self.first.is_some() { 1 } else { 0 };
250        let max = self.rest.size_hint().1.unwrap_or(0) + first;
251        if max > 0 {
252            (1, Some(max))
253        } else {
254            (0, Some(0))
255        }
256    }
257}
258
259/// Struct returned by call to `OneOrMany::iter_mut()`.
260pub struct IterMut<'a, T> {
261    // Mutable references.
262    first: Option<&'a mut T>,
263    rest: std::slice::IterMut<'a, T>,
264}
265
266// Implement `Iterator` for `IterMut<T>`.
267// The Item type of the `Iterator` trait is a mutable reference of `OneOrMany<T>`.
268impl<'a, T> Iterator for IterMut<'a, T> {
269    type Item = &'a mut T;
270
271    fn next(&mut self) -> Option<Self::Item> {
272        if let Some(first) = self.first.take() {
273            Some(first)
274        } else {
275            self.rest.next()
276        }
277    }
278
279    fn size_hint(&self) -> (usize, Option<usize>) {
280        let first = if self.first.is_some() { 1 } else { 0 };
281        let max = self.rest.size_hint().1.unwrap_or(0) + first;
282        if max > 0 {
283            (1, Some(max))
284        } else {
285            (0, Some(0))
286        }
287    }
288}
289
290// Serialize `OneOrMany<T>` into a json sequence (akin to `Vec<T>`)
291impl<T> Serialize for OneOrMany<T>
292where
293    T: Serialize + Clone,
294{
295    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
296    where
297        S: Serializer,
298    {
299        // Create a sequence serializer with the length of the OneOrMany object.
300        let mut seq = serializer.serialize_seq(Some(self.len()))?;
301        // Serialize each element in the OneOrMany object.
302        for e in self.iter() {
303            seq.serialize_element(e)?;
304        }
305        // End the sequence serialization.
306        seq.end()
307    }
308}
309
310// Deserialize a json sequence into `OneOrMany<T>` (akin to `Vec<T>`).
311// Additionally, deserialize a single element (of type `T`) into `OneOrMany<T>` using
312// `OneOrMany::one`, which is helpful to avoid `Either<T, OneOrMany<T>>` typing in serde structs.
313impl<'de, T> Deserialize<'de> for OneOrMany<T>
314where
315    T: Deserialize<'de> + Clone,
316{
317    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
318    where
319        D: Deserializer<'de>,
320    {
321        // Visitor struct to handle deserialization.
322        struct OneOrManyVisitor<T>(std::marker::PhantomData<T>);
323
324        impl<'de, T> Visitor<'de> for OneOrManyVisitor<T>
325        where
326            T: Deserialize<'de> + Clone,
327        {
328            type Value = OneOrMany<T>;
329
330            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
331                formatter.write_str("a sequence of at least one element")
332            }
333
334            // Visit a sequence and deserialize it into OneOrMany.
335            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
336            where
337                A: SeqAccess<'de>,
338            {
339                // Get the first element.
340                let first = seq
341                    .next_element()?
342                    .ok_or_else(|| de::Error::invalid_length(0, &self))?;
343
344                // Collect the rest of the elements.
345                let mut rest = Vec::new();
346                while let Some(value) = seq.next_element()? {
347                    rest.push(value);
348                }
349
350                // Return the deserialized OneOrMany object.
351                Ok(OneOrMany { first, rest })
352            }
353        }
354
355        // Deserialize any type into OneOrMany using the visitor.
356        deserializer.deserialize_any(OneOrManyVisitor(std::marker::PhantomData))
357    }
358}
359
360// A special deserialize_with function for fields with `OneOrMany<T: FromStr>`
361//
362// Usage:
363// #[derive(Deserialize)]
364// struct MyStruct {
365//     #[serde(deserialize_with = "string_or_one_or_many")]
366//     field: OneOrMany<String>,
367// }
368pub fn string_or_one_or_many<'de, T, D>(deserializer: D) -> Result<OneOrMany<T>, D::Error>
369where
370    T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
371    D: Deserializer<'de>,
372{
373    struct StringOrOneOrMany<T>(PhantomData<fn() -> T>);
374
375    impl<'de, T> Visitor<'de> for StringOrOneOrMany<T>
376    where
377        T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
378    {
379        type Value = OneOrMany<T>;
380
381        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
382            formatter.write_str("a string or sequence")
383        }
384
385        fn visit_str<E>(self, value: &str) -> Result<OneOrMany<T>, E>
386        where
387            E: de::Error,
388        {
389            let item = FromStr::from_str(value).map_err(de::Error::custom)?;
390            Ok(OneOrMany::one(item))
391        }
392
393        fn visit_seq<A>(self, seq: A) -> Result<OneOrMany<T>, A::Error>
394        where
395            A: SeqAccess<'de>,
396        {
397            Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))
398        }
399
400        fn visit_map<M>(self, map: M) -> Result<OneOrMany<T>, M::Error>
401        where
402            M: MapAccess<'de>,
403        {
404            let item = Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
405            Ok(OneOrMany::one(item))
406        }
407    }
408
409    deserializer.deserialize_any(StringOrOneOrMany(PhantomData))
410}
411
412// A variant of the `string_or_one_or_many` function that returns an `Option<OneOrMany<T>>`.
413//
414// Usage:
415// #[derive(Deserialize)]
416// struct MyStruct {
417//     #[serde(deserialize_with = "string_or_option_one_or_many")]
418//     field: Option<OneOrMany<String>>,
419// }
420pub fn string_or_option_one_or_many<'de, T, D>(
421    deserializer: D,
422) -> Result<Option<OneOrMany<T>>, D::Error>
423where
424    T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
425    D: Deserializer<'de>,
426{
427    struct StringOrOptionOneOrMany<T>(PhantomData<fn() -> T>);
428
429    impl<'de, T> Visitor<'de> for StringOrOptionOneOrMany<T>
430    where
431        T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
432    {
433        type Value = Option<OneOrMany<T>>;
434
435        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
436            formatter.write_str("null, a string, or a sequence")
437        }
438
439        fn visit_none<E>(self) -> Result<Option<OneOrMany<T>>, E>
440        where
441            E: de::Error,
442        {
443            Ok(None)
444        }
445
446        fn visit_unit<E>(self) -> Result<Option<OneOrMany<T>>, E>
447        where
448            E: de::Error,
449        {
450            Ok(None)
451        }
452
453        fn visit_some<D>(self, deserializer: D) -> Result<Option<OneOrMany<T>>, D::Error>
454        where
455            D: Deserializer<'de>,
456        {
457            string_or_one_or_many(deserializer).map(Some)
458        }
459    }
460
461    deserializer.deserialize_option(StringOrOptionOneOrMany(PhantomData))
462}
463
464#[cfg(test)]
465mod test {
466    use serde::{self, Deserialize};
467    use serde_json::json;
468
469    use super::*;
470
471    #[test]
472    fn test_single() {
473        let one_or_many = OneOrMany::one("hello".to_string());
474
475        assert_eq!(one_or_many.iter().count(), 1);
476
477        one_or_many.iter().for_each(|i| {
478            assert_eq!(i, "hello");
479        });
480    }
481
482    #[test]
483    fn test() {
484        let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
485
486        assert_eq!(one_or_many.iter().count(), 2);
487
488        one_or_many.iter().enumerate().for_each(|(i, item)| {
489            if i == 0 {
490                assert_eq!(item, "hello");
491            }
492            if i == 1 {
493                assert_eq!(item, "word");
494            }
495        });
496    }
497
498    #[test]
499    fn test_size_hint() {
500        let foo = "bar".to_string();
501        let one_or_many = OneOrMany::one(foo);
502        let size_hint = one_or_many.iter().size_hint();
503        assert_eq!(size_hint.0, 1);
504        assert_eq!(size_hint.1, Some(1));
505
506        let vec = vec!["foo".to_string(), "bar".to_string(), "baz".to_string()];
507        let mut one_or_many = OneOrMany::many(vec).expect("this should never fail");
508        let size_hint = one_or_many.iter().size_hint();
509        assert_eq!(size_hint.0, 1);
510        assert_eq!(size_hint.1, Some(3));
511
512        let size_hint = one_or_many.clone().into_iter().size_hint();
513        assert_eq!(size_hint.0, 1);
514        assert_eq!(size_hint.1, Some(3));
515
516        let size_hint = one_or_many.iter_mut().size_hint();
517        assert_eq!(size_hint.0, 1);
518        assert_eq!(size_hint.1, Some(3));
519    }
520
521    #[test]
522    fn test_one_or_many_into_iter_single() {
523        let one_or_many = OneOrMany::one("hello".to_string());
524
525        assert_eq!(one_or_many.clone().into_iter().count(), 1);
526
527        one_or_many.into_iter().for_each(|i| {
528            assert_eq!(i, "hello".to_string());
529        });
530    }
531
532    #[test]
533    fn test_one_or_many_into_iter() {
534        let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
535
536        assert_eq!(one_or_many.clone().into_iter().count(), 2);
537
538        one_or_many.into_iter().enumerate().for_each(|(i, item)| {
539            if i == 0 {
540                assert_eq!(item, "hello".to_string());
541            }
542            if i == 1 {
543                assert_eq!(item, "word".to_string());
544            }
545        });
546    }
547
548    #[test]
549    fn test_one_or_many_merge() {
550        let one_or_many_1 = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
551
552        let one_or_many_2 = OneOrMany::one("sup".to_string());
553
554        let merged = OneOrMany::merge(vec![one_or_many_1, one_or_many_2]).unwrap();
555
556        assert_eq!(merged.iter().count(), 3);
557
558        merged.iter().enumerate().for_each(|(i, item)| {
559            if i == 0 {
560                assert_eq!(item, "hello");
561            }
562            if i == 1 {
563                assert_eq!(item, "word");
564            }
565            if i == 2 {
566                assert_eq!(item, "sup");
567            }
568        });
569    }
570
571    #[test]
572    fn test_mut_single() {
573        let mut one_or_many = OneOrMany::one("hello".to_string());
574
575        assert_eq!(one_or_many.iter_mut().count(), 1);
576
577        one_or_many.iter_mut().for_each(|i| {
578            assert_eq!(i, "hello");
579        });
580    }
581
582    #[test]
583    fn test_mut() {
584        let mut one_or_many =
585            OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
586
587        assert_eq!(one_or_many.iter_mut().count(), 2);
588
589        one_or_many.iter_mut().enumerate().for_each(|(i, item)| {
590            if i == 0 {
591                item.push_str(" world");
592                assert_eq!(item, "hello world");
593            }
594            if i == 1 {
595                assert_eq!(item, "word");
596            }
597        });
598    }
599
600    #[test]
601    fn test_one_or_many_error() {
602        assert!(OneOrMany::<String>::many(vec![]).is_err())
603    }
604
605    #[test]
606    fn test_len_single() {
607        let one_or_many = OneOrMany::one("hello".to_string());
608
609        assert_eq!(one_or_many.len(), 1);
610    }
611
612    #[test]
613    fn test_len_many() {
614        let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
615
616        assert_eq!(one_or_many.len(), 2);
617    }
618
619    // Testing deserialization
620    #[test]
621    fn test_deserialize_list() {
622        let json_data = json!({"field": [1, 2, 3]});
623        let one_or_many: OneOrMany<i32> =
624            serde_json::from_value(json_data["field"].clone()).unwrap();
625
626        assert_eq!(one_or_many.len(), 3);
627        assert_eq!(one_or_many.first(), 1);
628        assert_eq!(one_or_many.rest(), vec![2, 3]);
629    }
630
631    #[test]
632    fn test_deserialize_list_of_maps() {
633        let json_data = json!({"field": [{"key": "value1"}, {"key": "value2"}]});
634        let one_or_many: OneOrMany<serde_json::Value> =
635            serde_json::from_value(json_data["field"].clone()).unwrap();
636
637        assert_eq!(one_or_many.len(), 2);
638        assert_eq!(one_or_many.first(), json!({"key": "value1"}));
639        assert_eq!(one_or_many.rest(), vec![json!({"key": "value2"})]);
640    }
641
642    #[derive(Debug, Deserialize, PartialEq)]
643    struct DummyStruct {
644        #[serde(deserialize_with = "string_or_one_or_many")]
645        field: OneOrMany<DummyString>,
646    }
647
648    #[derive(Debug, Deserialize, PartialEq)]
649    struct DummyStructOption {
650        #[serde(deserialize_with = "string_or_option_one_or_many")]
651        field: Option<OneOrMany<DummyString>>,
652    }
653
654    #[derive(Debug, Clone, Deserialize, PartialEq)]
655    struct DummyString {
656        pub string: String,
657    }
658
659    impl FromStr for DummyString {
660        type Err = Infallible;
661
662        fn from_str(s: &str) -> Result<Self, Self::Err> {
663            Ok(DummyString {
664                string: s.to_string(),
665            })
666        }
667    }
668
669    #[derive(Debug, Deserialize, PartialEq)]
670    #[serde(tag = "role", rename_all = "lowercase")]
671    enum DummyMessage {
672        Assistant {
673            #[serde(deserialize_with = "string_or_option_one_or_many")]
674            content: Option<OneOrMany<DummyString>>,
675        },
676    }
677
678    #[test]
679    fn test_deserialize_unit() {
680        let raw_json = r#"
681        {
682            "role": "assistant",
683            "content": null
684        }
685        "#;
686        let dummy: DummyMessage = serde_json::from_str(raw_json).unwrap();
687
688        assert_eq!(dummy, DummyMessage::Assistant { content: None });
689    }
690
691    #[test]
692    fn test_deserialize_string() {
693        let json_data = json!({"field": "hello"});
694        let dummy: DummyStruct = serde_json::from_value(json_data).unwrap();
695
696        assert_eq!(dummy.field.len(), 1);
697        assert_eq!(dummy.field.first(), DummyString::from_str("hello").unwrap());
698    }
699
700    #[test]
701    fn test_deserialize_string_option() {
702        let json_data = json!({"field": "hello"});
703        let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
704
705        assert!(dummy.field.is_some());
706        let field = dummy.field.unwrap();
707        assert_eq!(field.len(), 1);
708        assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
709    }
710
711    #[test]
712    fn test_deserialize_list_option() {
713        let json_data = json!({"field": [{"string": "hello"}, {"string": "world"}]});
714        let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
715
716        assert!(dummy.field.is_some());
717        let field = dummy.field.unwrap();
718        assert_eq!(field.len(), 2);
719        assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
720        assert_eq!(field.rest(), vec![DummyString::from_str("world").unwrap()]);
721    }
722
723    #[test]
724    fn test_deserialize_null_option() {
725        let json_data = json!({"field": null});
726        let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
727
728        assert!(dummy.field.is_none());
729    }
730}