use std::io::{Read, Write};
use bitflags::bitflags;
use bson::Document;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use super::{
header::{Header, OpCode},
util::CountReader,
};
use crate::{
cmap::conn::command::Command,
error::{ErrorKind, Result},
};
#[derive(Debug)]
pub(crate) struct Message {
pub(crate) response_to: i32,
pub(crate) flags: MessageFlags,
pub(crate) sections: Vec<MessageSection>,
pub(crate) checksum: Option<u32>,
pub(crate) request_id: Option<i32>,
}
impl Message {
pub(crate) fn with_command(mut command: Command, request_id: Option<i32>) -> Self {
command.body.insert("$db", command.target_db);
if let Some(read_pref) = command.read_pref {
command
.body
.insert("$readPreference", read_pref.into_document());
};
Self {
response_to: 0,
flags: MessageFlags::empty(),
sections: vec![MessageSection::Document(command.body)],
checksum: None,
request_id,
}
}
pub(crate) fn single_document_response(self) -> Result<Document> {
self.sections
.into_iter()
.next()
.and_then(|section| match section {
MessageSection::Document(doc) => Some(doc),
MessageSection::Sequence { documents, .. } => documents.into_iter().next(),
})
.ok_or_else(|| {
ErrorKind::ResponseError {
message: "no response received from server".into(),
}
.into()
})
}
#[allow(dead_code)]
pub(crate) fn documents(self) -> Vec<Document> {
self.sections
.into_iter()
.flat_map(|section| match section {
MessageSection::Document(doc) => vec![doc],
MessageSection::Sequence { documents, .. } => documents,
})
.collect()
}
pub(crate) fn read_from<R: Read>(reader: &mut R) -> Result<Self> {
let header = Header::read_from(reader)?;
let mut length_remaining = header.length - Header::LENGTH as i32;
let flags = MessageFlags::from_bits_truncate(reader.read_u32::<LittleEndian>()?);
length_remaining -= std::mem::size_of::<u32>() as i32;
let mut count_reader = CountReader::new(reader);
let mut sections = Vec::new();
while length_remaining - count_reader.bytes_read() as i32 > 4 {
sections.push(MessageSection::read(&mut count_reader)?);
}
length_remaining -= count_reader.bytes_read() as i32;
let mut checksum = None;
if length_remaining == 4 && flags.contains(MessageFlags::CHECKSUM_PRESENT) {
checksum = Some(reader.read_u32::<LittleEndian>()?);
} else if length_remaining != 0 {
return Err(ErrorKind::OperationError {
message: format!(
"The server indicated that the reply would be {} bytes long, but it instead \
was {}",
header.length,
header.length - length_remaining + count_reader.bytes_read() as i32,
),
}
.into());
}
Ok(Self {
response_to: header.response_to,
flags,
sections,
checksum,
request_id: None,
})
}
pub(crate) fn write_to<W: Write>(&self, writer: &mut W) -> Result<()> {
let mut sections_bytes = Vec::new();
for section in &self.sections {
section.write(&mut sections_bytes)?;
}
let total_length = Header::LENGTH
+ std::mem::size_of::<u32>()
+ sections_bytes.len()
+ self
.checksum
.as_ref()
.map(std::mem::size_of_val)
.unwrap_or(0);
let header = Header {
length: total_length as i32,
request_id: self.request_id.unwrap_or_else(super::util::next_request_id),
response_to: self.response_to,
op_code: OpCode::Message,
};
header.write_to(writer)?;
writer.write_u32::<LittleEndian>(self.flags.bits())?;
writer.write_all(§ions_bytes)?;
if let Some(checksum) = self.checksum {
writer.write_u32::<LittleEndian>(checksum)?;
}
writer.flush()?;
Ok(())
}
}
bitflags! {
pub(crate) struct MessageFlags: u32 {
const CHECKSUM_PRESENT = 0b_0000_0000_0000_0000_0000_0000_0000_0001;
const MORE_TO_COME = 0b_0000_0000_0000_0000_0000_0000_0000_0010;
const EXHAUST_ALLOWED = 0b_0000_0000_0000_0001_0000_0000_0000_0000;
}
}
#[derive(Debug)]
pub(crate) enum MessageSection {
Document(Document),
Sequence {
size: i32,
identifier: String,
documents: Vec<Document>,
},
}
impl MessageSection {
fn read<R: Read>(reader: &mut R) -> Result<Self> {
let payload_type = reader.read_u8()?;
if payload_type == 0 {
return Ok(MessageSection::Document(bson::decode_document(reader)?));
}
let size = reader.read_i32::<LittleEndian>()?;
let mut length_remaining = size - std::mem::size_of::<i32>() as i32;
let mut identifier = String::new();
length_remaining -= reader.read_to_string(&mut identifier)? as i32;
let mut documents = Vec::new();
let mut count_reader = CountReader::new(reader);
while length_remaining - count_reader.bytes_read() as i32 > 0 {
documents.push(bson::decode_document(&mut count_reader)?);
}
if length_remaining - count_reader.bytes_read() as i32 != 0 {
return Err(ErrorKind::OperationError {
message: format!(
"The server indicated that the reply would be {} bytes long, but it instead \
was {}",
size,
size - length_remaining + count_reader.bytes_read() as i32,
),
}
.into());
}
Ok(MessageSection::Sequence {
size,
identifier,
documents,
})
}
fn write<W: Write>(&self, writer: &mut W) -> Result<()> {
match self {
Self::Document(doc) => {
writer.write_u8(0)?;
bson::encode_document(writer, doc)?;
}
Self::Sequence {
size,
identifier,
documents,
} => {
writer.write_u8(1)?;
writer.write_i32::<LittleEndian>(*size)?;
super::util::write_cstring(writer, identifier)?;
for doc in documents {
bson::encode_document(writer, doc)?;
}
}
}
Ok(())
}
}