use crate::{cursors::RawCursor, error::Result, query_summary::QuerySummary, response::Response};
use bytes::{Buf, Bytes, BytesMut};
use futures_util::TryFutureExt;
use std::{
io::Result as IoResult,
pin::Pin,
task::{Context, Poll, ready},
};
use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf};
use tracing::Instrument;
pub struct BytesCursor {
raw: RawCursor,
bytes: Bytes,
span: tracing::Span,
}
impl BytesCursor {
pub(crate) fn new(response: Response, span: tracing::Span) -> Self {
Self {
raw: RawCursor::new(response),
bytes: Bytes::default(),
span,
}
}
pub async fn next(&mut self) -> Result<Option<Bytes>> {
assert!(
self.bytes.is_empty(),
"mixing `BytesCursor::next()` and `AsyncRead` API methods is not allowed"
);
self.raw
.next()
.inspect_err(|e| tracing::debug!(error=?e, "error from BytesCursor::next()"))
.instrument(self.span.clone())
.await
}
pub async fn collect(&mut self) -> Result<Bytes> {
let mut chunks = Vec::new();
let mut total_len = 0;
while let Some(chunk) = self.next().await? {
total_len += chunk.len();
chunks.push(chunk);
}
if chunks.len() == 1 {
return Ok(chunks.pop().unwrap());
}
let mut collected = BytesMut::with_capacity(total_len);
for chunk in chunks {
collected.extend_from_slice(&chunk);
}
debug_assert_eq!(collected.capacity(), total_len);
Ok(collected.freeze())
}
#[cold]
fn poll_refill(&mut self, cx: &mut Context<'_>) -> Poll<IoResult<bool>> {
debug_assert_eq!(self.bytes.len(), 0);
let _guard = self.span.enter();
while self.bytes.is_empty() {
match ready!(self.raw.poll_next(cx)) {
Ok(Some(chunk)) => self.bytes = chunk,
Ok(None) => return Poll::Ready(Ok(false)),
Err(e) => {
tracing::debug!(error=?e, "error reading from cursor");
return Poll::Ready(Err(e.into()));
}
}
}
Poll::Ready(Ok(true))
}
#[inline]
pub fn received_bytes(&self) -> u64 {
self.raw.received_bytes()
}
#[inline]
pub fn decoded_bytes(&self) -> u64 {
self.raw.decoded_bytes()
}
#[inline]
pub fn summary(&self) -> Option<&QuerySummary> {
self.raw.summary()
}
}
impl AsyncRead for BytesCursor {
#[inline]
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<IoResult<()>> {
while buf.remaining() > 0 {
if self.bytes.is_empty() && !ready!(self.poll_refill(cx)?) {
break;
}
let len = self.bytes.len().min(buf.remaining());
let bytes = self.bytes.slice(..len);
buf.put_slice(&bytes[0..len]);
self.bytes.advance(len);
}
Poll::Ready(Ok(()))
}
}
impl AsyncBufRead for BytesCursor {
#[inline]
fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<&[u8]>> {
if self.bytes.is_empty() {
ready!(self.poll_refill(cx)?);
}
Poll::Ready(Ok(&self.get_mut().bytes))
}
#[inline]
fn consume(mut self: Pin<&mut Self>, amt: usize) {
assert!(
amt <= self.bytes.len(),
"invalid `AsyncBufRead::consume` usage"
);
self.bytes.advance(amt);
}
}
impl Drop for BytesCursor {
fn drop(&mut self) {
let _span = self.span.enter();
tracing::record_all!(
self.span,
clickhouse.response.received_bytes = self.received_bytes(),
clickhouse.response.decoded_bytes = self.decoded_bytes(),
);
tracing::debug!("finished raw query");
}
}
#[cfg(feature = "futures03")]
impl futures_util::AsyncRead for BytesCursor {
#[inline]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<IoResult<usize>> {
let mut buf = ReadBuf::new(buf);
ready!(AsyncRead::poll_read(self, cx, &mut buf)?);
Poll::Ready(Ok(buf.filled().len()))
}
}
#[cfg(feature = "futures03")]
impl futures_util::AsyncBufRead for BytesCursor {
#[inline]
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<&[u8]>> {
AsyncBufRead::poll_fill_buf(self, cx)
}
#[inline]
fn consume(self: Pin<&mut Self>, amt: usize) {
AsyncBufRead::consume(self, amt);
}
}
#[cfg(feature = "futures03")]
impl futures_util::stream::Stream for BytesCursor {
type Item = crate::error::Result<bytes::Bytes>;
#[inline]
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
assert!(
self.bytes.is_empty(),
"mixing `Stream` and `AsyncRead` API methods is not allowed"
);
let this = &mut *self;
let _guard = this.span.enter();
this.raw.poll_next(cx).map(Result::transpose)
}
}
#[cfg(feature = "futures03")]
impl futures_util::stream::FusedStream for BytesCursor {
#[inline]
fn is_terminated(&self) -> bool {
self.bytes.is_empty() && self.raw.is_terminated()
}
}