use ahash::AHashMap;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashSet;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{broadcast, mpsc};
use tracing::{debug, info};
#[derive(Debug, Clone)]
pub struct LiveQueryConfig {
pub default_throttle_ms: u32,
pub max_per_connection: u32,
pub default_ttl_seconds: u32,
pub max_total: usize,
pub enable_hash_diff: bool,
pub invalidation_buffer_size: usize,
}
impl Default for LiveQueryConfig {
fn default() -> Self {
Self {
default_throttle_ms: 100,
max_per_connection: 10,
default_ttl_seconds: 0,
max_total: 10000,
enable_hash_diff: true,
invalidation_buffer_size: 1000,
}
}
}
impl LiveQueryConfig {
pub fn production() -> Self {
Self {
default_throttle_ms: 200,
max_per_connection: 5,
default_ttl_seconds: 3600, max_total: 5000,
enable_hash_diff: true,
invalidation_buffer_size: 500,
}
}
pub fn development() -> Self {
Self {
default_throttle_ms: 50,
max_per_connection: 50,
default_ttl_seconds: 0,
max_total: 50000,
enable_hash_diff: true,
invalidation_buffer_size: 2000,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum LiveQueryStrategy {
#[default]
Invalidation,
Polling,
HashDiff,
}
impl std::str::FromStr for LiveQueryStrategy {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_uppercase().as_str() {
"INVALIDATION" => Ok(Self::Invalidation),
"POLLING" => Ok(Self::Polling),
"HASH_DIFF" | "HASHDIFF" => Ok(Self::HashDiff),
_ => Err(()),
}
}
}
#[derive(Debug, Clone)]
pub struct InvalidationEvent {
pub type_name: String,
pub action: String,
pub entity_id: Option<String>,
pub timestamp: Instant,
}
impl InvalidationEvent {
pub fn new(type_name: impl Into<String>, action: impl Into<String>) -> Self {
Self {
type_name: type_name.into(),
action: action.into(),
entity_id: None,
timestamp: Instant::now(),
}
}
pub fn for_entity(
type_name: impl Into<String>,
action: impl Into<String>,
entity_id: impl Into<String>,
) -> Self {
Self {
type_name: type_name.into(),
action: action.into(),
entity_id: Some(entity_id.into()),
timestamp: Instant::now(),
}
}
pub fn matches(&self, trigger: &str) -> bool {
let parts: Vec<&str> = trigger.split('.').collect();
if parts.len() != 2 {
return false;
}
let (trigger_type, trigger_action) = (parts[0], parts[1]);
let type_matches = trigger_type == "*" || trigger_type == self.type_name;
let action_matches = trigger_action == "*" || trigger_action == self.action;
type_matches && action_matches
}
}
#[derive(Debug, Clone)]
pub struct ActiveLiveQuery {
pub id: String,
pub operation_name: String,
pub query: String,
pub variables: Option<serde_json::Value>,
pub triggers: Vec<String>,
pub throttle_ms: u32,
pub ttl_seconds: u32,
pub strategy: LiveQueryStrategy,
pub poll_interval_ms: u32,
pub last_hash: Option<String>,
pub last_update: Instant,
pub created_at: Instant,
pub connection_id: String,
}
impl ActiveLiveQuery {
pub fn should_update(&self, event: &InvalidationEvent) -> bool {
self.triggers.iter().any(|trigger| event.matches(trigger))
}
pub fn throttle_elapsed(&self) -> bool {
self.last_update.elapsed() >= Duration::from_millis(self.throttle_ms as u64)
}
pub fn is_expired(&self) -> bool {
if self.ttl_seconds == 0 {
return false;
}
self.created_at.elapsed() >= Duration::from_secs(self.ttl_seconds as u64)
}
pub fn cache_key(&self) -> String {
let mut hasher = Sha256::new();
hasher.update(self.query.as_bytes());
if let Some(vars) = &self.variables {
hasher.update(vars.to_string().as_bytes());
}
format!("{:x}", hasher.finalize())
}
}
#[derive(Debug, Clone, Serialize)]
pub struct LiveQueryUpdate {
pub id: String,
pub data: serde_json::Value,
pub is_initial: bool,
pub revision: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_control: Option<CacheControl>,
#[serde(skip_serializing_if = "Option::is_none")]
pub changed_fields: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub batched: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub timestamp: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheControl {
pub max_age: u32,
pub public: bool,
pub must_revalidate: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub etag: Option<String>,
}
impl Default for CacheControl {
fn default() -> Self {
Self {
max_age: 0, public: false,
must_revalidate: true,
etag: None,
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct FieldChange {
pub field_path: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub old_value: Option<serde_json::Value>,
pub new_value: serde_json::Value,
}
#[derive(Debug, Clone)]
pub struct BatchInvalidationConfig {
pub enabled: bool,
pub debounce_ms: u64,
pub max_batch_size: usize,
pub max_wait_ms: u64,
}
impl Default for BatchInvalidationConfig {
fn default() -> Self {
Self {
enabled: true,
debounce_ms: 50, max_batch_size: 100,
max_wait_ms: 500, }
}
}
#[derive(Debug, Clone, Copy)]
pub enum DataVolatility {
VeryHigh,
High,
Medium,
Low,
VeryLow,
}
#[derive(Debug, Clone, Default)]
pub struct LiveQueryStats {
pub active_count: usize,
pub total_updates: u64,
pub total_invalidations: u64,
pub expired_count: u64,
pub cancelled_count: u64,
}
pub struct LiveQueryStore {
config: LiveQueryConfig,
queries: RwLock<AHashMap<String, ActiveLiveQuery>>,
trigger_index: RwLock<AHashMap<String, HashSet<String>>>,
connection_index: RwLock<AHashMap<String, HashSet<String>>>,
update_senders: RwLock<AHashMap<String, mpsc::Sender<LiveQueryUpdate>>>,
invalidation_tx: broadcast::Sender<InvalidationEvent>,
#[allow(dead_code)] stats: LiveQueryStats,
update_counter: AtomicU64,
invalidation_counter: AtomicU64,
}
impl LiveQueryStore {
pub fn new() -> Self {
Self::with_config(LiveQueryConfig::default())
}
pub fn with_config(config: LiveQueryConfig) -> Self {
let (invalidation_tx, _) = broadcast::channel(config.invalidation_buffer_size);
Self {
config,
queries: RwLock::new(AHashMap::new()),
trigger_index: RwLock::new(AHashMap::new()),
connection_index: RwLock::new(AHashMap::new()),
update_senders: RwLock::new(AHashMap::new()),
invalidation_tx,
stats: LiveQueryStats::default(),
update_counter: AtomicU64::new(0),
invalidation_counter: AtomicU64::new(0),
}
}
pub fn register(
&self,
query: ActiveLiveQuery,
sender: mpsc::Sender<LiveQueryUpdate>,
) -> Result<(), LiveQueryError> {
let queries = self.queries.read();
if queries.len() >= self.config.max_total {
return Err(LiveQueryError::TooManyQueries {
current: queries.len(),
max: self.config.max_total,
});
}
drop(queries);
let connection_index = self.connection_index.read();
if let Some(conn_queries) = connection_index.get(&query.connection_id) {
if conn_queries.len() >= self.config.max_per_connection as usize {
return Err(LiveQueryError::TooManyQueriesPerConnection {
current: conn_queries.len(),
max: self.config.max_per_connection as usize,
});
}
}
drop(connection_index);
let id = query.id.clone();
let connection_id = query.connection_id.clone();
let triggers = query.triggers.clone();
self.queries.write().insert(id.clone(), query);
{
let mut trigger_index = self.trigger_index.write();
for trigger in &triggers {
trigger_index
.entry(trigger.clone())
.or_default()
.insert(id.clone());
}
}
{
let mut connection_index = self.connection_index.write();
connection_index
.entry(connection_id)
.or_default()
.insert(id.clone());
}
self.update_senders.write().insert(id.clone(), sender);
debug!(subscription_id = %id, "Registered live query");
Ok(())
}
pub fn unregister(&self, subscription_id: &str) -> Option<ActiveLiveQuery> {
let query = self.queries.write().remove(subscription_id)?;
{
let mut trigger_index = self.trigger_index.write();
for trigger in &query.triggers {
if let Some(ids) = trigger_index.get_mut(trigger) {
ids.remove(subscription_id);
if ids.is_empty() {
trigger_index.remove(trigger);
}
}
}
}
{
let mut connection_index = self.connection_index.write();
if let Some(ids) = connection_index.get_mut(&query.connection_id) {
ids.remove(subscription_id);
if ids.is_empty() {
connection_index.remove(&query.connection_id);
}
}
}
self.update_senders.write().remove(subscription_id);
debug!(subscription_id = %subscription_id, "Unregistered live query");
Some(query)
}
pub fn unregister_connection(&self, connection_id: &str) -> Vec<ActiveLiveQuery> {
let ids: Vec<String> = {
let connection_index = self.connection_index.read();
connection_index
.get(connection_id)
.map(|ids| ids.iter().cloned().collect())
.unwrap_or_default()
};
ids.iter().filter_map(|id| self.unregister(id)).collect()
}
pub fn invalidate(&self, event: InvalidationEvent) -> usize {
self.invalidation_counter.fetch_add(1, Ordering::Relaxed);
let mut affected = HashSet::new();
let trigger_pattern = format!("{}.{}", event.type_name, event.action);
let wildcard_type = format!("*.{}", event.action);
let wildcard_action = format!("{}.*", event.type_name);
let wildcard_all = "*.*".to_string();
let trigger_index = self.trigger_index.read();
for pattern in [
trigger_pattern,
wildcard_type,
wildcard_action,
wildcard_all,
] {
if let Some(ids) = trigger_index.get(&pattern) {
affected.extend(ids.iter().cloned());
}
}
drop(trigger_index);
let affected_count = affected.len();
debug!(
type_name = %event.type_name,
action = %event.action,
affected_count = affected_count,
"Broadcasting invalidation event"
);
let _ = self.invalidation_tx.send(event);
affected_count
}
pub async fn send_update(
&self,
subscription_id: &str,
data: serde_json::Value,
is_initial: bool,
) -> Result<(), LiveQueryError> {
let sender = {
let senders = self.update_senders.read();
senders.get(subscription_id).cloned()
};
let sender = sender.ok_or_else(|| LiveQueryError::SubscriptionNotFound {
id: subscription_id.to_string(),
})?;
{
let mut queries = self.queries.write();
if let Some(query) = queries.get_mut(subscription_id) {
query.last_update = Instant::now();
}
}
let revision = self.update_counter.fetch_add(1, Ordering::Relaxed);
let update = LiveQueryUpdate {
id: subscription_id.to_string(),
data,
is_initial,
revision,
cache_control: None, changed_fields: None, batched: None, timestamp: Some(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
),
};
sender
.send(update)
.await
.map_err(|_| LiveQueryError::ChannelClosed)
}
pub fn subscribe_invalidations(&self) -> broadcast::Receiver<InvalidationEvent> {
self.invalidation_tx.subscribe()
}
pub fn get(&self, subscription_id: &str) -> Option<ActiveLiveQuery> {
self.queries.read().get(subscription_id).cloned()
}
pub fn get_for_connection(&self, connection_id: &str) -> Vec<ActiveLiveQuery> {
let connection_index = self.connection_index.read();
let ids = connection_index.get(connection_id);
if let Some(ids) = ids {
let queries = self.queries.read();
ids.iter()
.filter_map(|id| queries.get(id).cloned())
.collect()
} else {
Vec::new()
}
}
pub fn is_live_enabled(
&self,
operation_name: &str,
live_query_configs: &AHashMap<String, LiveQueryOperationConfig>,
) -> bool {
live_query_configs.contains_key(operation_name)
}
pub fn get_operation_config<'a>(
&self,
operation_name: &str,
live_query_configs: &'a AHashMap<String, LiveQueryOperationConfig>,
) -> Option<&'a LiveQueryOperationConfig> {
live_query_configs.get(operation_name)
}
pub fn stats(&self) -> LiveQueryStats {
let active_count = self.queries.read().len();
LiveQueryStats {
active_count,
total_updates: self.update_counter.load(Ordering::Relaxed),
total_invalidations: self.invalidation_counter.load(Ordering::Relaxed),
..Default::default()
}
}
pub fn prune_expired(&self) -> Vec<String> {
let expired_ids: Vec<String> = {
let queries = self.queries.read();
queries
.iter()
.filter(|(_, q)| q.is_expired())
.map(|(id, _)| id.clone())
.collect()
};
for id in &expired_ids {
self.unregister(id);
}
if !expired_ids.is_empty() {
info!(count = expired_ids.len(), "Pruned expired live queries");
}
expired_ids
}
}
impl Default for LiveQueryStore {
fn default() -> Self {
Self::new()
}
}
pub type SharedLiveQueryStore = Arc<LiveQueryStore>;
lazy_static::lazy_static! {
static ref GLOBAL_LIVE_QUERY_STORE: SharedLiveQueryStore = Arc::new(LiveQueryStore::new());
}
pub fn global_live_query_store() -> SharedLiveQueryStore {
GLOBAL_LIVE_QUERY_STORE.clone()
}
pub fn create_live_query_store() -> SharedLiveQueryStore {
global_live_query_store()
}
pub fn create_live_query_store_with_config(config: LiveQueryConfig) -> SharedLiveQueryStore {
Arc::new(LiveQueryStore::with_config(config))
}
#[derive(Debug, Clone)]
pub struct LiveQueryConfigInfo {
pub operation_name: &'static str,
pub throttle_ms: u32,
pub triggers: &'static [&'static str],
pub max_connections: u32,
pub ttl_seconds: u32,
pub strategy: &'static str,
pub poll_interval_ms: u32,
pub depends_on: &'static [&'static str],
}
#[derive(Debug, Clone)]
pub struct LiveQueryOperationConfig {
pub operation_name: String,
pub enabled: bool,
pub throttle_ms: u32,
pub triggers: Vec<String>,
pub max_connections: u32,
pub ttl_seconds: u32,
pub strategy: LiveQueryStrategy,
pub poll_interval_ms: u32,
pub depends_on: Vec<String>,
}
#[derive(Debug, thiserror::Error)]
pub enum LiveQueryError {
#[error("Too many live queries ({current}/{max})")]
TooManyQueries { current: usize, max: usize },
#[error("Too many live queries per connection ({current}/{max})")]
TooManyQueriesPerConnection { current: usize, max: usize },
#[error("Subscription not found: {id}")]
SubscriptionNotFound { id: String },
#[error("Channel closed")]
ChannelClosed,
#[error("Operation not supported for live queries: {operation}")]
NotSupported { operation: String },
#[error("Query execution failed: {message}")]
ExecutionFailed { message: String },
}
pub fn has_live_directive(query: &str) -> bool {
let normalized = query
.lines()
.map(|line| {
if let Some(idx) = line.find('#') {
&line[..idx]
} else {
line
}
})
.collect::<Vec<_>>()
.join(" ");
let custom_patterns = ["@live ", "@live\n", "@live\t", "@live(", "@live{"];
for pattern in custom_patterns {
if normalized.contains(pattern) {
return true;
}
}
if normalized.trim_end().ends_with("@live") {
return true;
}
false
}
pub fn strip_live_directive(query: &str) -> String {
let mut result = query.to_string();
if let Some(start) = result.find("@live(") {
let after_open = start + 6; let mut depth = 1;
let mut end = after_open;
for (i, ch) in result[after_open..].char_indices() {
match ch {
'(' => depth += 1,
')' => {
depth -= 1;
if depth == 0 {
end = after_open + i + 1;
break;
}
}
_ => {}
}
}
result = format!("{}{}", &result[..start], &result[end..]);
} else {
result = result.replace("@live", "");
}
result
}
pub fn detect_field_changes(
old: &serde_json::Value,
new: &serde_json::Value,
path: &str,
depth: usize,
max_depth: usize,
) -> Vec<FieldChange> {
if depth >= max_depth {
return Vec::new();
}
let mut changes = Vec::new();
match (old, new) {
(serde_json::Value::Object(old_map), serde_json::Value::Object(new_map)) => {
use std::collections::HashSet;
let mut all_keys: HashSet<String> = old_map.keys().cloned().collect();
all_keys.extend(new_map.keys().cloned());
for key in all_keys {
let field_path = if path.is_empty() {
key.clone()
} else {
format!("{}.{}", path, key)
};
let old_val = old_map.get(&key);
let new_val = new_map.get(&key);
match (old_val, new_val) {
(Some(old_v), Some(new_v)) => {
if old_v != new_v {
let nested = detect_field_changes(
old_v,
new_v,
&field_path,
depth + 1,
max_depth,
);
if nested.is_empty() {
changes.push(FieldChange {
field_path: field_path.clone(),
old_value: Some(old_v.clone()),
new_value: new_v.clone(),
});
} else {
changes.extend(nested);
}
}
}
(None, Some(new_v)) => {
changes.push(FieldChange {
field_path: field_path.clone(),
old_value: None,
new_value: new_v.clone(),
});
}
(Some(old_v), None) => {
changes.push(FieldChange {
field_path: field_path.clone(),
old_value: Some(old_v.clone()),
new_value: serde_json::Value::Null,
});
}
(None, None) => {}
}
}
}
(serde_json::Value::Array(old_arr), serde_json::Value::Array(new_arr)) => {
let max_len = old_arr.len().max(new_arr.len());
for i in 0..max_len {
let field_path = format!("{}[{}]", path, i);
let old_val = old_arr.get(i);
let new_val = new_arr.get(i);
match (old_val, new_val) {
(Some(old_v), Some(new_v)) if old_v != new_v => {
let nested =
detect_field_changes(old_v, new_v, &field_path, depth + 1, max_depth);
if nested.is_empty() {
changes.push(FieldChange {
field_path: field_path.clone(),
old_value: Some(old_v.clone()),
new_value: new_v.clone(),
});
} else {
changes.extend(nested);
}
}
(None, Some(new_v)) => {
changes.push(FieldChange {
field_path: field_path.clone(),
old_value: None,
new_value: new_v.clone(),
});
}
(Some(old_v), None) => {
changes.push(FieldChange {
field_path: field_path.clone(),
old_value: Some(old_v.clone()),
new_value: serde_json::Value::Null,
});
}
_ => {}
}
}
}
_ => {
if old != new {
changes.push(FieldChange {
field_path: path.to_string(),
old_value: Some(old.clone()),
new_value: new.clone(),
});
}
}
}
changes
}
pub fn generate_cache_control(volatility: DataVolatility, etag: Option<String>) -> CacheControl {
let max_age = match volatility {
DataVolatility::VeryHigh => 0, DataVolatility::High => 5, DataVolatility::Medium => 30, DataVolatility::Low => 300, DataVolatility::VeryLow => 3600, };
CacheControl {
max_age,
public: false, must_revalidate: true,
etag,
}
}
pub fn parse_query_arguments(query: &str) -> AHashMap<String, String> {
let mut args = AHashMap::new();
if let Some(start) = query.find('(') {
if let Some(end) = query[start..].find(')') {
let args_str = &query[start + 1..start + end];
for pair in args_str.split(',') {
let parts: Vec<&str> = pair.split(':').map(|s| s.trim()).collect();
if parts.len() == 2 {
args.insert(parts[0].to_string(), parts[1].to_string());
}
}
}
}
args
}
pub fn matches_filter(filter: &AHashMap<String, String>, entity_data: &serde_json::Value) -> bool {
if filter.is_empty() {
return true; }
if let serde_json::Value::Object(entity_map) = entity_data {
for (key, expected_value) in filter {
if let Some(actual_value) = entity_map.get(key) {
let actual_str = match actual_value {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Number(n) => n.to_string(),
serde_json::Value::Bool(b) => b.to_string(),
_ => actual_value.to_string(),
};
if &actual_str != expected_value {
return false;
}
} else {
return false; }
}
true
} else {
false
}
}
pub fn generate_subscription_id() -> String {
uuid::Uuid::new_v4().to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
fn default_query() -> ActiveLiveQuery {
ActiveLiveQuery {
id: "default".to_string(),
operation_name: "op".to_string(),
query: "query".to_string(),
variables: None,
triggers: vec![],
throttle_ms: 100,
ttl_seconds: 0,
strategy: LiveQueryStrategy::Invalidation,
poll_interval_ms: 0,
last_hash: None,
last_update: Instant::now(),
created_at: Instant::now(),
connection_id: "default_conn".to_string(),
}
}
#[test]
fn test_invalidation_event_matching() {
let event = InvalidationEvent::new("User", "update");
assert!(event.matches("User.update"));
assert!(event.matches("User.*"));
assert!(event.matches("*.update"));
assert!(event.matches("*.*"));
assert!(!event.matches("Product.update"));
assert!(!event.matches("User.delete"));
assert!(!event.matches("User"));
}
#[test]
fn test_has_live_directive() {
assert!(has_live_directive("query @live { user { name } }"));
assert!(has_live_directive(
"query GetUser @live { user(id: 1) { name } }"
));
assert!(has_live_directive("query @live\n{ user { name } }"));
assert!(has_live_directive("query @live(throttle: 100) { user }"));
assert!(!has_live_directive("query { user { name } }"));
assert!(!has_live_directive("# @live\nquery { user { name } }"));
}
#[test]
fn test_live_query_config_defaults() {
let config = LiveQueryConfig::default();
assert_eq!(config.default_throttle_ms, 100);
assert_eq!(config.max_per_connection, 10);
let prod = LiveQueryConfig::production();
assert_eq!(prod.max_total, 5000);
let dev = LiveQueryConfig::development();
assert_eq!(dev.max_total, 50000);
assert_eq!(dev.default_throttle_ms, 50);
}
#[test]
fn test_strategy_parsing() {
assert_eq!(
LiveQueryStrategy::from_str("INVALIDATION"),
Ok(LiveQueryStrategy::Invalidation)
);
assert_eq!(
LiveQueryStrategy::from_str("POLLING"),
Ok(LiveQueryStrategy::Polling)
);
assert_eq!(
LiveQueryStrategy::from_str("HASH_DIFF"),
Ok(LiveQueryStrategy::HashDiff)
);
assert_eq!(
LiveQueryStrategy::from_str("HashDiff"),
Ok(LiveQueryStrategy::HashDiff)
);
assert!(LiveQueryStrategy::from_str("Unknown").is_err());
}
#[test]
fn test_strip_live_directive() {
assert_eq!(
strip_live_directive("query @live { foo }"),
"query { foo }"
);
assert_eq!(
strip_live_directive("query @live(throttle: 100) { foo }"),
"query { foo }"
);
assert_eq!(
strip_live_directive("query @live(a: (b)) { foo }"),
"query { foo }"
);
let stripped = strip_live_directive("query @live\n{ foo }");
assert!(!stripped.contains("@live"));
}
#[test]
fn test_active_live_query_methods() {
let mut query = ActiveLiveQuery {
id: "1".to_string(),
operation_name: "op".to_string(),
query: "query".to_string(),
variables: Some(serde_json::json!({"a": 1})),
triggers: vec!["Trigger.A".to_string()],
throttle_ms: 100,
ttl_seconds: 1, strategy: LiveQueryStrategy::Invalidation,
poll_interval_ms: 0,
last_hash: None,
last_update: Instant::now() - Duration::from_millis(200),
created_at: Instant::now() - Duration::from_secs(2),
connection_id: "c1".to_string(),
};
let key1 = query.cache_key();
let key2 = query.cache_key();
assert_eq!(key1, key2);
query.variables = None;
assert_ne!(key1, query.cache_key());
assert!(query.should_update(&InvalidationEvent::new("Trigger", "A")));
assert!(!query.should_update(&InvalidationEvent::new("Trigger", "B")));
assert!(query.throttle_elapsed());
query.last_update = Instant::now();
assert!(!query.throttle_elapsed());
assert!(query.is_expired());
query.created_at = Instant::now();
assert!(!query.is_expired());
}
#[tokio::test]
async fn test_live_query_store_lifecycle() {
let store = LiveQueryStore::new();
let (tx, mut rx) = mpsc::channel(10);
let query = ActiveLiveQuery {
id: "sub-1".to_string(),
operation_name: "user".to_string(),
query: "query @live { user { name } }".to_string(),
variables: None,
triggers: vec!["User.update".to_string()],
throttle_ms: 0, ttl_seconds: 0,
strategy: LiveQueryStrategy::Invalidation,
poll_interval_ms: 0,
last_hash: None,
last_update: Instant::now(),
created_at: Instant::now(),
connection_id: "conn-1".to_string(),
};
store.register(query.clone(), tx).unwrap();
assert_eq!(store.stats().active_count, 1);
assert_eq!(store.get("sub-1").unwrap().connection_id, "conn-1");
store
.send_update("sub-1", serde_json::json!({"data": "test"}), true)
.await
.unwrap();
let update = rx.recv().await.unwrap();
assert_eq!(update.id, "sub-1");
assert!(update.is_initial);
let event = InvalidationEvent::new("User", "update");
let affected = store.invalidate(event);
assert_eq!(affected, 1);
assert_eq!(store.stats().total_invalidations, 1);
let removed = store.unregister("sub-1");
assert!(removed.is_some());
assert_eq!(store.stats().active_count, 0);
assert!(store
.send_update("sub-1", serde_json::json!({}), false)
.await
.is_err());
}
#[tokio::test]
async fn test_store_limits() {
let config = LiveQueryConfig {
max_per_connection: 2,
max_total: 3,
..Default::default()
};
let store = LiveQueryStore::with_config(config);
let (tx, _rx) = mpsc::channel(1);
for i in 0..2 {
let q = ActiveLiveQuery {
id: format!("q{}", i),
connection_id: "c1".to_string(),
..default_query()
};
store.register(q, tx.clone()).unwrap();
}
let q_fail = ActiveLiveQuery {
id: "q_fail".to_string(),
connection_id: "c1".to_string(),
..default_query()
};
match store.register(q_fail, tx.clone()) {
Err(LiveQueryError::TooManyQueriesPerConnection { current, max }) => {
assert_eq!(current, 2);
assert_eq!(max, 2);
}
_ => panic!("Expected TooManyQueriesPerConnection"),
}
let q_c2 = ActiveLiveQuery {
id: "q_c2".to_string(),
connection_id: "c2".to_string(),
..default_query()
};
store.register(q_c2, tx.clone()).unwrap();
let q_global_fail = ActiveLiveQuery {
id: "q_global_fail".to_string(),
connection_id: "c3".to_string(),
..default_query()
};
match store.register(q_global_fail, tx.clone()) {
Err(LiveQueryError::TooManyQueries { current, max }) => {
assert_eq!(current, 3);
assert_eq!(max, 3);
}
_ => panic!("Expected TooManyQueries error"),
}
}
#[test]
fn test_unregister_connection() {
let store = LiveQueryStore::new();
let (tx, _rx) = mpsc::channel(1);
store
.register(
ActiveLiveQuery {
id: "1".into(),
connection_id: "c1".into(),
..default_query()
},
tx.clone(),
)
.unwrap();
store
.register(
ActiveLiveQuery {
id: "2".into(),
connection_id: "c1".into(),
..default_query()
},
tx.clone(),
)
.unwrap();
store
.register(
ActiveLiveQuery {
id: "3".into(),
connection_id: "c2".into(),
..default_query()
},
tx.clone(),
)
.unwrap();
assert_eq!(store.get_for_connection("c1").len(), 2);
let removed = store.unregister_connection("c1");
assert_eq!(removed.len(), 2);
assert_eq!(store.stats().active_count, 1); assert!(store.get("1").is_none());
}
#[test]
fn test_prune_expired() {
let store = LiveQueryStore::new();
let (tx, _rx) = mpsc::channel(1);
let mut q_expired = default_query();
q_expired.id = "exp".to_string();
q_expired.ttl_seconds = 1;
q_expired.created_at = Instant::now() - Duration::from_secs(2);
let mut q_active = default_query();
q_active.id = "act".to_string();
q_active.ttl_seconds = 100;
store.register(q_expired, tx.clone()).unwrap();
store.register(q_active, tx.clone()).unwrap();
let pruned = store.prune_expired();
assert_eq!(pruned.len(), 1);
assert_eq!(pruned[0], "exp");
assert_eq!(store.stats().active_count, 1);
}
}