use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};
use serde::ser::{SerializeSeq, Serializer};
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use std::fmt;
use std::marker::PhantomData;
use std::str::FromStr;
#[derive(PartialEq, Eq, Debug, Clone)]
pub struct OneOrMany<T> {
first: T,
rest: Vec<T>,
}
#[derive(Debug, thiserror::Error)]
#[error("Cannot create OneOrMany with an empty vector.")]
pub struct EmptyListError;
impl<T: Clone> OneOrMany<T> {
pub fn first(&self) -> T {
self.first.clone()
}
pub fn first_ref(&self) -> &T {
&self.first
}
pub fn first_mut(&mut self) -> &mut T {
&mut self.first
}
pub fn last(&self) -> T {
self.rest
.last()
.cloned()
.unwrap_or_else(|| self.first.clone())
}
pub fn last_ref(&self) -> &T {
self.rest.last().unwrap_or(&self.first)
}
pub fn last_mut(&mut self) -> &mut T {
self.rest.last_mut().unwrap_or(&mut self.first)
}
pub fn rest(&self) -> Vec<T> {
self.rest.clone()
}
pub fn push(&mut self, item: T) {
self.rest.push(item);
}
pub fn insert(&mut self, index: usize, item: T) {
if index == 0 {
let old_first = std::mem::replace(&mut self.first, item);
self.rest.insert(0, old_first);
} else {
self.rest.insert(index - 1, item);
}
}
pub fn len(&self) -> usize {
1 + self.rest.len()
}
pub fn is_empty(&self) -> bool {
false
}
pub fn one(item: T) -> Self {
OneOrMany {
first: item,
rest: vec![],
}
}
pub fn many<I>(items: I) -> Result<Self, EmptyListError>
where
I: IntoIterator<Item = T>,
{
let mut iter = items.into_iter();
Ok(OneOrMany {
first: match iter.next() {
Some(item) => item,
None => return Err(EmptyListError),
},
rest: iter.collect(),
})
}
pub fn merge<I>(one_or_many_items: I) -> Result<Self, EmptyListError>
where
I: IntoIterator<Item = OneOrMany<T>>,
{
let items = one_or_many_items
.into_iter()
.flat_map(|one_or_many| one_or_many.into_iter())
.collect::<Vec<_>>();
OneOrMany::many(items)
}
pub(crate) fn map<U, F: FnMut(T) -> U>(self, mut op: F) -> OneOrMany<U> {
OneOrMany {
first: op(self.first),
rest: self.rest.into_iter().map(op).collect(),
}
}
pub(crate) fn try_map<U, E, F>(self, mut op: F) -> Result<OneOrMany<U>, E>
where
F: FnMut(T) -> Result<U, E>,
{
Ok(OneOrMany {
first: op(self.first)?,
rest: self
.rest
.into_iter()
.map(op)
.collect::<Result<Vec<_>, E>>()?,
})
}
pub fn iter(&self) -> Iter<'_, T> {
Iter {
first: Some(&self.first),
rest: self.rest.iter(),
}
}
pub fn iter_mut(&mut self) -> IterMut<'_, T> {
IterMut {
first: Some(&mut self.first),
rest: self.rest.iter_mut(),
}
}
}
pub struct Iter<'a, T> {
first: Option<&'a T>,
rest: std::slice::Iter<'a, T>,
}
impl<'a, T> Iterator for Iter<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
if let Some(first) = self.first.take() {
Some(first)
} else {
self.rest.next()
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let first = if self.first.is_some() { 1 } else { 0 };
let max = self.rest.size_hint().1.unwrap_or(0) + first;
if max > 0 {
(1, Some(max))
} else {
(0, Some(0))
}
}
}
pub struct IntoIter<T> {
first: Option<T>,
rest: std::vec::IntoIter<T>,
}
impl<T> IntoIterator for OneOrMany<T>
where
T: Clone,
{
type Item = T;
type IntoIter = IntoIter<T>;
fn into_iter(self) -> Self::IntoIter {
IntoIter {
first: Some(self.first),
rest: self.rest.into_iter(),
}
}
}
impl<T> Iterator for IntoIter<T>
where
T: Clone,
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
match self.first.take() {
Some(first) => Some(first),
_ => self.rest.next(),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let first = if self.first.is_some() { 1 } else { 0 };
let max = self.rest.size_hint().1.unwrap_or(0) + first;
if max > 0 {
(1, Some(max))
} else {
(0, Some(0))
}
}
}
pub struct IterMut<'a, T> {
first: Option<&'a mut T>,
rest: std::slice::IterMut<'a, T>,
}
impl<'a, T> Iterator for IterMut<'a, T> {
type Item = &'a mut T;
fn next(&mut self) -> Option<Self::Item> {
if let Some(first) = self.first.take() {
Some(first)
} else {
self.rest.next()
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let first = if self.first.is_some() { 1 } else { 0 };
let max = self.rest.size_hint().1.unwrap_or(0) + first;
if max > 0 {
(1, Some(max))
} else {
(0, Some(0))
}
}
}
impl<T> Serialize for OneOrMany<T>
where
T: Serialize + Clone,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut seq = serializer.serialize_seq(Some(self.len()))?;
for e in self.iter() {
seq.serialize_element(e)?;
}
seq.end()
}
}
impl<'de, T> Deserialize<'de> for OneOrMany<T>
where
T: Deserialize<'de> + Clone,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct OneOrManyVisitor<T>(std::marker::PhantomData<T>);
impl<'de, T> Visitor<'de> for OneOrManyVisitor<T>
where
T: Deserialize<'de> + Clone,
{
type Value = OneOrMany<T>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a sequence of at least one element")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let first = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(0, &self))?;
let mut rest = Vec::new();
while let Some(value) = seq.next_element()? {
rest.push(value);
}
Ok(OneOrMany { first, rest })
}
}
deserializer.deserialize_any(OneOrManyVisitor(std::marker::PhantomData))
}
}
pub fn string_or_one_or_many<'de, T, D>(deserializer: D) -> Result<OneOrMany<T>, D::Error>
where
T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
D: Deserializer<'de>,
{
struct StringOrOneOrMany<T>(PhantomData<fn() -> T>);
impl<'de, T> Visitor<'de> for StringOrOneOrMany<T>
where
T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
{
type Value = OneOrMany<T>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string or sequence")
}
fn visit_str<E>(self, value: &str) -> Result<OneOrMany<T>, E>
where
E: de::Error,
{
let item = FromStr::from_str(value).map_err(de::Error::custom)?;
Ok(OneOrMany::one(item))
}
fn visit_seq<A>(self, seq: A) -> Result<OneOrMany<T>, A::Error>
where
A: SeqAccess<'de>,
{
Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))
}
fn visit_map<M>(self, map: M) -> Result<OneOrMany<T>, M::Error>
where
M: MapAccess<'de>,
{
let item = Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
Ok(OneOrMany::one(item))
}
}
deserializer.deserialize_any(StringOrOneOrMany(PhantomData))
}
pub fn string_or_option_one_or_many<'de, T, D>(
deserializer: D,
) -> Result<Option<OneOrMany<T>>, D::Error>
where
T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
D: Deserializer<'de>,
{
struct StringOrOptionOneOrMany<T>(PhantomData<fn() -> T>);
impl<'de, T> Visitor<'de> for StringOrOptionOneOrMany<T>
where
T: Deserialize<'de> + FromStr<Err = Infallible> + Clone,
{
type Value = Option<OneOrMany<T>>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("null, a string, or a sequence")
}
fn visit_none<E>(self) -> Result<Option<OneOrMany<T>>, E>
where
E: de::Error,
{
Ok(None)
}
fn visit_unit<E>(self) -> Result<Option<OneOrMany<T>>, E>
where
E: de::Error,
{
Ok(None)
}
fn visit_some<D>(self, deserializer: D) -> Result<Option<OneOrMany<T>>, D::Error>
where
D: Deserializer<'de>,
{
string_or_one_or_many(deserializer).map(Some)
}
}
deserializer.deserialize_option(StringOrOptionOneOrMany(PhantomData))
}
#[cfg(test)]
mod test {
use serde::{self, Deserialize};
use serde_json::json;
use super::*;
#[test]
fn test_single() {
let one_or_many = OneOrMany::one("hello".to_string());
assert_eq!(one_or_many.iter().count(), 1);
one_or_many.iter().for_each(|i| {
assert_eq!(i, "hello");
});
}
#[test]
fn test() {
let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
assert_eq!(one_or_many.iter().count(), 2);
one_or_many.iter().enumerate().for_each(|(i, item)| {
if i == 0 {
assert_eq!(item, "hello");
}
if i == 1 {
assert_eq!(item, "word");
}
});
}
#[test]
fn test_size_hint() {
let foo = "bar".to_string();
let one_or_many = OneOrMany::one(foo);
let size_hint = one_or_many.iter().size_hint();
assert_eq!(size_hint.0, 1);
assert_eq!(size_hint.1, Some(1));
let vec = vec!["foo".to_string(), "bar".to_string(), "baz".to_string()];
let mut one_or_many = OneOrMany::many(vec).expect("this should never fail");
let size_hint = one_or_many.iter().size_hint();
assert_eq!(size_hint.0, 1);
assert_eq!(size_hint.1, Some(3));
let size_hint = one_or_many.clone().into_iter().size_hint();
assert_eq!(size_hint.0, 1);
assert_eq!(size_hint.1, Some(3));
let size_hint = one_or_many.iter_mut().size_hint();
assert_eq!(size_hint.0, 1);
assert_eq!(size_hint.1, Some(3));
}
#[test]
fn test_one_or_many_into_iter_single() {
let one_or_many = OneOrMany::one("hello".to_string());
assert_eq!(one_or_many.clone().into_iter().count(), 1);
one_or_many.into_iter().for_each(|i| {
assert_eq!(i, "hello".to_string());
});
}
#[test]
fn test_one_or_many_into_iter() {
let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
assert_eq!(one_or_many.clone().into_iter().count(), 2);
one_or_many.into_iter().enumerate().for_each(|(i, item)| {
if i == 0 {
assert_eq!(item, "hello".to_string());
}
if i == 1 {
assert_eq!(item, "word".to_string());
}
});
}
#[test]
fn test_one_or_many_merge() {
let one_or_many_1 = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
let one_or_many_2 = OneOrMany::one("sup".to_string());
let merged = OneOrMany::merge(vec![one_or_many_1, one_or_many_2]).unwrap();
assert_eq!(merged.iter().count(), 3);
merged.iter().enumerate().for_each(|(i, item)| {
if i == 0 {
assert_eq!(item, "hello");
}
if i == 1 {
assert_eq!(item, "word");
}
if i == 2 {
assert_eq!(item, "sup");
}
});
}
#[test]
fn test_mut_single() {
let mut one_or_many = OneOrMany::one("hello".to_string());
assert_eq!(one_or_many.iter_mut().count(), 1);
one_or_many.iter_mut().for_each(|i| {
assert_eq!(i, "hello");
});
}
#[test]
fn test_mut() {
let mut one_or_many =
OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
assert_eq!(one_or_many.iter_mut().count(), 2);
one_or_many.iter_mut().enumerate().for_each(|(i, item)| {
if i == 0 {
item.push_str(" world");
assert_eq!(item, "hello world");
}
if i == 1 {
assert_eq!(item, "word");
}
});
}
#[test]
fn test_one_or_many_error() {
assert!(OneOrMany::<String>::many(vec![]).is_err())
}
#[test]
fn test_len_single() {
let one_or_many = OneOrMany::one("hello".to_string());
assert_eq!(one_or_many.len(), 1);
}
#[test]
fn test_len_many() {
let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap();
assert_eq!(one_or_many.len(), 2);
}
#[test]
fn test_deserialize_list() {
let json_data = json!({"field": [1, 2, 3]});
let one_or_many: OneOrMany<i32> =
serde_json::from_value(json_data["field"].clone()).unwrap();
assert_eq!(one_or_many.len(), 3);
assert_eq!(one_or_many.first(), 1);
assert_eq!(one_or_many.rest(), vec![2, 3]);
}
#[test]
fn test_deserialize_list_of_maps() {
let json_data = json!({"field": [{"key": "value1"}, {"key": "value2"}]});
let one_or_many: OneOrMany<serde_json::Value> =
serde_json::from_value(json_data["field"].clone()).unwrap();
assert_eq!(one_or_many.len(), 2);
assert_eq!(one_or_many.first(), json!({"key": "value1"}));
assert_eq!(one_or_many.rest(), vec![json!({"key": "value2"})]);
}
#[derive(Debug, Deserialize, PartialEq)]
struct DummyStruct {
#[serde(deserialize_with = "string_or_one_or_many")]
field: OneOrMany<DummyString>,
}
#[derive(Debug, Deserialize, PartialEq)]
struct DummyStructOption {
#[serde(deserialize_with = "string_or_option_one_or_many")]
field: Option<OneOrMany<DummyString>>,
}
#[derive(Debug, Clone, Deserialize, PartialEq)]
struct DummyString {
pub string: String,
}
impl FromStr for DummyString {
type Err = Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(DummyString {
string: s.to_string(),
})
}
}
#[derive(Debug, Deserialize, PartialEq)]
#[serde(tag = "role", rename_all = "lowercase")]
enum DummyMessage {
Assistant {
#[serde(deserialize_with = "string_or_option_one_or_many")]
content: Option<OneOrMany<DummyString>>,
},
}
#[test]
fn test_deserialize_unit() {
let raw_json = r#"
{
"role": "assistant",
"content": null
}
"#;
let dummy: DummyMessage = serde_json::from_str(raw_json).unwrap();
assert_eq!(dummy, DummyMessage::Assistant { content: None });
}
#[test]
fn test_deserialize_string() {
let json_data = json!({"field": "hello"});
let dummy: DummyStruct = serde_json::from_value(json_data).unwrap();
assert_eq!(dummy.field.len(), 1);
assert_eq!(dummy.field.first(), DummyString::from_str("hello").unwrap());
}
#[test]
fn test_deserialize_string_option() {
let json_data = json!({"field": "hello"});
let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
assert!(dummy.field.is_some());
let field = dummy.field.unwrap();
assert_eq!(field.len(), 1);
assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
}
#[test]
fn test_deserialize_list_option() {
let json_data = json!({"field": [{"string": "hello"}, {"string": "world"}]});
let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
assert!(dummy.field.is_some());
let field = dummy.field.unwrap();
assert_eq!(field.len(), 2);
assert_eq!(field.first(), DummyString::from_str("hello").unwrap());
assert_eq!(field.rest(), vec![DummyString::from_str("world").unwrap()]);
}
#[test]
fn test_deserialize_null_option() {
let json_data = json!({"field": null});
let dummy: DummyStructOption = serde_json::from_value(json_data).unwrap();
assert!(dummy.field.is_none());
}
}