use alloc::boxed::Box;
use alloc::collections::{BTreeMap, VecDeque};
use alloc::vec::Vec;
use zerodds_cdr::{BufferWriter, EncodeError, Endianness};
use zerodds_types::type_lookup::{
ContinuationPoint, GetTypeDependenciesReply, GetTypeDependenciesRequest, GetTypesReply,
GetTypesRequest,
};
use zerodds_types::{EquivalenceHash, TypeIdentifier};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct RequestId(pub u64);
impl RequestId {
#[must_use]
pub fn from_u64(v: u64) -> Self {
Self(v)
}
}
#[derive(Debug, Clone)]
pub enum TypeLookupReply {
Types(GetTypesReply),
Dependencies(GetTypeDependenciesReply),
}
pub type ClientCallback = Box<dyn FnMut(TypeLookupReply) + Send>;
struct Pending {
callback: ClientCallback,
}
impl core::fmt::Debug for Pending {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Pending").finish()
}
}
#[derive(Debug)]
pub struct TypeLookupClient {
pending: BTreeMap<RequestId, Pending>,
pending_order: VecDeque<RequestId>,
next_seq: u64,
max_pending: usize,
}
impl TypeLookupClient {
pub const DEFAULT_MAX_PENDING: usize = 256;
#[must_use]
pub fn new() -> Self {
Self::with_capacity(Self::DEFAULT_MAX_PENDING)
}
#[must_use]
pub fn with_capacity(max_pending: usize) -> Self {
Self {
pending: BTreeMap::new(),
pending_order: VecDeque::new(),
next_seq: 1,
max_pending: max_pending.max(1),
}
}
#[must_use]
pub fn pending_count(&self) -> usize {
self.pending.len()
}
pub fn request_types(
&mut self,
_ids: Vec<TypeIdentifier>,
callback: ClientCallback,
) -> RequestId {
self.alloc_pending(callback)
}
pub fn request_type_dependencies(
&mut self,
_ids: Vec<TypeIdentifier>,
_continuation_point: ContinuationPoint,
callback: ClientCallback,
) -> RequestId {
self.alloc_pending(callback)
}
fn alloc_pending(&mut self, callback: ClientCallback) -> RequestId {
let id = RequestId(self.next_seq);
self.next_seq = self.next_seq.saturating_add(1);
while self.pending.len() >= self.max_pending {
if let Some(old) = self.pending_order.pop_front() {
self.pending.remove(&old);
} else {
break;
}
}
self.pending.insert(id, Pending { callback });
self.pending_order.push_back(id);
id
}
pub fn handle_reply(&mut self, request_id: RequestId, reply: TypeLookupReply) -> bool {
let Some(mut entry) = self.pending.remove(&request_id) else {
return false;
};
if let Some(pos) = self.pending_order.iter().position(|x| *x == request_id) {
self.pending_order.remove(pos);
}
(entry.callback)(reply);
true
}
pub fn clear(&mut self) {
self.pending.clear();
self.pending_order.clear();
}
}
impl Default for TypeLookupClient {
fn default() -> Self {
Self::new()
}
}
pub fn request_types_payload(ids: &[TypeIdentifier]) -> Result<Vec<u8>, EncodeError> {
let req = GetTypesRequest {
type_ids: ids.to_vec(),
};
let mut w = BufferWriter::new(Endianness::Little);
req.encode_into(&mut w)?;
Ok(w.into_bytes())
}
pub fn request_dependencies_payload(
ids: &[TypeIdentifier],
continuation_point: ContinuationPoint,
) -> Result<Vec<u8>, EncodeError> {
let req = GetTypeDependenciesRequest {
type_ids: ids.to_vec(),
continuation_point,
};
let mut w = BufferWriter::new(Endianness::Little);
req.encode_into(&mut w)?;
Ok(w.into_bytes())
}
#[must_use]
pub fn hashes_to_minimal_ids(hashes: &[EquivalenceHash]) -> Vec<TypeIdentifier> {
hashes
.iter()
.map(|h| TypeIdentifier::EquivalenceHashMinimal(*h))
.collect()
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use core::cell::RefCell;
extern crate std;
use std::sync::Arc;
use std::sync::Mutex;
#[test]
fn request_id_unique_and_monotone() {
let mut c = TypeLookupClient::new();
let id1 = c.request_types(alloc::vec![], Box::new(|_| {}));
let id2 = c.request_types(alloc::vec![], Box::new(|_| {}));
let id3 = c.request_types(alloc::vec![], Box::new(|_| {}));
assert!(id1 < id2);
assert!(id2 < id3);
}
#[test]
fn handle_reply_unknown_id_is_ignored() {
let mut c = TypeLookupClient::new();
let consumed = c.handle_reply(
RequestId(99),
TypeLookupReply::Types(GetTypesReply::default()),
);
assert!(!consumed);
}
#[test]
fn handle_reply_invokes_callback() {
let calls = Arc::new(Mutex::new(0u32));
let calls_clone = Arc::clone(&calls);
let mut c = TypeLookupClient::new();
let id = c.request_types(
alloc::vec![],
Box::new(move |_| {
*calls_clone.lock().unwrap() += 1;
}),
);
assert_eq!(*calls.lock().unwrap(), 0);
let consumed = c.handle_reply(id, TypeLookupReply::Types(GetTypesReply::default()));
assert!(consumed);
assert_eq!(*calls.lock().unwrap(), 1);
assert_eq!(c.pending_count(), 0);
}
#[test]
fn double_reply_runs_callback_only_once() {
let calls = Arc::new(Mutex::new(0u32));
let calls_clone = Arc::clone(&calls);
let mut c = TypeLookupClient::new();
let id = c.request_types(
alloc::vec![],
Box::new(move |_| {
*calls_clone.lock().unwrap() += 1;
}),
);
c.handle_reply(id, TypeLookupReply::Types(GetTypesReply::default()));
c.handle_reply(id, TypeLookupReply::Types(GetTypesReply::default()));
assert_eq!(*calls.lock().unwrap(), 1);
}
#[test]
fn pending_cap_evicts_oldest() {
let mut c = TypeLookupClient::with_capacity(2);
let _id1 = c.request_types(alloc::vec![], Box::new(|_| {}));
let id2 = c.request_types(alloc::vec![], Box::new(|_| {}));
let id3 = c.request_types(alloc::vec![], Box::new(|_| {}));
assert_eq!(c.pending_count(), 2);
assert!(c.pending.contains_key(&id2));
assert!(c.pending.contains_key(&id3));
}
#[test]
fn clear_drops_all_pending() {
let mut c = TypeLookupClient::new();
c.request_types(alloc::vec![], Box::new(|_| {}));
c.request_types(alloc::vec![], Box::new(|_| {}));
assert_eq!(c.pending_count(), 2);
c.clear();
assert_eq!(c.pending_count(), 0);
}
#[test]
fn request_types_payload_roundtrips() {
let ids = alloc::vec![
TypeIdentifier::EquivalenceHashMinimal(EquivalenceHash([0x55; 14])),
TypeIdentifier::Primitive(zerodds_types::PrimitiveKind::Int32),
];
let bytes = request_types_payload(&ids).unwrap();
assert!(bytes.len() >= 4);
}
#[test]
fn dependencies_payload_carries_continuation() {
let ids = alloc::vec![TypeIdentifier::EquivalenceHashMinimal(EquivalenceHash(
[0x77; 14]
))];
let cp = ContinuationPoint(alloc::vec![1, 2, 3]);
let bytes = request_dependencies_payload(&ids, cp).unwrap();
assert!(!bytes.is_empty());
}
#[test]
fn hashes_to_minimal_ids_maps_each() {
let hashes = alloc::vec![EquivalenceHash([1; 14]), EquivalenceHash([2; 14])];
let ids = hashes_to_minimal_ids(&hashes);
assert_eq!(ids.len(), 2);
assert!(matches!(ids[0], TypeIdentifier::EquivalenceHashMinimal(_)));
}
#[test]
fn callback_can_mutate_via_arc_mutex() {
let _: RefCell<i32> = RefCell::new(0); }
}