use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{Mutex, oneshot};
use crate::control::messages::ControlResponse;
#[derive(Clone)]
pub struct PendingRequests {
inner: Arc<Mutex<HashMap<String, oneshot::Sender<ControlResponse>>>>,
}
impl PendingRequests {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn insert(&self, id: String, sender: oneshot::Sender<ControlResponse>) {
self.inner.lock().await.insert(id, sender);
}
pub async fn complete(&self, id: &str, response: ControlResponse) -> bool {
if let Some(sender) = self.inner.lock().await.remove(id) {
sender.send(response).is_ok()
} else {
false
}
}
pub async fn cancel(&self, id: &str) {
self.inner.lock().await.remove(id);
}
#[cfg(test)]
pub async fn len(&self) -> usize {
self.inner.lock().await.len()
}
#[cfg(test)]
pub async fn is_empty(&self) -> bool {
self.inner.lock().await.is_empty()
}
}
impl Default for PendingRequests {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn test_insert_and_complete() {
let pending = PendingRequests::new();
let (tx, rx) = oneshot::channel();
pending.insert("req_1".to_string(), tx).await;
assert_eq!(pending.len().await, 1);
let response = ControlResponse::Success {
data: json!({ "result": "ok" }),
};
let sent = pending.complete("req_1", response.clone()).await;
assert!(sent);
assert_eq!(pending.len().await, 0);
let received = rx.await.unwrap();
assert_eq!(received, response);
}
#[tokio::test]
async fn test_complete_nonexistent() {
let pending = PendingRequests::new();
let response = ControlResponse::Success {
data: json!({ "result": "ok" }),
};
let sent = pending.complete("nonexistent", response).await;
assert!(!sent);
}
#[tokio::test]
async fn test_cancel() {
let pending = PendingRequests::new();
let (tx, _rx) = oneshot::channel();
pending.insert("req_1".to_string(), tx).await;
assert_eq!(pending.len().await, 1);
pending.cancel("req_1").await;
assert_eq!(pending.len().await, 0);
}
#[tokio::test]
async fn test_cancel_nonexistent() {
let pending = PendingRequests::new();
pending.cancel("nonexistent").await; }
#[tokio::test]
async fn test_multiple_pending() {
let pending = PendingRequests::new();
let (tx1, mut rx1) = oneshot::channel();
let (tx2, rx2) = oneshot::channel();
let (tx3, mut rx3) = oneshot::channel();
pending.insert("req_1".to_string(), tx1).await;
pending.insert("req_2".to_string(), tx2).await;
pending.insert("req_3".to_string(), tx3).await;
assert_eq!(pending.len().await, 3);
let response2 = ControlResponse::Success {
data: json!({ "id": 2 }),
};
pending.complete("req_2", response2.clone()).await;
assert_eq!(pending.len().await, 2);
let received2 = rx2.await.unwrap();
assert_eq!(received2, response2);
assert!(rx1.try_recv().is_err());
assert!(rx3.try_recv().is_err());
}
#[tokio::test]
async fn test_complete_after_receiver_dropped() {
let pending = PendingRequests::new();
let (tx, rx) = oneshot::channel();
pending.insert("req_1".to_string(), tx).await;
drop(rx);
let response = ControlResponse::Success {
data: json!({ "result": "ok" }),
};
let sent = pending.complete("req_1", response).await;
assert!(!sent); assert_eq!(pending.len().await, 0);
}
#[tokio::test]
async fn test_concurrent_access() {
let pending = PendingRequests::new();
let mut handles = vec![];
for i in 0..10 {
let pending_clone = pending.clone();
let handle = tokio::spawn(async move {
let (tx, rx) = oneshot::channel();
let id = format!("req_{}", i);
pending_clone.insert(id.clone(), tx).await;
let response = ControlResponse::Success {
data: json!({ "id": i }),
};
pending_clone.complete(&id, response.clone()).await;
rx.await.unwrap()
});
handles.push(handle);
}
for (i, handle) in handles.into_iter().enumerate() {
let response = handle.await.unwrap();
match response {
ControlResponse::Success { data } => {
assert_eq!(data["id"], i);
}
_ => panic!("Wrong response type"),
}
}
assert_eq!(pending.len().await, 0);
}
}