openwire-core 0.1.1

Shared primitives, policies, bodies, and transport traits for OpenWire
Documentation
use std::fmt;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};

use crate::{SharedEventListener, SharedEventListenerFactory, TlsAlpnPreference};

static NEXT_CALL_ID: AtomicU64 = AtomicU64::new(1);
static NEXT_CONNECTION_ID: AtomicU64 = AtomicU64::new(1);

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct CallId(u64);

impl CallId {
    pub fn as_u64(self) -> u64 {
        self.0
    }
}

impl fmt::Display for CallId {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        self.0.fmt(f)
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ConnectionId(u64);

impl ConnectionId {
    pub fn as_u64(self) -> u64 {
        self.0
    }
}

impl fmt::Display for ConnectionId {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        self.0.fmt(f)
    }
}

pub fn next_connection_id() -> ConnectionId {
    ConnectionId(NEXT_CONNECTION_ID.fetch_add(1, Ordering::Relaxed))
}

#[derive(Clone)]
pub struct CallContext {
    inner: Arc<CallContextInner>,
}

struct CallContextInner {
    call_id: CallId,
    listener: SharedEventListener,
    created_at: Instant,
    deadline: Option<Instant>,
    connection_established: AtomicBool,
    tls_alpn_preference: TlsAlpnPreference,
}

impl std::fmt::Debug for CallContext {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("CallContext")
            .field("call_id", &self.call_id())
            .field("created_at", &self.created_at())
            .field("deadline", &self.deadline())
            .field("tls_alpn_preference", &self.tls_alpn_preference())
            .finish()
    }
}

impl CallContext {
    pub fn new(listener: SharedEventListener, deadline: Option<Duration>) -> Self {
        Self::with_tls_alpn_preference(listener, deadline, TlsAlpnPreference::Auto)
    }

    fn with_tls_alpn_preference(
        listener: SharedEventListener,
        deadline: Option<Duration>,
        tls_alpn_preference: TlsAlpnPreference,
    ) -> Self {
        let created_at = Instant::now();
        let deadline = deadline.map(|duration| created_at + duration);
        Self {
            inner: Arc::new(CallContextInner {
                call_id: CallId(NEXT_CALL_ID.fetch_add(1, Ordering::Relaxed)),
                listener,
                created_at,
                deadline,
                connection_established: AtomicBool::new(false),
                tls_alpn_preference,
            }),
        }
    }

    pub fn from_factory(
        factory: &SharedEventListenerFactory,
        request: &http::Request<crate::RequestBody>,
        deadline: Option<Duration>,
    ) -> Self {
        let listener = factory.create(request);
        let tls_alpn_preference = request
            .extensions()
            .get::<TlsAlpnPreference>()
            .copied()
            .unwrap_or_default();
        Self::with_tls_alpn_preference(listener, deadline, tls_alpn_preference)
    }

    pub fn call_id(&self) -> CallId {
        self.inner.call_id
    }

    pub fn listener(&self) -> &SharedEventListener {
        &self.inner.listener
    }

    pub fn created_at(&self) -> Instant {
        self.inner.created_at
    }

    pub fn deadline(&self) -> Option<Instant> {
        self.inner.deadline
    }

    pub fn tls_alpn_preference(&self) -> TlsAlpnPreference {
        self.inner.tls_alpn_preference
    }

    pub fn mark_connection_established(&self) {
        self.inner
            .connection_established
            .store(true, Ordering::Relaxed);
    }

    pub fn connection_established(&self) -> bool {
        self.inner.connection_established.load(Ordering::Relaxed)
    }
}

#[cfg(test)]
mod tests {
    use std::sync::Arc;

    use super::*;
    use crate::{NoopEventListenerFactory, RequestBody, SharedEventListenerFactory};

    #[test]
    fn from_factory_defaults_tls_alpn_preference_to_auto() {
        let factory = Arc::new(NoopEventListenerFactory) as SharedEventListenerFactory;
        let request = http::Request::builder()
            .uri("https://example.com/")
            .body(RequestBody::empty())
            .expect("request");

        let ctx = CallContext::from_factory(&factory, &request, None);

        assert_eq!(ctx.tls_alpn_preference(), TlsAlpnPreference::Auto);
    }

    #[test]
    fn from_factory_preserves_tls_alpn_preference_from_request_extensions() {
        let factory = Arc::new(NoopEventListenerFactory) as SharedEventListenerFactory;
        let mut request = http::Request::builder()
            .uri("https://example.com/")
            .body(RequestBody::empty())
            .expect("request");
        request
            .extensions_mut()
            .insert(TlsAlpnPreference::Http1Only);

        let ctx = CallContext::from_factory(&factory, &request, None);

        assert_eq!(ctx.tls_alpn_preference(), TlsAlpnPreference::Http1Only);
    }
}