mod grpc;
mod websocket;
pub use grpc::GrpcTransport;
pub use websocket::WebSocketTransport;
use crate::error::Result;
use async_trait::async_trait;
use std::fmt;
use strike48_proto::proto::StreamMessage;
use tokio::sync::mpsc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum TransportType {
#[default]
Grpc,
WebSocket,
}
impl fmt::Display for TransportType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TransportType::Grpc => write!(f, "gRPC"),
TransportType::WebSocket => write!(f, "WebSocket"),
}
}
}
#[derive(Debug, Clone)]
pub struct TransportOptions {
pub host: String,
pub use_tls: bool,
pub connect_timeout_ms: Option<u64>,
pub default_timeout_ms: Option<u64>,
pub channel_capacity: Option<usize>,
}
#[async_trait]
pub trait Transport: Send + Sync {
#[allow(dead_code)]
fn transport_type(&self) -> TransportType;
async fn connect(&mut self) -> Result<()>;
async fn start_stream(
&mut self,
initial_message: Option<StreamMessage>,
) -> Result<(
mpsc::UnboundedSender<StreamMessage>,
mpsc::UnboundedReceiver<StreamMessage>,
)>;
#[allow(dead_code)]
fn is_connected(&self) -> bool;
async fn disconnect(&mut self) -> Result<()>;
}
pub fn create_transport(
transport_type: TransportType,
options: TransportOptions,
) -> Box<dyn Transport> {
match transport_type {
TransportType::Grpc => Box::new(GrpcTransport::new(options)),
TransportType::WebSocket => Box::new(WebSocketTransport::new(options)),
}
}
pub(crate) fn create_unbounded_wrapper(
bounded_tx: mpsc::Sender<StreamMessage>,
bounded_rx: mpsc::Receiver<StreamMessage>,
) -> (
mpsc::UnboundedSender<StreamMessage>,
mpsc::UnboundedReceiver<StreamMessage>,
Vec<tokio::task::JoinHandle<()>>,
) {
use tracing::debug;
let (unbounded_tx, mut unbounded_rx_inner) = mpsc::unbounded_channel::<StreamMessage>();
let (unbounded_tx_out, unbounded_rx) = mpsc::unbounded_channel::<StreamMessage>();
let h1 = tokio::spawn(async move {
while let Some(msg) = unbounded_rx_inner.recv().await {
if bounded_tx.send(msg).await.is_err() {
debug!("Bounded channel closed, stopping forwarder");
break;
}
}
});
let mut bounded_rx = bounded_rx;
let h2 = tokio::spawn(async move {
while let Some(msg) = bounded_rx.recv().await {
if unbounded_tx_out.send(msg).is_err() {
debug!("Unbounded channel closed, stopping forwarder");
break;
}
}
});
(unbounded_tx, unbounded_rx, vec![h1, h2])
}