1use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};
2use serde::ser::{SerializeSeq, Serializer};
3use serde::{Deserialize, Serialize};
4use std::convert::Infallible;
5use std::fmt;
6use std::marker::PhantomData;
7use std::str::FromStr;
8
9#[derive(PartialEq, Eq, Debug, Clone)]
15pub struct OneOrMany<T> {
16 first: T,
18 rest: Vec<T>,
20}
21
22#[derive(Debug, thiserror::Error)]
24#[error("Cannot create OneOrMany with an empty vector.")]
25pub struct EmptyListError;
26
27impl<T: Clone> OneOrMany<T> {
28 pub fn first(&self) -> T {
30 self.first.clone()
31 }
32
33 pub fn rest(&self) -> Vec<T> {
35 self.rest.clone()
36 }
37
38 pub fn push(&mut self, item: T) {
40 self.rest.push(item);
41 }
42
43 pub fn insert(&mut self, index: usize, item: T) {
45 if index == 0 {
46 let old_first = std::mem::replace(&mut self.first, item);
47 self.rest.insert(0, old_first);
48 } else {
49 self.rest.insert(index - 1, item);
50 }
51 }
52
53 pub fn len(&self) -> usize {
55 1 + self.rest.len()
56 }
57
58 pub fn is_empty(&self) -> bool {
61 false
62 }
63
64 pub fn one(item: T) -> Self {
66 OneOrMany {
67 first: item,
68 rest: vec![],
69 }
70 }
71
72 pub fn many<I>(items: I) -> Result<Self, EmptyListError>
74 where
75 I: IntoIterator<Item = T>,
76 {
77 let mut iter = items.into_iter();
78 Ok(OneOrMany {
79 first: match iter.next() {
80 Some(item) => item,
81 None => return Err(EmptyListError),
82 },
83 rest: iter.collect(),
84 })
85 }
86
87 pub fn merge<I>(one_or_many_items: I) -> Result<Self, EmptyListError>
89 where
90 I: IntoIterator<Item = OneOrMany<T>>,
91 {
92 let items = one_or_many_items
93 .into_iter()
94 .flat_map(|one_or_many| one_or_many.into_iter())
95 .collect::<Vec<_>>();
96
97 OneOrMany::many(items)
98 }
99
100 pub(crate) fn map<U, F: FnMut(T) -> U>(self, mut op: F) -> OneOrMany<U> {
106 OneOrMany {
107 first: op(self.first),
108 rest: self.rest.into_iter().map(op).collect(),
109 }
110 }
111
112 pub(crate) fn try_map<U, E, F>(self, mut op: F) -> Result<OneOrMany<U>, E>
116 where
117 F: FnMut(T) -> Result<U, E>,
118 {
119 Ok(OneOrMany {
120 first: op(self.first)?,
121 rest: self
122 .rest
123 .into_iter()
124 .map(op)
125 .collect::<Result<Vec<_>, E>>()?,
126 })
127 }
128
129 pub fn iter(&self) -> Iter<'_, T> {
130 Iter {
131 first: Some(&self.first),
132 rest: self.rest.iter(),
133 }
134 }
135
136 pub fn iter_mut(&mut self) -> IterMut<'_, T> {
137 IterMut {
138 first: Some(&mut self.first),
139 rest: self.rest.iter_mut(),
140 }
141 }
142}
143
144pub struct Iter<'a, T> {
153 first: Option<&'a T>,
155 rest: std::slice::Iter<'a, T>,
156}
157
158impl<'a, T> Iterator for Iter<'a, T> {
161 type Item = &'a T;
162
163 fn next(&mut self) -> Option<Self::Item> {
164 if let Some(first) = self.first.take() {
165 Some(first)
166 } else {
167 self.rest.next()
168 }
169 }
170
171 fn size_hint(&self) -> (usize, Option<usize>) {
172 let first = if self.first.is_some() { 1 } else { 0 };
173 let max = self.rest.size_hint().1.unwrap_or(0) + first;
174 if max > 0 {
175 (1, Some(max))
176 } else {
177 (0, Some(0))
178 }
179 }
180}
181
182pub struct IntoIter<T> {
184 first: Option<T>,
186 rest: std::vec::IntoIter<T>,
187}
188
189impl<T> IntoIterator for OneOrMany<T>
191where
192 T: Clone,
193{
194 type Item = T;
195 type IntoIter = IntoIter<T>;
196
197 fn into_iter(self) -> Self::IntoIter {
198 IntoIter {
199 first: Some(self.first),
200 rest: self.rest.into_iter(),
201 }
202 }
203}
204
205impl<T> Iterator for IntoIter<T>
208where
209 T: Clone,
210{
211 type Item = T;
212
213 fn next(&mut self) -> Option<Self::Item> {
214 match self.first.take() {
215 Some(first) => Some(first),
216 _ => self.rest.next(),
217 }
218 }
219
220 fn size_hint(&self) -> (usize, Option<usize>) {
221 let first = if self.first.is_some() { 1 } else { 0 };
222 let max = self.rest.size_hint().1.unwrap_or(0) + first;
223 if max > 0 {
224 (1, Some(max))
225 } else {
226 (0, Some(0))
227 }
228 }
229}
230
231pub struct IterMut<'a, T> {
233 first: Option<&'a mut T>,
235 rest: std::slice::IterMut<'a, T>,
236}
237
238impl<'a, T> Iterator for IterMut<'a, T> {
241 type Item = &'a mut T;
242
243 fn next(&mut self) -> Option<Self::Item> {
244 if let Some(first) = self.first.take() {
245 Some(first)
246 } else {
247 self.rest.next()
248 }
249 }
250
251 fn size_hint(&self) -> (usize, Option<usize>) {
252 let first = if self.first.is_some() { 1 } else { 0 };
253 let max = self.rest.size_hint().1.unwrap_or(0) + first;
254 if max > 0 {
255 (1, Some(max))
256 } else {
257 (0, Some(0))
258 }
259 }
260}
261
262impl<T> Serialize for OneOrMany<T>
264where
265 T: Serialize + Clone,
266{
267 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
268 where
269 S: Serializer,
270 {
271 let mut seq = serializer.serialize_seq(Some(self.len()))?;
273 for e in self.iter() {
275 seq.serialize_element(e)?;
276 }
277 seq.end()
279 }
280}
281
282impl<'de, T> Deserialize<'de> for OneOrMany<T>
286where
287 T: Deserialize<'de> + Clone,
288{
289 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
290 where
291 D: Deserializer<'de>,
292 {
293 struct OneOrManyVisitor<T>(std::marker::PhantomData<T>);
295
296 impl<'de, T> Visitor<'de> for OneOrManyVisitor<T>
297 where
298 T: Deserialize<'de> + Clone,
299 {
300 type Value = OneOrMany<T>;
301
302 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
303 formatter.write_str("a sequence of at least one element")
304 }
305
306 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
308 where
309 A: SeqAccess<'de>,
310 {
311 let first = seq
313 .next_element()?
314 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
315
316 let mut rest = Vec::new();
318 while let Some(value) = seq.next_element()? {
319 rest.push(value);
320 }
321
322 Ok(OneOrMany { first, rest })
324 }
325 }
326
327 deserializer.deserialize_any(OneOrManyVisitor(std::marker::PhantomData))
329 }
330}
331
332pub fn string_or_one_or_many<'de, T, D>(deserializer: D) -> Result<OneOrMany<T>, D::Error>
341where
342 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
343 D: Deserializer<'de>,
344{
345 struct StringOrOneOrMany<T>(PhantomData<fn() -> T>);
346
347 impl<'de, T> Visitor<'de> for StringOrOneOrMany<T>
348 where
349 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
350 {
351 type Value = OneOrMany<T>;
352
353 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
354 formatter.write_str("a string or sequence")
355 }
356
357 fn visit_str<E>(self, value: &str) -> Result<OneOrMany<T>, E>
358 where
359 E: de::Error,
360 {
361 let item = FromStr::from_str(value).map_err(de::Error::custom)?;
362 Ok(OneOrMany::one(item))
363 }
364
365 fn visit_seq<A>(self, seq: A) -> Result<OneOrMany<T>, A::Error>
366 where
367 A: SeqAccess<'de>,
368 {
369 Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))
370 }
371
372 fn visit_map<M>(self, map: M) -> Result<OneOrMany<T>, M::Error>
373 where
374 M: MapAccess<'de>,
375 {
376 let item = Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
377 Ok(OneOrMany::one(item))
378 }
379 }
380
381 deserializer.deserialize_any(StringOrOneOrMany(PhantomData))
382}
383
384pub fn string_or_option_one_or_many<'de, T, D>(
393 deserializer: D,
394) -> Result<Option<OneOrMany<T>>, D::Error>
395where
396 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
397 D: Deserializer<'de>,
398{
399 struct StringOrOptionOneOrMany<T>(PhantomData<fn() -> T>);
400
401 impl<'de, T> Visitor<'de> for StringOrOptionOneOrMany<T>
402 where
403 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
404 {
405 type Value = Option<OneOrMany<T>>;
406
407 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
408 formatter.write_str("null, a string, or a sequence")
409 }
410
411 fn visit_none<E>(self) -> Result<Option<OneOrMany<T>>, E>
412 where
413 E: de::Error,
414 {
415 Ok(None)
416 }
417
418 fn visit_unit<E>(self) -> Result<Option<OneOrMany<T>>, E>
419 where
420 E: de::Error,
421 {
422 Ok(None)
423 }
424
425 fn visit_some<D>(self, deserializer: D) -> Result<Option<OneOrMany<T>>, D::Error>
426 where
427 D: Deserializer<'de>,
428 {
429 string_or_one_or_many(deserializer).map(Some)
430 }
431 }
432
433 deserializer.deserialize_option(StringOrOptionOneOrMany(PhantomData))
434}
435
436#[cfg(test)]
437mod test {
438 use serde::{self, Deserialize};
439 use serde_json::json;
440
441 use super::*;
442
443 #[test]
444 fn test_single() {
445 let one_or_many = OneOrMany::one("hello".to_string());
446
447 assert_eq!(one_or_many.iter().count(), 1);
448
449 one_or_many.iter().for_each(|i| {
450 assert_eq!(i, "hello");
451 });
452 }
453
454 #[test]
455 fn test() {
456 let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
457
458 assert_eq!(one_or_many.iter().count(), 2);
459
460 one_or_many.iter().enumerate().for_each(|(i, item)| {
461 if i == 0 {
462 assert_eq!(item, "hello");
463 }
464 if i == 1 {
465 assert_eq!(item, "word");
466 }
467 });
468 }
469
470 #[test]
471 fn test_size_hint() {
472 let foo = "bar".to_string();
473 let one_or_many = OneOrMany::one(foo);
474 let size_hint = one_or_many.iter().size_hint();
475 assert_eq!(size_hint.0, 1);
476 assert_eq!(size_hint.1, Some(1));
477
478 let vec = vec!["foo".to_string(), "bar".to_string(), "baz".to_string()];
479 let mut one_or_many = OneOrMany::many(vec).expect("this should never fail");
480 let size_hint = one_or_many.iter().size_hint();
481 assert_eq!(size_hint.0, 1);
482 assert_eq!(size_hint.1, Some(3));
483
484 let size_hint = one_or_many.clone().into_iter().size_hint();
485 assert_eq!(size_hint.0, 1);
486 assert_eq!(size_hint.1, Some(3));
487
488 let size_hint = one_or_many.iter_mut().size_hint();
489 assert_eq!(size_hint.0, 1);
490 assert_eq!(size_hint.1, Some(3));
491 }
492
493 #[test]
494 fn test_one_or_many_into_iter_single() {
495 let one_or_many = OneOrMany::one("hello".to_string());
496
497 assert_eq!(one_or_many.clone().into_iter().count(), 1);
498
499 one_or_many.into_iter().for_each(|i| {
500 assert_eq!(i, "hello".to_string());
501 });
502 }
503
504 #[test]
505 fn test_one_or_many_into_iter() {
506 let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
507
508 assert_eq!(one_or_many.clone().into_iter().count(), 2);
509
510 one_or_many.into_iter().enumerate().for_each(|(i, item)| {
511 if i == 0 {
512 assert_eq!(item, "hello".to_string());
513 }
514 if i == 1 {
515 assert_eq!(item, "word".to_string());
516 }
517 });
518 }
519
520 #[test]
521 fn test_one_or_many_merge() {
522 let one_or_many_1 = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
523
524 let one_or_many_2 = OneOrMany::one("sup".to_string());
525
526 let merged = OneOrMany::merge(vec![one_or_many_1, one_or_many_2]).unwrap();
527
528 assert_eq!(merged.iter().count(), 3);
529
530 merged.iter().enumerate().for_each(|(i, item)| {
531 if i == 0 {
532 assert_eq!(item, "hello");
533 }
534 if i == 1 {
535 assert_eq!(item, "word");
536 }
537 if i == 2 {
538 assert_eq!(item, "sup");
539 }
540 });
541 }
542
543 #[test]
544 fn test_mut_single() {
545 let mut one_or_many = OneOrMany::one("hello".to_string());
546
547 assert_eq!(one_or_many.iter_mut().count(), 1);
548
549 one_or_many.iter_mut().for_each(|i| {
550 assert_eq!(i, "hello");
551 });
552 }
553
554 #[test]
555 fn test_mut() {
556 let mut one_or_many =
557 OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
558
559 assert_eq!(one_or_many.iter_mut().count(), 2);
560
561 one_or_many.iter_mut().enumerate().for_each(|(i, item)| {
562 if i == 0 {
563 item.push_str(" world");
564 assert_eq!(item, "hello world");
565 }
566 if i == 1 {
567 assert_eq!(item, "word");
568 }
569 });
570 }
571
572 #[test]
573 fn test_one_or_many_error() {
574 assert!(OneOrMany::<String>::many(vec![]).is_err())
575 }
576
577 #[test]
578 fn test_len_single() {
579 let one_or_many = OneOrMany::one("hello".to_string());
580
581 assert_eq!(one_or_many.len(), 1);
582 }
583
584 #[test]
585 fn test_len_many() {
586 let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
587
588 assert_eq!(one_or_many.len(), 2);
589 }
590
591 #[test]
593 fn test_deserialize_list() {
594 let json_data = json!({"field": [1, 2, 3]});
595 let one_or_many: OneOrMany<i32> =
596 serde_json::from_value(json_data["field"].clone()).unwrap();
597
598 assert_eq!(one_or_many.len(), 3);
599 assert_eq!(one_or_many.first(), 1);
600 assert_eq!(one_or_many.rest(), vec![2, 3]);
601 }
602
603 #[test]
604 fn test_deserialize_list_of_maps() {
605 let json_data = json!({"field": [{"key": "value1"}, {"key": "value2"}]});
606 let one_or_many: OneOrMany<serde_json::Value> =
607 serde_json::from_value(json_data["field"].clone()).unwrap();
608
609 assert_eq!(one_or_many.len(), 2);
610 assert_eq!(one_or_many.first(), json!({"key": "value1"}));
611 assert_eq!(one_or_many.rest(), vec![json!({"key": "value2"})]);
612 }
613
614 #[derive(Debug, Deserialize, PartialEq)]
615 struct DummyStruct {
616 #[serde(deserialize_with = "string_or_one_or_many")]
617 field: OneOrMany<DummyString>,
618 }
619
620 #[derive(Debug, Deserialize, PartialEq)]
621 struct DummyStructOption {
622 #[serde(deserialize_with = "string_or_option_one_or_many")]
623 field: Option<OneOrMany<DummyString>>,
624 }
625
626 #[derive(Debug, Clone, Deserialize, PartialEq)]
627 struct DummyString {
628 pub string: String,
629 }
630
631 impl FromStr for DummyString {
632 type Err = Infallible;
633
634 fn from_str(s: &str) -> Result<Self, Self::Err> {
635 Ok(DummyString {
636 string: s.to_string(),
637 })
638 }
639 }
640
641 #[derive(Debug, Deserialize, PartialEq)]
642 #[serde(tag = "role", rename_all = "lowercase")]
643 enum DummyMessage {
644 Assistant {
645 #[serde(deserialize_with = "string_or_option_one_or_many")]
646 content: Option<OneOrMany<DummyString>>,
647 },
648 }
649
650 #[test]
651 fn test_deserialize_unit() {
652 let raw_json = r#"
653 {
654 "role": "assistant",
655 "content": null
656 }
657 "#;
658 let dummy: DummyMessage = serde_json::from_str(raw_json).unwrap();
659
660 assert_eq!(dummy, DummyMessage::Assistant { content: None });
661 }
662
663 #[test]
664 fn test_deserialize_string() {
665 let json_data = json!({"field": "hello"});
666 let dummy: DummyStruct = serde_json::from_value(json_data).unwrap();
667
668 assert_eq!(dummy.field.len(), 1);
669 assert_eq!(dummy.field.first(), DummyString::from_str("hello").unwrap());
670 }
671
672 #[test]
673 fn test_deserialize_string_option() {
674 let json_data = json!({"field": "hello"});
675 let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
676
677 assert!(dummy.field.is_some());
678 let field = dummy.field.unwrap();
679 assert_eq!(field.len(), 1);
680 assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
681 }
682
683 #[test]
684 fn test_deserialize_list_option() {
685 let json_data = json!({"field": [{"string": "hello"}, {"string": "world"}]});
686 let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
687
688 assert!(dummy.field.is_some());
689 let field = dummy.field.unwrap();
690 assert_eq!(field.len(), 2);
691 assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
692 assert_eq!(field.rest(), vec![DummyString::from_str("world").unwrap()]);
693 }
694
695 #[test]
696 fn test_deserialize_null_option() {
697 let json_data = json!({"field": null});
698 let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
699
700 assert!(dummy.field.is_none());
701 }
702}