gnostr_web/headers/
mod.rs

1use crate::parser::parse_url_encoded;
2use regex::Regex;
3use std::collections::HashMap;
4use std::io::Read;
5use std::net::TcpStream;
6
7pub type Headers = HashMap<String, Vec<String>>;
8
9#[derive(Debug)]
10pub enum RequestHeaderError {
11    /// Occurs if header size is larger than the given limit
12    MaxSizeExceed,
13    /// Occurs if client is disconnected
14    ClientDisconnected,
15}
16
17/// It will try to read headers from the tcp stream.
18/// Returns type `RequestHeaderError` if failed to extract headers.
19pub fn extract_headers(
20    stream: &mut TcpStream,
21    start_header: &mut String,
22    partial_body_bytes: &mut Vec<u8>,
23    max_size: usize,
24) -> Result<Headers, RequestHeaderError> {
25    let mut header_bytes = Vec::new();
26
27    let mut read_all_headers = false;
28
29    while !read_all_headers {
30        if header_bytes.len() > max_size {
31            return Err(RequestHeaderError::MaxSizeExceed);
32        }
33
34        let mut buffer = [0u8; 1024];
35        let read_result = stream.read(&mut buffer);
36
37        let read_size;
38
39        match read_result {
40            Ok(bytes_read) => {
41                if bytes_read == 0 {
42                    return Err(RequestHeaderError::ClientDisconnected);
43                }
44                read_size = bytes_read;
45            }
46
47            Err(_) => {
48                return Err(RequestHeaderError::ClientDisconnected);
49            }
50        }
51
52        // There will be index if the header is ended. However, contains_full_header don't take
53        // complete request header.
54        if let Some(header_end_index) = contains_full_headers(&buffer) {
55            header_bytes.extend(&buffer[..header_end_index]);
56
57            // Body starts from header_end_index + "\r\n\r\n"
58            partial_body_bytes.extend(&buffer[header_end_index + 4..read_size]);
59            read_all_headers = true;
60        } else {
61            header_bytes.extend(&buffer[..read_size]);
62        }
63    }
64
65    let raw_request_headers =
66        String::from_utf8(header_bytes).expect("Unsupported header encoding.");
67    let header_lines: Vec<&str> = raw_request_headers.split("\r\n").collect();
68
69    let mut headers: Headers = HashMap::new();
70    for (index, header_line) in header_lines.iter().enumerate() {
71        if index == 0 {
72            *start_header = header_line.to_string();
73        }
74
75        let key_value = parse_header(header_line);
76
77        if let Some((key, value)) = key_value {
78            if headers.contains_key(&key) {
79                let values = headers.get_mut(&key).unwrap();
80                values.push(value);
81            } else {
82                let header_value: Vec<String> = vec![value];
83                headers.insert(key, header_value);
84            }
85        }
86    }
87
88    return Ok(headers);
89}
90
91/// Returns content length from the `Header` if available
92pub fn content_length(headers: &Headers) -> Option<usize> {
93    if let Some(values) = headers.get("Content-Length") {
94        if values.len() > 0 {
95            let value = values.get(0).unwrap();
96            let content_length_value = value.parse::<usize>().expect("Invalid content length");
97            return Some(content_length_value);
98        }
99    }
100
101    return None;
102}
103
104/// Returns the value of `Connection` header if available
105pub fn connection_type(headers: &Headers) -> Option<String> {
106    if let Some(values) = headers.get("Connection") {
107        if values.len() > 0 {
108            let value = values.get(0).unwrap();
109            return Some(value.to_owned());
110        }
111    }
112
113    return None;
114}
115
116/// Returns `Host` value from the Header if available.
117pub fn host(headers: &Headers) -> Option<String> {
118    let host = headers.get("Host");
119    if let Some(host) = host {
120        if host.len() > 0 {
121            let value = host.get(0).unwrap();
122            return Some(value.to_string());
123        }
124    }
125
126    return None;
127}
128
129/// Returns `Content-Type` value from the header if available
130pub fn extract_content_type(headers: &Headers) -> Option<String> {
131    if let Some(values) = headers.get("Content-Type") {
132        let value = values.get(0).expect("Content-Type implementation error");
133        return Some(value.to_owned());
134    }
135
136    return None;
137}
138
139/// Returns size of header end position if header ends with "\r\n\r\n"
140pub fn contains_full_headers(buffer: &[u8]) -> Option<usize> {
141    let end_header_bytes = b"\r\n\r\n";
142    buffer
143        .windows(end_header_bytes.len())
144        .position(|window| window == end_header_bytes)
145}
146
147/// Returns the request method and raw path from the header line if matched
148/// ```markdown
149/// GET / HTTP/1.1
150/// ```
151pub fn parse_request_method_header(line: &str) -> Option<(String, String)> {
152    let pattern = Regex::new(r"(?<method>.+) (?<path>.+) (.+)").unwrap();
153
154    if let Some(groups) = pattern.captures(line) {
155        let request_method = &groups["method"];
156        let path = &groups["path"];
157        return Some((request_method.to_string(), path.to_string()));
158    }
159
160    return None;
161}
162
163/// Returns key value pair from the header line
164///
165/// Input example:
166/// ```markdown
167/// Content-Length: 10
168/// ```
169pub fn parse_header(line: &str) -> Option<(String, String)> {
170    let header_line: Vec<&str> = line.splitn(2, ":").collect();
171    if header_line.len() >= 2 {
172        let name = header_line.get(0).unwrap().trim().to_string();
173        let value = header_line.get(1).unwrap().trim().to_string();
174        return Some((name, value));
175    }
176    return None;
177}
178
179/// Returns map of url encoded key values
180/// Example: `/search?name=John&age=22`
181pub fn query_params_from_raw(raw_path: &String) -> HashMap<String, Vec<String>> {
182    let query_params: HashMap<String, Vec<String>> = HashMap::new();
183    let match_result = raw_path.find("?");
184
185    if !match_result.is_some() {
186        return query_params;
187    }
188
189    let index = match_result.unwrap();
190    if index == raw_path.len() - 1 {
191        return query_params;
192    }
193
194    let slice = &raw_path[index + 1..raw_path.len()];
195    return parse_url_encoded(slice);
196}