strike48-connector 0.3.9

Rust SDK for the Strike48 Connector Framework
Documentation
//! Shared HTTP/2 channel for `MultiConnectorRunner` (gRPC mode).
//!
//! Wraps one or more `tonic::Channel`s and hands out new bidi `Connect`
//! streams. The headline guarantee: until the soft cap
//! [`crate::MultiTransportOptions::max_streams_per_channel`] is reached, every
//! new stream is multiplexed onto the same TCP+HTTP/2 connection. Once the
//! cap is hit a new channel is opened lazily and subsequent streams use it
//! (overflow path, lands in a later phase).
//!
//! WebSocket transport is handled separately — `SharedChannel` is gRPC-only.

use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;

use tokio::sync::{Mutex, mpsc};
use tonic::Streaming;
use tonic::transport::{Channel, Endpoint};

use crate::error::{ConnectorError, Result};
use crate::multi::MultiTransportOptions;
use strike48_proto::proto::{StreamMessage, connector_service_client::ConnectorServiceClient};

/// One underlying `tonic::Channel` and the count of streams currently using it.
struct ChannelEntry {
    channel: Channel,
    active_streams: Arc<AtomicUsize>,
}

/// Pool of `tonic::Channel`s. A fresh `MultiConnectorRunner` starts with no
/// channels at all; the first `open_stream` call lazily connects one. When
/// the per-channel stream count reaches `opts.max_streams_per_channel` a new
/// channel is opened and used for subsequent streams.
pub(crate) struct SharedChannel {
    opts: MultiTransportOptions,
    channels: Mutex<Vec<ChannelEntry>>,
    /// Serializes "open a brand-new channel" so a burst of concurrent
    /// callers that each see an empty (or fully-saturated) pool dedupe to
    /// a single dial. The pool lock is NOT held across the dial; only this
    /// dial mutex is, which means callers that find an under-cap channel
    /// in the fast path are never blocked on a slow handshake.
    dial_lock: Mutex<()>,
}

impl SharedChannel {
    pub(crate) fn new(opts: MultiTransportOptions) -> Self {
        Self {
            opts,
            channels: Mutex::new(Vec::new()),
            dial_lock: Mutex::new(()),
        }
    }

    /// Number of underlying TCP connections currently held.
    /// Useful for tests; not part of the public API.
    #[cfg(test)]
    pub(crate) async fn channel_count(&self) -> usize {
        self.channels.lock().await.len()
    }

    /// Open a new bidi `Connect` stream multiplexed on the next available
    /// channel (creating one if needed).
    ///
    /// The `initial_message` is folded into the outbound stream BEFORE
    /// `tonic::client.connect` awaits the response headers. This matches
    /// `transport::grpc::start_stream` and is required because matrix's
    /// elixir-grpc handler does not emit response headers until it receives
    /// at least one message — so calling `connect(empty_stream).await` would
    /// deadlock with the SDK's "send register on the channel after connect
    /// returns" pattern.
    pub(crate) async fn open_stream(
        &self,
        initial_message: StreamMessage,
        outbound_capacity: usize,
    ) -> Result<SharedStream> {
        let entry = self.acquire_channel().await?;
        let channel = entry.channel.clone();
        let counter = entry.active_streams.clone();

        let mut client = ConnectorServiceClient::new(channel);

        let (tx, mut rx) = mpsc::channel::<StreamMessage>(outbound_capacity);

        let outbound = async_stream::stream! {
            yield initial_message;
            while let Some(msg) = rx.recv().await {
                yield msg;
            }
        };

        let response = client.connect(outbound).await.map_err(|status| {
            ConnectorError::ConnectionError(format!(
                "failed to open shared-channel stream: {status}"
            ))
        })?;

        counter.fetch_add(1, Ordering::SeqCst);

        Ok(SharedStream {
            tx,
            inbound: response.into_inner(),
            _release: ChannelStreamGuard { counter },
        })
    }

