Skip to main content

rns_net/
pickle.rs

1//! Minimal pickle codec (protocols 2-5).
2//!
3//! Supports a subset of pickle opcodes sufficient for RPC serialization
4//! compatible with Python's `multiprocessing.connection`.
5//!
6//! Encoder always produces protocol 2 (maximum compatibility).
7//! Decoder accepts protocols 2-5 (Python 3.8+ defaults to protocol 4/5).
8//!
9//! Security: rejects unknown opcodes (no arbitrary code execution).
10
11use std::collections::HashMap;
12
13// Pickle opcodes (protocol 2)
14const PROTO: u8 = 0x80;
15const STOP: u8 = b'.';
16const NONE: u8 = b'N';
17const NEWTRUE: u8 = 0x88;
18const NEWFALSE: u8 = 0x89;
19const BININT1: u8 = b'K';
20const BININT2: u8 = b'M';
21const BININT4: u8 = b'J';
22const BINFLOAT: u8 = b'G';
23const SHORT_BINUNICODE: u8 = 0x8c;
24const BINUNICODE: u8 = b'X';
25const BINBYTES: u8 = b'B'; // protocol 3+
26const SHORT_BINBYTES: u8 = b'C'; // protocol 3+
27const EMPTY_LIST: u8 = b']';
28const EMPTY_DICT: u8 = b'}';
29const APPENDS: u8 = b'e';
30const APPEND: u8 = b'a';
31const SETITEM: u8 = b's';
32const SETITEMS: u8 = b'u';
33const MARK: u8 = b'(';
34const BINPUT: u8 = b'q';
35const LONG_BINPUT: u8 = b'r';
36const BINGET: u8 = b'h';
37const LONG_BINGET: u8 = b'j';
38const GLOBAL: u8 = b'c';
39const REDUCE: u8 = b'R';
40const TUPLE1: u8 = 0x85;
41const TUPLE2: u8 = 0x86;
42const TUPLE3: u8 = 0x87;
43const EMPTY_TUPLE: u8 = b')';
44const LONG1: u8 = 0x8a;
45const SHORT_BINSTRING: u8 = b'U'; // protocol 0/1 but appears in some pickles
46const BINSTRING: u8 = b'T'; // protocol 0/1
47                            // Protocol 4+ opcodes
48const FRAME: u8 = 0x95;
49const MEMOIZE: u8 = 0x94;
50const SHORT_BINBYTES8: u8 = 0x8e; // protocol 4: 8-byte length bytes
51const BINUNICODE8: u8 = 0x8d; // protocol 4: 8-byte length unicode
52const BYTEARRAY8: u8 = 0x96; // protocol 5: bytearray
53
54/// A pickle value.
55#[derive(Debug, Clone, PartialEq)]
56pub enum PickleValue {
57    None,
58    Bool(bool),
59    Int(i64),
60    Float(f64),
61    String(String),
62    Bytes(Vec<u8>),
63    List(Vec<PickleValue>),
64    Dict(Vec<(PickleValue, PickleValue)>),
65}
66
67impl PickleValue {
68    /// Get as string reference if this is a String variant.
69    pub fn as_str(&self) -> Option<&str> {
70        match self {
71            PickleValue::String(s) => Some(s),
72            _ => None,
73        }
74    }
75
76    /// Get as i64 if this is an Int variant.
77    pub fn as_int(&self) -> Option<i64> {
78        match self {
79            PickleValue::Int(n) => Some(*n),
80            _ => None,
81        }
82    }
83
84    /// Get as f64 if this is a Float variant.
85    pub fn as_float(&self) -> Option<f64> {
86        match self {
87            PickleValue::Float(f) => Some(*f),
88            _ => None,
89        }
90    }
91
92    /// Get as bool if this is a Bool variant.
93    pub fn as_bool(&self) -> Option<bool> {
94        match self {
95            PickleValue::Bool(b) => Some(*b),
96            _ => None,
97        }
98    }
99
100    /// Get as bytes reference if this is a Bytes variant.
101    pub fn as_bytes(&self) -> Option<&[u8]> {
102        match self {
103            PickleValue::Bytes(b) => Some(b),
104            _ => None,
105        }
106    }
107
108    /// Get as list reference if this is a List variant.
109    pub fn as_list(&self) -> Option<&[PickleValue]> {
110        match self {
111            PickleValue::List(l) => Some(l),
112            _ => None,
113        }
114    }
115
116    /// Look up a key in a Dict by string key.
117    pub fn get(&self, key: &str) -> Option<&PickleValue> {
118        match self {
119            PickleValue::Dict(pairs) => {
120                for (k, v) in pairs {
121                    if let PickleValue::String(s) = k {
122                        if s == key {
123                            return Some(v);
124                        }
125                    }
126                }
127                None
128            }
129            _ => None,
130        }
131    }
132}
133
134/// Encode a PickleValue as pickle protocol 2 bytes.
135pub fn encode(value: &PickleValue) -> Vec<u8> {
136    let mut buf = Vec::new();
137    buf.push(PROTO);
138    buf.push(2); // protocol 2
139    encode_value(&mut buf, value);
140    buf.push(STOP);
141    buf
142}
143
144fn encode_value(buf: &mut Vec<u8>, value: &PickleValue) {
145    match value {
146        PickleValue::None => buf.push(NONE),
147        PickleValue::Bool(true) => buf.push(NEWTRUE),
148        PickleValue::Bool(false) => buf.push(NEWFALSE),
149        PickleValue::Int(n) => encode_int(buf, *n),
150        PickleValue::Float(f) => {
151            buf.push(BINFLOAT);
152            buf.extend_from_slice(&f.to_be_bytes());
153        }
154        PickleValue::String(s) => {
155            let bytes = s.as_bytes();
156            if bytes.len() < 256 {
157                buf.push(SHORT_BINUNICODE);
158                buf.push(bytes.len() as u8);
159            } else {
160                buf.push(BINUNICODE);
161                buf.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
162            }
163            buf.extend_from_slice(bytes);
164        }
165        PickleValue::Bytes(data) => {
166            // Protocol 2 encodes bytes via _codecs.encode trick:
167            // GLOBAL _codecs.encode, then two args, TUPLE2, REDUCE
168            // No MARK needed since TUPLE2 takes exactly 2 items from stack.
169            buf.extend_from_slice(b"c_codecs\nencode\n");
170            // Encode the bytes as a latin-1 unicode string
171            // Bytes 0x00-0x7F map to same UTF-8; 0x80-0xFF need 2-byte UTF-8
172            let mut latin1_utf8 = Vec::with_capacity(data.len() * 2);
173            for &b in data.iter() {
174                if b < 0x80 {
175                    latin1_utf8.push(b);
176                } else {
177                    // UTF-8 encode U+0080..U+00FF
178                    latin1_utf8.push(0xC0 | (b >> 6));
179                    latin1_utf8.push(0x80 | (b & 0x3F));
180                }
181            }
182            if latin1_utf8.len() < 256 {
183                buf.push(SHORT_BINUNICODE);
184                buf.push(latin1_utf8.len() as u8);
185            } else {
186                buf.push(BINUNICODE);
187                buf.extend_from_slice(&(latin1_utf8.len() as u32).to_le_bytes());
188            }
189            buf.extend_from_slice(&latin1_utf8);
190            // encoding name
191            buf.push(SHORT_BINUNICODE);
192            buf.push(7); // "latin-1"
193            buf.extend_from_slice(b"latin-1");
194            buf.push(TUPLE2);
195            buf.push(REDUCE);
196        }
197        PickleValue::List(items) => {
198            buf.push(EMPTY_LIST);
199            if !items.is_empty() {
200                buf.push(MARK);
201                for item in items {
202                    encode_value(buf, item);
203                }
204                buf.push(APPENDS);
205            }
206        }
207        PickleValue::Dict(pairs) => {
208            buf.push(EMPTY_DICT);
209            if !pairs.is_empty() {
210                buf.push(MARK);
211                for (k, v) in pairs {
212                    encode_value(buf, k);
213                    encode_value(buf, v);
214                }
215                buf.push(SETITEMS);
216            }
217        }
218    }
219}
220
221fn encode_int(buf: &mut Vec<u8>, n: i64) {
222    if n >= 0 && n < 256 {
223        buf.push(BININT1);
224        buf.push(n as u8);
225    } else if n >= 0 && n < 65536 {
226        buf.push(BININT2);
227        buf.extend_from_slice(&(n as u16).to_le_bytes());
228    } else if n >= i32::MIN as i64 && n <= i32::MAX as i64 {
229        buf.push(BININT4);
230        buf.extend_from_slice(&(n as i32).to_le_bytes());
231    } else {
232        // Use LONG1 for values that don't fit in i32
233        buf.push(LONG1);
234        let bytes = long_to_bytes(n);
235        buf.push(bytes.len() as u8);
236        buf.extend_from_slice(&bytes);
237    }
238}
239
240fn long_to_bytes(n: i64) -> Vec<u8> {
241    if n == 0 {
242        return vec![];
243    }
244    let bytes = n.to_le_bytes();
245    // Trim trailing 0x00 (positive) or 0xFF (negative) bytes
246    let mut len = 8;
247    if n > 0 {
248        while len > 1 && bytes[len - 1] == 0x00 {
249            len -= 1;
250        }
251        // If high bit is set, add a 0x00 byte
252        if bytes[len - 1] & 0x80 != 0 {
253            let mut result = bytes[..len].to_vec();
254            result.push(0x00);
255            return result;
256        }
257    } else {
258        while len > 1 && bytes[len - 1] == 0xFF {
259            len -= 1;
260        }
261        // If high bit is not set, add a 0xFF byte
262        if bytes[len - 1] & 0x80 == 0 {
263            let mut result = bytes[..len].to_vec();
264            result.push(0xFF);
265            return result;
266        }
267    }
268    bytes[..len].to_vec()
269}
270
271/// Decode error.
272#[derive(Debug)]
273pub enum DecodeError {
274    UnexpectedEnd,
275    UnknownOpcode(u8),
276    InvalidUtf8,
277    StackUnderflow,
278    NoMarkFound,
279    NoStop,
280    UnsupportedGlobal(String),
281}
282
283impl std::fmt::Display for DecodeError {
284    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285        match self {
286            DecodeError::UnexpectedEnd => write!(f, "unexpected end of pickle data"),
287            DecodeError::UnknownOpcode(op) => write!(f, "unknown pickle opcode: 0x{:02x}", op),
288            DecodeError::InvalidUtf8 => write!(f, "invalid UTF-8 in pickle string"),
289            DecodeError::StackUnderflow => write!(f, "stack underflow"),
290            DecodeError::NoMarkFound => write!(f, "no mark found on stack"),
291            DecodeError::NoStop => write!(f, "no STOP opcode found"),
292            DecodeError::UnsupportedGlobal(name) => {
293                write!(f, "unsupported global: {}", name)
294            }
295        }
296    }
297}
298
299impl std::error::Error for DecodeError {}
300
301/// Decode pickle protocol 2 bytes into a PickleValue.
302pub fn decode(data: &[u8]) -> Result<PickleValue, DecodeError> {
303    let mut stack: Vec<PickleValue> = Vec::new();
304    let mut memo: HashMap<u32, PickleValue> = HashMap::new();
305    let mut memo_counter: u32 = 0;
306    let mut pos = 0;
307
308    // Skip protocol header if present
309    if pos < data.len() && data[pos] == PROTO {
310        pos += 2; // skip PROTO + version byte
311    }
312
313    loop {
314        if pos >= data.len() {
315            return Err(DecodeError::NoStop);
316        }
317
318        let op = data[pos];
319        pos += 1;
320
321        match op {
322            STOP => {
323                return stack.pop().ok_or(DecodeError::StackUnderflow);
324            }
325            NONE => stack.push(PickleValue::None),
326            NEWTRUE => stack.push(PickleValue::Bool(true)),
327            NEWFALSE => stack.push(PickleValue::Bool(false)),
328            BININT1 => {
329                if pos >= data.len() {
330                    return Err(DecodeError::UnexpectedEnd);
331                }
332                stack.push(PickleValue::Int(data[pos] as i64));
333                pos += 1;
334            }
335            BININT2 => {
336                if pos + 2 > data.len() {
337                    return Err(DecodeError::UnexpectedEnd);
338                }
339                let val = u16::from_le_bytes([data[pos], data[pos + 1]]);
340                stack.push(PickleValue::Int(val as i64));
341                pos += 2;
342            }
343            BININT4 => {
344                if pos + 4 > data.len() {
345                    return Err(DecodeError::UnexpectedEnd);
346                }
347                let val =
348                    i32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
349                stack.push(PickleValue::Int(val as i64));
350                pos += 4;
351            }
352            LONG1 => {
353                if pos >= data.len() {
354                    return Err(DecodeError::UnexpectedEnd);
355                }
356                let n = data[pos] as usize;
357                pos += 1;
358                if pos + n > data.len() {
359                    return Err(DecodeError::UnexpectedEnd);
360                }
361                let val = bytes_to_long(&data[pos..pos + n]);
362                stack.push(PickleValue::Int(val));
363                pos += n;
364            }
365            BINFLOAT => {
366                if pos + 8 > data.len() {
367                    return Err(DecodeError::UnexpectedEnd);
368                }
369                let val = f64::from_be_bytes([
370                    data[pos],
371                    data[pos + 1],
372                    data[pos + 2],
373                    data[pos + 3],
374                    data[pos + 4],
375                    data[pos + 5],
376                    data[pos + 6],
377                    data[pos + 7],
378                ]);
379                stack.push(PickleValue::Float(val));
380                pos += 8;
381            }
382            SHORT_BINUNICODE => {
383                if pos >= data.len() {
384                    return Err(DecodeError::UnexpectedEnd);
385                }
386                let len = data[pos] as usize;
387                pos += 1;
388                if pos + len > data.len() {
389                    return Err(DecodeError::UnexpectedEnd);
390                }
391                let s = std::str::from_utf8(&data[pos..pos + len])
392                    .map_err(|_| DecodeError::InvalidUtf8)?;
393                stack.push(PickleValue::String(s.to_string()));
394                pos += len;
395            }
396            BINUNICODE => {
397                if pos + 4 > data.len() {
398                    return Err(DecodeError::UnexpectedEnd);
399                }
400                let len =
401                    u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]])
402                        as usize;
403                pos += 4;
404                if pos + len > data.len() {
405                    return Err(DecodeError::UnexpectedEnd);
406                }
407                let s = std::str::from_utf8(&data[pos..pos + len])
408                    .map_err(|_| DecodeError::InvalidUtf8)?;
409                stack.push(PickleValue::String(s.to_string()));
410                pos += len;
411            }
412            SHORT_BINSTRING => {
413                // Protocol 0/1 short binary string (used as bytes in some pickles)
414                if pos >= data.len() {
415                    return Err(DecodeError::UnexpectedEnd);
416                }
417                let len = data[pos] as usize;
418                pos += 1;
419                if pos + len > data.len() {
420                    return Err(DecodeError::UnexpectedEnd);
421                }
422                stack.push(PickleValue::Bytes(data[pos..pos + len].to_vec()));
423                pos += len;
424            }
425            BINSTRING => {
426                // Protocol 0/1 binary string
427                if pos + 4 > data.len() {
428                    return Err(DecodeError::UnexpectedEnd);
429                }
430                let len =
431                    i32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]])
432                        as usize;
433                pos += 4;
434                if pos + len > data.len() {
435                    return Err(DecodeError::UnexpectedEnd);
436                }
437                stack.push(PickleValue::Bytes(data[pos..pos + len].to_vec()));
438                pos += len;
439            }
440            SHORT_BINBYTES => {
441                // SHORT_BINBYTES is actually opcode 'B' = 0x42
442                // Wait, Python pickle docs say:
443                // SHORT_BINBYTES = b'B' (no, that's wrong)
444                // Actually: BINBYTES = b'B', SHORT_BINBYTES = b'C'
445                // Let me just handle both...
446                if pos >= data.len() {
447                    return Err(DecodeError::UnexpectedEnd);
448                }
449                let len = data[pos] as usize;
450                pos += 1;
451                if pos + len > data.len() {
452                    return Err(DecodeError::UnexpectedEnd);
453                }
454                stack.push(PickleValue::Bytes(data[pos..pos + len].to_vec()));
455                pos += len;
456            }
457            BINBYTES => {
458                if pos + 4 > data.len() {
459                    return Err(DecodeError::UnexpectedEnd);
460                }
461                let len =
462                    u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]])
463                        as usize;
464                pos += 4;
465                if pos + len > data.len() {
466                    return Err(DecodeError::UnexpectedEnd);
467                }
468                stack.push(PickleValue::Bytes(data[pos..pos + len].to_vec()));
469                pos += len;
470            }
471            EMPTY_LIST => stack.push(PickleValue::List(Vec::new())),
472            EMPTY_DICT => stack.push(PickleValue::Dict(Vec::new())),
473            EMPTY_TUPLE => stack.push(PickleValue::List(Vec::new())), // treat tuple as list
474            MARK => stack.push(PickleValue::String("__mark__".into())),
475            FRAME => {
476                // Protocol 4: 8-byte frame length prefix. We just skip it
477                // since we already have the full data.
478                if pos + 8 > data.len() {
479                    return Err(DecodeError::UnexpectedEnd);
480                }
481                pos += 8;
482            }
483            MEMOIZE => {
484                // Protocol 4: implicit memo, auto-assigns next index
485                if let Some(val) = stack.last() {
486                    memo.insert(memo_counter, val.clone());
487                }
488                memo_counter += 1;
489            }
490            APPEND => {
491                // Pop item, append to list on top of stack
492                let item = stack.pop().ok_or(DecodeError::StackUnderflow)?;
493                if let Some(PickleValue::List(ref mut list)) = stack.last_mut() {
494                    list.push(item);
495                } else {
496                    return Err(DecodeError::StackUnderflow);
497                }
498            }
499            APPENDS => {
500                // Pop items until mark, then append them to the list before the mark
501                let mark_pos = find_mark(&stack)?;
502                let items: Vec<PickleValue> = stack.drain(mark_pos + 1..).collect();
503                stack.pop(); // remove mark
504                if let Some(PickleValue::List(ref mut list)) = stack.last_mut() {
505                    list.extend(items);
506                } else {
507                    return Err(DecodeError::StackUnderflow);
508                }
509            }
510            SETITEM => {
511                // Pop value, pop key, set on dict at top of stack
512                let value = stack.pop().ok_or(DecodeError::StackUnderflow)?;
513                let key = stack.pop().ok_or(DecodeError::StackUnderflow)?;
514                if let Some(PickleValue::Dict(ref mut dict)) = stack.last_mut() {
515                    dict.push((key, value));
516                } else {
517                    return Err(DecodeError::StackUnderflow);
518                }
519            }
520            SETITEMS => {
521                // Pop key-value pairs until mark, then set them on the dict before the mark
522                let mark_pos = find_mark(&stack)?;
523                let items: Vec<PickleValue> = stack.drain(mark_pos + 1..).collect();
524                stack.pop(); // remove mark
525                if let Some(PickleValue::Dict(ref mut dict)) = stack.last_mut() {
526                    for pair in items.chunks_exact(2) {
527                        dict.push((pair[0].clone(), pair[1].clone()));
528                    }
529                } else {
530                    return Err(DecodeError::StackUnderflow);
531                }
532            }
533            TUPLE1 => {
534                let a = stack.pop().ok_or(DecodeError::StackUnderflow)?;
535                stack.push(PickleValue::List(vec![a]));
536            }
537            TUPLE2 => {
538                let b = stack.pop().ok_or(DecodeError::StackUnderflow)?;
539                let a = stack.pop().ok_or(DecodeError::StackUnderflow)?;
540                stack.push(PickleValue::List(vec![a, b]));
541            }
542            TUPLE3 => {
543                let c = stack.pop().ok_or(DecodeError::StackUnderflow)?;
544                let b = stack.pop().ok_or(DecodeError::StackUnderflow)?;
545                let a = stack.pop().ok_or(DecodeError::StackUnderflow)?;
546                stack.push(PickleValue::List(vec![a, b, c]));
547            }
548            SHORT_BINBYTES8 => {
549                // Protocol 4: 8-byte length bytes
550                if pos + 8 > data.len() {
551                    return Err(DecodeError::UnexpectedEnd);
552                }
553                let len = u64::from_le_bytes([
554                    data[pos],
555                    data[pos + 1],
556                    data[pos + 2],
557                    data[pos + 3],
558                    data[pos + 4],
559                    data[pos + 5],
560                    data[pos + 6],
561                    data[pos + 7],
562                ]) as usize;
563                pos += 8;
564                if pos + len > data.len() {
565                    return Err(DecodeError::UnexpectedEnd);
566                }
567                stack.push(PickleValue::Bytes(data[pos..pos + len].to_vec()));
568                pos += len;
569            }
570            BINUNICODE8 => {
571                // Protocol 4: 8-byte length unicode
572                if pos + 8 > data.len() {
573                    return Err(DecodeError::UnexpectedEnd);
574                }
575                let len = u64::from_le_bytes([
576                    data[pos],
577                    data[pos + 1],
578                    data[pos + 2],
579                    data[pos + 3],
580                    data[pos + 4],
581                    data[pos + 5],
582                    data[pos + 6],
583                    data[pos + 7],
584                ]) as usize;
585                pos += 8;
586                if pos + len > data.len() {
587                    return Err(DecodeError::UnexpectedEnd);
588                }
589                let s = std::str::from_utf8(&data[pos..pos + len])
590                    .map_err(|_| DecodeError::InvalidUtf8)?;
591                stack.push(PickleValue::String(s.to_string()));
592                pos += len;
593            }
594            BYTEARRAY8 => {
595                // Protocol 5: 8-byte length bytearray (treat as bytes)
596                if pos + 8 > data.len() {
597                    return Err(DecodeError::UnexpectedEnd);
598                }
599                let len = u64::from_le_bytes([
600                    data[pos],
601                    data[pos + 1],
602                    data[pos + 2],
603                    data[pos + 3],
604                    data[pos + 4],
605                    data[pos + 5],
606                    data[pos + 6],
607                    data[pos + 7],
608                ]) as usize;
609                pos += 8;
610                if pos + len > data.len() {
611                    return Err(DecodeError::UnexpectedEnd);
612                }
613                stack.push(PickleValue::Bytes(data[pos..pos + len].to_vec()));
614                pos += len;
615            }
616            BINPUT => {
617                if pos >= data.len() {
618                    return Err(DecodeError::UnexpectedEnd);
619                }
620                let idx = data[pos] as u32;
621                pos += 1;
622                if let Some(val) = stack.last() {
623                    memo.insert(idx, val.clone());
624                }
625            }
626            LONG_BINPUT => {
627                if pos + 4 > data.len() {
628                    return Err(DecodeError::UnexpectedEnd);
629                }
630                let idx =
631                    u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
632                pos += 4;
633                if let Some(val) = stack.last() {
634                    memo.insert(idx, val.clone());
635                }
636            }
637            BINGET => {
638                if pos >= data.len() {
639                    return Err(DecodeError::UnexpectedEnd);
640                }
641                let idx = data[pos] as u32;
642                pos += 1;
643                let val = memo.get(&idx).cloned().ok_or(DecodeError::StackUnderflow)?;
644                stack.push(val);
645            }
646            LONG_BINGET => {
647                if pos + 4 > data.len() {
648                    return Err(DecodeError::UnexpectedEnd);
649                }
650                let idx =
651                    u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
652                pos += 4;
653                let val = memo.get(&idx).cloned().ok_or(DecodeError::StackUnderflow)?;
654                stack.push(val);
655            }
656            GLOBAL => {
657                // Read module\nname\n
658                let nl1 = data[pos..]
659                    .iter()
660                    .position(|&b| b == b'\n')
661                    .ok_or(DecodeError::UnexpectedEnd)?;
662                let module = std::str::from_utf8(&data[pos..pos + nl1])
663                    .map_err(|_| DecodeError::InvalidUtf8)?;
664                pos += nl1 + 1;
665                let nl2 = data[pos..]
666                    .iter()
667                    .position(|&b| b == b'\n')
668                    .ok_or(DecodeError::UnexpectedEnd)?;
669                let name = std::str::from_utf8(&data[pos..pos + nl2])
670                    .map_err(|_| DecodeError::InvalidUtf8)?;
671                pos += nl2 + 1;
672
673                // Only allow _codecs.encode (for bytes encoding)
674                if module == "_codecs" && name == "encode" {
675                    stack.push(PickleValue::String("__codecs_encode__".into()));
676                } else {
677                    return Err(DecodeError::UnsupportedGlobal(format!(
678                        "{}.{}",
679                        module, name
680                    )));
681                }
682            }
683            REDUCE => {
684                // Pop args tuple and callable, apply
685                let args = stack.pop().ok_or(DecodeError::StackUnderflow)?;
686                let callable = stack.pop().ok_or(DecodeError::StackUnderflow)?;
687
688                if let PickleValue::String(ref s) = callable {
689                    if s == "__codecs_encode__" {
690                        // args should be a tuple (string, encoding)
691                        if let PickleValue::List(ref items) = args {
692                            if let Some(PickleValue::String(ref text)) = items.first() {
693                                // Convert latin-1 string back to bytes
694                                let bytes: Vec<u8> = text.chars().map(|c| c as u8).collect();
695                                stack.push(PickleValue::Bytes(bytes));
696                            } else {
697                                stack.push(PickleValue::None);
698                            }
699                        } else {
700                            stack.push(PickleValue::None);
701                        }
702                    } else {
703                        return Err(DecodeError::UnsupportedGlobal(s.clone()));
704                    }
705                } else {
706                    return Err(DecodeError::StackUnderflow);
707                }
708            }
709            other => {
710                return Err(DecodeError::UnknownOpcode(other));
711            }
712        }
713    }
714}
715
716fn bytes_to_long(bytes: &[u8]) -> i64 {
717    if bytes.is_empty() {
718        return 0;
719    }
720    let negative = bytes[bytes.len() - 1] & 0x80 != 0;
721    let mut result: i64 = 0;
722    for (i, &b) in bytes.iter().enumerate() {
723        result |= (b as i64) << (i * 8);
724    }
725    if negative {
726        // Sign-extend
727        let bits = bytes.len() * 8;
728        if bits < 64 {
729            result |= !0i64 << bits;
730        }
731    }
732    result
733}
734
735fn find_mark(stack: &[PickleValue]) -> Result<usize, DecodeError> {
736    for i in (0..stack.len()).rev() {
737        if let PickleValue::String(ref s) = stack[i] {
738            if s == "__mark__" {
739                return Ok(i);
740            }
741        }
742    }
743    Err(DecodeError::NoMarkFound)
744}
745
746#[cfg(test)]
747mod tests {
748    use super::*;
749
750    #[test]
751    fn roundtrip_none() {
752        let val = PickleValue::None;
753        let encoded = encode(&val);
754        let decoded = decode(&encoded).unwrap();
755        assert_eq!(decoded, val);
756    }
757
758    #[test]
759    fn roundtrip_bool_true() {
760        let val = PickleValue::Bool(true);
761        let encoded = encode(&val);
762        let decoded = decode(&encoded).unwrap();
763        assert_eq!(decoded, val);
764    }
765
766    #[test]
767    fn roundtrip_bool_false() {
768        let val = PickleValue::Bool(false);
769        let encoded = encode(&val);
770        let decoded = decode(&encoded).unwrap();
771        assert_eq!(decoded, val);
772    }
773
774    #[test]
775    fn roundtrip_int_small() {
776        let val = PickleValue::Int(42);
777        let encoded = encode(&val);
778        let decoded = decode(&encoded).unwrap();
779        assert_eq!(decoded, val);
780    }
781
782    #[test]
783    fn roundtrip_int_medium() {
784        let val = PickleValue::Int(1000);
785        let encoded = encode(&val);
786        let decoded = decode(&encoded).unwrap();
787        assert_eq!(decoded, val);
788    }
789
790    #[test]
791    fn roundtrip_int_large() {
792        let val = PickleValue::Int(100000);
793        let encoded = encode(&val);
794        let decoded = decode(&encoded).unwrap();
795        assert_eq!(decoded, val);
796    }
797
798    #[test]
799    fn roundtrip_int_negative() {
800        let val = PickleValue::Int(-42);
801        let encoded = encode(&val);
802        let decoded = decode(&encoded).unwrap();
803        assert_eq!(decoded, val);
804    }
805
806    #[test]
807    fn roundtrip_float() {
808        let val = PickleValue::Float(3.14159);
809        let encoded = encode(&val);
810        let decoded = decode(&encoded).unwrap();
811        assert_eq!(decoded, val);
812    }
813
814    #[test]
815    fn roundtrip_string_short() {
816        let val = PickleValue::String("hello".into());
817        let encoded = encode(&val);
818        let decoded = decode(&encoded).unwrap();
819        assert_eq!(decoded, val);
820    }
821
822    #[test]
823    fn roundtrip_string_long() {
824        let val = PickleValue::String("x".repeat(300));
825        let encoded = encode(&val);
826        let decoded = decode(&encoded).unwrap();
827        assert_eq!(decoded, val);
828    }
829
830    #[test]
831    fn roundtrip_bytes() {
832        let val = PickleValue::Bytes(vec![0, 1, 2, 3, 255]);
833        let encoded = encode(&val);
834        let decoded = decode(&encoded).unwrap();
835        assert_eq!(decoded, val);
836    }
837
838    #[test]
839    fn roundtrip_empty_list() {
840        let val = PickleValue::List(vec![]);
841        let encoded = encode(&val);
842        let decoded = decode(&encoded).unwrap();
843        assert_eq!(decoded, val);
844    }
845
846    #[test]
847    fn roundtrip_list() {
848        let val = PickleValue::List(vec![
849            PickleValue::Int(1),
850            PickleValue::String("two".into()),
851            PickleValue::Bool(true),
852        ]);
853        let encoded = encode(&val);
854        let decoded = decode(&encoded).unwrap();
855        assert_eq!(decoded, val);
856    }
857
858    #[test]
859    fn roundtrip_empty_dict() {
860        let val = PickleValue::Dict(vec![]);
861        let encoded = encode(&val);
862        let decoded = decode(&encoded).unwrap();
863        assert_eq!(decoded, val);
864    }
865
866    #[test]
867    fn roundtrip_dict() {
868        let val = PickleValue::Dict(vec![
869            (PickleValue::String("key".into()), PickleValue::Int(42)),
870            (PickleValue::String("flag".into()), PickleValue::Bool(false)),
871        ]);
872        let encoded = encode(&val);
873        let decoded = decode(&encoded).unwrap();
874        assert_eq!(decoded, val);
875    }
876
877    #[test]
878    fn roundtrip_nested() {
879        let val = PickleValue::Dict(vec![
880            (
881                PickleValue::String("list".into()),
882                PickleValue::List(vec![
883                    PickleValue::Int(1),
884                    PickleValue::Dict(vec![(
885                        PickleValue::String("inner".into()),
886                        PickleValue::None,
887                    )]),
888                ]),
889            ),
890            (
891                PickleValue::String("bytes".into()),
892                PickleValue::Bytes(vec![0xDE, 0xAD]),
893            ),
894        ]);
895        let encoded = encode(&val);
896        let decoded = decode(&encoded).unwrap();
897        assert_eq!(decoded, val);
898    }
899
900    #[test]
901    fn reject_unknown_opcode() {
902        // 0x80 0x02 = protocol 2, then 0xFF = unknown
903        let data = vec![0x80, 0x02, 0xFF];
904        assert!(decode(&data).is_err());
905    }
906
907    #[test]
908    fn dict_get_helper() {
909        let val = PickleValue::Dict(vec![(
910            PickleValue::String("get".into()),
911            PickleValue::String("interface_stats".into()),
912        )]);
913        assert_eq!(val.get("get").unwrap().as_str().unwrap(), "interface_stats");
914        assert!(val.get("missing").is_none());
915    }
916
917    #[test]
918    fn roundtrip_int_zero() {
919        let val = PickleValue::Int(0);
920        let encoded = encode(&val);
921        let decoded = decode(&encoded).unwrap();
922        assert_eq!(decoded, val);
923    }
924
925    #[test]
926    fn roundtrip_int_255() {
927        let val = PickleValue::Int(255);
928        let encoded = encode(&val);
929        let decoded = decode(&encoded).unwrap();
930        assert_eq!(decoded, val);
931    }
932
933    #[test]
934    fn roundtrip_bytes_empty() {
935        let val = PickleValue::Bytes(vec![]);
936        let encoded = encode(&val);
937        let decoded = decode(&encoded).unwrap();
938        assert_eq!(decoded, val);
939    }
940
941    #[test]
942    fn roundtrip_large_int() {
943        let val = PickleValue::Int(i64::MAX);
944        let encoded = encode(&val);
945        let decoded = decode(&encoded).unwrap();
946        assert_eq!(decoded, val);
947    }
948
949    #[test]
950    fn roundtrip_negative_large_int() {
951        let val = PickleValue::Int(i64::MIN);
952        let encoded = encode(&val);
953        let decoded = decode(&encoded).unwrap();
954        assert_eq!(decoded, val);
955    }
956
957    #[test]
958    fn decode_python_dict() {
959        // Manually constructed protocol 2 pickle of {"get": "stats"}
960        // PROTO 2, EMPTY_DICT, MARK, SHORT_BINUNICODE 3 "get", SHORT_BINUNICODE 5 "stats", SETITEMS, STOP
961        let data = vec![
962            0x80, 0x02, // PROTO 2
963            b'}', // EMPTY_DICT
964            b'(', // MARK
965            0x8c, 3, b'g', b'e', b't', // SHORT_BINUNICODE "get"
966            0x8c, 5, b's', b't', b'a', b't', b's', // SHORT_BINUNICODE "stats"
967            b'u', // SETITEMS
968            b'.', // STOP
969        ];
970        let val = decode(&data).unwrap();
971        assert_eq!(val.get("get").unwrap().as_str().unwrap(), "stats");
972    }
973
974    #[test]
975    fn decode_protocol4_dict() {
976        // Protocol 4 pickle of {"get": "interface_stats"} (from Python 3.8+)
977        // Generated by: pickle.dumps({"get": "interface_stats"})
978        let data = vec![
979            0x80, 0x04, // PROTO 4
980            0x95, 0x1c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // FRAME (28 bytes)
981            b'}', // EMPTY_DICT
982            0x94, // MEMOIZE
983            0x8c, 0x03, b'g', b'e', b't', // SHORT_BINUNICODE "get"
984            0x94, // MEMOIZE
985            0x8c, 0x0f, // SHORT_BINUNICODE (15 bytes)
986            b'i', b'n', b't', b'e', b'r', b'f', b'a', b'c', b'e', b'_', b's', b't', b'a', b't',
987            b's', 0x94, // MEMOIZE
988            b's', // SETITEM
989            b'.', // STOP
990        ];
991        let val = decode(&data).unwrap();
992        assert_eq!(val.get("get").unwrap().as_str().unwrap(), "interface_stats");
993    }
994
995    #[test]
996    fn decode_protocol4_with_bytes() {
997        // Protocol 4 pickle of {"drop": "path", "destination_hash": b"\x01\x02\x03"}
998        let data = vec![
999            0x80, 0x04, // PROTO 4
1000            0x95, 0x2c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // FRAME
1001            b'}', // EMPTY_DICT
1002            0x94, // MEMOIZE
1003            b'(', // MARK
1004            0x8c, 0x04, b'd', b'r', b'o', b'p', // SHORT_BINUNICODE "drop"
1005            0x94, // MEMOIZE
1006            0x8c, 0x04, b'p', b'a', b't', b'h', // SHORT_BINUNICODE "path"
1007            0x94, // MEMOIZE
1008            0x8c, 0x10, b'd', b'e', b's', b't', b'i', b'n', b'a', b't', b'i', b'o', b'n', b'_',
1009            b'h', b'a', b's', b'h', 0x94, // MEMOIZE
1010            b'C', 0x03, 0x01, 0x02, 0x03, // SHORT_BINBYTES 3 bytes
1011            0x94, // MEMOIZE
1012            b'u', // SETITEMS
1013            b'.', // STOP
1014        ];
1015        let val = decode(&data).unwrap();
1016        assert_eq!(val.get("drop").unwrap().as_str().unwrap(), "path");
1017        assert_eq!(
1018            val.get("destination_hash").unwrap().as_bytes().unwrap(),
1019            &[1, 2, 3]
1020        );
1021    }
1022}