1#[cfg(feature = "encryption")]
2use crate::cfb8::{setup_craft_cipher, CipherError, CraftCipher};
3use crate::util::{get_sized_buf, move_data_rightwards, VAR_INT_BUF_SIZE};
4use crate::wrapper::{CraftIo, CraftWrapper};
5use crate::DEAFULT_MAX_PACKET_SIZE;
6#[cfg(feature = "compression")]
7use flate2::{CompressError, Compression, FlushCompress, Status};
8use mcproto_rs::protocol::{Id, Packet, PacketDirection, RawPacket, State};
9use mcproto_rs::types::VarInt;
10use mcproto_rs::{Serialize, SerializeErr, SerializeResult, Serializer};
11#[cfg(feature = "backtrace")]
12use std::backtrace::Backtrace;
13use std::ops::{Deref, DerefMut};
14use thiserror::Error;
15#[cfg(any(feature = "futures-io", feature = "tokio-io"))]
16use async_trait::async_trait;
17
18#[derive(Debug, Error)]
19pub enum WriteError {
20 #[error("packet serialization error")]
21 Serialize {
22 #[from]
23 err: PacketSerializeFail,
24 #[cfg(feature = "backtrace")]
25 backtrace: Backtrace,
26 },
27 #[error("failed to compress packet")]
28 #[cfg(feature = "compression")]
29 CompressFail {
30 #[from]
31 err: CompressError,
32 #[cfg(feature = "backtrace")]
33 backtrace: Backtrace,
34 },
35 #[error("compression gave buf error")]
36 #[cfg(feature = "compression")]
37 CompressBufError {
38 #[cfg(feature = "backtrace")]
39 backtrace: Backtrace,
40 },
41 #[error("io error while writing data")]
42 IoFail {
43 #[from]
44 err: std::io::Error,
45 #[cfg(feature = "backtrace")]
46 backtrace: Backtrace,
47 },
48 #[error("bad direction")]
49 BadDirection {
50 attempted: PacketDirection,
51 expected: PacketDirection,
52 #[cfg(feature = "backtrace")]
53 backtrace: Backtrace,
54 },
55 #[error("bad state")]
56 BadState {
57 attempted: State,
58 expected: State,
59 #[cfg(feature = "backtrace")]
60 backtrace: Backtrace,
61 },
62 #[error("packet size {size} exceeds maximum size {max_size}")]
63 PacketTooLarge {
64 size: usize,
65 max_size: usize,
66 #[cfg(feature = "backtrace")]
67 backtrace: Backtrace,
68 }
69}
70
71#[derive(Debug, Error)]
72pub enum PacketSerializeFail {
73 #[error("failed to serialize packet header")]
74 Header(#[source] SerializeErr),
75 #[error("failed to serialize packet contents")]
76 Body(#[source] SerializeErr),
77}
78
79impl Deref for PacketSerializeFail {
80 type Target = SerializeErr;
81
82 fn deref(&self) -> &Self::Target {
83 use PacketSerializeFail::*;
84 match self {
85 Header(err) => err,
86 Body(err) => err,
87 }
88 }
89}
90
91impl DerefMut for PacketSerializeFail {
92 fn deref_mut(&mut self) -> &mut Self::Target {
93 use PacketSerializeFail::*;
94 match self {
95 Header(err) => err,
96 Body(err) => err,
97 }
98 }
99}
100
101impl Into<SerializeErr> for PacketSerializeFail {
102 fn into(self) -> SerializeErr {
103 use PacketSerializeFail::*;
104 match self {
105 Header(err) => err,
106 Body(err) => err,
107 }
108 }
109}
110
111pub type WriteResult<P> = Result<P, WriteError>;
112
113#[cfg(any(feature = "futures-io", feature = "tokio-io"))]
120#[async_trait]
121pub trait CraftAsyncWriter {
122 async fn write_packet_async<P>(&mut self, packet: P) -> WriteResult<()>
126 where
127 P: Packet + Send + Sync;
128
129 async fn write_raw_packet_async<'a, P>(&mut self, packet: P) -> WriteResult<()>
138 where
139 P: RawPacket<'a> + Send + Sync;
140}
141
142pub trait CraftSyncWriter {
150 fn write_packet<P>(&mut self, packet: P) -> WriteResult<()>
154 where
155 P: Packet;
156
157 fn write_raw_packet<'a, P>(&mut self, packet: P) -> WriteResult<()>
166 where
167 P: RawPacket<'a>;
168}
169
170pub struct CraftWriter<W> {
181 inner: W,
182 raw_buf: Option<Vec<u8>>,
183 #[cfg(feature = "compression")]
184 compress_buf: Option<Vec<u8>>,
185 #[cfg(feature = "compression")]
186 compression_threshold: Option<i32>,
187 state: State,
188 direction: PacketDirection,
189 #[cfg(feature = "encryption")]
190 encryption: Option<CraftCipher>,
191 max_packet_size: usize,
192}
193
194impl<W> CraftWrapper<W> for CraftWriter<W> {
195 fn into_inner(self) -> W {
196 self.inner
197 }
198}
199
200impl<W> CraftIo for CraftWriter<W> {
201 fn set_state(&mut self, next: State) {
202 self.state = next;
203 }
204
205 #[cfg(feature = "compression")]
206 fn set_compression_threshold(&mut self, threshold: Option<i32>) {
207 self.compression_threshold = threshold;
208 }
209
210 #[cfg(feature = "encryption")]
211 fn enable_encryption(&mut self, key: &[u8], iv: &[u8]) -> Result<(), CipherError> {
212 setup_craft_cipher(&mut self.encryption, key, iv)
213 }
214
215 fn set_max_packet_size(&mut self, max_size: usize) {
216 debug_assert!(max_size > 5);
217 self.max_packet_size = max_size;
218 }
219
220 fn ensure_buf_capacity(&mut self, capacity: usize) {
221 get_sized_buf(&mut self.raw_buf, 0, if capacity > self.max_packet_size {
222 self.max_packet_size
223 } else {
224 capacity
225 });
226 }
227
228 #[cfg(feature = "compression")]
229 fn ensure_compression_buf_capacity(&mut self, capacity: usize) {
230 get_sized_buf(&mut self.compress_buf, 0, if capacity > self.max_packet_size {
231 self.max_packet_size
232 } else {
233 capacity
234 });
235 }
236}
237
238impl<W> CraftSyncWriter for CraftWriter<W>
239where
240 W: std::io::Write,
241{
242 fn write_packet<P>(&mut self, packet: P) -> WriteResult<()>
243 where
244 P: Packet,
245 {
246 let prepared = self.serialize_packet_to_buf(packet)?;
247 write_data_to_target_sync(self.prepare_packet_in_buf(prepared)?)?;
248 Ok(())
249 }
250
251 fn write_raw_packet<'a, P>(&mut self, packet: P) -> WriteResult<()>
252 where
253 P: RawPacket<'a>,
254 {
255 let prepared = self.serialize_raw_packet_to_buf(packet)?;
256 write_data_to_target_sync(self.prepare_packet_in_buf(prepared)?)?;
257 Ok(())
258 }
259}
260
261fn write_data_to_target_sync<'a, W>(tuple: (&'a [u8], &'a mut W)) -> Result<(), std::io::Error>
262where
263 W: std::io::Write,
264{
265 let (data, target) = tuple;
266 target.write_all(data)
267}
268
269#[cfg(any(feature = "tokio-io", feature = "futures-io"))]
270#[async_trait]
271pub trait AsyncWriteAll: Unpin + Send + Sync {
272 async fn write_all(&mut self, data: &[u8]) -> Result<(), std::io::Error>;
273}
274
275#[cfg(all(feature = "futures-io", not(feature = "tokio-io")))]
276#[async_trait]
277impl<W> AsyncWriteAll for W
278where
279 W: futures::AsyncWrite + Unpin + Send + Sync,
280{
281 async fn write_all(&mut self, data: &[u8]) -> Result<(), std::io::Error> {
282 futures::AsyncWriteExt::write_all(self, data).await?;
283 Ok(())
284 }
285}
286
287#[cfg(feature = "tokio-io")]
288#[async_trait]
289impl<W> AsyncWriteAll for W
290where
291 W: tokio::io::AsyncWrite + Unpin + Send + Sync,
292{
293 async fn write_all(&mut self, data: &[u8]) -> Result<(), std::io::Error> {
294 tokio::io::AsyncWriteExt::write_all(self, data).await?;
295 Ok(())
296 }
297}
298
299#[cfg(any(feature = "futures-io", feature = "tokio-io"))]
300#[async_trait]
301impl<W> CraftAsyncWriter for CraftWriter<W>
302where
303 W: AsyncWriteAll,
304{
305 async fn write_packet_async<P>(&mut self, packet: P) -> WriteResult<()>
306 where
307 P: Packet + Send + Sync,
308 {
309 let prepared = self.serialize_packet_to_buf(packet)?;
310 write_data_to_target_async(self.prepare_packet_in_buf(prepared)?).await?;
311 Ok(())
312 }
313
314 async fn write_raw_packet_async<'a, P>(&mut self, packet: P) -> WriteResult<()>
315 where
316 P: RawPacket<'a> + Send + Sync,
317 {
318 let prepared = self.serialize_raw_packet_to_buf(packet)?;
319 write_data_to_target_async(self.prepare_packet_in_buf(prepared)?).await?;
320 Ok(())
321 }
322}
323
324#[cfg(any(feature = "futures-io", feature = "tokio-io"))]
325async fn write_data_to_target_async<'a, W>(
326 tuple: (&'a [u8], &'a mut W),
327) -> Result<(), std::io::Error>
328where
329 W: AsyncWriteAll,
330{
331 let (data, target) = tuple;
332 target.write_all(data).await
333}
334
335#[cfg(feature = "compression")]
374const HEADER_OFFSET: usize = VAR_INT_BUF_SIZE + 1;
375
376#[cfg(not(feature = "compression"))]
377const HEADER_OFFSET: usize = VAR_INT_BUF_SIZE;
378
379#[cfg(feature = "compression")]
380const COMPRESSED_HEADER_OFFSET: usize = VAR_INT_BUF_SIZE * 2;
381
382struct PreparedPacketHandle {
383 id_size: usize,
384 data_size: usize,
385}
386
387impl<W> CraftWriter<W> {
388 pub fn wrap(inner: W, direction: PacketDirection) -> Self {
389 Self::wrap_with_state(inner, direction, State::Handshaking)
390 }
391
392 pub fn wrap_with_state(inner: W, direction: PacketDirection, state: State) -> Self {
393 Self {
394 inner,
395 raw_buf: None,
396 #[cfg(feature = "compression")]
397 compression_threshold: None,
398 #[cfg(feature = "compression")]
399 compress_buf: None,
400 state,
401 direction,
402 #[cfg(feature = "encryption")]
403 encryption: None,
404 max_packet_size: DEAFULT_MAX_PACKET_SIZE,
405 }
406 }
407
408 fn prepare_packet_in_buf(
409 &mut self,
410 prepared: PreparedPacketHandle,
411 ) -> WriteResult<(&[u8], &mut W)> {
412 let body_size = prepared.id_size + prepared.data_size;
414 let buf = get_sized_buf(&mut self.raw_buf, 0, HEADER_OFFSET + body_size);
415
416 #[cfg(feature = "compression")]
417 let packet_data = if let Some(threshold) = self.compression_threshold {
418 if threshold >= 0 && (threshold as usize) <= body_size {
419 let body_data = &buf[HEADER_OFFSET..];
420 prepare_packet_compressed(body_data, &mut self.compress_buf)?
421 } else {
422 prepare_packet_compressed_below_threshold(buf, body_size)?
423 }
424 } else {
425 prepare_packet_normally(buf, body_size)?
426 };
427
428 #[cfg(not(feature = "compression"))]
429 let packet_data = prepare_packet_normally(buf, body_size)?;
430
431 #[cfg(feature = "encryption")]
432 handle_encryption(self.encryption.as_mut(), packet_data);
433
434 Ok((packet_data, &mut self.inner))
435 }
436
437 fn serialize_packet_to_buf<P>(&mut self, packet: P) -> WriteResult<PreparedPacketHandle>
438 where
439 P: Packet,
440 {
441 let id_size = self.serialize_id_to_buf(packet.id())?;
442 let data_size = self.serialize_to_buf(HEADER_OFFSET + id_size, move |serializer| {
443 packet
444 .mc_serialize_body(serializer)
445 .map_err(move |err| PacketSerializeFail::Body(err).into())
446 })?;
447
448 Ok(PreparedPacketHandle { id_size, data_size })
449 }
450
451 fn serialize_raw_packet_to_buf<'a, P>(&mut self, packet: P) -> WriteResult<PreparedPacketHandle>
452 where
453 P: RawPacket<'a>,
454 {
455 let id_size = self.serialize_id_to_buf(packet.id())?;
456 let packet_data = packet.data();
457 let data_size = packet_data.len();
458 if data_size > self.max_packet_size {
459 return Err(WriteError::PacketTooLarge {
460 size: data_size,
461 max_size: self.max_packet_size,
462 #[cfg(feature = "backtrace")]
463 backtrace: Backtrace::capture()
464 })
465 }
466 let buf = get_sized_buf(&mut self.raw_buf, HEADER_OFFSET, id_size + data_size);
467
468 (&mut buf[id_size..]).copy_from_slice(packet_data);
469
470 Ok(PreparedPacketHandle { id_size, data_size })
471 }
472
473 fn serialize_id_to_buf(&mut self, id: Id) -> WriteResult<usize> {
474 if id.direction != self.direction {
475 return Err(WriteError::BadDirection {
476 expected: self.direction,
477 attempted: id.direction,
478 #[cfg(feature = "backtrace")]
479 backtrace: Backtrace::capture(),
480 });
481 }
482
483 if id.state != self.state {
484 return Err(WriteError::BadState {
485 expected: self.state,
486 attempted: id.state,
487 #[cfg(feature = "backtrace")]
488 backtrace: Backtrace::capture(),
489 });
490 }
491
492 self.serialize_to_buf(HEADER_OFFSET, move |serializer| {
493 id.mc_serialize(serializer)
494 .map_err(move |err| PacketSerializeFail::Header(err).into())
495 })
496 }
497
498 fn serialize_to_buf<'a, F>(&'a mut self, offset: usize, f: F) -> WriteResult<usize>
499 where
500 F: FnOnce(&mut GrowVecSerializer<'a>) -> Result<(), WriteError>,
501 {
502 let mut serializer = GrowVecSerializer::create(&mut self.raw_buf, offset, self.max_packet_size);
503 f(&mut serializer)?;
504 let packet_size = serializer.written_data_len();
505 if serializer.exceeded_max_size {
506 Err(WriteError::PacketTooLarge {
507 size: packet_size,
508 max_size: self.max_packet_size,
509 #[cfg(feature = "backtrace")]
510 backtrace: Backtrace::capture(),
511 })
512 } else {
513 Ok(packet_size)
514 }
515 }
516}
517
518fn prepare_packet_normally(buf: &mut [u8], body_size: usize) -> WriteResult<&mut [u8]> {
519 #[cfg(feature = "compression")]
520 const BUF_SKIP_BYTES: usize = 1;
521
522 #[cfg(not(feature = "compression"))]
523 const BUF_SKIP_BYTES: usize = 0;
524
525 let packet_len_target = &mut buf[BUF_SKIP_BYTES..HEADER_OFFSET];
526 let mut packet_len_serializer = SliceSerializer::create(packet_len_target);
527 VarInt(body_size as i32)
528 .mc_serialize(&mut packet_len_serializer)
529 .map_err(move |err| PacketSerializeFail::Header(err))?;
530 let packet_len_bytes = packet_len_serializer.finish().len();
531
532 let n_shift_packet_len = VAR_INT_BUF_SIZE - packet_len_bytes;
533 move_data_rightwards(
534 &mut buf[BUF_SKIP_BYTES..HEADER_OFFSET],
535 packet_len_bytes,
536 n_shift_packet_len,
537 );
538
539 let start_offset = n_shift_packet_len + BUF_SKIP_BYTES;
540 let end_at = start_offset + packet_len_bytes + body_size;
541 Ok(&mut buf[start_offset..end_at])
542}
543
544#[cfg(feature = "compression")]
545fn prepare_packet_compressed<'a>(
546 buf: &'a [u8],
547 compress_buf: &'a mut Option<Vec<u8>>,
548) -> WriteResult<&'a mut [u8]> {
549 let compressed_size = compress(buf, compress_buf, COMPRESSED_HEADER_OFFSET)?.len();
550 let compress_buf = get_sized_buf(compress_buf, 0, compressed_size + COMPRESSED_HEADER_OFFSET);
551
552 let data_len_target = &mut compress_buf[VAR_INT_BUF_SIZE..COMPRESSED_HEADER_OFFSET];
553 let mut data_len_serializer = SliceSerializer::create(data_len_target);
554 VarInt(buf.len() as i32)
555 .mc_serialize(&mut data_len_serializer)
556 .map_err(move |err| PacketSerializeFail::Header(err))?;
557 let data_len_bytes = data_len_serializer.finish().len();
558
559 let packet_len_target = &mut compress_buf[..VAR_INT_BUF_SIZE];
560 let mut packet_len_serializer = SliceSerializer::create(packet_len_target);
561 VarInt((compressed_size + data_len_bytes) as i32)
562 .mc_serialize(&mut packet_len_serializer)
563 .map_err(move |err| PacketSerializeFail::Header(err))?;
564 let packet_len_bytes = packet_len_serializer.finish().len();
565
566 let n_shift_packet_len = VAR_INT_BUF_SIZE - packet_len_bytes;
567 move_data_rightwards(
568 &mut compress_buf[..COMPRESSED_HEADER_OFFSET],
569 packet_len_bytes,
570 n_shift_packet_len,
571 );
572 let n_shift_data_len = VAR_INT_BUF_SIZE - data_len_bytes;
573 move_data_rightwards(
574 &mut compress_buf[n_shift_packet_len..COMPRESSED_HEADER_OFFSET],
575 packet_len_bytes + data_len_bytes,
576 n_shift_data_len,
577 );
578 let start_offset = n_shift_data_len + n_shift_packet_len;
579 let end_at = start_offset + data_len_bytes + packet_len_bytes + compressed_size;
580
581 Ok(&mut compress_buf[start_offset..end_at])
582}
583
584#[cfg(feature = "compression")]
585fn prepare_packet_compressed_below_threshold(
586 buf: &mut [u8],
587 body_size: usize,
588) -> WriteResult<&mut [u8]> {
589 let packet_len_target = &mut buf[..HEADER_OFFSET - 1];
590 let mut packet_len_serializer = SliceSerializer::create(packet_len_target);
591 VarInt((body_size + 1) as i32) .mc_serialize(&mut packet_len_serializer)
593 .map_err(move |err| PacketSerializeFail::Header(err))?;
594
595 let packet_len_bytes = packet_len_serializer.finish().len();
596 let n_shift_packet_len = VAR_INT_BUF_SIZE - packet_len_bytes;
597 move_data_rightwards(
598 &mut buf[..HEADER_OFFSET - 1],
599 packet_len_bytes,
600 n_shift_packet_len,
601 );
602
603 let end_at = n_shift_packet_len + packet_len_bytes + 1 + body_size;
604 buf[HEADER_OFFSET - 1] = 0; Ok(&mut buf[n_shift_packet_len..end_at])
606}
607
608#[cfg(feature = "encryption")]
609fn handle_encryption(encryption: Option<&mut CraftCipher>, buf: &mut [u8]) {
610 if let Some(encryption) = encryption {
611 encryption.encrypt(buf);
612 }
613}
614
615#[derive(Debug)]
616struct GrowVecSerializer<'a> {
617 target: &'a mut Option<Vec<u8>>,
618 at: usize,
619 offset: usize,
620 max_size: usize,
621 exceeded_max_size: bool,
622}
623
624impl<'a> Serializer for GrowVecSerializer<'a> {
625 fn serialize_bytes(&mut self, data: &[u8]) -> SerializeResult {
626 if !self.exceeded_max_size {
627 let cur_len = self.written_data_len();
628 let new_len = cur_len + data.len();
629 if new_len > self.max_size {
630 self.exceeded_max_size = true;
631 } else {
632 get_sized_buf(self.target, self.at + self.offset, data.len()).copy_from_slice(data);
633 }
634 }
635
636 self.at += data.len();
637
638 Ok(())
639 }
640}
641
642impl<'a> GrowVecSerializer<'a> {
643 fn create(target: &'a mut Option<Vec<u8>>, offset: usize, max_size: usize) -> Self {
644 Self {
645 target,
646 at: 0,
647 offset,
648 max_size,
649 exceeded_max_size: false,
650 }
651 }
652
653 fn written_data_len(&self) -> usize {
654 self.at
655 }
656}
657
658struct SliceSerializer<'a> {
659 target: &'a mut [u8],
660 at: usize,
661}
662
663impl<'a> Serializer for SliceSerializer<'a> {
664 fn serialize_bytes(&mut self, data: &[u8]) -> SerializeResult {
665 let end_at = self.at + data.len();
666 if end_at >= self.target.len() {
667 panic!(
668 "cannot fit data in slice ({} exceeds length {} at {})",
669 data.len(),
670 self.target.len(),
671 self.at
672 );
673 }
674
675 (&mut self.target[self.at..end_at]).copy_from_slice(data);
676 self.at = end_at;
677 Ok(())
678 }
679}
680
681impl<'a> SliceSerializer<'a> {
682 fn create(target: &'a mut [u8]) -> Self {
683 Self { target, at: 0 }
684 }
685
686 fn finish(self) -> &'a [u8] {
687 &self.target[..self.at]
688 }
689}
690
691#[cfg(feature = "compression")]
692fn compress<'a, 'b>(
693 src: &'b [u8],
694 output: &'a mut Option<Vec<u8>>,
695 offset: usize,
696) -> Result<&'a mut [u8], WriteError> {
697 let target = get_sized_buf(output, offset, src.len());
698 let mut compressor = flate2::Compress::new_with_window_bits(Compression::fast(), true, 15);
699 loop {
700 let input = &src[(compressor.total_in() as usize)..];
701 let eof = input.is_empty();
702 let output = &mut target[(compressor.total_out() as usize)..];
703 let flush = if eof {
704 FlushCompress::Finish
705 } else {
706 FlushCompress::None
707 };
708
709 match compressor.compress(input, output, flush)? {
710 Status::Ok => {}
711 Status::BufError => {
712 return Err(WriteError::CompressBufError {
713 #[cfg(feature = "backtrace")]
714 backtrace: Backtrace::capture(),
715 })
716 }
717 Status::StreamEnd => break,
718 }
719 }
720
721 Ok(&mut target[..(compressor.total_out() as usize)])
722}