1use std::ffi::{OsStr, OsString};
4use std::mem::MaybeUninit;
5use std::os::unix::prelude::OsStrExt;
6use std::ptr::NonNull;
7
8use mountpoint_s3_crt_sys::*;
9use thiserror::Error;
10
11use crate::common::allocator::Allocator;
12use crate::common::error::Error;
13use crate::http::http_library_init;
14use crate::io::stream::InputStream;
15use crate::{CrtError, ToAwsByteCursor, aws_byte_cursor_as_slice};
16
17#[derive(Debug)]
19pub struct Header<N: AsRef<OsStr>, V: AsRef<OsStr>> {
20 inner: aws_http_header,
21 name: N,
22 value: V,
23}
24
25impl<N: AsRef<OsStr>, V: AsRef<OsStr>> Header<N, V> {
26 pub fn new(name: N, value: V) -> Self {
28 let inner = unsafe {
31 aws_http_header {
32 name: name.as_ref().as_aws_byte_cursor(),
33 value: value.as_ref().as_aws_byte_cursor(),
34 ..Default::default()
35 }
36 };
37
38 Self { name, value, inner }
39 }
40
41 pub fn name(&self) -> &N {
43 &self.name
44 }
45
46 pub fn value(&self) -> &V {
48 &self.value
49 }
50}
51
52#[derive(Debug)]
54pub struct Headers {
55 inner: NonNull<aws_http_headers>,
56}
57
58unsafe impl Send for Headers {}
61unsafe impl Sync for Headers {}
64
65#[derive(Debug, Error, PartialEq, Eq)]
71pub enum HeadersError {
72 #[error("Header {0:?} not found")]
74 HeaderNotFound(OsString),
75
76 #[error("CRT error: {0}")]
78 CrtError(#[source] Error),
79
80 #[error("Header {name:?} had invalid string value: {value:?}")]
82 Invalid {
83 name: OsString,
85 value: OsString,
87 },
88}
89
90impl HeadersError {
91 fn try_convert(err: Error, header_name: &OsStr) -> HeadersError {
93 if err.raw_error() == (aws_http_errors::AWS_ERROR_HTTP_HEADER_NOT_FOUND as i32) {
94 HeadersError::HeaderNotFound(header_name.to_owned())
95 } else {
96 HeadersError::CrtError(err)
97 }
98 }
99}
100
101impl Headers {
102 pub(crate) unsafe fn from_crt(ptr: NonNull<aws_http_headers>) -> Self {
110 unsafe { aws_http_headers_acquire(ptr.as_ptr()) };
112 Self { inner: ptr }
113 }
114
115 pub fn new(allocator: &Allocator) -> Result<Self, HeadersError> {
117 let inner = unsafe {
119 aws_http_headers_new(allocator.inner.as_ptr())
120 .ok_or_last_error()
121 .map_err(HeadersError::CrtError)?
122 };
123
124 Ok(Self { inner })
125 }
126
127 pub fn count(&self) -> usize {
129 unsafe { aws_http_headers_count(self.inner.as_ptr()) }
132 }
133
134 fn get_index(&self, index: usize) -> Result<Header<OsString, OsString>, HeadersError> {
136 let header = unsafe {
139 let mut header: MaybeUninit<aws_http_header> = MaybeUninit::uninit();
140 aws_http_headers_get_index(self.inner.as_ptr(), index, header.as_mut_ptr())
141 .ok_or_last_error()
142 .map_err(HeadersError::CrtError)?;
143 header.assume_init()
144 };
145
146 let (name, value) = unsafe {
149 (
150 OsStr::from_bytes(aws_byte_cursor_as_slice(&header.name)).to_owned(),
151 OsStr::from_bytes(aws_byte_cursor_as_slice(&header.value)).to_owned(),
152 )
153 };
154
155 Ok(Header::new(name, value))
156 }
157
158 pub fn add_header(&mut self, header: &Header<impl AsRef<OsStr>, impl AsRef<OsStr>>) -> Result<(), HeadersError> {
160 if self.has_header(header.name()) {
166 self.erase_header(header.name())?;
167 }
168
169 unsafe {
172 aws_http_headers_add_header(self.inner.as_ptr(), &header.inner)
173 .ok_or_last_error()
174 .map_err(HeadersError::CrtError)?;
175 }
176
177 Ok(())
178 }
179
180 pub fn has_header(&self, name: impl AsRef<OsStr>) -> bool {
182 unsafe { aws_http_headers_has(self.inner.as_ptr(), name.as_ref().as_aws_byte_cursor()) }
185 }
186
187 pub fn erase_header(&self, name: impl AsRef<OsStr>) -> Result<(), HeadersError> {
189 unsafe {
192 aws_http_headers_erase(self.inner.as_ptr(), name.as_ref().as_aws_byte_cursor())
193 .ok_or_last_error()
194 .map_err(|err| HeadersError::try_convert(err, name.as_ref()))?;
195 }
196
197 Ok(())
198 }
199
200 pub fn get<H: AsRef<OsStr>>(&self, name: H) -> Result<Header<OsString, OsString>, HeadersError> {
202 let value = unsafe {
205 let mut value: MaybeUninit<aws_byte_cursor> = MaybeUninit::uninit();
206
207 aws_http_headers_get(
208 self.inner.as_ptr(),
209 name.as_ref().as_aws_byte_cursor(),
210 value.as_mut_ptr(),
211 )
212 .ok_or_last_error()
213 .map_err(|err| HeadersError::try_convert(err, name.as_ref()))?;
214
215 value.assume_init()
216 };
217
218 let name = name.as_ref().to_os_string();
219
220 let value = unsafe { OsStr::from_bytes(aws_byte_cursor_as_slice(&value)).to_owned() };
223
224 Ok(Header::new(name, value))
225 }
226
227 pub fn get_as_string<H: AsRef<OsStr>>(&self, name: H) -> Result<String, HeadersError> {
229 let name = name.as_ref();
230 let header = self.get(name)?;
231 let value = header.value();
232 if let Some(s) = value.to_str() {
233 Ok(s.to_string())
234 } else {
235 let err = HeadersError::Invalid {
236 name: name.to_owned(),
237 value: value.clone(),
238 };
239 Err(err)
240 }
241 }
242
243 pub fn get_as_optional_string<H: AsRef<OsStr>>(&self, name: H) -> Result<Option<String>, HeadersError> {
245 Ok(if self.has_header(&name) {
246 Some(self.get_as_string(name)?)
247 } else {
248 None
249 })
250 }
251
252 pub fn iter(&self) -> impl Iterator<Item = (OsString, OsString)> + '_ {
254 HeadersIterator {
255 headers: self,
256 offset: 0,
257 }
258 }
259}
260
261impl Clone for Headers {
262 fn clone(&self) -> Self {
263 unsafe { Headers::from_crt(self.inner) }
266 }
267}
268
269impl Drop for Headers {
270 fn drop(&mut self) {
271 unsafe {
274 aws_http_headers_release(self.inner.as_ptr());
275 }
276 }
277}
278
279#[derive(Debug)]
281struct HeadersIterator<'a> {
282 headers: &'a Headers,
283 offset: usize,
284}
285
286impl Iterator for HeadersIterator<'_> {
287 type Item = (OsString, OsString);
288
289 fn next(&mut self) -> Option<Self::Item> {
290 if self.offset < self.headers.count() {
291 let header = self
292 .headers
293 .get_index(self.offset)
294 .expect("headers at any offset smaller than original count should always exist given mut access");
295 self.offset += 1;
296
297 Some((header.name, header.value))
298 } else {
299 None
300 }
301 }
302}
303
304#[derive(Debug)]
306pub struct Message<'a> {
307 pub(crate) inner: NonNull<aws_http_message>,
309
310 body_input_stream: Option<InputStream<'a>>,
312}
313
314impl<'a> Message<'a> {
315 pub fn new_request(allocator: &Allocator) -> Result<Self, Error> {
317 http_library_init(allocator);
319
320 let inner = unsafe { aws_http_message_new_request(allocator.inner.as_ptr()).ok_or_last_error()? };
322
323 Ok(Self {
324 inner,
325 body_input_stream: None,
326 })
327 }
328
329 pub fn add_header(&mut self, header: &Header<impl AsRef<OsStr>, impl AsRef<OsStr>>) -> Result<(), Error> {
333 unsafe { aws_http_message_add_header(self.inner.as_ptr(), header.inner).ok_or_last_error() }
335 }
336
337 pub fn set_header(&mut self, header: &Header<impl AsRef<OsStr>, impl AsRef<OsStr>>) -> Result<(), Error> {
340 let headers = unsafe { aws_http_message_get_headers(self.inner.as_ptr()) };
342 assert!(!headers.is_null(), "headers are always initialized");
343 unsafe {
345 aws_http_headers_set(
346 headers,
347 header.name().as_aws_byte_cursor(),
348 header.value().as_aws_byte_cursor(),
349 )
350 .ok_or_last_error()
351 }
352 }
353
354 pub fn set_request_path(&mut self, path: impl AsRef<OsStr>) -> Result<(), Error> {
356 unsafe { aws_http_message_set_request_path(self.inner.as_ptr(), path.as_aws_byte_cursor()).ok_or_last_error() }
358 }
359
360 pub fn set_request_method(&mut self, method: impl AsRef<OsStr>) -> Result<(), Error> {
362 unsafe {
364 aws_http_message_set_request_method(self.inner.as_ptr(), method.as_aws_byte_cursor()).ok_or_last_error()
365 }
366 }
367
368 pub fn get_headers(&mut self) -> Result<Headers, Error> {
370 let header_ptr = unsafe { aws_http_message_get_headers(self.inner.as_ptr()).ok_or_last_error()? };
372 let headers = unsafe { Headers::from_crt(header_ptr) };
375 Ok(headers)
376 }
377
378 pub fn set_body_stream(&mut self, input_stream: Option<InputStream<'a>>) -> Option<InputStream<'a>> {
381 let old_input_stream = std::mem::replace(&mut self.body_input_stream, input_stream);
382
383 let new_input_stream_ptr = self
384 .body_input_stream
385 .as_ref()
386 .map(|s| s.inner.as_ptr())
387 .unwrap_or(std::ptr::null_mut());
388
389 unsafe {
394 aws_http_message_set_body_stream(self.inner.as_ptr(), new_input_stream_ptr);
395 }
396
397 old_input_stream
398 }
399}
400
401impl Drop for Message<'_> {
402 fn drop(&mut self) {
403 unsafe {
406 aws_http_message_release(self.inner.as_ptr());
407 }
408 }
409}
410
411#[cfg(test)]
412mod test {
413 use super::*;
414 use crate::common::allocator::Allocator;
415 use std::collections::HashMap;
416
417 #[test]
419 fn test_headers() {
420 let mut headers = Headers::new(&Allocator::default()).expect("failed to create headers");
421
422 headers.add_header(&Header::new("a", "1")).unwrap();
423 headers.add_header(&Header::new("b", "2")).unwrap();
424 headers.add_header(&Header::new("c", "3")).unwrap();
425
426 assert_eq!(headers.count(), 3);
427
428 assert!(headers.has_header("a"));
429 assert!(headers.has_header("b"));
430 assert!(headers.has_header("c"));
431
432 assert_eq!(headers.get("a").unwrap().name(), "a");
433 assert_eq!(headers.get("a").unwrap().value(), "1");
434
435 assert_eq!(headers.get_as_string("a"), Ok("1".to_string()));
436 assert_eq!(headers.get_as_optional_string("a"), Ok(Some("1".to_string())));
437
438 let map: HashMap<OsString, OsString> = headers.iter().collect();
439
440 assert_eq!(map.len(), 3);
441 assert_eq!(map.get(OsStr::new("a")), Some(&OsString::from("1")));
442 }
443
444 #[test]
446 fn test_header_not_present() {
447 let headers = Headers::new(&Allocator::default()).expect("failed to create headers");
448
449 assert!(!headers.has_header("a"), "header should not be present");
450
451 let error = headers.get("a").expect_err("should fail because header is not present");
452 assert_eq!(
453 error.to_string(),
454 "Header \"a\" not found",
455 "header error display should match expected output",
456 );
457 if let HeadersError::HeaderNotFound(name) = error {
458 assert_eq!(name, "a", "header name should match original argument");
459 } else {
460 panic!("should fail with HeaderNotFound");
461 }
462
463 let error = headers
464 .get_as_string("a")
465 .expect_err("should fail because header is not present");
466 if let HeadersError::HeaderNotFound(name) = error {
467 assert_eq!(name, "a", "header name should match original argument");
468 } else {
469 panic!("should fail with HeaderNotFound");
470 }
471
472 let header = headers
473 .get_as_optional_string("a")
474 .expect("Should not fail as optional is expected here");
475 assert_eq!(header, None, "should return None");
476 }
477
478 #[test]
480 fn test_headers_overwrite() {
481 let mut headers = Headers::new(&Allocator::default()).expect("failed to create headers");
482
483 headers.add_header(&Header::new("a", "1")).unwrap();
484 headers.add_header(&Header::new("a", "2")).unwrap();
485
486 assert_eq!(headers.count(), 1);
487
488 assert_eq!(headers.get("a").unwrap().name(), "a");
489 assert_eq!(headers.get("a").unwrap().value(), "2");
490
491 let map: HashMap<OsString, OsString> = headers.iter().collect();
492
493 assert_eq!(map.len(), 1);
494 assert_eq!(map.get(OsStr::new("a")), Some(&OsString::from("2")));
495 }
496
497 #[test]
499 fn test_headers_erase() {
500 let mut headers = Headers::new(&Allocator::default()).expect("failed to create headers");
501
502 headers.add_header(&Header::new("a", "1")).unwrap();
503 assert_eq!(headers.count(), 1);
504
505 headers.erase_header("a").unwrap();
506
507 assert_eq!(headers.count(), 0);
508
509 let map: HashMap<OsString, OsString> = headers.iter().collect();
510
511 assert_eq!(map.len(), 0);
512 }
513}