macro_rules! define_batch_response {
(
$(#[$meta:meta])*
$name:ident => $field:ident : $type:ty
) => {
$(#[$meta])*
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct $name {
#[doc = "Successfully fetched data, keyed by symbol"]
pub $field: std::collections::HashMap<String, $type>,
#[doc = "Symbols that failed to fetch, with error messages"]
pub errors: std::collections::HashMap<String, String>,
}
impl $name {
pub(crate) fn with_capacity(capacity: usize) -> Self {
Self {
$field: std::collections::HashMap::with_capacity(capacity),
errors: std::collections::HashMap::with_capacity(capacity),
}
}
#[doc = "Number of successfully fetched items"]
pub fn success_count(&self) -> usize {
self.$field.len()
}
#[doc = "Number of failed symbols"]
pub fn error_count(&self) -> usize {
self.errors.len()
}
#[doc = "Check if all symbols were successful"]
pub fn all_successful(&self) -> bool {
self.errors.is_empty()
}
}
};
}
macro_rules! batch_fetch_cached {
($self:expr;
cache: $cache_field:ident,
guard: map($guard_field:ident, $guard_key:expr),
$($rest:tt)*
) => {
batch_fetch_cached!(@impl $self;
cache: $cache_field,
acquire_guard: {
let __fetch_guard = Self::get_fetch_guard(&$self.$guard_field, $guard_key).await;
let _guard = __fetch_guard.lock().await;
},
$($rest)*
)
};
($self:expr;
cache: $cache_field:ident,
guard: simple($guard_field:ident),
$($rest:tt)*
) => {
batch_fetch_cached!(@impl $self;
cache: $cache_field,
acquire_guard: {
let _guard = $self.$guard_field.lock().await;
},
$($rest)*
)
};
(@impl $self:expr;
cache: $cache_field:ident,
acquire_guard: { $($guard_code:tt)* },
key: |$ksym:ident| $key_expr:expr,
response: $resp_ty:ident . $resp_field:ident,
fetch: |$client:ident, $symbol:ident| $fetch_expr:expr $(,)?
) => {{
let cache_key_fn = |$ksym: &std::sync::Arc<str>| $key_expr;
{
let cache = $self.$cache_field.read().await;
if $self.all_cached(&cache, $self.symbols.iter().map(&cache_key_fn)) {
let mut response = $resp_ty::with_capacity($self.symbols.len());
for symbol in &$self.symbols {
if let Some(entry) = cache.get(&cache_key_fn(symbol)) {
response.$resp_field.insert(symbol.to_string(), entry.value.clone());
}
}
return Ok(response);
}
}
$($guard_code)*
{
let cache = $self.$cache_field.read().await;
if $self.all_cached(&cache, $self.symbols.iter().map(&cache_key_fn)) {
let mut response = $resp_ty::with_capacity($self.symbols.len());
for symbol in &$self.symbols {
if let Some(entry) = cache.get(&cache_key_fn(symbol)) {
response.$resp_field.insert(symbol.to_string(), entry.value.clone());
}
}
return Ok(response);
}
}
let futures: Vec<_> = $self.symbols.iter().map(|sym_ref| {
#[allow(unused_variables)]
let $client = std::sync::Arc::clone(&$self.client);
let $symbol = std::sync::Arc::clone(sym_ref);
async move {
let result: crate::error::Result<_> = (async { $fetch_expr }).await;
($symbol, result)
}
}).collect();
let results: Vec<_> = futures::stream::iter(futures)
.buffer_unordered($self.max_concurrency)
.collect()
.await;
let mut response = $resp_ty::with_capacity($self.symbols.len());
let mut parsed = Vec::new();
for (sym, result) in results {
match result {
Ok(value) => parsed.push((sym, value)),
Err(e) => { response.errors.insert(sym.to_string(), e.to_string()); }
}
}
if $self.cache_ttl.is_some() {
let mut cache = $self.$cache_field.write().await;
for (sym, value) in &parsed {
$self.cache_insert(&mut cache, cache_key_fn(sym), value.clone());
}
}
for (sym, value) in parsed {
response.$resp_field.insert(sym.to_string(), value);
}
Ok(response)
}};
}
pub(crate) use batch_fetch_cached;
pub(crate) use define_batch_response;