openwire 0.1.0

OkHttp-inspired async HTTP client for Rust built on hyper and tower
Documentation
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex, Weak};

use hyper::client::conn::{http1, http2};
use openwire_core::{BoxTaskHandle, ConnectionId, ConnectionInfo, RequestBody};

use crate::connection::{ConnectionAvailability, ExchangeFinder, RealConnection};
use crate::sync_util::lock_mutex;

const CONNECTION_BINDING_SHARDS: usize = 32;

#[derive(Clone)]
pub(super) struct ConnectionBindings {
    shards: Arc<[Mutex<HashMap<ConnectionId, ConnectionBinding>>]>,
}

enum ConnectionBinding {
    Http1(Http1Binding),
    Http2(Http2Binding),
}

struct Http1Binding {
    info: ConnectionInfo,
    sender: Option<http1::SendRequest<RequestBody>>,
}

struct Http2Binding {
    info: ConnectionInfo,
    sender: http2::SendRequest<RequestBody>,
}

pub(super) enum AcquiredBinding {
    Http1 {
        info: ConnectionInfo,
        sender: http1::SendRequest<RequestBody>,
    },
    Http2 {
        info: ConnectionInfo,
        sender: http2::SendRequest<RequestBody>,
    },
}

pub(super) enum BindingAcquireResult {
    Acquired(AcquiredBinding),
    Busy,
    Stale,
}

impl ConnectionBindings {
    fn shard(
        &self,
        connection_id: ConnectionId,
    ) -> &Mutex<HashMap<ConnectionId, ConnectionBinding>> {
        &self.shards[(connection_id.as_u64() as usize) % self.shards.len()]
    }

    pub(super) fn insert_http1(
        &self,
        connection_id: ConnectionId,
        info: ConnectionInfo,
        sender: http1::SendRequest<RequestBody>,
    ) {
        lock_mutex(self.shard(connection_id)).insert(
            connection_id,
            ConnectionBinding::Http1(Http1Binding {
                info,
                sender: Some(sender),
            }),
        );
    }

    pub(super) fn insert_http2(
        &self,
        connection_id: ConnectionId,
        info: ConnectionInfo,
        sender: http2::SendRequest<RequestBody>,
    ) {
        lock_mutex(self.shard(connection_id)).insert(
            connection_id,
            ConnectionBinding::Http2(Http2Binding { info, sender }),
        );
    }

    pub(super) fn acquire(&self, connection_id: ConnectionId) -> BindingAcquireResult {
        let mut bindings = lock_mutex(self.shard(connection_id));
        let mut remove_stale = false;
        let acquired = match bindings.get_mut(&connection_id) {
            Some(ConnectionBinding::Http1(binding)) => {
                let Some(sender) = binding.sender.take() else {
                    return BindingAcquireResult::Busy;
                };
                if sender.is_closed() {
                    remove_stale = true;
                    BindingAcquireResult::Stale
                } else {
                    BindingAcquireResult::Acquired(AcquiredBinding::Http1 {
                        info: binding.info.clone(),
                        sender,
                    })
                }
            }
            Some(ConnectionBinding::Http2(binding)) => {
                if binding.sender.is_closed() {
                    remove_stale = true;
                    BindingAcquireResult::Stale
                } else if !binding.sender.is_ready() {
                    BindingAcquireResult::Busy
                } else {
                    BindingAcquireResult::Acquired(AcquiredBinding::Http2 {
                        info: binding.info.clone(),
                        sender: binding.sender.clone(),
                    })
                }
            }
            None => BindingAcquireResult::Stale,
        };
        if remove_stale {
            bindings.remove(&connection_id);
        }
        acquired
    }

    pub(super) fn release_http1(
        &self,
        connection_id: ConnectionId,
        sender: http1::SendRequest<RequestBody>,
    ) -> bool {
        if sender.is_closed() {
            self.remove(connection_id);
            return false;
        }

        let mut bindings = lock_mutex(self.shard(connection_id));
        let Some(ConnectionBinding::Http1(binding)) = bindings.get_mut(&connection_id) else {
            return false;
        };
        debug_assert!(
            binding.sender.is_none(),
            "HTTP/1 sender should be checked out"
        );
        binding.sender = Some(sender);
        true
    }

    pub(super) fn remove(&self, connection_id: ConnectionId) {
        lock_mutex(self.shard(connection_id)).remove(&connection_id);
    }
}

impl Default for ConnectionBindings {
    fn default() -> Self {
        let shards = (0..CONNECTION_BINDING_SHARDS)
            .map(|_| Mutex::new(HashMap::new()))
            .collect::<Vec<_>>();
        Self {
            shards: Arc::<[Mutex<HashMap<ConnectionId, ConnectionBinding>>]>::from(shards),
        }
    }
}

#[derive(Clone, Default)]
pub(super) struct ConnectionTaskRegistry {
    inner: Arc<ConnectionTaskRegistryInner>,
}

#[derive(Default)]
pub(super) struct ConnectionTaskRegistryInner {
    next_id: AtomicU64,
    handles: Mutex<HashMap<u64, Option<BoxTaskHandle>>>,
}

impl ConnectionTaskRegistry {
    pub(super) fn reserve(&self) -> (u64, Weak<ConnectionTaskRegistryInner>) {
        let id = self.inner.next_id.fetch_add(1, Ordering::Relaxed) + 1;
        lock_mutex(&self.inner.handles).insert(id, None);
        (id, Arc::downgrade(&self.inner))
    }

    pub(super) fn attach(&self, task_id: u64, handle: BoxTaskHandle) {
        let mut handles = lock_mutex(&self.inner.handles);
        if let Some(slot) = handles.get_mut(&task_id) {
            *slot = Some(handle);
            return;
        }
        drop(handles);
        handle.abort();
    }

    pub(super) fn cancel(&self, task_id: u64) {
        lock_mutex(&self.inner.handles).remove(&task_id);
    }

    pub(super) fn complete_weak(inner: &Weak<ConnectionTaskRegistryInner>, task_id: u64) {
        let Some(inner) = inner.upgrade() else {
            return;
        };
        lock_mutex(&inner.handles).remove(&task_id);
    }

    #[cfg(test)]
    pub(super) fn poison_handles_for_test(&self) {
        let _guard = self
            .inner
            .handles
            .lock()
            .expect("poison connection task registry lock for test");
        panic!("poison connection task registry");
    }
}

impl Drop for ConnectionTaskRegistryInner {
    fn drop(&mut self) {
        let handles = lock_mutex(&self.handles);
        for handle in handles.values().filter_map(Option::as_ref) {
            handle.abort();
        }
    }
}

pub(super) fn release_acquired_connection(
    exchange_finder: &Arc<ExchangeFinder>,
    bindings: &Arc<ConnectionBindings>,
    availability: &ConnectionAvailability,
    connection: RealConnection,
    binding: AcquiredBinding,
) {
    match binding {
        AcquiredBinding::Http1 { sender, .. } => {
            if bindings.release_http1(connection.id(), sender)
                && exchange_finder.release(&connection)
            {
                availability.notify();
                return;
            }
            bindings.remove(connection.id());
            let _ = exchange_finder.pool().remove(connection.id());
            availability.notify();
        }
        AcquiredBinding::Http2 { .. } => {
            if exchange_finder.release(&connection) {
                availability.notify();
                return;
            }
            let _ = exchange_finder.pool().remove(connection.id());
            availability.notify();
        }
    }
}