use crate::{
clients::offers_client,
models::{CosmosResponse, ResourceResponse, ThroughputProperties},
};
use azure_core::http::StatusCode;
use azure_core::time::Duration;
use azure_data_cosmos_driver::models::AccountReference;
use azure_data_cosmos_driver::CosmosDriver;
use futures::{stream::BoxStream, Stream, StreamExt};
use std::{
future::{Future, IntoFuture},
pin::Pin,
sync::Arc,
task,
};
const DEFAULT_POLLING_INTERVAL: Duration = Duration::seconds(5);
pub struct ThroughputPoller {
stream: BoxStream<'static, crate::Result<CosmosResponse>>,
}
impl ThroughputPoller {
pub(crate) fn new(
initial_response: CosmosResponse,
driver: Arc<CosmosDriver>,
account: AccountReference,
offer_id: String,
) -> Self {
let is_pending = is_offer_replace_pending(&initial_response);
if is_pending {
Self::pending(initial_response, driver, account, offer_id)
} else {
Self::completed(initial_response)
}
}
fn completed(response: CosmosResponse) -> Self {
let stream = futures::stream::once(async { Ok(response) });
Self {
stream: Box::pin(stream),
}
}
fn pending(
initial_response: CosmosResponse,
driver: Arc<CosmosDriver>,
account: AccountReference,
offer_id: String,
) -> Self {
let polling_interval = DEFAULT_POLLING_INTERVAL;
let stream = futures::stream::unfold(
Some(PollState::Initial(Box::new(initial_response))),
move |state| {
let driver = driver.clone();
let account = account.clone();
let offer_id = offer_id.clone();
async move {
let state = state?;
match state {
PollState::Initial(response) => {
Some((Ok(*response), Some(PollState::Polling)))
}
PollState::Polling => {
azure_core::sleep::sleep(polling_interval).await;
let result =
offers_client::read_offer_by_id(&driver, &account, &offer_id).await;
match result {
Ok(response) => {
if is_offer_replace_pending(&response) {
Some((Ok(response), Some(PollState::Polling)))
} else {
Some((Ok(response), None))
}
}
Err(e) => Some((Err(e), None)),
}
}
}
}
},
);
Self {
stream: Box::pin(stream),
}
}
}
enum PollState {
Initial(Box<CosmosResponse>),
Polling,
}
impl Stream for ThroughputPoller {
type Item = crate::Result<ResourceResponse<ThroughputProperties>>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<Option<Self::Item>> {
self.stream
.poll_next_unpin(cx)
.map(|opt| opt.map(|res| res.map(ResourceResponse::new)))
}
}
impl IntoFuture for ThroughputPoller {
type Output = crate::Result<ResourceResponse<ThroughputProperties>>;
type IntoFuture =
Pin<Box<dyn Future<Output = crate::Result<ResourceResponse<ThroughputProperties>>> + Send>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
let mut stream = self.stream;
let mut last_response = None;
while let Some(result) = stream.next().await {
last_response = Some(result?);
}
last_response.map(ResourceResponse::new).ok_or_else(|| {
crate::CosmosError::from(
crate::DriverCosmosError::builder()
.with_status(crate::CosmosStatus::CLIENT_THROUGHPUT_POLLER_INCOMPLETE)
.with_message("throughput poller stream ended without yielding a response")
.build(),
)
})
})
}
}
fn is_offer_replace_pending(response: &CosmosResponse) -> bool {
if response.cosmos_headers().offer_replace_pending() == Some(true) {
return true;
}
response.status().status_code() == StatusCode::Accepted
}
#[cfg(test)]
mod tests {
use super::*;
use futures::TryStreamExt;
#[test]
fn is_offer_replace_pending_returns_false_for_ok() {
let response = create_mock_response(StatusCode::Ok, None);
assert!(!is_offer_replace_pending(&response));
}
#[test]
fn is_offer_replace_pending_returns_true_for_accepted() {
let response = create_mock_response(StatusCode::Accepted, None);
assert!(is_offer_replace_pending(&response));
}
#[test]
fn is_offer_replace_pending_returns_true_for_header() {
let response = create_mock_response(StatusCode::Ok, Some("true"));
assert!(is_offer_replace_pending(&response));
}
#[test]
fn is_offer_replace_pending_returns_false_for_header_false() {
let response = create_mock_response(StatusCode::Ok, Some("false"));
assert!(!is_offer_replace_pending(&response));
}
#[test]
fn is_offer_replace_pending_handles_pascal_case_true_from_headers() {
use azure_core::http::headers::Headers;
use azure_data_cosmos_driver::models::CosmosResponseHeaders;
for raw in ["True", "TRUE", "tRuE"] {
let mut wire_headers = Headers::new();
wire_headers.insert(crate::constants::OFFER_REPLACE_PENDING, raw.to_owned());
let parsed = CosmosResponseHeaders::from_headers(&wire_headers);
assert_eq!(
parsed.offer_replace_pending,
Some(true),
"{raw:?} should parse as Some(true) from the wire"
);
let response = build_mock_response_from_parsed_headers(StatusCode::Ok, parsed);
assert!(
is_offer_replace_pending(&response),
"{raw:?} must keep the poller marked as pending"
);
}
}
fn build_mock_response_from_parsed_headers(
status: StatusCode,
cosmos_headers: azure_data_cosmos_driver::models::CosmosResponseHeaders,
) -> CosmosResponse {
use crate::DiagnosticsContext;
use azure_data_cosmos_driver::models::{ActivityId, CosmosStatus, ResponseBody};
use std::sync::Arc;
let body = ResponseBody::from_bytes(azure_core::Bytes::from_static(b"{}"));
let cosmos_status = CosmosStatus::new(status);
let diagnostics = Arc::new(DiagnosticsContext::for_testing(ActivityId::new_uuid()));
CosmosResponse::from_driver_parts(
body.into(),
cosmos_headers.into(),
cosmos_status,
diagnostics,
)
}
#[tokio::test]
async fn completed_poller_yields_one_item() {
let response = create_mock_response(StatusCode::Ok, None);
let mut poller = ThroughputPoller::completed(response);
let first = poller.try_next().await.expect("should yield Ok");
assert!(first.is_some(), "should yield one item");
let second = poller.try_next().await.expect("should yield Ok");
assert!(second.is_none(), "should end after one item");
}
#[tokio::test]
async fn completed_poller_into_future_returns_response() {
let response = create_mock_response(StatusCode::Ok, None);
let poller = ThroughputPoller::completed(response);
let result = poller.await;
assert!(result.is_ok(), "into_future should return Ok");
assert_eq!(result.unwrap().status().status_code(), StatusCode::Ok);
}
fn create_mock_response(
status: StatusCode,
offer_replace_pending: Option<&str>,
) -> CosmosResponse {
use crate::DiagnosticsContext;
use azure_data_cosmos_driver::models::{
ActivityId, CosmosResponseHeaders, CosmosStatus, ResponseBody,
};
use std::sync::Arc;
let body = ResponseBody::Bytes(azure_core::Bytes::from_static(b"{}"));
let mut cosmos_headers = CosmosResponseHeaders::default();
if let Some(value) = offer_replace_pending {
cosmos_headers.offer_replace_pending = value.parse::<bool>().ok();
}
let cosmos_status = CosmosStatus::new(status);
let diagnostics = Arc::new(DiagnosticsContext::for_testing(ActivityId::new_uuid()));
CosmosResponse::from_driver_parts(
body.into(),
cosmos_headers.into(),
cosmos_status,
diagnostics,
)
}
}