use crate::error::Result;
use crate::runtime::oneshot;
use crate::runtime::{self, Mutex};
use crate::types::{JSONRPCResponse, RequestId};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
pub type ProgressCallback = Box<dyn Fn(u64, Option<u64>) + Send + Sync>;
#[derive(Debug, Clone, Default)]
pub struct ProtocolOptions {
pub enforce_strict_capabilities: bool,
pub debounced_notification_methods: Vec<String>,
}
#[derive(Default)]
pub struct RequestOptions {
pub timeout: Option<Duration>,
pub on_progress: Option<ProgressCallback>,
}
impl std::fmt::Debug for RequestOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RequestOptions")
.field("timeout", &self.timeout)
.field(
"on_progress",
&self.on_progress.as_ref().map(|_| "<callback>"),
)
.finish()
}
}
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct TransportId(Arc<str>);
impl TransportId {
pub fn new() -> Self {
Self(Arc::from(uuid::Uuid::new_v4().to_string()))
}
pub fn from_string(s: String) -> Self {
Self(Arc::from(s))
}
}
impl Default for TransportId {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
struct RequestContext {
transport_id: TransportId,
sender: Arc<Mutex<Option<oneshot::Sender<JSONRPCResponse>>>>,
}
#[derive(Debug)]
pub struct Protocol {
options: ProtocolOptions,
pending_requests: HashMap<RequestId, RequestContext>,
transport_id: TransportId,
}
impl Protocol {
pub fn new(options: ProtocolOptions) -> Self {
Self {
options,
pending_requests: HashMap::new(),
transport_id: TransportId::new(),
}
}
pub fn with_transport_id(options: ProtocolOptions, transport_id: TransportId) -> Self {
Self {
options,
pending_requests: HashMap::new(),
transport_id,
}
}
pub fn options(&self) -> &ProtocolOptions {
&self.options
}
pub fn transport_id(&self) -> &TransportId {
&self.transport_id
}
pub fn register_request(&mut self, id: RequestId) -> oneshot::Receiver<JSONRPCResponse> {
let (tx, rx) = oneshot::channel();
let context = RequestContext {
transport_id: self.transport_id.clone(),
sender: Arc::new(Mutex::new(Some(tx))),
};
self.pending_requests.insert(id, context);
rx
}
pub fn complete_request(&mut self, id: &RequestId, response: JSONRPCResponse) -> Result<()> {
if let Some(context) = self.pending_requests.remove(id) {
if context.transport_id == self.transport_id {
let sender = context.sender;
runtime::spawn(async move {
let tx_option = sender.lock().await.take();
if let Some(tx) = tx_option {
let _ = tx.send(response);
}
});
} else {
self.pending_requests.insert(id.clone(), context);
}
}
Ok(())
}
pub fn complete_request_for_transport(
&mut self,
id: &RequestId,
response: JSONRPCResponse,
transport_id: &TransportId,
) -> Result<bool> {
if let Some(context) = self.pending_requests.get(id) {
if &context.transport_id == transport_id {
if let Some(context) = self.pending_requests.remove(id) {
let sender = context.sender;
runtime::spawn(async move {
let tx_option = sender.lock().await.take();
if let Some(tx) = tx_option {
let _ = tx.send(response);
}
});
return Ok(true);
}
}
}
Ok(false)
}
pub fn cancel_request(&mut self, id: &RequestId) {
self.pending_requests.remove(id);
}
}
#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests {
use super::*;
#[test]
fn test_protocol_options() {
let options = ProtocolOptions {
enforce_strict_capabilities: true,
debounced_notification_methods: vec!["test".to_string()],
};
assert!(options.enforce_strict_capabilities);
assert_eq!(options.debounced_notification_methods, vec!["test"]);
let default_options = ProtocolOptions::default();
assert!(!default_options.enforce_strict_capabilities);
assert!(default_options.debounced_notification_methods.is_empty());
}
#[test]
fn test_request_options() {
let options = RequestOptions {
timeout: Some(Duration::from_secs(30)),
on_progress: None,
};
assert_eq!(options.timeout, Some(Duration::from_secs(30)));
assert!(options.on_progress.is_none());
let debug_str = format!("{:?}", options);
assert!(debug_str.contains("timeout: Some"));
}
#[test]
fn test_protocol_creation() {
let options = ProtocolOptions::default();
let protocol = Protocol::new(options);
assert!(!protocol.options().enforce_strict_capabilities);
assert_eq!(protocol.pending_requests.len(), 0);
}
#[tokio::test]
async fn test_register_and_complete_request() {
let mut protocol = Protocol::new(ProtocolOptions::default());
let id = RequestId::Number(42);
let mut rx = protocol.register_request(id.clone());
assert_eq!(protocol.pending_requests.len(), 1);
let response = JSONRPCResponse::success(id.clone(), serde_json::json!("success"));
protocol.complete_request(&id, response.clone()).unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let received = rx.try_recv().unwrap();
assert_eq!(received.result(), Some(&serde_json::json!("success")));
}
#[test]
fn test_cancel_request() {
let mut protocol = Protocol::new(ProtocolOptions::default());
let id1 = RequestId::Number(1);
let id2 = RequestId::String("req-2".to_string());
let _rx1 = protocol.register_request(id1.clone());
let _rx2 = protocol.register_request(id2.clone());
assert_eq!(protocol.pending_requests.len(), 2);
protocol.cancel_request(&id1);
assert_eq!(protocol.pending_requests.len(), 1);
assert!(!protocol.pending_requests.contains_key(&id1));
assert!(protocol.pending_requests.contains_key(&id2));
protocol.cancel_request(&RequestId::Number(999));
assert_eq!(protocol.pending_requests.len(), 1);
}
#[tokio::test]
async fn test_complete_non_existent_request() {
let mut protocol = Protocol::new(ProtocolOptions::default());
let id = RequestId::String("non-existent".to_string());
let response = JSONRPCResponse::success(id.clone(), serde_json::json!("test"));
let result = protocol.complete_request(&id, response);
assert!(result.is_ok());
}
#[tokio::test]
async fn test_multiple_pending_requests() {
let mut protocol = Protocol::new(ProtocolOptions::default());
let ids: Vec<_> = (0..5).map(RequestId::Number).collect();
let _receivers: Vec<_> = ids
.iter()
.map(|id| protocol.register_request(id.clone()))
.collect();
assert_eq!(protocol.pending_requests.len(), 5);
for (i, id) in ids.iter().enumerate().rev() {
let response = JSONRPCResponse::success(id.clone(), serde_json::json!(i));
protocol.complete_request(id, response).unwrap();
}
assert_eq!(protocol.pending_requests.len(), 0);
}
#[test]
fn test_request_options_with_progress() {
let called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let called_clone = called.clone();
let options = RequestOptions {
timeout: Some(Duration::from_millis(100)),
on_progress: Some(Box::new(move |current, total| {
called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
assert_eq!(current, 50);
assert_eq!(total, Some(100));
})),
};
if let Some(cb) = &options.on_progress {
cb(50, Some(100));
}
assert!(called.load(std::sync::atomic::Ordering::SeqCst));
}
#[test]
fn test_protocol_with_enforced_capabilities() {
let options = ProtocolOptions {
enforce_strict_capabilities: true,
debounced_notification_methods: vec![
"notifications/progress".to_string(),
"notifications/cancelled".to_string(),
],
};
let protocol = Protocol::new(options);
assert!(protocol.options().enforce_strict_capabilities);
assert_eq!(protocol.options().debounced_notification_methods.len(), 2);
}
}