1use crate::{
2 Buffer, Conn, Headers, HttpContext, Method, ProtocolSession, Status, TypeSet, Version,
3 h2::H2Connection, h3::H3Connection, received_body::read_buffered,
4};
5use fieldwork::Fieldwork;
6use futures_lite::{AsyncRead, AsyncWrite};
7use std::{
8 borrow::Cow,
9 fmt::{self, Debug, Formatter},
10 io,
11 net::IpAddr,
12 pin::Pin,
13 str,
14 sync::Arc,
15 task::{self, Poll},
16 time::Instant,
17};
18use trillium_macros::AsyncWrite;
19
20#[derive(AsyncWrite, Fieldwork)]
28#[fieldwork(get, get_mut, set, with, take, into_field, rename_predicates)]
29pub struct Upgrade<Transport> {
30 request_headers: Headers,
32
33 response_headers: Headers,
36
37 #[field(get = false)]
39 path: Cow<'static, str>,
40
41 #[field(copy)]
43 method: Method,
44
45 state: TypeSet,
47
48 #[async_write]
50 transport: Transport,
51
52 #[field(deref = "[u8]", into_field = false, set = false, with = false)]
57 buffer: Buffer,
58
59 #[field(deref = false)]
61 context: Arc<HttpContext>,
62
63 #[field(copy)]
65 peer_ip: Option<IpAddr>,
66
67 #[field(copy)]
70 start_time: Instant,
71
72 authority: Option<Cow<'static, str>>,
74
75 scheme: Option<Cow<'static, str>>,
77
78 #[field = false]
82 protocol_session: ProtocolSession,
83
84 protocol: Option<Cow<'static, str>>,
86
87 #[field = "http_version"]
89 version: Version,
90
91 #[field(copy)]
95 status: Option<Status>,
96
97 secure: bool,
99}
100
101impl<Transport> Upgrade<Transport> {
102 #[doc(hidden)]
103 pub fn new(
104 request_headers: Headers,
105 path: impl Into<Cow<'static, str>>,
106 method: Method,
107 transport: Transport,
108 buffer: Buffer,
109 version: Version,
110 ) -> Self {
111 Self {
112 request_headers,
113 response_headers: Headers::new(),
114 path: path.into(),
115 method,
116 transport,
117 buffer,
118 state: TypeSet::new(),
119 context: Arc::default(),
120 peer_ip: None,
121 start_time: Instant::now(),
122 authority: None,
123 scheme: None,
124 protocol_session: ProtocolSession::Http1,
125 protocol: None,
126 secure: false,
127 version,
128 status: None,
129 }
130 }
131
132 pub fn h2_connection(&self) -> Option<&Arc<H2Connection>> {
134 self.protocol_session.h2_connection()
135 }
136
137 pub fn h2_stream_id(&self) -> Option<u32> {
139 self.protocol_session.h2_stream_id()
140 }
141
142 pub fn h3_connection(&self) -> Option<&Arc<H3Connection>> {
144 self.protocol_session.h3_connection()
145 }
146
147 pub fn h3_stream_id(&self) -> Option<u64> {
149 self.protocol_session.h3_stream_id()
150 }
151
152 pub fn take_buffer(&mut self) -> Vec<u8> {
154 std::mem::take(&mut self.buffer).into()
155 }
156
157 #[doc(hidden)]
158 pub fn buffer_and_transport_mut(&mut self) -> (&mut Buffer, &mut Transport) {
159 (&mut self.buffer, &mut self.transport)
160 }
161
162 pub fn shared_state(&self) -> &TypeSet {
164 self.context.shared_state()
165 }
166
167 pub fn path(&self) -> &str {
169 match self.path.split_once('?') {
170 Some((path, _)) => path,
171 None => &self.path,
172 }
173 }
174
175 pub fn querystring(&self) -> &str {
177 self.path
178 .split_once('?')
179 .map(|(_, query)| query)
180 .unwrap_or_default()
181 }
182
183 pub fn map_transport<T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static>(
187 self,
188 f: impl Fn(Transport) -> T,
189 ) -> Upgrade<T> {
190 Upgrade {
195 transport: f(self.transport),
196 path: self.path,
197 method: self.method,
198 state: self.state,
199 buffer: self.buffer,
200 request_headers: self.request_headers,
201 response_headers: self.response_headers,
202 context: self.context,
203 peer_ip: self.peer_ip,
204 start_time: self.start_time,
205 authority: self.authority,
206 scheme: self.scheme,
207 protocol_session: self.protocol_session,
208 protocol: self.protocol,
209 version: self.version,
210 status: self.status,
211 secure: self.secure,
212 }
213 }
214}
215
216impl<Transport> Debug for Upgrade<Transport> {
217 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
218 f.debug_struct(&format!("Upgrade<{}>", std::any::type_name::<Transport>()))
219 .field("request_headers", &self.request_headers)
220 .field("response_headers", &self.response_headers)
221 .field("path", &self.path)
222 .field("method", &self.method)
223 .field("buffer", &self.buffer)
224 .field("context", &self.context)
225 .field("state", &self.state)
226 .field("transport", &format_args!(".."))
227 .field("peer_ip", &self.peer_ip)
228 .field("start_time", &self.start_time)
229 .field("authority", &self.authority)
230 .field("scheme", &self.scheme)
231 .field("protocol_session", &self.protocol_session)
232 .field("protocol", &self.protocol)
233 .field("version", &self.version)
234 .field("status", &self.status)
235 .field("secure", &self.secure)
236 .finish()
237 }
238}
239
240impl<Transport> From<Conn<Transport>> for Upgrade<Transport> {
241 fn from(conn: Conn<Transport>) -> Self {
242 let Conn {
249 request_headers,
250 response_headers,
251 path,
252 method,
253 state,
254 transport,
255 buffer,
256 context,
257 peer_ip,
258 start_time,
259 authority,
260 scheme,
261 protocol_session,
262 protocol,
263 version,
264 status,
265 secure,
266 response_body: _,
269 request_body_state: _,
270 after_send: _,
271 request_trailers: _,
272 } = conn;
273
274 Self {
275 request_headers,
276 response_headers,
277 path,
278 method,
279 state,
280 transport,
281 buffer,
282 context,
283 peer_ip,
284 start_time,
285 authority,
286 scheme,
287 protocol_session,
288 protocol,
289 version,
290 status,
291 secure,
292 }
293 }
294}
295
296impl<Transport: AsyncRead + Unpin> AsyncRead for Upgrade<Transport> {
297 fn poll_read(
298 mut self: Pin<&mut Self>,
299 cx: &mut task::Context<'_>,
300 buf: &mut [u8],
301 ) -> Poll<io::Result<usize>> {
302 let Self {
303 transport, buffer, ..
304 } = &mut *self;
305 read_buffered(buffer, transport, cx, buf)
306 }
307}