use std::collections::HashMap;
use std::sync::Arc;
use magnetar_proto::ConnectionConfig;
use moonpool_core::{Providers, TaskProvider, TimeProvider};
use parking_lot::Mutex;
use tokio::sync::Notify;
use crate::dns::DnsResolver;
use crate::driver::{DriverHandle, ReconnectContext, spawn_supervised as spawn_supervised_driver};
use crate::transport::Transport;
use crate::{ConnectionShared, EngineError, handshake_plain, make_shared_with_providers};
#[derive(Clone)]
pub(crate) struct ConnectionFactory<P: Providers> {
pub(crate) addr: String,
pub(crate) bootstrap_config: ConnectionConfig,
pub(crate) providers: P,
pub(crate) service_url_provider: Option<Arc<dyn magnetar_proto::ServiceUrlProvider>>,
pub(crate) dns_resolver: Option<Arc<dyn DnsResolver>>,
}
impl<P: Providers> std::fmt::Debug for ConnectionFactory<P> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnectionFactory")
.field("addr", &self.addr)
.field(
"has_service_url_provider",
&self.service_url_provider.is_some(),
)
.field("has_dns_resolver", &self.dns_resolver.is_some())
.finish_non_exhaustive()
}
}
type PoolKey = (String, String);
type DialOutcome = Result<Arc<ConnectionShared>, EngineError>;
struct PendingDial {
notify: Arc<Notify>,
result: Arc<Mutex<Option<Arc<DialOutcome>>>>,
}
impl PendingDial {
fn new() -> Self {
Self {
notify: Arc::new(Notify::new()),
result: Arc::new(Mutex::new(None)),
}
}
fn handles(&self) -> Self {
Self {
notify: self.notify.clone(),
result: self.result.clone(),
}
}
}
enum EntryState {
Pending(PendingDial),
Ready {
shared: Arc<ConnectionShared>,
driver: Mutex<Option<DriverHandle>>,
},
}
pub(crate) struct ProxyConnectionPool<P: Providers> {
factory: ConnectionFactory<P>,
entries: Mutex<HashMap<PoolKey, Arc<EntryState>>>,
}
impl<P: Providers> std::fmt::Debug for ProxyConnectionPool<P> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let snapshot: Vec<_> = self.entries.lock().keys().cloned().collect();
f.debug_struct("ProxyConnectionPool")
.field("factory", &self.factory)
.field("entries", &snapshot)
.finish()
}
}
impl<P: Providers> ProxyConnectionPool<P> {
pub(crate) fn new(factory: ConnectionFactory<P>) -> Arc<Self> {
Arc::new(Self {
factory,
entries: Mutex::new(HashMap::new()),
})
}
#[allow(dead_code)] pub(crate) fn bootstrap_addr(&self) -> &str {
&self.factory.addr
}
#[cfg(test)]
#[must_use]
pub(crate) fn len(&self) -> usize {
self.entries.lock().len()
}
}
impl<P: Providers + Send + Sync> ProxyConnectionPool<P> {
pub(crate) async fn close(&self) {
let drained: Vec<Arc<EntryState>> = self.entries.lock().drain().map(|(_, v)| v).collect();
for state in drained {
if let EntryState::Ready { shared, driver } = &*state {
{
let mut conn = shared.inner.lock();
conn.close();
}
shared.driver_waker.notify_one();
let handle = driver.lock().take();
if let Some(handle) = handle {
let _ = handle.join().await;
}
}
}
}
}
pub(crate) async fn get_or_open<P>(
pool: Arc<ProxyConnectionPool<P>>,
logical: &str,
physical: &str,
proxy_to_broker_url: Option<String>,
) -> Result<Arc<ConnectionShared>, EngineError>
where
P: Providers + Send + Sync,
{
let key: PoolKey = (logical.to_owned(), physical.to_owned());
let pending = {
let mut entries = pool.entries.lock();
if let Some(state) = entries.get(&key).cloned() {
match &*state {
EntryState::Ready { shared, .. } => return Ok(shared.clone()),
EntryState::Pending(pending) => pending.handles(),
}
} else {
let pending = PendingDial::new();
let handles = pending.handles();
let clobbered = entries.insert(key.clone(), Arc::new(EntryState::Pending(pending)));
debug_assert!(
clobbered.is_none(),
"pool entry insert clobbered a live entry — double registration for one key"
);
drop(entries);
spawn_dial(
pool.clone(),
physical.to_owned(),
proxy_to_broker_url,
key.clone(),
handles.handles(),
);
handles
}
};
let time = pool.factory.providers.time();
let op_timeout = pool.factory.bootstrap_config.operation_timeout;
let deadline = time.sleep(op_timeout);
tokio::pin!(deadline);
loop {
if let Some(outcome) = pending.result.lock().as_ref().map(Arc::clone) {
return match &*outcome {
Ok(shared) => Ok(shared.clone()),
Err(err) => Err(clone_engine_error(err)),
};
}
tokio::select! {
biased;
() = pending.notify.notified() => {}
_ = &mut deadline => {
return Err(EngineError::Io(std::io::Error::new(
std::io::ErrorKind::TimedOut,
format!("pool dial to {physical} exceeded operation_timeout ({op_timeout:?})"),
)));
}
}
}
}
fn spawn_dial<P>(
pool: Arc<ProxyConnectionPool<P>>,
physical: String,
proxy_to_broker_url: Option<String>,
key: PoolKey,
pending: PendingDial,
) where
P: Providers + Send + Sync,
{
let factory = pool.factory.clone();
let task = pool.factory.providers.task().clone();
let _detached = task.spawn_task("magnetar-moonpool-pool-dial", async move {
let outcome = build_entry_async::<P>(&factory, &physical, proxy_to_broker_url).await;
let outcome_for_waiters: Arc<DialOutcome> = Arc::new(match &outcome {
Ok((shared, _)) => Ok(shared.clone()),
Err(err) => Err(clone_engine_error(err)),
});
*pending.result.lock() = Some(outcome_for_waiters);
pending.notify.notify_waiters();
let mut map = pool.entries.lock();
if let Ok((shared, driver)) = outcome {
map.insert(
key,
Arc::new(EntryState::Ready {
shared,
driver: Mutex::new(Some(driver)),
}),
);
} else {
map.remove(&key);
}
});
}
async fn build_entry_async<P: Providers>(
factory: &ConnectionFactory<P>,
physical: &str,
proxy_to_broker_url: Option<String>,
) -> Result<(Arc<ConnectionShared>, DriverHandle), EngineError> {
let mut cfg = factory.bootstrap_config.clone();
cfg.proxy_to_broker_url = proxy_to_broker_url;
let connect_timeout = cfg.connect_timeout;
let operation_timeout = cfg.operation_timeout;
let mut transport = crate::dial_with_retry::<P, _, _>(
factory.providers.time(),
cfg.connect_max_retries,
operation_timeout,
|| {
Transport::<P>::connect_with_resolver(
factory.providers.network(),
physical,
factory.dns_resolver.as_deref(),
factory.providers.time(),
connect_timeout,
)
},
)
.await?;
let shared = make_shared_with_providers::<P>(&factory.providers, cfg);
handshake_plain::<P>(
&shared,
&mut transport,
factory.providers.time(),
None,
physical,
false,
)
.await?;
let ctx = ReconnectContext {
host_port: physical.to_owned(),
service_url_provider: factory.service_url_provider.clone(),
dns_resolver: factory.dns_resolver.clone(),
};
let driver =
spawn_supervised_driver::<P>(shared.clone(), transport, ctx, factory.providers.clone());
Ok((shared, driver))
}
fn clone_engine_error(err: &EngineError) -> EngineError {
match err {
EngineError::Io(io) => EngineError::Io(std::io::Error::new(io.kind(), io.to_string())),
EngineError::PeerClosed => EngineError::PeerClosed,
EngineError::Config(msg) => EngineError::Config(msg.clone()),
EngineError::Protocol(p) => EngineError::Config(format!("protocol error: {p}")),
EngineError::HandshakeFailed(msg) => EngineError::HandshakeFailed(msg.clone()),
EngineError::Tls(t) => EngineError::Config(format!("tls error: {t}")),
EngineError::MemoryLimitExceeded {
current,
limit,
requested,
} => EngineError::MemoryLimitExceeded {
current: *current,
limit: *limit,
requested: *requested,
},
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use moonpool_core::TokioProviders;
use super::*;
fn dummy_factory() -> ConnectionFactory<TokioProviders> {
ConnectionFactory {
addr: "broker.example.com:6650".to_owned(),
bootstrap_config: ConnectionConfig {
operation_timeout: Duration::from_secs(30),
..ConnectionConfig::default()
},
providers: TokioProviders::new(),
service_url_provider: None,
dns_resolver: None,
}
}
#[test]
fn fresh_pool_is_empty() {
let pool = ProxyConnectionPool::new(dummy_factory());
assert_eq!(pool.len(), 0);
}
#[test]
fn debug_includes_pool_state() {
let pool = ProxyConnectionPool::new(dummy_factory());
let s = format!("{pool:?}");
assert!(s.contains("ProxyConnectionPool"));
assert!(s.contains("entries"));
}
}