use std::collections::HashMap;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
use crate::request_id::RequestId;
use crate::Response;
pub const CANCEL_REQUEST_METHOD: &str = "$/cancelRequest";
#[must_use]
pub fn parse_cancel_params(params: &Option<serde_json::Value>) -> Option<RequestId> {
let params = params.as_ref()?;
let id_value = params.get("id")?;
match id_value {
serde_json::Value::Number(n) => n
.as_i64()
.and_then(|i| i32::try_from(i).ok())
.map(RequestId::Integer),
serde_json::Value::String(s) => Some(RequestId::String(s.clone())),
_ => None,
}
}
#[derive(Debug)]
pub struct IncomingRequests<I> {
pending: HashMap<RequestId, (I, CancellationToken)>,
}
impl<I> IncomingRequests<I> {
#[must_use]
pub fn new() -> Self {
Self {
pending: HashMap::new(),
}
}
pub fn register(&mut self, id: RequestId, data: I, token: CancellationToken) {
self.pending.insert(id, (data, token));
}
pub fn complete(&mut self, id: &RequestId) -> Option<I> {
self.pending.remove(id).map(|(data, _)| data)
}
#[must_use]
pub fn is_pending(&self, id: &RequestId) -> bool {
self.pending.contains_key(id)
}
#[must_use]
pub fn cancel(&self, id: &RequestId) -> bool {
if let Some((_, token)) = self.pending.get(id) {
token.cancel();
true
} else {
false
}
}
#[must_use]
pub fn get_token(&self, id: &RequestId) -> Option<CancellationToken> {
self.pending.get(id).map(|(_, token)| token.clone())
}
#[must_use]
pub fn pending_count(&self) -> usize {
self.pending.len()
}
}
impl<I> Default for IncomingRequests<I> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct OutgoingRequests {
pending: HashMap<RequestId, oneshot::Sender<Response>>,
}
impl OutgoingRequests {
#[must_use]
pub fn new() -> Self {
Self {
pending: HashMap::new(),
}
}
pub fn register(&mut self, id: RequestId) -> oneshot::Receiver<Response> {
let (tx, rx) = oneshot::channel();
self.pending.insert(id, tx);
rx
}
pub fn complete(&mut self, id: &RequestId, response: Response) -> bool {
if let Some(tx) = self.pending.remove(id) {
let _ = tx.send(response);
true
} else {
false
}
}
pub fn cancel(&mut self, id: &RequestId) -> bool {
self.pending.remove(id).is_some()
}
#[must_use]
pub fn is_pending(&self, id: &RequestId) -> bool {
self.pending.contains_key(id)
}
#[must_use]
pub fn pending_count(&self) -> usize {
self.pending.len()
}
}
impl Default for OutgoingRequests {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct RequestQueue<I> {
pub incoming: IncomingRequests<I>,
pub outgoing: OutgoingRequests,
}
impl<I> RequestQueue<I> {
#[must_use]
pub fn new() -> Self {
Self {
incoming: IncomingRequests::new(),
outgoing: OutgoingRequests::new(),
}
}
}
impl<I> Default for RequestQueue<I> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio_util::sync::CancellationToken;
#[test]
fn incoming_register_and_complete() {
let mut incoming: IncomingRequests<String> = IncomingRequests::new();
let token = CancellationToken::new();
incoming.register(1.into(), "metadata".to_string(), token);
let data = incoming.complete(&1.into());
assert_eq!(data, Some("metadata".to_string()));
assert!(!incoming.is_pending(&1.into()));
}
#[test]
fn incoming_complete_unknown_returns_none() {
let mut incoming: IncomingRequests<String> = IncomingRequests::new();
let data = incoming.complete(&999.into());
assert_eq!(data, None);
}
#[test]
fn incoming_is_pending() {
let mut incoming: IncomingRequests<()> = IncomingRequests::new();
assert!(!incoming.is_pending(&1.into()));
let token = CancellationToken::new();
incoming.register(1.into(), (), token);
assert!(incoming.is_pending(&1.into()));
incoming.complete(&1.into());
assert!(!incoming.is_pending(&1.into()));
}
#[test]
fn incoming_pending_count() {
let mut incoming: IncomingRequests<i32> = IncomingRequests::new();
assert_eq!(incoming.pending_count(), 0);
let token1 = CancellationToken::new();
incoming.register(1.into(), 100, token1);
assert_eq!(incoming.pending_count(), 1);
let token2 = CancellationToken::new();
incoming.register(2.into(), 200, token2);
assert_eq!(incoming.pending_count(), 2);
incoming.complete(&1.into());
assert_eq!(incoming.pending_count(), 1);
incoming.complete(&2.into());
assert_eq!(incoming.pending_count(), 0);
}
#[test]
fn incoming_default() {
let incoming: IncomingRequests<()> = IncomingRequests::default();
assert_eq!(incoming.pending_count(), 0);
}
#[test]
fn incoming_cancel_triggers_token() {
let mut incoming: IncomingRequests<String> = IncomingRequests::new();
let token = CancellationToken::new();
let token_clone = token.clone();
incoming.register(1.into(), "data".to_string(), token);
assert!(incoming.cancel(&1.into()));
assert!(token_clone.is_cancelled());
}
#[test]
fn incoming_cancel_unknown_returns_false() {
let incoming: IncomingRequests<()> = IncomingRequests::new();
assert!(!incoming.cancel(&999.into()));
}
#[test]
fn incoming_cancel_idempotent() {
let mut incoming: IncomingRequests<()> = IncomingRequests::new();
let token = CancellationToken::new();
incoming.register(1.into(), (), token);
assert!(incoming.cancel(&1.into()));
assert!(incoming.cancel(&1.into())); }
#[test]
fn incoming_get_token_returns_clone() {
let mut incoming: IncomingRequests<String> = IncomingRequests::new();
let original_token = CancellationToken::new();
incoming.register(1.into(), "data".to_string(), original_token.clone());
let retrieved = incoming.get_token(&1.into());
assert!(retrieved.is_some());
retrieved.unwrap().cancel();
assert!(original_token.is_cancelled());
}
#[test]
fn incoming_get_token_unknown_returns_none() {
let incoming: IncomingRequests<()> = IncomingRequests::new();
assert!(incoming.get_token(&999.into()).is_none());
}
#[test]
fn incoming_complete_after_cancel_returns_data() {
let mut incoming: IncomingRequests<String> = IncomingRequests::new();
let token = CancellationToken::new();
incoming.register(1.into(), "cancelled_data".to_string(), token);
let _ = incoming.cancel(&1.into());
let data = incoming.complete(&1.into());
assert_eq!(data, Some("cancelled_data".to_string()));
}
#[tokio::test]
async fn outgoing_register_and_complete() {
let mut outgoing = OutgoingRequests::new();
let rx = outgoing.register(1.into());
let response = Response::ok(1, serde_json::json!("response"));
assert!(outgoing.complete(&1.into(), response.clone()));
assert_eq!(rx.await.unwrap().id, response.id);
}
#[test]
fn outgoing_complete_unknown_returns_false() {
let mut outgoing = OutgoingRequests::new();
let result = outgoing.complete(&999.into(), Response::ok(999, serde_json::json!(null)));
assert!(!result);
}
#[tokio::test]
async fn outgoing_cancel_drops_sender() {
let mut outgoing = OutgoingRequests::new();
let rx = outgoing.register(1.into());
assert!(outgoing.cancel(&1.into()));
assert!(!outgoing.is_pending(&1.into()));
assert!(rx.await.is_err());
}
#[test]
fn outgoing_cancel_unknown_returns_false() {
let mut outgoing = OutgoingRequests::new();
assert!(!outgoing.cancel(&999.into()));
}
#[test]
fn outgoing_is_pending() {
let mut outgoing = OutgoingRequests::new();
assert!(!outgoing.is_pending(&1.into()));
let _rx = outgoing.register(1.into());
assert!(outgoing.is_pending(&1.into()));
outgoing.complete(&1.into(), Response::ok(1, serde_json::json!(null)));
assert!(!outgoing.is_pending(&1.into()));
}
#[test]
fn outgoing_pending_count() {
let mut outgoing = OutgoingRequests::new();
assert_eq!(outgoing.pending_count(), 0);
let _rx1 = outgoing.register(1.into());
assert_eq!(outgoing.pending_count(), 1);
let _rx2 = outgoing.register(2.into());
assert_eq!(outgoing.pending_count(), 2);
outgoing.complete(&1.into(), Response::ok(1, serde_json::json!(null)));
assert_eq!(outgoing.pending_count(), 1);
outgoing.cancel(&2.into());
assert_eq!(outgoing.pending_count(), 0);
}
#[test]
fn outgoing_default() {
let outgoing = OutgoingRequests::default();
assert_eq!(outgoing.pending_count(), 0);
}
#[test]
fn queue_new_creates_empty() {
let queue: RequestQueue<()> = RequestQueue::new();
assert_eq!(queue.incoming.pending_count(), 0);
assert_eq!(queue.outgoing.pending_count(), 0);
}
#[test]
fn queue_incoming_outgoing_independent() {
let mut queue: RequestQueue<String> = RequestQueue::new();
let token = CancellationToken::new();
queue
.incoming
.register(1.into(), "incoming".to_string(), token);
assert_eq!(queue.incoming.pending_count(), 1);
assert_eq!(queue.outgoing.pending_count(), 0);
let _rx = queue.outgoing.register(2.into());
assert_eq!(queue.incoming.pending_count(), 1);
assert_eq!(queue.outgoing.pending_count(), 1);
queue.incoming.complete(&1.into());
assert_eq!(queue.incoming.pending_count(), 0);
assert_eq!(queue.outgoing.pending_count(), 1);
}
#[test]
fn queue_default() {
let queue: RequestQueue<()> = RequestQueue::default();
assert_eq!(queue.incoming.pending_count(), 0);
assert_eq!(queue.outgoing.pending_count(), 0);
}
#[test]
fn queue_with_string_request_id() {
let mut queue: RequestQueue<i32> = RequestQueue::new();
let str_id: RequestId = "abc-123".into();
let token = CancellationToken::new();
queue.incoming.register(str_id.clone(), 42, token);
assert!(queue.incoming.is_pending(&str_id));
assert_eq!(queue.incoming.complete(&str_id), Some(42));
}
use super::parse_cancel_params;
#[test]
fn parse_cancel_params_integer_id() {
let params = Some(serde_json::json!({"id": 42}));
let id = parse_cancel_params(¶ms);
assert_eq!(id, Some(RequestId::Integer(42)));
}
#[test]
fn parse_cancel_params_string_id() {
let params = Some(serde_json::json!({"id": "request-123"}));
let id = parse_cancel_params(¶ms);
assert_eq!(id, Some(RequestId::String("request-123".to_string())));
}
#[test]
fn parse_cancel_params_missing_params() {
let id = parse_cancel_params(&None);
assert!(id.is_none());
}
#[test]
fn parse_cancel_params_missing_id_field() {
let params = Some(serde_json::json!({"other": "field"}));
let id = parse_cancel_params(¶ms);
assert!(id.is_none());
}
#[test]
fn parse_cancel_params_invalid_id_type() {
let params = Some(serde_json::json!({"id": true}));
let id = parse_cancel_params(¶ms);
assert!(id.is_none());
}
#[test]
fn parse_cancel_params_null_id() {
let params = Some(serde_json::json!({"id": null}));
let id = parse_cancel_params(¶ms);
assert!(id.is_none());
}
}