use cloud_terrastodon_azure_types::prelude::uuid::Uuid;
use cloud_terrastodon_command::CommandBuilder;
use cloud_terrastodon_command::CommandKind;
use eyre::Result;
use eyre::bail;
use http::Method;
use itertools::Itertools;
use serde::Deserialize;
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::collections::HashMap;
use tracing::debug;
#[derive(Debug, Serialize, Deserialize, Default, Clone)]
pub struct BatchRequest<T> {
pub requests: Vec<BatchRequestEntry<T>>,
}
impl<T> BatchRequest<T>
where
T: Default,
{
pub fn new() -> Self {
BatchRequest::<T>::default()
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct BatchRequestEntry<T> {
#[serde(
rename = "httpMethod",
deserialize_with = "cloud_terrastodon_azure_types::serde_helpers::deserialize_using_from_str",
serialize_with = "cloud_terrastodon_azure_types::serde_helpers::serialize_using_asref_str"
)]
pub http_method: Method,
pub name: Uuid,
pub url: String,
pub content: Option<T>,
}
impl BatchRequestEntry<()> {
pub fn new_get(url: String) -> Self {
BatchRequestEntry {
http_method: Method::GET,
name: Uuid::new_v4(),
url,
content: None,
}
}
}
impl<T> BatchRequestEntry<T> {
pub fn new(http_method: Method, url: String, content: Option<T>) -> Self {
BatchRequestEntry {
http_method,
name: Uuid::new_v4(),
url,
content,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct BatchResponse<T> {
pub responses: Vec<BatchResponseEntry<T>>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct BatchResponseEntry<T> {
pub name: Uuid,
#[serde(rename = "httpStatusCode")]
pub http_status_code: u16,
pub headers: HashMap<String, String>,
pub content: T,
#[serde(rename = "contentLength")]
pub content_length: u64,
}
pub async fn invoke_batch_request<REQ, RESP>(
request: &BatchRequest<REQ>,
) -> Result<BatchResponse<RESP>>
where
REQ: Serialize + Clone,
RESP: DeserializeOwned,
{
let url = "https://management.azure.com/batch?api-version=2020-06-01";
let mut cmd_base = CommandBuilder::new(CommandKind::AzureCLI);
cmd_base.args(["rest", "--method", "POST", "--url", url, "--body"]);
let validator = |response: BatchResponse<RESP>| {
let failures = response
.responses
.iter()
.filter(|resp| resp.http_status_code != 200)
.count();
if failures > 0 {
bail!("There were {} requests with non-200 status codes", failures)
}
Ok(response)
};
let mut rtn = BatchResponse {
responses: Vec::new(),
};
let chunks = request.requests.chunks(20);
let num_chunks = chunks.len();
for (i, chunk) in chunks.enumerate() {
let mut cmd = cmd_base.clone();
cmd.azure_file_arg(
"body.json",
serde_json::to_string_pretty(&BatchRequest {
requests: chunk.iter().cloned().collect_vec(),
})?,
);
debug!(
batch_index = i,
total_batches = num_chunks,
"Performing batch request"
);
let response = cmd.run_with_validator(validator).await?;
rtn.responses.extend(response.responses);
}
assert_eq!(request.requests.len(), rtn.responses.len());
for (a, b) in request.requests.iter().zip(rtn.responses.iter()) {
assert_eq!(a.name, b.name);
}
Ok(rtn)
}
impl<T> BatchRequest<T>
where
T: Serialize + Clone,
{
pub async fn invoke<RESP>(&self) -> eyre::Result<BatchResponse<RESP>>
where
RESP: DeserializeOwned,
{
invoke_batch_request(self).await
}
}