milstian_http/
request.rs

1//! # Handles everything related to HTTP requests.
2
3use std::collections::HashMap;
4use std::str;
5
6#[derive(Debug)]
7pub enum BodyContentType {
8    SinglePart(HashMap<String, String>),
9    MultiPart(HashMap<String, MultiPartValue>),
10}
11
12#[derive(Debug)]
13pub struct Message {
14    pub body: BodyContentType,
15    pub headers: HashMap<String, HeaderValueParts>,
16    pub request_line: Line,
17}
18
19#[derive(Debug)]
20pub struct Line {
21    pub method: Method,
22    pub protocol: Protocol,
23    pub raw: String,
24    pub request_uri: String,
25    pub request_uri_base: String,
26    pub query_arguments: HashMap<String, String>,
27    pub query_string: String,
28}
29
30#[derive(Debug, Eq, PartialEq)]
31pub enum Method {
32    Connect,
33    Delete,
34    Get,
35    Head,
36    Invalid,
37    Options,
38    Patch,
39    Post,
40    Put,
41    Trace,
42}
43
44#[derive(Debug)]
45pub enum HeaderContentType {
46    MultiPart(String), // String is multi-part boundary string
47    SinglePart,
48}
49
50#[derive(Debug)]
51pub enum HeaderValuePart {
52    Single(String),
53    KeyValue(String, String),
54}
55
56#[derive(Debug)]
57pub struct HeaderValueParts {
58    pub parts: Vec<Vec<HeaderValuePart>>,
59}
60
61impl HeaderValueParts {
62    pub fn get_key_value(&self, key: &str) -> Option<String> {
63        for params_block in self.parts.iter() {
64            for params_subblock in params_block.iter() {
65                if let HeaderValuePart::KeyValue(key_value_key, key_value_value) = params_subblock {
66                    if key_value_key == key {
67                        return Some(key_value_value.to_string());
68                    }
69                }
70            }
71        }
72        None
73    }
74
75    pub fn to_string(&self) -> String {
76        let mut output = String::new();
77        let mut params_block_count = 0;
78        for params_block in self.parts.iter() {
79            if params_block_count > 0 {
80                output.push_str("; ");
81            }
82            let mut params_subblock_count = 0;
83            for params_subblock in params_block.iter() {
84                if params_subblock_count > 0 {
85                    output.push_str(", ");
86                }
87                match params_subblock {
88                    HeaderValuePart::Single(string) => {
89                        output.push_str(&string);
90                    }
91                    HeaderValuePart::KeyValue(key, value) => {
92                        output.push_str(&format!("{}={}", key, value).to_string());
93                    }
94                }
95                params_subblock_count = params_subblock_count + 1;
96            }
97            params_block_count = params_block_count + 1;
98        }
99        output
100    }
101}
102
103#[derive(Debug)]
104pub struct MultiPartValue {
105    pub body: Vec<u8>,
106    pub headers: HashMap<String, HeaderValueParts>,
107}
108
109#[derive(Debug, Eq, PartialEq)]
110pub enum Protocol {
111    Invalid,
112    V1_0,
113    V1_1,
114    V2_0,
115    V0_9,
116}
117
118enum ParserSection {
119    Line,
120    HeaderFields,
121    MessageBody,
122}
123
124enum MultiPartSection {
125    End,
126    EndSecondary,
127    EndBoundary,
128    Skipping,
129    Start,
130    StartSuffix,
131}
132
133enum ParserMode {
134    Boundaries(Vec<u8>),
135    Lines,
136}
137
138#[derive(Debug, Eq, PartialEq)]
139enum SettingValence {
140    Optional,
141    No,
142    Yes,
143}
144
145impl Message {
146    fn method_has_request_body(method: &Method) -> SettingValence {
147        match method {
148            Method::Connect => SettingValence::Yes,
149            Method::Delete => SettingValence::No,
150            Method::Get => SettingValence::Optional,
151            Method::Head => SettingValence::No,
152            Method::Options => SettingValence::Optional,
153            Method::Patch => SettingValence::Yes,
154            Method::Post => SettingValence::Yes,
155            Method::Put => SettingValence::Yes,
156            Method::Trace => SettingValence::Yes,
157            Method::Invalid => SettingValence::Optional,
158        }
159    }
160
161    fn _method_has_response_body(method: &Method) -> bool {
162        match method {
163            Method::Connect => true,
164            Method::Delete => true,
165            Method::Get => true,
166            Method::Head => false,
167            Method::Options => true,
168            Method::Patch => true,
169            Method::Post => true,
170            Method::Put => true,
171            Method::Trace => true,
172            Method::Invalid => true,
173        }
174    }
175
176    fn _method_is_safe(method: &Method) -> bool {
177        match method {
178            Method::Connect => false,
179            Method::Delete => false,
180            Method::Get => true,
181            Method::Head => true,
182            Method::Options => true,
183            Method::Patch => false,
184            Method::Post => false,
185            Method::Put => false,
186            Method::Trace => true,
187            Method::Invalid => true,
188        }
189    }
190
191    fn _method_is_idempotent(method: &Method) -> bool {
192        match method {
193            Method::Connect => false,
194            Method::Delete => true,
195            Method::Get => true,
196            Method::Head => true,
197            Method::Options => true,
198            Method::Patch => false,
199            Method::Post => false,
200            Method::Put => true,
201            Method::Trace => true,
202            Method::Invalid => true,
203        }
204    }
205
206    fn _method_is_cacheable(method: &Method) -> bool {
207        match method {
208            Method::Connect => false,
209            Method::Delete => false,
210            Method::Get => true,
211            Method::Head => true,
212            Method::Options => false,
213            Method::Patch => false,
214            Method::Post => true,
215            Method::Put => false,
216            Method::Trace => false,
217            Method::Invalid => false,
218        }
219    }
220
221    fn get_query_args_from_multipart_blob(data: &[u8]) -> Option<(String, MultiPartValue)> {
222        let mut headers: HashMap<String, HeaderValueParts> = HashMap::new();
223        let mut last_was_carriage_return = false;
224        let mut index = 0;
225        let mut start = 0;
226        for byte in data.iter() {
227            if byte == &10 && last_was_carriage_return {
228                last_was_carriage_return = false;
229                if let Ok(utf8_line) = str::from_utf8(&data[start..index]) {
230                    if utf8_line.trim().is_empty() {
231                        start = index + 1;
232                        break;
233                    } else {
234                        if let Some((header_key, header_value)) =
235                            Message::get_header_field(utf8_line)
236                        {
237                            headers.insert(header_key, header_value);
238                        }
239                    }
240                    start = index + 1;
241                }
242            } else if byte == &13 {
243                last_was_carriage_return = true;
244            } else {
245                last_was_carriage_return = false;
246            }
247            index = index + 1;
248        }
249
250        // Did we find a name within the content-disposition header?
251        let mut name = String::new();
252        if let Some(content_disposition) = headers.get("Content-Disposition") {
253            if let Some(content_disposition_name) = content_disposition.get_key_value("name") {
254                name = content_disposition_name.trim_matches('"').to_string();
255            }
256        }
257        if !name.is_empty() {
258            let body = data[start..].to_vec();
259            if !body.is_empty() {
260                return Some((name, MultiPartValue { body, headers }));
261            }
262        }
263        None
264    }
265
266    fn get_query_args_from_string(subject: &str) -> Option<HashMap<String, String>> {
267        let mut args: HashMap<String, String> = HashMap::new();
268        if !subject.is_empty() {
269            let subject_arguments: Vec<&str> = subject.split("&").collect();
270            for item in subject_arguments {
271                let query_arg: Vec<&str> = item.split("=").collect();
272                if query_arg.len() == 2 {
273                    args.insert(query_arg.get(0)?.to_string(), query_arg.get(1)?.to_string());
274                } else {
275                    args.insert(query_arg.get(0)?.to_string(), String::from("1"));
276                }
277            }
278        }
279        if args.len() > 0 {
280            return Some(args);
281        }
282        None
283    }
284
285    pub fn get_protocol_text(protocol: &Protocol) -> String {
286        match protocol {
287            Protocol::V0_9 => String::from("HTTP/0.9"),
288            Protocol::V1_0 => String::from("HTTP/1.0"),
289            Protocol::V1_1 => String::from("HTTP/1.1"),
290            Protocol::V2_0 => String::from("HTTP/2.0"),
291            Protocol::Invalid => String::from("INVALID"),
292        }
293    }
294
295    pub fn get_message_body(body: &str) -> Option<BodyContentType> {
296        if let Some(body) = Message::get_query_args_from_string(body) {
297            return Some(BodyContentType::SinglePart(body));
298        }
299        None
300    }
301
302    // TODO Maybe change keys to Camel-Case to improve parsing
303    pub fn get_header_field(line: &str) -> Option<(String, HeaderValueParts)> {
304        let line = line.trim();
305        if !line.is_empty() {
306            let parts: Vec<&str> = line.splitn(2, ":").collect();
307            if parts.len() == 2 {
308                let header_key = parts.get(0)?.trim().to_string();
309                let header_value = parts.get(1)?.trim().to_string();
310                let mut header_parts: Vec<Vec<HeaderValuePart>> = Vec::new();
311
312                let params_blocks: Vec<&str> = header_value.split(";").collect();
313                for params_block in params_blocks.iter() {
314                    let mut header_value_part: Vec<HeaderValuePart> = Vec::new();
315                    let params_subblocks: Vec<&str> = params_block.split(",").collect();
316                    for params_subblock in params_subblocks.iter() {
317                        let params_subblock_clone = params_subblock.clone();
318                        let params_key_pair: Vec<&str> =
319                            params_subblock_clone.splitn(2, "=").collect();
320                        if params_key_pair.len() == 2 {
321                            let param_key = params_key_pair.get(0)?.trim().to_string();
322                            let param_value = params_key_pair.get(1)?.trim().to_string();
323                            header_value_part
324                                .push(HeaderValuePart::KeyValue(param_key, param_value));
325                        } else {
326                            header_value_part
327                                .push(HeaderValuePart::Single(params_subblock.trim().to_string()));
328                        }
329                    }
330                    header_parts.push(header_value_part);
331                }
332
333                return Some((
334                    header_key,
335                    HeaderValueParts {
336                        parts: header_parts,
337                    },
338                ));
339            }
340        }
341        None
342    }
343
344    pub fn get_request_line(line: &str) -> Option<Line> {
345        let line = line.trim();
346        let parts: Vec<&str> = line.split(" ").collect();
347        if parts.len() == 3 {
348            let method = match parts.get(0)?.as_ref() {
349                "CONNECT" => Method::Connect,
350                "DELETE" => Method::Delete,
351                "GET" => Method::Get,
352                "HEAD" => Method::Head,
353                "OPTIONS" => Method::Options,
354                "PATCH" => Method::Patch,
355                "PUT" => Method::Put,
356                "POST" => Method::Post,
357                "TRACE" => Method::Trace,
358                __ => Method::Invalid,
359            };
360
361            let request_uri = parts.get(1)?.to_string();
362            let request_uri_copy = request_uri.clone();
363            let mut request_uri_base = request_uri.clone();
364            let mut query_string = String::new();
365            let mut query_arguments: HashMap<String, String> = HashMap::new();
366            let uri_parts: Vec<&str> = request_uri_copy.splitn(2, "?").collect();
367            if uri_parts.len() == 2 {
368                request_uri_base = uri_parts.get(0)?.to_string();
369                query_string = uri_parts.get(1)?.to_string();
370                if let Some(query_args) = Message::get_query_args_from_string(&query_string) {
371                    query_arguments = query_args;
372                }
373            }
374
375            let protocol = match parts.get(2)?.as_ref() {
376                "HTTP/0.9" => Protocol::V0_9,
377                "HTTP/1.0" => Protocol::V1_0,
378                "HTTP/1.1" => Protocol::V1_1,
379                "HTTP/2.0" => Protocol::V2_0,
380                _ => Protocol::Invalid,
381            };
382
383            if method != Method::Invalid && protocol != Protocol::Invalid {
384                return Some(Line {
385                    method,
386                    protocol,
387                    raw: line.to_string(),
388                    request_uri,
389                    request_uri_base,
390                    query_arguments,
391                    query_string,
392                });
393            }
394        } else if parts.len() == 1 {
395            // Support for a request line containing only the path name is accepted by servers to
396            // maintain compatibility with  clients before the HTTP/1.0 specification.
397            let method = Method::Get;
398            let request_uri = parts.get(0)?.trim_matches(char::from(0)).to_string();
399            if !request_uri.is_empty() {
400                let protocol = Protocol::V0_9;
401
402                let request_uri_copy = request_uri.clone();
403                let mut request_uri_base = request_uri.clone();
404                let mut query_string = String::new();
405                let mut query_arguments: HashMap<String, String> = HashMap::new();
406
407                let uri_parts: Vec<&str> = request_uri_copy.splitn(2, "?").collect();
408                if uri_parts.len() == 2 {
409                    request_uri_base = uri_parts.get(0)?.to_string();
410                    query_string = uri_parts.get(1)?.to_string();
411                    if let Some(query_args) = Message::get_query_args_from_string(&query_string) {
412                        query_arguments = query_args;
413                    }
414                }
415
416                return Some(Line {
417                    method,
418                    protocol,
419                    raw: line.to_string(),
420                    request_uri,
421                    request_uri_base,
422                    query_arguments,
423                    query_string,
424                });
425            }
426        }
427        None
428    }
429
430    /// Try to decode a byte stream into a HTTP Message
431    /// ## Usage
432    /// ```rust
433    /// use milstian_http::request::{Message, Method, Protocol};
434    /// let response = Message::from_tcp_stream(b"GET / HTTP/2.0\r\n");
435    /// let response_unwrapped = response.expect("A decoded HTTP Message");
436    /// assert_eq!(response_unwrapped.request_line.method, Method::Get);
437    /// assert_eq!(response_unwrapped.request_line.request_uri, "/".to_string());
438    /// assert_eq!(response_unwrapped.request_line.protocol, Protocol::V2_0);
439    /// ```
440    pub fn from_tcp_stream(request: &[u8]) -> Option<Message> {
441        // Temporary message
442        let mut message = Message {
443            body: BodyContentType::SinglePart(HashMap::new()),
444            headers: HashMap::new(),
445            request_line: Line {
446                method: Method::Invalid,
447                protocol: Protocol::Invalid,
448                raw: String::new(),
449                request_uri: String::new(),
450                request_uri_base: String::new(),
451                query_arguments: HashMap::new(),
452                query_string: String::new(),
453            },
454        };
455
456        // Parsing variables
457        let mut start = 0;
458        let mut start_boundary = 0;
459        let mut start_data = 0;
460        let mut section = ParserSection::Line;
461        let mut end = 0;
462        let mut end_data = 0;
463        let last_index = match request.len() {
464            0 => 0,
465            _ => request.len() - 1,
466        };
467        let mut last_was_carriage_return = false;
468        let mut parser_mode = ParserMode::Lines;
469        let mut multipart_section = MultiPartSection::Start;
470
471        for byte in request.iter() {
472            match parser_mode {
473                // Are we parsing boundaries?
474                ParserMode::Boundaries(ref boundary) => {
475                    match multipart_section {
476                        // Stay here until we encounter \n\r
477                        MultiPartSection::Skipping => {
478                            if byte == &13 {
479                                last_was_carriage_return = true;
480                            } else if byte == &10 && last_was_carriage_return {
481                                multipart_section = MultiPartSection::Start;
482                                eprintln!("Going from 'skipping' -> 'start'");
483                                start_boundary = end + 1;
484                                last_was_carriage_return = false;
485                            } else if byte == &0 {
486                                break;
487                            } else {
488                                last_was_carriage_return = false;
489                            }
490                        }
491
492                        // Stay here until we encounter the boundary with optionally appending - characters
493                        MultiPartSection::Start => {
494                            // Does byte match next byte in boundary?
495                            if let Some(boundary_byte) = boundary.get(end - start_boundary) {
496                                if boundary_byte == byte {
497                                    // Was it the last character of boundary?
498                                    if end - start_boundary + 1 == boundary.len() {
499                                        multipart_section = MultiPartSection::StartSuffix;
500                                        eprintln!("Going from 'start' -> 'start suffix'");
501                                    }
502                                } else if byte == &45 && start_boundary < end {
503                                    if let Some(boundary_byte) =
504                                        boundary.get(end - start_boundary - 1)
505                                    {
506                                        if boundary_byte == byte {
507                                            start_boundary = start_boundary + 1;
508                                        } else {
509                                            multipart_section = MultiPartSection::Skipping;
510                                            eprintln!("Going from 'start' -> 'skipping'");
511                                        }
512                                    } else {
513                                        multipart_section = MultiPartSection::Skipping;
514                                        eprintln!("Going from 'start' -> 'skipping'");
515                                    }
516                                } else {
517                                    multipart_section = MultiPartSection::Skipping;
518                                    eprintln!("Going from 'start' -> 'skipping'");
519                                }
520                            } else if byte == &0 {
521                                break;
522                            } else {
523                                multipart_section = MultiPartSection::Skipping;
524                                eprintln!("Going from 'start' -> 'skipping'");
525                            }
526                        }
527
528                        // Stay here until we encounter \r\n after boundary
529                        MultiPartSection::StartSuffix => {
530                            if byte == &13 {
531                                last_was_carriage_return = true;
532                            } else if byte == &10 && last_was_carriage_return {
533                                multipart_section = MultiPartSection::End;
534                                eprintln!("Going from 'start suffix' -> 'end'");
535                                last_was_carriage_return = false;
536                                start_data = end;
537                            } else if byte == &0 {
538                                break;
539                            } else {
540                                last_was_carriage_return = false;
541                                multipart_section = MultiPartSection::Skipping;
542                                eprintln!("Going from 'start suffix' -> 'skipping'");
543                            }
544                        }
545
546                        // Stay here until we encounter \r\n
547                        MultiPartSection::End => {
548                            // Is it a carriage return?
549                            if byte == &13 {
550                                last_was_carriage_return = true;
551
552                            // Is it a new-line?
553                            } else if byte == &10 && last_was_carriage_return {
554                                multipart_section = MultiPartSection::EndSecondary;
555                                last_was_carriage_return = false;
556                                end_data = end - 1;
557                                start_boundary = end + 1;
558                                eprintln!("Going from 'end' -> 'end secondary'");
559                            } else if byte == &0 {
560                                break;
561                            }
562                        }
563
564                        // Stay here until we encounter \r\n
565                        MultiPartSection::EndSecondary => {
566                            // Is it a carriage return?
567                            if byte == &13 {
568                                last_was_carriage_return = true;
569
570                            // Is it a new-line?
571                            } else if byte == &10 && last_was_carriage_return {
572                                multipart_section = MultiPartSection::EndBoundary;
573                                last_was_carriage_return = false;
574                                eprintln!("Going from 'end secondary' -> 'end boundary'");
575                            } else if byte == &0 {
576                                break;
577                            } else {
578                                multipart_section = MultiPartSection::EndBoundary;
579                                last_was_carriage_return = false;
580                            }
581                        }
582
583                        // Stay here until we can't find boundary or find the full boundary
584                        MultiPartSection::EndBoundary => {
585                            // Does byte match next byte in boundary?
586                            if let Some(boundary_byte) = boundary.get(end - start_boundary) {
587                                if boundary_byte == byte {
588                                    eprintln!(
589                                        "Byte matched boundary byte {}",
590                                        *boundary_byte as char
591                                    );
592                                    // Was it the last character of boundary?
593                                    if end - start_boundary + 1 == boundary.len() {
594                                        multipart_section = MultiPartSection::StartSuffix;
595                                        eprintln!("Going from 'end boundary' -> 'start suffix'");
596
597                                        if start_data > 0
598                                            && start_data < end_data
599                                            && end_data < request.len()
600                                        {
601                                            let data = &request[start_data..end_data];
602                                            eprintln!(
603                                                "Trying to get query arg from {:?}",
604                                                str::from_utf8(&data)
605                                            );
606                                            if let Some((query_key, query_value)) =
607                                                Message::get_query_args_from_multipart_blob(&data)
608                                            {
609                                                if let BodyContentType::MultiPart(ref mut values) =
610                                                    message.body
611                                                {
612                                                    values.insert(query_key, query_value);
613                                                }
614                                            }
615                                        }
616                                    }
617
618                                // Was the character a '-' and does the start of boundary occur before the current position?
619                                } else if byte == &45 && start_boundary < end {
620                                    if let Some(boundary_byte) =
621                                        boundary.get(end - start_boundary - 1)
622                                    {
623                                        if boundary_byte == byte {
624                                            start_boundary = start_boundary + 1;
625                                            eprintln!(
626                                                "Character matches boundary byte '{}'",
627                                                *byte as char
628                                            );
629                                        } else {
630                                            multipart_section = MultiPartSection::End;
631                                            eprintln!("Going from 'end boundary' -> 'end'. Byte didnt match boundary {} vs {}", *boundary_byte as char, *byte as char);
632                                        }
633                                    } else {
634                                        multipart_section = MultiPartSection::End;
635                                        eprintln!("Going from 'end boundary' -> 'end'. Failed to find boundary byte");
636                                    }
637                                } else {
638                                    multipart_section = MultiPartSection::End;
639                                    eprintln!("Going from 'end boundary' -> 'end'. Not matching character was not a '-' but {:?}", *byte as char);
640                                    if byte == &13 {
641                                        last_was_carriage_return = true;
642                                    }
643                                }
644                            } else if byte == &0 {
645                                break;
646                            } else {
647                                multipart_section = MultiPartSection::End;
648                                eprintln!("Going from 'end boundary' -> 'end'");
649                            }
650                        }
651                    }
652                }
653
654                // Are we parsing lines?
655                ParserMode::Lines => {
656                    if byte == &13 {
657                        last_was_carriage_return = true;
658
659                    // Did we find a \r\n sequence?
660                    } else if byte == &10 && last_was_carriage_return {
661                        let clean_end = end - 1;
662                        if let Ok(utf8_line) = str::from_utf8(&request[start..clean_end]) {
663                            Message::parse_line(
664                                &utf8_line,
665                                &mut section,
666                                &mut message,
667                                &mut parser_mode,
668                            );
669                            start = end + 1;
670                            start_boundary = end + 1;
671                        }
672                        last_was_carriage_return = false;
673
674                    // When we get null bytes we are done or if we reach last index
675                    } else if byte == &0 || end == last_index {
676                        let clean_end = match byte {
677                            &0 => end,
678                            _ => end + 1,
679                        };
680                        if let Ok(utf8_line) = str::from_utf8(&request[start..clean_end]) {
681                            Message::parse_line(
682                                &utf8_line,
683                                &mut section,
684                                &mut message,
685                                &mut parser_mode,
686                            );
687                        }
688                        break;
689                    } else {
690                        last_was_carriage_return = false;
691                    }
692                }
693            }
694
695            // Increment byte position
696            end = end + 1;
697        }
698
699        // Did we find a valid method and protocol?
700        if message.request_line.method != Method::Invalid
701            && message.request_line.protocol != Protocol::Invalid
702        {
703            return Some(message);
704        }
705
706        None
707    }
708
709    fn parse_line(
710        line: &str,
711        section: &mut ParserSection,
712        message: &mut Message,
713        parser_mode: &mut ParserMode,
714    ) {
715        match section {
716            ParserSection::Line => {
717                if let Some(request_line_temp) = Message::get_request_line(line) {
718                    message.request_line = request_line_temp;
719                    *section = ParserSection::HeaderFields;
720                }
721            }
722            ParserSection::HeaderFields => {
723                // Is it the last line of the headers?
724                if line.trim().is_empty() {
725                    // Check if we have a multi-part body
726                    if let Some(content_type_header) = message.headers.get("Content-Type") {
727                        if let Some(boundary) = content_type_header.get_key_value("boundary") {
728                            *parser_mode = ParserMode::Boundaries(boundary.as_bytes().to_vec());
729                            message.body = BodyContentType::MultiPart(HashMap::new());
730                            eprintln!("Found boundary start: '{}'", &boundary);
731                        }
732                    }
733
734                    if Message::method_has_request_body(&message.request_line.method)
735                        != SettingValence::No
736                    {
737                        *section = ParserSection::MessageBody;
738                    }
739                } else {
740                    if let Some((header_key, header_value)) = Message::get_header_field(line) {
741                        message.headers.insert(header_key, header_value);
742                    }
743                }
744            }
745            ParserSection::MessageBody => {
746                if !line.is_empty() {
747                    if let Some(body_args) = Message::get_message_body(line) {
748                        message.body = body_args;
749                    }
750                }
751            }
752        }
753    }
754}
755
756#[cfg(test)]
757mod tests {
758    use super::*;
759
760    #[test]
761    fn test_get_message_body_single_part() {
762        let response = Message::get_message_body("random=abc&hej=def&def");
763        assert!(response.is_some());
764
765        let response_unwrapped = response.unwrap();
766        if let BodyContentType::SinglePart(response_unwrapped) = response_unwrapped {
767            assert_eq!(
768                response_unwrapped
769                    .get(&"random".to_string())
770                    .unwrap()
771                    .to_string(),
772                "abc".to_string()
773            );
774            assert_eq!(
775                response_unwrapped
776                    .get(&"hej".to_string())
777                    .unwrap()
778                    .to_string(),
779                "def".to_string()
780            );
781            assert_eq!(
782                response_unwrapped
783                    .get(&"def".to_string())
784                    .unwrap()
785                    .to_string(),
786                "1".to_string()
787            );
788            assert!(response_unwrapped.get(&"defs".to_string()).is_none());
789        }
790
791        let response = Message::get_message_body("");
792        assert!(response.is_none());
793    }
794
795    #[test]
796    fn test_get_query_args_from_multipart_blob() {
797        let response = Message::get_query_args_from_multipart_blob(
798            b"Content-Disposition: form-data; name=\"losen\"\r\n\r\nabc\n123",
799        );
800        assert!(response.is_some());
801        if let Some((query_key, query_value)) = response {
802            assert_eq!(query_key, "losen".to_string());
803            assert_eq!(
804                query_value
805                    .headers
806                    .get("Content-Disposition")
807                    .unwrap()
808                    .to_string(),
809                "form-data; name=\"losen\""
810            );
811            assert_eq!(query_value.body, b"abc\n123");
812        } else {
813            panic!("Expected multipart body but received: {:?}", response);
814        }
815
816        let response = Message::get_query_args_from_multipart_blob(
817            b"Content-Disposition: form-data; name=\"file\"; filename=\"KeePassXC-2.3.1.dmg.sig\"\r\nContent-Type: application/octet-stream\r\n\r\n
818-----BEGIN PGP SIGNATURE-----
819
820iQEzBAABCAAdFiEEweTLo61406/YlPngt6ZvA7WQdqgFAlqfE5MACgkQt6ZvA7WQ
821dqgnEAgAjtdbsMPaULGXKX6H+fcsYeGEN8OjiUTNz+StwNDkDxhxB4MT0N0lYZ4L
822xUv86kwMdWAaxp8pvVWo6gWXTEM5gWmN302bBxkpbhBl9fnq6WdcCCDGs4GM5vHX
823lOrHXWTsK+8ayLNZ0dCcP054srAtMmJHscPiuUYPfvKSgLxl+JxkPC147EktCCzv
8245O+2AtQPwIEPuaMewFqP9KjaGOhWgAc0nauIKa0ASt9FXXrexq1EoZnoZ3ZQ0p/w
825/otAB2D27yQ4kv+X2Rn94Ky9W0lMT2MYEF+/tQH4aEKsdMBQ7REQtfLGFlEzTMB/
826BNUI5YCF3PV9MKr3N53vEVYvkbXLbw==
827=LO1E
828-----END PGP SIGNATURE-----
829");
830
831        assert!(response.is_some());
832        if let Some((query_key, query_value)) = response {
833            assert_eq!(query_key, "file".to_string());
834            assert_eq!(
835                query_value
836                    .headers
837                    .get("Content-Disposition")
838                    .unwrap()
839                    .to_string(),
840                "form-data; name=\"file\"; filename=\"KeePassXC-2.3.1.dmg.sig\""
841            );
842        } else {
843            panic!("Expected multipart body but received: {:?}", response);
844        }
845
846        let response = Message::get_query_args_from_multipart_blob(
847            b"okasdokadsokasd oa skoasdk\r\nokadsokasdokoadskods\r\n123123",
848        );
849        assert!(response.is_none());
850    }
851
852    #[test]
853    fn test_get_header_field() {
854        let response = Message::get_header_field(
855            "User-Agent: Mozilla/5.0 (X11; Linux x86_64; rv:12.0) Gecko/20100101 Firefox/12.0\r\n",
856        );
857        assert!(response.is_some());
858
859        let (key, value) = response.unwrap();
860        assert_eq!(key, "User-Agent".to_string());
861        assert_eq!(
862            value.to_string(),
863            "Mozilla/5.0 (X11; Linux x86_64; rv:12.0) Gecko/20100101 Firefox/12.0".to_string()
864        );
865
866        let response = Message::get_header_field("Cache-Control: no-cache \r\n");
867        assert!(response.is_some());
868
869        let (key, value) = response.unwrap();
870        assert_eq!(key, "Cache-Control".to_string());
871        assert_eq!(value.to_string(), "no-cache".to_string());
872
873        let response = Message::get_header_field("Just various text here\r\n");
874        assert!(response.is_none());
875
876        let response = Message::get_header_field("");
877        assert!(response.is_none());
878
879        let response = Message::get_header_field(
880            "Content-Type: multipart/form-data; boundary=---------------------------208201381313076108731815782760\r\n",
881        );
882        assert!(response.is_some());
883        let (key, value) = response.unwrap();
884        assert_eq!(key, "Content-Type".to_string());
885        assert_eq!(value.to_string(), "multipart/form-data; boundary=---------------------------208201381313076108731815782760".to_string());
886        assert_eq!(
887            value.get_key_value("boundary").unwrap(),
888            "---------------------------208201381313076108731815782760".to_string()
889        );
890    }
891
892    #[test]
893    fn test_get_request_line() {
894        let response = Message::get_request_line("POST /random?abc=test HTTP/0.9\r\n");
895        assert!(response.is_some());
896
897        let response_unpacked = response.unwrap();
898        assert_eq!(response_unpacked.method, Method::Post);
899        assert_eq!(
900            response_unpacked.request_uri,
901            String::from("/random?abc=test")
902        );
903        assert_eq!(response_unpacked.request_uri_base, String::from("/random"));
904        assert_eq!(response_unpacked.query_string, String::from("abc=test"));
905        assert_eq!(
906            response_unpacked
907                .query_arguments
908                .get(&"abc".to_string())
909                .unwrap()
910                .to_string(),
911            String::from("test")
912        );
913        assert_eq!(response_unpacked.protocol, Protocol::V0_9);
914
915        let response = Message::get_request_line("GET / HTTP/1.0\r\n");
916        assert!(response.is_some());
917
918        let response_unpacked = response.unwrap();
919        assert_eq!(response_unpacked.method, Method::Get);
920        assert_eq!(response_unpacked.request_uri, String::from("/"));
921        assert_eq!(response_unpacked.request_uri_base, String::from("/"));
922        assert_eq!(response_unpacked.query_string, String::from(""));
923        assert_eq!(response_unpacked.protocol, Protocol::V1_0);
924
925        let response = Message::get_request_line("HEAD /moradish.html?test&abc=def HTTP/1.1\r\n");
926        assert!(response.is_some());
927
928        let response_unpacked = response.unwrap();
929        assert_eq!(response_unpacked.method, Method::Head);
930        assert_eq!(
931            response_unpacked.request_uri,
932            String::from("/moradish.html?test&abc=def")
933        );
934        assert_eq!(
935            response_unpacked.request_uri_base,
936            String::from("/moradish.html")
937        );
938        assert_eq!(response_unpacked.query_string, String::from("test&abc=def"));
939        assert_eq!(
940            response_unpacked
941                .query_arguments
942                .get(&"test".to_string())
943                .unwrap()
944                .to_string(),
945            String::from("1")
946        );
947        assert_eq!(
948            response_unpacked
949                .query_arguments
950                .get(&"abc".to_string())
951                .unwrap()
952                .to_string(),
953            String::from("def")
954        );
955        assert_eq!(response_unpacked.protocol, Protocol::V1_1);
956
957        let response = Message::get_request_line("OPTIONS /random/random2.txt HTTP/2.0\r\n");
958        assert!(response.is_some());
959
960        let response_unpacked = response.unwrap();
961        assert_eq!(response_unpacked.method, Method::Options);
962        assert_eq!(
963            response_unpacked.request_uri,
964            String::from("/random/random2.txt")
965        );
966        assert_eq!(response_unpacked.protocol, Protocol::V2_0);
967
968        let response = Message::get_request_line("GET / HTTP/2.2\r\n");
969        assert!(response.is_none());
970    }
971
972    #[test]
973    fn test_from_tcp_stream() {
974        // GET request with no headers or body
975        let response = Message::from_tcp_stream(b"GET / HTTP/2.0\r\n");
976        assert!(response.is_some());
977        let response_unwrapped = response.expect("GET HTTP2");
978        assert_eq!(response_unwrapped.request_line.method, Method::Get);
979        assert_eq!(response_unwrapped.request_line.request_uri, "/".to_string());
980        assert_eq!(response_unwrapped.request_line.protocol, Protocol::V2_0);
981
982        // POST request with random header and null bytes
983        let mut request: Vec<u8> =
984            b"POST /random HTTP/1.0\r\nAgent: Random browser\r\n\r\ntest=abc".to_vec();
985        request.push(0);
986        request.push(0);
987        let response = Message::from_tcp_stream(&request);
988        assert!(response.is_some());
989        assert_eq!(
990            "/random".to_string(),
991            response.expect("/random").request_line.request_uri
992        );
993
994        // POST request with random header
995        let response =
996            Message::from_tcp_stream(b"POST / HTTP/1.0\r\nAgent: Random browser\r\n\r\ntest=abc");
997        assert!(response.is_some());
998        let response_unwrapped = response.expect("POST HTTP1");
999        assert_eq!(response_unwrapped.request_line.method, Method::Post);
1000        assert_eq!(response_unwrapped.request_line.protocol, Protocol::V1_0);
1001        assert_eq!(
1002            response_unwrapped
1003                .headers
1004                .get(&"Agent".to_string())
1005                .expect("Agent")
1006                .to_string(),
1007            "Random browser".to_string()
1008        );
1009        if let BodyContentType::SinglePart(body) = response_unwrapped.body {
1010            assert_eq!(
1011                body.get(&"test".to_string()).expect("test-abc").to_string(),
1012                "abc".to_string()
1013            );
1014        }
1015
1016        // Two invalid  requests
1017        let response = Message::from_tcp_stream(b"RANDOM /stuff HTTP/2.5\r\n");
1018        assert!(response.is_none());
1019        let response = Message::from_tcp_stream(b"");
1020        assert!(response.is_none());
1021
1022        // Multi-part with one form-data
1023        let response = Message::from_tcp_stream(b"POST /?test=abcdef HTTP/1.1\r\nHost: localhost:8888\r\nUser-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10.13; rv:63.0) Gecko/20100101 Firefox/63.0\r\nAccept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\nAccept-Language: en-US,en;q=0.5\r\nAccept-Encoding: gzip, deflate\r\nReferer: http://localhost:8888/?test=abcdef\r\nContent-Type: multipart/form-data; boundary=-----------------------------3204198641555151219403070096\r\nContent-Length: 733\r\nDNT: 1\r\nConnection: keep-alive\r\nUpgrade-Insecure-Requests: 1\r\nPragma: no-cache\r\nCache-Control: no-cache\r\n\r\n-----------------------------3204198641555151219403070096\r\nContent-Disposition: form-data; name=\"file\"; filename=\"KeePassXC-2.3.3.dmg.DIGEST\"\r\nContent-Type: application/octet-stream\r\n\r\n1219dd686aee2549ef8fe688aeef22e85272a8ccbefdbbb64c0e5601db17fbdb  KeePassXC-2.3.3.dmg\r\n\r\n-----------------------------3204198641555151219403070096\r\n");
1024        assert!(response.is_some());
1025        let response_unwrapped = response.expect("multipart");
1026        if let BodyContentType::MultiPart(body) = response_unwrapped.body {
1027            eprintln!("body: {:?}", body);
1028            assert_eq!(
1029                String::from_utf8(body.get(&"file".to_string()).expect("expecting file data 1").body.clone()).expect("expecting utf-8 file data"),
1030                "1219dd686aee2549ef8fe688aeef22e85272a8ccbefdbbb64c0e5601db17fbdb  KeePassXC-2.3.3.dmg".to_string()
1031            );
1032        } else {
1033            eprintln!(
1034                "Boundary header: {:?}",
1035                response_unwrapped
1036                    .headers
1037                    .get("Content-Type")
1038                    .expect("A content-type header")
1039                    .get_key_value("boundary")
1040                    .expect("A boundary")
1041            );
1042            panic!(
1043                "Expected multipart content but got: {:?}",
1044                response_unwrapped
1045            );
1046        }
1047
1048        // Multi-part data with two data
1049        let response = Message::from_tcp_stream(b"POST /?test=abcdef HTTP/1.1\r\nHost: localhost:8888\r\nUser-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10.13; rv:63.0) Gecko/20100101 Firefox/63.0\r\nAccept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\nAccept-Language: en-US,en;q=0.5\r\nAccept-Encoding: gzip, deflate\r\nReferer: http://localhost:8888/?test=abcdef\r\nContent-Type: multipart/form-data; boundary=-----------------------------3204198641555151219403070096\r\nContent-Length: 733\r\nDNT: 1\r\nConnection: keep-alive\r\nUpgrade-Insecure-Requests: 1\r\nPragma: no-cache\r\nCache-Control: no-cache\r\n\r\n-----------------------------3204198641555151219403070096\r\nContent-Disposition: form-data; name=\"file\"; filename=\"KeePassXC-2.3.3.dmg.DIGEST\"\r\nContent-Type: application/octet-stream\r\n\r\n1219dd686aee2549ef8fe688aeef22e85272a8ccbefdbbb64c0e5601db17fbdb  KeePassXC-2.3.3.dmg\r\n\r\n-----------------------------3204198641555151219403070096\r\nContent-Disposition: form-data; name=\"file2\"; filename=\"KeePassXC-2.3.3.dmg.sig\"\r\nContent-Type: application/octet-stream\r\n\r\n-----BEGIN PGP SIGNATURE-----\n\niQEzBAABCAAdFiEEweTLo61406/YlPngt6ZvA7WQdqgFAlrzMl4ACgkQt6ZvA7WQ\ndqhkrQf9G3r5thluX7Ogx9BCnot2L17nH7DFcwcWe2k1gHyC7ttkbdYSXQXaCDGN\nYmedemyvdE7d/TZxbbPuo09LYvj/+5WAUx8KBJHsE6xMK7kwbZJ5i3BBO2NY7p2b\no68XU+Emg6VuynjoW9xDTQO/2PUSSzJeU9Jql7RXPY2RpJp0+BbGkC356vavZk9a\n8oX8/abn1iZgzfY1lyC4aBNHFf7ycalEbOgGAfw/iT5qtDIihLf4QwFqCKO0/stn\nB118cEtpnKmAQuQMoAqKXlPg8f3xxVf2plJZkRMaynX39ykf3gAeRDnkCoQWx0GN\nFr5IBrP1bBbAWAKn2C4TqKb9QyMwJw==\n=icrk\n-----END PGP SIGNATURE-----\r\n\r\n-----------------------------3204198641555151219403070096--\r\n");
1050        let response_unwrapped = response.expect("multipart");
1051        if let BodyContentType::MultiPart(body) = response_unwrapped.body {
1052            assert_eq!(
1053                String::from_utf8(body.get(&"file".to_string()).expect("expecting file data").body.clone()).expect("expecting utf-8 file data"),
1054                "1219dd686aee2549ef8fe688aeef22e85272a8ccbefdbbb64c0e5601db17fbdb  KeePassXC-2.3.3.dmg".to_string()
1055            );
1056            assert_eq!(
1057                String::from_utf8(body.get(&"file2".to_string()).expect("expecting file data").body.clone()).expect("expecting utf-8 file data"),
1058                "-----BEGIN PGP SIGNATURE-----\n\niQEzBAABCAAdFiEEweTLo61406/YlPngt6ZvA7WQdqgFAlrzMl4ACgkQt6ZvA7WQ\ndqhkrQf9G3r5thluX7Ogx9BCnot2L17nH7DFcwcWe2k1gHyC7ttkbdYSXQXaCDGN\nYmedemyvdE7d/TZxbbPuo09LYvj/+5WAUx8KBJHsE6xMK7kwbZJ5i3BBO2NY7p2b\no68XU+Emg6VuynjoW9xDTQO/2PUSSzJeU9Jql7RXPY2RpJp0+BbGkC356vavZk9a\n8oX8/abn1iZgzfY1lyC4aBNHFf7ycalEbOgGAfw/iT5qtDIihLf4QwFqCKO0/stn\nB118cEtpnKmAQuQMoAqKXlPg8f3xxVf2plJZkRMaynX39ykf3gAeRDnkCoQWx0GN\nFr5IBrP1bBbAWAKn2C4TqKb9QyMwJw==\n=icrk\n-----END PGP SIGNATURE-----".to_string()
1059            );
1060        } else {
1061            eprintln!(
1062                "Boundary header: {:?}",
1063                response_unwrapped
1064                    .headers
1065                    .get("Content-Type")
1066                    .expect("A content-type header")
1067                    .get_key_value("boundary")
1068                    .expect("A boundary")
1069            );
1070            panic!(
1071                "Expected multipart content but got: {:?}",
1072                response_unwrapped
1073            );
1074        }
1075
1076        // Get requests should get their message body parsed
1077        let response = Message::from_tcp_stream(b"GET / HTTP/2.0\r\n\r\nabc=123");
1078        assert!(response.is_some());
1079        let response_unwrapped = response.unwrap();
1080        if let BodyContentType::SinglePart(body) = response_unwrapped.body {
1081            assert_eq!(
1082                body.get(&"abc".to_string()).unwrap().to_string(),
1083                "123".to_string()
1084            );
1085        }
1086
1087        // HEAD requests should not get their message body parsed
1088        let response = Message::from_tcp_stream(b"HEAD / HTTP/2.0\r\n\r\nabc=123");
1089        assert!(response.is_some());
1090        let response_unwrapped = response.unwrap();
1091        if let BodyContentType::SinglePart(body) = response_unwrapped.body {
1092            assert!(body.get(&"abc".to_string()).is_none());
1093        }
1094
1095        let response = Message::from_tcp_stream(b"html/index.html\r\n");
1096        assert!(response.is_some());
1097        let response_unwrapped = response.unwrap();
1098        assert_eq!(response_unwrapped.request_line.method, Method::Get);
1099        assert_eq!(
1100            response_unwrapped.request_line.request_uri,
1101            "html/index.html".to_string()
1102        );
1103        assert_eq!(response_unwrapped.request_line.protocol, Protocol::V0_9);
1104
1105        let response = Message::from_tcp_stream(&[0; 100]);
1106        assert!(response.is_none());
1107    }
1108
1109}