mountpoint_s3_crt/http/
request_response.rs

1//! Management of HTTP streams
2
3use std::ffi::{OsStr, OsString};
4use std::mem::MaybeUninit;
5use std::os::unix::prelude::OsStrExt;
6use std::ptr::NonNull;
7
8use mountpoint_s3_crt_sys::*;
9use thiserror::Error;
10
11use crate::common::allocator::Allocator;
12use crate::common::error::Error;
13use crate::http::http_library_init;
14use crate::io::stream::InputStream;
15use crate::{CrtError, ToAwsByteCursor, aws_byte_cursor_as_slice};
16
17/// An HTTP header.
18#[derive(Debug)]
19pub struct Header<N: AsRef<OsStr>, V: AsRef<OsStr>> {
20    inner: aws_http_header,
21    name: N,
22    value: V,
23}
24
25impl<N: AsRef<OsStr>, V: AsRef<OsStr>> Header<N, V> {
26    /// Create a new header.
27    pub fn new(name: N, value: V) -> Self {
28        // SAFETY: this struct will own `name` and `value` so they will live as long as the byte
29        // cursors do.
30        let inner = unsafe {
31            aws_http_header {
32                name: name.as_ref().as_aws_byte_cursor(),
33                value: value.as_ref().as_aws_byte_cursor(),
34                ..Default::default()
35            }
36        };
37
38        Self { name, value, inner }
39    }
40
41    /// Get the name of this header
42    pub fn name(&self) -> &N {
43        &self.name
44    }
45
46    /// Get the value of this header
47    pub fn value(&self) -> &V {
48        &self.value
49    }
50}
51
52/// A block of HTTP headers that provides a nice API for getting/setting header names and values
53#[derive(Debug)]
54pub struct Headers {
55    inner: NonNull<aws_http_headers>,
56}
57
58/// Safety: This is okay since we don't implement a shallow Clone for Headers that could otherwise
59/// allow threads to simultaneously modify it.
60unsafe impl Send for Headers {}
61/// Safety: This is okay since we don't implement a shallow Clone for Headers that could otherwise
62/// allow threads to simultaneously modify it.
63unsafe impl Sync for Headers {}
64
65/// Errors returned by operations on [Headers].
66///
67/// TODO: Where the variant contains an [OsString] for the header name,
68/// we could explore using a static [OsStr] to avoid unnecessary memory copies
69/// since we know the values at compilation time.
70#[derive(Debug, Error, PartialEq, Eq)]
71pub enum HeadersError {
72    /// The header was not found
73    #[error("Header {0:?} not found")]
74    HeaderNotFound(OsString),
75
76    /// Internal CRT error
77    #[error("CRT error: {0}")]
78    CrtError(#[source] Error),
79
80    /// Header value could not be converted to String
81    #[error("Header {name:?} had invalid string value: {value:?}")]
82    Invalid {
83        /// Name of the header
84        name: OsString,
85        /// Value of the header, which was not valid to convert to [String]
86        value: OsString,
87    },
88}
89
90impl HeadersError {
91    /// Try to convert the CRT [Error] into [HeadersError::HeaderNotFound], or return [HeadersError::CrtError].
92    fn try_convert(err: Error, header_name: &OsStr) -> HeadersError {
93        if err.raw_error() == (aws_http_errors::AWS_ERROR_HTTP_HEADER_NOT_FOUND as i32) {
94            HeadersError::HeaderNotFound(header_name.to_owned())
95        } else {
96            HeadersError::CrtError(err)
97        }
98    }
99}
100
101impl Headers {
102    /// Construct a [Headers] from an existing instance of the underlying CRT structure. The
103    /// returned [Headers] will increment the reference count of the underlying CRT structure, and
104    /// so there are no lifetime issues here.
105    ///
106    /// ## Safety
107    ///
108    /// `ptr` must point to a valid `aws_http_headers` struct.
109    pub(crate) unsafe fn from_crt(ptr: NonNull<aws_http_headers>) -> Self {
110        // SAFETY: `ptr` points to a valid `aws_http_headers`.
111        unsafe { aws_http_headers_acquire(ptr.as_ptr()) };
112        Self { inner: ptr }
113    }
114
115    /// Create a new [Headers] object in the given allocator.
116    pub fn new(allocator: &Allocator) -> Result<Self, HeadersError> {
117        // SAFETY: allocator is a valid aws_allocator, and we check the return is non-null.
118        let inner = unsafe {
119            aws_http_headers_new(allocator.inner.as_ptr())
120                .ok_or_last_error()
121                .map_err(HeadersError::CrtError)?
122        };
123
124        Ok(Self { inner })
125    }
126
127    /// Return how many headers there are.
128    pub fn count(&self) -> usize {
129        // SAFETY: `self.inner` is a valid aws_http_headers, and `aws_http_headers_count` returns a
130        // value of a primitive type so there's no potential lifetime issues.
131        unsafe { aws_http_headers_count(self.inner.as_ptr()) }
132    }
133
134    /// Get the header at the specified index.
135    fn get_index(&self, index: usize) -> Result<Header<OsString, OsString>, HeadersError> {
136        // SAFETY: `self.inner` is a valid aws_http_headers, and `aws_http_headers_get_index`
137        // promises to initialize the output `struct aws_http_header *out_header` on success.
138        let header = unsafe {
139            let mut header: MaybeUninit<aws_http_header> = MaybeUninit::uninit();
140            aws_http_headers_get_index(self.inner.as_ptr(), index, header.as_mut_ptr())
141                .ok_or_last_error()
142                .map_err(HeadersError::CrtError)?;
143            header.assume_init()
144        };
145
146        // SAFETY: `header.name` and `header.value are assumed to be valid byte cursors since they
147        // came from the CRT, and we immediately make copies of them before they could expire.
148        let (name, value) = unsafe {
149            (
150                OsStr::from_bytes(aws_byte_cursor_as_slice(&header.name)).to_owned(),
151                OsStr::from_bytes(aws_byte_cursor_as_slice(&header.value)).to_owned(),
152            )
153        };
154
155        Ok(Header::new(name, value))
156    }
157
158    /// Add a [Header] to these [Headers]. Overrides any existing header with the same name.
159    pub fn add_header(&mut self, header: &Header<impl AsRef<OsStr>, impl AsRef<OsStr>>) -> Result<(), HeadersError> {
160        // CRT's default behavior is to always use the first header set with a given name, and
161        // ignore any later-added headers with that name. But this is non-obvious to users and could
162        // be a source of tricky bugs, so we tweak the semantics so that any existing headers with
163        // the same name are erased. This also makes the behavior match the behavior of
164        // `headers.iter().collect()` into a data structure like a HashMap or RbTree.
165        if self.has_header(header.name()) {
166            self.erase_header(header.name())?;
167        }
168
169        // SAFETY: `aws_http_headers_add_header` makes a copy of the underlying strings.
170        // Also, this function takes a mut reference to `self`, since this function modifies the headers.
171        unsafe {
172            aws_http_headers_add_header(self.inner.as_ptr(), &header.inner)
173                .ok_or_last_error()
174                .map_err(HeadersError::CrtError)?;
175        }
176
177        Ok(())
178    }
179
180    /// Returns whether a header with the given name is present in these [Headers].
181    pub fn has_header(&self, name: impl AsRef<OsStr>) -> bool {
182        // SAFETY: `aws_http_headers_has` doesn't hold on to a copy of the name we pass in, so it's
183        // okay to call with with an `aws_byte_cursor` that may not outlive this `Headers`.
184        unsafe { aws_http_headers_has(self.inner.as_ptr(), name.as_ref().as_aws_byte_cursor()) }
185    }
186
187    /// Erases a header with the given name from these [Headers].
188    pub fn erase_header(&self, name: impl AsRef<OsStr>) -> Result<(), HeadersError> {
189        // SAFETY: `aws_http_headers_erase` doesn't hold on to a copy of the name we pass in, so it's
190        // okay to call with with an `aws_byte_cursor` that may not outlive this `Headers`.
191        unsafe {
192            aws_http_headers_erase(self.inner.as_ptr(), name.as_ref().as_aws_byte_cursor())
193                .ok_or_last_error()
194                .map_err(|err| HeadersError::try_convert(err, name.as_ref()))?;
195        }
196
197        Ok(())
198    }
199
200    /// Get a single header by name from this block of headers
201    pub fn get<H: AsRef<OsStr>>(&self, name: H) -> Result<Header<OsString, OsString>, HeadersError> {
202        // SAFETY: `self.inner` is a valid aws_http_headers, and `aws_http_headers_get` promises to
203        // initialize the output `struct aws_byte_cursor *out_value` on success.
204        let value = unsafe {
205            let mut value: MaybeUninit<aws_byte_cursor> = MaybeUninit::uninit();
206
207            aws_http_headers_get(
208                self.inner.as_ptr(),
209                name.as_ref().as_aws_byte_cursor(),
210                value.as_mut_ptr(),
211            )
212            .ok_or_last_error()
213            .map_err(|err| HeadersError::try_convert(err, name.as_ref()))?;
214
215            value.assume_init()
216        };
217
218        let name = name.as_ref().to_os_string();
219
220        // SAFETY: `value` is assumed to be a valid byte cursor since it came from the CRT, and we
221        // immediately make a copy of it before the byte cursor can expire.
222        let value = unsafe { OsStr::from_bytes(aws_byte_cursor_as_slice(&value)).to_owned() };
223
224        Ok(Header::new(name, value))
225    }
226
227    /// Get a single header by name as a [String].
228    pub fn get_as_string<H: AsRef<OsStr>>(&self, name: H) -> Result<String, HeadersError> {
229        let name = name.as_ref();
230        let header = self.get(name)?;
231        let value = header.value();
232        if let Some(s) = value.to_str() {
233            Ok(s.to_string())
234        } else {
235            let err = HeadersError::Invalid {
236                name: name.to_owned(),
237                value: value.clone(),
238            };
239            Err(err)
240        }
241    }
242
243    /// Get an optional header by name as a [String].
244    pub fn get_as_optional_string<H: AsRef<OsStr>>(&self, name: H) -> Result<Option<String>, HeadersError> {
245        Ok(if self.has_header(&name) {
246            Some(self.get_as_string(name)?)
247        } else {
248            None
249        })
250    }
251
252    /// Iterate over the headers as (name, value) pairs.
253    pub fn iter(&self) -> impl Iterator<Item = (OsString, OsString)> + '_ {
254        HeadersIterator {
255            headers: self,
256            offset: 0,
257        }
258    }
259}
260
261impl Clone for Headers {
262    fn clone(&self) -> Self {
263        // SAFETY: `self.inner` is a valid `aws_http_headers`, and on Clone it's safe and required to increment
264        // the reference count, dropping new Headers object will decrement it
265        unsafe { Headers::from_crt(self.inner) }
266    }
267}
268
269impl Drop for Headers {
270    fn drop(&mut self) {
271        // SAFETY: `self.inner` is a valid `aws_http_headers`, and on Drop it's safe to decrement
272        // the reference count since we won't use it again through `self.`
273        unsafe {
274            aws_http_headers_release(self.inner.as_ptr());
275        }
276    }
277}
278
279/// [HeadersIterator] iterates through (name, value) pairs in the iterator.
280#[derive(Debug)]
281struct HeadersIterator<'a> {
282    headers: &'a Headers,
283    offset: usize,
284}
285
286impl Iterator for HeadersIterator<'_> {
287    type Item = (OsString, OsString);
288
289    fn next(&mut self) -> Option<Self::Item> {
290        if self.offset < self.headers.count() {
291            let header = self
292                .headers
293                .get_index(self.offset)
294                .expect("headers at any offset smaller than original count should always exist given mut access");
295            self.offset += 1;
296
297            Some((header.name, header.value))
298        } else {
299            None
300        }
301    }
302}
303
304/// A single HTTP message, initialized to be empty (i.e., no headers, no body).
305#[derive(Debug)]
306pub struct Message<'a> {
307    /// The pointer to the inner `aws_http_message`.
308    pub(crate) inner: NonNull<aws_http_message>,
309
310    /// Input stream for the body of the http message, if present.
311    body_input_stream: Option<InputStream<'a>>,
312}
313
314impl<'a> Message<'a> {
315    /// Creates a new HTTP/1.1 request message.
316    pub fn new_request(allocator: &Allocator) -> Result<Self, Error> {
317        // TODO: figure out a better place to call this
318        http_library_init(allocator);
319
320        // SAFETY: `allocator.inner` is a valid `aws_allocator`.
321        let inner = unsafe { aws_http_message_new_request(allocator.inner.as_ptr()).ok_or_last_error()? };
322
323        Ok(Self {
324            inner,
325            body_input_stream: None,
326        })
327    }
328
329    /// Add a header to this message. If the header already exists in the message, this will add a
330    /// another header instead of overwriting the existing one. Use [Self::set_header] to overwrite
331    /// potentially existing headers.
332    pub fn add_header(&mut self, header: &Header<impl AsRef<OsStr>, impl AsRef<OsStr>>) -> Result<(), Error> {
333        // SAFETY: `aws_http_message_add_header` makes a copy of the values in `header`.
334        unsafe { aws_http_message_add_header(self.inner.as_ptr(), header.inner).ok_or_last_error() }
335    }
336
337    /// Set a header in this message. The header is added if necessary and any existing values for
338    /// this name are removed.
339    pub fn set_header(&mut self, header: &Header<impl AsRef<OsStr>, impl AsRef<OsStr>>) -> Result<(), Error> {
340        // SAFETY: `self.inner` is a valid aws_http_message
341        let headers = unsafe { aws_http_message_get_headers(self.inner.as_ptr()) };
342        assert!(!headers.is_null(), "headers are always initialized");
343        // SAFETY: `aws_http_headers_set` makes a copy of the values in `header`
344        unsafe {
345            aws_http_headers_set(
346                headers,
347                header.name().as_aws_byte_cursor(),
348                header.value().as_aws_byte_cursor(),
349            )
350            .ok_or_last_error()
351        }
352    }
353
354    /// Set the request path for this message.
355    pub fn set_request_path(&mut self, path: impl AsRef<OsStr>) -> Result<(), Error> {
356        // SAFETY: `aws_http_message_set_request_path` makes a copy of `path`.
357        unsafe { aws_http_message_set_request_path(self.inner.as_ptr(), path.as_aws_byte_cursor()).ok_or_last_error() }
358    }
359
360    /// Set the request method for this message.
361    pub fn set_request_method(&mut self, method: impl AsRef<OsStr>) -> Result<(), Error> {
362        // SAFETY: `aws_http_message_set_request_method` makes a copy of `method`.
363        unsafe {
364            aws_http_message_set_request_method(self.inner.as_ptr(), method.as_aws_byte_cursor()).ok_or_last_error()
365        }
366    }
367
368    /// get the headers from the message and increases the reference count for the Headers in CRT.
369    pub fn get_headers(&mut self) -> Result<Headers, Error> {
370        // SAFETY: `aws_http_message_get_headers` is safe because self.inner is a valid NonNull `aws_http_message`.
371        let header_ptr = unsafe { aws_http_message_get_headers(self.inner.as_ptr()).ok_or_last_error()? };
372        // SAFETY: `Headers::from_crt` increments the reference count of the Headers object in CRT so there are
373        // no lifetime issues. And `header_ptr` is valid `aws_http_header` pointer.
374        let headers = unsafe { Headers::from_crt(header_ptr) };
375        Ok(headers)
376    }
377
378    /// Sets the body input stream for this message, and returns any previously set input stream.
379    /// If input_stream is None, unsets the body.
380    pub fn set_body_stream(&mut self, input_stream: Option<InputStream<'a>>) -> Option<InputStream<'a>> {
381        let old_input_stream = std::mem::replace(&mut self.body_input_stream, input_stream);
382
383        let new_input_stream_ptr = self
384            .body_input_stream
385            .as_ref()
386            .map(|s| s.inner.as_ptr())
387            .unwrap_or(std::ptr::null_mut());
388
389        // SAFETY: `aws_http_message_set_request_method` does _not_ take ownership of the underlying
390        // input stream. We take ownership of the input stream to make sure it doesn't get dropped
391        // while the CRT has a pointer to it. We also use lifetime parameters to enforce that this
392        // message does not outlive any data borrowed by the input stream.
393        unsafe {
394            aws_http_message_set_body_stream(self.inner.as_ptr(), new_input_stream_ptr);
395        }
396
397        old_input_stream
398    }
399}
400
401impl Drop for Message<'_> {
402    fn drop(&mut self) {
403        // SAFETY: `self.inner` is a valid `aws_http_message`, and on Drop it's safe to decrement
404        // the reference count since we won't use it again through `self.`
405        unsafe {
406            aws_http_message_release(self.inner.as_ptr());
407        }
408    }
409}
410
411#[cfg(test)]
412mod test {
413    use super::*;
414    use crate::common::allocator::Allocator;
415    use std::collections::HashMap;
416
417    /// Test various parts of the [Headers] API.
418    #[test]
419    fn test_headers() {
420        let mut headers = Headers::new(&Allocator::default()).expect("failed to create headers");
421
422        headers.add_header(&Header::new("a", "1")).unwrap();
423        headers.add_header(&Header::new("b", "2")).unwrap();
424        headers.add_header(&Header::new("c", "3")).unwrap();
425
426        assert_eq!(headers.count(), 3);
427
428        assert!(headers.has_header("a"));
429        assert!(headers.has_header("b"));
430        assert!(headers.has_header("c"));
431
432        assert_eq!(headers.get("a").unwrap().name(), "a");
433        assert_eq!(headers.get("a").unwrap().value(), "1");
434
435        assert_eq!(headers.get_as_string("a"), Ok("1".to_string()));
436        assert_eq!(headers.get_as_optional_string("a"), Ok(Some("1".to_string())));
437
438        let map: HashMap<OsString, OsString> = headers.iter().collect();
439
440        assert_eq!(map.len(), 3);
441        assert_eq!(map.get(OsStr::new("a")), Some(&OsString::from("1")));
442    }
443
444    /// Test the error returned when a requested header is not present.
445    #[test]
446    fn test_header_not_present() {
447        let headers = Headers::new(&Allocator::default()).expect("failed to create headers");
448
449        assert!(!headers.has_header("a"), "header should not be present");
450
451        let error = headers.get("a").expect_err("should fail because header is not present");
452        assert_eq!(
453            error.to_string(),
454            "Header \"a\" not found",
455            "header error display should match expected output",
456        );
457        if let HeadersError::HeaderNotFound(name) = error {
458            assert_eq!(name, "a", "header name should match original argument");
459        } else {
460            panic!("should fail with HeaderNotFound");
461        }
462
463        let error = headers
464            .get_as_string("a")
465            .expect_err("should fail because header is not present");
466        if let HeadersError::HeaderNotFound(name) = error {
467            assert_eq!(name, "a", "header name should match original argument");
468        } else {
469            panic!("should fail with HeaderNotFound");
470        }
471
472        let header = headers
473            .get_as_optional_string("a")
474            .expect("Should not fail as optional is expected here");
475        assert_eq!(header, None, "should return None");
476    }
477
478    /// Test setting the same header twice, which should overwrite with the second value.
479    #[test]
480    fn test_headers_overwrite() {
481        let mut headers = Headers::new(&Allocator::default()).expect("failed to create headers");
482
483        headers.add_header(&Header::new("a", "1")).unwrap();
484        headers.add_header(&Header::new("a", "2")).unwrap();
485
486        assert_eq!(headers.count(), 1);
487
488        assert_eq!(headers.get("a").unwrap().name(), "a");
489        assert_eq!(headers.get("a").unwrap().value(), "2");
490
491        let map: HashMap<OsString, OsString> = headers.iter().collect();
492
493        assert_eq!(map.len(), 1);
494        assert_eq!(map.get(OsStr::new("a")), Some(&OsString::from("2")));
495    }
496
497    /// Test erasing a header.
498    #[test]
499    fn test_headers_erase() {
500        let mut headers = Headers::new(&Allocator::default()).expect("failed to create headers");
501
502        headers.add_header(&Header::new("a", "1")).unwrap();
503        assert_eq!(headers.count(), 1);
504
505        headers.erase_header("a").unwrap();
506
507        assert_eq!(headers.count(), 0);
508
509        let map: HashMap<OsString, OsString> = headers.iter().collect();
510
511        assert_eq!(map.len(), 0);
512    }
513}