h11/
_connection.rs

1use crate::_events::*;
2use crate::_headers::*;
3use crate::_readers::*;
4use crate::_receivebuffer::*;
5use crate::_state::*;
6use crate::_util::*;
7use crate::_writers::*;
8use std::collections::HashMap;
9use std::collections::HashSet;
10
11static DEFAULT_MAX_INCOMPLETE_EVENT_SIZE: usize = 16 * 1024;
12
13enum RequestOrResponse {
14    Request(Request),
15    Response(Response),
16}
17
18impl RequestOrResponse {
19    pub fn headers(&self) -> &Headers {
20        match self {
21            Self::Request(request) => &request.headers,
22            Self::Response(response) => &response.headers,
23        }
24    }
25
26    pub fn http_version(&self) -> &Vec<u8> {
27        match self {
28            Self::Request(request) => &request.http_version,
29            Self::Response(response) => &response.http_version,
30        }
31    }
32}
33
34impl From<Request> for RequestOrResponse {
35    fn from(value: Request) -> Self {
36        Self::Request(value)
37    }
38}
39
40impl From<Response> for RequestOrResponse {
41    fn from(value: Response) -> Self {
42        Self::Response(value)
43    }
44}
45
46impl From<Event> for RequestOrResponse {
47    fn from(value: Event) -> Self {
48        match value {
49            Event::Request(request) => Self::Request(request),
50            Event::NormalResponse(response) => Self::Response(response),
51            _ => panic!("Invalid event type"),
52        }
53    }
54}
55
56fn _keep_alive<T: Into<RequestOrResponse>>(event: T) -> bool {
57    let event: RequestOrResponse = event.into();
58    let connection = get_comma_header(event.headers(), b"connection");
59    if connection.contains(&b"close".to_vec()) {
60        return false;
61    }
62    if event.http_version() < &b"1.1".to_vec() {
63        return false;
64    }
65    return true;
66}
67
68fn _body_framing<T: Into<RequestOrResponse>>(request_method: &[u8], event: T) -> (&str, isize) {
69    let event: RequestOrResponse = event.into();
70    if let RequestOrResponse::Response(response) = &event {
71        if response.status_code == 204
72            || response.status_code == 304
73            || request_method == b"HEAD"
74            || (request_method == b"CONNECT"
75                && 200 <= response.status_code
76                && response.status_code < 300)
77        {
78            return ("content-length", 0);
79        }
80        assert!(response.status_code >= 200);
81    }
82
83    let trasfer_encodings = get_comma_header(event.headers(), b"transfer-encoding");
84    if !trasfer_encodings.is_empty() {
85        assert!(trasfer_encodings == vec![b"chunked".to_vec()]);
86        return ("chunked", 0);
87    }
88
89    let content_lengths = get_comma_header(event.headers(), b"content-length");
90    if !content_lengths.is_empty() {
91        return (
92            "content-length",
93            std::str::from_utf8(&content_lengths[0])
94                .unwrap()
95                .parse()
96                .unwrap(),
97        );
98    }
99
100    if let RequestOrResponse::Request(_) = event {
101        return ("content-length", 0);
102    } else {
103        return ("http/1.0", 0);
104    }
105}
106
107pub struct Connection {
108    pub our_role: Role,
109    pub their_role: Role,
110    _cstate: ConnectionState,
111    _writer: Option<Box<WriterFnMut>>,
112    _reader: Option<Box<dyn Reader>>,
113    _max_incomplete_event_size: usize,
114    _receive_buffer: ReceiveBuffer,
115    _receive_buffer_closed: bool,
116    pub their_http_version: Option<Vec<u8>>,
117    _request_method: Option<Vec<u8>>,
118    client_is_waiting_for_100_continue: bool,
119}
120
121impl Connection {
122    pub fn new(our_role: Role, max_incomplete_event_size: Option<usize>) -> Self {
123        Self {
124            our_role,
125            their_role: if our_role == Role::Client {
126                Role::Server
127            } else {
128                Role::Client
129            },
130            _cstate: ConnectionState::new(),
131            _writer: match our_role {
132                Role::Client => Some(Box::new(write_request)),
133                Role::Server => Some(Box::new(write_response)),
134            },
135            _reader: match our_role {
136                Role::Server => Some(Box::new(IdleClientReader {})),
137                Role::Client => Some(Box::new(SendResponseServerReader {})),
138            },
139            _max_incomplete_event_size: max_incomplete_event_size
140                .unwrap_or(DEFAULT_MAX_INCOMPLETE_EVENT_SIZE),
141            _receive_buffer: ReceiveBuffer::new(),
142            _receive_buffer_closed: false,
143            their_http_version: None,
144            _request_method: None,
145            client_is_waiting_for_100_continue: false,
146        }
147    }
148
149    pub fn get_states(&self) -> HashMap<Role, State> {
150        self._cstate.states.clone()
151    }
152
153    pub fn get_our_state(&self) -> State {
154        self._cstate.states[&self.our_role]
155    }
156
157    pub fn get_their_state(&self) -> State {
158        self._cstate.states[&self.their_role]
159    }
160
161    pub fn get_client_is_waiting_for_100_continue(&self) -> bool {
162        self.client_is_waiting_for_100_continue
163    }
164
165    pub fn get_they_are_waiting_for_100_continue(&self) -> bool {
166        self.their_role == Role::Client && self.client_is_waiting_for_100_continue
167    }
168
169    pub fn start_next_cycle(&mut self) -> Result<(), ProtocolError> {
170        let old_states = self._cstate.states.clone();
171        self._cstate.start_next_cycle()?;
172        self._request_method = None;
173        self.their_http_version = None;
174        self.client_is_waiting_for_100_continue = false;
175        self._respond_to_state_changes(old_states, None);
176        Ok(())
177    }
178
179    fn _process_error(&mut self, role: Role) {
180        let old_states = self._cstate.states.clone();
181        self._cstate.process_error(role);
182        self._respond_to_state_changes(old_states, None);
183    }
184
185    fn _server_switch_event(&self, event: Event) -> Option<Switch> {
186        if let Event::InformationalResponse(informational_response) = &event {
187            if informational_response.status_code == 101 {
188                return Some(Switch::SwitchUpgrade);
189            }
190        }
191        if let Event::NormalResponse(response) = &event {
192            if self
193                ._cstate
194                .pending_switch_proposals
195                .contains(&Switch::SwitchConnect)
196                && 200 <= response.status_code
197                && response.status_code < 300
198            {
199                return Some(Switch::SwitchConnect);
200            }
201        }
202        return None;
203    }
204
205    fn _process_event(&mut self, role: Role, event: Event) -> Result<(), ProtocolError> {
206        let old_states = self._cstate.states.clone();
207        if role == Role::Client {
208            if let Event::Request(request) = event.clone() {
209                if request.method == b"CONNECT" {
210                    self._cstate
211                        .process_client_switch_proposal(Switch::SwitchConnect);
212                }
213                if get_comma_header(&request.headers, b"upgrade").len() > 0 {
214                    self._cstate
215                        .process_client_switch_proposal(Switch::SwitchUpgrade);
216                }
217            }
218        }
219        let server_switch_event = if role == Role::Server {
220            self._server_switch_event(event.clone())
221        } else {
222            None
223        };
224        self._cstate
225            .process_event(role, (&event).into(), server_switch_event)?;
226
227        if let Event::Request(request) = event.clone() {
228            self._request_method = Some(request.method);
229        }
230
231        if role == self.their_role {
232            if let Event::Request(request) = event.clone() {
233                self.their_http_version = Some(request.http_version);
234            }
235            if let Event::NormalResponse(response) = event.clone() {
236                self.their_http_version = Some(response.http_version);
237            }
238            if let Event::InformationalResponse(informational_response) = event.clone() {
239                self.their_http_version = Some(informational_response.http_version);
240            }
241        }
242
243        if let Event::Request(request) = event.clone() {
244            if !_keep_alive(RequestOrResponse::from(request)) {
245                self._cstate.process_keep_alive_disabled();
246            }
247        }
248        if let Event::NormalResponse(response) = event.clone() {
249            if !_keep_alive(RequestOrResponse::from(response)) {
250                self._cstate.process_keep_alive_disabled();
251            }
252        }
253
254        if let Event::Request(request) = event.clone() {
255            if has_expect_100_continue(&request) {
256                self.client_is_waiting_for_100_continue = true;
257            }
258        }
259        match (&event).into() {
260            EventType::InformationalResponse => {
261                self.client_is_waiting_for_100_continue = false;
262            }
263            EventType::NormalResponse => {
264                self.client_is_waiting_for_100_continue = false;
265            }
266            EventType::Data => {
267                if role == Role::Client {
268                    self.client_is_waiting_for_100_continue = false;
269                }
270            }
271            EventType::EndOfMessage => {
272                if role == Role::Client {
273                    self.client_is_waiting_for_100_continue = false;
274                }
275            }
276            _ => {}
277        }
278
279        self._respond_to_state_changes(old_states, Some(event));
280        Ok(())
281    }
282
283    fn _respond_to_state_changes(
284        &mut self,
285        old_states: HashMap<Role, State>,
286        event: Option<Event>,
287    ) {
288        if self.get_our_state() != old_states[&self.our_role] {
289            let state = self._cstate.states[&self.our_role];
290            self._writer = match state {
291                State::SendBody => {
292                    let request_method = self._request_method.clone().unwrap_or(vec![]);
293                    let (framing_type, length) = _body_framing(
294                        &request_method,
295                        RequestOrResponse::from(event.clone().unwrap()),
296                    );
297
298                    match framing_type {
299                        "content-length" => Some(Box::new(content_length_writer(length))),
300                        "chunked" => Some(Box::new(chunked_writer())),
301                        "http/1.0" => Some(Box::new(http10_writer())),
302                        _ => {
303                            panic!("Invalid role and framing type combination");
304                        }
305                    }
306                }
307                _ => match (&self.our_role, state) {
308                    (Role::Client, State::Idle) => Some(Box::new(write_request)),
309                    (Role::Server, State::Idle) => Some(Box::new(write_response)),
310                    (Role::Server, State::SendResponse) => Some(Box::new(write_response)),
311                    _ => None,
312                },
313            };
314        }
315        if self.get_their_state() != old_states[&self.their_role] {
316            self._reader = match self._cstate.states[&self.their_role] {
317                State::SendBody => {
318                    let request_method = self._request_method.clone().unwrap_or(vec![]);
319                    let (framing_type, length) = _body_framing(
320                        &request_method,
321                        RequestOrResponse::from(event.clone().unwrap()),
322                    );
323                    match framing_type {
324                        "content-length" => {
325                            Some(Box::new(ContentLengthReader::new(length as usize)))
326                        }
327                        "chunked" => Some(Box::new(ChunkedReader::new())),
328                        "http/1.0" => Some(Box::new(Http10Reader {})),
329                        _ => {
330                            panic!("Invalid role and framing type combination");
331                        }
332                    }
333                }
334                _ => match (&self.their_role, self._cstate.states[&self.their_role]) {
335                    (Role::Client, State::Idle) => Some(Box::new(IdleClientReader {})),
336                    (Role::Server, State::Idle) => Some(Box::new(SendResponseServerReader {})),
337                    (Role::Server, State::SendResponse) => {
338                        Some(Box::new(SendResponseServerReader {}))
339                    }
340                    (Role::Client, State::Done) => Some(Box::new(ClosedReader {})),
341                    (Role::Client, State::MustClose) => Some(Box::new(ClosedReader {})),
342                    (Role::Client, State::Closed) => Some(Box::new(ClosedReader {})),
343                    (Role::Server, State::Done) => Some(Box::new(ClosedReader {})),
344                    (Role::Server, State::MustClose) => Some(Box::new(ClosedReader {})),
345                    (Role::Server, State::Closed) => Some(Box::new(ClosedReader {})),
346                    _ => None,
347                },
348            };
349        }
350    }
351
352    pub fn get_trailing_data(&self) -> (Vec<u8>, bool) {
353        (
354            self._receive_buffer.bytes().to_vec(),
355            self._receive_buffer_closed,
356        )
357    }
358
359    pub fn receive_data(&mut self, data: &[u8]) -> Result<(), String> {
360        Ok(if data.len() > 0 {
361            if self._receive_buffer_closed {
362                return Err("received close, then received more data?".to_string());
363            }
364            self._receive_buffer.add(data);
365        } else {
366            self._receive_buffer_closed = true;
367        })
368    }
369
370    fn _extract_next_receive_event(&mut self) -> Result<Event, ProtocolError> {
371        let state = self.get_their_state();
372        if state == State::Done && self._receive_buffer.len() > 0 {
373            return Ok(Event::Paused());
374        }
375        if state == State::MightSwitchProtocol || state == State::SwitchedProtocol {
376            return Ok(Event::Paused());
377        }
378        let event = self
379            ._reader
380            .as_mut()
381            .unwrap()
382            .call(&mut self._receive_buffer)?;
383        if event.is_none() {
384            if self._receive_buffer.len() == 0 && self._receive_buffer_closed {
385                return self._reader.as_mut().unwrap().read_eof();
386            }
387        }
388        Ok(event.unwrap_or(Event::NeedData()))
389    }
390
391    pub fn next_event(&mut self) -> Result<Event, ProtocolError> {
392        if self.get_their_state() == State::Error {
393            return Err(ProtocolError::RemoteProtocolError(
394                "Can't receive data when peer state is ERROR".into(),
395            ));
396        }
397        match (|| {
398            let event = self._extract_next_receive_event()?;
399            match event {
400                Event::NeedData() | Event::Paused() => {}
401                _ => {
402                    self._process_event(self.their_role, event.clone())?;
403                }
404            };
405
406            if let Event::NeedData() = event.clone() {
407                if self._receive_buffer.len() > self._max_incomplete_event_size {
408                    return Err(ProtocolError::RemoteProtocolError(
409                        ("Receive buffer too long".to_string(), 431).into(),
410                    ));
411                }
412                if self._receive_buffer_closed {
413                    return Err(ProtocolError::RemoteProtocolError(
414                        "peer unexpectedly closed connection".to_string().into(),
415                    ));
416                }
417            }
418
419            Ok(event)
420        })() {
421            Err(error) => {
422                self._process_error(self.their_role);
423                match error {
424                    ProtocolError::LocalProtocolError(error) => {
425                        Err(error._reraise_as_remote_protocol_error().into())
426                    }
427                    _ => Err(error),
428                }
429            }
430            Ok(any) => Ok(any),
431        }
432    }
433
434    pub fn send(&mut self, mut event: Event) -> Result<Option<Vec<u8>>, ProtocolError> {
435        if self.get_our_state() == State::Error {
436            return Err(ProtocolError::LocalProtocolError(
437                "Can't send data when our state is ERROR".to_string().into(),
438            ));
439        }
440        event = if let Event::NormalResponse(response) = &event {
441            Event::NormalResponse(self._clean_up_response_headers_for_sending(response.clone())?)
442        } else {
443            event
444        };
445        let event_type: EventType = (&event).into();
446        let res: Result<Vec<u8>, ProtocolError> = match self._writer.as_mut() {
447            Some(_) if event_type == EventType::ConnectionClosed => Ok(vec![]),
448            Some(writer) => writer(event.clone()),
449            None => Err(ProtocolError::LocalProtocolError(
450                "Can't send data when our state is not SEND_BODY"
451                    .to_string()
452                    .into(),
453            )),
454        };
455        self._process_event(self.our_role, event.clone())?;
456        if event_type == EventType::ConnectionClosed {
457            return Ok(None);
458        } else {
459            match res {
460                Ok(data_list) => Ok(Some(data_list)),
461                Err(error) => {
462                    self._process_error(self.our_role);
463                    Err(error)
464                }
465            }
466        }
467    }
468
469    pub fn send_failed(&mut self) {
470        self._process_error(self.our_role);
471    }
472
473    fn _clean_up_response_headers_for_sending(
474        &self,
475        response: Response,
476    ) -> Result<Response, ProtocolError> {
477        let mut headers = response.clone().headers;
478        let mut need_close = false;
479        let mut method_for_choosing_headers = self._request_method.clone().unwrap_or(vec![]);
480        if method_for_choosing_headers == b"HEAD".to_vec() {
481            method_for_choosing_headers = b"GET".to_vec();
482        }
483        let (framing_type, _) = _body_framing(&method_for_choosing_headers, response.clone());
484        if framing_type == "chunked" || framing_type == "http/1.0" {
485            headers = set_comma_header(&headers, b"content-length", vec![])?;
486            if self
487                .their_http_version
488                .clone()
489                .map(|v| v < b"1.1".to_vec())
490                .unwrap_or(true)
491            {
492                headers = set_comma_header(&headers, b"transfer-encoding", vec![])?;
493                if self._request_method.clone().unwrap_or(vec![]) != b"HEAD".to_vec() {
494                    need_close = true;
495                }
496            } else {
497                headers =
498                    set_comma_header(&headers, b"transfer-encoding", vec![b"chunked".to_vec()])?;
499            }
500        }
501        if !self._cstate.keep_alive || need_close {
502            let mut connection: HashSet<Vec<u8>> = get_comma_header(&headers, b"connection")
503                .into_iter()
504                .collect();
505            connection.retain(|x| x != &b"keep-alive".to_vec());
506            connection.insert(b"close".to_vec());
507            headers = set_comma_header(&headers, b"connection", connection.into_iter().collect())?;
508        }
509        return Ok(Response {
510            headers,
511            status_code: response.status_code,
512            http_version: response.http_version,
513            reason: response.reason,
514        });
515    }
516}
517
518#[cfg(test)]
519mod tests {
520    use super::*;
521
522    #[test]
523    fn test_keep_alive() {
524        assert!(_keep_alive(Request {
525            method: b"GET".to_vec(),
526            target: b"/".to_vec(),
527            headers: vec![(b"Host".to_vec(), b"Example.com".to_vec())].into(),
528            http_version: b"1.1".to_vec(),
529        }));
530        assert!(!_keep_alive(Request {
531            method: b"GET".to_vec(),
532            target: b"/".to_vec(),
533            headers: vec![
534                (b"Host".to_vec(), b"Example.com".to_vec()),
535                (b"Connection".to_vec(), b"close".to_vec()),
536            ]
537            .into(),
538            http_version: b"1.1".to_vec(),
539        }));
540        assert!(!_keep_alive(Request {
541            method: b"GET".to_vec(),
542            target: b"/".to_vec(),
543            headers: vec![
544                (b"Host".to_vec(), b"Example.com".to_vec()),
545                (b"Connection".to_vec(), b"a, b, cLOse, foo".to_vec()),
546            ]
547            .into(),
548            http_version: b"1.1".to_vec(),
549        }));
550        assert!(!_keep_alive(Request {
551            method: b"GET".to_vec(),
552            target: b"/".to_vec(),
553            headers: vec![].into(),
554            http_version: b"1.0".to_vec(),
555        }));
556
557        assert!(_keep_alive(Response {
558            status_code: 200,
559            headers: vec![].into(),
560            http_version: b"1.1".to_vec(),
561            reason: b"OK".to_vec(),
562        }));
563        assert!(!_keep_alive(Response {
564            status_code: 200,
565            headers: vec![(b"Connection".to_vec(), b"close".to_vec())].into(),
566            http_version: b"1.1".to_vec(),
567            reason: b"OK".to_vec(),
568        }));
569        assert!(!_keep_alive(Response {
570            status_code: 200,
571            headers: vec![(b"Connection".to_vec(), b"a, b, cLOse, foo".to_vec()),].into(),
572            http_version: b"1.1".to_vec(),
573            reason: b"OK".to_vec(),
574        }));
575        assert!(!_keep_alive(Response {
576            status_code: 200,
577            headers: vec![].into(),
578            http_version: b"1.0".to_vec(),
579            reason: b"OK".to_vec(),
580        }));
581    }
582
583    #[test]
584    fn test_body_framing() {
585        fn headers(cl: Option<usize>, te: bool) -> Headers {
586            let mut headers = vec![];
587            if let Some(cl) = cl {
588                headers.push((
589                    b"Content-Length".to_vec(),
590                    cl.to_string().as_bytes().to_vec(),
591                ));
592            }
593            if te {
594                headers.push((b"Transfer-Encoding".to_vec(), b"chunked".to_vec()));
595            }
596            headers.push((b"Host".to_vec(), b"example.com".to_vec()));
597            return headers.into();
598        }
599
600        fn resp(status_code: u16, cl: Option<usize>, te: bool) -> Response {
601            Response {
602                status_code,
603                headers: headers(cl, te),
604                http_version: b"1.1".to_vec(),
605                reason: b"OK".to_vec(),
606            }
607        }
608
609        fn req(cl: Option<usize>, te: bool) -> Request {
610            Request {
611                method: b"GET".to_vec(),
612                target: b"/".to_vec(),
613                headers: headers(cl, te),
614                http_version: b"1.1".to_vec(),
615            }
616        }
617
618        // Special cases where the headers are ignored:
619        for (cl, te) in vec![(Some(100), false), (None, true), (Some(100), true)] {
620            for (meth, r) in vec![
621                (b"HEAD".to_vec(), resp(200, cl, te)),
622                (b"GET".to_vec(), resp(204, cl, te)),
623                (b"GET".to_vec(), resp(304, cl, te)),
624            ] {
625                assert_eq!(_body_framing(&meth, r), ("content-length", 0));
626            }
627        }
628
629        // Transfer-encoding
630        for (cl, te) in vec![(None, true), (Some(100), true)] {
631            for (meth, r) in vec![
632                (b"".to_vec(), RequestOrResponse::from(req(cl, te))),
633                (b"GET".to_vec(), RequestOrResponse::from(resp(200, cl, te))),
634            ] {
635                assert_eq!(_body_framing(&meth, r), ("chunked", 0));
636            }
637        }
638
639        // Content-Length
640        for (meth, r) in vec![
641            (b"".to_vec(), RequestOrResponse::from(req(Some(100), false))),
642            (
643                b"GET".to_vec(),
644                RequestOrResponse::from(resp(200, Some(100), false)),
645            ),
646        ] {
647            assert_eq!(_body_framing(&meth, r), ("content-length", 100));
648        }
649
650        // No headers
651        assert_eq!(_body_framing(b"", req(None, false)), ("content-length", 0));
652        assert_eq!(
653            _body_framing(b"GET", resp(200, None, false)),
654            ("http/1.0", 0)
655        );
656    }
657}