1use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};
2use serde::ser::{SerializeSeq, Serializer};
3use serde::{Deserialize, Serialize};
4use std::convert::Infallible;
5use std::fmt;
6use std::marker::PhantomData;
7use std::str::FromStr;
8
9#[derive(PartialEq, Eq, Debug, Clone)]
15pub struct OneOrMany<T> {
16 first: T,
18 rest: Vec<T>,
20}
21
22#[derive(Debug, thiserror::Error)]
24#[error("Cannot create OneOrMany with an empty vector.")]
25pub struct EmptyListError;
26
27impl<T: Clone> OneOrMany<T> {
28 pub fn first(&self) -> T {
30 self.first.clone()
31 }
32
33 pub fn rest(&self) -> Vec<T> {
35 self.rest.clone()
36 }
37
38 pub fn push(&mut self, item: T) {
40 self.rest.push(item);
41 }
42
43 pub fn insert(&mut self, index: usize, item: T) {
45 if index == 0 {
46 let old_first = std::mem::replace(&mut self.first, item);
47 self.rest.insert(0, old_first);
48 } else {
49 self.rest.insert(index - 1, item);
50 }
51 }
52
53 pub fn len(&self) -> usize {
55 1 + self.rest.len()
56 }
57
58 pub fn is_empty(&self) -> bool {
61 false
62 }
63
64 pub fn one(item: T) -> Self {
66 OneOrMany {
67 first: item,
68 rest: vec![],
69 }
70 }
71
72 pub fn many<I>(items: I) -> Result<Self, EmptyListError>
74 where
75 I: IntoIterator<Item = T>,
76 {
77 let mut iter = items.into_iter();
78 Ok(OneOrMany {
79 first: match iter.next() {
80 Some(item) => item,
81 None => return Err(EmptyListError),
82 },
83 rest: iter.collect(),
84 })
85 }
86
87 pub fn merge<I>(one_or_many_items: I) -> Result<Self, EmptyListError>
89 where
90 I: IntoIterator<Item = OneOrMany<T>>,
91 {
92 let items = one_or_many_items
93 .into_iter()
94 .flat_map(|one_or_many| one_or_many.into_iter())
95 .collect::<Vec<_>>();
96
97 OneOrMany::many(items)
98 }
99
100 pub(crate) fn map<U, F: FnMut(T) -> U>(self, mut op: F) -> OneOrMany<U> {
106 OneOrMany {
107 first: op(self.first),
108 rest: self.rest.into_iter().map(op).collect(),
109 }
110 }
111
112 pub(crate) fn try_map<U, E, F: FnMut(T) -> Result<U, E>>(
116 self,
117 mut op: F,
118 ) -> Result<OneOrMany<U>, E> {
119 Ok(OneOrMany {
120 first: op(self.first)?,
121 rest: self
122 .rest
123 .into_iter()
124 .map(op)
125 .collect::<Result<Vec<_>, E>>()?,
126 })
127 }
128
129 pub fn iter(&self) -> Iter<T> {
130 Iter {
131 first: Some(&self.first),
132 rest: self.rest.iter(),
133 }
134 }
135
136 pub fn iter_mut(&mut self) -> IterMut<'_, T> {
137 IterMut {
138 first: Some(&mut self.first),
139 rest: self.rest.iter_mut(),
140 }
141 }
142}
143
144pub struct Iter<'a, T> {
153 first: Option<&'a T>,
155 rest: std::slice::Iter<'a, T>,
156}
157
158impl<'a, T> Iterator for Iter<'a, T> {
161 type Item = &'a T;
162
163 fn next(&mut self) -> Option<Self::Item> {
164 if let Some(first) = self.first.take() {
165 Some(first)
166 } else {
167 self.rest.next()
168 }
169 }
170}
171
172pub struct IntoIter<T> {
174 first: Option<T>,
176 rest: std::vec::IntoIter<T>,
177}
178
179impl<T: Clone> IntoIterator for OneOrMany<T> {
181 type Item = T;
182 type IntoIter = IntoIter<T>;
183
184 fn into_iter(self) -> Self::IntoIter {
185 IntoIter {
186 first: Some(self.first),
187 rest: self.rest.into_iter(),
188 }
189 }
190}
191
192impl<T: Clone> Iterator for IntoIter<T> {
195 type Item = T;
196
197 fn next(&mut self) -> Option<Self::Item> {
198 if let Some(first) = self.first.take() {
199 Some(first)
200 } else {
201 self.rest.next()
202 }
203 }
204}
205
206pub struct IterMut<'a, T> {
208 first: Option<&'a mut T>,
210 rest: std::slice::IterMut<'a, T>,
211}
212
213impl<'a, T> Iterator for IterMut<'a, T> {
216 type Item = &'a mut T;
217
218 fn next(&mut self) -> Option<Self::Item> {
219 if let Some(first) = self.first.take() {
220 Some(first)
221 } else {
222 self.rest.next()
223 }
224 }
225}
226
227impl<T: Clone> Serialize for OneOrMany<T>
229where
230 T: Serialize,
231{
232 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
233 where
234 S: Serializer,
235 {
236 let mut seq = serializer.serialize_seq(Some(self.len()))?;
238 for e in self.iter() {
240 seq.serialize_element(e)?;
241 }
242 seq.end()
244 }
245}
246
247impl<'de, T> Deserialize<'de> for OneOrMany<T>
251where
252 T: Deserialize<'de> + Clone,
253{
254 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
255 where
256 D: Deserializer<'de>,
257 {
258 struct OneOrManyVisitor<T>(std::marker::PhantomData<T>);
260
261 impl<'de, T> Visitor<'de> for OneOrManyVisitor<T>
262 where
263 T: Deserialize<'de> + Clone,
264 {
265 type Value = OneOrMany<T>;
266
267 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
268 formatter.write_str("a sequence of at least one element")
269 }
270
271 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
273 where
274 A: SeqAccess<'de>,
275 {
276 let first = seq
278 .next_element()?
279 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
280
281 let mut rest = Vec::new();
283 while let Some(value) = seq.next_element()? {
284 rest.push(value);
285 }
286
287 Ok(OneOrMany { first, rest })
289 }
290 }
291
292 deserializer.deserialize_any(OneOrManyVisitor(std::marker::PhantomData))
294 }
295}
296
297pub fn string_or_one_or_many<'de, T, D>(deserializer: D) -> Result<OneOrMany<T>, D::Error>
306where
307 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
308 D: Deserializer<'de>,
309{
310 struct StringOrOneOrMany<T>(PhantomData<fn() -> T>);
311
312 impl<'de, T> Visitor<'de> for StringOrOneOrMany<T>
313 where
314 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
315 {
316 type Value = OneOrMany<T>;
317
318 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
319 formatter.write_str("a string or sequence")
320 }
321
322 fn visit_str<E>(self, value: &str) -> Result<OneOrMany<T>, E>
323 where
324 E: de::Error,
325 {
326 let item = FromStr::from_str(value).map_err(de::Error::custom)?;
327 Ok(OneOrMany::one(item))
328 }
329
330 fn visit_seq<A>(self, seq: A) -> Result<OneOrMany<T>, A::Error>
331 where
332 A: SeqAccess<'de>,
333 {
334 Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))
335 }
336
337 fn visit_map<M>(self, map: M) -> Result<OneOrMany<T>, M::Error>
338 where
339 M: MapAccess<'de>,
340 {
341 let item = Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
342 Ok(OneOrMany::one(item))
343 }
344 }
345
346 deserializer.deserialize_any(StringOrOneOrMany(PhantomData))
347}
348
349pub fn string_or_option_one_or_many<'de, T, D>(
358 deserializer: D,
359) -> Result<Option<OneOrMany<T>>, D::Error>
360where
361 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
362 D: Deserializer<'de>,
363{
364 struct StringOrOptionOneOrMany<T>(PhantomData<fn() -> T>);
365
366 impl<'de, T> Visitor<'de> for StringOrOptionOneOrMany<T>
367 where
368 T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
369 {
370 type Value = Option<OneOrMany<T>>;
371
372 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
373 formatter.write_str("null, a string, or a sequence")
374 }
375
376 fn visit_none<E>(self) -> Result<Option<OneOrMany<T>>, E>
377 where
378 E: de::Error,
379 {
380 Ok(None)
381 }
382
383 fn visit_unit<E>(self) -> Result<Option<OneOrMany<T>>, E>
384 where
385 E: de::Error,
386 {
387 Ok(None)
388 }
389
390 fn visit_some<D>(self, deserializer: D) -> Result<Option<OneOrMany<T>>, D::Error>
391 where
392 D: Deserializer<'de>,
393 {
394 string_or_one_or_many(deserializer).map(Some)
395 }
396 }
397
398 deserializer.deserialize_option(StringOrOptionOneOrMany(PhantomData))
399}
400
401#[cfg(test)]
402mod test {
403 use serde::{self, Deserialize};
404 use serde_json::json;
405
406 use super::*;
407
408 #[test]
409 fn test_single() {
410 let one_or_many = OneOrMany::one("hello".to_string());
411
412 assert_eq!(one_or_many.iter().count(), 1);
413
414 one_or_many.iter().for_each(|i| {
415 assert_eq!(i, "hello");
416 });
417 }
418
419 #[test]
420 fn test() {
421 let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
422
423 assert_eq!(one_or_many.iter().count(), 2);
424
425 one_or_many.iter().enumerate().for_each(|(i, item)| {
426 if i == 0 {
427 assert_eq!(item, "hello");
428 }
429 if i == 1 {
430 assert_eq!(item, "word");
431 }
432 });
433 }
434
435 #[test]
436 fn test_one_or_many_into_iter_single() {
437 let one_or_many = OneOrMany::one("hello".to_string());
438
439 assert_eq!(one_or_many.clone().into_iter().count(), 1);
440
441 one_or_many.into_iter().for_each(|i| {
442 assert_eq!(i, "hello".to_string());
443 });
444 }
445
446 #[test]
447 fn test_one_or_many_into_iter() {
448 let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
449
450 assert_eq!(one_or_many.clone().into_iter().count(), 2);
451
452 one_or_many.into_iter().enumerate().for_each(|(i, item)| {
453 if i == 0 {
454 assert_eq!(item, "hello".to_string());
455 }
456 if i == 1 {
457 assert_eq!(item, "word".to_string());
458 }
459 });
460 }
461
462 #[test]
463 fn test_one_or_many_merge() {
464 let one_or_many_1 = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
465
466 let one_or_many_2 = OneOrMany::one("sup".to_string());
467
468 let merged = OneOrMany::merge(vec![one_or_many_1, one_or_many_2]).unwrap();
469
470 assert_eq!(merged.iter().count(), 3);
471
472 merged.iter().enumerate().for_each(|(i, item)| {
473 if i == 0 {
474 assert_eq!(item, "hello");
475 }
476 if i == 1 {
477 assert_eq!(item, "word");
478 }
479 if i == 2 {
480 assert_eq!(item, "sup");
481 }
482 });
483 }
484
485 #[test]
486 fn test_mut_single() {
487 let mut one_or_many = OneOrMany::one("hello".to_string());
488
489 assert_eq!(one_or_many.iter_mut().count(), 1);
490
491 one_or_many.iter_mut().for_each(|i| {
492 assert_eq!(i, "hello");
493 });
494 }
495
496 #[test]
497 fn test_mut() {
498 let mut one_or_many =
499 OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
500
501 assert_eq!(one_or_many.iter_mut().count(), 2);
502
503 one_or_many.iter_mut().enumerate().for_each(|(i, item)| {
504 if i == 0 {
505 item.push_str(" world");
506 assert_eq!(item, "hello world");
507 }
508 if i == 1 {
509 assert_eq!(item, "word");
510 }
511 });
512 }
513
514 #[test]
515 fn test_one_or_many_error() {
516 assert!(OneOrMany::<String>::many(vec![]).is_err())
517 }
518
519 #[test]
520 fn test_len_single() {
521 let one_or_many = OneOrMany::one("hello".to_string());
522
523 assert_eq!(one_or_many.len(), 1);
524 }
525
526 #[test]
527 fn test_len_many() {
528 let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
529
530 assert_eq!(one_or_many.len(), 2);
531 }
532
533 #[test]
535 fn test_deserialize_list() {
536 let json_data = json!({"field": [1, 2, 3]});
537 let one_or_many: OneOrMany<i32> =
538 serde_json::from_value(json_data["field"].clone()).unwrap();
539
540 assert_eq!(one_or_many.len(), 3);
541 assert_eq!(one_or_many.first(), 1);
542 assert_eq!(one_or_many.rest(), vec![2, 3]);
543 }
544
545 #[test]
546 fn test_deserialize_list_of_maps() {
547 let json_data = json!({"field": [{"key": "value1"}, {"key": "value2"}]});
548 let one_or_many: OneOrMany<serde_json::Value> =
549 serde_json::from_value(json_data["field"].clone()).unwrap();
550
551 assert_eq!(one_or_many.len(), 2);
552 assert_eq!(one_or_many.first(), json!({"key": "value1"}));
553 assert_eq!(one_or_many.rest(), vec![json!({"key": "value2"})]);
554 }
555
556 #[derive(Debug, Deserialize, PartialEq)]
557 struct DummyStruct {
558 #[serde(deserialize_with = "string_or_one_or_many")]
559 field: OneOrMany<DummyString>,
560 }
561
562 #[derive(Debug, Deserialize, PartialEq)]
563 struct DummyStructOption {
564 #[serde(deserialize_with = "string_or_option_one_or_many")]
565 field: Option<OneOrMany<DummyString>>,
566 }
567
568 #[derive(Debug, Clone, Deserialize, PartialEq)]
569 struct DummyString {
570 pub string: String,
571 }
572
573 impl FromStr for DummyString {
574 type Err = Infallible;
575
576 fn from_str(s: &str) -> Result<Self, Self::Err> {
577 Ok(DummyString {
578 string: s.to_string(),
579 })
580 }
581 }
582
583 #[derive(Debug, Deserialize, PartialEq)]
584 #[serde(tag = "role", rename_all = "lowercase")]
585 enum DummyMessage {
586 Assistant {
587 #[serde(deserialize_with = "string_or_option_one_or_many")]
588 content: Option<OneOrMany<DummyString>>,
589 },
590 }
591
592 #[test]
593 fn test_deserialize_unit() {
594 let raw_json = r#"
595 {
596 "role": "assistant",
597 "content": null
598 }
599 "#;
600 let dummy: DummyMessage = serde_json::from_str(raw_json).unwrap();
601
602 assert_eq!(dummy, DummyMessage::Assistant { content: None });
603 }
604
605 #[test]
606 fn test_deserialize_string() {
607 let json_data = json!({"field": "hello"});
608 let dummy: DummyStruct = serde_json::from_value(json_data).unwrap();
609
610 assert_eq!(dummy.field.len(), 1);
611 assert_eq!(dummy.field.first(), DummyString::from_str("hello").unwrap());
612 }
613
614 #[test]
615 fn test_deserialize_string_option() {
616 let json_data = json!({"field": "hello"});
617 let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
618
619 assert!(dummy.field.is_some());
620 let field = dummy.field.unwrap();
621 assert_eq!(field.len(), 1);
622 assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
623 }
624
625 #[test]
626 fn test_deserialize_list_option() {
627 let json_data = json!({"field": [{"string": "hello"}, {"string": "world"}]});
628 let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
629
630 assert!(dummy.field.is_some());
631 let field = dummy.field.unwrap();
632 assert_eq!(field.len(), 2);
633 assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
634 assert_eq!(field.rest(), vec![DummyString::from_str("world").unwrap()]);
635 }
636
637 #[test]
638 fn test_deserialize_null_option() {
639 let json_data = json!({"field": null});
640 let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
641
642 assert!(dummy.field.is_none());
643 }
644}