rig/
one_or_many.rs

1use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};
2use serde::ser::{SerializeSeq, Serializer};
3use serde::{Deserialize, Serialize};
4use std::convert::Infallible;
5use std::fmt;
6use std::marker::PhantomData;
7use std::str::FromStr;
8
9/// Struct containing either a single item or a list of items of type T.
10/// If a single item is present, `first` will contain it and `rest` will be empty.
11/// If multiple items are present, `first` will contain the first item and `rest` will contain the rest.
12/// IMPORTANT: this struct cannot be created with an empty vector.
13/// OneOrMany objects can only be created using OneOrMany::from() or OneOrMany::try_from().
14#[derive(PartialEq, Eq, Debug, Clone)]
15pub struct OneOrMany<T> {
16    /// First item in the list.
17    first: T,
18    /// Rest of the items in the list.
19    rest: Vec<T>,
20}
21
22/// Error type for when trying to create a OneOrMany object with an empty vector.
23#[derive(Debug, thiserror::Error)]
24#[error("Cannot create OneOrMany with an empty vector.")]
25pub struct EmptyListError;
26
27impl<T: Clone> OneOrMany<T> {
28    /// Get the first item in the list.
29    pub fn first(&self) -> T {
30        self.first.clone()
31    }
32
33    /// Get the rest of the items in the list (excluding the first one).
34    pub fn rest(&self) -> Vec<T> {
35        self.rest.clone()
36    }
37
38    /// After `OneOrMany<T>` is created, add an item of type T to the `rest`.
39    pub fn push(&mut self, item: T) {
40        self.rest.push(item);
41    }
42
43    /// After `OneOrMany<T>` is created, insert an item of type T at an index.
44    pub fn insert(&mut self, index: usize, item: T) {
45        if index == 0 {
46            let old_first = std::mem::replace(&mut self.first, item);
47            self.rest.insert(0, old_first);
48        } else {
49            self.rest.insert(index - 1, item);
50        }
51    }
52
53    /// Length of all items in `OneOrMany<T>`.
54    pub fn len(&self) -> usize {
55        1 + self.rest.len()
56    }
57
58    /// If `OneOrMany<T>` is empty. This will always be false because you cannot create an empty `OneOrMany<T>`.
59    /// This method is required when the method `len` exists.
60    pub fn is_empty(&self) -> bool {
61        false
62    }
63
64    /// Create a `OneOrMany` object with a single item of any type.
65    pub fn one(item: T) -> Self {
66        OneOrMany {
67            first: item,
68            rest: vec![],
69        }
70    }
71
72    /// Create a `OneOrMany` object with a vector of items of any type.
73    pub fn many<I>(items: I) -> Result<Self, EmptyListError>
74    where
75        I: IntoIterator<Item = T>,
76    {
77        let mut iter = items.into_iter();
78        Ok(OneOrMany {
79            first: match iter.next() {
80                Some(item) => item,
81                None => return Err(EmptyListError),
82            },
83            rest: iter.collect(),
84        })
85    }
86
87    /// Merge a list of OneOrMany items into a single OneOrMany item.
88    pub fn merge<I>(one_or_many_items: I) -> Result<Self, EmptyListError>
89    where
90        I: IntoIterator<Item = OneOrMany<T>>,
91    {
92        let items = one_or_many_items
93            .into_iter()
94            .flat_map(|one_or_many| one_or_many.into_iter())
95            .collect::<Vec<_>>();
96
97        OneOrMany::many(items)
98    }
99
100    /// Specialized map function for OneOrMany objects.
101    ///
102    /// Since OneOrMany objects have *atleast* 1 item, using `.collect::<Vec<_>>()` and
103    /// `OneOrMany::many()` is fallible resulting in unergonomic uses of `.expect` or `.unwrap`.
104    /// This function bypasses those hurdles by directly constructing the `OneOrMany` struct.
105    pub(crate) fn map<U, F: FnMut(T) -> U>(self, mut op: F) -> OneOrMany<U> {
106        OneOrMany {
107            first: op(self.first),
108            rest: self.rest.into_iter().map(op).collect(),
109        }
110    }
111
112    /// Specialized try map function for OneOrMany objects.
113    ///
114    /// Same as `OneOrMany::map` but fallible.
115    pub(crate) fn try_map<U, E, F>(self, mut op: F) -> Result<OneOrMany<U>, E>
116    where
117        F: FnMut(T) -> Result<U, E>,
118    {
119        Ok(OneOrMany {
120            first: op(self.first)?,
121            rest: self
122                .rest
123                .into_iter()
124                .map(op)
125                .collect::<Result<Vec<_>, E>>()?,
126        })
127    }
128
129    pub fn iter(&self) -> Iter<'_, T> {
130        Iter {
131            first: Some(&self.first),
132            rest: self.rest.iter(),
133        }
134    }
135
136    pub fn iter_mut(&mut self) -> IterMut<'_, T> {
137        IterMut {
138            first: Some(&mut self.first),
139            rest: self.rest.iter_mut(),
140        }
141    }
142}
143
144// ================================================================
145// Implementations of Iterator for OneOrMany
146//   - OneOrMany<T>::iter() -> iterate over references of T objects
147//   - OneOrMany<T>::into_iter() -> iterate over owned T objects
148//   - OneOrMany<T>::iter_mut() -> iterate over mutable references of T objects
149// ================================================================
150
151/// Struct returned by call to `OneOrMany::iter()`.
152pub struct Iter<'a, T> {
153    // References.
154    first: Option<&'a T>,
155    rest: std::slice::Iter<'a, T>,
156}
157
158/// Implement `Iterator` for `Iter<T>`.
159/// The Item type of the `Iterator` trait is a reference of `T`.
160impl<'a, T> Iterator for Iter<'a, T> {
161    type Item = &'a T;
162
163    fn next(&mut self) -> Option<Self::Item> {
164        if let Some(first) = self.first.take() {
165            Some(first)
166        } else {
167            self.rest.next()
168        }
169    }
170
171    fn size_hint(&self) -> (usize, Option<usize>) {
172        let first = if self.first.is_some() { 1 } else { 0 };
173        let max = self.rest.size_hint().1.unwrap_or(0) + first;
174        if max > 0 {
175            (1, Some(max))
176        } else {
177            (0, Some(0))
178        }
179    }
180}
181
182/// Struct returned by call to `OneOrMany::into_iter()`.
183pub struct IntoIter<T> {
184    // Owned.
185    first: Option<T>,
186    rest: std::vec::IntoIter<T>,
187}
188
189/// Implement `Iterator` for `IntoIter<T>`.
190impl<T> IntoIterator for OneOrMany<T>
191where
192    T: Clone,
193{
194    type Item = T;
195    type IntoIter = IntoIter<T>;
196
197    fn into_iter(self) -> Self::IntoIter {
198        IntoIter {
199            first: Some(self.first),
200            rest: self.rest.into_iter(),
201        }
202    }
203}
204
205/// Implement `Iterator` for `IntoIter<T>`.
206/// The Item type of the `Iterator` trait is an owned `T`.
207impl<T> Iterator for IntoIter<T>
208where
209    T: Clone,
210{
211    type Item = T;
212
213    fn next(&mut self) -> Option<Self::Item> {
214        match self.first.take() {
215            Some(first) => Some(first),
216            _ => self.rest.next(),
217        }
218    }
219
220    fn size_hint(&self) -> (usize, Option<usize>) {
221        let first = if self.first.is_some() { 1 } else { 0 };
222        let max = self.rest.size_hint().1.unwrap_or(0) + first;
223        if max > 0 {
224            (1, Some(max))
225        } else {
226            (0, Some(0))
227        }
228    }
229}
230
231/// Struct returned by call to `OneOrMany::iter_mut()`.
232pub struct IterMut<'a, T> {
233    // Mutable references.
234    first: Option<&'a mut T>,
235    rest: std::slice::IterMut<'a, T>,
236}
237
238// Implement `Iterator` for `IterMut<T>`.
239// The Item type of the `Iterator` trait is a mutable reference of `OneOrMany<T>`.
240impl<'a, T> Iterator for IterMut<'a, T> {
241    type Item = &'a mut T;
242
243    fn next(&mut self) -> Option<Self::Item> {
244        if let Some(first) = self.first.take() {
245            Some(first)
246        } else {
247            self.rest.next()
248        }
249    }
250
251    fn size_hint(&self) -> (usize, Option<usize>) {
252        let first = if self.first.is_some() { 1 } else { 0 };
253        let max = self.rest.size_hint().1.unwrap_or(0) + first;
254        if max > 0 {
255            (1, Some(max))
256        } else {
257            (0, Some(0))
258        }
259    }
260}
261
262// Serialize `OneOrMany<T>` into a json sequence (akin to `Vec<T>`)
263impl<T> Serialize for OneOrMany<T>
264where
265    T: Serialize + Clone,
266{
267    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
268    where
269        S: Serializer,
270    {
271        // Create a sequence serializer with the length of the OneOrMany object.
272        let mut seq = serializer.serialize_seq(Some(self.len()))?;
273        // Serialize each element in the OneOrMany object.
274        for e in self.iter() {
275            seq.serialize_element(e)?;
276        }
277        // End the sequence serialization.
278        seq.end()
279    }
280}
281
282// Deserialize a json sequence into `OneOrMany<T>` (akin to `Vec<T>`).
283// Additionally, deserialize a single element (of type `T`) into `OneOrMany<T>` using
284// `OneOrMany::one`, which is helpful to avoid `Either<T, OneOrMany<T>>` typing in serde structs.
285impl<'de, T> Deserialize<'de> for OneOrMany<T>
286where
287    T: Deserialize<'de> + Clone,
288{
289    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
290    where
291        D: Deserializer<'de>,
292    {
293        // Visitor struct to handle deserialization.
294        struct OneOrManyVisitor<T>(std::marker::PhantomData<T>);
295
296        impl<'de, T> Visitor<'de> for OneOrManyVisitor<T>
297        where
298            T: Deserialize<'de> + Clone,
299        {
300            type Value = OneOrMany<T>;
301
302            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
303                formatter.write_str("a sequence of at least one element")
304            }
305
306            // Visit a sequence and deserialize it into OneOrMany.
307            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
308            where
309                A: SeqAccess<'de>,
310            {
311                // Get the first element.
312                let first = seq
313                    .next_element()?
314                    .ok_or_else(|| de::Error::invalid_length(0, &self))?;
315
316                // Collect the rest of the elements.
317                let mut rest = Vec::new();
318                while let Some(value) = seq.next_element()? {
319                    rest.push(value);
320                }
321
322                // Return the deserialized OneOrMany object.
323                Ok(OneOrMany { first, rest })
324            }
325        }
326
327        // Deserialize any type into OneOrMany using the visitor.
328        deserializer.deserialize_any(OneOrManyVisitor(std::marker::PhantomData))
329    }
330}
331
332// A special deserialize_with function for fields with `OneOrMany<T: FromStr>`
333//
334// Usage:
335// #[derive(Deserialize)]
336// struct MyStruct {
337//     #[serde(deserialize_with = "string_or_one_or_many")]
338//     field: OneOrMany<String>,
339// }
340pub fn string_or_one_or_many<'de, T, D>(deserializer: D) -> Result<OneOrMany<T>, D::Error>
341where
342    T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
343    D: Deserializer<'de>,
344{
345    struct StringOrOneOrMany<T>(PhantomData<fn() -> T>);
346
347    impl<'de, T> Visitor<'de> for StringOrOneOrMany<T>
348    where
349        T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
350    {
351        type Value = OneOrMany<T>;
352
353        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
354            formatter.write_str("a string or sequence")
355        }
356
357        fn visit_str<E>(self, value: &str) -> Result<OneOrMany<T>, E>
358        where
359            E: de::Error,
360        {
361            let item = FromStr::from_str(value).map_err(de::Error::custom)?;
362            Ok(OneOrMany::one(item))
363        }
364
365        fn visit_seq<A>(self, seq: A) -> Result<OneOrMany<T>, A::Error>
366        where
367            A: SeqAccess<'de>,
368        {
369            Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))
370        }
371
372        fn visit_map<M>(self, map: M) -> Result<OneOrMany<T>, M::Error>
373        where
374            M: MapAccess<'de>,
375        {
376            let item = Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
377            Ok(OneOrMany::one(item))
378        }
379    }
380
381    deserializer.deserialize_any(StringOrOneOrMany(PhantomData))
382}
383
384// A variant of the `string_or_one_or_many` function that returns an `Option<OneOrMany<T>>`.
385//
386// Usage:
387// #[derive(Deserialize)]
388// struct MyStruct {
389//     #[serde(deserialize_with = "string_or_option_one_or_many")]
390//     field: Option<OneOrMany<String>>,
391// }
392pub fn string_or_option_one_or_many<'de, T, D>(
393    deserializer: D,
394) -> Result<Option<OneOrMany<T>>, D::Error>
395where
396    T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
397    D: Deserializer<'de>,
398{
399    struct StringOrOptionOneOrMany<T>(PhantomData<fn() -> T>);
400
401    impl<'de, T> Visitor<'de> for StringOrOptionOneOrMany<T>
402    where
403        T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
404    {
405        type Value = Option<OneOrMany<T>>;
406
407        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
408            formatter.write_str("null, a string, or a sequence")
409        }
410
411        fn visit_none<E>(self) -> Result<Option<OneOrMany<T>>, E>
412        where
413            E: de::Error,
414        {
415            Ok(None)
416        }
417
418        fn visit_unit<E>(self) -> Result<Option<OneOrMany<T>>, E>
419        where
420            E: de::Error,
421        {
422            Ok(None)
423        }
424
425        fn visit_some<D>(self, deserializer: D) -> Result<Option<OneOrMany<T>>, D::Error>
426        where
427            D: Deserializer<'de>,
428        {
429            string_or_one_or_many(deserializer).map(Some)
430        }
431    }
432
433    deserializer.deserialize_option(StringOrOptionOneOrMany(PhantomData))
434}
435
436#[cfg(test)]
437mod test {
438    use serde::{self, Deserialize};
439    use serde_json::json;
440
441    use super::*;
442
443    #[test]
444    fn test_single() {
445        let one_or_many = OneOrMany::one("hello".to_string());
446
447        assert_eq!(one_or_many.iter().count(), 1);
448
449        one_or_many.iter().for_each(|i| {
450            assert_eq!(i, "hello");
451        });
452    }
453
454    #[test]
455    fn test() {
456        let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
457
458        assert_eq!(one_or_many.iter().count(), 2);
459
460        one_or_many.iter().enumerate().for_each(|(i, item)| {
461            if i == 0 {
462                assert_eq!(item, "hello");
463            }
464            if i == 1 {
465                assert_eq!(item, "word");
466            }
467        });
468    }
469
470    #[test]
471    fn test_size_hint() {
472        let foo = "bar".to_string();
473        let one_or_many = OneOrMany::one(foo);
474        let size_hint = one_or_many.iter().size_hint();
475        assert_eq!(size_hint.0, 1);
476        assert_eq!(size_hint.1, Some(1));
477
478        let vec = vec!["foo".to_string(), "bar".to_string(), "baz".to_string()];
479        let mut one_or_many = OneOrMany::many(vec).expect("this should never fail");
480        let size_hint = one_or_many.iter().size_hint();
481        assert_eq!(size_hint.0, 1);
482        assert_eq!(size_hint.1, Some(3));
483
484        let size_hint = one_or_many.clone().into_iter().size_hint();
485        assert_eq!(size_hint.0, 1);
486        assert_eq!(size_hint.1, Some(3));
487
488        let size_hint = one_or_many.iter_mut().size_hint();
489        assert_eq!(size_hint.0, 1);
490        assert_eq!(size_hint.1, Some(3));
491    }
492
493    #[test]
494    fn test_one_or_many_into_iter_single() {
495        let one_or_many = OneOrMany::one("hello".to_string());
496
497        assert_eq!(one_or_many.clone().into_iter().count(), 1);
498
499        one_or_many.into_iter().for_each(|i| {
500            assert_eq!(i, "hello".to_string());
501        });
502    }
503
504    #[test]
505    fn test_one_or_many_into_iter() {
506        let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
507
508        assert_eq!(one_or_many.clone().into_iter().count(), 2);
509
510        one_or_many.into_iter().enumerate().for_each(|(i, item)| {
511            if i == 0 {
512                assert_eq!(item, "hello".to_string());
513            }
514            if i == 1 {
515                assert_eq!(item, "word".to_string());
516            }
517        });
518    }
519
520    #[test]
521    fn test_one_or_many_merge() {
522        let one_or_many_1 = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
523
524        let one_or_many_2 = OneOrMany::one("sup".to_string());
525
526        let merged = OneOrMany::merge(vec![one_or_many_1, one_or_many_2]).unwrap();
527
528        assert_eq!(merged.iter().count(), 3);
529
530        merged.iter().enumerate().for_each(|(i, item)| {
531            if i == 0 {
532                assert_eq!(item, "hello");
533            }
534            if i == 1 {
535                assert_eq!(item, "word");
536            }
537            if i == 2 {
538                assert_eq!(item, "sup");
539            }
540        });
541    }
542
543    #[test]
544    fn test_mut_single() {
545        let mut one_or_many = OneOrMany::one("hello".to_string());
546
547        assert_eq!(one_or_many.iter_mut().count(), 1);
548
549        one_or_many.iter_mut().for_each(|i| {
550            assert_eq!(i, "hello");
551        });
552    }
553
554    #[test]
555    fn test_mut() {
556        let mut one_or_many =
557            OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
558
559        assert_eq!(one_or_many.iter_mut().count(), 2);
560
561        one_or_many.iter_mut().enumerate().for_each(|(i, item)| {
562            if i == 0 {
563                item.push_str(" world");
564                assert_eq!(item, "hello world");
565            }
566            if i == 1 {
567                assert_eq!(item, "word");
568            }
569        });
570    }
571
572    #[test]
573    fn test_one_or_many_error() {
574        assert!(OneOrMany::<String>::many(vec![]).is_err())
575    }
576
577    #[test]
578    fn test_len_single() {
579        let one_or_many = OneOrMany::one("hello".to_string());
580
581        assert_eq!(one_or_many.len(), 1);
582    }
583
584    #[test]
585    fn test_len_many() {
586        let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
587
588        assert_eq!(one_or_many.len(), 2);
589    }
590
591    // Testing deserialization
592    #[test]
593    fn test_deserialize_list() {
594        let json_data = json!({"field": [1, 2, 3]});
595        let one_or_many: OneOrMany<i32> =
596            serde_json::from_value(json_data["field"].clone()).unwrap();
597
598        assert_eq!(one_or_many.len(), 3);
599        assert_eq!(one_or_many.first(), 1);
600        assert_eq!(one_or_many.rest(), vec![2, 3]);
601    }
602
603    #[test]
604    fn test_deserialize_list_of_maps() {
605        let json_data = json!({"field": [{"key": "value1"}, {"key": "value2"}]});
606        let one_or_many: OneOrMany<serde_json::Value> =
607            serde_json::from_value(json_data["field"].clone()).unwrap();
608
609        assert_eq!(one_or_many.len(), 2);
610        assert_eq!(one_or_many.first(), json!({"key": "value1"}));
611        assert_eq!(one_or_many.rest(), vec![json!({"key": "value2"})]);
612    }
613
614    #[derive(Debug, Deserialize, PartialEq)]
615    struct DummyStruct {
616        #[serde(deserialize_with = "string_or_one_or_many")]
617        field: OneOrMany<DummyString>,
618    }
619
620    #[derive(Debug, Deserialize, PartialEq)]
621    struct DummyStructOption {
622        #[serde(deserialize_with = "string_or_option_one_or_many")]
623        field: Option<OneOrMany<DummyString>>,
624    }
625
626    #[derive(Debug, Clone, Deserialize, PartialEq)]
627    struct DummyString {
628        pub string: String,
629    }
630
631    impl FromStr for DummyString {
632        type Err = Infallible;
633
634        fn from_str(s: &str) -> Result<Self, Self::Err> {
635            Ok(DummyString {
636                string: s.to_string(),
637            })
638        }
639    }
640
641    #[derive(Debug, Deserialize, PartialEq)]
642    #[serde(tag = "role", rename_all = "lowercase")]
643    enum DummyMessage {
644        Assistant {
645            #[serde(deserialize_with = "string_or_option_one_or_many")]
646            content: Option<OneOrMany<DummyString>>,
647        },
648    }
649
650    #[test]
651    fn test_deserialize_unit() {
652        let raw_json = r#"
653        {
654            "role": "assistant",
655            "content": null
656        }
657        "#;
658        let dummy: DummyMessage = serde_json::from_str(raw_json).unwrap();
659
660        assert_eq!(dummy, DummyMessage::Assistant { content: None });
661    }
662
663    #[test]
664    fn test_deserialize_string() {
665        let json_data = json!({"field": "hello"});
666        let dummy: DummyStruct = serde_json::from_value(json_data).unwrap();
667
668        assert_eq!(dummy.field.len(), 1);
669        assert_eq!(dummy.field.first(), DummyString::from_str("hello").unwrap());
670    }
671
672    #[test]
673    fn test_deserialize_string_option() {
674        let json_data = json!({"field": "hello"});
675        let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
676
677        assert!(dummy.field.is_some());
678        let field = dummy.field.unwrap();
679        assert_eq!(field.len(), 1);
680        assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
681    }
682
683    #[test]
684    fn test_deserialize_list_option() {
685        let json_data = json!({"field": [{"string": "hello"}, {"string": "world"}]});
686        let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
687
688        assert!(dummy.field.is_some());
689        let field = dummy.field.unwrap();
690        assert_eq!(field.len(), 2);
691        assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
692        assert_eq!(field.rest(), vec![DummyString::from_str("world").unwrap()]);
693    }
694
695    #[test]
696    fn test_deserialize_null_option() {
697        let json_data = json!({"field": null});
698        let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
699
700        assert!(dummy.field.is_none());
701    }
702}