use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use tokio::sync::{Mutex, mpsc};
use tonic::Streaming;
use tonic::transport::{Channel, Endpoint};
use crate::error::{ConnectorError, Result};
use crate::multi::MultiTransportOptions;
use strike48_proto::proto::{StreamMessage, connector_service_client::ConnectorServiceClient};
struct ChannelEntry {
channel: Channel,
active_streams: Arc<AtomicUsize>,
}
pub(crate) struct SharedChannel {
opts: MultiTransportOptions,
channels: Mutex<Vec<ChannelEntry>>,
dial_lock: Mutex<()>,
}
impl SharedChannel {
pub(crate) fn new(opts: MultiTransportOptions) -> Self {
Self {
opts,
channels: Mutex::new(Vec::new()),
dial_lock: Mutex::new(()),
}
}
#[cfg(test)]
pub(crate) async fn channel_count(&self) -> usize {
self.channels.lock().await.len()
}
pub(crate) async fn open_stream(
&self,
initial_message: StreamMessage,
outbound_capacity: usize,
) -> Result<SharedStream> {
let entry = self.acquire_channel().await?;
let channel = entry.channel.clone();
let counter = entry.active_streams.clone();
let mut client = ConnectorServiceClient::new(channel);
let (tx, mut rx) = mpsc::channel::<StreamMessage>(outbound_capacity);
let outbound = async_stream::stream! {
yield initial_message;
while let Some(msg) = rx.recv().await {
yield msg;
}
};
let response = client.connect(outbound).await.map_err(|status| {
ConnectorError::ConnectionError(format!(
"failed to open shared-channel stream: {status}"
))
})?;
counter.fetch_add(1, Ordering::SeqCst);
Ok(SharedStream {
tx,
inbound: response.into_inner(),
_release: ChannelStreamGuard { counter },
})
}
async fn acquire_channel(&self) -> Result<ChannelEntry> {
{
let channels = self.channels.lock().await;
if let Some(hit) = pick_under_cap(&channels, self.opts.max_streams_per_channel) {
return Ok(hit);
}
}
let _dial_guard = self.dial_lock.lock().await;
{
let channels = self.channels.lock().await;
if let Some(hit) = pick_under_cap(&channels, self.opts.max_streams_per_channel) {
return Ok(hit);
}
}
let endpoint = build_endpoint(&self.opts)?;
let channel = endpoint.connect().await.map_err(|e| {
ConnectorError::ConnectionError(format!(
"failed to connect tonic channel to {}: {e}",
self.opts.host
))
})?;
let mut channels = self.channels.lock().await;
if let Some(hit) = pick_under_cap(&channels, self.opts.max_streams_per_channel) {
return Ok(hit);
}
let entry = ChannelEntry {
channel,
active_streams: Arc::new(AtomicUsize::new(0)),
};
let cloned = ChannelEntry {
channel: entry.channel.clone(),
active_streams: entry.active_streams.clone(),
};
channels.push(entry);
Ok(cloned)
}
}
fn pick_under_cap(channels: &[ChannelEntry], cap: usize) -> Option<ChannelEntry> {
channels.iter().find_map(|entry| {
if entry.active_streams.load(Ordering::SeqCst) < cap {
Some(ChannelEntry {
channel: entry.channel.clone(),
active_streams: entry.active_streams.clone(),
})
} else {
None
}
})
}
pub(crate) struct SharedStream {
#[allow(dead_code)]
pub tx: mpsc::Sender<StreamMessage>,
pub inbound: Streaming<StreamMessage>,
_release: ChannelStreamGuard,
}
struct ChannelStreamGuard {
counter: Arc<AtomicUsize>,
}
impl Drop for ChannelStreamGuard {
fn drop(&mut self) {
self.counter.fetch_sub(1, Ordering::SeqCst);
}
}
fn build_endpoint(opts: &MultiTransportOptions) -> Result<Endpoint> {
let scheme = if opts.use_tls { "https" } else { "http" };
let url = format!("{scheme}://{}", opts.host);
let endpoint = Endpoint::from_shared(url.clone())
.map_err(|e| ConnectorError::InvalidConfig(format!("invalid endpoint url '{url}': {e}")))?
.keep_alive_while_idle(true)
.http2_keep_alive_interval(Duration::from_secs(30))
.keep_alive_timeout(Duration::from_secs(10))
.connect_timeout(Duration::from_millis(opts.connect_timeout_ms));
Ok(endpoint)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::TransportType;
#[tokio::test]
async fn shared_channel_starts_with_no_channels() {
let opts = MultiTransportOptions {
host: "localhost:50061".into(),
transport_type: TransportType::Grpc,
..Default::default()
};
let sc = SharedChannel::new(opts);
assert_eq!(sc.channel_count().await, 0);
}
}