rig/
one_or_many.rs

1use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};
2use serde::ser::{SerializeSeq, Serializer};
3use serde::{Deserialize, Serialize};
4use std::convert::Infallible;
5use std::fmt;
6use std::marker::PhantomData;
7use std::str::FromStr;
8
9/// Struct containing either a single item or a list of items of type T.
10/// If a single item is present, `first` will contain it and `rest` will be empty.
11/// If multiple items are present, `first` will contain the first item and `rest` will contain the rest.
12/// IMPORTANT: this struct cannot be created with an empty vector.
13/// OneOrMany objects can only be created using OneOrMany::from() or OneOrMany::try_from().
14#[derive(PartialEq, Eq, Debug, Clone)]
15pub struct OneOrMany<T> {
16    /// First item in the list.
17    first: T,
18    /// Rest of the items in the list.
19    rest: Vec<T>,
20}
21
22/// Error type for when trying to create a OneOrMany object with an empty vector.
23#[derive(Debug, thiserror::Error)]
24#[error("Cannot create OneOrMany with an empty vector.")]
25pub struct EmptyListError;
26
27impl<T: Clone> OneOrMany<T> {
28    /// Get the first item in the list.
29    pub fn first(&self) -> T {
30        self.first.clone()
31    }
32
33    /// Get the rest of the items in the list (excluding the first one).
34    pub fn rest(&self) -> Vec<T> {
35        self.rest.clone()
36    }
37
38    /// After `OneOrMany<T>` is created, add an item of type T to the `rest`.
39    pub fn push(&mut self, item: T) {
40        self.rest.push(item);
41    }
42
43    /// After `OneOrMany<T>` is created, insert an item of type T at an index.
44    pub fn insert(&mut self, index: usize, item: T) {
45        if index == 0 {
46            let old_first = std::mem::replace(&mut self.first, item);
47            self.rest.insert(0, old_first);
48        } else {
49            self.rest.insert(index - 1, item);
50        }
51    }
52
53    /// Length of all items in `OneOrMany<T>`.
54    pub fn len(&self) -> usize {
55        1 + self.rest.len()
56    }
57
58    /// If `OneOrMany<T>` is empty. This will always be false because you cannot create an empty `OneOrMany<T>`.
59    /// This method is required when the method `len` exists.
60    pub fn is_empty(&self) -> bool {
61        false
62    }
63
64    /// Create a `OneOrMany` object with a single item of any type.
65    pub fn one(item: T) -> Self {
66        OneOrMany {
67            first: item,
68            rest: vec![],
69        }
70    }
71
72    /// Create a `OneOrMany` object with a vector of items of any type.
73    pub fn many<I>(items: I) -> Result<Self, EmptyListError>
74    where
75        I: IntoIterator<Item = T>,
76    {
77        let mut iter = items.into_iter();
78        Ok(OneOrMany {
79            first: match iter.next() {
80                Some(item) => item,
81                None => return Err(EmptyListError),
82            },
83            rest: iter.collect(),
84        })
85    }
86
87    /// Merge a list of OneOrMany items into a single OneOrMany item.
88    pub fn merge<I>(one_or_many_items: I) -> Result<Self, EmptyListError>
89    where
90        I: IntoIterator<Item = OneOrMany<T>>,
91    {
92        let items = one_or_many_items
93            .into_iter()
94            .flat_map(|one_or_many| one_or_many.into_iter())
95            .collect::<Vec<_>>();
96
97        OneOrMany::many(items)
98    }
99
100    /// Specialized map function for OneOrMany objects.
101    ///
102    /// Since OneOrMany objects have *atleast* 1 item, using `.collect::<Vec<_>>()` and
103    /// `OneOrMany::many()` is fallible resulting in unergonomic uses of `.expect` or `.unwrap`.
104    /// This function bypasses those hurdles by directly constructing the `OneOrMany` struct.
105    pub(crate) fn map<U, F: FnMut(T) -> U>(self, mut op: F) -> OneOrMany<U> {
106        OneOrMany {
107            first: op(self.first),
108            rest: self.rest.into_iter().map(op).collect(),
109        }
110    }
111
112    /// Specialized try map function for OneOrMany objects.
113    ///
114    /// Same as `OneOrMany::map` but fallible.
115    pub(crate) fn try_map<U, E, F: FnMut(T) -> Result<U, E>>(
116        self,
117        mut op: F,
118    ) -> Result<OneOrMany<U>, E> {
119        Ok(OneOrMany {
120            first: op(self.first)?,
121            rest: self
122                .rest
123                .into_iter()
124                .map(op)
125                .collect::<Result<Vec<_>, E>>()?,
126        })
127    }
128
129    pub fn iter(&self) -> Iter<T> {
130        Iter {
131            first: Some(&self.first),
132            rest: self.rest.iter(),
133        }
134    }
135
136    pub fn iter_mut(&mut self) -> IterMut<'_, T> {
137        IterMut {
138            first: Some(&mut self.first),
139            rest: self.rest.iter_mut(),
140        }
141    }
142}
143
144// ================================================================
145// Implementations of Iterator for OneOrMany
146//   - OneOrMany<T>::iter() -> iterate over references of T objects
147//   - OneOrMany<T>::into_iter() -> iterate over owned T objects
148//   - OneOrMany<T>::iter_mut() -> iterate over mutable references of T objects
149// ================================================================
150
151/// Struct returned by call to `OneOrMany::iter()`.
152pub struct Iter<'a, T> {
153    // References.
154    first: Option<&'a T>,
155    rest: std::slice::Iter<'a, T>,
156}
157
158/// Implement `Iterator` for `Iter<T>`.
159/// The Item type of the `Iterator` trait is a reference of `T`.
160impl<'a, T> Iterator for Iter<'a, T> {
161    type Item = &'a T;
162
163    fn next(&mut self) -> Option<Self::Item> {
164        if let Some(first) = self.first.take() {
165            Some(first)
166        } else {
167            self.rest.next()
168        }
169    }
170}
171
172/// Struct returned by call to `OneOrMany::into_iter()`.
173pub struct IntoIter<T> {
174    // Owned.
175    first: Option<T>,
176    rest: std::vec::IntoIter<T>,
177}
178
179/// Implement `Iterator` for `IntoIter<T>`.
180impl<T: Clone> IntoIterator for OneOrMany<T> {
181    type Item = T;
182    type IntoIter = IntoIter<T>;
183
184    fn into_iter(self) -> Self::IntoIter {
185        IntoIter {
186            first: Some(self.first),
187            rest: self.rest.into_iter(),
188        }
189    }
190}
191
192/// Implement `Iterator` for `IntoIter<T>`.
193/// The Item type of the `Iterator` trait is an owned `T`.
194impl<T: Clone> Iterator for IntoIter<T> {
195    type Item = T;
196
197    fn next(&mut self) -> Option<Self::Item> {
198        if let Some(first) = self.first.take() {
199            Some(first)
200        } else {
201            self.rest.next()
202        }
203    }
204}
205
206/// Struct returned by call to `OneOrMany::iter_mut()`.
207pub struct IterMut<'a, T> {
208    // Mutable references.
209    first: Option<&'a mut T>,
210    rest: std::slice::IterMut<'a, T>,
211}
212
213// Implement `Iterator` for `IterMut<T>`.
214// The Item type of the `Iterator` trait is a mutable reference of `OneOrMany<T>`.
215impl<'a, T> Iterator for IterMut<'a, T> {
216    type Item = &'a mut T;
217
218    fn next(&mut self) -> Option<Self::Item> {
219        if let Some(first) = self.first.take() {
220            Some(first)
221        } else {
222            self.rest.next()
223        }
224    }
225}
226
227// Serialize `OneOrMany<T>` into a json sequence (akin to `Vec<T>`)
228impl<T: Clone> Serialize for OneOrMany<T>
229where
230    T: Serialize,
231{
232    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
233    where
234        S: Serializer,
235    {
236        // Create a sequence serializer with the length of the OneOrMany object.
237        let mut seq = serializer.serialize_seq(Some(self.len()))?;
238        // Serialize each element in the OneOrMany object.
239        for e in self.iter() {
240            seq.serialize_element(e)?;
241        }
242        // End the sequence serialization.
243        seq.end()
244    }
245}
246
247// Deserialize a json sequence into `OneOrMany<T>` (akin to `Vec<T>`).
248// Additionally, deserialize a single element (of type `T`) into `OneOrMany<T>` using
249// `OneOrMany::one`, which is helpful to avoid `Either<T, OneOrMany<T>>` typing in serde structs.
250impl<'de, T> Deserialize<'de> for OneOrMany<T>
251where
252    T: Deserialize<'de> + Clone,
253{
254    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
255    where
256        D: Deserializer<'de>,
257    {
258        // Visitor struct to handle deserialization.
259        struct OneOrManyVisitor<T>(std::marker::PhantomData<T>);
260
261        impl<'de, T> Visitor<'de> for OneOrManyVisitor<T>
262        where
263            T: Deserialize<'de> + Clone,
264        {
265            type Value = OneOrMany<T>;
266
267            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
268                formatter.write_str("a sequence of at least one element")
269            }
270
271            // Visit a sequence and deserialize it into OneOrMany.
272            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
273            where
274                A: SeqAccess<'de>,
275            {
276                // Get the first element.
277                let first = seq
278                    .next_element()?
279                    .ok_or_else(|| de::Error::invalid_length(0, &self))?;
280
281                // Collect the rest of the elements.
282                let mut rest = Vec::new();
283                while let Some(value) = seq.next_element()? {
284                    rest.push(value);
285                }
286
287                // Return the deserialized OneOrMany object.
288                Ok(OneOrMany { first, rest })
289            }
290        }
291
292        // Deserialize any type into OneOrMany using the visitor.
293        deserializer.deserialize_any(OneOrManyVisitor(std::marker::PhantomData))
294    }
295}
296
297// A special deserialize_with function for fields with `OneOrMany<T: FromStr>`
298//
299// Usage:
300// #[derive(Deserialize)]
301// struct MyStruct {
302//     #[serde(deserialize_with = "string_or_one_or_many")]
303//     field: OneOrMany<String>,
304// }
305pub fn string_or_one_or_many<'de, T, D>(deserializer: D) -> Result<OneOrMany<T>, D::Error>
306where
307    T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
308    D: Deserializer<'de>,
309{
310    struct StringOrOneOrMany<T>(PhantomData<fn() -> T>);
311
312    impl<'de, T> Visitor<'de> for StringOrOneOrMany<T>
313    where
314        T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
315    {
316        type Value = OneOrMany<T>;
317
318        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
319            formatter.write_str("a string or sequence")
320        }
321
322        fn visit_str<E>(self, value: &str) -> Result<OneOrMany<T>, E>
323        where
324            E: de::Error,
325        {
326            let item = FromStr::from_str(value).map_err(de::Error::custom)?;
327            Ok(OneOrMany::one(item))
328        }
329
330        fn visit_seq<A>(self, seq: A) -> Result<OneOrMany<T>, A::Error>
331        where
332            A: SeqAccess<'de>,
333        {
334            Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))
335        }
336
337        fn visit_map<M>(self, map: M) -> Result<OneOrMany<T>, M::Error>
338        where
339            M: MapAccess<'de>,
340        {
341            let item = Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
342            Ok(OneOrMany::one(item))
343        }
344    }
345
346    deserializer.deserialize_any(StringOrOneOrMany(PhantomData))
347}
348
349// A variant of the `string_or_one_or_many` function that returns an `Option<OneOrMany<T>>`.
350//
351// Usage:
352// #[derive(Deserialize)]
353// struct MyStruct {
354//     #[serde(deserialize_with = "string_or_option_one_or_many")]
355//     field: Option<OneOrMany<String>>,
356// }
357pub fn string_or_option_one_or_many<'de, T, D>(
358    deserializer: D,
359) -> Result<Option<OneOrMany<T>>, D::Error>
360where
361    T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
362    D: Deserializer<'de>,
363{
364    struct StringOrOptionOneOrMany<T>(PhantomData<fn() -> T>);
365
366    impl<'de, T> Visitor<'de> for StringOrOptionOneOrMany<T>
367    where
368        T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
369    {
370        type Value = Option<OneOrMany<T>>;
371
372        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
373            formatter.write_str("null, a string, or a sequence")
374        }
375
376        fn visit_none<E>(self) -> Result<Option<OneOrMany<T>>, E>
377        where
378            E: de::Error,
379        {
380            Ok(None)
381        }
382
383        fn visit_unit<E>(self) -> Result<Option<OneOrMany<T>>, E>
384        where
385            E: de::Error,
386        {
387            Ok(None)
388        }
389
390        fn visit_some<D>(self, deserializer: D) -> Result<Option<OneOrMany<T>>, D::Error>
391        where
392            D: Deserializer<'de>,
393        {
394            string_or_one_or_many(deserializer).map(Some)
395        }
396    }
397
398    deserializer.deserialize_option(StringOrOptionOneOrMany(PhantomData))
399}
400
401#[cfg(test)]
402mod test {
403    use serde::{self, Deserialize};
404    use serde_json::json;
405
406    use super::*;
407
408    #[test]
409    fn test_single() {
410        let one_or_many = OneOrMany::one("hello".to_string());
411
412        assert_eq!(one_or_many.iter().count(), 1);
413
414        one_or_many.iter().for_each(|i| {
415            assert_eq!(i, "hello");
416        });
417    }
418
419    #[test]
420    fn test() {
421        let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
422
423        assert_eq!(one_or_many.iter().count(), 2);
424
425        one_or_many.iter().enumerate().for_each(|(i, item)| {
426            if i == 0 {
427                assert_eq!(item, "hello");
428            }
429            if i == 1 {
430                assert_eq!(item, "word");
431            }
432        });
433    }
434
435    #[test]
436    fn test_one_or_many_into_iter_single() {
437        let one_or_many = OneOrMany::one("hello".to_string());
438
439        assert_eq!(one_or_many.clone().into_iter().count(), 1);
440
441        one_or_many.into_iter().for_each(|i| {
442            assert_eq!(i, "hello".to_string());
443        });
444    }
445
446    #[test]
447    fn test_one_or_many_into_iter() {
448        let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
449
450        assert_eq!(one_or_many.clone().into_iter().count(), 2);
451
452        one_or_many.into_iter().enumerate().for_each(|(i, item)| {
453            if i == 0 {
454                assert_eq!(item, "hello".to_string());
455            }
456            if i == 1 {
457                assert_eq!(item, "word".to_string());
458            }
459        });
460    }
461
462    #[test]
463    fn test_one_or_many_merge() {
464        let one_or_many_1 = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
465
466        let one_or_many_2 = OneOrMany::one("sup".to_string());
467
468        let merged = OneOrMany::merge(vec![one_or_many_1, one_or_many_2]).unwrap();
469
470        assert_eq!(merged.iter().count(), 3);
471
472        merged.iter().enumerate().for_each(|(i, item)| {
473            if i == 0 {
474                assert_eq!(item, "hello");
475            }
476            if i == 1 {
477                assert_eq!(item, "word");
478            }
479            if i == 2 {
480                assert_eq!(item, "sup");
481            }
482        });
483    }
484
485    #[test]
486    fn test_mut_single() {
487        let mut one_or_many = OneOrMany::one("hello".to_string());
488
489        assert_eq!(one_or_many.iter_mut().count(), 1);
490
491        one_or_many.iter_mut().for_each(|i| {
492            assert_eq!(i, "hello");
493        });
494    }
495
496    #[test]
497    fn test_mut() {
498        let mut one_or_many =
499            OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
500
501        assert_eq!(one_or_many.iter_mut().count(), 2);
502
503        one_or_many.iter_mut().enumerate().for_each(|(i, item)| {
504            if i == 0 {
505                item.push_str(" world");
506                assert_eq!(item, "hello world");
507            }
508            if i == 1 {
509                assert_eq!(item, "word");
510            }
511        });
512    }
513
514    #[test]
515    fn test_one_or_many_error() {
516        assert!(OneOrMany::<String>::many(vec![]).is_err())
517    }
518
519    #[test]
520    fn test_len_single() {
521        let one_or_many = OneOrMany::one("hello".to_string());
522
523        assert_eq!(one_or_many.len(), 1);
524    }
525
526    #[test]
527    fn test_len_many() {
528        let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
529
530        assert_eq!(one_or_many.len(), 2);
531    }
532
533    // Testing deserialization
534    #[test]
535    fn test_deserialize_list() {
536        let json_data = json!({"field": [1, 2, 3]});
537        let one_or_many: OneOrMany<i32> =
538            serde_json::from_value(json_data["field"].clone()).unwrap();
539
540        assert_eq!(one_or_many.len(), 3);
541        assert_eq!(one_or_many.first(), 1);
542        assert_eq!(one_or_many.rest(), vec![2, 3]);
543    }
544
545    #[test]
546    fn test_deserialize_list_of_maps() {
547        let json_data = json!({"field": [{"key": "value1"}, {"key": "value2"}]});
548        let one_or_many: OneOrMany<serde_json::Value> =
549            serde_json::from_value(json_data["field"].clone()).unwrap();
550
551        assert_eq!(one_or_many.len(), 2);
552        assert_eq!(one_or_many.first(), json!({"key": "value1"}));
553        assert_eq!(one_or_many.rest(), vec![json!({"key": "value2"})]);
554    }
555
556    #[derive(Debug, Deserialize, PartialEq)]
557    struct DummyStruct {
558        #[serde(deserialize_with = "string_or_one_or_many")]
559        field: OneOrMany<DummyString>,
560    }
561
562    #[derive(Debug, Deserialize, PartialEq)]
563    struct DummyStructOption {
564        #[serde(deserialize_with = "string_or_option_one_or_many")]
565        field: Option<OneOrMany<DummyString>>,
566    }
567
568    #[derive(Debug, Clone, Deserialize, PartialEq)]
569    struct DummyString {
570        pub string: String,
571    }
572
573    impl FromStr for DummyString {
574        type Err = Infallible;
575
576        fn from_str(s: &str) -> Result<Self, Self::Err> {
577            Ok(DummyString {
578                string: s.to_string(),
579            })
580        }
581    }
582
583    #[derive(Debug, Deserialize, PartialEq)]
584    #[serde(tag = "role", rename_all = "lowercase")]
585    enum DummyMessage {
586        Assistant {
587            #[serde(deserialize_with = "string_or_option_one_or_many")]
588            content: Option<OneOrMany<DummyString>>,
589        },
590    }
591
592    #[test]
593    fn test_deserialize_unit() {
594        let raw_json = r#"
595        {
596            "role": "assistant",
597            "content": null
598        }
599        "#;
600        let dummy: DummyMessage = serde_json::from_str(raw_json).unwrap();
601
602        assert_eq!(dummy, DummyMessage::Assistant { content: None });
603    }
604
605    #[test]
606    fn test_deserialize_string() {
607        let json_data = json!({"field": "hello"});
608        let dummy: DummyStruct = serde_json::from_value(json_data).unwrap();
609
610        assert_eq!(dummy.field.len(), 1);
611        assert_eq!(dummy.field.first(), DummyString::from_str("hello").unwrap());
612    }
613
614    #[test]
615    fn test_deserialize_string_option() {
616        let json_data = json!({"field": "hello"});
617        let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
618
619        assert!(dummy.field.is_some());
620        let field = dummy.field.unwrap();
621        assert_eq!(field.len(), 1);
622        assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
623    }
624
625    #[test]
626    fn test_deserialize_list_option() {
627        let json_data = json!({"field": [{"string": "hello"}, {"string": "world"}]});
628        let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
629
630        assert!(dummy.field.is_some());
631        let field = dummy.field.unwrap();
632        assert_eq!(field.len(), 2);
633        assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
634        assert_eq!(field.rest(), vec![DummyString::from_str("world").unwrap()]);
635    }
636
637    #[test]
638    fn test_deserialize_null_option() {
639        let json_data = json!({"field": null});
640        let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
641
642        assert!(dummy.field.is_none());
643    }
644}