use ruma::time::SystemTime;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TtlValue<T> {
#[serde(flatten)]
data: T,
#[serde(default = "default_timestamp")]
last_fetch_ts: Option<f64>,
}
impl<T> TtlValue<T> {
pub const STALE_THRESHOLD: f64 = (1000 * 60 * 60 * 24) as _;
pub fn new(data: T) -> Self {
Self { data, last_fetch_ts: Some(now_timestamp_ms()) }
}
pub fn without_expiry(data: T) -> Self {
Self { data, last_fetch_ts: None }
}
pub fn as_ref(&self) -> TtlValue<&T> {
TtlValue { data: &self.data, last_fetch_ts: self.last_fetch_ts }
}
pub fn map<U, F>(self, f: F) -> TtlValue<U>
where
F: FnOnce(T) -> U,
{
TtlValue { data: f(self.data), last_fetch_ts: self.last_fetch_ts }
}
pub fn has_expired(&self) -> bool {
self.last_fetch_ts.is_some_and(|ts| now_timestamp_ms() - ts >= Self::STALE_THRESHOLD)
}
pub fn expire(&mut self) {
self.last_fetch_ts = Some(0.0)
}
pub fn data(&self) -> &T {
&self.data
}
pub fn into_data(self) -> T {
self.data
}
}
fn now_timestamp_ms() -> f64 {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("System clock was before 1970.")
.as_secs_f64()
* 1000.0
}
fn default_timestamp() -> Option<f64> {
Some(0.0)
}
#[cfg(test)]
mod tests {
use serde::{Deserialize, Serialize};
use serde_json::json;
use super::{TtlValue, now_timestamp_ms};
#[test]
fn test_ttl_value_expiry() {
let ttl_value = TtlValue {
data: (),
last_fetch_ts: Some(now_timestamp_ms() - TtlValue::<()>::STALE_THRESHOLD - 1.0),
};
assert!(ttl_value.has_expired());
let ttl_value = TtlValue::new(());
assert!(!ttl_value.has_expired());
let ttl_value = TtlValue::without_expiry(());
assert!(!ttl_value.has_expired());
}
#[test]
fn test_ttl_value_serialize_roundtrip() {
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
struct Data {
foo: String,
}
let data = Data { foo: "bar".to_owned() };
let ttl_value = TtlValue { data: data.clone(), last_fetch_ts: Some(1000.0) };
let json = json!({
"foo": "bar",
"last_fetch_ts": 1000.0,
});
assert_eq!(serde_json::to_value(&ttl_value).unwrap(), json);
let deserialized = serde_json::from_value::<TtlValue<Data>>(json).unwrap();
assert_eq!(deserialized.data, data);
assert!(deserialized.last_fetch_ts.unwrap() - ttl_value.last_fetch_ts.unwrap() < 0.0001);
let json = json!({
"foo": "bar",
});
let deserialized = serde_json::from_value::<TtlValue<Data>>(json).unwrap();
assert_eq!(deserialized.data, data);
assert!(deserialized.last_fetch_ts.unwrap() - 0.0 < 0.0001);
}
}