1use crate::header::TryIntoHeader;
27use crate::StreamId;
28use std::collections::HashMap;
29use std::fmt::Debug;
30use std::fmt::Display;
31use std::marker::PhantomPinned;
32use std::pin::Pin;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36#[non_exhaustive]
37pub enum EncoderErrorKind {
38 InvalidHeader,
40 InitFailed,
42 EncodeFailed,
44 EndHeaderFailed,
46 FeedFailed,
48}
49
50pub struct EncoderError {
52 kind: EncoderErrorKind,
53}
54
55impl EncoderError {
56 pub fn kind(&self) -> EncoderErrorKind {
58 self.kind
59 }
60
61 fn new(kind: EncoderErrorKind) -> Self {
62 Self { kind }
63 }
64}
65
66pub struct Encoder {
68 inner: Pin<Box<InnerEncoder>>,
69 seqnos: HashMap<StreamId, u32>,
70}
71
72impl Encoder {
73 #[inline]
80 pub fn new() -> Self {
81 Self {
82 inner: InnerEncoder::new(),
83 seqnos: HashMap::new(),
84 }
85 }
86
87 #[inline]
97 pub fn configure(
98 &mut self,
99 max_table_size: u32,
100 dyn_table_size: u32,
101 max_blocked_streams: u32,
102 ) -> Result<SDTCInstruction, EncoderError> {
103 self.inner
104 .as_mut()
105 .init(max_table_size, dyn_table_size, max_blocked_streams)
106 .map(SDTCInstruction)
107 }
108
109 pub fn encode_all<I, H>(
129 &mut self,
130 stream_id: StreamId,
131 headers: I,
132 ) -> Result<BuffersEncoded, EncoderError>
133 where
134 I: IntoIterator<Item = H>,
135 H: TryIntoHeader,
136 {
137 let mut encoding = self.encoding(stream_id);
138
139 for header in headers {
140 encoding.append(header)?;
141 }
142
143 encoding.encode()
144 }
145
146 #[inline]
167 pub fn encoding(&mut self, stream_id: StreamId) -> EncodingBlock<'_> {
168 let seqno = {
169 let seqno_ref = self.seqnos.entry(stream_id).or_default();
170 std::mem::replace(seqno_ref, seqno_ref.wrapping_add(1))
171 };
172
173 EncodingBlock::new(self, stream_id, seqno)
174 }
175
176 pub fn feed<D>(&mut self, data: D) -> Result<(), EncoderError>
178 where
179 D: AsRef<[u8]>,
180 {
181 self.inner.as_mut().feed_decoder_data(data.as_ref())
182 }
183
184 #[inline]
190 pub fn ratio(&self) -> f32 {
191 self.inner.as_ref().ratio()
192 }
193
194 #[inline]
195 fn inner_mut(&mut self) -> Pin<&mut InnerEncoder> {
196 self.inner.as_mut()
197 }
198}
199
200impl Default for Encoder {
201 fn default() -> Self {
202 Self::new()
203 }
204}
205
206#[derive(Debug)]
211pub struct SDTCInstruction(Box<[u8]>);
212
213impl SDTCInstruction {
214 #[inline]
216 pub fn data(&self) -> &[u8] {
217 &self.0
218 }
219
220 #[inline]
222 pub fn take(self) -> Box<[u8]> {
223 self.0
224 }
225}
226
227impl AsRef<[u8]> for SDTCInstruction {
228 #[inline]
229 fn as_ref(&self) -> &[u8] {
230 self.data()
231 }
232}
233
234impl From<SDTCInstruction> for Box<[u8]> {
235 fn from(sdtc_instruction: SDTCInstruction) -> Self {
236 sdtc_instruction.0
237 }
238}
239
240pub struct EncodingBlock<'a>(&'a mut Encoder);
244
245impl<'a> EncodingBlock<'a> {
246 fn new(encoder: &'a mut Encoder, stream_id: StreamId, seqno: u32) -> Self {
247 encoder
248 .inner_mut()
249 .start_header_block(stream_id, seqno)
250 .map(|()| Self(encoder))
251 .unwrap() }
253
254 pub fn append<H>(&mut self, header: H) -> Result<&mut Self, EncoderError>
256 where
257 H: TryIntoHeader,
258 {
259 self.0.inner_mut().encode(header).map(|()| self)
260 }
261
262 pub fn encode(self) -> Result<BuffersEncoded, EncoderError> {
264 self.0
265 .inner_mut()
266 .end_header_block()
267 .map(|(header, stream)| BuffersEncoded {
268 header: header.into_boxed_slice(),
269 stream: stream.into_boxed_slice(),
270 })
271 }
272}
273
274#[derive(Debug)]
278pub struct BuffersEncoded {
279 header: Box<[u8]>,
280 stream: Box<[u8]>,
281}
282
283impl BuffersEncoded {
284 pub fn header(&self) -> &[u8] {
286 &self.header
287 }
288
289 pub fn stream(&self) -> &[u8] {
291 &self.stream
292 }
293
294 pub fn take(self) -> (Box<[u8]>, Box<[u8]>) {
295 self.into()
296 }
297}
298
299impl From<BuffersEncoded> for (Box<[u8]>, Box<[u8]>) {
300 fn from(buffers_encoded: BuffersEncoded) -> Self {
301 (buffers_encoded.header, buffers_encoded.stream)
302 }
303}
304
305struct InnerEncoder {
306 encoder: ls_qpack_rs_sys::lsqpack_enc,
307 enc_buffer: Vec<u8>,
308 hdr_buffer: Vec<u8>,
309 _marker: PhantomPinned,
310}
311
312impl InnerEncoder {
313 fn new() -> Pin<Box<Self>> {
314 let mut this = Box::new(Self {
315 encoder: ls_qpack_rs_sys::lsqpack_enc::default(),
316 enc_buffer: Vec::new(),
317 hdr_buffer: Vec::new(),
318 _marker: PhantomPinned,
319 });
320
321 unsafe {
325 ls_qpack_rs_sys::lsqpack_enc_preinit(&mut this.encoder, std::ptr::null_mut());
326 }
327
328 Box::into_pin(this)
329 }
330
331 fn init(
332 self: Pin<&mut Self>,
333 max_table_size: u32,
334 dyn_table_size: u32,
335 max_blocked_streams: u32,
336 ) -> Result<Box<[u8]>, EncoderError> {
337 let this = unsafe { self.get_unchecked_mut() };
341
342 let mut buffer = vec![0; ls_qpack_rs_sys::LSQPACK_LONGEST_SDTC as usize];
343 let mut sdtc_buffer_size = buffer.len();
344
345 let result = unsafe {
349 ls_qpack_rs_sys::lsqpack_enc_init(
350 &mut this.encoder,
351 std::ptr::null_mut(),
352 max_table_size,
353 dyn_table_size,
354 max_blocked_streams,
355 ls_qpack_rs_sys::lsqpack_enc_opts_LSQPACK_ENC_OPT_STAGE_2,
356 buffer.as_mut_ptr(),
357 &mut sdtc_buffer_size,
358 )
359 };
360
361 if result == 0 {
362 buffer.truncate(sdtc_buffer_size);
363 Ok(buffer.into_boxed_slice())
364 } else {
365 Err(EncoderError::new(EncoderErrorKind::InitFailed))
366 }
367 }
368
369 fn start_header_block(
371 self: Pin<&mut Self>,
372 stream_id: StreamId,
373 seqno: u32,
374 ) -> Result<(), EncoderError> {
375 let this = unsafe { self.get_unchecked_mut() };
377
378 let result = unsafe {
381 ls_qpack_rs_sys::lsqpack_enc_start_header(&mut this.encoder, stream_id.value(), seqno)
382 };
383
384 if result == 0 {
385 this.enc_buffer.clear();
386 this.hdr_buffer.clear();
387
388 Ok(())
389 } else {
390 Err(EncoderError::new(EncoderErrorKind::EncodeFailed))
391 }
392 }
393
394 fn encode<H>(self: Pin<&mut Self>, header: H) -> Result<(), EncoderError>
395 where
396 H: TryIntoHeader,
397 {
398 const BUFFER_SIZE: usize = 4096;
399
400 let mut header = header
401 .try_into_header()
402 .map_err(|_| EncoderError::new(EncoderErrorKind::InvalidHeader))?;
403
404 let this = unsafe { self.get_unchecked_mut() };
407
408 let enc_buffer_offset = this.enc_buffer.len();
409 this.enc_buffer.resize(enc_buffer_offset + BUFFER_SIZE, 0);
410
411 let hdr_buffer_offset = this.hdr_buffer.len();
412 this.hdr_buffer.resize(hdr_buffer_offset + BUFFER_SIZE, 0);
413
414 let mut enc_buffer_size = this.enc_buffer.len() - enc_buffer_offset;
415 let mut hdr_buffer_size = this.hdr_buffer.len() - hdr_buffer_offset;
416
417 let result = unsafe {
422 ls_qpack_rs_sys::lsqpack_enc_encode(
423 &mut this.encoder,
424 this.enc_buffer.as_mut_ptr().add(enc_buffer_offset),
425 &mut enc_buffer_size,
426 this.hdr_buffer.as_mut_ptr().add(hdr_buffer_offset),
427 &mut hdr_buffer_size,
428 header.build_lsxpack_header().as_ref(),
429 0,
430 )
431 };
432
433 if result == ls_qpack_rs_sys::lsqpack_enc_status_LQES_OK {
434 this.enc_buffer
435 .truncate(enc_buffer_offset + enc_buffer_size);
436 this.hdr_buffer
437 .truncate(hdr_buffer_offset + hdr_buffer_size);
438
439 Ok(())
440 } else {
441 this.enc_buffer.truncate(enc_buffer_offset);
442 this.hdr_buffer.truncate(hdr_buffer_offset);
443
444 Err(EncoderError::new(EncoderErrorKind::EncodeFailed))
445 }
446 }
447
448 fn end_header_block(self: Pin<&mut Self>) -> Result<(Vec<u8>, Vec<u8>), EncoderError> {
455 let this = unsafe { self.get_unchecked_mut() };
457
458 let max_prefix_len =
461 unsafe { ls_qpack_rs_sys::lsqpack_enc_header_block_prefix_size(&this.encoder) };
462
463 let mut hdr_block = vec![0; max_prefix_len + this.hdr_buffer.len()];
464
465 let hdr_prefix_len = unsafe {
468 ls_qpack_rs_sys::lsqpack_enc_end_header(
469 &mut this.encoder,
470 hdr_block.as_mut_ptr(),
471 max_prefix_len,
472 std::ptr::null_mut(),
473 )
474 };
475
476 if hdr_prefix_len > 0 {
477 hdr_block.truncate(hdr_prefix_len as usize);
478 hdr_block.extend_from_slice(&this.hdr_buffer);
479
480 Ok((hdr_block, std::mem::take(&mut this.enc_buffer)))
481 } else {
482 Err(EncoderError::new(EncoderErrorKind::EndHeaderFailed))
483 }
484 }
485
486 fn feed_decoder_data(self: Pin<&mut Self>, data: &[u8]) -> Result<(), EncoderError> {
487 let this = unsafe { self.get_unchecked_mut() };
489
490 let result = unsafe {
493 ls_qpack_rs_sys::lsqpack_enc_decoder_in(&mut this.encoder, data.as_ptr(), data.len())
494 };
495
496 if result == 0 {
497 Ok(())
498 } else {
499 Err(EncoderError::new(EncoderErrorKind::FeedFailed))
500 }
501 }
502
503 fn ratio(self: Pin<&Self>) -> f32 {
504 unsafe { ls_qpack_rs_sys::lsqpack_enc_ratio(&self.encoder) }
506 }
507}
508
509impl Drop for InnerEncoder {
510 fn drop(&mut self) {
511 unsafe { ls_qpack_rs_sys::lsqpack_enc_cleanup(&mut self.encoder) }
515 }
516}
517
518unsafe impl Send for InnerEncoder {}
523
524unsafe impl Sync for InnerEncoder {}
529
530const _: () = {
531 fn _assert_send<T: Send>() {}
532 fn _assert_sync<T: Sync>() {}
533 fn _assert_all() {
534 _assert_send::<Encoder>();
535 _assert_sync::<Encoder>();
536 }
537};
538
539impl Debug for EncoderError {
540 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
541 f.debug_struct("EncoderError")
542 .field("kind", &self.kind)
543 .finish()
544 }
545}
546
547impl Display for EncoderError {
548 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
549 match self.kind {
550 EncoderErrorKind::InvalidHeader => write!(f, "invalid header"),
551 EncoderErrorKind::InitFailed => write!(f, "encoder initialization failed"),
552 EncoderErrorKind::EncodeFailed => write!(f, "encoding operation failed"),
553 EncoderErrorKind::EndHeaderFailed => write!(f, "failed to finalize header block"),
554 EncoderErrorKind::FeedFailed => write!(f, "failed to process decoder stream data"),
555 }
556 }
557}
558
559impl std::error::Error for EncoderError {}
560
561#[cfg(test)]
562mod tests {
563 use super::Encoder;
564 use super::StreamId;
565
566 #[test]
567 fn test_encoder_determinism_static() {
568 let mut encoder = Encoder::new();
569
570 let results = (0..1024)
571 .map(|_| {
572 encoder
573 .encode_all(StreamId::new(0), utilities::HEADERS_LIST_1)
574 .unwrap()
575 })
576 .collect::<Vec<_>>();
577
578 assert!(results.iter().all(|b| b.header() == results[0].header()));
579 assert!(results.iter().all(|b| b.stream().is_empty()));
580 }
581
582 mod utilities {
583 pub(super) const HEADERS_LIST_1: [(&str, &str); 2] =
584 [(":status", "404"), (":method", "connect")];
585 }
586}