use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use super::ids::{ExportId, IdAllocator, ImportId};
use crate::RpcTarget;
type PromiseSender =
Arc<tokio::sync::Mutex<Option<tokio::sync::watch::Sender<Option<Result<Value, Value>>>>>>;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum Value {
Null,
Bool(bool),
Number(serde_json::Number),
String(String),
Array(Vec<Value>),
Object(std::collections::HashMap<String, Box<Value>>),
Date(f64),
Error {
error_type: String,
message: String,
stack: Option<String>,
},
#[serde(skip)]
Stub(StubReference),
#[serde(skip)]
Promise(PromiseReference),
}
#[derive(Debug, Clone)]
pub struct StubReference {
pub id: String,
#[allow(dead_code)]
stub: Arc<dyn RpcTarget>,
}
impl StubReference {
pub fn new(stub: Arc<dyn RpcTarget>) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
stub,
}
}
pub fn get(&self) -> Arc<dyn RpcTarget> {
self.stub.clone()
}
}
impl PartialEq for StubReference {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct PromiseReference {
pub id: String,
}
#[derive(Debug)]
pub enum PromiseState {
Pending(tokio::sync::watch::Receiver<Option<Result<Value, Value>>>),
Resolved(Value),
Rejected(Value),
}
#[derive(Debug)]
pub struct ImportEntry {
pub value: ImportValue,
pub refcount: AtomicU32,
}
#[derive(Debug, Clone)]
pub enum ImportValue {
Stub(StubReference),
Promise(PromiseReference),
Value(Value),
}
#[derive(Debug)]
pub struct ExportEntry {
pub value: ExportValue,
pub export_count: AtomicU32,
}
#[derive(Debug)]
pub enum ExportValue {
Stub(StubReference),
Promise(PromiseSender),
Resolved(Value),
Rejected(Value),
}
#[derive(Debug)]
pub enum ExportValueRef {
Stub(StubReference),
Promise(PromiseSender),
Resolved(Value),
Rejected(Value),
}
#[derive(Debug)]
pub struct ImportTable {
allocator: Arc<IdAllocator>,
entries: DashMap<ImportId, ImportEntry>,
}
impl ImportTable {
pub fn new(allocator: Arc<IdAllocator>) -> Self {
Self {
allocator,
entries: DashMap::new(),
}
}
pub fn with_default_allocator() -> Self {
Self {
allocator: Arc::new(IdAllocator::new()),
entries: DashMap::new(),
}
}
pub fn allocate_local(&self) -> ImportId {
self.allocator.allocate_import()
}
pub fn insert(&self, id: ImportId, value: ImportValue) -> Result<(), TableError> {
let entry = ImportEntry {
value,
refcount: AtomicU32::new(1),
};
if self.entries.insert(id, entry).is_some() {
return Err(TableError::DuplicateImport(id));
}
Ok(())
}
pub fn get(&self, id: ImportId) -> Option<ImportValue> {
self.entries.get(&id).map(|entry| match &entry.value {
ImportValue::Stub(stub) => ImportValue::Stub(stub.clone()),
ImportValue::Promise(promise) => ImportValue::Promise(promise.clone()),
ImportValue::Value(val) => ImportValue::Value(val.clone()),
})
}
pub fn add_ref(&self, id: ImportId) -> Result<(), TableError> {
self.entries
.get(&id)
.map(|entry| {
entry.refcount.fetch_add(1, Ordering::SeqCst);
})
.ok_or(TableError::UnknownImport(id))
}
pub fn release(&self, id: ImportId, refcount: u32) -> Result<bool, TableError> {
let mut should_remove = false;
self.entries.alter(&id, |_key, entry| {
let current = entry.refcount.load(Ordering::SeqCst);
if current >= refcount {
let new_count = current - refcount;
entry.refcount.store(new_count, Ordering::SeqCst);
if new_count == 0 {
should_remove = true;
}
}
entry
});
if should_remove {
self.entries.remove(&id);
Ok(true)
} else {
Ok(false)
}
}
pub fn resolve_promise(&self, id: ImportId, value: Value) -> Result<(), TableError> {
self.entries.alter(&id, |_key, mut entry| {
if let ImportValue::Promise(_promise) = &mut entry.value {
entry.value = ImportValue::Value(value);
}
entry
});
Ok(())
}
}
#[derive(Debug)]
pub struct ExportTable {
allocator: Arc<IdAllocator>,
entries: DashMap<ExportId, ExportEntry>,
}
impl ExportTable {
pub fn new(allocator: Arc<IdAllocator>) -> Self {
Self {
allocator,
entries: DashMap::new(),
}
}
pub fn with_default_allocator() -> Self {
Self {
allocator: Arc::new(IdAllocator::new()),
entries: DashMap::new(),
}
}
pub fn allocate_local(&self) -> ExportId {
self.allocator.allocate_export()
}
pub fn insert(&self, id: ExportId, value: ExportValue) -> Result<(), TableError> {
let entry = ExportEntry {
value,
export_count: AtomicU32::new(1),
};
if self.entries.insert(id, entry).is_some() {
return Err(TableError::DuplicateExport(id));
}
Ok(())
}
pub fn export_stub(&self, stub: Arc<dyn RpcTarget>) -> ExportId {
let id = self.allocate_local();
let stub_ref = StubReference::new(stub);
let _ = self.insert(id, ExportValue::Stub(stub_ref));
id
}
pub fn export_promise(
&self,
) -> (
ExportId,
tokio::sync::watch::Receiver<Option<Result<Value, Value>>>,
) {
let id = self.allocate_local();
let (tx, rx) = tokio::sync::watch::channel(None);
let _ = self.insert(
id,
ExportValue::Promise(Arc::new(tokio::sync::Mutex::new(Some(tx)))),
);
(id, rx)
}
pub fn get(&self, id: ExportId) -> Option<ExportValueRef> {
self.entries.get(&id).map(|entry| match &entry.value {
ExportValue::Stub(stub) => ExportValueRef::Stub(stub.clone()),
ExportValue::Promise(promise) => ExportValueRef::Promise(promise.clone()),
ExportValue::Resolved(val) => ExportValueRef::Resolved(val.clone()),
ExportValue::Rejected(val) => ExportValueRef::Rejected(val.clone()),
})
}
pub async fn resolve(&self, id: ExportId, value: Value) -> Result<(), TableError> {
if let Some(mut entry) = self.entries.get_mut(&id) {
match &entry.value {
ExportValue::Promise(promise_sender) => {
if let Some(sender) = promise_sender.lock().await.take() {
let _ = sender.send(Some(Ok(value.clone())));
}
entry.value = ExportValue::Resolved(value);
}
_ => {
}
}
}
Ok(())
}
pub async fn reject(&self, id: ExportId, error: Value) -> Result<(), TableError> {
if let Some(mut entry) = self.entries.get_mut(&id) {
match &entry.value {
ExportValue::Promise(promise_sender) => {
if let Some(sender) = promise_sender.lock().await.take() {
let _ = sender.send(Some(Err(error.clone())));
}
entry.value = ExportValue::Rejected(error);
}
_ => {
}
}
}
Ok(())
}
pub fn add_export(&self, id: ExportId) -> Result<(), TableError> {
self.entries
.get(&id)
.map(|entry| {
entry.export_count.fetch_add(1, Ordering::SeqCst);
})
.ok_or(TableError::UnknownExport(id))
}
pub fn release(&self, id: ExportId) -> Result<bool, TableError> {
let mut should_remove = false;
self.entries.alter(&id, |_key, entry| {
let current = entry.export_count.load(Ordering::SeqCst);
if current > 0 {
let new_count = current - 1;
entry.export_count.store(new_count, Ordering::SeqCst);
if new_count == 0 {
should_remove = true;
}
}
entry
});
if should_remove {
self.entries.remove(&id);
Ok(true)
} else {
Ok(false)
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum TableError {
#[error("Duplicate import ID: {0}")]
DuplicateImport(ImportId),
#[error("Duplicate export ID: {0}")]
DuplicateExport(ExportId),
#[error("Unknown import ID: {0}")]
UnknownImport(ImportId),
#[error("Unknown export ID: {0}")]
UnknownExport(ExportId),
#[error("Cannot resolve non-promise export")]
NotAPromise,
#[error("Export already resolved")]
AlreadyResolved,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_import_table() {
let allocator = Arc::new(IdAllocator::new());
let table = ImportTable::new(allocator.clone());
let id = table.allocate_local();
assert_eq!(id, ImportId(1));
let stub = Arc::new(crate::MockRpcTarget::new());
let stub_ref = StubReference::new(stub);
table.insert(id, ImportValue::Stub(stub_ref)).unwrap();
match table.get(id).unwrap() {
ImportValue::Stub(_) => {}
_ => panic!("Expected stub"),
}
table.add_ref(id).unwrap();
assert!(!table.release(id, 1).unwrap()); assert!(table.release(id, 1).unwrap()); assert!(table.get(id).is_none());
}
#[tokio::test]
async fn test_export_table() {
let allocator = Arc::new(IdAllocator::new());
let table = ExportTable::new(allocator.clone());
let (id, mut rx) = table.export_promise();
assert_eq!(id, ExportId(-1));
table
.resolve(id, Value::String("result".to_string()))
.await
.unwrap();
rx.changed().await.unwrap();
match rx.borrow().as_ref().unwrap() {
Ok(Value::String(s)) => assert_eq!(s, "result"),
_ => panic!("Expected resolved string"),
}
match table.get(id).unwrap() {
ExportValueRef::Resolved(Value::String(s)) => assert_eq!(s, "result"),
_ => panic!("Expected resolved export"),
}
}
#[test]
fn test_stub_export() {
let allocator = Arc::new(IdAllocator::new());
let table = ExportTable::new(allocator.clone());
let stub = Arc::new(crate::MockRpcTarget::new());
let id = table.export_stub(stub.clone());
match table.get(id).unwrap() {
ExportValueRef::Stub(_) => {}
_ => panic!("Expected stub export"),
}
}
}