use std::convert::TryFrom;
use std::fmt;
use std::iter::FusedIterator;
use std::marker::PhantomData;
use std::str::{from_utf8, FromStr};
use circular::Buffer;
use crate::{
client::Connection,
error::{self, Error},
parse,
query::Query,
};
mod queue;
use self::queue::Queue;
pub struct Pipeline<'a> {
conn: &'a mut Connection,
buf: Buffer,
queue: Queue,
}
impl<'a> Pipeline<'a> {
#[tracing::instrument(level = "debug")]
pub(crate) fn new(conn: &'a mut Connection, capacity: usize) -> Self {
let buf = Buffer::with_capacity(capacity);
let queue = Queue::default();
Self { conn, buf, queue }
}
#[tracing::instrument(skip(conn, f), fields(initial = initial.cmd()), level = "debug")]
pub(crate) fn from_initial<'b, T, F, I>(
conn: &'a mut Connection,
initial: Query,
f: F,
) -> Result<Self, Error>
where
'a: 'b,
T: FromStr + fmt::Debug,
T::Err: std::error::Error + Send + Sync + 'static,
F: FnMut(Result<ResponseItem<T>, Error>) -> Option<I>,
I: IntoIterator<Item = Query>,
{
let mut pipeline = conn.pipeline();
let raw_self: *mut Self = pipeline.push(initial)?;
pipeline
.pop()
.unwrap_or_else(|| Err(Error::Dequeue))?
.filter_map(f)
.flatten()
.try_for_each(move |query| {
#[allow(unsafe_code)]
let result = unsafe { (*raw_self).push(query) };
if let Err(err) = result {
tracing::error!("error enqueing query: {}", err);
Err(err)
} else {
Ok(())
}
})?;
Ok(pipeline)
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn push(&mut self, query: Query) -> Result<&mut Self, Error> {
tracing::debug!("pushing new query");
self.queue.push(query);
self.flush()?;
Ok(self)
}
#[tracing::instrument(level = "trace")]
fn flush(&mut self) -> Result<(), Error> {
self.queue.flush(|query| self.conn.send(&query.cmd()))
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn pop<'b, T>(&'b mut self) -> Option<Result<Response<'a, 'b, T>, Error>>
where
T: FromStr + fmt::Debug,
T::Err: std::error::Error + Send + Sync + 'static,
{
self.pop_wrapped()
.map(|wrapped| wrapped.map_err(error::Wrapper::take_inner))
}
#[tracing::instrument(level = "trace")]
fn pop_wrapped<'b, T>(
&'b mut self,
) -> Option<Result<Response<'a, 'b, T>, error::Wrapper<'a, 'b>>>
where
'a: 'b,
T: FromStr + fmt::Debug,
T::Err: std::error::Error + Send + Sync + 'static,
{
match self.flush() {
Ok(()) => {}
Err(err) => return Some(Err(error::Wrapper::new(Some(self), err))),
};
#[allow(clippy::cognitive_complexity)]
self.queue.pop().map(move |query| {
tracing::debug!(?query, "popped query response");
let expect = loop {
tracing::trace!(?self);
match parse::response_status(self.buf.data()) {
Ok((_, (consumed, response_result))) => {
_ = self.buf.consume(consumed);
match response_result {
Ok(Some(len)) => break len,
Ok(None) => break 0,
Err(err) => {
return Err(error::Wrapper::new(
Some(self),
Error::ResponseErr(query, err),
))
}
}
}
Err(nom::Err::Incomplete(_)) => {
tracing::trace!("incomplete parse, trying to fetch more data");
if let Err(err) = self.fetch() {
return Err(error::Wrapper::new(Some(self), err));
};
}
Err(err) => {
let inner_err = err.into();
return Err(error::Wrapper::new(Some(self), inner_err));
}
}
};
if query.expect_data() {
if expect == 0 {
tracing::warn!("unexpected zero length response for query {query:?}");
}
tracing::debug!("expecting response length {} bytes", expect);
Ok(Response::new(query, self, expect))
} else if expect == 0 {
tracing::debug!("found expected zero-length response");
Ok(Response::new(query, self, expect))
} else {
Err(error::Wrapper::new(
Some(self),
Error::UnexpectedData(query, expect),
))
}
})
}
#[tracing::instrument(skip(self), level = "trace")]
pub fn responses<'b, T>(&'b mut self) -> Responses<'a, 'b, T>
where
'a: 'b,
T: FromStr + fmt::Debug,
T::Err: std::error::Error + Send + Sync + 'static,
{
Responses {
pipeline: Some(self),
current_reponse: None,
}
}
#[tracing::instrument(skip(self), level = "trace")]
fn fetch(&mut self) -> Result<usize, Error> {
self.buf.shift();
let space = self.buf.space();
tracing::trace!("trying to fetch up to {} bytes", space.len());
let fetched = self.conn.read(space)?;
tracing::trace!("fetched {} bytes", fetched);
let filled = self.buf.fill(fetched);
Ok(filled)
}
#[tracing::instrument(level = "trace")]
pub fn clear(&mut self) -> &mut Self {
self.responses::<String>().consume();
self
}
}
impl Drop for Pipeline<'_> {
fn drop(&mut self) {
_ = self.clear();
}
}
impl Extend<Query> for Pipeline<'_> {
#[tracing::instrument(skip(self, iter), level = "debug")]
fn extend<I>(&mut self, iter: I)
where
I: IntoIterator<Item = Query>,
{
iter.into_iter().for_each(|q| {
if let Err(err) = self.push(q) {
tracing::error!("error enqueuing query: {}", err);
}
});
}
}
impl fmt::Debug for Pipeline<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let max_length = 100;
let (buf_data, truncated) = if self.buf.available_data() <= max_length {
(self.buf.data(), "")
} else {
(&self.buf.data()[..max_length], " ...")
};
let buf_decoded = String::from_utf8_lossy(buf_data);
f.debug_struct("Pipeline")
.field("conn", &self.conn)
.field(
"buf",
&format_args!("'{}{}'", buf_decoded.escape_debug(), truncated),
)
.field("queue", &self.queue)
.finish()
}
}
#[derive(Debug)]
pub struct Responses<'a, 'b, T>
where
T: FromStr + fmt::Debug,
T::Err: std::error::Error + Send + Sync + 'static,
{
pipeline: Option<&'b mut Pipeline<'a>>,
current_reponse: Option<Response<'a, 'b, T>>,
}
impl<T> Responses<'_, '_, T>
where
T: FromStr + fmt::Debug,
T::Err: std::error::Error + Send + Sync + 'static,
{
#[tracing::instrument(skip(self), level = "debug")]
fn consume(&mut self) {
for item in self {
tracing::debug!(?item, "consuming unused response item");
}
}
}
impl<T> Iterator for Responses<'_, '_, T>
where
T: FromStr + fmt::Debug,
T::Err: std::error::Error + Send + Sync + 'static,
{
type Item = Result<ResponseItem<T>, Error>;
#[tracing::instrument(level = "trace")]
fn next(&mut self) -> Option<Self::Item> {
loop {
if let Some(ref mut current) = self.current_reponse {
match current.next_or_yield() {
Ok(ItemOrYield::Item(item)) => return Some(item),
Ok(ItemOrYield::Yield(pipeline)) => {
self.pipeline = Some(pipeline);
self.current_reponse = None;
}
Err(err) => {
let (pipeline, inner_err) = err.split();
tracing::warn!("error while extracting response item: {inner_err}");
self.pipeline = pipeline;
self.current_reponse = None;
return Some(Err(inner_err));
}
Ok(ItemOrYield::Finished) => {
unreachable!("current_reponse has already finished")
}
}
}
if let Some(pipeline) = self.pipeline.take() {
if let Some(next_response) = pipeline.pop_wrapped() {
match next_response {
Ok(response) => {
self.current_reponse = Some(response);
}
Err(err) => {
let (pipeline, inner_err) = err.split();
self.pipeline = pipeline;
return Some(Err(inner_err));
}
}
}
} else {
tracing::debug!("response queue empty");
return None;
}
}
}
}
impl<T> FusedIterator for Responses<'_, '_, T>
where
T: FromStr + fmt::Debug,
T::Err: std::error::Error + Send + Sync + 'static,
{
}
#[derive(Debug)]
pub struct Response<'a, 'b, T>
where
T: FromStr + fmt::Debug,
T::Err: std::error::Error + Send + Sync + 'static,
{
query: Query,
pipeline: Option<&'b mut Pipeline<'a>>,
expect: usize,
seen: usize,
finished: bool,
content_type: PhantomData<T>,
}
impl<'a, 'b, T> Response<'a, 'b, T>
where
T: FromStr + fmt::Debug,
T::Err: std::error::Error + Send + Sync + 'static,
{
pub(crate) fn new(query: Query, pipeline: &'b mut Pipeline<'a>, expect: usize) -> Self {
Self {
query,
pipeline: Some(pipeline),
expect,
seen: 0,
finished: false,
content_type: PhantomData,
}
}
#[must_use]
pub const fn query(&self) -> &Query {
&self.query
}
fn fuse(&mut self) {
self.finished = true;
}
#[tracing::instrument(level = "trace")]
fn next_or_yield(&mut self) -> Result<ItemOrYield<'a, 'b, T>, error::Wrapper<'a, 'b>> {
if self.finished {
tracing::trace!("response fully consumed");
return Ok(ItemOrYield::Finished);
}
if let Some(pipeline) = self.pipeline.take() {
if self.query.expect_data() {
if self.expect == 0 {
self.fuse();
Ok(ItemOrYield::Yield(pipeline))
} else {
loop {
if let Ok((_, consumed)) = parse::end_of_response(pipeline.buf.data()) {
_ = pipeline.buf.consume(consumed);
self.fuse();
break if self.expect == self.seen + 1 {
Ok(ItemOrYield::Yield(pipeline))
} else {
let err = Error::ResponseDataUnderrun(self.seen, self.expect);
tracing::error!(%err);
Err(error::Wrapper::new(Some(pipeline), err))
};
}
if self.seen > self.expect {
self.fuse();
let err = Error::ResponseDataOverrun(self.seen, self.expect);
tracing::error!(%err);
break Err(error::Wrapper::new(Some(pipeline), err));
}
match self.query.parse_item(pipeline.buf.data()) {
Ok((consumed, item)) => {
let item_result = Ok(ResponseItem(item, self.query.clone()));
_ = pipeline.buf.consume(consumed);
self.seen += consumed;
self.pipeline = Some(pipeline);
break Ok(ItemOrYield::Item(item_result));
}
Err(Error::Incomplete | Error::ParseErr) => {
if let Err(err) = pipeline.fetch() {
break Ok(ItemOrYield::Item(Err(err)));
}
}
Err(err @ Error::ParseItem(_, _)) => {
tracing::error!("error parsing content from response item: {err}");
if let Error::ParseItem(_, consumed) = err {
_ = pipeline.buf.consume(consumed);
self.seen += consumed;
}
self.pipeline = Some(pipeline);
break Ok(ItemOrYield::Item(Err(err)));
}
Err(err) => {
tracing::error!("error parsing word from buffer: {err}");
break Ok(ItemOrYield::Item(Err(err)));
}
}
}
}
} else {
self.fuse();
Ok(ItemOrYield::Yield(pipeline))
}
} else {
self.fuse();
Err(error::Wrapper::new(None, Error::ConsumedResponse))
}
}
fn consume(&mut self) {
for item in self {
tracing::debug!(?item, "consuming unused response item");
}
}
}
impl<T> Drop for Response<'_, '_, T>
where
T: FromStr + fmt::Debug,
T::Err: std::error::Error + Send + Sync + 'static,
{
fn drop(&mut self) {
self.consume();
}
}
impl<T> Iterator for Response<'_, '_, T>
where
T: FromStr + fmt::Debug,
T::Err: std::error::Error + Send + Sync + 'static,
{
type Item = Result<ResponseItem<T>, Error>;
fn next(&mut self) -> Option<Self::Item> {
match self.next_or_yield() {
Ok(ItemOrYield::Item(item)) => Some(item),
Ok(ItemOrYield::Yield(_) | ItemOrYield::Finished) => None,
Err(err) => Some(Err(err.into())),
}
}
}
impl<T> FusedIterator for Response<'_, '_, T>
where
T: FromStr + fmt::Debug,
T::Err: std::error::Error + Send + Sync + 'static,
{
}
enum ItemOrYield<'a, 'b, T>
where
T: FromStr + fmt::Debug,
T::Err: std::error::Error + Send + Sync + 'static,
{
Item(Result<ResponseItem<T>, Error>),
Yield(&'b mut Pipeline<'a>),
Finished,
}
#[derive(Debug)]
pub struct ResponseItem<T>(ResponseContent<T>, Query)
where
T: FromStr + fmt::Debug,
T::Err: std::error::Error + Send + Sync + 'static;
impl<T> ResponseItem<T>
where
T: FromStr + fmt::Debug,
T::Err: std::error::Error + Send + Sync + 'static,
{
pub const fn content(&self) -> &T {
self.0.content()
}
pub fn into_content(self) -> T {
self.0.into_content()
}
pub const fn query(&self) -> &Query {
&self.1
}
}
#[derive(Debug)]
pub(crate) struct ResponseContent<T>(T)
where
T: FromStr + fmt::Debug,
T::Err: std::error::Error + Send + Sync + 'static;
impl<T> ResponseContent<T>
where
T: FromStr + fmt::Debug,
T::Err: std::error::Error + Send + Sync + 'static,
{
const fn content(&self) -> &T {
&self.0
}
#[allow(clippy::missing_const_for_fn)]
fn into_content(self) -> T {
self.0
}
}
impl<T> TryFrom<&[u8]> for ResponseContent<T>
where
T: FromStr + fmt::Debug,
T::Err: std::error::Error + Send + Sync + 'static,
{
type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from(buf: &[u8]) -> Result<Self, Self::Error> {
Ok(Self(from_utf8(buf)?.parse()?))
}
}