use std::sync::atomic::AtomicI64;
use crate::schema::{schema_utils::McpMessage, RequestId};
use async_trait::async_trait;
#[async_trait]
pub trait RequestIdGen: Send + Sync {
fn next_request_id(&self) -> RequestId;
#[allow(unused)]
fn last_request_id(&self) -> Option<RequestId>;
#[allow(unused)]
fn reset_to(&self, id: u64);
fn request_id_for_message(
&self,
message: &dyn McpMessage,
request_id: Option<RequestId>,
) -> Option<RequestId> {
if message.is_request() {
assert!(request_id.is_none());
Some(self.next_request_id())
} else if !message.is_notification() {
assert!(request_id.is_some());
request_id
} else {
None
}
}
}
pub struct RequestIdGenNumeric {
message_id_counter: AtomicI64,
last_message_id: AtomicI64,
}
impl RequestIdGenNumeric {
pub fn new(initial_id: Option<u64>) -> Self {
Self {
message_id_counter: AtomicI64::new(initial_id.unwrap_or(0) as i64),
last_message_id: AtomicI64::new(-1),
}
}
}
impl RequestIdGen for RequestIdGenNumeric {
fn next_request_id(&self) -> RequestId {
let id = self
.message_id_counter
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
self.last_message_id
.store(id, std::sync::atomic::Ordering::Relaxed);
RequestId::Integer(id)
}
fn last_request_id(&self) -> Option<RequestId> {
let last_id = self
.last_message_id
.load(std::sync::atomic::Ordering::Relaxed);
if last_id == -1 {
None
} else {
Some(RequestId::Integer(last_id))
}
}
fn reset_to(&self, id: u64) {
self.message_id_counter
.store(id as i64, std::sync::atomic::Ordering::Relaxed);
}
}