1use crate::{
2 builtins::{PyBaseExceptionRef, PyBytesRef, PyTuple, PyTupleRef, PyTypeRef},
3 common::{static_cell, str::wchar_t},
4 convert::ToPyObject,
5 function::{ArgBytesLike, ArgIntoBool, ArgIntoFloat},
6 PyObjectRef, PyResult, TryFromObject, VirtualMachine,
7};
8use half::f16;
9use itertools::Itertools;
10use malachite_bigint::BigInt;
11use num_traits::{PrimInt, ToPrimitive};
12use std::{fmt, iter::Peekable, mem, os::raw};
13
14type PackFunc = fn(&VirtualMachine, PyObjectRef, &mut [u8]) -> PyResult<()>;
15type UnpackFunc = fn(&VirtualMachine, &[u8]) -> PyObjectRef;
16
17static OVERFLOW_MSG: &str = "total struct size too long"; #[derive(Debug, Copy, Clone, PartialEq)]
20pub(crate) enum Endianness {
21 Native,
22 Little,
23 Big,
24 Host,
25}
26
27impl Endianness {
28 fn parse<I>(chars: &mut Peekable<I>) -> Endianness
31 where
32 I: Sized + Iterator<Item = u8>,
33 {
34 let e = match chars.peek() {
35 Some(b'@') => Endianness::Native,
36 Some(b'=') => Endianness::Host,
37 Some(b'<') => Endianness::Little,
38 Some(b'>') | Some(b'!') => Endianness::Big,
39 _ => return Endianness::Native,
40 };
41 chars.next().unwrap();
42 e
43 }
44}
45
46trait ByteOrder {
47 fn convert<I: PrimInt>(i: I) -> I;
48}
49enum BigEndian {}
50impl ByteOrder for BigEndian {
51 fn convert<I: PrimInt>(i: I) -> I {
52 i.to_be()
53 }
54}
55enum LittleEndian {}
56impl ByteOrder for LittleEndian {
57 fn convert<I: PrimInt>(i: I) -> I {
58 i.to_le()
59 }
60}
61
62#[cfg(target_endian = "big")]
63type NativeEndian = BigEndian;
64#[cfg(target_endian = "little")]
65type NativeEndian = LittleEndian;
66
67#[derive(Copy, Clone, num_enum::TryFromPrimitive)]
68#[repr(u8)]
69pub(crate) enum FormatType {
70 Pad = b'x',
71 SByte = b'b',
72 UByte = b'B',
73 Char = b'c',
74 WideChar = b'u',
75 Str = b's',
76 Pascal = b'p',
77 Short = b'h',
78 UShort = b'H',
79 Int = b'i',
80 UInt = b'I',
81 Long = b'l',
82 ULong = b'L',
83 SSizeT = b'n',
84 SizeT = b'N',
85 LongLong = b'q',
86 ULongLong = b'Q',
87 Bool = b'?',
88 Half = b'e',
89 Float = b'f',
90 Double = b'd',
91 VoidP = b'P',
92}
93
94impl fmt::Debug for FormatType {
95 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96 fmt::Debug::fmt(&(*self as u8 as char), f)
97 }
98}
99
100impl FormatType {
101 fn info(self, e: Endianness) -> &'static FormatInfo {
102 use mem::{align_of, size_of};
103 use FormatType::*;
104 macro_rules! native_info {
105 ($t:ty) => {{
106 &FormatInfo {
107 size: size_of::<$t>(),
108 align: align_of::<$t>(),
109 pack: Some(<$t as Packable>::pack::<NativeEndian>),
110 unpack: Some(<$t as Packable>::unpack::<NativeEndian>),
111 }
112 }};
113 }
114 macro_rules! nonnative_info {
115 ($t:ty, $end:ty) => {{
116 &FormatInfo {
117 size: size_of::<$t>(),
118 align: 0,
119 pack: Some(<$t as Packable>::pack::<$end>),
120 unpack: Some(<$t as Packable>::unpack::<$end>),
121 }
122 }};
123 }
124 macro_rules! match_nonnative {
125 ($zelf:expr, $end:ty) => {{
126 match $zelf {
127 Pad | Str | Pascal => &FormatInfo {
128 size: size_of::<u8>(),
129 align: 0,
130 pack: None,
131 unpack: None,
132 },
133 SByte => nonnative_info!(i8, $end),
134 UByte => nonnative_info!(u8, $end),
135 Char => &FormatInfo {
136 size: size_of::<u8>(),
137 align: 0,
138 pack: Some(pack_char),
139 unpack: Some(unpack_char),
140 },
141 Short => nonnative_info!(i16, $end),
142 UShort => nonnative_info!(u16, $end),
143 Int | Long => nonnative_info!(i32, $end),
144 UInt | ULong => nonnative_info!(u32, $end),
145 LongLong => nonnative_info!(i64, $end),
146 ULongLong => nonnative_info!(u64, $end),
147 Bool => nonnative_info!(bool, $end),
148 Half => nonnative_info!(f16, $end),
149 Float => nonnative_info!(f32, $end),
150 Double => nonnative_info!(f64, $end),
151 _ => unreachable!(), }
153 }};
154 }
155 match e {
156 Endianness::Native => match self {
157 Pad | Str | Pascal => &FormatInfo {
158 size: size_of::<raw::c_char>(),
159 align: 0,
160 pack: None,
161 unpack: None,
162 },
163 SByte => native_info!(raw::c_schar),
164 UByte => native_info!(raw::c_uchar),
165 Char => &FormatInfo {
166 size: size_of::<raw::c_char>(),
167 align: 0,
168 pack: Some(pack_char),
169 unpack: Some(unpack_char),
170 },
171 WideChar => native_info!(wchar_t),
172 Short => native_info!(raw::c_short),
173 UShort => native_info!(raw::c_ushort),
174 Int => native_info!(raw::c_int),
175 UInt => native_info!(raw::c_uint),
176 Long => native_info!(raw::c_long),
177 ULong => native_info!(raw::c_ulong),
178 SSizeT => native_info!(isize), SizeT => native_info!(usize), LongLong => native_info!(raw::c_longlong),
181 ULongLong => native_info!(raw::c_ulonglong),
182 Bool => native_info!(bool),
183 Half => native_info!(f16),
184 Float => native_info!(raw::c_float),
185 Double => native_info!(raw::c_double),
186 VoidP => native_info!(*mut raw::c_void),
187 },
188 Endianness::Big => match_nonnative!(self, BigEndian),
189 Endianness::Little => match_nonnative!(self, LittleEndian),
190 Endianness::Host => match_nonnative!(self, NativeEndian),
191 }
192 }
193}
194
195#[derive(Debug, Clone)]
196pub(crate) struct FormatCode {
197 pub repeat: usize,
198 pub code: FormatType,
199 pub info: &'static FormatInfo,
200 pub pre_padding: usize,
201}
202
203impl FormatCode {
204 pub fn arg_count(&self) -> usize {
205 match self.code {
206 FormatType::Pad => 0,
207 FormatType::Str | FormatType::Pascal => 1,
208 _ => self.repeat,
209 }
210 }
211
212 pub fn parse<I>(
213 chars: &mut Peekable<I>,
214 endianness: Endianness,
215 ) -> Result<(Vec<Self>, usize, usize), String>
216 where
217 I: Sized + Iterator<Item = u8>,
218 {
219 let mut offset = 0isize;
220 let mut arg_count = 0usize;
221 let mut codes = vec![];
222 while chars.peek().is_some() {
223 let repeat = match chars.peek() {
225 Some(b'0'..=b'9') => {
226 let mut repeat = 0isize;
227 while let Some(b'0'..=b'9') = chars.peek() {
228 if let Some(c) = chars.next() {
229 let current_digit = c - b'0';
230 repeat = repeat
231 .checked_mul(10)
232 .and_then(|r| r.checked_add(current_digit as _))
233 .ok_or_else(|| OVERFLOW_MSG.to_owned())?;
234 }
235 }
236 repeat
237 }
238 _ => 1,
239 };
240
241 let c = chars
243 .next()
244 .ok_or_else(|| "repeat count given without format specifier".to_owned())?;
245 let code = FormatType::try_from(c)
246 .ok()
247 .filter(|c| match c {
248 FormatType::SSizeT | FormatType::SizeT | FormatType::VoidP => {
249 endianness == Endianness::Native
250 }
251 _ => true,
252 })
253 .ok_or_else(|| "bad char in struct format".to_owned())?;
254
255 let info = code.info(endianness);
256
257 let padding = compensate_alignment(offset as usize, info.align)
258 .ok_or_else(|| OVERFLOW_MSG.to_owned())?;
259 offset = padding
260 .to_isize()
261 .and_then(|extra| offset.checked_add(extra))
262 .ok_or_else(|| OVERFLOW_MSG.to_owned())?;
263
264 let code = FormatCode {
265 repeat: repeat as usize,
266 code,
267 info,
268 pre_padding: padding,
269 };
270 arg_count += code.arg_count();
271 codes.push(code);
272
273 offset = (info.size as isize)
274 .checked_mul(repeat)
275 .and_then(|item_size| offset.checked_add(item_size))
276 .ok_or_else(|| OVERFLOW_MSG.to_owned())?;
277 }
278
279 Ok((codes, offset as usize, arg_count))
280 }
281}
282
283fn compensate_alignment(offset: usize, align: usize) -> Option<usize> {
284 if align != 0 && offset != 0 {
285 (align - 1).checked_sub((offset - 1) & (align - 1))
287 } else {
288 Some(0)
290 }
291}
292
293pub(crate) struct FormatInfo {
294 pub size: usize,
295 pub align: usize,
296 pub pack: Option<PackFunc>,
297 pub unpack: Option<UnpackFunc>,
298}
299impl fmt::Debug for FormatInfo {
300 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
301 f.debug_struct("FormatInfo")
302 .field("size", &self.size)
303 .field("align", &self.align)
304 .finish()
305 }
306}
307
308#[derive(Debug, Clone)]
309pub struct FormatSpec {
310 #[allow(dead_code)]
311 pub(crate) endianness: Endianness,
312 pub(crate) codes: Vec<FormatCode>,
313 pub size: usize,
314 pub arg_count: usize,
315}
316
317impl FormatSpec {
318 pub fn parse(fmt: &[u8], vm: &VirtualMachine) -> PyResult<FormatSpec> {
319 let mut chars = fmt.iter().copied().peekable();
320
321 let endianness = Endianness::parse(&mut chars);
323
324 let (codes, size, arg_count) =
326 FormatCode::parse(&mut chars, endianness).map_err(|err| new_struct_error(vm, err))?;
327
328 Ok(FormatSpec {
329 endianness,
330 codes,
331 size,
332 arg_count,
333 })
334 }
335
336 pub fn pack(&self, args: Vec<PyObjectRef>, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
337 let mut data = vec![0; self.size];
339
340 self.pack_into(&mut data, args, vm)?;
341
342 Ok(data)
343 }
344
345 pub fn pack_into(
346 &self,
347 mut buffer: &mut [u8],
348 args: Vec<PyObjectRef>,
349 vm: &VirtualMachine,
350 ) -> PyResult<()> {
351 if self.arg_count != args.len() {
352 return Err(new_struct_error(
353 vm,
354 format!(
355 "pack expected {} items for packing (got {})",
356 self.codes.len(),
357 args.len()
358 ),
359 ));
360 }
361
362 let mut args = args.into_iter();
363 for code in &self.codes {
365 buffer = &mut buffer[code.pre_padding..];
366 debug!("code: {:?}", code);
367 match code.code {
368 FormatType::Str => {
369 let (buf, rest) = buffer.split_at_mut(code.repeat);
370 pack_string(vm, args.next().unwrap(), buf)?;
371 buffer = rest;
372 }
373 FormatType::Pascal => {
374 let (buf, rest) = buffer.split_at_mut(code.repeat);
375 pack_pascal(vm, args.next().unwrap(), buf)?;
376 buffer = rest;
377 }
378 FormatType::Pad => {
379 let (pad_buf, rest) = buffer.split_at_mut(code.repeat);
380 for el in pad_buf {
381 *el = 0
382 }
383 buffer = rest;
384 }
385 _ => {
386 let pack = code.info.pack.unwrap();
387 for arg in args.by_ref().take(code.repeat) {
388 let (item_buf, rest) = buffer.split_at_mut(code.info.size);
389 pack(vm, arg, item_buf)?;
390 buffer = rest;
391 }
392 }
393 }
394 }
395
396 Ok(())
397 }
398
399 pub fn unpack(&self, mut data: &[u8], vm: &VirtualMachine) -> PyResult<PyTupleRef> {
400 if self.size != data.len() {
401 return Err(new_struct_error(
402 vm,
403 format!("unpack requires a buffer of {} bytes", self.size),
404 ));
405 }
406
407 let mut items = Vec::with_capacity(self.arg_count);
408 for code in &self.codes {
409 data = &data[code.pre_padding..];
410 debug!("unpack code: {:?}", code);
411 match code.code {
412 FormatType::Pad => {
413 data = &data[code.repeat..];
414 }
415 FormatType::Str => {
416 let (str_data, rest) = data.split_at(code.repeat);
417 items.push(vm.ctx.new_bytes(str_data.to_vec()).into());
419 data = rest;
420 }
421 FormatType::Pascal => {
422 let (str_data, rest) = data.split_at(code.repeat);
423 items.push(unpack_pascal(vm, str_data));
424 data = rest;
425 }
426 _ => {
427 let unpack = code.info.unpack.unwrap();
428 for _ in 0..code.repeat {
429 let (item_data, rest) = data.split_at(code.info.size);
430 items.push(unpack(vm, item_data));
431 data = rest;
432 }
433 }
434 };
435 }
436
437 Ok(PyTuple::new_ref(items, &vm.ctx))
438 }
439
440 #[inline]
441 pub fn size(&self) -> usize {
442 self.size
443 }
444}
445
446trait Packable {
447 fn pack<E: ByteOrder>(vm: &VirtualMachine, arg: PyObjectRef, data: &mut [u8]) -> PyResult<()>;
448 fn unpack<E: ByteOrder>(vm: &VirtualMachine, data: &[u8]) -> PyObjectRef;
449}
450
451trait PackInt: PrimInt {
452 fn pack_int<E: ByteOrder>(self, data: &mut [u8]);
453 fn unpack_int<E: ByteOrder>(data: &[u8]) -> Self;
454}
455
456macro_rules! make_pack_primint {
457 ($T:ty) => {
458 impl PackInt for $T {
459 fn pack_int<E: ByteOrder>(self, data: &mut [u8]) {
460 let i = E::convert(self);
461 data.copy_from_slice(&i.to_ne_bytes());
462 }
463 #[inline]
464 fn unpack_int<E: ByteOrder>(data: &[u8]) -> Self {
465 let mut x = [0; std::mem::size_of::<$T>()];
466 x.copy_from_slice(data);
467 E::convert(<$T>::from_ne_bytes(x))
468 }
469 }
470
471 impl Packable for $T {
472 fn pack<E: ByteOrder>(
473 vm: &VirtualMachine,
474 arg: PyObjectRef,
475 data: &mut [u8],
476 ) -> PyResult<()> {
477 let i: $T = get_int_or_index(vm, arg)?;
478 i.pack_int::<E>(data);
479 Ok(())
480 }
481
482 fn unpack<E: ByteOrder>(vm: &VirtualMachine, rdr: &[u8]) -> PyObjectRef {
483 let i = <$T>::unpack_int::<E>(rdr);
484 vm.ctx.new_int(i).into()
485 }
486 }
487 };
488}
489
490fn get_int_or_index<T>(vm: &VirtualMachine, arg: PyObjectRef) -> PyResult<T>
491where
492 T: PrimInt + for<'a> TryFrom<&'a BigInt>,
493{
494 let index = arg.try_index_opt(vm).unwrap_or_else(|| {
495 Err(new_struct_error(
496 vm,
497 "required argument is not an integer".to_owned(),
498 ))
499 })?;
500 index
501 .try_to_primitive(vm)
502 .map_err(|_| new_struct_error(vm, "argument out of range".to_owned()))
503}
504
505make_pack_primint!(i8);
506make_pack_primint!(u8);
507make_pack_primint!(i16);
508make_pack_primint!(u16);
509make_pack_primint!(i32);
510make_pack_primint!(u32);
511make_pack_primint!(i64);
512make_pack_primint!(u64);
513make_pack_primint!(usize);
514make_pack_primint!(isize);
515
516macro_rules! make_pack_float {
517 ($T:ty) => {
518 impl Packable for $T {
519 fn pack<E: ByteOrder>(
520 vm: &VirtualMachine,
521 arg: PyObjectRef,
522 data: &mut [u8],
523 ) -> PyResult<()> {
524 let f = *ArgIntoFloat::try_from_object(vm, arg)? as $T;
525 f.to_bits().pack_int::<E>(data);
526 Ok(())
527 }
528
529 fn unpack<E: ByteOrder>(vm: &VirtualMachine, rdr: &[u8]) -> PyObjectRef {
530 let i = PackInt::unpack_int::<E>(rdr);
531 <$T>::from_bits(i).to_pyobject(vm)
532 }
533 }
534 };
535}
536
537make_pack_float!(f32);
538make_pack_float!(f64);
539
540impl Packable for f16 {
541 fn pack<E: ByteOrder>(vm: &VirtualMachine, arg: PyObjectRef, data: &mut [u8]) -> PyResult<()> {
542 let f_64 = *ArgIntoFloat::try_from_object(vm, arg)?;
543 let f_16 = f16::from_f64(f_64);
544 if f_16.is_infinite() != f_64.is_infinite() {
545 return Err(vm.new_overflow_error("float too large to pack with e format".to_owned()));
546 }
547 f_16.to_bits().pack_int::<E>(data);
548 Ok(())
549 }
550
551 fn unpack<E: ByteOrder>(vm: &VirtualMachine, rdr: &[u8]) -> PyObjectRef {
552 let i = PackInt::unpack_int::<E>(rdr);
553 f16::from_bits(i).to_f64().to_pyobject(vm)
554 }
555}
556
557impl Packable for *mut raw::c_void {
558 fn pack<E: ByteOrder>(vm: &VirtualMachine, arg: PyObjectRef, data: &mut [u8]) -> PyResult<()> {
559 usize::pack::<E>(vm, arg, data)
560 }
561
562 fn unpack<E: ByteOrder>(vm: &VirtualMachine, rdr: &[u8]) -> PyObjectRef {
563 usize::unpack::<E>(vm, rdr)
564 }
565}
566
567impl Packable for bool {
568 fn pack<E: ByteOrder>(vm: &VirtualMachine, arg: PyObjectRef, data: &mut [u8]) -> PyResult<()> {
569 let v = *ArgIntoBool::try_from_object(vm, arg)? as u8;
570 v.pack_int::<E>(data);
571 Ok(())
572 }
573
574 fn unpack<E: ByteOrder>(vm: &VirtualMachine, rdr: &[u8]) -> PyObjectRef {
575 let i = u8::unpack_int::<E>(rdr);
576 vm.ctx.new_bool(i != 0).into()
577 }
578}
579
580fn pack_char(vm: &VirtualMachine, arg: PyObjectRef, data: &mut [u8]) -> PyResult<()> {
581 let v = PyBytesRef::try_from_object(vm, arg)?;
582 let ch = *v.as_bytes().iter().exactly_one().map_err(|_| {
583 new_struct_error(
584 vm,
585 "char format requires a bytes object of length 1".to_owned(),
586 )
587 })?;
588 data[0] = ch;
589 Ok(())
590}
591
592fn pack_string(vm: &VirtualMachine, arg: PyObjectRef, buf: &mut [u8]) -> PyResult<()> {
593 let b = ArgBytesLike::try_from_object(vm, arg)?;
594 b.with_ref(|data| write_string(buf, data));
595 Ok(())
596}
597
598fn pack_pascal(vm: &VirtualMachine, arg: PyObjectRef, buf: &mut [u8]) -> PyResult<()> {
599 if buf.is_empty() {
600 return Ok(());
601 }
602 let b = ArgBytesLike::try_from_object(vm, arg)?;
603 b.with_ref(|data| {
604 let string_length = std::cmp::min(std::cmp::min(data.len(), 255), buf.len() - 1);
605 buf[0] = string_length as u8;
606 write_string(&mut buf[1..], data);
607 });
608 Ok(())
609}
610
611fn write_string(buf: &mut [u8], data: &[u8]) {
612 let len_from_data = std::cmp::min(data.len(), buf.len());
613 buf[..len_from_data].copy_from_slice(&data[..len_from_data]);
614 for byte in &mut buf[len_from_data..] {
615 *byte = 0
616 }
617}
618
619fn unpack_char(vm: &VirtualMachine, data: &[u8]) -> PyObjectRef {
620 vm.ctx.new_bytes(vec![data[0]]).into()
621}
622
623fn unpack_pascal(vm: &VirtualMachine, data: &[u8]) -> PyObjectRef {
624 let (&len, data) = match data.split_first() {
625 Some(x) => x,
626 None => {
627 return vm.ctx.new_bytes(vec![]).into();
629 }
630 };
631 let len = std::cmp::min(len as usize, data.len());
632 vm.ctx.new_bytes(data[..len].to_vec()).into()
633}
634
635pub fn struct_error_type(vm: &VirtualMachine) -> &'static PyTypeRef {
637 static_cell! {
638 static INSTANCE: PyTypeRef;
639 }
640 INSTANCE.get_or_init(|| vm.ctx.new_exception_type("struct", "error", None))
641}
642
643pub fn new_struct_error(vm: &VirtualMachine, msg: String) -> PyBaseExceptionRef {
644 vm.new_exception_msg(struct_error_type(vm).clone(), msg)
647}