use std::io::Write;
use std::panic::UnwindSafe;
use libc::c_void;
use crate::nxt_unit::{
self, nxt_unit_buf_send, nxt_unit_buf_t, nxt_unit_request_info_t, nxt_unit_response_buf_alloc,
};
use crate::error::{IntoUnitResult, UnitError, UnitResult};
use crate::request::Request;
pub struct Response<'a> {
pub(crate) request: &'a Request<'a>,
}
impl<'a> Response<'a> {
pub fn add_field<N: AsRef<[u8]>, V: AsRef<[u8]>>(&self, name: N, value: V) -> UnitResult<()> {
unsafe {
nxt_unit::nxt_unit_response_add_field(
self.request.nxt_request,
name.as_ref().as_ptr() as *const libc::c_char,
name.as_ref().len() as u8,
value.as_ref().as_ptr() as *const libc::c_char,
value.as_ref().len() as u32,
)
.into_unit_result()
}
}
pub fn add_content<C: AsRef<[u8]>>(&self, content: C) -> UnitResult<()> {
unsafe {
nxt_unit::nxt_unit_response_add_content(
self.request.nxt_request,
content.as_ref().as_ptr() as *const c_void,
content.as_ref().len() as u32,
)
.into_unit_result()
}
}
pub fn realloc(&self, max_fields_count: usize, max_fields_size: usize) -> UnitResult<()> {
unsafe {
nxt_unit::nxt_unit_response_realloc(
self.request.nxt_request,
max_fields_count as u32,
max_fields_size as u32,
)
.into_unit_result()
}
}
pub fn send(&self) -> UnitResult<()> {
unsafe { nxt_unit::nxt_unit_response_send(self.request.nxt_request).into_unit_result() }
}
}
pub struct BodyWriter<'a> {
_lifetime: std::marker::PhantomData<&'a mut ()>,
nxt_request: *mut nxt_unit_request_info_t,
response_buffer: *mut nxt_unit_buf_t,
chunk_cursor: *mut u8,
chunk_size: usize,
bytes_remaining: usize,
}
impl UnwindSafe for BodyWriter<'_> {}
impl<'a> BodyWriter<'a> {
pub(crate) fn new(request: &'a Request<'a>, chunk_size: usize) -> std::io::Result<Self> {
let mut writer = BodyWriter {
_lifetime: Default::default(),
nxt_request: request.nxt_request,
response_buffer: std::ptr::null_mut(),
chunk_cursor: std::ptr::null_mut(),
chunk_size,
bytes_remaining: 0,
};
writer.allocate_buffer()?;
Ok(writer)
}
fn allocate_buffer(&mut self) -> std::io::Result<()> {
unsafe {
let buf = nxt_unit_response_buf_alloc(self.nxt_request, self.chunk_size as u32);
if buf.is_null() {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Could not allocate response buffer in Unit's shared memory",
));
}
self.response_buffer = buf;
self.chunk_cursor = (*buf).start as *mut u8;
self.bytes_remaining = self.chunk_size;
}
Ok(())
}
pub fn copy_from_reader<R: std::io::Read>(&mut self, mut r: R) -> std::io::Result<()> {
loop {
if self.bytes_remaining == 0 {
self.flush()?;
self.allocate_buffer()?;
}
let write_buffer = unsafe {
libc::memset(self.chunk_cursor as *mut c_void, 0, self.bytes_remaining);
std::slice::from_raw_parts_mut(self.chunk_cursor, self.bytes_remaining)
};
let bytes = r.read(write_buffer)?;
self.chunk_cursor = unsafe { self.chunk_cursor.add(bytes) };
self.bytes_remaining -= bytes;
if bytes == 0 {
break;
}
}
return Ok(());
}
}
impl std::io::Write for BodyWriter<'_> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let buf = if buf.len() >= self.bytes_remaining && !buf.is_empty() {
if self.bytes_remaining == 0 {
self.flush()?;
self.allocate_buffer()?;
}
&buf[..buf.len().min(self.bytes_remaining)]
} else {
buf
};
unsafe {
std::ptr::copy_nonoverlapping(buf.as_ptr(), self.chunk_cursor, buf.len());
self.chunk_cursor = self.chunk_cursor.add(buf.len());
}
self.bytes_remaining -= buf.len();
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
if self.response_buffer.is_null() || self.bytes_remaining == self.chunk_size {
return Ok(());
}
unsafe {
(*self.response_buffer).free = (*self.response_buffer)
.start
.add(self.chunk_size - self.bytes_remaining);
nxt_unit_buf_send(self.response_buffer)
.into_unit_result()
.map_err(|UnitError(_)| {
std::io::Error::new(
std::io::ErrorKind::Other,
"Could not send response buffer to Unit server",
)
})?;
}
self.response_buffer = std::ptr::null_mut();
self.bytes_remaining = 0;
Ok(())
}
}
impl Drop for BodyWriter<'_> {
fn drop(&mut self) {
if !self.chunk_cursor.is_null() {
if std::thread::panicking() {
unsafe {
nxt_unit::nxt_unit_buf_free(self.response_buffer);
}
} else {
if let Err(err) = self.flush() {
if !std::thread::panicking() {
panic!("Error while dropping ResponseWriter: {}", err);
}
}
}
}
}
}