use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
use tracing::{debug, trace};
#[derive(Debug, thiserror::Error)]
pub enum LatchError {
#[error("Timeout waiting for completion")]
Timeout,
#[error("Latch was cancelled")]
Cancelled,
#[error("Indexing failed: {0}")]
IndexingFailed(String),
}
#[derive(Clone, Debug, Default)]
enum LatchState {
#[default]
Pending,
Completed,
Failed(String),
}
#[derive(Clone)]
pub struct IndexLatch {
state_tx: Arc<watch::Sender<LatchState>>,
state_rx: watch::Receiver<LatchState>,
}
impl IndexLatch {
pub fn new() -> Self {
debug!("Creating new IndexLatch");
let (state_tx, state_rx) = watch::channel(LatchState::default());
Self {
state_tx: Arc::new(state_tx),
state_rx,
}
}
pub async fn wait(&self, timeout: Duration) -> Result<(), LatchError> {
let mut rx = self.state_rx.clone();
match *rx.borrow() {
LatchState::Completed => {
debug!("IndexLatch: Already completed, returning immediately");
return Ok(());
}
LatchState::Failed(ref error) => {
debug!("IndexLatch: Already failed, returning error immediately");
return Err(LatchError::IndexingFailed(error.clone()));
}
LatchState::Pending => {
trace!("IndexLatch: Waiting for completion");
}
}
let result = tokio::time::timeout(timeout, rx.changed()).await;
match result {
Ok(Ok(())) => {
match *rx.borrow() {
LatchState::Completed => {
debug!("IndexLatch: Completed successfully");
Ok(())
}
LatchState::Failed(ref error) => {
debug!("IndexLatch: Failed with error: {}", error);
Err(LatchError::IndexingFailed(error.clone()))
}
LatchState::Pending => {
trace!("IndexLatch: Got change notification but still pending");
Err(LatchError::Cancelled)
}
}
}
Ok(Err(_)) => {
debug!("IndexLatch: Sender dropped");
Err(LatchError::Cancelled)
}
Err(_) => {
debug!("IndexLatch: Timeout after {:?}", timeout);
Err(LatchError::Timeout)
}
}
}
pub async fn trigger_success(&self) {
if matches!(*self.state_rx.borrow(), LatchState::Pending) {
let _ = self.state_tx.send(LatchState::Completed);
debug!("IndexLatch: Triggered success");
} else {
trace!("IndexLatch: Already triggered, ignoring success trigger");
}
}
pub async fn trigger_failure(&self, error: String) {
if matches!(*self.state_rx.borrow(), LatchState::Pending) {
let _ = self.state_tx.send(LatchState::Failed(error.clone()));
debug!("IndexLatch: Triggered failure: {}", error);
} else {
trace!("IndexLatch: Already triggered, ignoring failure trigger");
}
}
}
impl Default for IndexLatch {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for IndexLatch {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IndexLatch").finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::Duration;
#[tokio::test]
async fn test_latch_creation() {
let latch = IndexLatch::new();
match *latch.state_rx.borrow() {
LatchState::Pending => {} _ => panic!("Expected Pending state on creation"),
}
}
#[tokio::test]
async fn test_trigger_success() {
let latch = IndexLatch::new();
latch.trigger_success().await;
match *latch.state_rx.borrow() {
LatchState::Completed => {} _ => panic!("Expected Completed state after trigger_success"),
}
}
#[tokio::test]
async fn test_trigger_failure() {
let latch = IndexLatch::new();
let error_msg = "test error".to_string();
latch.trigger_failure(error_msg.clone()).await;
match &*latch.state_rx.borrow() {
LatchState::Failed(msg) => assert_eq!(msg, &error_msg),
_ => panic!("Expected Failed state after trigger_failure"),
}
}
#[tokio::test]
async fn test_wait_already_completed() {
let latch = IndexLatch::new();
latch.trigger_success().await;
let result = latch.wait(Duration::from_millis(100)).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_wait_already_failed() {
let latch = IndexLatch::new();
let error_msg = "test error".to_string();
latch.trigger_failure(error_msg.clone()).await;
let result = latch.wait(Duration::from_millis(100)).await;
match result {
Err(LatchError::IndexingFailed(msg)) => assert_eq!(msg, error_msg),
_ => panic!("Expected IndexingFailed error"),
}
}
#[tokio::test]
async fn test_multiple_waiters_sequential() {
let latch = IndexLatch::new();
latch.trigger_success().await;
let result1 = latch.wait(Duration::from_millis(100)).await;
assert!(result1.is_ok(), "First waiter should succeed");
let result2 = latch.wait(Duration::from_millis(100)).await;
assert!(result2.is_ok(), "Second waiter should succeed");
let result3 = latch.wait(Duration::from_millis(100)).await;
assert!(result3.is_ok(), "Third waiter should succeed");
}
#[tokio::test]
async fn test_timeout() {
let latch = IndexLatch::new();
let result = latch.wait(Duration::from_millis(50)).await;
match result {
Err(LatchError::Timeout) => {} _ => panic!("Expected Timeout error, got: {:?}", result),
}
match *latch.state_rx.borrow() {
LatchState::Pending => {} _ => panic!("Expected Pending state after timeout"),
}
}
#[tokio::test]
async fn test_trigger_then_wait() {
let latch = IndexLatch::new();
latch.trigger_success().await;
let result = latch.wait(Duration::from_millis(100)).await;
assert!(
result.is_ok(),
"Wait should succeed immediately after trigger"
);
}
#[tokio::test]
async fn test_multiple_waiters_after_failure() {
let latch = IndexLatch::new();
let error_msg = "test failure".to_string();
latch.trigger_failure(error_msg.clone()).await;
let result1 = latch.wait(Duration::from_millis(100)).await;
match result1 {
Err(LatchError::IndexingFailed(msg)) => assert_eq!(msg, error_msg),
_ => panic!("Expected IndexingFailed error"),
}
let result2 = latch.wait(Duration::from_millis(100)).await;
match result2 {
Err(LatchError::IndexingFailed(msg)) => assert_eq!(msg, error_msg),
_ => panic!("Expected IndexingFailed error"),
}
}
#[tokio::test]
async fn test_double_trigger_ignored() {
let latch = IndexLatch::new();
latch.trigger_success().await;
match *latch.state_rx.borrow() {
LatchState::Completed => {} _ => panic!("Expected Completed state after first trigger"),
}
latch.trigger_failure("error".to_string()).await;
match *latch.state_rx.borrow() {
LatchState::Completed => {} _ => panic!("State should remain Completed after second trigger"),
}
}
}