1use std::collections::HashMap;
12
13const PROTO: u8 = 0x80;
15const STOP: u8 = b'.';
16const NONE: u8 = b'N';
17const NEWTRUE: u8 = 0x88;
18const NEWFALSE: u8 = 0x89;
19const BININT1: u8 = b'K';
20const BININT2: u8 = b'M';
21const BININT4: u8 = b'J';
22const BINFLOAT: u8 = b'G';
23const SHORT_BINUNICODE: u8 = 0x8c;
24const BINUNICODE: u8 = b'X';
25const BINBYTES: u8 = b'B'; const SHORT_BINBYTES: u8 = b'C'; const EMPTY_LIST: u8 = b']';
28const EMPTY_DICT: u8 = b'}';
29const APPENDS: u8 = b'e';
30const APPEND: u8 = b'a';
31const SETITEM: u8 = b's';
32const SETITEMS: u8 = b'u';
33const MARK: u8 = b'(';
34const BINPUT: u8 = b'q';
35const LONG_BINPUT: u8 = b'r';
36const BINGET: u8 = b'h';
37const LONG_BINGET: u8 = b'j';
38const GLOBAL: u8 = b'c';
39const REDUCE: u8 = b'R';
40const TUPLE1: u8 = 0x85;
41const TUPLE2: u8 = 0x86;
42const TUPLE3: u8 = 0x87;
43const EMPTY_TUPLE: u8 = b')';
44const LONG1: u8 = 0x8a;
45const SHORT_BINSTRING: u8 = b'U'; const BINSTRING: u8 = b'T'; const FRAME: u8 = 0x95;
49const MEMOIZE: u8 = 0x94;
50const SHORT_BINBYTES8: u8 = 0x8e; const BINUNICODE8: u8 = 0x8d; const BYTEARRAY8: u8 = 0x96; #[derive(Debug, Clone, PartialEq)]
56pub enum PickleValue {
57 None,
58 Bool(bool),
59 Int(i64),
60 Float(f64),
61 String(String),
62 Bytes(Vec<u8>),
63 List(Vec<PickleValue>),
64 Dict(Vec<(PickleValue, PickleValue)>),
65}
66
67impl PickleValue {
68 pub fn as_str(&self) -> Option<&str> {
70 match self {
71 PickleValue::String(s) => Some(s),
72 _ => None,
73 }
74 }
75
76 pub fn as_int(&self) -> Option<i64> {
78 match self {
79 PickleValue::Int(n) => Some(*n),
80 _ => None,
81 }
82 }
83
84 pub fn as_float(&self) -> Option<f64> {
86 match self {
87 PickleValue::Float(f) => Some(*f),
88 _ => None,
89 }
90 }
91
92 pub fn as_bool(&self) -> Option<bool> {
94 match self {
95 PickleValue::Bool(b) => Some(*b),
96 _ => None,
97 }
98 }
99
100 pub fn as_bytes(&self) -> Option<&[u8]> {
102 match self {
103 PickleValue::Bytes(b) => Some(b),
104 _ => None,
105 }
106 }
107
108 pub fn as_list(&self) -> Option<&[PickleValue]> {
110 match self {
111 PickleValue::List(l) => Some(l),
112 _ => None,
113 }
114 }
115
116 pub fn get(&self, key: &str) -> Option<&PickleValue> {
118 match self {
119 PickleValue::Dict(pairs) => {
120 for (k, v) in pairs {
121 if let PickleValue::String(s) = k {
122 if s == key {
123 return Some(v);
124 }
125 }
126 }
127 None
128 }
129 _ => None,
130 }
131 }
132}
133
134pub fn encode(value: &PickleValue) -> Vec<u8> {
136 let mut buf = Vec::new();
137 buf.push(PROTO);
138 buf.push(2); encode_value(&mut buf, value);
140 buf.push(STOP);
141 buf
142}
143
144fn encode_value(buf: &mut Vec<u8>, value: &PickleValue) {
145 match value {
146 PickleValue::None => buf.push(NONE),
147 PickleValue::Bool(true) => buf.push(NEWTRUE),
148 PickleValue::Bool(false) => buf.push(NEWFALSE),
149 PickleValue::Int(n) => encode_int(buf, *n),
150 PickleValue::Float(f) => {
151 buf.push(BINFLOAT);
152 buf.extend_from_slice(&f.to_be_bytes());
153 }
154 PickleValue::String(s) => {
155 let bytes = s.as_bytes();
156 if bytes.len() < 256 {
157 buf.push(SHORT_BINUNICODE);
158 buf.push(bytes.len() as u8);
159 } else {
160 buf.push(BINUNICODE);
161 buf.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
162 }
163 buf.extend_from_slice(bytes);
164 }
165 PickleValue::Bytes(data) => {
166 buf.extend_from_slice(b"c_codecs\nencode\n");
170 let mut latin1_utf8 = Vec::with_capacity(data.len() * 2);
173 for &b in data.iter() {
174 if b < 0x80 {
175 latin1_utf8.push(b);
176 } else {
177 latin1_utf8.push(0xC0 | (b >> 6));
179 latin1_utf8.push(0x80 | (b & 0x3F));
180 }
181 }
182 if latin1_utf8.len() < 256 {
183 buf.push(SHORT_BINUNICODE);
184 buf.push(latin1_utf8.len() as u8);
185 } else {
186 buf.push(BINUNICODE);
187 buf.extend_from_slice(&(latin1_utf8.len() as u32).to_le_bytes());
188 }
189 buf.extend_from_slice(&latin1_utf8);
190 buf.push(SHORT_BINUNICODE);
192 buf.push(7); buf.extend_from_slice(b"latin-1");
194 buf.push(TUPLE2);
195 buf.push(REDUCE);
196 }
197 PickleValue::List(items) => {
198 buf.push(EMPTY_LIST);
199 if !items.is_empty() {
200 buf.push(MARK);
201 for item in items {
202 encode_value(buf, item);
203 }
204 buf.push(APPENDS);
205 }
206 }
207 PickleValue::Dict(pairs) => {
208 buf.push(EMPTY_DICT);
209 if !pairs.is_empty() {
210 buf.push(MARK);
211 for (k, v) in pairs {
212 encode_value(buf, k);
213 encode_value(buf, v);
214 }
215 buf.push(SETITEMS);
216 }
217 }
218 }
219}
220
221fn encode_int(buf: &mut Vec<u8>, n: i64) {
222 if n >= 0 && n < 256 {
223 buf.push(BININT1);
224 buf.push(n as u8);
225 } else if n >= 0 && n < 65536 {
226 buf.push(BININT2);
227 buf.extend_from_slice(&(n as u16).to_le_bytes());
228 } else if n >= i32::MIN as i64 && n <= i32::MAX as i64 {
229 buf.push(BININT4);
230 buf.extend_from_slice(&(n as i32).to_le_bytes());
231 } else {
232 buf.push(LONG1);
234 let bytes = long_to_bytes(n);
235 buf.push(bytes.len() as u8);
236 buf.extend_from_slice(&bytes);
237 }
238}
239
240fn long_to_bytes(n: i64) -> Vec<u8> {
241 if n == 0 {
242 return vec![];
243 }
244 let bytes = n.to_le_bytes();
245 let mut len = 8;
247 if n > 0 {
248 while len > 1 && bytes[len - 1] == 0x00 {
249 len -= 1;
250 }
251 if bytes[len - 1] & 0x80 != 0 {
253 let mut result = bytes[..len].to_vec();
254 result.push(0x00);
255 return result;
256 }
257 } else {
258 while len > 1 && bytes[len - 1] == 0xFF {
259 len -= 1;
260 }
261 if bytes[len - 1] & 0x80 == 0 {
263 let mut result = bytes[..len].to_vec();
264 result.push(0xFF);
265 return result;
266 }
267 }
268 bytes[..len].to_vec()
269}
270
271#[derive(Debug)]
273pub enum DecodeError {
274 UnexpectedEnd,
275 UnknownOpcode(u8),
276 InvalidUtf8,
277 StackUnderflow,
278 NoMarkFound,
279 NoStop,
280 UnsupportedGlobal(String),
281}
282
283impl std::fmt::Display for DecodeError {
284 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285 match self {
286 DecodeError::UnexpectedEnd => write!(f, "unexpected end of pickle data"),
287 DecodeError::UnknownOpcode(op) => write!(f, "unknown pickle opcode: 0x{:02x}", op),
288 DecodeError::InvalidUtf8 => write!(f, "invalid UTF-8 in pickle string"),
289 DecodeError::StackUnderflow => write!(f, "stack underflow"),
290 DecodeError::NoMarkFound => write!(f, "no mark found on stack"),
291 DecodeError::NoStop => write!(f, "no STOP opcode found"),
292 DecodeError::UnsupportedGlobal(name) => {
293 write!(f, "unsupported global: {}", name)
294 }
295 }
296 }
297}
298
299impl std::error::Error for DecodeError {}
300
301pub fn decode(data: &[u8]) -> Result<PickleValue, DecodeError> {
303 let mut stack: Vec<PickleValue> = Vec::new();
304 let mut memo: HashMap<u32, PickleValue> = HashMap::new();
305 let mut memo_counter: u32 = 0;
306 let mut pos = 0;
307
308 if pos < data.len() && data[pos] == PROTO {
310 pos += 2; }
312
313 loop {
314 if pos >= data.len() {
315 return Err(DecodeError::NoStop);
316 }
317
318 let op = data[pos];
319 pos += 1;
320
321 match op {
322 STOP => {
323 return stack.pop().ok_or(DecodeError::StackUnderflow);
324 }
325 NONE => stack.push(PickleValue::None),
326 NEWTRUE => stack.push(PickleValue::Bool(true)),
327 NEWFALSE => stack.push(PickleValue::Bool(false)),
328 BININT1 => {
329 if pos >= data.len() {
330 return Err(DecodeError::UnexpectedEnd);
331 }
332 stack.push(PickleValue::Int(data[pos] as i64));
333 pos += 1;
334 }
335 BININT2 => {
336 if pos + 2 > data.len() {
337 return Err(DecodeError::UnexpectedEnd);
338 }
339 let val = u16::from_le_bytes([data[pos], data[pos + 1]]);
340 stack.push(PickleValue::Int(val as i64));
341 pos += 2;
342 }
343 BININT4 => {
344 if pos + 4 > data.len() {
345 return Err(DecodeError::UnexpectedEnd);
346 }
347 let val = i32::from_le_bytes([
348 data[pos],
349 data[pos + 1],
350 data[pos + 2],
351 data[pos + 3],
352 ]);
353 stack.push(PickleValue::Int(val as i64));
354 pos += 4;
355 }
356 LONG1 => {
357 if pos >= data.len() {
358 return Err(DecodeError::UnexpectedEnd);
359 }
360 let n = data[pos] as usize;
361 pos += 1;
362 if pos + n > data.len() {
363 return Err(DecodeError::UnexpectedEnd);
364 }
365 let val = bytes_to_long(&data[pos..pos + n]);
366 stack.push(PickleValue::Int(val));
367 pos += n;
368 }
369 BINFLOAT => {
370 if pos + 8 > data.len() {
371 return Err(DecodeError::UnexpectedEnd);
372 }
373 let val = f64::from_be_bytes([
374 data[pos],
375 data[pos + 1],
376 data[pos + 2],
377 data[pos + 3],
378 data[pos + 4],
379 data[pos + 5],
380 data[pos + 6],
381 data[pos + 7],
382 ]);
383 stack.push(PickleValue::Float(val));
384 pos += 8;
385 }
386 SHORT_BINUNICODE => {
387 if pos >= data.len() {
388 return Err(DecodeError::UnexpectedEnd);
389 }
390 let len = data[pos] as usize;
391 pos += 1;
392 if pos + len > data.len() {
393 return Err(DecodeError::UnexpectedEnd);
394 }
395 let s = std::str::from_utf8(&data[pos..pos + len])
396 .map_err(|_| DecodeError::InvalidUtf8)?;
397 stack.push(PickleValue::String(s.to_string()));
398 pos += len;
399 }
400 BINUNICODE => {
401 if pos + 4 > data.len() {
402 return Err(DecodeError::UnexpectedEnd);
403 }
404 let len = u32::from_le_bytes([
405 data[pos],
406 data[pos + 1],
407 data[pos + 2],
408 data[pos + 3],
409 ]) as usize;
410 pos += 4;
411 if pos + len > data.len() {
412 return Err(DecodeError::UnexpectedEnd);
413 }
414 let s = std::str::from_utf8(&data[pos..pos + len])
415 .map_err(|_| DecodeError::InvalidUtf8)?;
416 stack.push(PickleValue::String(s.to_string()));
417 pos += len;
418 }
419 SHORT_BINSTRING => {
420 if pos >= data.len() {
422 return Err(DecodeError::UnexpectedEnd);
423 }
424 let len = data[pos] as usize;
425 pos += 1;
426 if pos + len > data.len() {
427 return Err(DecodeError::UnexpectedEnd);
428 }
429 stack.push(PickleValue::Bytes(data[pos..pos + len].to_vec()));
430 pos += len;
431 }
432 BINSTRING => {
433 if pos + 4 > data.len() {
435 return Err(DecodeError::UnexpectedEnd);
436 }
437 let len = i32::from_le_bytes([
438 data[pos],
439 data[pos + 1],
440 data[pos + 2],
441 data[pos + 3],
442 ]) as usize;
443 pos += 4;
444 if pos + len > data.len() {
445 return Err(DecodeError::UnexpectedEnd);
446 }
447 stack.push(PickleValue::Bytes(data[pos..pos + len].to_vec()));
448 pos += len;
449 }
450 SHORT_BINBYTES => {
451 if pos >= data.len() {
457 return Err(DecodeError::UnexpectedEnd);
458 }
459 let len = data[pos] as usize;
460 pos += 1;
461 if pos + len > data.len() {
462 return Err(DecodeError::UnexpectedEnd);
463 }
464 stack.push(PickleValue::Bytes(data[pos..pos + len].to_vec()));
465 pos += len;
466 }
467 BINBYTES => {
468 if pos + 4 > data.len() {
469 return Err(DecodeError::UnexpectedEnd);
470 }
471 let len = u32::from_le_bytes([
472 data[pos],
473 data[pos + 1],
474 data[pos + 2],
475 data[pos + 3],
476 ]) as usize;
477 pos += 4;
478 if pos + len > data.len() {
479 return Err(DecodeError::UnexpectedEnd);
480 }
481 stack.push(PickleValue::Bytes(data[pos..pos + len].to_vec()));
482 pos += len;
483 }
484 EMPTY_LIST => stack.push(PickleValue::List(Vec::new())),
485 EMPTY_DICT => stack.push(PickleValue::Dict(Vec::new())),
486 EMPTY_TUPLE => stack.push(PickleValue::List(Vec::new())), MARK => stack.push(PickleValue::String("__mark__".into())),
488 FRAME => {
489 if pos + 8 > data.len() {
492 return Err(DecodeError::UnexpectedEnd);
493 }
494 pos += 8;
495 }
496 MEMOIZE => {
497 if let Some(val) = stack.last() {
499 memo.insert(memo_counter, val.clone());
500 }
501 memo_counter += 1;
502 }
503 APPEND => {
504 let item = stack.pop().ok_or(DecodeError::StackUnderflow)?;
506 if let Some(PickleValue::List(ref mut list)) = stack.last_mut() {
507 list.push(item);
508 } else {
509 return Err(DecodeError::StackUnderflow);
510 }
511 }
512 APPENDS => {
513 let mark_pos = find_mark(&stack)?;
515 let items: Vec<PickleValue> = stack.drain(mark_pos + 1..).collect();
516 stack.pop(); if let Some(PickleValue::List(ref mut list)) = stack.last_mut() {
518 list.extend(items);
519 } else {
520 return Err(DecodeError::StackUnderflow);
521 }
522 }
523 SETITEM => {
524 let value = stack.pop().ok_or(DecodeError::StackUnderflow)?;
526 let key = stack.pop().ok_or(DecodeError::StackUnderflow)?;
527 if let Some(PickleValue::Dict(ref mut dict)) = stack.last_mut() {
528 dict.push((key, value));
529 } else {
530 return Err(DecodeError::StackUnderflow);
531 }
532 }
533 SETITEMS => {
534 let mark_pos = find_mark(&stack)?;
536 let items: Vec<PickleValue> = stack.drain(mark_pos + 1..).collect();
537 stack.pop(); if let Some(PickleValue::Dict(ref mut dict)) = stack.last_mut() {
539 for pair in items.chunks_exact(2) {
540 dict.push((pair[0].clone(), pair[1].clone()));
541 }
542 } else {
543 return Err(DecodeError::StackUnderflow);
544 }
545 }
546 TUPLE1 => {
547 let a = stack.pop().ok_or(DecodeError::StackUnderflow)?;
548 stack.push(PickleValue::List(vec![a]));
549 }
550 TUPLE2 => {
551 let b = stack.pop().ok_or(DecodeError::StackUnderflow)?;
552 let a = stack.pop().ok_or(DecodeError::StackUnderflow)?;
553 stack.push(PickleValue::List(vec![a, b]));
554 }
555 TUPLE3 => {
556 let c = stack.pop().ok_or(DecodeError::StackUnderflow)?;
557 let b = stack.pop().ok_or(DecodeError::StackUnderflow)?;
558 let a = stack.pop().ok_or(DecodeError::StackUnderflow)?;
559 stack.push(PickleValue::List(vec![a, b, c]));
560 }
561 SHORT_BINBYTES8 => {
562 if pos + 8 > data.len() {
564 return Err(DecodeError::UnexpectedEnd);
565 }
566 let len = u64::from_le_bytes([
567 data[pos], data[pos+1], data[pos+2], data[pos+3],
568 data[pos+4], data[pos+5], data[pos+6], data[pos+7],
569 ]) as usize;
570 pos += 8;
571 if pos + len > data.len() {
572 return Err(DecodeError::UnexpectedEnd);
573 }
574 stack.push(PickleValue::Bytes(data[pos..pos + len].to_vec()));
575 pos += len;
576 }
577 BINUNICODE8 => {
578 if pos + 8 > data.len() {
580 return Err(DecodeError::UnexpectedEnd);
581 }
582 let len = u64::from_le_bytes([
583 data[pos], data[pos+1], data[pos+2], data[pos+3],
584 data[pos+4], data[pos+5], data[pos+6], data[pos+7],
585 ]) as usize;
586 pos += 8;
587 if pos + len > data.len() {
588 return Err(DecodeError::UnexpectedEnd);
589 }
590 let s = std::str::from_utf8(&data[pos..pos + len])
591 .map_err(|_| DecodeError::InvalidUtf8)?;
592 stack.push(PickleValue::String(s.to_string()));
593 pos += len;
594 }
595 BYTEARRAY8 => {
596 if pos + 8 > data.len() {
598 return Err(DecodeError::UnexpectedEnd);
599 }
600 let len = u64::from_le_bytes([
601 data[pos], data[pos+1], data[pos+2], data[pos+3],
602 data[pos+4], data[pos+5], data[pos+6], data[pos+7],
603 ]) as usize;
604 pos += 8;
605 if pos + len > data.len() {
606 return Err(DecodeError::UnexpectedEnd);
607 }
608 stack.push(PickleValue::Bytes(data[pos..pos + len].to_vec()));
609 pos += len;
610 }
611 BINPUT => {
612 if pos >= data.len() {
613 return Err(DecodeError::UnexpectedEnd);
614 }
615 let idx = data[pos] as u32;
616 pos += 1;
617 if let Some(val) = stack.last() {
618 memo.insert(idx, val.clone());
619 }
620 }
621 LONG_BINPUT => {
622 if pos + 4 > data.len() {
623 return Err(DecodeError::UnexpectedEnd);
624 }
625 let idx = u32::from_le_bytes([
626 data[pos],
627 data[pos + 1],
628 data[pos + 2],
629 data[pos + 3],
630 ]);
631 pos += 4;
632 if let Some(val) = stack.last() {
633 memo.insert(idx, val.clone());
634 }
635 }
636 BINGET => {
637 if pos >= data.len() {
638 return Err(DecodeError::UnexpectedEnd);
639 }
640 let idx = data[pos] as u32;
641 pos += 1;
642 let val = memo
643 .get(&idx)
644 .cloned()
645 .ok_or(DecodeError::StackUnderflow)?;
646 stack.push(val);
647 }
648 LONG_BINGET => {
649 if pos + 4 > data.len() {
650 return Err(DecodeError::UnexpectedEnd);
651 }
652 let idx = u32::from_le_bytes([
653 data[pos],
654 data[pos + 1],
655 data[pos + 2],
656 data[pos + 3],
657 ]);
658 pos += 4;
659 let val = memo
660 .get(&idx)
661 .cloned()
662 .ok_or(DecodeError::StackUnderflow)?;
663 stack.push(val);
664 }
665 GLOBAL => {
666 let nl1 = data[pos..]
668 .iter()
669 .position(|&b| b == b'\n')
670 .ok_or(DecodeError::UnexpectedEnd)?;
671 let module =
672 std::str::from_utf8(&data[pos..pos + nl1]).map_err(|_| DecodeError::InvalidUtf8)?;
673 pos += nl1 + 1;
674 let nl2 = data[pos..]
675 .iter()
676 .position(|&b| b == b'\n')
677 .ok_or(DecodeError::UnexpectedEnd)?;
678 let name =
679 std::str::from_utf8(&data[pos..pos + nl2]).map_err(|_| DecodeError::InvalidUtf8)?;
680 pos += nl2 + 1;
681
682 if module == "_codecs" && name == "encode" {
684 stack.push(PickleValue::String("__codecs_encode__".into()));
685 } else {
686 return Err(DecodeError::UnsupportedGlobal(format!(
687 "{}.{}",
688 module, name
689 )));
690 }
691 }
692 REDUCE => {
693 let args = stack.pop().ok_or(DecodeError::StackUnderflow)?;
695 let callable = stack.pop().ok_or(DecodeError::StackUnderflow)?;
696
697 if let PickleValue::String(ref s) = callable {
698 if s == "__codecs_encode__" {
699 if let PickleValue::List(ref items) = args {
701 if let Some(PickleValue::String(ref text)) = items.first() {
702 let bytes: Vec<u8> =
704 text.chars().map(|c| c as u8).collect();
705 stack.push(PickleValue::Bytes(bytes));
706 } else {
707 stack.push(PickleValue::None);
708 }
709 } else {
710 stack.push(PickleValue::None);
711 }
712 } else {
713 return Err(DecodeError::UnsupportedGlobal(s.clone()));
714 }
715 } else {
716 return Err(DecodeError::StackUnderflow);
717 }
718 }
719 other => {
720 return Err(DecodeError::UnknownOpcode(other));
721 }
722 }
723 }
724}
725
726fn bytes_to_long(bytes: &[u8]) -> i64 {
727 if bytes.is_empty() {
728 return 0;
729 }
730 let negative = bytes[bytes.len() - 1] & 0x80 != 0;
731 let mut result: i64 = 0;
732 for (i, &b) in bytes.iter().enumerate() {
733 result |= (b as i64) << (i * 8);
734 }
735 if negative {
736 let bits = bytes.len() * 8;
738 if bits < 64 {
739 result |= !0i64 << bits;
740 }
741 }
742 result
743}
744
745fn find_mark(stack: &[PickleValue]) -> Result<usize, DecodeError> {
746 for i in (0..stack.len()).rev() {
747 if let PickleValue::String(ref s) = stack[i] {
748 if s == "__mark__" {
749 return Ok(i);
750 }
751 }
752 }
753 Err(DecodeError::NoMarkFound)
754}
755
756#[cfg(test)]
757mod tests {
758 use super::*;
759
760 #[test]
761 fn roundtrip_none() {
762 let val = PickleValue::None;
763 let encoded = encode(&val);
764 let decoded = decode(&encoded).unwrap();
765 assert_eq!(decoded, val);
766 }
767
768 #[test]
769 fn roundtrip_bool_true() {
770 let val = PickleValue::Bool(true);
771 let encoded = encode(&val);
772 let decoded = decode(&encoded).unwrap();
773 assert_eq!(decoded, val);
774 }
775
776 #[test]
777 fn roundtrip_bool_false() {
778 let val = PickleValue::Bool(false);
779 let encoded = encode(&val);
780 let decoded = decode(&encoded).unwrap();
781 assert_eq!(decoded, val);
782 }
783
784 #[test]
785 fn roundtrip_int_small() {
786 let val = PickleValue::Int(42);
787 let encoded = encode(&val);
788 let decoded = decode(&encoded).unwrap();
789 assert_eq!(decoded, val);
790 }
791
792 #[test]
793 fn roundtrip_int_medium() {
794 let val = PickleValue::Int(1000);
795 let encoded = encode(&val);
796 let decoded = decode(&encoded).unwrap();
797 assert_eq!(decoded, val);
798 }
799
800 #[test]
801 fn roundtrip_int_large() {
802 let val = PickleValue::Int(100000);
803 let encoded = encode(&val);
804 let decoded = decode(&encoded).unwrap();
805 assert_eq!(decoded, val);
806 }
807
808 #[test]
809 fn roundtrip_int_negative() {
810 let val = PickleValue::Int(-42);
811 let encoded = encode(&val);
812 let decoded = decode(&encoded).unwrap();
813 assert_eq!(decoded, val);
814 }
815
816 #[test]
817 fn roundtrip_float() {
818 let val = PickleValue::Float(3.14159);
819 let encoded = encode(&val);
820 let decoded = decode(&encoded).unwrap();
821 assert_eq!(decoded, val);
822 }
823
824 #[test]
825 fn roundtrip_string_short() {
826 let val = PickleValue::String("hello".into());
827 let encoded = encode(&val);
828 let decoded = decode(&encoded).unwrap();
829 assert_eq!(decoded, val);
830 }
831
832 #[test]
833 fn roundtrip_string_long() {
834 let val = PickleValue::String("x".repeat(300));
835 let encoded = encode(&val);
836 let decoded = decode(&encoded).unwrap();
837 assert_eq!(decoded, val);
838 }
839
840 #[test]
841 fn roundtrip_bytes() {
842 let val = PickleValue::Bytes(vec![0, 1, 2, 3, 255]);
843 let encoded = encode(&val);
844 let decoded = decode(&encoded).unwrap();
845 assert_eq!(decoded, val);
846 }
847
848 #[test]
849 fn roundtrip_empty_list() {
850 let val = PickleValue::List(vec![]);
851 let encoded = encode(&val);
852 let decoded = decode(&encoded).unwrap();
853 assert_eq!(decoded, val);
854 }
855
856 #[test]
857 fn roundtrip_list() {
858 let val = PickleValue::List(vec![
859 PickleValue::Int(1),
860 PickleValue::String("two".into()),
861 PickleValue::Bool(true),
862 ]);
863 let encoded = encode(&val);
864 let decoded = decode(&encoded).unwrap();
865 assert_eq!(decoded, val);
866 }
867
868 #[test]
869 fn roundtrip_empty_dict() {
870 let val = PickleValue::Dict(vec![]);
871 let encoded = encode(&val);
872 let decoded = decode(&encoded).unwrap();
873 assert_eq!(decoded, val);
874 }
875
876 #[test]
877 fn roundtrip_dict() {
878 let val = PickleValue::Dict(vec![
879 (
880 PickleValue::String("key".into()),
881 PickleValue::Int(42),
882 ),
883 (
884 PickleValue::String("flag".into()),
885 PickleValue::Bool(false),
886 ),
887 ]);
888 let encoded = encode(&val);
889 let decoded = decode(&encoded).unwrap();
890 assert_eq!(decoded, val);
891 }
892
893 #[test]
894 fn roundtrip_nested() {
895 let val = PickleValue::Dict(vec![
896 (
897 PickleValue::String("list".into()),
898 PickleValue::List(vec![
899 PickleValue::Int(1),
900 PickleValue::Dict(vec![(
901 PickleValue::String("inner".into()),
902 PickleValue::None,
903 )]),
904 ]),
905 ),
906 (
907 PickleValue::String("bytes".into()),
908 PickleValue::Bytes(vec![0xDE, 0xAD]),
909 ),
910 ]);
911 let encoded = encode(&val);
912 let decoded = decode(&encoded).unwrap();
913 assert_eq!(decoded, val);
914 }
915
916 #[test]
917 fn reject_unknown_opcode() {
918 let data = vec![0x80, 0x02, 0xFF];
920 assert!(decode(&data).is_err());
921 }
922
923 #[test]
924 fn dict_get_helper() {
925 let val = PickleValue::Dict(vec![
926 (PickleValue::String("get".into()), PickleValue::String("interface_stats".into())),
927 ]);
928 assert_eq!(
929 val.get("get").unwrap().as_str().unwrap(),
930 "interface_stats"
931 );
932 assert!(val.get("missing").is_none());
933 }
934
935 #[test]
936 fn roundtrip_int_zero() {
937 let val = PickleValue::Int(0);
938 let encoded = encode(&val);
939 let decoded = decode(&encoded).unwrap();
940 assert_eq!(decoded, val);
941 }
942
943 #[test]
944 fn roundtrip_int_255() {
945 let val = PickleValue::Int(255);
946 let encoded = encode(&val);
947 let decoded = decode(&encoded).unwrap();
948 assert_eq!(decoded, val);
949 }
950
951 #[test]
952 fn roundtrip_bytes_empty() {
953 let val = PickleValue::Bytes(vec![]);
954 let encoded = encode(&val);
955 let decoded = decode(&encoded).unwrap();
956 assert_eq!(decoded, val);
957 }
958
959 #[test]
960 fn roundtrip_large_int() {
961 let val = PickleValue::Int(i64::MAX);
962 let encoded = encode(&val);
963 let decoded = decode(&encoded).unwrap();
964 assert_eq!(decoded, val);
965 }
966
967 #[test]
968 fn roundtrip_negative_large_int() {
969 let val = PickleValue::Int(i64::MIN);
970 let encoded = encode(&val);
971 let decoded = decode(&encoded).unwrap();
972 assert_eq!(decoded, val);
973 }
974
975 #[test]
976 fn decode_python_dict() {
977 let data = vec![
980 0x80, 0x02, b'}', b'(', 0x8c, 3, b'g', b'e', b't', 0x8c, 5, b's', b't', b'a', b't', b's', b'u', b'.', ];
988 let val = decode(&data).unwrap();
989 assert_eq!(val.get("get").unwrap().as_str().unwrap(), "stats");
990 }
991
992 #[test]
993 fn decode_protocol4_dict() {
994 let data = vec![
997 0x80, 0x04, 0x95, 0x1c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, b'}', 0x94, 0x8c, 0x03, b'g', b'e', b't', 0x94, 0x8c, 0x0f, b'i', b'n', b't', b'e', b'r', b'f', b'a', b'c', b'e',
1005 b'_', b's', b't', b'a', b't', b's',
1006 0x94, b's', b'.', ];
1010 let val = decode(&data).unwrap();
1011 assert_eq!(val.get("get").unwrap().as_str().unwrap(), "interface_stats");
1012 }
1013
1014 #[test]
1015 fn decode_protocol4_with_bytes() {
1016 let data = vec![
1018 0x80, 0x04, 0x95, 0x2c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, b'}', 0x94, b'(', 0x8c, 0x04, b'd', b'r', b'o', b'p', 0x94, 0x8c, 0x04, b'p', b'a', b't', b'h', 0x94, 0x8c, 0x10, b'd', b'e', b's', b't', b'i', b'n', b'a', b't',
1028 b'i', b'o', b'n', b'_', b'h', b'a', b's', b'h',
1029 0x94, b'C', 0x03, 0x01, 0x02, 0x03, 0x94, b'u', b'.', ];
1035 let val = decode(&data).unwrap();
1036 assert_eq!(val.get("drop").unwrap().as_str().unwrap(), "path");
1037 assert_eq!(val.get("destination_hash").unwrap().as_bytes().unwrap(), &[1, 2, 3]);
1038 }
1039}