use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use k2db_api_contract::{
AggregateRequest, CountRequest, CountResult, CreateIndexesRequest, CreateResult, HealthOk,
MessageResponse, PatchCollectionRequest, ProblemDetailsPayload, ReadyNotOk, ReadyOk,
RestoreRequest, RestoreResult, SearchRequest, UpdateResult, VersionInfo,
VersionedUpdateRequest, VersionedUpdateResult,
};
use reqwest::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
use reqwest::{Method, StatusCode};
use serde::{Serialize, de::DeserializeOwned};
use serde_json::Value;
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore};
use crate::error::K2DbApiClientError;
#[derive(Debug, Clone)]
pub struct K2DbApiClientOptions {
pub base_url: String,
pub api_key: Option<String>,
pub headers: Vec<(String, String)>,
pub read_cache_ttl_ms: Option<u64>,
pub max_concurrent_requests: Option<usize>,
}
#[derive(Debug, Clone, Default)]
pub struct RequestOptions {
pub api_key: Option<String>,
pub scope: Option<String>,
pub headers: Vec<(String, String)>,
pub cache_ttl_ms: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct K2DbApiClient {
base_url: String,
api_key: Option<String>,
default_headers: HeaderMap,
http: reqwest::Client,
read_cache_ttl_ms: u64,
cache: Arc<Mutex<HashMap<String, CacheEntry>>>,
inflight: Arc<Mutex<HashMap<String, Arc<InFlightRequest>>>>,
semaphore: Option<Arc<Semaphore>>,
}
#[derive(Debug, Clone)]
struct CacheEntry {
expires_at: Instant,
value: Value,
}
#[derive(Debug)]
struct InFlightRequest {
notify: Notify,
result: Mutex<Option<Result<Value, SharedClientError>>>,
}
#[derive(Debug, Clone)]
enum SharedClientError {
Transport(String),
Problem(ProblemDetailsPayload),
Configuration(String),
Serialization(String),
}
impl From<SharedClientError> for K2DbApiClientError {
fn from(value: SharedClientError) -> Self {
match value {
SharedClientError::Transport(message) => Self::Transport(message),
SharedClientError::Problem(problem) => Self::Problem(problem),
SharedClientError::Configuration(message) => Self::Configuration(message),
SharedClientError::Serialization(message) => Self::Serialization(message),
}
}
}
impl InFlightRequest {
fn new() -> Self {
Self {
notify: Notify::new(),
result: Mutex::new(None),
}
}
fn finish(&self, result: Result<Value, SharedClientError>) {
*self.result.lock().expect("inflight result lock") = Some(result);
self.notify.notify_waiters();
}
async fn wait(&self) -> Result<Value, SharedClientError> {
loop {
if let Some(result) = self.result.lock().expect("inflight result lock clone").clone() {
return result;
}
self.notify.notified().await;
}
}
}
impl K2DbApiClient {
pub fn new(options: K2DbApiClientOptions) -> Result<Self, K2DbApiClientError> {
let base_url = options.base_url.trim_end_matches('/').to_owned();
if base_url.is_empty() {
return Err(K2DbApiClientError::Configuration(
"base_url is required".to_owned(),
));
}
let mut default_headers = HeaderMap::new();
default_headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
for (name, value) in options.headers {
let header_name = HeaderName::try_from(name.as_str()).map_err(|_| {
K2DbApiClientError::Configuration(format!("invalid header name: {name}"))
})?;
let header_value = HeaderValue::try_from(value.as_str()).map_err(|_| {
K2DbApiClientError::Configuration(format!("invalid header value for {name}"))
})?;
default_headers.insert(header_name, header_value);
}
let read_cache_ttl_ms = options.read_cache_ttl_ms.unwrap_or(60_000);
let max_concurrent_requests = options.max_concurrent_requests.unwrap_or(0);
Ok(Self {
base_url,
api_key: options.api_key,
default_headers,
http: reqwest::Client::new(),
read_cache_ttl_ms,
cache: Arc::new(Mutex::new(HashMap::new())),
inflight: Arc::new(Mutex::new(HashMap::new())),
semaphore: (max_concurrent_requests > 0)
.then(|| Arc::new(Semaphore::new(max_concurrent_requests))),
})
}
pub async fn health(&self) -> Result<HealthOk, K2DbApiClientError> {
self.get_json("/health", &RequestOptions::default(), false).await
}
pub async fn ready(&self) -> Result<Result<ReadyOk, ReadyNotOk>, K2DbApiClientError> {
let (status, payload) = self
.request_value_with_status(
Method::GET,
"/ready",
&RequestOptions::default(),
Option::<&()>::None,
None,
false,
)
.await?;
match status {
status if status.is_success() => Ok(Ok(from_value(payload)?)),
StatusCode::SERVICE_UNAVAILABLE => Ok(Err(from_value(payload)?)),
_ => Err(error_from_status_payload(status, payload)),
}
}
pub async fn create<T: Serialize>(
&self,
collection: &str,
document: &T,
options: &RequestOptions,
) -> Result<CreateResult, K2DbApiClientError> {
let result = self
.request_json(
Method::POST,
&format!("/v1/{}", encode_segment(collection)),
options,
Some(document),
None,
true,
)
.await?;
self.invalidate_cache();
Ok(result)
}
pub async fn get_by_id<T: DeserializeOwned>(
&self,
collection: &str,
id: &str,
options: &RequestOptions,
) -> Result<T, K2DbApiClientError> {
self.request_json(
Method::GET,
&format!("/v1/{}/{}", encode_segment(collection), encode_segment(id)),
options,
Option::<&()>::None,
None,
true,
)
.await
}
pub async fn patch_by_id<T: Serialize>(
&self,
collection: &str,
id: &str,
document: &T,
options: &RequestOptions,
) -> Result<UpdateResult, K2DbApiClientError> {
let result = self
.request_json(
Method::PATCH,
&format!("/v1/{}/{}", encode_segment(collection), encode_segment(id)),
options,
Some(document),
None,
true,
)
.await?;
self.invalidate_cache();
Ok(result)
}
pub async fn delete_by_id(
&self,
collection: &str,
id: &str,
options: &RequestOptions,
) -> Result<(), K2DbApiClientError> {
self.request_empty(
Method::DELETE,
&format!("/v1/{}/{}", encode_segment(collection), encode_segment(id)),
options,
Option::<&()>::None,
None,
true,
)
.await?;
self.invalidate_cache();
Ok(())
}
pub async fn patch_collection(
&self,
collection: &str,
payload: &PatchCollectionRequest,
options: &RequestOptions,
) -> Result<UpdateResult, K2DbApiClientError> {
let result = self
.request_json(
Method::PATCH,
&format!("/v1/{}", encode_segment(collection)),
options,
Some(payload),
None,
true,
)
.await?;
self.invalidate_cache();
Ok(result)
}
pub async fn search<T: DeserializeOwned>(
&self,
collection: &str,
payload: &SearchRequest,
options: &RequestOptions,
) -> Result<Vec<T>, K2DbApiClientError> {
self.request_json(
Method::POST,
&format!("/v1/{}/search", encode_segment(collection)),
options,
Some(payload),
None,
true,
)
.await
}
pub async fn aggregate<T: DeserializeOwned>(
&self,
collection: &str,
payload: &AggregateRequest,
options: &RequestOptions,
) -> Result<Vec<T>, K2DbApiClientError> {
self.request_json(
Method::POST,
&format!("/v1/{}/aggregate", encode_segment(collection)),
options,
Some(payload),
None,
true,
)
.await
}
pub async fn count(
&self,
collection: &str,
payload: &CountRequest,
options: &RequestOptions,
) -> Result<CountResult, K2DbApiClientError> {
self.request_json(
Method::POST,
&format!("/v1/{}/count", encode_segment(collection)),
options,
Some(payload),
None,
true,
)
.await
}
pub async fn restore(
&self,
collection: &str,
payload: &RestoreRequest,
options: &RequestOptions,
) -> Result<RestoreResult, K2DbApiClientError> {
let result = self
.request_json(
Method::POST,
&format!("/v1/{}/restore", encode_segment(collection)),
options,
Some(payload),
None,
true,
)
.await?;
self.invalidate_cache();
Ok(result)
}
pub async fn get_versions(
&self,
collection: &str,
id: &str,
skip: Option<u64>,
limit: Option<u64>,
options: &RequestOptions,
) -> Result<Vec<VersionInfo>, K2DbApiClientError> {
let query = query_pairs(skip, limit);
self.request_json(
Method::GET,
&format!("/v1/{}/{}/versions", encode_segment(collection), encode_segment(id)),
options,
Option::<&()>::None,
Some(&query),
true,
)
.await
}
pub async fn patch_versions(
&self,
collection: &str,
id: &str,
payload: &VersionedUpdateRequest,
options: &RequestOptions,
) -> Result<Vec<VersionedUpdateResult>, K2DbApiClientError> {
let result = self
.request_json(
Method::PATCH,
&format!("/v1/{}/{}/versions", encode_segment(collection), encode_segment(id)),
options,
Some(payload),
None,
true,
)
.await?;
self.invalidate_cache();
Ok(result)
}
pub async fn revert_version(
&self,
collection: &str,
id: &str,
version: u64,
options: &RequestOptions,
) -> Result<UpdateResult, K2DbApiClientError> {
let result = self
.request_json(
Method::POST,
&format!(
"/v1/{}/{}/versions/{}/revert",
encode_segment(collection),
encode_segment(id),
version
),
options,
Option::<&()>::None,
None,
true,
)
.await?;
self.invalidate_cache();
Ok(result)
}
pub async fn admin_delete_collection(
&self,
collection: &str,
options: &RequestOptions,
) -> Result<(), K2DbApiClientError> {
self.request_empty(
Method::DELETE,
&format!("/v1/admin/{}", encode_segment(collection)),
options,
Option::<&()>::None,
None,
true,
)
.await?;
self.invalidate_cache();
Ok(())
}
pub async fn admin_delete_by_id(
&self,
collection: &str,
id: &str,
options: &RequestOptions,
) -> Result<(), K2DbApiClientError> {
self.request_empty(
Method::DELETE,
&format!("/v1/admin/{}/{}", encode_segment(collection), encode_segment(id)),
options,
Option::<&()>::None,
None,
true,
)
.await?;
self.invalidate_cache();
Ok(())
}
pub async fn admin_create_indexes(
&self,
collection: &str,
payload: &CreateIndexesRequest,
options: &RequestOptions,
) -> Result<MessageResponse, K2DbApiClientError> {
let result = self
.request_json(
Method::POST,
&format!("/v1/admin/{}/indexes", encode_segment(collection)),
options,
Some(payload),
None,
true,
)
.await?;
self.invalidate_cache();
Ok(result)
}
pub async fn admin_create_history_indexes(
&self,
collection: &str,
options: &RequestOptions,
) -> Result<MessageResponse, K2DbApiClientError> {
let result = self
.request_json(
Method::POST,
&format!("/v1/admin/{}/history-indexes", encode_segment(collection)),
options,
Some(&serde_json::json!({})),
None,
true,
)
.await?;
self.invalidate_cache();
Ok(result)
}
async fn get_json<T: DeserializeOwned>(
&self,
path: &str,
options: &RequestOptions,
include_auth: bool,
) -> Result<T, K2DbApiClientError> {
let value = self
.request_value(
Method::GET,
path,
options,
Option::<&()>::None,
None,
include_auth,
)
.await?;
from_value(value)
}
async fn request_json<T, B>(
&self,
method: Method,
path: &str,
options: &RequestOptions,
body: Option<&B>,
query: Option<&[(String, String)]>,
include_auth: bool,
) -> Result<T, K2DbApiClientError>
where
T: DeserializeOwned,
B: Serialize + ?Sized,
{
let value = self
.request_value(method, path, options, body, query, include_auth)
.await?;
from_value(value)
}
async fn request_empty<B>(
&self,
method: Method,
path: &str,
options: &RequestOptions,
body: Option<&B>,
query: Option<&[(String, String)]>,
include_auth: bool,
) -> Result<(), K2DbApiClientError>
where
B: Serialize + ?Sized,
{
let url = self.url(path, query);
let headers = self.headers(options, include_auth, body.is_some())?;
let body_bytes = self.serialize_body_bytes(body)?;
let (status, payload) = self
.perform_request(method, url, headers, body_bytes)
.await
.map_err(K2DbApiClientError::from)?;
if status.is_success() {
return Ok(());
}
Err(error_from_status_payload(status, payload))
}
async fn request_value<B>(
&self,
method: Method,
path: &str,
options: &RequestOptions,
body: Option<&B>,
query: Option<&[(String, String)]>,
include_auth: bool,
) -> Result<Value, K2DbApiClientError>
where
B: Serialize + ?Sized,
{
let (status, payload) = self
.request_value_with_status(method, path, options, body, query, include_auth)
.await?;
if status.is_success() {
return Ok(payload);
}
Err(error_from_status_payload(status, payload))
}
async fn request_value_with_status<B>(
&self,
method: Method,
path: &str,
options: &RequestOptions,
body: Option<&B>,
query: Option<&[(String, String)]>,
include_auth: bool,
) -> Result<(StatusCode, Value), K2DbApiClientError>
where
B: Serialize + ?Sized,
{
let url = self.url(path, query);
let headers = self.headers(options, include_auth, body.is_some())?;
let body_value = self.serialize_body_value(body)?;
let body_bytes = self.serialize_body_bytes(body)?;
let ttl_ms = self.cache_ttl_ms(method.as_str(), path, options);
if ttl_ms > 0 {
let signature = self.request_signature(method.as_str(), &url, &headers, body_value.as_ref())?;
if let Some(value) = self.cache_get(&signature) {
return Ok((StatusCode::OK, value));
}
let (request, owner) = {
let mut inflight = self.inflight.lock().expect("inflight map lock");
if let Some(existing) = inflight.get(&signature).cloned() {
(existing, false)
} else {
let created = Arc::new(InFlightRequest::new());
inflight.insert(signature.clone(), created.clone());
(created, true)
}
};
if owner {
let result = self
.perform_request(method, url, headers, body_bytes)
.await;
let shared = result
.map_err(SharedClientError::from)
.and_then(|(status, payload)| {
if status.is_success() {
Ok(payload)
} else {
Err(SharedClientError::from(error_from_status_payload(status, payload)))
}
});
if let Ok(value) = &shared {
self.cache_set(&signature, ttl_ms, value.clone());
}
request.finish(shared.clone());
self.inflight
.lock()
.expect("inflight map lock remove")
.remove(&signature);
return shared.map(|value| (StatusCode::OK, value)).map_err(K2DbApiClientError::from);
}
return request
.wait()
.await
.map(|value| (StatusCode::OK, value))
.map_err(K2DbApiClientError::from);
}
self.perform_request(method, url, headers, body_bytes)
.await
.map_err(K2DbApiClientError::from)
}
async fn perform_request(
&self,
method: Method,
url: String,
headers: HeaderMap,
body: Option<Vec<u8>>,
) -> Result<(StatusCode, Value), SharedClientError> {
let mut request = self.http.request(method, url).headers(headers);
if let Some(body) = body {
request = request.body(body);
}
let _permit = self.acquire_request_slot().await?;
let response = request
.send()
.await
.map_err(|error| SharedClientError::Transport(error.to_string()))?;
let status = response.status();
if status == StatusCode::NO_CONTENT {
return Ok((status, Value::Null));
}
let content_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.unwrap_or_default()
.to_owned();
let text = response
.text()
.await
.map_err(|error| SharedClientError::Transport(error.to_string()))?;
if content_type.contains("application/json") || content_type.contains("application/problem+json") {
let payload = serde_json::from_str::<Value>(&text)
.map_err(|error| SharedClientError::Serialization(error.to_string()))?;
return Ok((status, payload));
}
Ok((status, Value::String(text)))
}
async fn acquire_request_slot(&self) -> Result<Option<OwnedSemaphorePermit>, SharedClientError> {
match &self.semaphore {
Some(semaphore) => semaphore
.clone()
.acquire_owned()
.await
.map(Some)
.map_err(|error| SharedClientError::Transport(error.to_string())),
None => Ok(None),
}
}
fn invalidate_cache(&self) {
self.cache.lock().expect("cache lock clear").clear();
}
fn cache_ttl_ms(&self, method: &str, path: &str, options: &RequestOptions) -> u64 {
if self.is_cacheable_read(method, path) {
options.cache_ttl_ms.unwrap_or(self.read_cache_ttl_ms)
} else {
0
}
}
fn is_cacheable_read(&self, method: &str, path: &str) -> bool {
if method == "GET" {
return true;
}
method == "POST"
&& (path.ends_with("/search") || path.ends_with("/aggregate") || path.ends_with("/count"))
}
fn cache_get(&self, signature: &str) -> Option<Value> {
let mut cache = self.cache.lock().expect("cache lock get");
let entry = cache.get(signature)?.clone();
if entry.expires_at > Instant::now() {
Some(entry.value)
} else {
cache.remove(signature);
None
}
}
fn cache_set(&self, signature: &str, ttl_ms: u64, value: Value) {
self.cache.lock().expect("cache lock set").insert(
signature.to_owned(),
CacheEntry {
expires_at: Instant::now() + Duration::from_millis(ttl_ms),
value,
},
);
}
fn request_signature(
&self,
method: &str,
url: &str,
headers: &HeaderMap,
body: Option<&Value>,
) -> Result<String, K2DbApiClientError> {
let scope = headers
.get("x-scope")
.and_then(|value| value.to_str().ok())
.unwrap_or_default();
let auth = headers
.get(AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.unwrap_or_default();
let auth_key = if auth.is_empty() { String::new() } else { fnv1a64(auth) };
let body_key = body.map(stable_stringify).unwrap_or_default();
let raw = format!("{method} {url}\nscope={scope}\nauth={auth_key}\nbody={body_key}");
Ok(fnv1a64(&raw))
}
fn serialize_body_value<B: Serialize + ?Sized>(
&self,
body: Option<&B>,
) -> Result<Option<Value>, K2DbApiClientError> {
body.map(|value| {
serde_json::to_value(value)
.map_err(|error| K2DbApiClientError::Serialization(error.to_string()))
})
.transpose()
}
fn serialize_body_bytes<B: Serialize + ?Sized>(
&self,
body: Option<&B>,
) -> Result<Option<Vec<u8>>, K2DbApiClientError> {
body.map(|value| {
serde_json::to_vec(value)
.map_err(|error| K2DbApiClientError::Serialization(error.to_string()))
})
.transpose()
}
fn url(&self, path: &str, query: Option<&[(String, String)]>) -> String {
let mut url = format!("{}{}", self.base_url, path);
if let Some(query) = query {
if !query.is_empty() {
let mut first = true;
for (key, value) in query {
url.push(if first { '?' } else { '&' });
first = false;
url.push_str(&urlencoding::encode(key));
url.push('=');
url.push_str(&urlencoding::encode(value));
}
}
}
url
}
fn headers(
&self,
options: &RequestOptions,
include_auth: bool,
include_content_type: bool,
) -> Result<HeaderMap, K2DbApiClientError> {
let mut headers = self.default_headers.clone();
for (name, value) in &options.headers {
let header_name = HeaderName::try_from(name.as_str()).map_err(|_| {
K2DbApiClientError::Configuration(format!("invalid header name: {name}"))
})?;
let header_value = HeaderValue::try_from(value.as_str()).map_err(|_| {
K2DbApiClientError::Configuration(format!("invalid header value for {name}"))
})?;
headers.insert(header_name, header_value);
}
if let Some(scope) = &options.scope {
headers.insert(
HeaderName::from_static("x-scope"),
HeaderValue::try_from(scope.as_str()).map_err(|_| {
K2DbApiClientError::Configuration("invalid x-scope header value".to_owned())
})?,
);
}
if include_auth {
let auth = options.api_key.as_ref().or(self.api_key.as_ref());
if let Some(auth) = auth {
let value = if auth.starts_with("ApiKey ") {
auth.clone()
} else {
format!("ApiKey {auth}")
};
headers.insert(
AUTHORIZATION,
HeaderValue::try_from(value.as_str()).map_err(|_| {
K2DbApiClientError::Configuration(
"invalid authorization header value".to_owned(),
)
})?,
);
}
}
if include_content_type {
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
}
Ok(headers)
}
}
impl From<K2DbApiClientError> for SharedClientError {
fn from(value: K2DbApiClientError) -> Self {
match value {
K2DbApiClientError::Http(error) => Self::Transport(error.to_string()),
K2DbApiClientError::Transport(message) => Self::Transport(message),
K2DbApiClientError::Problem(problem) => Self::Problem(problem),
K2DbApiClientError::Configuration(message) => Self::Configuration(message),
K2DbApiClientError::Serialization(message) => Self::Serialization(message),
}
}
}
fn from_value<T: DeserializeOwned>(value: Value) -> Result<T, K2DbApiClientError> {
serde_json::from_value(value).map_err(|error| K2DbApiClientError::Serialization(error.to_string()))
}
fn error_from_status_payload(status: StatusCode, payload: Value) -> K2DbApiClientError {
if let Ok(problem) = serde_json::from_value::<ProblemDetailsPayload>(payload.clone()) {
return K2DbApiClientError::Problem(problem);
}
let detail = match payload {
Value::String(value) if !value.is_empty() => value,
Value::Object(map) => map
.get("detail")
.and_then(|value| value.as_str())
.map(ToOwned::to_owned)
.unwrap_or_else(|| format!("request failed: {status}")),
_ => format!("request failed: {status}"),
};
K2DbApiClientError::Transport(detail)
}
fn stable_stringify(value: &Value) -> String {
match value {
Value::Null => "null".to_owned(),
Value::Bool(value) => value.to_string(),
Value::Number(value) => value.to_string(),
Value::String(value) => serde_json::to_string(value).unwrap_or_else(|_| "\"\"".to_owned()),
Value::Array(values) => format!(
"[{}]",
values.iter().map(stable_stringify).collect::<Vec<_>>().join(",")
),
Value::Object(map) => {
let mut keys = map.keys().cloned().collect::<Vec<_>>();
keys.sort();
format!(
"{{{}}}",
keys.iter()
.map(|key| {
let encoded = serde_json::to_string(key).unwrap_or_else(|_| "\"\"".to_owned());
let value = map.get(key).expect("stable stringify value");
format!("{encoded}:{}", stable_stringify(value))
})
.collect::<Vec<_>>()
.join(",")
)
}
}
}
fn fnv1a64(value: &str) -> String {
let mut hash = 0xcbf29ce484222325_u64;
for byte in value.as_bytes() {
hash ^= u64::from(*byte);
hash = hash.wrapping_mul(0x100000001b3);
}
format!("{hash:016x}")
}
fn encode_segment(value: &str) -> String {
urlencoding::encode(value).into_owned()
}
fn query_pairs(skip: Option<u64>, limit: Option<u64>) -> Vec<(String, String)> {
let mut out = Vec::new();
if let Some(skip) = skip {
out.push(("skip".to_owned(), skip.to_string()));
}
if let Some(limit) = limit {
out.push(("limit".to_owned(), limit.to_string()));
}
out
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use axum::Json;
use axum::extract::Path;
use axum::routing::{get, patch};
use axum::{Router, serve};
use serde_json::json;
use tokio::net::TcpListener;
use tokio::time::sleep;
use super::*;
async fn start_server(router: Router) -> String {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let address = listener.local_addr().expect("local addr");
tokio::spawn(async move {
serve(listener, router).await.expect("serve");
});
format!("http://{}", address)
}
#[tokio::test]
async fn cache_and_inflight_collapse_share_single_read() {
let hits = Arc::new(AtomicUsize::new(0));
let router = {
let hits = hits.clone();
Router::new().route(
"/health",
get(move || {
let hits = hits.clone();
async move {
hits.fetch_add(1, Ordering::SeqCst);
sleep(Duration::from_millis(50)).await;
Json(json!({ "status": "ok" }))
}
}),
)
};
let base_url = start_server(router).await;
let client = K2DbApiClient::new(K2DbApiClientOptions {
base_url,
api_key: None,
headers: Vec::new(),
read_cache_ttl_ms: Some(1_000),
max_concurrent_requests: None,
})
.expect("client");
let (left, right) = tokio::join!(client.health(), client.health());
assert_eq!(left.expect("left").status, "ok");
assert_eq!(right.expect("right").status, "ok");
assert_eq!(hits.load(Ordering::SeqCst), 1);
let third = client.health().await.expect("third");
assert_eq!(third.status, "ok");
assert_eq!(hits.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn mutating_requests_invalidate_read_cache() {
let reads = Arc::new(AtomicUsize::new(0));
let router = {
let reads_get = reads.clone();
Router::new()
.route(
"/v1/widgets/alpha",
get(move || {
let reads = reads_get.clone();
async move {
reads.fetch_add(1, Ordering::SeqCst);
Json(json!({ "name": "before" }))
}
})
.patch(|| async { Json(json!({ "updated": 1 })) }),
)
};
let base_url = start_server(router).await;
let client = K2DbApiClient::new(K2DbApiClientOptions {
base_url,
api_key: Some("demo.secret".to_owned()),
headers: Vec::new(),
read_cache_ttl_ms: Some(1_000),
max_concurrent_requests: None,
})
.expect("client");
let options = RequestOptions {
scope: Some("owner:demo".to_owned()),
..RequestOptions::default()
};
let _: Value = client.get_by_id("widgets", "alpha", &options).await.expect("first read");
let _: Value = client.get_by_id("widgets", "alpha", &options).await.expect("cached read");
assert_eq!(reads.load(Ordering::SeqCst), 1);
client
.patch_by_id("widgets", "alpha", &json!({ "name": "after" }), &options)
.await
.expect("patch");
let _: Value = client.get_by_id("widgets", "alpha", &options).await.expect("after patch read");
assert_eq!(reads.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn concurrency_limit_serializes_requests() {
let active = Arc::new(AtomicUsize::new(0));
let max_seen = Arc::new(AtomicUsize::new(0));
let router = {
let active = active.clone();
let max_seen = max_seen.clone();
Router::new().route(
"/v1/widgets/{id}",
get(move |Path(_id): Path<String>| {
let active = active.clone();
let max_seen = max_seen.clone();
async move {
let current = active.fetch_add(1, Ordering::SeqCst) + 1;
let _ = max_seen.fetch_max(current, Ordering::SeqCst);
sleep(Duration::from_millis(40)).await;
active.fetch_sub(1, Ordering::SeqCst);
Json(json!({ "ok": true }))
}
}),
)
};
let base_url = start_server(router).await;
let client = K2DbApiClient::new(K2DbApiClientOptions {
base_url,
api_key: Some("demo.secret".to_owned()),
headers: Vec::new(),
read_cache_ttl_ms: Some(0),
max_concurrent_requests: Some(1),
})
.expect("client");
let options = RequestOptions {
scope: Some("owner:demo".to_owned()),
..RequestOptions::default()
};
let (left, right) = tokio::join!(
client.get_by_id::<Value>("widgets", "a", &options),
client.get_by_id::<Value>("widgets", "b", &options)
);
left.expect("left request");
right.expect("right request");
assert_eq!(max_seen.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn patch_collection_calls_collection_patch_endpoint() {
let router = Router::new().route(
"/v1/widgets",
patch(|| async { Json(json!({ "updated": 3 })) }),
);
let base_url = start_server(router).await;
let client = K2DbApiClient::new(K2DbApiClientOptions {
base_url,
api_key: Some("demo.secret".to_owned()),
headers: Vec::new(),
read_cache_ttl_ms: Some(0),
max_concurrent_requests: None,
})
.expect("client");
let result = client
.patch_collection(
"widgets",
&PatchCollectionRequest {
criteria: json!({ "kind": "demo" }),
values: json!({ "name": "updated" }),
},
&RequestOptions {
scope: Some("owner:demo".to_owned()),
..RequestOptions::default()
},
)
.await
.expect("patch collection");
assert_eq!(result.updated, 3);
}
}