use std::sync::Arc;
use iroh::endpoint::Connection;
use iroh::PublicKey;
#[derive(Clone)]
pub(crate) struct PooledConnection {
pub conn: Connection,
pub remote_id_str: String,
}
impl PooledConnection {
pub fn new(conn: Connection) -> Self {
let remote_id_str = crate::base32_encode(conn.remote_id().as_bytes());
Self {
conn,
remote_id_str,
}
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
struct PoolKey {
node_id: PublicKey,
alpn: &'static [u8],
}
pub(crate) struct ConnectionPool {
cache: moka::future::Cache<PoolKey, Arc<PooledConnection>>,
}
impl ConnectionPool {
pub fn new(max_idle: Option<usize>, idle_timeout: Option<std::time::Duration>) -> Self {
let cap = max_idle.unwrap_or(512) as u64;
let mut builder = moka::future::Cache::builder().max_capacity(cap);
if let Some(tti) = idle_timeout {
builder = builder.time_to_idle(tti);
}
Self {
cache: builder.build(),
}
}
pub async fn get_or_connect<F, Fut>(
&self,
node_id: PublicKey,
alpn: &'static [u8],
connect_fn: F,
) -> Result<PooledConnection, String>
where
F: FnOnce() -> Fut + Send,
Fut: std::future::Future<Output = Result<Connection, String>> + Send,
{
let key = PoolKey { node_id, alpn };
if let Some(pooled) = self.cache.get(&key).await {
if pooled.conn.close_reason().is_none() {
tracing::debug!(peer = %pooled.remote_id_str, "iroh-http: pool hit");
return Ok((*pooled).clone());
}
tracing::debug!(peer = %pooled.remote_id_str, "iroh-http: pool stale, reconnecting");
self.cache.invalidate(&key).await;
}
tracing::debug!(peer = %crate::base32_encode(key.node_id.as_bytes()), "iroh-http: pool miss, connecting");
let result = self
.cache
.try_get_with(key.clone(), async move {
connect_fn()
.await
.map(|conn| Arc::new(PooledConnection::new(conn)))
})
.await;
match result {
Err(e) => {
Err((*e).clone())
}
Ok(pooled) => {
if pooled.conn.close_reason().is_some() {
self.cache.invalidate(&key).await;
return Err("pooled connection closed immediately after connect".to_string());
}
Ok((*pooled).clone())
}
}
}
pub async fn get_existing(
&self,
node_id: PublicKey,
alpn: &'static [u8],
) -> Option<PooledConnection> {
let key = PoolKey { node_id, alpn };
let pooled = self.cache.get(&key).await?;
if pooled.conn.close_reason().is_none() {
Some((*pooled).clone())
} else {
None
}
}
#[cfg(test)]
pub async fn len(&self) -> usize {
self.cache.run_pending_tasks().await;
self.cache.entry_count() as usize
}
pub(crate) fn entry_count_approx(&self) -> u64 {
self.cache.entry_count()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn pool_starts_empty() {
let pool = ConnectionPool::new(None, None);
assert_eq!(pool.len().await, 0);
}
}