    async fn acquire_channel(&self) -> Result<ChannelEntry> {
        // Fast path: a channel is already below the soft cap. The pool
        // lock is held only for the (sync) scan and immediately released.
        {
            let channels = self.channels.lock().await;
            if let Some(hit) = pick_under_cap(&channels, self.opts.max_streams_per_channel) {
                return Ok(hit);
            }
        }

        // Slow path: serialize concurrent dials so a burst of N callers
        // creates one channel, not N. The pool lock is NOT held across
        // `endpoint.connect().await` — only the dial lock is — so callers
        // with an under-cap channel still see them in the fast path even
        // while a handshake is in flight.
        let _dial_guard = self.dial_lock.lock().await;

        // Re-check after acquiring the dial lock: another caller may have
        // dialled while we waited.
        {
            let channels = self.channels.lock().await;
            if let Some(hit) = pick_under_cap(&channels, self.opts.max_streams_per_channel) {
                return Ok(hit);
            }
        }

        let endpoint = build_endpoint(&self.opts)?;
        let channel = endpoint.connect().await.map_err(|e| {
            ConnectorError::ConnectionError(format!(
                "failed to connect tonic channel to {}: {e}",
                self.opts.host
            ))
        })?;

        // Final re-check: between releasing the pool lock above and
        // finishing the dial, yet another caller may have produced an
        // under-cap channel. Prefer reuse to keep TCP usage minimal.
        let mut channels = self.channels.lock().await;
        if let Some(hit) = pick_under_cap(&channels, self.opts.max_streams_per_channel) {
            return Ok(hit);
        }

        let entry = ChannelEntry {
            channel,
            active_streams: Arc::new(AtomicUsize::new(0)),
        };
        let cloned = ChannelEntry {
            channel: entry.channel.clone(),
            active_streams: entry.active_streams.clone(),
        };
        channels.push(entry);
        Ok(cloned)
    }
}

/// Pick the first channel whose active stream count is strictly below
/// `cap`, returning a clone of the entry. Pure helper, no awaits.
fn pick_under_cap(channels: &[ChannelEntry], cap: usize) -> Option<ChannelEntry> {
    channels.iter().find_map(|entry| {
        if entry.active_streams.load(Ordering::SeqCst) < cap {
            Some(ChannelEntry {
                channel: entry.channel.clone(),
                active_streams: entry.active_streams.clone(),
            })
        } else {
            None
        }
    })
}

/// A single bidi stream borrowed from a [`SharedChannel`].
///
/// The initial register message is folded into the outbound stream by
/// [`SharedChannel::open_stream`]; subsequent outbound traffic flows through
/// `tx`. Inbound messages arrive via `inbound`. The `_release` field
/// decrements the per-channel stream count when the stream is dropped.
pub(crate) struct SharedStream {
    /// Sender for follow-up messages (heartbeats, execute responses, ...).
    /// Marked `allow(dead_code)` until the dispatch loop wires execute /
    /// heartbeat handling on top of `RegistrationRunner`.
    #[allow(dead_code)]
    pub tx: mpsc::Sender<StreamMessage>,
    pub inbound: Streaming<StreamMessage>,
    _release: ChannelStreamGuard,
}

struct ChannelStreamGuard {
    counter: Arc<AtomicUsize>,
}

impl Drop for ChannelStreamGuard {
    fn drop(&mut self) {
        self.counter.fetch_sub(1, Ordering::SeqCst);
    }
}

fn build_endpoint(opts: &MultiTransportOptions) -> Result<Endpoint> {
    let scheme = if opts.use_tls { "https" } else { "http" };
    let url = format!("{scheme}://{}", opts.host);
    let endpoint = Endpoint::from_shared(url.clone())
        .map_err(|e| ConnectorError::InvalidConfig(format!("invalid endpoint url '{url}': {e}")))?
        .keep_alive_while_idle(true)
        .http2_keep_alive_interval(Duration::from_secs(30))
        .keep_alive_timeout(Duration::from_secs(10))
        .connect_timeout(Duration::from_millis(opts.connect_timeout_ms));

    Ok(endpoint)
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::transport::TransportType;

    #[tokio::test]
    async fn shared_channel_starts_with_no_channels() {
        let opts = MultiTransportOptions {
            host: "localhost:50061".into(),
            transport_type: TransportType::Grpc,
            ..Default::default()
        };
        let sc = SharedChannel::new(opts);
        assert_eq!(sc.channel_count().await, 0);
    }
}