use std::{marker::PhantomData, pin::Pin, sync::Arc, task};
use azure_data_cosmos_driver::{
models::{ContainerReference, CosmosResponse as DriverResponse, CosmosResponseHeaders},
options::OperationOptions,
CosmosDriver, OperationPlan,
};
use futures::future::BoxFuture;
use futures::Stream;
use serde::{de::DeserializeOwned, Deserialize};
use crate::{
driver_bridge,
models::{CosmosResponse, DiagnosticsContext, ResponseHeaders},
ContinuationToken,
};
#[derive(Debug)]
pub struct FeedPage<T> {
items: Vec<T>,
headers: ResponseHeaders,
diagnostics: Arc<DiagnosticsContext>,
}
impl<T> FeedPage<T> {
pub(crate) fn new(
items: Vec<T>,
headers: ResponseHeaders,
diagnostics: Arc<DiagnosticsContext>,
) -> Self {
Self {
items,
headers,
diagnostics,
}
}
pub fn items(&self) -> &[T] {
&self.items
}
pub fn into_items(self) -> Vec<T> {
self.items
}
pub fn headers(&self) -> &ResponseHeaders {
&self.headers
}
pub fn diagnostics(&self) -> Arc<DiagnosticsContext> {
Arc::clone(&self.diagnostics)
}
}
#[derive(Debug)]
pub struct QueryFeedPage<T> {
page: FeedPage<T>,
index_metrics: Option<String>,
query_metrics: Option<String>,
}
impl<T> QueryFeedPage<T> {
pub fn as_feed_page(&self) -> &FeedPage<T> {
&self.page
}
pub fn items(&self) -> &[T] {
self.page.items()
}
pub fn into_items(self) -> Vec<T> {
self.page.into_items()
}
pub fn headers(&self) -> &ResponseHeaders {
self.page.headers()
}
pub fn diagnostics(&self) -> Arc<DiagnosticsContext> {
self.page.diagnostics()
}
pub fn index_metrics(&self) -> Option<&str> {
self.index_metrics.as_deref()
}
pub fn query_metrics(&self) -> Option<&str> {
self.query_metrics.as_deref()
}
}
#[derive(Deserialize)]
pub(crate) struct FeedBody<T> {
#[serde(alias = "Documents")]
#[serde(alias = "DocumentCollections")]
#[serde(alias = "Databases")]
#[serde(alias = "Offers")]
pub(crate) items: Vec<T>,
}
impl<T: DeserializeOwned> QueryFeedPage<T> {
pub(crate) fn from_response(response: CosmosResponse) -> crate::Result<Self> {
let cosmos_headers: CosmosResponseHeaders =
crate::models::response_headers::into_driver_headers(response.cosmos_headers().clone());
let index_metrics = cosmos_headers.index_metrics.clone();
let query_metrics = cosmos_headers.query_metrics.clone();
let diagnostics = response.diagnostics();
let body: FeedBody<T> = response.into_model()?;
Ok(Self {
page: FeedPage::new(
body.items,
ResponseHeaders::from(cosmos_headers),
diagnostics,
),
index_metrics,
query_metrics,
})
}
}
type DriverPageFuture = BoxFuture<'static, (OperationPlan, crate::Result<Option<DriverResponse>>)>;
#[pin_project::pin_project]
struct LiveState {
driver: Arc<CosmosDriver>,
container: Option<ContainerReference>,
options: OperationOptions,
plan: Option<OperationPlan>,
in_flight: Option<DriverPageFuture>,
exhausted: bool,
}
impl LiveState {
fn new(
driver: Arc<CosmosDriver>,
container: Option<ContainerReference>,
plan: OperationPlan,
options: OperationOptions,
) -> Self {
Self {
driver,
container,
options,
plan: Some(plan),
in_flight: None,
exhausted: false,
}
}
fn poll_next_page<T: DeserializeOwned + Send + 'static>(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<Option<crate::Result<QueryFeedPage<T>>>> {
let this = self.project();
if *this.exhausted {
return task::Poll::Ready(None);
}
let in_flight = match this.in_flight.as_mut() {
Some(fut) => fut,
None => {
let mut plan = this
.plan
.take()
.expect("plan must be present between polls");
let driver = Arc::clone(this.driver);
let container = this.container.clone();
let options = this.options.clone();
let fut: DriverPageFuture = Box::pin(async move {
let result = driver
.execute_plan(&mut plan, container, options)
.await
.map_err(Into::into);
(plan, result)
});
this.in_flight.insert(fut)
}
};
let (plan, result) = match in_flight.as_mut().poll(cx) {
task::Poll::Pending => return task::Poll::Pending,
task::Poll::Ready(out) => out,
};
this.in_flight.take();
*this.plan = Some(plan);
match result {
Ok(None) => {
*this.exhausted = true;
task::Poll::Ready(None)
}
Err(err) => {
*this.exhausted = true;
task::Poll::Ready(Some(Err(err)))
}
Ok(Some(driver_response)) => {
let response = driver_bridge::driver_response_to_cosmos_response(driver_response);
match QueryFeedPage::from_response(response) {
Ok(page) => task::Poll::Ready(Some(Ok(page))),
Err(err) => {
*this.exhausted = true;
task::Poll::Ready(Some(Err(err)))
}
}
}
}
}
fn to_continuation_token(&self) -> crate::Result<ContinuationToken> {
let plan = self.plan.as_ref().ok_or_else(|| {
crate::DriverCosmosError::builder()
.with_status(crate::CosmosStatus::CLIENT_CONTINUATION_TOKEN_FETCH_IN_FLIGHT)
.with_message("to_continuation_token called while a page fetch is in flight")
.build()
})?;
plan.to_continuation_token().map_err(Into::into)
}
}
#[pin_project::pin_project(project = PageSourceProj)]
enum PageSource<T: Send> {
Live(Pin<Box<LiveState>>),
#[cfg(test)]
Synthetic(std::collections::VecDeque<crate::Result<QueryFeedPage<T>>>),
#[cfg(not(test))]
#[allow(dead_code)]
_Phantom(PhantomData<fn() -> T>),
}
impl<T: Send + DeserializeOwned + 'static> PageSource<T> {
fn poll_next_page(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<Option<crate::Result<QueryFeedPage<T>>>> {
match self.project() {
PageSourceProj::Live(state) => state.as_mut().poll_next_page::<T>(cx),
#[cfg(test)]
PageSourceProj::Synthetic(pages) => task::Poll::Ready(pages.pop_front()),
#[cfg(not(test))]
PageSourceProj::_Phantom(_) => task::Poll::Ready(None),
}
}
}
#[pin_project::pin_project]
pub struct QueryItemIterator<T: Send> {
#[pin]
source: PageSource<T>,
current: Option<std::vec::IntoIter<T>>,
_marker: PhantomData<fn() -> T>,
}
impl<T: Send + DeserializeOwned + 'static> QueryItemIterator<T> {
pub(crate) fn new(
driver: Arc<CosmosDriver>,
container: Option<ContainerReference>,
plan: OperationPlan,
options: OperationOptions,
) -> Self {
Self {
source: PageSource::Live(Box::pin(LiveState::new(driver, container, plan, options))),
current: None,
_marker: PhantomData,
}
}
pub fn into_pages(self) -> QueryPageIterator<T> {
QueryPageIterator {
source: self.source,
_marker: PhantomData,
}
}
}
impl<T: Send + DeserializeOwned + 'static> Stream for QueryItemIterator<T> {
type Item = crate::Result<T>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<Option<Self::Item>> {
let mut this = self.project();
loop {
if let Some(current) = this.current {
if let Some(item) = current.next() {
return task::Poll::Ready(Some(Ok(item)));
}
this.current.take();
}
match this.source.as_mut().poll_next_page(cx) {
task::Poll::Ready(Some(Ok(page))) => {
*this.current = Some(page.into_items().into_iter());
continue;
}
task::Poll::Ready(Some(Err(err))) => return task::Poll::Ready(Some(Err(err))),
task::Poll::Ready(None) => return task::Poll::Ready(None),
task::Poll::Pending => return task::Poll::Pending,
}
}
}
}
#[pin_project::pin_project]
pub struct QueryPageIterator<T: Send> {
#[pin]
source: PageSource<T>,
_marker: PhantomData<fn() -> T>,
}
impl<T: Send + DeserializeOwned + 'static> QueryPageIterator<T> {
pub fn to_continuation_token(&self) -> crate::Result<ContinuationToken> {
match &self.source {
PageSource::Live(state) => state.to_continuation_token(),
#[cfg(test)]
PageSource::Synthetic(_) => Err(crate::DriverCosmosError::builder()
.with_status(crate::CosmosStatus::new(
azure_core::http::StatusCode::BadRequest,
))
.with_message("synthetic test iterator does not support to_continuation_token")
.build()
.into()),
#[cfg(not(test))]
PageSource::_Phantom(_) => unreachable!(),
}
}
}
impl<T: Send + DeserializeOwned + 'static> Stream for QueryPageIterator<T> {
type Item = crate::Result<QueryFeedPage<T>>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<Option<Self::Item>> {
let this = self.project();
this.source.poll_next_page(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use azure_data_cosmos_driver::models::ActivityId;
use futures::StreamExt;
fn create_test_page<T>(items: Vec<T>) -> QueryFeedPage<T> {
QueryFeedPage {
page: FeedPage::new(
items,
ResponseHeaders::default(),
Arc::new(DiagnosticsContext::for_testing(ActivityId::new_uuid())),
),
index_metrics: None,
query_metrics: None,
}
}
fn synthetic_item_iter<T: Send + DeserializeOwned + 'static>(
pages: Vec<crate::Result<QueryFeedPage<T>>>,
) -> QueryItemIterator<T> {
QueryItemIterator {
source: PageSource::Synthetic(pages.into()),
current: None,
_marker: PhantomData,
}
}
#[tokio::test]
async fn item_iterator_yields_all_items_from_multiple_pages() {
let pages = vec![
Ok(create_test_page(vec![1, 2, 3])),
Ok(create_test_page(vec![4, 5])),
Ok(create_test_page(vec![6])),
];
let item_iter = synthetic_item_iter(pages);
let items: Vec<_> = item_iter
.collect::<Vec<_>>()
.await
.into_iter()
.map(|r| r.unwrap())
.collect();
assert_eq!(items, vec![1, 2, 3, 4, 5, 6]);
}
#[tokio::test]
async fn page_iterator_yields_all_pages() {
let pages = vec![
Ok(create_test_page(vec![1, 2])),
Ok(create_test_page(vec![3])),
];
let page_iter = synthetic_item_iter(pages).into_pages();
let page_items: Vec<_> = page_iter
.collect::<Vec<_>>()
.await
.into_iter()
.map(|r| r.unwrap().into_items())
.collect();
assert_eq!(page_items, vec![vec![1, 2], vec![3]]);
}
#[tokio::test]
async fn item_iterator_propagates_errors() {
let pages = vec![
Ok(create_test_page(vec![1, 2])),
Err(crate::DriverCosmosError::builder()
.with_status(crate::CosmosStatus::new(
azure_core::http::StatusCode::BadRequest,
))
.with_message("test error")
.build()
.into()),
];
let mut item_iter = synthetic_item_iter(pages);
assert_eq!(item_iter.next().await.unwrap().unwrap(), 1);
assert_eq!(item_iter.next().await.unwrap().unwrap(), 2);
assert!(item_iter.next().await.unwrap().is_err());
}
#[tokio::test]
async fn item_iterator_handles_empty_pages() {
let pages = vec![
Ok(create_test_page(vec![1])),
Ok(create_test_page(vec![])),
Ok(create_test_page(vec![2])),
];
let item_iter = synthetic_item_iter(pages);
let items: Vec<_> = item_iter
.collect::<Vec<_>>()
.await
.into_iter()
.map(|r| r.unwrap())
.collect();
assert_eq!(items, vec![1, 2]);
}
}