use std::{borrow::Cow, convert::TryFrom, fmt::Formatter, ops::Range};
use serde::{
de::{DeserializeSeed, Error as SerdeError, MapAccess, SeqAccess, Visitor},
Deserializer,
};
use crate::{
raw::{write_string, RAW_BSON_NEWTYPE},
spec::{BinarySubtype, ElementType},
RawBson,
RawBsonRef,
};
use super::{CowStr, MapParse, OwnedOrBorrowedRawBson, OwnedOrBorrowedRawBsonVisitor};
pub(crate) struct CowByteBuffer<'de>(pub(crate) Option<Cow<'de, [u8]>>);
impl<'de> CowByteBuffer<'de> {
pub(crate) fn new() -> Self {
Self(None)
}
fn len(&self) -> usize {
match &self.0 {
Some(buffer) => buffer.len(),
None => 0,
}
}
fn get_owned_buffer(&mut self) -> &mut Vec<u8> {
self.0
.get_or_insert_with(|| Cow::Owned(Vec::new()))
.to_mut()
}
fn push_byte(&mut self, byte: u8) {
let buffer = self.get_owned_buffer();
buffer.push(byte);
}
fn append_bytes(&mut self, bytes: &[u8]) {
let buffer = self.get_owned_buffer();
buffer.extend_from_slice(bytes);
}
fn append_borrowed_bytes(&mut self, bytes: &'de [u8]) {
match &mut self.0 {
Some(buffer) => buffer.to_mut().extend_from_slice(bytes),
None => self.0 = Some(Cow::Borrowed(bytes)),
}
}
fn copy_from_slice(&mut self, range: Range<usize>, slice: &[u8]) {
let buffer = self.get_owned_buffer();
buffer[range].copy_from_slice(slice);
}
fn drain(&mut self, range: Range<usize>) {
let buffer = self.get_owned_buffer();
buffer.drain(range);
}
}
pub(crate) struct SeededVisitor<'a, 'de> {
buffer: &'a mut CowByteBuffer<'de>,
}
impl<'de> DeserializeSeed<'de> for SeededVisitor<'_, 'de> {
type Value = ElementType;
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_newtype_struct(RAW_BSON_NEWTYPE, self)
}
}
impl<'de> DeserializeSeed<'de> for &mut SeededVisitor<'_, 'de> {
type Value = ElementType;
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_newtype_struct(
RAW_BSON_NEWTYPE,
SeededVisitor {
buffer: self.buffer,
},
)
}
}
impl<'a, 'de> SeededVisitor<'a, 'de> {
pub(crate) fn new(buffer: &'a mut CowByteBuffer<'de>) -> Self {
Self { buffer }
}
fn append_cstring(&mut self, key: &str) -> Result<(), String> {
crate::raw::CStr::from_str(key)
.map_err(|e| e.to_string())?
.append_to(self.buffer.get_owned_buffer());
Ok(())
}
fn append_string(&mut self, s: &str) {
write_string(self.buffer.get_owned_buffer(), s)
}
fn append_length_bytes(&mut self, length: i32) {
self.buffer.append_bytes(&length.to_le_bytes());
}
fn append_owned_binary(&mut self, bytes: Vec<u8>, subtype: u8) {
match &mut self.buffer.0 {
Some(_) => self.append_embedded_binary(&bytes, subtype),
None => self.buffer.0 = Some(Cow::Owned(bytes)),
}
}
fn append_borrowed_binary(&mut self, bytes: &'de [u8], subtype: u8) {
match &self.buffer.0 {
Some(_) => self.append_embedded_binary(bytes, subtype),
None => self.buffer.0 = Some(Cow::Borrowed(bytes)),
}
}
fn append_embedded_binary(&mut self, bytes: &[u8], subtype: impl Into<u8>) {
self.append_length_bytes(bytes.len() as i32);
self.buffer.push_byte(subtype.into());
self.buffer.append_bytes(bytes);
}
fn pad_element_type(&mut self) -> usize {
let index = self.buffer.len();
self.buffer.push_byte(0);
index
}
fn write_element_type(&mut self, element_type: ElementType, index: usize) {
self.buffer
.copy_from_slice(index..index + 1, &[element_type as u8]);
}
fn pad_document_length(&mut self) -> usize {
let index = self.buffer.len();
self.buffer.append_bytes(&[0; 4]);
index
}
fn finish_document(&mut self, length_index: usize) -> Result<(), String> {
self.buffer.push_byte(0);
let length_bytes = match i32::try_from(self.buffer.len() - length_index) {
Ok(length) => length.to_le_bytes(),
Err(_) => return Err("value exceeds maximum length".to_string()),
};
self.buffer
.copy_from_slice(length_index..length_index + 4, &length_bytes);
Ok(())
}
pub(crate) fn iterate_map<A>(mut self, first_key: CowStr, mut map: A) -> Result<(), A::Error>
where
A: MapAccess<'de>,
{
let length_index = self.pad_document_length();
let mut current_key = first_key;
loop {
let element_type_index = self.pad_element_type();
self.append_cstring(current_key.0.as_ref())
.map_err(SerdeError::custom)?;
let element_type = map.next_value_seed(&mut self)?;
self.write_element_type(element_type, element_type_index);
match map.next_key::<CowStr>()? {
Some(next_key) => current_key = next_key,
None => break,
}
}
self.finish_document(length_index)
.map_err(SerdeError::custom)?;
Ok(())
}
}
impl<'de> Visitor<'de> for SeededVisitor<'_, 'de> {
type Value = ElementType;
fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
formatter.write_str("a raw BSON value")
}
fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(self)
}
fn visit_map<A>(mut self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
match OwnedOrBorrowedRawBsonVisitor::parse_map(&mut map)? {
MapParse::Leaf(bson) => {
match bson {
OwnedOrBorrowedRawBson::Borrowed(RawBsonRef::Binary(b)) => {
self.append_borrowed_binary(b.bytes, b.subtype.into());
Ok(ElementType::Binary)
}
OwnedOrBorrowedRawBson::Owned(RawBson::Binary(b)) => {
self.append_owned_binary(b.bytes, b.subtype.into());
Ok(ElementType::Binary)
}
OwnedOrBorrowedRawBson::Borrowed(RawBsonRef::Document(doc)) => {
self.buffer.append_borrowed_bytes(doc.as_bytes());
Ok(ElementType::EmbeddedDocument)
}
OwnedOrBorrowedRawBson::Borrowed(RawBsonRef::Array(arr)) => {
self.buffer.append_borrowed_bytes(arr.as_bytes());
Ok(ElementType::Array)
}
_ => {
let bson = bson.as_ref();
bson.append_to(self.buffer.get_owned_buffer());
Ok(bson.element_type())
}
}
}
MapParse::Aggregate(first_key) => {
self.iterate_map(first_key, map)?;
Ok(ElementType::EmbeddedDocument)
}
}
}
fn visit_seq<A>(mut self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let length_index = self.pad_document_length();
let mut i = 0u32;
loop {
let element_type_index = self.pad_element_type();
let key = i.to_string();
self.append_cstring(&key).map_err(SerdeError::custom)?;
let element_type = match seq.next_element_seed(&mut self)? {
Some(element_type) => element_type,
None => {
self.buffer.drain(element_type_index..self.buffer.len());
break;
}
};
self.write_element_type(element_type, element_type_index);
i += 1;
}
self.finish_document(length_index)
.map_err(SerdeError::custom)?;
Ok(ElementType::Array)
}
fn visit_str<E>(mut self, s: &str) -> Result<Self::Value, E>
where
E: SerdeError,
{
self.append_string(s);
Ok(ElementType::String)
}
fn visit_bool<E>(self, b: bool) -> Result<Self::Value, E>
where
E: SerdeError,
{
self.buffer.push_byte(b as u8);
Ok(ElementType::Boolean)
}
fn visit_i8<E>(self, n: i8) -> Result<Self::Value, E>
where
E: SerdeError,
{
self.buffer.append_bytes(&(n as i32).to_le_bytes());
Ok(ElementType::Int32)
}
fn visit_i16<E>(self, n: i16) -> Result<Self::Value, E>
where
E: SerdeError,
{
self.buffer.append_bytes(&(n as i32).to_le_bytes());
Ok(ElementType::Int32)
}
fn visit_i32<E>(self, n: i32) -> Result<Self::Value, E>
where
E: SerdeError,
{
self.buffer.append_bytes(&n.to_le_bytes());
Ok(ElementType::Int32)
}
fn visit_i64<E>(self, n: i64) -> Result<Self::Value, E>
where
E: SerdeError,
{
self.buffer.append_bytes(&n.to_le_bytes());
Ok(ElementType::Int64)
}
fn visit_u8<E>(self, n: u8) -> Result<Self::Value, E>
where
E: SerdeError,
{
self.buffer.append_bytes(&(n as i32).to_le_bytes());
Ok(ElementType::Int32)
}
fn visit_u16<E>(self, n: u16) -> Result<Self::Value, E>
where
E: SerdeError,
{
self.buffer.append_bytes(&(n as i32).to_le_bytes());
Ok(ElementType::Int32)
}
fn visit_u32<E>(self, n: u32) -> Result<Self::Value, E>
where
E: SerdeError,
{
match i32::try_from(n) {
Ok(n) => {
self.buffer.append_bytes(&n.to_le_bytes());
Ok(ElementType::Int32)
}
Err(_) => {
self.buffer.append_bytes(&(n as i64).to_le_bytes());
Ok(ElementType::Int64)
}
}
}
fn visit_u64<E>(self, n: u64) -> Result<Self::Value, E>
where
E: SerdeError,
{
if let Ok(n) = i32::try_from(n) {
self.buffer.append_bytes(&n.to_le_bytes());
Ok(ElementType::Int32)
} else if let Ok(n) = i64::try_from(n) {
self.buffer.append_bytes(&n.to_le_bytes());
Ok(ElementType::Int64)
} else {
Err(SerdeError::custom(format!(
"number is too large for BSON: {}",
n
)))
}
}
fn visit_f64<E>(self, n: f64) -> Result<Self::Value, E>
where
E: SerdeError,
{
self.buffer.append_bytes(&n.to_le_bytes());
Ok(ElementType::Double)
}
fn visit_none<E>(self) -> Result<Self::Value, E>
where
E: SerdeError,
{
Ok(ElementType::Null)
}
fn visit_unit<E>(self) -> Result<Self::Value, E>
where
E: SerdeError,
{
Ok(ElementType::Null)
}
fn visit_bytes<E>(mut self, bytes: &[u8]) -> Result<Self::Value, E>
where
E: SerdeError,
{
self.append_owned_binary(bytes.to_owned(), BinarySubtype::Generic.into());
Ok(ElementType::Binary)
}
fn visit_borrowed_bytes<E>(mut self, bytes: &'de [u8]) -> Result<Self::Value, E>
where
E: SerdeError,
{
self.append_borrowed_binary(bytes, BinarySubtype::Generic.into());
Ok(ElementType::Binary)
}
}