use crate::Header;
use crate::StreamId;
use std::collections::hash_map;
use std::collections::HashMap;
use std::fmt::Debug;
use std::fmt::Display;
use std::marker::PhantomPinned;
use std::pin::Pin;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum DecoderErrorKind {
DuplicateStreamId,
DecodeFailed,
FeedFailed,
InvalidHeader,
}
pub struct DecoderError {
kind: DecoderErrorKind,
}
impl DecoderError {
pub fn kind(&self) -> DecoderErrorKind {
self.kind
}
fn new(kind: DecoderErrorKind) -> Self {
Self { kind }
}
}
#[derive(Debug)]
pub struct BuffersDecoded {
headers: Vec<Header>,
stream: Box<[u8]>,
}
impl BuffersDecoded {
pub fn headers(&self) -> &Vec<Header> {
&self.headers
}
pub fn stream(&self) -> &[u8] {
&self.stream
}
}
#[derive(Debug)]
pub enum DecoderOutput {
Done(BuffersDecoded),
BlockedStream,
}
impl DecoderOutput {
pub fn take(self) -> Option<BuffersDecoded> {
match self {
Self::Done(v) => Some(v),
Self::BlockedStream => None,
}
}
pub fn is_blocked(&self) -> bool {
matches!(self, Self::BlockedStream)
}
}
pub struct Decoder {
inner: Pin<Box<InnerDecoder>>,
}
impl Decoder {
pub fn new(dyn_table_size: u32, max_blocked_streams: u32) -> Self {
Self {
inner: InnerDecoder::new(dyn_table_size, max_blocked_streams),
}
}
pub fn decode<D>(&mut self, stream_id: StreamId, data: D) -> Result<DecoderOutput, DecoderError>
where
D: AsRef<[u8]>,
{
self.inner
.as_mut()
.feed_header_data(stream_id, data.as_ref())
}
pub fn feed<D>(&mut self, data: D) -> Result<(), DecoderError>
where
D: AsRef<[u8]>,
{
self.inner.as_mut().feed_encoder_data(data.as_ref())
}
pub fn unblocked(
&mut self,
stream_id: StreamId,
) -> Option<Result<DecoderOutput, DecoderError>> {
self.inner.as_mut().process_decoded_data(stream_id)
}
}
struct InnerDecoder {
decoder: ls_qpack_rs_sys::lsqpack_dec,
header_blocks: HashMap<StreamId, Pin<Box<callbacks::HeaderBlockCtx>>>,
_marker: PhantomPinned,
}
impl InnerDecoder {
fn new(dyn_table_size: u32, max_blocked_streams: u32) -> Pin<Box<Self>> {
let mut this = Box::new(Self {
decoder: ls_qpack_rs_sys::lsqpack_dec::default(),
header_blocks: HashMap::new(),
_marker: PhantomPinned,
});
unsafe {
ls_qpack_rs_sys::lsqpack_dec_init(
&mut this.decoder,
std::ptr::null_mut(),
dyn_table_size,
max_blocked_streams,
&callbacks::HSET_IF_CALLBACKS,
0,
);
}
Box::into_pin(this)
}
fn feed_header_data(
self: Pin<&mut Self>,
stream_id: StreamId,
data: &[u8],
) -> Result<DecoderOutput, DecoderError> {
let this = unsafe { self.get_unchecked_mut() };
if this.header_blocks.contains_key(&stream_id) {
return Err(DecoderError::new(DecoderErrorKind::DuplicateStreamId));
}
let mut hblock_ctx =
callbacks::HeaderBlockCtx::new(&mut this.decoder, data.to_vec().into_boxed_slice());
let encoded_cursor = hblock_ctx.as_ref().encoded_cursor();
let encoded_cursor_len = encoded_cursor.len();
let header_block_len = encoded_cursor.len();
let mut cursor_after = encoded_cursor.as_ptr();
let mut buffer = vec![0; ls_qpack_rs_sys::LSQPACK_LONGEST_SDTC as usize];
let mut sdtc_buffer_size = buffer.len();
let result = unsafe {
ls_qpack_rs_sys::lsqpack_dec_header_in(
&mut this.decoder,
hblock_ctx.as_mut().as_mut_ptr() as *mut libc::c_void,
stream_id.value(),
header_block_len,
&mut cursor_after,
encoded_cursor_len,
buffer.as_mut_ptr(),
&mut sdtc_buffer_size,
)
};
match result {
ls_qpack_rs_sys::lsqpack_read_header_status_LQRHS_DONE => {
debug_assert!(!hblock_ctx.as_ref().is_blocked());
debug_assert!(!hblock_ctx.as_ref().is_error());
let hblock_ctx = unsafe { Pin::into_inner_unchecked(hblock_ctx) };
buffer.truncate(sdtc_buffer_size);
Ok(DecoderOutput::Done(BuffersDecoded {
headers: hblock_ctx.decoded_headers(),
stream: buffer.into_boxed_slice(),
}))
}
ls_qpack_rs_sys::lsqpack_read_header_status_LQRHS_BLOCKED
| ls_qpack_rs_sys::lsqpack_read_header_status_LQRHS_NEED => {
let offset = unsafe {
cursor_after.offset_from(hblock_ctx.as_ref().encoded_cursor().as_ptr())
};
hblock_ctx.as_mut().advance_cursor(offset as usize);
hblock_ctx.as_mut().set_blocked(true);
this.header_blocks.insert(stream_id, hblock_ctx);
Ok(DecoderOutput::BlockedStream)
}
_ => Err(DecoderError::new(DecoderErrorKind::DecodeFailed)),
}
}
fn feed_encoder_data(self: Pin<&mut Self>, data: &[u8]) -> Result<(), DecoderError> {
let this = unsafe { self.get_unchecked_mut() };
let result = unsafe {
ls_qpack_rs_sys::lsqpack_dec_enc_in(&mut this.decoder, data.as_ptr(), data.len())
};
if result == 0 {
Ok(())
} else {
Err(DecoderError::new(DecoderErrorKind::FeedFailed))
}
}
fn process_decoded_data(
self: Pin<&mut Self>,
stream_id: StreamId,
) -> Option<Result<DecoderOutput, DecoderError>> {
let this = unsafe { self.get_unchecked_mut() };
match this.header_blocks.entry(stream_id) {
hash_map::Entry::Occupied(hdbk) => {
if hdbk.get().as_ref().is_blocked() {
debug_assert!(!hdbk.get().as_ref().is_error());
return Some(Ok(DecoderOutput::BlockedStream));
}
let hdbk = hdbk.remove();
if hdbk.as_ref().is_error() {
debug_assert!(!hdbk.as_ref().is_blocked());
return Some(Err(DecoderError::new(DecoderErrorKind::InvalidHeader)));
}
let hdbk = unsafe { Pin::into_inner_unchecked(hdbk) };
Some(Ok(DecoderOutput::Done(BuffersDecoded {
headers: hdbk.decoded_headers(),
stream: hdbk.stream_data().into_boxed_slice(),
})))
}
hash_map::Entry::Vacant(_) => None,
}
}
}
impl Drop for InnerDecoder {
fn drop(&mut self) {
unsafe { ls_qpack_rs_sys::lsqpack_dec_cleanup(&mut self.decoder) }
}
}
unsafe impl Send for InnerDecoder {}
unsafe impl Sync for InnerDecoder {}
const _: () = {
fn _assert_send<T: Send>() {}
fn _assert_sync<T: Sync>() {}
fn _assert_all() {
_assert_send::<Decoder>();
_assert_sync::<Decoder>();
}
};
mod callbacks {
use crate::header::HeaderError;
use crate::Header;
use std::ffi::c_char;
use std::marker::PhantomPinned;
use std::pin::Pin;
pub(super) static HSET_IF_CALLBACKS: ls_qpack_rs_sys::lsqpack_dec_hset_if =
ls_qpack_rs_sys::lsqpack_dec_hset_if {
dhi_unblocked: Some(dhi_unblocked),
dhi_prepare_decode: Some(dhi_prepare_decode),
dhi_process_header: Some(dhi_process_header),
};
#[derive(Debug)]
pub(super) struct HeaderBlockCtx {
decoder: *mut ls_qpack_rs_sys::lsqpack_dec,
encoded_data: Box<[u8]>,
encoded_data_offset: usize,
decoding_buffer: Vec<u8>,
header: ls_qpack_rs_sys::lsxpack_header,
blocked: bool,
error: bool,
stream_data: Vec<u8>,
decoded_headers: Vec<Header>,
_marker: PhantomPinned,
}
impl HeaderBlockCtx {
pub(super) fn new(
decoder: *mut ls_qpack_rs_sys::lsqpack_dec,
encoded_data: Box<[u8]>,
) -> Pin<Box<Self>> {
debug_assert!(!decoder.is_null());
Box::pin(Self {
decoder,
encoded_data,
encoded_data_offset: 0,
decoding_buffer: Vec::new(),
stream_data: Vec::new(),
header: Default::default(),
blocked: false,
error: false,
decoded_headers: Default::default(),
_marker: PhantomPinned,
})
}
pub(super) unsafe fn as_mut_ptr(mut self: Pin<&mut Self>) -> *mut HeaderBlockCtx {
self.as_mut().get_unchecked_mut()
}
pub(super) fn encoded_cursor(self: Pin<&Self>) -> &[u8] {
debug_assert!(self.encoded_data_offset < self.encoded_data.len());
&self.get_ref().encoded_data[self.encoded_data_offset..]
}
pub(super) fn advance_cursor(self: Pin<&mut Self>, offset: usize) {
debug_assert!(offset <= self.encoded_data.len());
let this = unsafe { self.get_unchecked_mut() };
this.encoded_data_offset += offset;
}
pub(super) fn set_blocked(self: Pin<&mut Self>, blocked: bool) {
let this = unsafe { self.get_unchecked_mut() };
this.blocked = blocked;
}
pub(super) fn set_stream_data(self: Pin<&mut Self>, data: &[u8]) {
let this = unsafe { self.get_unchecked_mut() };
this.stream_data = data.to_vec();
}
pub(super) fn enable_error(self: Pin<&mut Self>) {
let this = unsafe { self.get_unchecked_mut() };
debug_assert!(!this.error);
this.error = true;
}
pub(super) fn is_blocked(self: Pin<&Self>) -> bool {
self.blocked
}
pub(super) fn is_error(self: Pin<&Self>) -> bool {
self.error
}
pub(super) fn decoded_headers(&self) -> Vec<Header> {
self.decoded_headers.clone()
}
pub(super) fn stream_data(&self) -> Vec<u8> {
self.stream_data.clone()
}
unsafe fn from_void_ptr(ptr: *mut libc::c_void) -> Pin<&'static mut Self> {
debug_assert!(!ptr.is_null());
Pin::new_unchecked(&mut *(ptr as *mut Self))
}
fn reset_header(self: Pin<&mut Self>) {
let this = unsafe { self.get_unchecked_mut() };
this.header = Default::default()
}
fn resize_header(self: Pin<&mut Self>, space: u16) {
let this = unsafe { self.get_unchecked_mut() };
this.decoding_buffer
.resize(space as usize, Default::default());
this.header.buf = this.decoding_buffer.as_mut_ptr() as *mut c_char;
this.header.val_len = space;
}
fn header_mut(self: Pin<&mut Self>) -> &mut ls_qpack_rs_sys::lsxpack_header {
let this = unsafe { self.get_unchecked_mut() };
&mut this.header
}
fn process_header(self: Pin<&mut Self>) -> Result<(), HeaderError> {
let this = unsafe { self.get_unchecked_mut() };
let header = Header::with_buffer(
std::mem::take(&mut this.decoding_buffer).into_boxed_slice(),
this.header.name_offset as usize,
this.header.name_len as usize,
this.header.val_offset as usize,
this.header.val_len as usize,
)?;
this.decoded_headers.push(header);
this.header = Default::default();
Ok(())
}
}
unsafe impl Send for HeaderBlockCtx {}
unsafe impl Sync for HeaderBlockCtx {}
extern "C" fn dhi_unblocked(hblock_ctx: *mut libc::c_void) {
let mut hblock_ctx = unsafe { HeaderBlockCtx::from_void_ptr(hblock_ctx) };
debug_assert!(hblock_ctx.as_ref().is_blocked());
hblock_ctx.as_mut().set_blocked(false);
let encoded_cursor = hblock_ctx.as_ref().encoded_cursor();
let encoded_cursor_len = encoded_cursor.len();
let mut cursor_after = encoded_cursor.as_ptr();
let mut buffer = vec![0; ls_qpack_rs_sys::LSQPACK_LONGEST_SDTC as usize];
let mut sdtc_buffer_size = buffer.len();
let result = unsafe {
ls_qpack_rs_sys::lsqpack_dec_header_read(
hblock_ctx.decoder,
hblock_ctx.as_mut().as_mut_ptr() as *mut libc::c_void,
&mut cursor_after,
encoded_cursor_len,
buffer.as_mut_ptr(),
&mut sdtc_buffer_size,
)
};
match result {
ls_qpack_rs_sys::lsqpack_read_header_status_LQRHS_DONE => {
buffer.truncate(sdtc_buffer_size);
hblock_ctx.as_mut().set_stream_data(&buffer);
}
ls_qpack_rs_sys::lsqpack_read_header_status_LQRHS_BLOCKED
| ls_qpack_rs_sys::lsqpack_read_header_status_LQRHS_NEED => {
let offset = unsafe {
cursor_after.offset_from(hblock_ctx.as_ref().encoded_cursor().as_ptr())
};
debug_assert!(offset >= 0);
hblock_ctx.as_mut().advance_cursor(offset as usize);
hblock_ctx.as_mut().set_blocked(true);
}
_ => {
hblock_ctx.as_mut().enable_error();
}
}
}
extern "C" fn dhi_prepare_decode(
hblock_ctx: *mut libc::c_void,
header: *mut ls_qpack_rs_sys::lsxpack_header,
space: libc::size_t,
) -> *mut ls_qpack_rs_sys::lsxpack_header {
const MAX_SPACE: usize = u16::MAX as usize;
let mut hblock_ctx = unsafe {
HeaderBlockCtx::from_void_ptr(hblock_ctx)
};
if space > MAX_SPACE {
return std::ptr::null_mut();
}
let space = space as u16;
if header.is_null() {
hblock_ctx.as_mut().reset_header();
} else {
assert!(std::ptr::eq(&hblock_ctx.header, header));
assert!(space > hblock_ctx.header.val_len);
}
hblock_ctx.as_mut().resize_header(space);
hblock_ctx.as_mut().header_mut()
}
extern "C" fn dhi_process_header(
hblock_ctx: *mut libc::c_void,
header: *mut ls_qpack_rs_sys::lsxpack_header,
) -> libc::c_int {
let mut hblock_ctx = unsafe { HeaderBlockCtx::from_void_ptr(hblock_ctx) };
debug_assert!(!hblock_ctx.blocked);
debug_assert_eq!(header as *const _, &hblock_ctx.header);
match hblock_ctx.as_mut().process_header() {
Ok(()) => 0,
Err(_) => {
hblock_ctx.as_mut().enable_error();
-1
}
}
}
}
impl Debug for DecoderError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DecoderError")
.field("kind", &self.kind)
.finish()
}
}
impl Display for DecoderError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.kind {
DecoderErrorKind::DuplicateStreamId => {
write!(f, "stream ID already has a pending header block")
}
DecoderErrorKind::DecodeFailed => write!(f, "decoding operation failed"),
DecoderErrorKind::FeedFailed => write!(f, "failed to process encoder stream data"),
DecoderErrorKind::InvalidHeader => {
write!(f, "invalid header encountered during decoding")
}
}
}
}
impl std::error::Error for DecoderError {}
#[cfg(test)]
mod tests {
use super::*;
use crate::encoder::Encoder;
#[test]
fn test_decode_static_table_round_trip() {
let headers = [(":status", "200"), ("content-type", "text/html")];
let encoded = Encoder::new()
.encode_all(StreamId::new(0), headers)
.unwrap();
let (hdr_data, stream_data) = encoded.into();
assert!(
stream_data.is_empty(),
"static table should produce no stream data"
);
let output = Decoder::new(0, 0)
.decode(StreamId::new(0), hdr_data)
.unwrap();
let decoded = output.take().expect("should not be blocked");
let hdrs = decoded.headers();
assert_eq!(hdrs.len(), 2);
assert_eq!(hdrs[0].name(), ":status");
assert_eq!(hdrs[0].value(), "200");
assert_eq!(hdrs[1].name(), "content-type");
assert_eq!(hdrs[1].value(), "text/html");
}
#[test]
fn test_decode_multiple_streams() {
let mut encoder = Encoder::new();
let mut decoder = Decoder::new(0, 0);
for stream_id in 0..5u64 {
let encoded = encoder
.encode_all(StreamId::new(stream_id), [(":method", "GET")])
.unwrap();
let (hdr_data, _) = encoded.into();
let output = decoder.decode(StreamId::new(stream_id), hdr_data).unwrap();
let decoded = output.take().expect("should not be blocked");
assert_eq!(decoded.headers()[0].name(), ":method");
assert_eq!(decoded.headers()[0].value(), "GET");
}
}
#[test]
fn test_decode_duplicate_stream_id_error() {
let mut encoder = Encoder::new();
let sdtc = encoder.configure(4096, 4096, 100).unwrap();
let mut decoder = Decoder::new(4096, 100);
decoder.feed(sdtc.data()).unwrap();
let mut hdr_data_blocked = None;
let mut enc_stream_blocked = None;
for i in 0..10u64 {
let name = format!("x-unique-{}", i);
let value = format!("value-{}", i);
let encoded = encoder
.encode_all(StreamId::new(i), [(name.as_str(), value.as_str())])
.unwrap();
let (hdr_data, enc_stream) = encoded.into();
if !enc_stream.is_empty() && hdr_data_blocked.is_none() {
hdr_data_blocked = Some((i, hdr_data));
enc_stream_blocked = Some(enc_stream);
continue;
}
if !enc_stream.is_empty() {
decoder.feed(&enc_stream).unwrap();
}
let _ = decoder.decode(StreamId::new(i), hdr_data);
}
if let Some((stream_id, hdr_data)) = hdr_data_blocked {
let output = decoder.decode(StreamId::new(stream_id), &hdr_data).unwrap();
assert!(
output.is_blocked(),
"should be blocked without encoder stream"
);
let err = decoder
.decode(StreamId::new(stream_id), &hdr_data)
.unwrap_err();
assert_eq!(err.kind(), DecoderErrorKind::DuplicateStreamId);
if let Some(enc_stream) = enc_stream_blocked {
decoder.feed(&enc_stream).unwrap();
}
} else {
let encoded = encoder
.encode_all(StreamId::new(100), [(":status", "200")])
.unwrap();
let (hdr_data, enc_stream) = encoded.into();
if !enc_stream.is_empty() {
decoder.feed(&enc_stream).unwrap();
}
let _ = decoder.decode(StreamId::new(100), &hdr_data).unwrap();
}
}
#[test]
fn test_dynamic_table_round_trip() {
let mut encoder = Encoder::new();
let sdtc = encoder.configure(4096, 4096, 100).unwrap();
let mut decoder = Decoder::new(4096, 100);
decoder.feed(sdtc.data()).unwrap();
let headers = [
(":status", "200"),
("x-custom", "hello"),
("x-another", "world"),
];
let encoded = encoder.encode_all(StreamId::new(0), headers).unwrap();
let (hdr_data, enc_stream) = encoded.into();
decoder.feed(&enc_stream).unwrap();
let output = decoder.decode(StreamId::new(0), hdr_data).unwrap();
let decoded = output
.take()
.expect("should not be blocked after feeding encoder stream");
let hdrs = decoded.headers();
assert_eq!(hdrs.len(), 3);
assert_eq!(hdrs[0].name(), ":status");
assert_eq!(hdrs[0].value(), "200");
assert_eq!(hdrs[1].name(), "x-custom");
assert_eq!(hdrs[1].value(), "hello");
assert_eq!(hdrs[2].name(), "x-another");
assert_eq!(hdrs[2].value(), "world");
}
#[test]
fn test_dynamic_table_blocked_then_unblocked() {
let mut encoder = Encoder::new();
let sdtc = encoder.configure(4096, 4096, 100).unwrap();
let mut decoder = Decoder::new(4096, 100);
decoder.feed(sdtc.data()).unwrap();
let mut blocked_info = None;
for i in 0..10u64 {
let name = format!("x-blocked-test-{}", i);
let value = format!("value-{}", i);
let encoded = encoder
.encode_all(StreamId::new(i), [(name.as_str(), value.as_str())])
.unwrap();
let (hdr_data, enc_stream) = encoded.into();
if !enc_stream.is_empty() && blocked_info.is_none() {
let output = decoder.decode(StreamId::new(i), &hdr_data).unwrap();
if output.is_blocked() {
blocked_info = Some((i, enc_stream, name, value));
continue;
}
}
if !enc_stream.is_empty() {
decoder.feed(&enc_stream).unwrap();
}
let _ = decoder.decode(StreamId::new(i), hdr_data);
}
if let Some((stream_id, enc_stream, name, value)) = blocked_info {
decoder.feed(&enc_stream).unwrap();
let result = decoder.unblocked(StreamId::new(stream_id));
let output = result
.expect("should have result for blocked stream")
.unwrap();
let decoded = output.take().expect("should be unblocked now");
assert_eq!(decoded.headers().len(), 1);
assert_eq!(decoded.headers()[0].name(), name);
assert_eq!(decoded.headers()[0].value(), value);
}
}
#[test]
fn test_decoder_error_display() {
let err = DecoderError::new(DecoderErrorKind::DuplicateStreamId);
assert_eq!(
err.to_string(),
"stream ID already has a pending header block"
);
let err = DecoderError::new(DecoderErrorKind::DecodeFailed);
assert_eq!(err.to_string(), "decoding operation failed");
let err = DecoderError::new(DecoderErrorKind::FeedFailed);
assert_eq!(err.to_string(), "failed to process encoder stream data");
let err = DecoderError::new(DecoderErrorKind::InvalidHeader);
assert_eq!(
err.to_string(),
"invalid header encountered during decoding"
);
}
#[test]
fn test_decoder_error_debug() {
let err = DecoderError::new(DecoderErrorKind::DecodeFailed);
let debug = format!("{:?}", err);
assert!(debug.contains("DecoderError"));
assert!(debug.contains("DecodeFailed"));
}
#[test]
fn test_decoder_output_variants() {
let blocked = DecoderOutput::BlockedStream;
assert!(blocked.is_blocked());
assert!(blocked.take().is_none());
let done = DecoderOutput::Done(BuffersDecoded {
headers: vec![],
stream: vec![].into_boxed_slice(),
});
assert!(!done.is_blocked());
assert!(done.take().is_some());
}
#[test]
fn test_decode_produces_stream_data() {
let encoded = Encoder::new()
.encode_all(StreamId::new(0), [(":method", "GET")])
.unwrap();
let (hdr_data, _) = encoded.into();
let output = Decoder::new(0, 0)
.decode(StreamId::new(0), hdr_data)
.unwrap();
let decoded = output.take().expect("should decode");
assert_eq!(decoded.headers().len(), 1);
assert_eq!(decoded.headers()[0].name(), ":method");
assert_eq!(decoded.headers()[0].value(), "GET");
}
}