use std::{pin::Pin, task};
use azure_core::http::{
headers::Headers,
pager::{PagerContinuation, PagerResult},
};
use azure_data_cosmos_driver::models::CosmosResponseHeaders;
use futures::stream::BoxStream;
use futures::Stream;
use serde::{de::DeserializeOwned, Deserialize};
use crate::{
constants,
models::{CosmosDiagnostics, CosmosResponse},
SessionToken,
};
#[derive(Debug)]
pub struct FeedPage<T> {
items: Vec<T>,
continuation: Option<String>,
raw_headers: Headers,
headers: CosmosResponseHeaders,
diagnostics: CosmosDiagnostics,
}
impl<T> FeedPage<T> {
pub(crate) fn new(
items: Vec<T>,
continuation: Option<String>,
raw_headers: Headers,
headers: CosmosResponseHeaders,
diagnostics: CosmosDiagnostics,
) -> Self {
Self {
items,
continuation,
raw_headers,
headers,
diagnostics,
}
}
pub fn items(&self) -> &[T] {
&self.items
}
pub fn into_items(self) -> Vec<T> {
self.items
}
pub fn continuation(&self) -> Option<&str> {
self.continuation.as_deref()
}
pub fn headers(&self) -> &Headers {
&self.raw_headers
}
pub fn request_charge(&self) -> Option<f64> {
self.headers.request_charge.as_ref().map(|rc| rc.value())
}
pub fn session_token(&self) -> Option<SessionToken> {
self.headers
.session_token
.as_ref()
.map(|st| SessionToken::from(st.as_str().to_string()))
}
pub fn diagnostics(&self) -> &CosmosDiagnostics {
&self.diagnostics
}
}
impl<T> From<FeedPage<T>> for PagerResult<FeedPage<T>> {
fn from(value: FeedPage<T>) -> Self {
let continuation = value.continuation.clone();
match continuation {
Some(continuation) => PagerResult::More {
response: value,
continuation: PagerContinuation::Token(continuation),
},
None => PagerResult::Done { response: value },
}
}
}
impl<T: DeserializeOwned> FeedPage<T> {
#[allow(dead_code)] pub(crate) async fn from_response(
response: CosmosResponse<FeedBody<T>>,
) -> azure_core::Result<Self> {
let raw_headers = response.headers().clone();
let continuation = raw_headers.get_optional_string(&constants::CONTINUATION);
let cosmos_headers = response.cosmos_headers().clone();
let diagnostics = response.diagnostics().clone();
let body: FeedBody<T> = response.into_model()?;
Ok(Self::new(
body.items,
continuation,
raw_headers,
cosmos_headers,
diagnostics,
))
}
}
#[derive(Debug)]
pub struct QueryFeedPage<T> {
page: FeedPage<T>,
index_metrics: Option<String>,
query_metrics: Option<String>,
}
impl<T> QueryFeedPage<T> {
pub fn items(&self) -> &[T] {
self.page.items()
}
pub fn into_items(self) -> Vec<T> {
self.page.into_items()
}
pub fn continuation(&self) -> Option<&str> {
self.page.continuation()
}
pub fn headers(&self) -> &Headers {
self.page.headers()
}
pub fn request_charge(&self) -> Option<f64> {
self.page.request_charge()
}
pub fn session_token(&self) -> Option<SessionToken> {
self.page.session_token()
}
pub fn diagnostics(&self) -> &CosmosDiagnostics {
self.page.diagnostics()
}
pub fn index_metrics(&self) -> Option<&str> {
self.index_metrics.as_deref()
}
pub fn query_metrics(&self) -> Option<&str> {
self.query_metrics.as_deref()
}
}
impl<T> From<QueryFeedPage<T>> for PagerResult<QueryFeedPage<T>> {
fn from(value: QueryFeedPage<T>) -> Self {
let continuation = value.page.continuation.clone();
match continuation {
Some(continuation) => PagerResult::More {
response: value,
continuation: PagerContinuation::Token(continuation),
},
None => PagerResult::Done { response: value },
}
}
}
#[derive(Deserialize)]
pub(crate) struct FeedBody<T> {
#[serde(alias = "Documents")]
#[serde(alias = "DocumentCollections")]
#[serde(alias = "Databases")]
#[serde(alias = "Offers")]
pub(crate) items: Vec<T>,
}
impl<T: DeserializeOwned> QueryFeedPage<T> {
pub(crate) async fn from_response(
response: CosmosResponse<FeedBody<T>>,
) -> azure_core::Result<Self> {
let raw_headers = response.headers().clone();
let continuation = raw_headers.get_optional_string(&constants::CONTINUATION);
let cosmos_headers = response.cosmos_headers().clone();
let index_metrics = cosmos_headers.index_metrics.clone();
let query_metrics = cosmos_headers.query_metrics.clone();
let diagnostics = response.diagnostics().clone();
let body: FeedBody<T> = response.into_model()?;
Ok(Self {
page: FeedPage::new(
body.items,
continuation,
raw_headers,
cosmos_headers,
diagnostics,
),
index_metrics,
query_metrics,
})
}
}
#[pin_project::pin_project]
pub struct FeedItemIterator<T: Send> {
#[pin]
pages: BoxStream<'static, azure_core::Result<QueryFeedPage<T>>>,
current: Option<std::vec::IntoIter<T>>,
}
impl<T: Send> FeedItemIterator<T> {
pub(crate) fn new(
stream: impl Stream<Item = azure_core::Result<QueryFeedPage<T>>> + Send + 'static,
) -> Self {
Self {
pages: Box::pin(stream),
current: None,
}
}
pub fn into_pages(self) -> FeedPageIterator<T> {
FeedPageIterator(self.pages)
}
}
impl<T: Send> Stream for FeedItemIterator<T> {
type Item = azure_core::Result<T>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<Option<Self::Item>> {
let mut this = self.project();
loop {
if let Some(current) = this.current.as_mut() {
if let Some(item) = current.next() {
return task::Poll::Ready(Some(Ok(item)));
}
*this.current = None;
}
match this.pages.as_mut().poll_next(cx) {
task::Poll::Ready(page) => match page {
Some(Ok(page)) => {
*this.current = Some(page.page.items.into_iter());
continue;
}
Some(Err(err)) => return task::Poll::Ready(Some(Err(err))),
None => return task::Poll::Ready(None),
},
task::Poll::Pending => return task::Poll::Pending,
}
}
}
}
pub struct FeedPageIterator<T: Send>(BoxStream<'static, azure_core::Result<QueryFeedPage<T>>>);
impl<T: Send> Stream for FeedPageIterator<T> {
type Item = azure_core::Result<QueryFeedPage<T>>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<Option<Self::Item>> {
self.0.as_mut().poll_next(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;
fn create_test_page<T>(items: Vec<T>, continuation: Option<String>) -> QueryFeedPage<T> {
QueryFeedPage {
page: FeedPage::new(
items,
continuation,
Headers::new(),
CosmosResponseHeaders::default(),
CosmosDiagnostics::default(),
),
index_metrics: None,
query_metrics: None,
}
}
#[tokio::test]
async fn item_iterator_yields_all_items_from_multiple_pages() {
let pages = vec![
Ok(create_test_page(vec![1, 2, 3], Some("token1".to_string()))),
Ok(create_test_page(vec![4, 5], Some("token2".to_string()))),
Ok(create_test_page(vec![6], None)),
];
let stream = futures::stream::iter(pages);
let item_iter = FeedItemIterator::new(stream);
let items: Vec<_> = item_iter
.collect::<Vec<_>>()
.await
.into_iter()
.map(|r| r.unwrap())
.collect();
assert_eq!(items, vec![1, 2, 3, 4, 5, 6]);
}
#[tokio::test]
async fn page_iterator_yields_all_pages() {
let pages = vec![
Ok(create_test_page(vec![1, 2], Some("token1".to_string()))),
Ok(create_test_page(vec![3], None)),
];
let stream = futures::stream::iter(pages);
let page_iter = FeedItemIterator::new(stream).into_pages();
let page_items: Vec<_> = page_iter
.collect::<Vec<_>>()
.await
.into_iter()
.map(|r| r.unwrap().into_items())
.collect();
assert_eq!(page_items, vec![vec![1, 2], vec![3]]);
}
#[tokio::test]
async fn item_iterator_propagates_errors() {
let pages = vec![
Ok(create_test_page(vec![1, 2], Some("token".to_string()))),
Err(azure_core::Error::new(
azure_core::error::ErrorKind::Other,
"test error",
)),
];
let stream = futures::stream::iter(pages);
let mut item_iter = FeedItemIterator::new(stream);
assert_eq!(item_iter.next().await.unwrap().unwrap(), 1);
assert_eq!(item_iter.next().await.unwrap().unwrap(), 2);
assert!(item_iter.next().await.unwrap().is_err());
}
#[tokio::test]
async fn item_iterator_handles_empty_pages() {
let pages = vec![
Ok(create_test_page(vec![1], Some("token1".to_string()))),
Ok(create_test_page(vec![], Some("token2".to_string()))),
Ok(create_test_page(vec![2], None)),
];
let stream = futures::stream::iter(pages);
let item_iter = FeedItemIterator::new(stream);
let items: Vec<_> = item_iter
.collect::<Vec<_>>()
.await
.into_iter()
.map(|r| r.unwrap())
.collect();
assert_eq!(items, vec![1, 2]);
}
}