use std::io::Read;
use bitflags::bitflags;
use futures_io::AsyncWrite;
use futures_util::{
io::{BufReader, BufWriter},
AsyncReadExt,
AsyncWriteExt,
};
use super::header::{Header, OpCode};
use crate::{
bson_util,
cmap::{
conn::{command::RawCommand, wire::util::SyncCountReader},
Command,
},
error::{Error, ErrorKind, Result},
runtime::{AsyncLittleEndianWrite, AsyncStream, SyncLittleEndianRead},
};
#[derive(Debug)]
pub(crate) struct Message {
pub(crate) response_to: i32,
pub(crate) flags: MessageFlags,
pub(crate) sections: Vec<MessageSection>,
pub(crate) checksum: Option<u32>,
pub(crate) request_id: Option<i32>,
}
impl Message {
pub(crate) fn with_command(command: Command, request_id: Option<i32>) -> Result<Self> {
let bytes = bson::to_vec(&command)?;
Ok(Self::with_raw_command(
RawCommand {
bytes,
target_db: command.target_db,
name: command.name,
},
request_id,
))
}
pub(crate) fn with_raw_command(command: RawCommand, request_id: Option<i32>) -> Self {
Self {
response_to: 0,
flags: MessageFlags::empty(),
sections: vec![MessageSection::Document(command.bytes)],
checksum: None,
request_id,
}
}
pub(crate) fn single_document_response(self) -> Result<Vec<u8>> {
let section = self.sections.into_iter().next().ok_or_else(|| {
Error::new(
ErrorKind::InvalidResponse {
message: "no response received from server".into(),
},
Option::<Vec<String>>::None,
)
})?;
match section {
MessageSection::Document(doc) => Some(doc),
MessageSection::Sequence { documents, .. } => documents.into_iter().next(),
}
.ok_or_else(|| {
Error::from(ErrorKind::InvalidResponse {
message: "no message received from the server".to_string(),
})
})
}
pub(crate) async fn read_from(reader: &mut AsyncStream) -> Result<Self> {
let mut reader = BufReader::new(reader);
let header = Header::read_from(&mut reader).await?;
let mut length_remaining = header.length - Header::LENGTH as i32;
let mut buf = vec![0u8; length_remaining as usize];
reader.read_exact(&mut buf).await?;
let mut reader = buf.as_slice();
let flags = MessageFlags::from_bits_truncate(reader.read_u32()?);
length_remaining -= std::mem::size_of::<u32>() as i32;
let mut count_reader = SyncCountReader::new(&mut reader);
let mut sections = Vec::new();
while length_remaining - count_reader.bytes_read() as i32 > 4 {
sections.push(MessageSection::read(&mut count_reader)?);
}
length_remaining -= count_reader.bytes_read() as i32;
let mut checksum = None;
if length_remaining == 4 && flags.contains(MessageFlags::CHECKSUM_PRESENT) {
checksum = Some(reader.read_u32()?);
} else if length_remaining != 0 {
return Err(ErrorKind::InvalidResponse {
message: format!(
"The server indicated that the reply would be {} bytes long, but it instead \
was {}",
header.length,
header.length - length_remaining + count_reader.bytes_read() as i32,
),
}
.into());
}
Ok(Self {
response_to: header.response_to,
flags,
sections,
checksum,
request_id: None,
})
}
pub(crate) async fn write_to(&self, stream: &mut AsyncStream) -> Result<()> {
let mut writer = BufWriter::new(stream);
let mut sections_bytes = Vec::new();
for section in &self.sections {
section.write(&mut sections_bytes).await?;
}
let total_length = Header::LENGTH
+ std::mem::size_of::<u32>()
+ sections_bytes.len()
+ self
.checksum
.as_ref()
.map(std::mem::size_of_val)
.unwrap_or(0);
let header = Header {
length: total_length as i32,
request_id: self.request_id.unwrap_or_else(super::util::next_request_id),
response_to: self.response_to,
op_code: OpCode::Message,
};
header.write_to(&mut writer).await?;
writer.write_u32(self.flags.bits()).await?;
writer.write_all(§ions_bytes).await?;
if let Some(checksum) = self.checksum {
writer.write_u32(checksum).await?;
}
writer.flush().await?;
Ok(())
}
}
bitflags! {
pub(crate) struct MessageFlags: u32 {
const CHECKSUM_PRESENT = 0b_0000_0000_0000_0000_0000_0000_0000_0001;
const MORE_TO_COME = 0b_0000_0000_0000_0000_0000_0000_0000_0010;
const EXHAUST_ALLOWED = 0b_0000_0000_0000_0001_0000_0000_0000_0000;
}
}
#[derive(Debug)]
pub(crate) enum MessageSection {
Document(Vec<u8>),
Sequence {
size: i32,
identifier: String,
documents: Vec<Vec<u8>>,
},
}
impl MessageSection {
fn read<R: Read>(reader: &mut R) -> Result<Self> {
let payload_type = reader.read_u8()?;
if payload_type == 0 {
return Ok(MessageSection::Document(bson_util::read_document_bytes(
reader,
)?));
}
let size = reader.read_i32()?;
let mut length_remaining = size - std::mem::size_of::<i32>() as i32;
let mut identifier = String::new();
length_remaining -= reader.read_to_string(&mut identifier)? as i32;
let mut documents = Vec::new();
let mut count_reader = SyncCountReader::new(reader);
while length_remaining > count_reader.bytes_read() as i32 {
documents.push(bson_util::read_document_bytes(&mut count_reader)?);
}
if length_remaining != count_reader.bytes_read() as i32 {
return Err(ErrorKind::InvalidResponse {
message: format!(
"The server indicated that the reply would be {} bytes long, but it instead \
was {}",
size,
length_remaining + count_reader.bytes_read() as i32,
),
}
.into());
}
Ok(MessageSection::Sequence {
size,
identifier,
documents,
})
}
async fn write<W: AsyncWrite + Unpin + Send>(&self, writer: &mut W) -> Result<()> {
match self {
Self::Document(doc) => {
writer.write_u8(0).await?;
writer.write_all(doc.as_slice()).await?;
}
Self::Sequence {
size,
identifier,
documents,
} => {
writer.write_u8(1).await?;
writer.write_i32(*size).await?;
super::util::write_cstring(writer, identifier).await?;
for doc in documents {
writer.write_all(doc.as_slice()).await?;
}
}
}
Ok(())
}
}