use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use tokio::sync::{Mutex, RwLock, Semaphore, watch};
use crate::connector::{BaseConnector, ConnectorConfig, ShutdownHandle};
use crate::error::{ConnectorError, Result};
use crate::logger::Logger;
use crate::transport::TransportType;
use crate::types::ConnectorMetrics;
mod registration_runner;
mod shared_channel;
use registration_runner::RegistrationRunner;
use shared_channel::SharedChannel;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct MultiTransportOptions {
pub host: String,
pub use_tls: bool,
pub transport_type: TransportType,
pub max_streams_per_channel: usize,
pub connect_timeout_ms: u64,
pub reconnect_enabled: bool,
pub reconnect_delay_ms: u64,
pub max_backoff_delay_ms: u64,
pub reconnect_jitter_ms: u64,
pub max_concurrent_requests: usize,
pub heartbeat_interval: Option<Duration>,
pub heartbeat_timeout: Option<Duration>,
}
impl MultiTransportOptions {
pub fn builder() -> MultiTransportOptionsBuilder {
MultiTransportOptionsBuilder::default()
}
}
impl Default for MultiTransportOptions {
fn default() -> Self {
Self {
host: "localhost:50061".to_string(),
use_tls: false,
transport_type: TransportType::Grpc,
max_streams_per_channel: 80,
connect_timeout_ms: 10_000,
reconnect_enabled: true,
reconnect_delay_ms: 500,
max_backoff_delay_ms: 60_000,
reconnect_jitter_ms: 500,
max_concurrent_requests: 100,
heartbeat_interval: None,
heartbeat_timeout: None,
}
}
}
fn validate_heartbeat_pair(interval: Option<Duration>, timeout: Option<Duration>) -> bool {
match (interval, timeout) {
(Some(i), Some(t)) if t < i => {
tracing::warn!(
target: "strike48_connector::heartbeat",
interval_ms = i.as_millis() as u64,
timeout_ms = t.as_millis() as u64,
"heartbeat_timeout < heartbeat_interval; the watchdog can fire before the first heartbeat reply has a chance to arrive"
);
false
}
_ => true,
}
}
#[derive(Debug, Clone, Default)]
pub struct MultiTransportOptionsBuilder {
inner: Option<MultiTransportOptions>,
}
impl MultiTransportOptionsBuilder {
fn opts(&mut self) -> &mut MultiTransportOptions {
self.inner
.get_or_insert_with(MultiTransportOptions::default)
}
pub fn host(mut self, host: impl Into<String>) -> Self {
self.opts().host = host.into();
self
}
pub fn use_tls(mut self, use_tls: bool) -> Self {
self.opts().use_tls = use_tls;
self
}
pub fn transport_type(mut self, t: TransportType) -> Self {
self.opts().transport_type = t;
self
}
pub fn max_streams_per_channel(mut self, n: usize) -> Self {
self.opts().max_streams_per_channel = n;
self
}
pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
self.opts().connect_timeout_ms = ms;
self
}
pub fn max_concurrent_requests(mut self, n: usize) -> Self {
self.opts().max_concurrent_requests = n.max(1);
self
}
pub fn reconnect_enabled(mut self, enabled: bool) -> Self {
self.opts().reconnect_enabled = enabled;
self
}
pub fn reconnect_delay_ms(mut self, ms: u64) -> Self {
self.opts().reconnect_delay_ms = ms;
self
}
pub fn max_backoff_delay_ms(mut self, ms: u64) -> Self {
self.opts().max_backoff_delay_ms = ms;
self
}
pub fn reconnect_jitter_ms(mut self, ms: u64) -> Self {
self.opts().reconnect_jitter_ms = ms;
self
}
pub fn heartbeat_interval(mut self, d: Duration) -> Self {
self.opts().heartbeat_interval = Some(d);
let _ = validate_heartbeat_pair(
self.opts().heartbeat_interval,
self.opts().heartbeat_timeout,
);
self
}
pub fn heartbeat_timeout(mut self, d: Duration) -> Self {
self.opts().heartbeat_timeout = Some(d);
let _ = validate_heartbeat_pair(
self.opts().heartbeat_interval,
self.opts().heartbeat_timeout,
);
self
}
pub fn build(mut self) -> MultiTransportOptions {
self.inner.take().unwrap_or_default()
}
}
#[non_exhaustive]
pub struct ConnectorRegistration {
pub config: ConnectorConfig,
pub connector: Arc<dyn BaseConnector>,
}
impl ConnectorRegistration {
pub fn new<T>(config: ConnectorConfig, connector: T) -> Self
where
T: BaseConnector + 'static,
{
Self {
config,
connector: Arc::new(connector) as Arc<dyn BaseConnector>,
}
}
pub fn from_arc(config: ConnectorConfig, connector: Arc<dyn BaseConnector>) -> Self {
Self { config, connector }
}
}
impl std::fmt::Debug for ConnectorRegistration {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnectorRegistration")
.field("config", &self.config)
.field(
"connector",
&format_args!(
"Arc<dyn BaseConnector>(\"{}\")",
self.connector.connector_type()
),
)
.finish()
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
#[non_exhaustive]
pub struct RegistrationKey {
pub tenant_id: String,
pub connector_type: String,
pub instance_id: String,
}
impl RegistrationKey {
pub fn from_config(config: &ConnectorConfig) -> Self {
Self {
tenant_id: config.tenant_id.clone(),
connector_type: config.connector_type.clone(),
instance_id: config.instance_id.clone(),
}
}
}
impl std::fmt::Display for RegistrationKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}.{}.{}",
self.tenant_id, self.connector_type, self.instance_id
)
}
}
struct RegistrationEntry {
key: RegistrationKey,
config: ConnectorConfig,
connector: Arc<dyn BaseConnector>,
metrics: Arc<Mutex<ConnectorMetrics>>,
}
pub struct MultiConnectorRunner {
opts: MultiTransportOptions,
registrations: RwLock<Vec<RegistrationEntry>>,
shutdown_requested: Arc<AtomicBool>,
running: Arc<AtomicBool>,
}
impl MultiConnectorRunner {
pub fn new(opts: MultiTransportOptions, registrations: Vec<ConnectorRegistration>) -> Self {
let mut entries: Vec<RegistrationEntry> = Vec::with_capacity(registrations.len());
for ConnectorRegistration { config, connector } in registrations {
let key = RegistrationKey::from_config(&config);
if entries.iter().any(|e| e.key == key) {
tracing::warn!(
target: "strike48_connector::multi",
registration = %key,
"duplicate registration ignored"
);
continue;
}
entries.push(RegistrationEntry {
key,
config,
connector,
metrics: Arc::new(Mutex::new(ConnectorMetrics::default())),
});
}
Self {
opts,
registrations: RwLock::new(entries),
shutdown_requested: Arc::new(AtomicBool::new(false)),
running: Arc::new(AtomicBool::new(false)),
}
}
pub async fn add(&self, registration: ConnectorRegistration) -> Result<()> {
let key = RegistrationKey::from_config(®istration.config);
let mut regs = self.registrations.write().await;
if self.running.load(Ordering::SeqCst) {
return Err(ConnectorError::AlreadyRunning);
}
if regs.iter().any(|e| e.key == key) {
return Err(ConnectorError::InvalidConfig(format!(
"duplicate registration: {key}"
)));
}
regs.push(RegistrationEntry {
key,
config: registration.config,
connector: registration.connector,
metrics: Arc::new(Mutex::new(ConnectorMetrics::default())),
});
Ok(())
}
pub fn shutdown_handle(&self) -> ShutdownHandle {
ShutdownHandle::from_flag(self.shutdown_requested.clone())
}
pub async fn registrations(&self) -> Vec<RegistrationKey> {
self.registrations
.read()
.await
.iter()
.map(|e| e.key.clone())
.collect()
}
pub async fn metrics_snapshot(&self) -> HashMap<RegistrationKey, ConnectorMetrics> {
let regs = self.registrations.read().await;
let mut out = HashMap::with_capacity(regs.len());
for entry in regs.iter() {
let snapshot = entry.metrics.lock().await.clone();
out.insert(entry.key.clone(), snapshot);
}
out
}
pub async fn run(&self) -> Result<()> {
let logger = Logger::new("multi");
let entries: Vec<RegistrationEntry> = {
let regs = self.registrations.write().await;
if self
.running
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_err()
{
return Err(ConnectorError::AlreadyRunning);
}
regs.iter()
.map(|e| RegistrationEntry {
key: e.key.clone(),
config: e.config.clone(),
connector: e.connector.clone(),
metrics: e.metrics.clone(),
})
.collect()
};
if self.shutdown_requested.load(Ordering::SeqCst) {
logger.debug("shutdown signalled before run; exiting");
self.running.store(false, Ordering::SeqCst);
return Ok(());
}
if entries.is_empty() {
logger.warn("no registrations configured; run() exiting immediately");
self.running.store(false, Ordering::SeqCst);
return Ok(());
}
let result = match self.opts.transport_type {
TransportType::Grpc => self.run_grpc(entries, logger).await,
TransportType::WebSocket => self.run_websocket(entries, logger).await,
};
self.running.store(false, Ordering::SeqCst);
result
}
async fn run_grpc(&self, entries: Vec<RegistrationEntry>, logger: Logger) -> Result<()> {
let shared = Arc::new(SharedChannel::new(self.opts.clone()));
let mut tasks = Vec::with_capacity(entries.len());
for entry in entries {
let runner = RegistrationRunner {
key: entry.key.clone(),
config: Arc::new(RwLock::new(entry.config)),
connector: entry.connector,
shared_channel: shared.clone(),
shutdown: self.shutdown_requested.clone(),
metrics: entry.metrics,
opts: self.opts.clone(),
request_semaphore: Arc::new(Semaphore::new(
self.opts.max_concurrent_requests.max(1),
)),
session_token: Arc::new(RwLock::new(None)),
};
tasks.push(tokio::spawn(async move { runner.run().await }));
}
for task in tasks {
match task.await {
Ok(Ok(())) => {}
Ok(Err(e)) => {
logger.warn(&format!("registration runner exited with error: {e}"));
}
Err(join_err) => {
logger.error("registration task panicked", &join_err.to_string());
}
}
}
Ok(())
}
async fn run_websocket(&self, entries: Vec<RegistrationEntry>, logger: Logger) -> Result<()> {
use crate::ConnectorRunner;
let (shutdown_tx, shutdown_rx_template) = watch::channel(false);
let multi_shutdown = self.shutdown_requested.clone();
let bridge_tx = shutdown_tx.clone();
let bridge = tokio::spawn(async move {
while !multi_shutdown.load(Ordering::SeqCst) {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
if bridge_tx.is_closed() {
return;
}
}
let _ = bridge_tx.send(true);
});
let mut tasks = Vec::with_capacity(entries.len());
for entry in entries {
let mut config = entry.config.clone();
config.transport_type = TransportType::WebSocket;
config.host = self.opts.host.clone();
config.use_tls = self.opts.use_tls;
config.reconnect_enabled = self.opts.reconnect_enabled;
config.reconnect_delay_ms = self.opts.reconnect_delay_ms;
config.max_backoff_delay_ms = self.opts.max_backoff_delay_ms;
config.reconnect_jitter_ms = self.opts.reconnect_jitter_ms;
let runner = ConnectorRunner::new(config, entry.connector);
let child_shutdown = runner.shutdown_handle();
let mut shutdown_rx = shutdown_rx_template.clone();
let key = entry.key.clone();
tasks.push(tokio::spawn(async move {
let mut runner_fut = Box::pin(runner.run());
let res = loop {
tokio::select! {
biased;
changed = shutdown_rx.changed() => {
match changed {
Ok(()) if *shutdown_rx.borrow() => {
child_shutdown.shutdown();
}
Err(_) => {
break runner_fut.await;
}
_ => {}
}
}
result = &mut runner_fut => break result,
}
};
(key, res)
}));
}
drop(shutdown_rx_template);
for task in tasks {
match task.await {
Ok((_key, Ok(()))) => {}
Ok((key, Err(e))) => {
logger.warn(&format!("ws registration {key} exited with error: {e}"));
}
Err(join_err) => {
logger.error("ws registration task panicked", &join_err.to_string());
}
}
}
bridge.abort();
let _ = bridge.await;
drop(shutdown_tx);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::ConnectorBehavior;
struct DummyConnector;
impl BaseConnector for DummyConnector {
fn connector_type(&self) -> &str {
"dummy"
}
fn version(&self) -> &str {
"0.0.0"
}
fn execute(
&self,
_: serde_json::Value,
_: Option<&str>,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<serde_json::Value>> + Send + '_>,
> {
Box::pin(async { Ok(serde_json::json!({})) })
}
fn behavior(&self) -> ConnectorBehavior {
ConnectorBehavior::Tool
}
}
fn reg(tenant: &str, ty: &str, inst: &str) -> ConnectorRegistration {
ConnectorRegistration::new(
ConnectorConfig {
tenant_id: tenant.into(),
connector_type: ty.into(),
instance_id: inst.into(),
..ConnectorConfig::default()
},
DummyConnector,
)
}
#[test]
fn options_builder_defaults_match_default_impl() {
let built = MultiTransportOptions::builder().build();
let defaulted = MultiTransportOptions::default();
assert_eq!(built.host, defaulted.host);
assert_eq!(built.use_tls, defaulted.use_tls);
assert_eq!(
built.max_streams_per_channel,
defaulted.max_streams_per_channel
);
assert_eq!(built.transport_type, defaulted.transport_type);
}
#[test]
fn options_builder_heartbeat_roundtrip() {
let opts = MultiTransportOptions::builder()
.heartbeat_interval(Duration::from_secs(5))
.heartbeat_timeout(Duration::from_secs(15))
.build();
assert_eq!(opts.heartbeat_interval, Some(Duration::from_secs(5)));
assert_eq!(opts.heartbeat_timeout, Some(Duration::from_secs(15)));
}
#[test]
fn options_builder_heartbeat_defaults_are_none() {
let opts = MultiTransportOptions::builder().build();
assert!(opts.heartbeat_interval.is_none());
assert!(opts.heartbeat_timeout.is_none());
}
#[test]
fn validate_heartbeat_pair_flags_misordered_pair() {
assert!(!validate_heartbeat_pair(
Some(Duration::from_secs(30)),
Some(Duration::from_secs(10))
));
assert!(validate_heartbeat_pair(
Some(Duration::from_secs(30)),
Some(Duration::from_secs(45))
));
assert!(validate_heartbeat_pair(None, Some(Duration::from_secs(5))));
assert!(validate_heartbeat_pair(Some(Duration::from_secs(5)), None));
}
#[test]
fn options_builder_overrides_apply() {
let opts = MultiTransportOptions::builder()
.host("h:1")
.use_tls(true)
.max_streams_per_channel(42)
.transport_type(TransportType::WebSocket)
.build();
assert_eq!(opts.host, "h:1");
assert!(opts.use_tls);
assert_eq!(opts.max_streams_per_channel, 42);
assert_eq!(opts.transport_type, TransportType::WebSocket);
}
#[tokio::test]
async fn registration_key_from_config_matches_display_form() {
let r = reg("t", "c", "i");
let k = RegistrationKey::from_config(&r.config);
assert_eq!(k.to_string(), "t.c.i");
}
#[tokio::test]
async fn duplicate_registrations_in_new_are_collapsed() {
let runner = MultiConnectorRunner::new(
MultiTransportOptions::default(),
vec![reg("t", "c", "i"), reg("t", "c", "i"), reg("t", "c", "j")],
);
let keys = runner.registrations().await;
assert_eq!(keys.len(), 2, "second duplicate should be dropped");
assert_eq!(keys[0].instance_id, "i");
assert_eq!(keys[1].instance_id, "j");
}
#[tokio::test]
async fn add_rejects_duplicates() {
let runner =
MultiConnectorRunner::new(MultiTransportOptions::default(), vec![reg("t", "c", "i")]);
let err = runner.add(reg("t", "c", "i")).await.unwrap_err();
assert!(matches!(err, ConnectorError::InvalidConfig(_)));
}
#[tokio::test]
async fn add_after_run_starts_is_rejected() {
let runner =
MultiConnectorRunner::new(MultiTransportOptions::default(), vec![reg("t", "c", "i")]);
runner.running.store(true, Ordering::SeqCst);
let err = runner.add(reg("t", "c", "j")).await.unwrap_err();
assert!(matches!(&err, ConnectorError::AlreadyRunning));
}
#[tokio::test]
async fn add_rejects_duplicate_with_invalid_config() {
let runner =
MultiConnectorRunner::new(MultiTransportOptions::default(), vec![reg("t", "c", "i")]);
let err = runner.add(reg("t", "c", "i")).await.unwrap_err();
assert!(matches!(&err, ConnectorError::InvalidConfig(m) if m.contains("duplicate")));
}
#[tokio::test]
async fn shutdown_handle_signals_internal_flag() {
let runner =
MultiConnectorRunner::new(MultiTransportOptions::default(), vec![reg("t", "c", "i")]);
let h = runner.shutdown_handle();
assert!(!runner.shutdown_requested.load(Ordering::SeqCst));
h.shutdown();
assert!(runner.shutdown_requested.load(Ordering::SeqCst));
}
#[tokio::test]
async fn run_with_empty_registrations_is_ok() {
let runner = MultiConnectorRunner::new(MultiTransportOptions::default(), vec![]);
runner.run().await.expect("empty run should succeed");
}
#[tokio::test]
async fn run_with_pre_signalled_shutdown_is_ok() {
let runner =
MultiConnectorRunner::new(MultiTransportOptions::default(), vec![reg("t", "c", "i")]);
runner.shutdown_handle().shutdown();
runner
.run()
.await
.expect("pre-signalled shutdown should be a clean Ok exit");
}
#[tokio::test]
async fn run_websocket_accepts_config_and_shuts_down_cleanly() {
let opts = MultiTransportOptions::builder()
.transport_type(TransportType::WebSocket)
.host("localhost:65535") .build();
let runner = MultiConnectorRunner::new(opts, vec![reg("t", "c", "i")]);
runner.shutdown_handle().shutdown();
runner
.run()
.await
.expect("pre-signalled WS shutdown should be a clean Ok exit");
}
#[tokio::test]
async fn run_websocket_shutdown_does_not_leak_watcher_tasks() {
let opts = MultiTransportOptions::builder()
.transport_type(TransportType::WebSocket)
.host("127.0.0.1:1")
.build();
let mut opts = opts;
opts.reconnect_enabled = false;
let runner = MultiConnectorRunner::new(opts, vec![reg("t", "c", "i")]);
let shutdown = runner.shutdown_handle();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
shutdown.shutdown();
});
let res = tokio::time::timeout(std::time::Duration::from_secs(5), runner.run()).await;
assert!(
res.is_ok(),
"run_websocket must exit within 5s of shutdown signal"
);
res.unwrap()
.expect("run_websocket should return Ok after clean shutdown");
}
#[tokio::test]
async fn add_races_with_run_either_lands_or_rejects_never_silently_dropped() {
for _ in 0..200 {
let opts = MultiTransportOptions::default();
let runner = std::sync::Arc::new(MultiConnectorRunner::new(opts, vec![]));
runner.shutdown_handle().shutdown();
let r1 = runner.clone();
let run_task = tokio::spawn(async move { r1.run().await });
tokio::task::yield_now().await;
let add_res = runner.add(reg("t", "c", "i")).await;
run_task.await.expect("run join").expect("run ok");
match add_res {
Ok(()) => {
let keys = runner.registrations().await;
assert!(
keys.iter().any(|k| k.instance_id == "i"),
"add() succeeded but registration is not visible"
);
}
Err(ConnectorError::AlreadyRunning) => {
let keys = runner.registrations().await;
assert!(
!keys.iter().any(|k| k.instance_id == "i"),
"add() returned AlreadyRunning but registration was still inserted"
);
}
Err(other) => panic!("unexpected add() error: {other:?}"),
}
}
}
#[tokio::test]
async fn run_called_twice_returns_already_running() {
let runner = MultiConnectorRunner::new(MultiTransportOptions::default(), vec![]);
runner.running.store(true, Ordering::SeqCst);
let err = runner.run().await.unwrap_err();
assert!(matches!(err, ConnectorError::AlreadyRunning));
}
}