1use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};
2use serde::ser::{SerializeSeq, Serializer};
3use serde::{Deserialize, Serialize};
4use std::convert::Infallible;
5use std::fmt;
6use std::marker::PhantomData;
7use std::str::FromStr;
8
9#[derive(PartialEq, Eq, Debug, Clone)]
15pub struct OneOrMany<T> {
16 first: T,
18 rest: Vec<T>,
20}
21
22#[derive(Debug, thiserror::Error)]
24#[error("Cannot create OneOrMany with an empty vector.")]
25pub struct EmptyListError;
26
27impl<T: Clone> OneOrMany<T> {
28 pub fn first(&self) -> T {
30 self.first.clone()
31 }
32
33 pub fn rest(&self) -> Vec<T> {
35 self.rest.clone()
36 }
37
38 pub fn push(&mut self, item: T) {
40 self.rest.push(item);
41 }
42
43 pub fn insert(&mut self, index: usize, item: T) {
45 if index == 0 {
46 let old_first = std::mem::replace(&mut self.first, item);
47 self.rest.insert(0, old_first);
48 } else {
49 self.rest.insert(index - 1, item);
50 }
51 }
52
53 pub fn len(&self) -> usize {
55 1 + self.rest.len()
56 }
57
58 pub fn is_empty(&self) -> bool {
61 false
62 }
63
64 pub fn one(item: T) -> Self {
66 OneOrMany {
67 first: item,
68 rest: vec![],
69 }
70 }
71
72 pub fn many<I>(items: I) -> Result<Self, EmptyListError>
74 where
75 I: IntoIterator<Item = T>,
76 {
77 let mut iter = items.into_iter();
78 Ok(OneOrMany {
79 first: match iter.next() {
80 Some(item) => item,
81 None => return Err(EmptyListError),
82 },
83 rest: iter.collect(),
84 })
85 }
86
87 pub fn merge<I>(one_or_many_items: I) -> Result<Self, EmptyListError>
89 where
90 I: IntoIterator<Item = OneOrMany<T>>,
91 {
92 let items = one_or_many_items
93 .into_iter()
94 .flat_map(|one_or_many| one_or_many.into_iter())
95 .collect::<Vec<_>>();
96
97 OneOrMany::many(items)
98 }
99
100 pub(crate) fn map<U, F: FnMut(T) -> U>(self, mut op: F) -> OneOrMany<U> {
106 OneOrMany {
107 first: op(self.first),
108 rest: self.rest.into_iter().map(op).collect(),
109 }
110 }
111
112 pub(crate) fn try_map<U, E, F: FnMut(T) -> Result<U, E>>(
116 self,
117 mut op: F,
118 ) -> Result<OneOrMany<U>, E> {
119 Ok(OneOrMany {
120 first: op(self.first)?,
121 rest: self
122 .rest
123 .into_iter()
124 .map(op)
125 .collect::<Result<Vec<_>, E>>()?,
126 })
127 }
128
129 pub fn iter(&self) -> Iter<T> {
130 Iter {
131 first: Some(&self.first),
132 rest: self.rest.iter(),
133 }
134 }
135
136 pub fn iter_mut(&mut self) -> IterMut<'_, T> {
137 IterMut {
138 first: Some(&mut self.first),
139 rest: self.rest.iter_mut(),
140 }
141 }
142}
143
144pub struct Iter<'a, T> {
153 first: Option<&'a T>,
155 rest: std::slice::Iter<'a, T>,
156}
157
158impl<'a, T> Iterator for Iter<'a, T> {
161 type Item = &'a T;
162
163 fn next(&mut self) -> Option<Self::Item> {
164 if let Some(first) = self.first.take() {
165 Some(first)
166 } else {
167 self.rest.next()
168 }
169 }
170}
171
172pub struct IntoIter<T> {
174 first: Option<T>,
176 rest: std::vec::IntoIter<T>,
177}
178
179impl<T: Clone> IntoIterator for OneOrMany<T> {
181 type Item = T;
182 type IntoIter = IntoIter<T>;
183
184 fn into_iter(self) -> Self::IntoIter {
185 IntoIter {
186 first: Some(self.first),
187 rest: self.rest.into_iter(),
188 }
189 }
190}
191
192impl<T: Clone> Iterator for IntoIter<T> {
195 type Item = T;
196
197 fn next(&mut self) -> Option<Self::Item> {
198 match self.first.take() {
199 Some(first) => Some(first),
200 _ => self.rest.next(),
201 }
202 }
203}
204
205pub struct IterMut<'a, T> {
207 first: Option<&'a mut T>,
209 rest: std::slice::IterMut<'a, T>,
210}
211
212impl<'a, T> Iterator for IterMut<'a, T> {
215 type Item = &'a mut T;
216
217 fn next(&mut self) -> Option<Self::Item> {
218 if let Some(first) = self.first.take() {
219 Some(first)
220 } else {
221 self.rest.next()
222 }
223 }
224}
225
226impl<T: Clone> Serialize for OneOrMany<T>
228where
229 T: Serialize,
230{
231 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
232 where
233 S: Serializer,
234 {
235 let mut seq = serializer.serialize_seq(Some(self.len()))?;
237 for e in self.iter() {
239 seq.serialize_element(e)?;
240 }
241 seq.end()
243 }
244}
245
246impl<'de, T> Deserialize<'de> for OneOrMany<T>
250where
251 T: Deserialize<'de> + Clone,
252{
253 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
254 where
255 D: Deserializer<'de>,
256 {
257 struct OneOrManyVisitor<T>(std::marker::PhantomData<T>);
259
260 impl<'de, T> Visitor<'de> for OneOrManyVisitor<T>
261 where
262 T: Deserialize<'de> + Clone,
263 {
264 type Value = OneOrMany<T>;
265
266 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
267 formatter.write_str("a sequence of at least one element")
268 }
269
270 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
272 where
273 A: SeqAccess<'de>,
274 {
275 let first = seq
277 .next_element()?
278 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
279
280 let mut rest = Vec::new();
282 while let Some(value) = seq.next_element()? {
283 rest.push(value);
284 }
285
286 Ok(OneOrMany { first, rest })
288 }
289 }
290
291 deserializer.deserialize_any(OneOrManyVisitor(std::marker::PhantomData))
293 }
294}
295
296pub fn string_or_one_or_many<'de, T, D>(deserializer: D) -> Result<OneOrMany<T>, D::Error>
305where
306 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
307 D: Deserializer<'de>,
308{
309 struct StringOrOneOrMany<T>(PhantomData<fn() -> T>);
310
311 impl<'de, T> Visitor<'de> for StringOrOneOrMany<T>
312 where
313 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
314 {
315 type Value = OneOrMany<T>;
316
317 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
318 formatter.write_str("a string or sequence")
319 }
320
321 fn visit_str<E>(self, value: &str) -> Result<OneOrMany<T>, E>
322 where
323 E: de::Error,
324 {
325 let item = FromStr::from_str(value).map_err(de::Error::custom)?;
326 Ok(OneOrMany::one(item))
327 }
328
329 fn visit_seq<A>(self, seq: A) -> Result<OneOrMany<T>, A::Error>
330 where
331 A: SeqAccess<'de>,
332 {
333 Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))
334 }
335
336 fn visit_map<M>(self, map: M) -> Result<OneOrMany<T>, M::Error>
337 where
338 M: MapAccess<'de>,
339 {
340 let item = Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
341 Ok(OneOrMany::one(item))
342 }
343 }
344
345 deserializer.deserialize_any(StringOrOneOrMany(PhantomData))
346}
347
348pub fn string_or_option_one_or_many<'de, T, D>(
357 deserializer: D,
358) -> Result<Option<OneOrMany<T>>, D::Error>
359where
360 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
361 D: Deserializer<'de>,
362{
363 struct StringOrOptionOneOrMany<T>(PhantomData<fn() -> T>);
364
365 impl<'de, T> Visitor<'de> for StringOrOptionOneOrMany<T>
366 where
367 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
368 {
369 type Value = Option<OneOrMany<T>>;
370
371 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
372 formatter.write_str("null, a string, or a sequence")
373 }
374
375 fn visit_none<E>(self) -> Result<Option<OneOrMany<T>>, E>
376 where
377 E: de::Error,
378 {
379 Ok(None)
380 }
381
382 fn visit_unit<E>(self) -> Result<Option<OneOrMany<T>>, E>
383 where
384 E: de::Error,
385 {
386 Ok(None)
387 }
388
389 fn visit_some<D>(self, deserializer: D) -> Result<Option<OneOrMany<T>>, D::Error>
390 where
391 D: Deserializer<'de>,
392 {
393 string_or_one_or_many(deserializer).map(Some)
394 }
395 }
396
397 deserializer.deserialize_option(StringOrOptionOneOrMany(PhantomData))
398}
399
400#[cfg(test)]
401mod test {
402 use serde::{self, Deserialize};
403 use serde_json::json;
404
405 use super::*;
406
407 #[test]
408 fn test_single() {
409 let one_or_many = OneOrMany::one("hello".to_string());
410
411 assert_eq!(one_or_many.iter().count(), 1);
412
413 one_or_many.iter().for_each(|i| {
414 assert_eq!(i, "hello");
415 });
416 }
417
418 #[test]
419 fn test() {
420 let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
421
422 assert_eq!(one_or_many.iter().count(), 2);
423
424 one_or_many.iter().enumerate().for_each(|(i, item)| {
425 if i == 0 {
426 assert_eq!(item, "hello");
427 }
428 if i == 1 {
429 assert_eq!(item, "word");
430 }
431 });
432 }
433
434 #[test]
435 fn test_one_or_many_into_iter_single() {
436 let one_or_many = OneOrMany::one("hello".to_string());
437
438 assert_eq!(one_or_many.clone().into_iter().count(), 1);
439
440 one_or_many.into_iter().for_each(|i| {
441 assert_eq!(i, "hello".to_string());
442 });
443 }
444
445 #[test]
446 fn test_one_or_many_into_iter() {
447 let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
448
449 assert_eq!(one_or_many.clone().into_iter().count(), 2);
450
451 one_or_many.into_iter().enumerate().for_each(|(i, item)| {
452 if i == 0 {
453 assert_eq!(item, "hello".to_string());
454 }
455 if i == 1 {
456 assert_eq!(item, "word".to_string());
457 }
458 });
459 }
460
461 #[test]
462 fn test_one_or_many_merge() {
463 let one_or_many_1 = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
464
465 let one_or_many_2 = OneOrMany::one("sup".to_string());
466
467 let merged = OneOrMany::merge(vec![one_or_many_1, one_or_many_2]).unwrap();
468
469 assert_eq!(merged.iter().count(), 3);
470
471 merged.iter().enumerate().for_each(|(i, item)| {
472 if i == 0 {
473 assert_eq!(item, "hello");
474 }
475 if i == 1 {
476 assert_eq!(item, "word");
477 }
478 if i == 2 {
479 assert_eq!(item, "sup");
480 }
481 });
482 }
483
484 #[test]
485 fn test_mut_single() {
486 let mut one_or_many = OneOrMany::one("hello".to_string());
487
488 assert_eq!(one_or_many.iter_mut().count(), 1);
489
490 one_or_many.iter_mut().for_each(|i| {
491 assert_eq!(i, "hello");
492 });
493 }
494
495 #[test]
496 fn test_mut() {
497 let mut one_or_many =
498 OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
499
500 assert_eq!(one_or_many.iter_mut().count(), 2);
501
502 one_or_many.iter_mut().enumerate().for_each(|(i, item)| {
503 if i == 0 {
504 item.push_str(" world");
505 assert_eq!(item, "hello world");
506 }
507 if i == 1 {
508 assert_eq!(item, "word");
509 }
510 });
511 }
512
513 #[test]
514 fn test_one_or_many_error() {
515 assert!(OneOrMany::<String>::many(vec![]).is_err())
516 }
517
518 #[test]
519 fn test_len_single() {
520 let one_or_many = OneOrMany::one("hello".to_string());
521
522 assert_eq!(one_or_many.len(), 1);
523 }
524
525 #[test]
526 fn test_len_many() {
527 let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
528
529 assert_eq!(one_or_many.len(), 2);
530 }
531
532 #[test]
534 fn test_deserialize_list() {
535 let json_data = json!({"field": [1, 2, 3]});
536 let one_or_many: OneOrMany<i32> =
537 serde_json::from_value(json_data["field"].clone()).unwrap();
538
539 assert_eq!(one_or_many.len(), 3);
540 assert_eq!(one_or_many.first(), 1);
541 assert_eq!(one_or_many.rest(), vec![2, 3]);
542 }
543
544 #[test]
545 fn test_deserialize_list_of_maps() {
546 let json_data = json!({"field": [{"key": "value1"}, {"key": "value2"}]});
547 let one_or_many: OneOrMany<serde_json::Value> =
548 serde_json::from_value(json_data["field"].clone()).unwrap();
549
550 assert_eq!(one_or_many.len(), 2);
551 assert_eq!(one_or_many.first(), json!({"key": "value1"}));
552 assert_eq!(one_or_many.rest(), vec![json!({"key": "value2"})]);
553 }
554
555 #[derive(Debug, Deserialize, PartialEq)]
556 struct DummyStruct {
557 #[serde(deserialize_with = "string_or_one_or_many")]
558 field: OneOrMany<DummyString>,
559 }
560
561 #[derive(Debug, Deserialize, PartialEq)]
562 struct DummyStructOption {
563 #[serde(deserialize_with = "string_or_option_one_or_many")]
564 field: Option<OneOrMany<DummyString>>,
565 }
566
567 #[derive(Debug, Clone, Deserialize, PartialEq)]
568 struct DummyString {
569 pub string: String,
570 }
571
572 impl FromStr for DummyString {
573 type Err = Infallible;
574
575 fn from_str(s: &str) -> Result<Self, Self::Err> {
576 Ok(DummyString {
577 string: s.to_string(),
578 })
579 }
580 }
581
582 #[derive(Debug, Deserialize, PartialEq)]
583 #[serde(tag = "role", rename_all = "lowercase")]
584 enum DummyMessage {
585 Assistant {
586 #[serde(deserialize_with = "string_or_option_one_or_many")]
587 content: Option<OneOrMany<DummyString>>,
588 },
589 }
590
591 #[test]
592 fn test_deserialize_unit() {
593 let raw_json = r#"
594 {
595 "role": "assistant",
596 "content": null
597 }
598 "#;
599 let dummy: DummyMessage = serde_json::from_str(raw_json).unwrap();
600
601 assert_eq!(dummy, DummyMessage::Assistant { content: None });
602 }
603
604 #[test]
605 fn test_deserialize_string() {
606 let json_data = json!({"field": "hello"});
607 let dummy: DummyStruct = serde_json::from_value(json_data).unwrap();
608
609 assert_eq!(dummy.field.len(), 1);
610 assert_eq!(dummy.field.first(), DummyString::from_str("hello").unwrap());
611 }
612
613 #[test]
614 fn test_deserialize_string_option() {
615 let json_data = json!({"field": "hello"});
616 let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
617
618 assert!(dummy.field.is_some());
619 let field = dummy.field.unwrap();
620 assert_eq!(field.len(), 1);
621 assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
622 }
623
624 #[test]
625 fn test_deserialize_list_option() {
626 let json_data = json!({"field": [{"string": "hello"}, {"string": "world"}]});
627 let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
628
629 assert!(dummy.field.is_some());
630 let field = dummy.field.unwrap();
631 assert_eq!(field.len(), 2);
632 assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
633 assert_eq!(field.rest(), vec![DummyString::from_str("world").unwrap()]);
634 }
635
636 #[test]
637 fn test_deserialize_null_option() {
638 let json_data = json!({"field": null});
639 let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
640
641 assert!(dummy.field.is_none());
642 }
643}