use crate::types::{CapacityResponse, TokenCapacity, VAEA_API_URL};
use crate::local_builder::update_registry_from_capacity;
use std::sync::{Arc, Mutex};
use tokio::task::JoinHandle;
pub struct WarmCache {
api_url: String,
refresh_ms: u64,
inner: Arc<Mutex<WarmCacheInner>>,
task_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
}
struct WarmCacheInner {
capacity: Option<CapacityResponse>,
listeners: Vec<Box<dyn Fn(&CapacityResponse) + Send + 'static>>,
}
impl WarmCache {
pub fn new(api_url: Option<&str>, refresh_ms: Option<u64>) -> Self {
Self {
api_url: api_url.unwrap_or(VAEA_API_URL).to_string(),
refresh_ms: refresh_ms.unwrap_or(2000),
inner: Arc::new(Mutex::new(WarmCacheInner {
capacity: None,
listeners: Vec::new(),
})),
task_handle: Arc::new(Mutex::new(None)),
}
}
pub async fn start(&self) {
self.refresh().await;
let api_url = self.api_url.clone();
let refresh_ms = self.refresh_ms;
let inner = Arc::clone(&self.inner);
let handle = tokio::spawn(async move {
let client = reqwest::Client::new();
let mut interval = tokio::time::interval(std::time::Duration::from_millis(refresh_ms));
loop {
interval.tick().await;
if let Ok(res) = client.get(&format!("{}/v1/capacity", api_url)).send().await {
if let Ok(capacity) = res.json::<CapacityResponse>().await {
update_registry_from_capacity(&capacity.tokens);
let mut guard = inner.lock().unwrap();
guard.capacity = Some(capacity.clone());
for listener in &guard.listeners {
listener(&capacity);
}
}
}
}
});
*self.task_handle.lock().unwrap() = Some(handle);
}
pub fn stop(&self) {
if let Some(handle) = self.task_handle.lock().unwrap().take() {
handle.abort();
}
}
pub fn on_update<F: Fn(&CapacityResponse) + Send + 'static>(&self, handler: F) {
self.inner.lock().unwrap().listeners.push(Box::new(handler));
}
pub fn get_capacity(&self) -> Option<CapacityResponse> {
self.inner.lock().unwrap().capacity.clone()
}
pub fn get_token_capacity(&self, symbol: &str) -> Option<TokenCapacity> {
let guard = self.inner.lock().unwrap();
guard.capacity.as_ref()?.tokens.iter()
.find(|t| t.symbol.eq_ignore_ascii_case(symbol))
.cloned()
}
pub fn is_warm(&self) -> bool {
self.inner.lock().unwrap().capacity.is_some()
}
async fn refresh(&self) {
let client = reqwest::Client::new();
if let Ok(res) = client.get(&format!("{}/v1/capacity", self.api_url)).send().await {
if let Ok(capacity) = res.json::<CapacityResponse>().await {
update_registry_from_capacity(&capacity.tokens);
let mut guard = self.inner.lock().unwrap();
guard.capacity = Some(capacity.clone());
for listener in &guard.listeners {
listener(&capacity);
}
}
}
}
}