use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tokio::sync::{mpsc, oneshot};
use url::Url;
use super::AvisoClient;
use crate::ClientError;
use crate::state::ResumeKey;
use crate::watch::{
CHANNEL_CAPACITY, NotificationStream, WatchRequest, WireWatchRequest, run_supervisor,
};
impl AvisoClient {
pub fn watch(&self, request: WatchRequest) -> crate::Result<NotificationStream> {
let handle = tokio::runtime::Handle::try_current().map_err(|_| {
ClientError::Config("AvisoClient::watch requires a Tokio runtime".into())
})?;
let _ = WireWatchRequest::from_public(&request)?;
let resume_key = compute_resume_key(&self.base_url, &request)?;
let capacity = if self.state_store.is_some() {
1
} else {
CHANNEL_CAPACITY
};
let (tx, rx) = mpsc::channel(capacity);
let (cancel_tx, cancel_rx) = oneshot::channel();
let (done_tx, done_rx) = oneshot::channel();
let parent_cancel = self.parent_drop.subscribe();
let http = self.http.clone();
let base_url = self.base_url.clone();
let auth = self.auth.clone();
let heartbeat_interval = self.heartbeat_interval;
let state_store = self.state_store.clone();
let active_resume_keys = self.active_resume_keys.clone();
let flush_cursor_on_exit = self.flush_cursor_on_exit;
increment_active_key(&active_resume_keys, &resume_key, request.event_type());
handle.spawn(run_supervisor(
request,
http,
base_url,
auth,
heartbeat_interval,
state_store,
resume_key,
tx,
cancel_rx,
parent_cancel,
active_resume_keys,
flush_cursor_on_exit,
done_tx,
));
Ok(NotificationStream::new(rx, cancel_tx, done_rx))
}
pub async fn watch_with_handler<F, Fut>(
&self,
request: WatchRequest,
mut handler: F,
) -> crate::Result<()>
where
F: FnMut(crate::Notification) -> Fut + Send,
Fut: std::future::Future<Output = crate::Result<()>> + Send,
{
let mut stream = self.watch(request)?;
loop {
match stream.recv().await {
Some(Ok(notification)) => handler(notification).await?,
Some(Err(e)) => return Err(e),
None => return Ok(()),
}
}
}
}
pub(crate) fn increment_active_key(
active: &Arc<Mutex<HashMap<ResumeKey, usize>>>,
key: &ResumeKey,
event_type: &str,
) {
let (mut guard, poisoned) = match active.lock() {
Ok(g) => (g, false),
Err(poison) => {
tracing::warn!(
event.name = "client.resume.collision.poisoned",
"resume-key collision tracker mutex is poisoned; refcount continues but \
collision WARN is suppressed for this call"
);
(poison.into_inner(), true)
}
};
let prior = *guard.get(key).unwrap_or(&0);
guard.insert(key.clone(), prior + 1);
drop(guard);
if prior > 0 && !poisoned {
tracing::warn!(
event.name = "client.resume.collision",
resume_key = %key.as_hex(),
event_type = event_type,
"multiple concurrent watch() calls share the same resume key; checkpoint \
advancement is racy and the affected watches will interleave commits"
);
}
}
pub(crate) fn decrement_active_key(
active: &Arc<Mutex<HashMap<ResumeKey, usize>>>,
key: &ResumeKey,
) {
let mut guard = match active.lock() {
Ok(g) => g,
Err(poison) => poison.into_inner(),
};
if let Some(count) = guard.get_mut(key) {
*count = count.saturating_sub(1);
if *count == 0 {
guard.remove(key);
}
}
}
pub(crate) fn compute_resume_key(
base_url: &Url,
request: &WatchRequest,
) -> crate::Result<ResumeKey> {
let filter_value = serde_json::Value::Object(
request
.filter()
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
);
ResumeKey::new(base_url, request.event_type(), &filter_value, None)
.map_err(|e| ClientError::Config(format!("compute resume key for watch request: {e}")))
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
reason = "test code: unwrap on constructor success and mutex guards is the expected diagnostic"
)]
mod tests {
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[tokio::test]
async fn resume_key_collision_refcount_semantics() {
let active = Arc::new(Mutex::new(HashMap::new()));
let key = crate::state::ResumeKey::new(
&url::Url::parse("http://example.com/").unwrap(),
"mars",
&serde_json::Value::Object(serde_json::Map::default()),
None,
)
.unwrap();
let count_at = |a: &Arc<Mutex<HashMap<_, _>>>, k: &crate::state::ResumeKey| -> usize {
*a.lock().unwrap().get(k).unwrap_or(&0)
};
super::increment_active_key(&active, &key, "mars");
assert_eq!(count_at(&active, &key), 1);
super::increment_active_key(&active, &key, "mars");
assert_eq!(count_at(&active, &key), 2);
super::decrement_active_key(&active, &key);
assert_eq!(count_at(&active, &key), 1);
super::increment_active_key(&active, &key, "mars");
assert_eq!(count_at(&active, &key), 2);
super::decrement_active_key(&active, &key);
super::decrement_active_key(&active, &key);
assert_eq!(count_at(&active, &key), 0);
assert!(
!active.lock().unwrap().contains_key(&key),
"entry must be removed at zero refcount"
);
super::increment_active_key(&active, &key, "mars");
assert_eq!(count_at(&active, &key), 1);
super::decrement_active_key(&active, &key);
}
}