use std::collections::VecDeque;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore};
#[derive(Debug, Error)]
pub enum LandingZoneError {
#[error("Attempted to remove non-observed element")]
RemovingNonObservedElement,
}
struct LandingZoneState<T> {
queue: VecDeque<T>,
observed_items: VecDeque<T>,
}
pub struct LandingZone<T: Clone> {
state: Arc<std::sync::Mutex<LandingZoneState<T>>>,
new_item_notify: Arc<Notify>,
semaphore: Arc<Semaphore>,
permits: std::sync::Mutex<VecDeque<OwnedSemaphorePermit>>,
}
impl<T: Clone> LandingZone<T> {
pub fn new(max_inflight_requests: usize) -> Self {
Self {
state: Arc::new(std::sync::Mutex::new(LandingZoneState {
queue: VecDeque::with_capacity(max_inflight_requests),
observed_items: VecDeque::with_capacity(max_inflight_requests),
})),
new_item_notify: Arc::new(Notify::new()),
semaphore: Arc::new(Semaphore::new(max_inflight_requests)),
permits: std::sync::Mutex::new(VecDeque::with_capacity(max_inflight_requests)),
}
}
pub fn remove_all(&self) -> Vec<T> {
let mut state = self.state.lock().expect("Lock poisoned");
let mut all_items = Vec::with_capacity(state.observed_items.len() + state.queue.len());
all_items.extend(state.observed_items.drain(..));
all_items.extend(state.queue.drain(..));
let mut permits = self.permits.lock().expect("Lock poisoned");
permits.clear();
all_items
}
pub async fn add(&self, request: T) {
let _permit = self
.semaphore
.clone()
.acquire_owned()
.await
.expect("Failed to acquire semaphore");
let mut state = self.state.lock().expect("Lock poisoned");
state.queue.push_back(request);
self.permits
.lock()
.expect("Lock poisoned")
.push_back(_permit);
self.new_item_notify.notify_one();
}
pub fn remove_observed(&self) -> Result<T, LandingZoneError> {
let mut state = self.state.lock().expect("Lock poisoned");
if let Some(item) = state.observed_items.pop_front() {
self.permits.lock().expect("Lock poisoned").pop_front();
Ok(item)
} else {
Err(LandingZoneError::RemovingNonObservedElement)
}
}
pub async fn observe(&self) -> T {
loop {
let notified = self.new_item_notify.notified();
{
let mut state = self.state.lock().expect("Lock poisoned");
if let Some(elem) = state.queue.pop_front() {
state.observed_items.push_back(elem.clone());
return elem;
}
}
notified.await;
}
}
pub fn reset_observe(&self) {
let mut state = self.state.lock().expect("Lock poisoned");
while let Some(observed_item) = state.observed_items.pop_back() {
state.queue.push_front(observed_item);
}
}
pub fn is_observed_empty(&self) -> bool {
let state = self.state.lock().expect("Lock poisoned");
state.observed_items.is_empty()
}
pub fn len(&self) -> usize {
let state = self.state.lock().expect("Lock poisoned");
state.queue.len() + state.observed_items.len()
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use tokio::time::{timeout, Duration};
use super::{LandingZone, LandingZoneError};
#[tokio::test]
async fn test_add_and_observe() {
let lz = Arc::new(LandingZone::new(10));
lz.add("test_item".to_string()).await;
let observed = lz.observe().await;
assert_eq!(observed, "test_item");
}
#[tokio::test]
async fn test_observe_blocks_until_item_available() {
let lz = Arc::new(LandingZone::new(10));
let lz_clone = lz.clone();
let observe_task = tokio::spawn(async move { lz_clone.observe().await });
tokio::time::sleep(Duration::from_millis(10)).await;
lz.add("delayed_item".to_string()).await;
let result = timeout(Duration::from_millis(100), observe_task).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().unwrap(), "delayed_item");
}
#[tokio::test]
async fn test_remove_observed() {
let lz = Arc::new(LandingZone::new(10));
lz.add("item1".to_string()).await;
let _observed = lz.observe().await;
let removed = lz.remove_observed().unwrap();
assert_eq!(removed, "item1");
}
#[tokio::test]
async fn test_remove_non_observed_fails() {
let lz = Arc::new(LandingZone::<String>::new(10));
let result = lz.remove_observed();
assert!(matches!(
result,
Err(LandingZoneError::RemovingNonObservedElement)
));
}
#[tokio::test]
async fn test_remove_all() {
let lz = Arc::new(LandingZone::new(10));
lz.add("item1".to_string()).await;
lz.add("item2".to_string()).await;
let _observed = lz.observe().await;
let all_items = lz.remove_all();
assert_eq!(all_items.len(), 2);
assert!(all_items.contains(&"item1".to_string()));
assert!(all_items.contains(&"item2".to_string()));
assert!(lz.len() == 0);
}
#[tokio::test]
async fn test_semaphore_limits_capacity() {
let lz = Arc::new(LandingZone::new(2));
lz.add("item1".to_string()).await;
lz.add("item2".to_string()).await;
let mut add_task = tokio::spawn({
let lz = lz.clone();
async move {
lz.add("item3".to_string()).await;
}
});
tokio::select! {
_ = &mut add_task => {
panic!("add_task should not complete while semaphore is full");
}
_ = tokio::time::sleep(Duration::from_millis(50)) => {
}
};
let _observed = lz.observe().await;
let _removed = lz.remove_observed().unwrap();
add_task.await.unwrap();
let all_items = lz.remove_all();
assert_eq!(all_items.len(), 2);
assert!(all_items.contains(&"item2".to_string()));
assert!(all_items.contains(&"item3".to_string()));
}
#[tokio::test]
async fn test_reset_observe_with_concurrent_add() {
let lz = Arc::new(LandingZone::new(10));
lz.add("item1".to_string()).await;
lz.add("item2".to_string()).await;
lz.add("item3".to_string()).await;
let observed = lz.observe().await;
assert_eq!(observed, "item1");
let lz_clone = lz.clone();
let add_task = tokio::spawn(async move {
lz_clone.add("item4".to_string()).await;
});
add_task.await.unwrap();
lz.reset_observe();
assert_eq!(lz.observe().await, "item1");
assert_eq!(lz.observe().await, "item2");
assert_eq!(lz.observe().await, "item3");
assert_eq!(lz.observe().await, "item4");
}
#[tokio::test]
async fn test_semaphore_with_observe_reset() {
let lz = Arc::new(LandingZone::new(2));
lz.add("item1".to_string()).await;
lz.add("item2".to_string()).await;
let _observed = lz.observe().await;
let add_task = tokio::spawn({
let lz = lz.clone();
async move {
lz.add("item3".to_string()).await;
}
});
let result = timeout(Duration::from_millis(50), add_task).await;
assert!(result.is_err());
lz.reset_observe();
assert!(lz.remove_observed().is_err());
let _observed_again = lz.observe().await;
let _removed = lz.remove_observed().unwrap();
let _observed_again_2 = lz.observe().await;
let _removed_2 = lz.remove_observed().unwrap();
lz.add("item4".to_string()).await;
}
#[tokio::test]
async fn test_is_observed_empty() {
let lz = Arc::new(LandingZone::new(16));
assert!(lz.is_observed_empty());
lz.add("item1".to_string()).await;
assert!(lz.is_observed_empty());
lz.observe().await;
assert!(!lz.is_observed_empty());
lz.remove_observed().unwrap();
assert!(lz.is_observed_empty());
}
#[tokio::test]
async fn test_concurrent_operations() {
let lz = Arc::new(LandingZone::new(100));
let mut add_tasks = vec![];
for i in 0..10 {
let lz_clone = lz.clone();
add_tasks.push(tokio::spawn(async move {
lz_clone.add(format!("item{}", i)).await;
}));
}
let mut observe_tasks = vec![];
for _ in 0..10 {
let lz_clone = lz.clone();
observe_tasks.push(tokio::spawn(async move {
lz_clone.observe().await;
}));
}
for task in add_tasks {
task.await.unwrap();
}
let mut observed_items = vec![];
for task in observe_tasks {
observed_items.push(task.await);
}
assert_eq!(observed_items.len(), 10);
}
}