use std::{collections::HashMap, sync::Arc};
use azure_core::http::headers::{HeaderName, HeaderValue};
use azure_data_cosmos_driver::{
models::{CosmosOperation, SessionToken},
options::OperationOptions as DriverOperationOptions,
CosmosDriver,
};
use serde::de::DeserializeOwned;
use crate::{constants, driver_bridge, feed::FeedBody, Query, QueryFeedPage};
pub struct QueryExecutor<T: DeserializeOwned + Send> {
driver: Arc<CosmosDriver>,
operation_factory: Box<dyn Fn() -> CosmosOperation + Send>,
query: Query,
query_body: Option<Vec<u8>>,
base_options: DriverOperationOptions,
base_headers: HashMap<HeaderName, HeaderValue>,
session_token: Option<SessionToken>,
continuation: Option<String>,
complete: bool,
phantom: std::marker::PhantomData<fn() -> T>,
}
impl<T: DeserializeOwned + Send + 'static> QueryExecutor<T> {
pub(crate) fn new(
driver: Arc<CosmosDriver>,
operation_factory: impl Fn() -> CosmosOperation + Send + 'static,
query: Query,
base_options: DriverOperationOptions,
session_token: Option<SessionToken>,
) -> Self {
let mut base_headers = base_options.custom_headers().cloned().unwrap_or_default();
base_headers.insert(constants::QUERY.clone(), HeaderValue::from_static("True"));
base_headers.insert(
azure_core::http::headers::CONTENT_TYPE,
HeaderValue::from_static("application/query+json"),
);
Self {
driver,
operation_factory: Box::new(operation_factory),
query,
query_body: None,
base_options,
base_headers,
session_token,
continuation: None,
complete: false,
phantom: std::marker::PhantomData,
}
}
pub fn into_stream(self) -> azure_core::Result<crate::FeedItemIterator<T>> {
Ok(crate::FeedItemIterator::new(futures::stream::try_unfold(
self,
|mut state| async move {
let val = state.next_page().await?;
Ok(val.map(|item| (item, state)))
},
)))
}
pub async fn next_page(&mut self) -> azure_core::Result<Option<QueryFeedPage<T>>> {
if self.complete {
return Ok(None);
}
let mut operation = (self.operation_factory)();
if self.query_body.is_none() {
self.query_body = Some(serde_json::to_vec(&self.query)?);
}
operation = operation.with_body(self.query_body.clone().unwrap());
if let Some(session_token) = &self.session_token {
operation = operation.with_session_token(session_token.clone());
}
let mut headers = self.base_headers.clone();
if let Some(continuation) = &self.continuation {
headers.insert(
constants::CONTINUATION.clone(),
HeaderValue::from(continuation.clone()),
);
}
let op_options = self.base_options.clone().with_custom_headers(headers);
let driver_response = self.driver.execute_operation(operation, op_options).await?;
let cosmos_response =
driver_bridge::driver_response_to_cosmos_response::<FeedBody<T>>(driver_response);
let page = QueryFeedPage::<T>::from_response(cosmos_response).await?;
match page.continuation() {
Some(token) => self.continuation = Some(token.to_string()),
None => self.complete = true,
}
Ok(Some(page))
}
}