xidl-jsonrpc 0.28.0

A IDL codegen.
Documentation
use std::collections::VecDeque;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, LazyLock, Mutex};

use dashmap::DashMap;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};

use super::{Listener, Stream, loopback_peer_addr};

type InprocStream = tokio::io::DuplexStream;

struct BoundSlot {
    listener_id: u64,
    tx: UnboundedSender<InprocStream>,
}

#[derive(Default)]
struct EndpointEntry {
    bound: Option<BoundSlot>,
    pending: VecDeque<InprocStream>,
}

fn next_listener_id() -> u64 {
    static NEXT_ID: AtomicU64 = AtomicU64::new(1);
    NEXT_ID.fetch_add(1, Ordering::Relaxed)
}

type EndpointState = Arc<Mutex<EndpointEntry>>;
type Registry = DashMap<String, EndpointState>;
static REGISTRY: LazyLock<Registry> = LazyLock::new(DashMap::new);

pub struct InprocListener {
    listener_id: u64,
    endpoint: String,
    rx: tokio::sync::Mutex<UnboundedReceiver<InprocStream>>,
}

impl InprocListener {
    pub fn bind(endpoint: impl Into<String>) -> std::io::Result<Self> {
        let endpoint = endpoint.into();
        let listener_id = next_listener_id();
        let (tx, rx) = unbounded_channel();

        let entry = REGISTRY
            .entry(endpoint.clone())
            .or_insert_with(|| Arc::new(Mutex::new(EndpointEntry::default())))
            .clone();

        let mut guard = entry
            .lock()
            .map_err(|err| std::io::Error::other(err.to_string()))?;
        if guard.bound.is_some() {
            return Err(std::io::Error::new(
                std::io::ErrorKind::AddrInUse,
                format!("inproc endpoint already in use: {endpoint}"),
            ));
        }

        guard.bound = Some(BoundSlot {
            listener_id,
            tx: tx.clone(),
        });

        while let Some(stream) = guard.pending.pop_front() {
            if let Err(err) = tx.send(stream) {
                let failed = err.0;
                guard.bound = None;
                guard.pending.push_front(failed);
                return Err(std::io::Error::new(
                    std::io::ErrorKind::BrokenPipe,
                    format!("inproc listener channel closed for endpoint: {endpoint}"),
                ));
            }
        }
        drop(guard);

        Ok(Self {
            listener_id,
            endpoint,
            rx: tokio::sync::Mutex::new(rx),
        })
    }
}

impl Drop for InprocListener {
    fn drop(&mut self) {
        if let Some(state) = REGISTRY.get(&self.endpoint).map(|entry| entry.clone()) {
            let mut should_remove = false;
            if let Ok(mut entry) = state.lock() {
                if entry
                    .bound
                    .as_ref()
                    .map(|slot| slot.listener_id == self.listener_id)
                    .unwrap_or(false)
                {
                    entry.bound = None;
                }
                should_remove = entry.bound.is_none() && entry.pending.is_empty();
            }
            if should_remove {
                REGISTRY.remove(&self.endpoint);
            }
        }
    }
}

pub fn connect_inproc(endpoint: &str) -> std::io::Result<InprocStream> {
    let (client, server) = tokio::io::duplex(64 * 1024);
    let state = REGISTRY
        .entry(endpoint.to_string())
        .or_insert_with(|| Arc::new(Mutex::new(EndpointEntry::default())))
        .clone();
    let mut entry = state
        .lock()
        .map_err(|err| std::io::Error::other(err.to_string()))?;
    if let Some(bound) = entry.bound.as_ref() {
        if let Err(err) = bound.tx.send(server) {
            let failed = err.0;
            entry.bound = None;
            entry.pending.push_back(failed);
        }
    } else {
        entry.pending.push_back(server);
    }
    Ok(client)
}

#[async_trait::async_trait]
impl Listener for InprocListener {
    async fn accept(
        &self,
    ) -> std::io::Result<(Box<dyn Stream + Unpin + Send + 'static>, SocketAddr)> {
        let mut rx = self.rx.lock().await;
        let stream = rx.recv().await.ok_or_else(|| {
            std::io::Error::new(std::io::ErrorKind::BrokenPipe, "inproc listener closed")
        })?;
        Ok((Box::new(stream), loopback_peer_addr()))
    }

    fn endpoint(&self) -> Option<String> {
        Some(format!("inproc://{}", self.endpoint))
    }
}