Skip to main content

openwire_core/
context.rs

1use std::fmt;
2use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use crate::{SharedEventListener, SharedEventListenerFactory, TlsAlpnPreference};
7
8static NEXT_CALL_ID: AtomicU64 = AtomicU64::new(1);
9static NEXT_CONNECTION_ID: AtomicU64 = AtomicU64::new(1);
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub struct CallId(u64);
13
14impl CallId {
15    pub fn as_u64(self) -> u64 {
16        self.0
17    }
18}
19
20impl fmt::Display for CallId {
21    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22        self.0.fmt(f)
23    }
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27pub struct ConnectionId(u64);
28
29impl ConnectionId {
30    pub fn as_u64(self) -> u64 {
31        self.0
32    }
33}
34
35impl fmt::Display for ConnectionId {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        self.0.fmt(f)
38    }
39}
40
41pub fn next_connection_id() -> ConnectionId {
42    ConnectionId(NEXT_CONNECTION_ID.fetch_add(1, Ordering::Relaxed))
43}
44
45#[derive(Clone)]
46pub struct CallContext {
47    inner: Arc<CallContextInner>,
48}
49
50struct CallContextInner {
51    call_id: CallId,
52    listener: SharedEventListener,
53    created_at: Instant,
54    deadline: Option<Instant>,
55    connection_established: AtomicBool,
56    tls_alpn_preference: TlsAlpnPreference,
57}
58
59impl std::fmt::Debug for CallContext {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        f.debug_struct("CallContext")
62            .field("call_id", &self.call_id())
63            .field("created_at", &self.created_at())
64            .field("deadline", &self.deadline())
65            .field("tls_alpn_preference", &self.tls_alpn_preference())
66            .finish()
67    }
68}
69
70impl CallContext {
71    pub fn new(listener: SharedEventListener, deadline: Option<Duration>) -> Self {
72        Self::with_tls_alpn_preference(listener, deadline, TlsAlpnPreference::Auto)
73    }
74
75    fn with_tls_alpn_preference(
76        listener: SharedEventListener,
77        deadline: Option<Duration>,
78        tls_alpn_preference: TlsAlpnPreference,
79    ) -> Self {
80        let created_at = Instant::now();
81        let deadline = deadline.map(|duration| created_at + duration);
82        Self {
83            inner: Arc::new(CallContextInner {
84                call_id: CallId(NEXT_CALL_ID.fetch_add(1, Ordering::Relaxed)),
85                listener,
86                created_at,
87                deadline,
88                connection_established: AtomicBool::new(false),
89                tls_alpn_preference,
90            }),
91        }
92    }
93
94    pub fn from_factory(
95        factory: &SharedEventListenerFactory,
96        request: &http::Request<crate::RequestBody>,
97        deadline: Option<Duration>,
98    ) -> Self {
99        let listener = factory.create(request);
100        let tls_alpn_preference = request
101            .extensions()
102            .get::<TlsAlpnPreference>()
103            .copied()
104            .unwrap_or_default();
105        Self::with_tls_alpn_preference(listener, deadline, tls_alpn_preference)
106    }
107
108    pub fn call_id(&self) -> CallId {
109        self.inner.call_id
110    }
111
112    pub fn listener(&self) -> &SharedEventListener {
113        &self.inner.listener
114    }
115
116    pub fn created_at(&self) -> Instant {
117        self.inner.created_at
118    }
119
120    pub fn deadline(&self) -> Option<Instant> {
121        self.inner.deadline
122    }
123
124    pub fn tls_alpn_preference(&self) -> TlsAlpnPreference {
125        self.inner.tls_alpn_preference
126    }
127
128    pub fn mark_connection_established(&self) {
129        self.inner
130            .connection_established
131            .store(true, Ordering::Relaxed);
132    }
133
134    pub fn connection_established(&self) -> bool {
135        self.inner.connection_established.load(Ordering::Relaxed)
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use std::sync::Arc;
142
143    use super::*;
144    use crate::{NoopEventListenerFactory, RequestBody, SharedEventListenerFactory};
145
146    #[test]
147    fn from_factory_defaults_tls_alpn_preference_to_auto() {
148        let factory = Arc::new(NoopEventListenerFactory) as SharedEventListenerFactory;
149        let request = http::Request::builder()
150            .uri("https://example.com/")
151            .body(RequestBody::empty())
152            .expect("request");
153
154        let ctx = CallContext::from_factory(&factory, &request, None);
155
156        assert_eq!(ctx.tls_alpn_preference(), TlsAlpnPreference::Auto);
157    }
158
159    #[test]
160    fn from_factory_preserves_tls_alpn_preference_from_request_extensions() {
161        let factory = Arc::new(NoopEventListenerFactory) as SharedEventListenerFactory;
162        let mut request = http::Request::builder()
163            .uri("https://example.com/")
164            .body(RequestBody::empty())
165            .expect("request");
166        request
167            .extensions_mut()
168            .insert(TlsAlpnPreference::Http1Only);
169
170        let ctx = CallContext::from_factory(&factory, &request, None);
171
172        assert_eq!(ctx.tls_alpn_preference(), TlsAlpnPreference::Http1Only);
173    }
174}