1use crate::_events::*;
2use crate::_headers::*;
3use crate::_readers::*;
4use crate::_receivebuffer::*;
5use crate::_state::*;
6use crate::_util::*;
7use crate::_writers::*;
8use std::collections::HashMap;
9use std::collections::HashSet;
10
11static DEFAULT_MAX_INCOMPLETE_EVENT_SIZE: usize = 16 * 1024;
12
13enum RequestOrResponse {
14 Request(Request),
15 Response(Response),
16}
17
18impl RequestOrResponse {
19 pub fn headers(&self) -> &Headers {
20 match self {
21 Self::Request(request) => &request.headers,
22 Self::Response(response) => &response.headers,
23 }
24 }
25
26 pub fn http_version(&self) -> &Vec<u8> {
27 match self {
28 Self::Request(request) => &request.http_version,
29 Self::Response(response) => &response.http_version,
30 }
31 }
32}
33
34impl From<Request> for RequestOrResponse {
35 fn from(value: Request) -> Self {
36 Self::Request(value)
37 }
38}
39
40impl From<Response> for RequestOrResponse {
41 fn from(value: Response) -> Self {
42 Self::Response(value)
43 }
44}
45
46impl From<Event> for RequestOrResponse {
47 fn from(value: Event) -> Self {
48 match value {
49 Event::Request(request) => Self::Request(request),
50 Event::NormalResponse(response) => Self::Response(response),
51 _ => panic!("Invalid event type"),
52 }
53 }
54}
55
56fn _keep_alive<T: Into<RequestOrResponse>>(event: T) -> bool {
57 let event: RequestOrResponse = event.into();
58 let connection = get_comma_header(event.headers(), b"connection");
59 if connection.contains(&b"close".to_vec()) {
60 return false;
61 }
62 if event.http_version() < &b"1.1".to_vec() {
63 return false;
64 }
65 return true;
66}
67
68fn _body_framing<T: Into<RequestOrResponse>>(request_method: &[u8], event: T) -> (&str, isize) {
69 let event: RequestOrResponse = event.into();
70 if let RequestOrResponse::Response(response) = &event {
71 if response.status_code == 204
72 || response.status_code == 304
73 || request_method == b"HEAD"
74 || (request_method == b"CONNECT"
75 && 200 <= response.status_code
76 && response.status_code < 300)
77 {
78 return ("content-length", 0);
79 }
80 assert!(response.status_code >= 200);
81 }
82
83 let trasfer_encodings = get_comma_header(event.headers(), b"transfer-encoding");
84 if !trasfer_encodings.is_empty() {
85 assert!(trasfer_encodings == vec![b"chunked".to_vec()]);
86 return ("chunked", 0);
87 }
88
89 let content_lengths = get_comma_header(event.headers(), b"content-length");
90 if !content_lengths.is_empty() {
91 return (
92 "content-length",
93 std::str::from_utf8(&content_lengths[0])
94 .unwrap()
95 .parse()
96 .unwrap(),
97 );
98 }
99
100 if let RequestOrResponse::Request(_) = event {
101 return ("content-length", 0);
102 } else {
103 return ("http/1.0", 0);
104 }
105}
106
107pub struct Connection {
108 pub our_role: Role,
109 pub their_role: Role,
110 _cstate: ConnectionState,
111 _writer: Option<Box<WriterFnMut>>,
112 _reader: Option<Box<dyn Reader>>,
113 _max_incomplete_event_size: usize,
114 _receive_buffer: ReceiveBuffer,
115 _receive_buffer_closed: bool,
116 pub their_http_version: Option<Vec<u8>>,
117 _request_method: Option<Vec<u8>>,
118 client_is_waiting_for_100_continue: bool,
119}
120
121impl Connection {
122 pub fn new(our_role: Role, max_incomplete_event_size: Option<usize>) -> Self {
123 Self {
124 our_role,
125 their_role: if our_role == Role::Client {
126 Role::Server
127 } else {
128 Role::Client
129 },
130 _cstate: ConnectionState::new(),
131 _writer: match our_role {
132 Role::Client => Some(Box::new(write_request)),
133 Role::Server => Some(Box::new(write_response)),
134 },
135 _reader: match our_role {
136 Role::Server => Some(Box::new(IdleClientReader {})),
137 Role::Client => Some(Box::new(SendResponseServerReader {})),
138 },
139 _max_incomplete_event_size: max_incomplete_event_size
140 .unwrap_or(DEFAULT_MAX_INCOMPLETE_EVENT_SIZE),
141 _receive_buffer: ReceiveBuffer::new(),
142 _receive_buffer_closed: false,
143 their_http_version: None,
144 _request_method: None,
145 client_is_waiting_for_100_continue: false,
146 }
147 }
148
149 pub fn get_states(&self) -> HashMap<Role, State> {
150 self._cstate.states.clone()
151 }
152
153 pub fn get_our_state(&self) -> State {
154 self._cstate.states[&self.our_role]
155 }
156
157 pub fn get_their_state(&self) -> State {
158 self._cstate.states[&self.their_role]
159 }
160
161 pub fn get_client_is_waiting_for_100_continue(&self) -> bool {
162 self.client_is_waiting_for_100_continue
163 }
164
165 pub fn get_they_are_waiting_for_100_continue(&self) -> bool {
166 self.their_role == Role::Client && self.client_is_waiting_for_100_continue
167 }
168
169 pub fn start_next_cycle(&mut self) -> Result<(), ProtocolError> {
170 let old_states = self._cstate.states.clone();
171 self._cstate.start_next_cycle()?;
172 self._request_method = None;
173 self.their_http_version = None;
174 self.client_is_waiting_for_100_continue = false;
175 self._respond_to_state_changes(old_states, None);
176 Ok(())
177 }
178
179 fn _process_error(&mut self, role: Role) {
180 let old_states = self._cstate.states.clone();
181 self._cstate.process_error(role);
182 self._respond_to_state_changes(old_states, None);
183 }
184
185 fn _server_switch_event(&self, event: Event) -> Option<Switch> {
186 if let Event::InformationalResponse(informational_response) = &event {
187 if informational_response.status_code == 101 {
188 return Some(Switch::SwitchUpgrade);
189 }
190 }
191 if let Event::NormalResponse(response) = &event {
192 if self
193 ._cstate
194 .pending_switch_proposals
195 .contains(&Switch::SwitchConnect)
196 && 200 <= response.status_code
197 && response.status_code < 300
198 {
199 return Some(Switch::SwitchConnect);
200 }
201 }
202 return None;
203 }
204
205 fn _process_event(&mut self, role: Role, event: Event) -> Result<(), ProtocolError> {
206 let old_states = self._cstate.states.clone();
207 if role == Role::Client {
208 if let Event::Request(request) = event.clone() {
209 if request.method == b"CONNECT" {
210 self._cstate
211 .process_client_switch_proposal(Switch::SwitchConnect);
212 }
213 if get_comma_header(&request.headers, b"upgrade").len() > 0 {
214 self._cstate
215 .process_client_switch_proposal(Switch::SwitchUpgrade);
216 }
217 }
218 }
219 let server_switch_event = if role == Role::Server {
220 self._server_switch_event(event.clone())
221 } else {
222 None
223 };
224 self._cstate
225 .process_event(role, (&event).into(), server_switch_event)?;
226
227 if let Event::Request(request) = event.clone() {
228 self._request_method = Some(request.method);
229 }
230
231 if role == self.their_role {
232 if let Event::Request(request) = event.clone() {
233 self.their_http_version = Some(request.http_version);
234 }
235 if let Event::NormalResponse(response) = event.clone() {
236 self.their_http_version = Some(response.http_version);
237 }
238 if let Event::InformationalResponse(informational_response) = event.clone() {
239 self.their_http_version = Some(informational_response.http_version);
240 }
241 }
242
243 if let Event::Request(request) = event.clone() {
244 if !_keep_alive(RequestOrResponse::from(request)) {
245 self._cstate.process_keep_alive_disabled();
246 }
247 }
248 if let Event::NormalResponse(response) = event.clone() {
249 if !_keep_alive(RequestOrResponse::from(response)) {
250 self._cstate.process_keep_alive_disabled();
251 }
252 }
253
254 if let Event::Request(request) = event.clone() {
255 if has_expect_100_continue(&request) {
256 self.client_is_waiting_for_100_continue = true;
257 }
258 }
259 match (&event).into() {
260 EventType::InformationalResponse => {
261 self.client_is_waiting_for_100_continue = false;
262 }
263 EventType::NormalResponse => {
264 self.client_is_waiting_for_100_continue = false;
265 }
266 EventType::Data => {
267 if role == Role::Client {
268 self.client_is_waiting_for_100_continue = false;
269 }
270 }
271 EventType::EndOfMessage => {
272 if role == Role::Client {
273 self.client_is_waiting_for_100_continue = false;
274 }
275 }
276 _ => {}
277 }
278
279 self._respond_to_state_changes(old_states, Some(event));
280 Ok(())
281 }
282
283 fn _respond_to_state_changes(
284 &mut self,
285 old_states: HashMap<Role, State>,
286 event: Option<Event>,
287 ) {
288 if self.get_our_state() != old_states[&self.our_role] {
289 let state = self._cstate.states[&self.our_role];
290 self._writer = match state {
291 State::SendBody => {
292 let request_method = self._request_method.clone().unwrap_or(vec![]);
293 let (framing_type, length) = _body_framing(
294 &request_method,
295 RequestOrResponse::from(event.clone().unwrap()),
296 );
297
298 match framing_type {
299 "content-length" => Some(Box::new(content_length_writer(length))),
300 "chunked" => Some(Box::new(chunked_writer())),
301 "http/1.0" => Some(Box::new(http10_writer())),
302 _ => {
303 panic!("Invalid role and framing type combination");
304 }
305 }
306 }
307 _ => match (&self.our_role, state) {
308 (Role::Client, State::Idle) => Some(Box::new(write_request)),
309 (Role::Server, State::Idle) => Some(Box::new(write_response)),
310 (Role::Server, State::SendResponse) => Some(Box::new(write_response)),
311 _ => None,
312 },
313 };
314 }
315 if self.get_their_state() != old_states[&self.their_role] {
316 self._reader = match self._cstate.states[&self.their_role] {
317 State::SendBody => {
318 let request_method = self._request_method.clone().unwrap_or(vec![]);
319 let (framing_type, length) = _body_framing(
320 &request_method,
321 RequestOrResponse::from(event.clone().unwrap()),
322 );
323 match framing_type {
324 "content-length" => {
325 Some(Box::new(ContentLengthReader::new(length as usize)))
326 }
327 "chunked" => Some(Box::new(ChunkedReader::new())),
328 "http/1.0" => Some(Box::new(Http10Reader {})),
329 _ => {
330 panic!("Invalid role and framing type combination");
331 }
332 }
333 }
334 _ => match (&self.their_role, self._cstate.states[&self.their_role]) {
335 (Role::Client, State::Idle) => Some(Box::new(IdleClientReader {})),
336 (Role::Server, State::Idle) => Some(Box::new(SendResponseServerReader {})),
337 (Role::Server, State::SendResponse) => {
338 Some(Box::new(SendResponseServerReader {}))
339 }
340 (Role::Client, State::Done) => Some(Box::new(ClosedReader {})),
341 (Role::Client, State::MustClose) => Some(Box::new(ClosedReader {})),
342 (Role::Client, State::Closed) => Some(Box::new(ClosedReader {})),
343 (Role::Server, State::Done) => Some(Box::new(ClosedReader {})),
344 (Role::Server, State::MustClose) => Some(Box::new(ClosedReader {})),
345 (Role::Server, State::Closed) => Some(Box::new(ClosedReader {})),
346 _ => None,
347 },
348 };
349 }
350 }
351
352 pub fn get_trailing_data(&self) -> (Vec<u8>, bool) {
353 (
354 self._receive_buffer.bytes().to_vec(),
355 self._receive_buffer_closed,
356 )
357 }
358
359 pub fn receive_data(&mut self, data: &[u8]) -> Result<(), String> {
360 Ok(if data.len() > 0 {
361 if self._receive_buffer_closed {
362 return Err("received close, then received more data?".to_string());
363 }
364 self._receive_buffer.add(data);
365 } else {
366 self._receive_buffer_closed = true;
367 })
368 }
369
370 fn _extract_next_receive_event(&mut self) -> Result<Event, ProtocolError> {
371 let state = self.get_their_state();
372 if state == State::Done && self._receive_buffer.len() > 0 {
373 return Ok(Event::Paused());
374 }
375 if state == State::MightSwitchProtocol || state == State::SwitchedProtocol {
376 return Ok(Event::Paused());
377 }
378 let event = self
379 ._reader
380 .as_mut()
381 .unwrap()
382 .call(&mut self._receive_buffer)?;
383 if event.is_none() {
384 if self._receive_buffer.len() == 0 && self._receive_buffer_closed {
385 return self._reader.as_mut().unwrap().read_eof();
386 }
387 }
388 Ok(event.unwrap_or(Event::NeedData()))
389 }
390
391 pub fn next_event(&mut self) -> Result<Event, ProtocolError> {
392 if self.get_their_state() == State::Error {
393 return Err(ProtocolError::RemoteProtocolError(
394 "Can't receive data when peer state is ERROR".into(),
395 ));
396 }
397 match (|| {
398 let event = self._extract_next_receive_event()?;
399 match event {
400 Event::NeedData() | Event::Paused() => {}
401 _ => {
402 self._process_event(self.their_role, event.clone())?;
403 }
404 };
405
406 if let Event::NeedData() = event.clone() {
407 if self._receive_buffer.len() > self._max_incomplete_event_size {
408 return Err(ProtocolError::RemoteProtocolError(
409 ("Receive buffer too long".to_string(), 431).into(),
410 ));
411 }
412 if self._receive_buffer_closed {
413 return Err(ProtocolError::RemoteProtocolError(
414 "peer unexpectedly closed connection".to_string().into(),
415 ));
416 }
417 }
418
419 Ok(event)
420 })() {
421 Err(error) => {
422 self._process_error(self.their_role);
423 match error {
424 ProtocolError::LocalProtocolError(error) => {
425 Err(error._reraise_as_remote_protocol_error().into())
426 }
427 _ => Err(error),
428 }
429 }
430 Ok(any) => Ok(any),
431 }
432 }
433
434 pub fn send(&mut self, mut event: Event) -> Result<Option<Vec<u8>>, ProtocolError> {
435 if self.get_our_state() == State::Error {
436 return Err(ProtocolError::LocalProtocolError(
437 "Can't send data when our state is ERROR".to_string().into(),
438 ));
439 }
440 event = if let Event::NormalResponse(response) = &event {
441 Event::NormalResponse(self._clean_up_response_headers_for_sending(response.clone())?)
442 } else {
443 event
444 };
445 let event_type: EventType = (&event).into();
446 let res: Result<Vec<u8>, ProtocolError> = match self._writer.as_mut() {
447 Some(_) if event_type == EventType::ConnectionClosed => Ok(vec![]),
448 Some(writer) => writer(event.clone()),
449 None => Err(ProtocolError::LocalProtocolError(
450 "Can't send data when our state is not SEND_BODY"
451 .to_string()
452 .into(),
453 )),
454 };
455 self._process_event(self.our_role, event.clone())?;
456 if event_type == EventType::ConnectionClosed {
457 return Ok(None);
458 } else {
459 match res {
460 Ok(data_list) => Ok(Some(data_list)),
461 Err(error) => {
462 self._process_error(self.our_role);
463 Err(error)
464 }
465 }
466 }
467 }
468
469 pub fn send_failed(&mut self) {
470 self._process_error(self.our_role);
471 }
472
473 fn _clean_up_response_headers_for_sending(
474 &self,
475 response: Response,
476 ) -> Result<Response, ProtocolError> {
477 let mut headers = response.clone().headers;
478 let mut need_close = false;
479 let mut method_for_choosing_headers = self._request_method.clone().unwrap_or(vec![]);
480 if method_for_choosing_headers == b"HEAD".to_vec() {
481 method_for_choosing_headers = b"GET".to_vec();
482 }
483 let (framing_type, _) = _body_framing(&method_for_choosing_headers, response.clone());
484 if framing_type == "chunked" || framing_type == "http/1.0" {
485 headers = set_comma_header(&headers, b"content-length", vec![])?;
486 if self
487 .their_http_version
488 .clone()
489 .map(|v| v < b"1.1".to_vec())
490 .unwrap_or(true)
491 {
492 headers = set_comma_header(&headers, b"transfer-encoding", vec![])?;
493 if self._request_method.clone().unwrap_or(vec![]) != b"HEAD".to_vec() {
494 need_close = true;
495 }
496 } else {
497 headers =
498 set_comma_header(&headers, b"transfer-encoding", vec![b"chunked".to_vec()])?;
499 }
500 }
501 if !self._cstate.keep_alive || need_close {
502 let mut connection: HashSet<Vec<u8>> = get_comma_header(&headers, b"connection")
503 .into_iter()
504 .collect();
505 connection.retain(|x| x != &b"keep-alive".to_vec());
506 connection.insert(b"close".to_vec());
507 headers = set_comma_header(&headers, b"connection", connection.into_iter().collect())?;
508 }
509 return Ok(Response {
510 headers,
511 status_code: response.status_code,
512 http_version: response.http_version,
513 reason: response.reason,
514 });
515 }
516}
517
518#[cfg(test)]
519mod tests {
520 use super::*;
521
522 #[test]
523 fn test_keep_alive() {
524 assert!(_keep_alive(Request {
525 method: b"GET".to_vec(),
526 target: b"/".to_vec(),
527 headers: vec![(b"Host".to_vec(), b"Example.com".to_vec())].into(),
528 http_version: b"1.1".to_vec(),
529 }));
530 assert!(!_keep_alive(Request {
531 method: b"GET".to_vec(),
532 target: b"/".to_vec(),
533 headers: vec![
534 (b"Host".to_vec(), b"Example.com".to_vec()),
535 (b"Connection".to_vec(), b"close".to_vec()),
536 ]
537 .into(),
538 http_version: b"1.1".to_vec(),
539 }));
540 assert!(!_keep_alive(Request {
541 method: b"GET".to_vec(),
542 target: b"/".to_vec(),
543 headers: vec![
544 (b"Host".to_vec(), b"Example.com".to_vec()),
545 (b"Connection".to_vec(), b"a, b, cLOse, foo".to_vec()),
546 ]
547 .into(),
548 http_version: b"1.1".to_vec(),
549 }));
550 assert!(!_keep_alive(Request {
551 method: b"GET".to_vec(),
552 target: b"/".to_vec(),
553 headers: vec![].into(),
554 http_version: b"1.0".to_vec(),
555 }));
556
557 assert!(_keep_alive(Response {
558 status_code: 200,
559 headers: vec![].into(),
560 http_version: b"1.1".to_vec(),
561 reason: b"OK".to_vec(),
562 }));
563 assert!(!_keep_alive(Response {
564 status_code: 200,
565 headers: vec![(b"Connection".to_vec(), b"close".to_vec())].into(),
566 http_version: b"1.1".to_vec(),
567 reason: b"OK".to_vec(),
568 }));
569 assert!(!_keep_alive(Response {
570 status_code: 200,
571 headers: vec![(b"Connection".to_vec(), b"a, b, cLOse, foo".to_vec()),].into(),
572 http_version: b"1.1".to_vec(),
573 reason: b"OK".to_vec(),
574 }));
575 assert!(!_keep_alive(Response {
576 status_code: 200,
577 headers: vec![].into(),
578 http_version: b"1.0".to_vec(),
579 reason: b"OK".to_vec(),
580 }));
581 }
582
583 #[test]
584 fn test_body_framing() {
585 fn headers(cl: Option<usize>, te: bool) -> Headers {
586 let mut headers = vec![];
587 if let Some(cl) = cl {
588 headers.push((
589 b"Content-Length".to_vec(),
590 cl.to_string().as_bytes().to_vec(),
591 ));
592 }
593 if te {
594 headers.push((b"Transfer-Encoding".to_vec(), b"chunked".to_vec()));
595 }
596 headers.push((b"Host".to_vec(), b"example.com".to_vec()));
597 return headers.into();
598 }
599
600 fn resp(status_code: u16, cl: Option<usize>, te: bool) -> Response {
601 Response {
602 status_code,
603 headers: headers(cl, te),
604 http_version: b"1.1".to_vec(),
605 reason: b"OK".to_vec(),
606 }
607 }
608
609 fn req(cl: Option<usize>, te: bool) -> Request {
610 Request {
611 method: b"GET".to_vec(),
612 target: b"/".to_vec(),
613 headers: headers(cl, te),
614 http_version: b"1.1".to_vec(),
615 }
616 }
617
618 for (cl, te) in vec![(Some(100), false), (None, true), (Some(100), true)] {
620 for (meth, r) in vec![
621 (b"HEAD".to_vec(), resp(200, cl, te)),
622 (b"GET".to_vec(), resp(204, cl, te)),
623 (b"GET".to_vec(), resp(304, cl, te)),
624 ] {
625 assert_eq!(_body_framing(&meth, r), ("content-length", 0));
626 }
627 }
628
629 for (cl, te) in vec![(None, true), (Some(100), true)] {
631 for (meth, r) in vec![
632 (b"".to_vec(), RequestOrResponse::from(req(cl, te))),
633 (b"GET".to_vec(), RequestOrResponse::from(resp(200, cl, te))),
634 ] {
635 assert_eq!(_body_framing(&meth, r), ("chunked", 0));
636 }
637 }
638
639 for (meth, r) in vec![
641 (b"".to_vec(), RequestOrResponse::from(req(Some(100), false))),
642 (
643 b"GET".to_vec(),
644 RequestOrResponse::from(resp(200, Some(100), false)),
645 ),
646 ] {
647 assert_eq!(_body_framing(&meth, r), ("content-length", 100));
648 }
649
650 assert_eq!(_body_framing(b"", req(None, false)), ("content-length", 0));
652 assert_eq!(
653 _body_framing(b"GET", resp(200, None, false)),
654 ("http/1.0", 0)
655 );
656 }
657}