1use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};
2use serde::ser::{SerializeSeq, Serializer};
3use serde::{Deserialize, Serialize};
4use std::convert::Infallible;
5use std::fmt;
6use std::marker::PhantomData;
7use std::str::FromStr;
8
9#[derive(PartialEq, Eq, Debug, Clone)]
15pub struct OneOrMany<T> {
16 first: T,
18 rest: Vec<T>,
20}
21
22#[derive(Debug, thiserror::Error)]
24#[error("Cannot create OneOrMany with an empty vector.")]
25pub struct EmptyListError;
26
27impl<T: Clone> OneOrMany<T> {
28 pub fn first(&self) -> T {
30 self.first.clone()
31 }
32
33 pub fn rest(&self) -> Vec<T> {
35 self.rest.clone()
36 }
37
38 pub fn push(&mut self, item: T) {
40 self.rest.push(item);
41 }
42
43 pub fn insert(&mut self, index: usize, item: T) {
45 if index == 0 {
46 let old_first = std::mem::replace(&mut self.first, item);
47 self.rest.insert(0, old_first);
48 } else {
49 self.rest.insert(index - 1, item);
50 }
51 }
52
53 pub fn len(&self) -> usize {
55 1 + self.rest.len()
56 }
57
58 pub fn is_empty(&self) -> bool {
61 false
62 }
63
64 pub fn one(item: T) -> Self {
66 OneOrMany {
67 first: item,
68 rest: vec![],
69 }
70 }
71
72 pub fn many<I>(items: I) -> Result<Self, EmptyListError>
74 where
75 I: IntoIterator<Item = T>,
76 {
77 let mut iter = items.into_iter();
78 Ok(OneOrMany {
79 first: match iter.next() {
80 Some(item) => item,
81 None => return Err(EmptyListError),
82 },
83 rest: iter.collect(),
84 })
85 }
86
87 pub fn merge<I>(one_or_many_items: I) -> Result<Self, EmptyListError>
89 where
90 I: IntoIterator<Item = OneOrMany<T>>,
91 {
92 let items = one_or_many_items
93 .into_iter()
94 .flat_map(|one_or_many| one_or_many.into_iter())
95 .collect::<Vec<_>>();
96
97 OneOrMany::many(items)
98 }
99
100 pub(crate) fn map<U, F: FnMut(T) -> U>(self, mut op: F) -> OneOrMany<U> {
106 OneOrMany {
107 first: op(self.first),
108 rest: self.rest.into_iter().map(op).collect(),
109 }
110 }
111
112 pub(crate) fn try_map<U, E, F: FnMut(T) -> Result<U, E>>(
116 self,
117 mut op: F,
118 ) -> Result<OneOrMany<U>, E> {
119 Ok(OneOrMany {
120 first: op(self.first)?,
121 rest: self
122 .rest
123 .into_iter()
124 .map(op)
125 .collect::<Result<Vec<_>, E>>()?,
126 })
127 }
128
129 pub fn iter(&self) -> Iter<T> {
130 Iter {
131 first: Some(&self.first),
132 rest: self.rest.iter(),
133 }
134 }
135
136 pub fn iter_mut(&mut self) -> IterMut<'_, T> {
137 IterMut {
138 first: Some(&mut self.first),
139 rest: self.rest.iter_mut(),
140 }
141 }
142}
143
144pub struct Iter<'a, T> {
153 first: Option<&'a T>,
155 rest: std::slice::Iter<'a, T>,
156}
157
158impl<'a, T> Iterator for Iter<'a, T> {
161 type Item = &'a T;
162
163 fn next(&mut self) -> Option<Self::Item> {
164 if let Some(first) = self.first.take() {
165 Some(first)
166 } else {
167 self.rest.next()
168 }
169 }
170
171 fn size_hint(&self) -> (usize, Option<usize>) {
172 let first = if self.first.is_some() { 1 } else { 0 };
173 let max = self.rest.size_hint().1.unwrap_or(0) + first;
174 if max > 0 {
175 (1, Some(max))
176 } else {
177 (0, Some(0))
178 }
179 }
180}
181
182pub struct IntoIter<T> {
184 first: Option<T>,
186 rest: std::vec::IntoIter<T>,
187}
188
189impl<T: Clone> IntoIterator for OneOrMany<T> {
191 type Item = T;
192 type IntoIter = IntoIter<T>;
193
194 fn into_iter(self) -> Self::IntoIter {
195 IntoIter {
196 first: Some(self.first),
197 rest: self.rest.into_iter(),
198 }
199 }
200}
201
202impl<T: Clone> Iterator for IntoIter<T> {
205 type Item = T;
206
207 fn next(&mut self) -> Option<Self::Item> {
208 match self.first.take() {
209 Some(first) => Some(first),
210 _ => self.rest.next(),
211 }
212 }
213
214 fn size_hint(&self) -> (usize, Option<usize>) {
215 let first = if self.first.is_some() { 1 } else { 0 };
216 let max = self.rest.size_hint().1.unwrap_or(0) + first;
217 if max > 0 {
218 (1, Some(max))
219 } else {
220 (0, Some(0))
221 }
222 }
223}
224
225pub struct IterMut<'a, T> {
227 first: Option<&'a mut T>,
229 rest: std::slice::IterMut<'a, T>,
230}
231
232impl<'a, T> Iterator for IterMut<'a, T> {
235 type Item = &'a mut T;
236
237 fn next(&mut self) -> Option<Self::Item> {
238 if let Some(first) = self.first.take() {
239 Some(first)
240 } else {
241 self.rest.next()
242 }
243 }
244
245 fn size_hint(&self) -> (usize, Option<usize>) {
246 let first = if self.first.is_some() { 1 } else { 0 };
247 let max = self.rest.size_hint().1.unwrap_or(0) + first;
248 if max > 0 {
249 (1, Some(max))
250 } else {
251 (0, Some(0))
252 }
253 }
254}
255
256impl<T: Clone> Serialize for OneOrMany<T>
258where
259 T: Serialize,
260{
261 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
262 where
263 S: Serializer,
264 {
265 let mut seq = serializer.serialize_seq(Some(self.len()))?;
267 for e in self.iter() {
269 seq.serialize_element(e)?;
270 }
271 seq.end()
273 }
274}
275
276impl<'de, T> Deserialize<'de> for OneOrMany<T>
280where
281 T: Deserialize<'de> + Clone,
282{
283 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
284 where
285 D: Deserializer<'de>,
286 {
287 struct OneOrManyVisitor<T>(std::marker::PhantomData<T>);
289
290 impl<'de, T> Visitor<'de> for OneOrManyVisitor<T>
291 where
292 T: Deserialize<'de> + Clone,
293 {
294 type Value = OneOrMany<T>;
295
296 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
297 formatter.write_str("a sequence of at least one element")
298 }
299
300 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
302 where
303 A: SeqAccess<'de>,
304 {
305 let first = seq
307 .next_element()?
308 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
309
310 let mut rest = Vec::new();
312 while let Some(value) = seq.next_element()? {
313 rest.push(value);
314 }
315
316 Ok(OneOrMany { first, rest })
318 }
319 }
320
321 deserializer.deserialize_any(OneOrManyVisitor(std::marker::PhantomData))
323 }
324}
325
326pub fn string_or_one_or_many<'de, T, D>(deserializer: D) -> Result<OneOrMany<T>, D::Error>
335where
336 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
337 D: Deserializer<'de>,
338{
339 struct StringOrOneOrMany<T>(PhantomData<fn() -> T>);
340
341 impl<'de, T> Visitor<'de> for StringOrOneOrMany<T>
342 where
343 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
344 {
345 type Value = OneOrMany<T>;
346
347 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
348 formatter.write_str("a string or sequence")
349 }
350
351 fn visit_str<E>(self, value: &str) -> Result<OneOrMany<T>, E>
352 where
353 E: de::Error,
354 {
355 let item = FromStr::from_str(value).map_err(de::Error::custom)?;
356 Ok(OneOrMany::one(item))
357 }
358
359 fn visit_seq<A>(self, seq: A) -> Result<OneOrMany<T>, A::Error>
360 where
361 A: SeqAccess<'de>,
362 {
363 Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))
364 }
365
366 fn visit_map<M>(self, map: M) -> Result<OneOrMany<T>, M::Error>
367 where
368 M: MapAccess<'de>,
369 {
370 let item = Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
371 Ok(OneOrMany::one(item))
372 }
373 }
374
375 deserializer.deserialize_any(StringOrOneOrMany(PhantomData))
376}
377
378pub fn string_or_option_one_or_many<'de, T, D>(
387 deserializer: D,
388) -> Result<Option<OneOrMany<T>>, D::Error>
389where
390 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
391 D: Deserializer<'de>,
392{
393 struct StringOrOptionOneOrMany<T>(PhantomData<fn() -> T>);
394
395 impl<'de, T> Visitor<'de> for StringOrOptionOneOrMany<T>
396 where
397 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
398 {
399 type Value = Option<OneOrMany<T>>;
400
401 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
402 formatter.write_str("null, a string, or a sequence")
403 }
404
405 fn visit_none<E>(self) -> Result<Option<OneOrMany<T>>, E>
406 where
407 E: de::Error,
408 {
409 Ok(None)
410 }
411
412 fn visit_unit<E>(self) -> Result<Option<OneOrMany<T>>, E>
413 where
414 E: de::Error,
415 {
416 Ok(None)
417 }
418
419 fn visit_some<D>(self, deserializer: D) -> Result<Option<OneOrMany<T>>, D::Error>
420 where
421 D: Deserializer<'de>,
422 {
423 string_or_one_or_many(deserializer).map(Some)
424 }
425 }
426
427 deserializer.deserialize_option(StringOrOptionOneOrMany(PhantomData))
428}
429
430#[cfg(test)]
431mod test {
432 use serde::{self, Deserialize};
433 use serde_json::json;
434
435 use super::*;
436
437 #[test]
438 fn test_single() {
439 let one_or_many = OneOrMany::one("hello".to_string());
440
441 assert_eq!(one_or_many.iter().count(), 1);
442
443 one_or_many.iter().for_each(|i| {
444 assert_eq!(i, "hello");
445 });
446 }
447
448 #[test]
449 fn test() {
450 let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
451
452 assert_eq!(one_or_many.iter().count(), 2);
453
454 one_or_many.iter().enumerate().for_each(|(i, item)| {
455 if i == 0 {
456 assert_eq!(item, "hello");
457 }
458 if i == 1 {
459 assert_eq!(item, "word");
460 }
461 });
462 }
463
464 #[test]
465 fn test_size_hint() {
466 let foo = "bar".to_string();
467 let one_or_many = OneOrMany::one(foo);
468 let size_hint = one_or_many.iter().size_hint();
469 assert_eq!(size_hint.0, 1);
470 assert_eq!(size_hint.1, Some(1));
471
472 let vec = vec!["foo".to_string(), "bar".to_string(), "baz".to_string()];
473 let mut one_or_many = OneOrMany::many(vec).expect("this should never fail");
474 let size_hint = one_or_many.iter().size_hint();
475 assert_eq!(size_hint.0, 1);
476 assert_eq!(size_hint.1, Some(3));
477
478 let size_hint = one_or_many.clone().into_iter().size_hint();
479 assert_eq!(size_hint.0, 1);
480 assert_eq!(size_hint.1, Some(3));
481
482 let size_hint = one_or_many.iter_mut().size_hint();
483 assert_eq!(size_hint.0, 1);
484 assert_eq!(size_hint.1, Some(3));
485 }
486
487 #[test]
488 fn test_one_or_many_into_iter_single() {
489 let one_or_many = OneOrMany::one("hello".to_string());
490
491 assert_eq!(one_or_many.clone().into_iter().count(), 1);
492
493 one_or_many.into_iter().for_each(|i| {
494 assert_eq!(i, "hello".to_string());
495 });
496 }
497
498 #[test]
499 fn test_one_or_many_into_iter() {
500 let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
501
502 assert_eq!(one_or_many.clone().into_iter().count(), 2);
503
504 one_or_many.into_iter().enumerate().for_each(|(i, item)| {
505 if i == 0 {
506 assert_eq!(item, "hello".to_string());
507 }
508 if i == 1 {
509 assert_eq!(item, "word".to_string());
510 }
511 });
512 }
513
514 #[test]
515 fn test_one_or_many_merge() {
516 let one_or_many_1 = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
517
518 let one_or_many_2 = OneOrMany::one("sup".to_string());
519
520 let merged = OneOrMany::merge(vec![one_or_many_1, one_or_many_2]).unwrap();
521
522 assert_eq!(merged.iter().count(), 3);
523
524 merged.iter().enumerate().for_each(|(i, item)| {
525 if i == 0 {
526 assert_eq!(item, "hello");
527 }
528 if i == 1 {
529 assert_eq!(item, "word");
530 }
531 if i == 2 {
532 assert_eq!(item, "sup");
533 }
534 });
535 }
536
537 #[test]
538 fn test_mut_single() {
539 let mut one_or_many = OneOrMany::one("hello".to_string());
540
541 assert_eq!(one_or_many.iter_mut().count(), 1);
542
543 one_or_many.iter_mut().for_each(|i| {
544 assert_eq!(i, "hello");
545 });
546 }
547
548 #[test]
549 fn test_mut() {
550 let mut one_or_many =
551 OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
552
553 assert_eq!(one_or_many.iter_mut().count(), 2);
554
555 one_or_many.iter_mut().enumerate().for_each(|(i, item)| {
556 if i == 0 {
557 item.push_str(" world");
558 assert_eq!(item, "hello world");
559 }
560 if i == 1 {
561 assert_eq!(item, "word");
562 }
563 });
564 }
565
566 #[test]
567 fn test_one_or_many_error() {
568 assert!(OneOrMany::<String>::many(vec![]).is_err())
569 }
570
571 #[test]
572 fn test_len_single() {
573 let one_or_many = OneOrMany::one("hello".to_string());
574
575 assert_eq!(one_or_many.len(), 1);
576 }
577
578 #[test]
579 fn test_len_many() {
580 let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
581
582 assert_eq!(one_or_many.len(), 2);
583 }
584
585 #[test]
587 fn test_deserialize_list() {
588 let json_data = json!({"field": [1, 2, 3]});
589 let one_or_many: OneOrMany<i32> =
590 serde_json::from_value(json_data["field"].clone()).unwrap();
591
592 assert_eq!(one_or_many.len(), 3);
593 assert_eq!(one_or_many.first(), 1);
594 assert_eq!(one_or_many.rest(), vec![2, 3]);
595 }
596
597 #[test]
598 fn test_deserialize_list_of_maps() {
599 let json_data = json!({"field": [{"key": "value1"}, {"key": "value2"}]});
600 let one_or_many: OneOrMany<serde_json::Value> =
601 serde_json::from_value(json_data["field"].clone()).unwrap();
602
603 assert_eq!(one_or_many.len(), 2);
604 assert_eq!(one_or_many.first(), json!({"key": "value1"}));
605 assert_eq!(one_or_many.rest(), vec![json!({"key": "value2"})]);
606 }
607
608 #[derive(Debug, Deserialize, PartialEq)]
609 struct DummyStruct {
610 #[serde(deserialize_with = "string_or_one_or_many")]
611 field: OneOrMany<DummyString>,
612 }
613
614 #[derive(Debug, Deserialize, PartialEq)]
615 struct DummyStructOption {
616 #[serde(deserialize_with = "string_or_option_one_or_many")]
617 field: Option<OneOrMany<DummyString>>,
618 }
619
620 #[derive(Debug, Clone, Deserialize, PartialEq)]
621 struct DummyString {
622 pub string: String,
623 }
624
625 impl FromStr for DummyString {
626 type Err = Infallible;
627
628 fn from_str(s: &str) -> Result<Self, Self::Err> {
629 Ok(DummyString {
630 string: s.to_string(),
631 })
632 }
633 }
634
635 #[derive(Debug, Deserialize, PartialEq)]
636 #[serde(tag = "role", rename_all = "lowercase")]
637 enum DummyMessage {
638 Assistant {
639 #[serde(deserialize_with = "string_or_option_one_or_many")]
640 content: Option<OneOrMany<DummyString>>,
641 },
642 }
643
644 #[test]
645 fn test_deserialize_unit() {
646 let raw_json = r#"
647 {
648 "role": "assistant",
649 "content": null
650 }
651 "#;
652 let dummy: DummyMessage = serde_json::from_str(raw_json).unwrap();
653
654 assert_eq!(dummy, DummyMessage::Assistant { content: None });
655 }
656
657 #[test]
658 fn test_deserialize_string() {
659 let json_data = json!({"field": "hello"});
660 let dummy: DummyStruct = serde_json::from_value(json_data).unwrap();
661
662 assert_eq!(dummy.field.len(), 1);
663 assert_eq!(dummy.field.first(), DummyString::from_str("hello").unwrap());
664 }
665
666 #[test]
667 fn test_deserialize_string_option() {
668 let json_data = json!({"field": "hello"});
669 let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
670
671 assert!(dummy.field.is_some());
672 let field = dummy.field.unwrap();
673 assert_eq!(field.len(), 1);
674 assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
675 }
676
677 #[test]
678 fn test_deserialize_list_option() {
679 let json_data = json!({"field": [{"string": "hello"}, {"string": "world"}]});
680 let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
681
682 assert!(dummy.field.is_some());
683 let field = dummy.field.unwrap();
684 assert_eq!(field.len(), 2);
685 assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
686 assert_eq!(field.rest(), vec![DummyString::from_str("world").unwrap()]);
687 }
688
689 #[test]
690 fn test_deserialize_null_option() {
691 let json_data = json!({"field": null});
692 let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
693
694 assert!(dummy.field.is_none());
695 }
696}