Skip to main content

rns_net/
pickle.rs

1//! Minimal pickle codec (protocols 2-5).
2//!
3//! Supports a subset of pickle opcodes sufficient for RPC serialization
4//! compatible with Python's `multiprocessing.connection`.
5//!
6//! Encoder always produces protocol 2 (maximum compatibility).
7//! Decoder accepts protocols 2-5 (Python 3.8+ defaults to protocol 4/5).
8//!
9//! Security: rejects unknown opcodes (no arbitrary code execution).
10
11use std::collections::HashMap;
12
13// Pickle opcodes (protocol 2)
14const PROTO: u8 = 0x80;
15const STOP: u8 = b'.';
16const NONE: u8 = b'N';
17const NEWTRUE: u8 = 0x88;
18const NEWFALSE: u8 = 0x89;
19const BININT1: u8 = b'K';
20const BININT2: u8 = b'M';
21const BININT4: u8 = b'J';
22const BINFLOAT: u8 = b'G';
23const SHORT_BINUNICODE: u8 = 0x8c;
24const BINUNICODE: u8 = b'X';
25const BINBYTES: u8 = b'B'; // protocol 3+
26const SHORT_BINBYTES: u8 = b'C'; // protocol 3+
27const EMPTY_LIST: u8 = b']';
28const EMPTY_DICT: u8 = b'}';
29const APPENDS: u8 = b'e';
30const APPEND: u8 = b'a';
31const SETITEM: u8 = b's';
32const SETITEMS: u8 = b'u';
33const MARK: u8 = b'(';
34const BINPUT: u8 = b'q';
35const LONG_BINPUT: u8 = b'r';
36const BINGET: u8 = b'h';
37const LONG_BINGET: u8 = b'j';
38const GLOBAL: u8 = b'c';
39const REDUCE: u8 = b'R';
40const TUPLE1: u8 = 0x85;
41const TUPLE2: u8 = 0x86;
42const TUPLE3: u8 = 0x87;
43const EMPTY_TUPLE: u8 = b')';
44const LONG1: u8 = 0x8a;
45const SHORT_BINSTRING: u8 = b'U'; // protocol 0/1 but appears in some pickles
46const BINSTRING: u8 = b'T'; // protocol 0/1
47// Protocol 4+ opcodes
48const FRAME: u8 = 0x95;
49const MEMOIZE: u8 = 0x94;
50const SHORT_BINBYTES8: u8 = 0x8e; // protocol 4: 8-byte length bytes
51const BINUNICODE8: u8 = 0x8d; // protocol 4: 8-byte length unicode
52const BYTEARRAY8: u8 = 0x96; // protocol 5: bytearray
53
54/// A pickle value.
55#[derive(Debug, Clone, PartialEq)]
56pub enum PickleValue {
57    None,
58    Bool(bool),
59    Int(i64),
60    Float(f64),
61    String(String),
62    Bytes(Vec<u8>),
63    List(Vec<PickleValue>),
64    Dict(Vec<(PickleValue, PickleValue)>),
65}
66
67impl PickleValue {
68    /// Get as string reference if this is a String variant.
69    pub fn as_str(&self) -> Option<&str> {
70        match self {
71            PickleValue::String(s) => Some(s),
72            _ => None,
73        }
74    }
75
76    /// Get as i64 if this is an Int variant.
77    pub fn as_int(&self) -> Option<i64> {
78        match self {
79            PickleValue::Int(n) => Some(*n),
80            _ => None,
81        }
82    }
83
84    /// Get as f64 if this is a Float variant.
85    pub fn as_float(&self) -> Option<f64> {
86        match self {
87            PickleValue::Float(f) => Some(*f),
88            _ => None,
89        }
90    }
91
92    /// Get as bool if this is a Bool variant.
93    pub fn as_bool(&self) -> Option<bool> {
94        match self {
95            PickleValue::Bool(b) => Some(*b),
96            _ => None,
97        }
98    }
99
100    /// Get as bytes reference if this is a Bytes variant.
101    pub fn as_bytes(&self) -> Option<&[u8]> {
102        match self {
103            PickleValue::Bytes(b) => Some(b),
104            _ => None,
105        }
106    }
107
108    /// Get as list reference if this is a List variant.
109    pub fn as_list(&self) -> Option<&[PickleValue]> {
110        match self {
111            PickleValue::List(l) => Some(l),
112            _ => None,
113        }
114    }
115
116    /// Look up a key in a Dict by string key.
117    pub fn get(&self, key: &str) -> Option<&PickleValue> {
118        match self {
119            PickleValue::Dict(pairs) => {
120                for (k, v) in pairs {
121                    if let PickleValue::String(s) = k {
122                        if s == key {
123                            return Some(v);
124                        }
125                    }
126                }
127                None
128            }
129            _ => None,
130        }
131    }
132}
133
134/// Encode a PickleValue as pickle protocol 2 bytes.
135pub fn encode(value: &PickleValue) -> Vec<u8> {
136    let mut buf = Vec::new();
137    buf.push(PROTO);
138    buf.push(2); // protocol 2
139    encode_value(&mut buf, value);
140    buf.push(STOP);
141    buf
142}
143
144fn encode_value(buf: &mut Vec<u8>, value: &PickleValue) {
145    match value {
146        PickleValue::None => buf.push(NONE),
147        PickleValue::Bool(true) => buf.push(NEWTRUE),
148        PickleValue::Bool(false) => buf.push(NEWFALSE),
149        PickleValue::Int(n) => encode_int(buf, *n),
150        PickleValue::Float(f) => {
151            buf.push(BINFLOAT);
152            buf.extend_from_slice(&f.to_be_bytes());
153        }
154        PickleValue::String(s) => {
155            let bytes = s.as_bytes();
156            if bytes.len() < 256 {
157                buf.push(SHORT_BINUNICODE);
158                buf.push(bytes.len() as u8);
159            } else {
160                buf.push(BINUNICODE);
161                buf.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
162            }
163            buf.extend_from_slice(bytes);
164        }
165        PickleValue::Bytes(data) => {
166            // Protocol 2 encodes bytes via _codecs.encode trick:
167            // GLOBAL _codecs.encode, then two args, TUPLE2, REDUCE
168            // No MARK needed since TUPLE2 takes exactly 2 items from stack.
169            buf.extend_from_slice(b"c_codecs\nencode\n");
170            // Encode the bytes as a latin-1 unicode string
171            // Bytes 0x00-0x7F map to same UTF-8; 0x80-0xFF need 2-byte UTF-8
172            let mut latin1_utf8 = Vec::with_capacity(data.len() * 2);
173            for &b in data.iter() {
174                if b < 0x80 {
175                    latin1_utf8.push(b);
176                } else {
177                    // UTF-8 encode U+0080..U+00FF
178                    latin1_utf8.push(0xC0 | (b >> 6));
179                    latin1_utf8.push(0x80 | (b & 0x3F));
180                }
181            }
182            if latin1_utf8.len() < 256 {
183                buf.push(SHORT_BINUNICODE);
184                buf.push(latin1_utf8.len() as u8);
185            } else {
186                buf.push(BINUNICODE);
187                buf.extend_from_slice(&(latin1_utf8.len() as u32).to_le_bytes());
188            }
189            buf.extend_from_slice(&latin1_utf8);
190            // encoding name
191            buf.push(SHORT_BINUNICODE);
192            buf.push(7); // "latin-1"
193            buf.extend_from_slice(b"latin-1");
194            buf.push(TUPLE2);
195            buf.push(REDUCE);
196        }
197        PickleValue::List(items) => {
198            buf.push(EMPTY_LIST);
199            if !items.is_empty() {
200                buf.push(MARK);
201                for item in items {
202                    encode_value(buf, item);
203                }
204                buf.push(APPENDS);
205            }
206        }
207        PickleValue::Dict(pairs) => {
208            buf.push(EMPTY_DICT);
209            if !pairs.is_empty() {
210                buf.push(MARK);
211                for (k, v) in pairs {
212                    encode_value(buf, k);
213                    encode_value(buf, v);
214                }
215                buf.push(SETITEMS);
216            }
217        }
218    }
219}
220
221fn encode_int(buf: &mut Vec<u8>, n: i64) {
222    if n >= 0 && n < 256 {
223        buf.push(BININT1);
224        buf.push(n as u8);
225    } else if n >= 0 && n < 65536 {
226        buf.push(BININT2);
227        buf.extend_from_slice(&(n as u16).to_le_bytes());
228    } else if n >= i32::MIN as i64 && n <= i32::MAX as i64 {
229        buf.push(BININT4);
230        buf.extend_from_slice(&(n as i32).to_le_bytes());
231    } else {
232        // Use LONG1 for values that don't fit in i32
233        buf.push(LONG1);
234        let bytes = long_to_bytes(n);
235        buf.push(bytes.len() as u8);
236        buf.extend_from_slice(&bytes);
237    }
238}
239
240fn long_to_bytes(n: i64) -> Vec<u8> {
241    if n == 0 {
242        return vec![];
243    }
244    let bytes = n.to_le_bytes();
245    // Trim trailing 0x00 (positive) or 0xFF (negative) bytes
246    let mut len = 8;
247    if n > 0 {
248        while len > 1 && bytes[len - 1] == 0x00 {
249            len -= 1;
250        }
251        // If high bit is set, add a 0x00 byte
252        if bytes[len - 1] & 0x80 != 0 {
253            let mut result = bytes[..len].to_vec();
254            result.push(0x00);
255            return result;
256        }
257    } else {
258        while len > 1 && bytes[len - 1] == 0xFF {
259            len -= 1;
260        }
261        // If high bit is not set, add a 0xFF byte
262        if bytes[len - 1] & 0x80 == 0 {
263            let mut result = bytes[..len].to_vec();
264            result.push(0xFF);
265            return result;
266        }
267    }
268    bytes[..len].to_vec()
269}
270
271/// Decode error.
272#[derive(Debug)]
273pub enum DecodeError {
274    UnexpectedEnd,
275    UnknownOpcode(u8),
276    InvalidUtf8,
277    StackUnderflow,
278    NoMarkFound,
279    NoStop,
280    UnsupportedGlobal(String),
281}
282
283impl std::fmt::Display for DecodeError {
284    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285        match self {
286            DecodeError::UnexpectedEnd => write!(f, "unexpected end of pickle data"),
287            DecodeError::UnknownOpcode(op) => write!(f, "unknown pickle opcode: 0x{:02x}", op),
288            DecodeError::InvalidUtf8 => write!(f, "invalid UTF-8 in pickle string"),
289            DecodeError::StackUnderflow => write!(f, "stack underflow"),
290            DecodeError::NoMarkFound => write!(f, "no mark found on stack"),
291            DecodeError::NoStop => write!(f, "no STOP opcode found"),
292            DecodeError::UnsupportedGlobal(name) => {
293                write!(f, "unsupported global: {}", name)
294            }
295        }
296    }
297}
298
299impl std::error::Error for DecodeError {}
300
301/// Decode pickle protocol 2 bytes into a PickleValue.
302pub fn decode(data: &[u8]) -> Result<PickleValue, DecodeError> {
303    let mut stack: Vec<PickleValue> = Vec::new();
304    let mut memo: HashMap<u32, PickleValue> = HashMap::new();
305    let mut memo_counter: u32 = 0;
306    let mut pos = 0;
307
308    // Skip protocol header if present
309    if pos < data.len() && data[pos] == PROTO {
310        pos += 2; // skip PROTO + version byte
311    }
312
313    loop {
314        if pos >= data.len() {
315            return Err(DecodeError::NoStop);
316        }
317
318        let op = data[pos];
319        pos += 1;
320
321        match op {
322            STOP => {
323                return stack.pop().ok_or(DecodeError::StackUnderflow);
324            }
325            NONE => stack.push(PickleValue::None),
326            NEWTRUE => stack.push(PickleValue::Bool(true)),
327            NEWFALSE => stack.push(PickleValue::Bool(false)),
328            BININT1 => {
329                if pos >= data.len() {
330                    return Err(DecodeError::UnexpectedEnd);
331                }
332                stack.push(PickleValue::Int(data[pos] as i64));
333                pos += 1;
334            }
335            BININT2 => {
336                if pos + 2 > data.len() {
337                    return Err(DecodeError::UnexpectedEnd);
338                }
339                let val = u16::from_le_bytes([data[pos], data[pos + 1]]);
340                stack.push(PickleValue::Int(val as i64));
341                pos += 2;
342            }
343            BININT4 => {
344                if pos + 4 > data.len() {
345                    return Err(DecodeError::UnexpectedEnd);
346                }
347                let val = i32::from_le_bytes([
348                    data[pos],
349                    data[pos + 1],
350                    data[pos + 2],
351                    data[pos + 3],
352                ]);
353                stack.push(PickleValue::Int(val as i64));
354                pos += 4;
355            }
356            LONG1 => {
357                if pos >= data.len() {
358                    return Err(DecodeError::UnexpectedEnd);
359                }
360                let n = data[pos] as usize;
361                pos += 1;
362                if pos + n > data.len() {
363                    return Err(DecodeError::UnexpectedEnd);
364                }
365                let val = bytes_to_long(&data[pos..pos + n]);
366                stack.push(PickleValue::Int(val));
367                pos += n;
368            }
369            BINFLOAT => {
370                if pos + 8 > data.len() {
371                    return Err(DecodeError::UnexpectedEnd);
372                }
373                let val = f64::from_be_bytes([
374                    data[pos],
375                    data[pos + 1],
376                    data[pos + 2],
377                    data[pos + 3],
378                    data[pos + 4],
379                    data[pos + 5],
380                    data[pos + 6],
381                    data[pos + 7],
382                ]);
383                stack.push(PickleValue::Float(val));
384                pos += 8;
385            }
386            SHORT_BINUNICODE => {
387                if pos >= data.len() {
388                    return Err(DecodeError::UnexpectedEnd);
389                }
390                let len = data[pos] as usize;
391                pos += 1;
392                if pos + len > data.len() {
393                    return Err(DecodeError::UnexpectedEnd);
394                }
395                let s = std::str::from_utf8(&data[pos..pos + len])
396                    .map_err(|_| DecodeError::InvalidUtf8)?;
397                stack.push(PickleValue::String(s.to_string()));
398                pos += len;
399            }
400            BINUNICODE => {
401                if pos + 4 > data.len() {
402                    return Err(DecodeError::UnexpectedEnd);
403                }
404                let len = u32::from_le_bytes([
405                    data[pos],
406                    data[pos + 1],
407                    data[pos + 2],
408                    data[pos + 3],
409                ]) as usize;
410                pos += 4;
411                if pos + len > data.len() {
412                    return Err(DecodeError::UnexpectedEnd);
413                }
414                let s = std::str::from_utf8(&data[pos..pos + len])
415                    .map_err(|_| DecodeError::InvalidUtf8)?;
416                stack.push(PickleValue::String(s.to_string()));
417                pos += len;
418            }
419            SHORT_BINSTRING => {
420                // Protocol 0/1 short binary string (used as bytes in some pickles)
421                if pos >= data.len() {
422                    return Err(DecodeError::UnexpectedEnd);
423                }
424                let len = data[pos] as usize;
425                pos += 1;
426                if pos + len > data.len() {
427                    return Err(DecodeError::UnexpectedEnd);
428                }
429                stack.push(PickleValue::Bytes(data[pos..pos + len].to_vec()));
430                pos += len;
431            }
432            BINSTRING => {
433                // Protocol 0/1 binary string
434                if pos + 4 > data.len() {
435                    return Err(DecodeError::UnexpectedEnd);
436                }
437                let len = i32::from_le_bytes([
438                    data[pos],
439                    data[pos + 1],
440                    data[pos + 2],
441                    data[pos + 3],
442                ]) as usize;
443                pos += 4;
444                if pos + len > data.len() {
445                    return Err(DecodeError::UnexpectedEnd);
446                }
447                stack.push(PickleValue::Bytes(data[pos..pos + len].to_vec()));
448                pos += len;
449            }
450            SHORT_BINBYTES => {
451                // SHORT_BINBYTES is actually opcode 'B' = 0x42
452                // Wait, Python pickle docs say:
453                // SHORT_BINBYTES = b'B' (no, that's wrong)
454                // Actually: BINBYTES = b'B', SHORT_BINBYTES = b'C'
455                // Let me just handle both...
456                if pos >= data.len() {
457                    return Err(DecodeError::UnexpectedEnd);
458                }
459                let len = data[pos] as usize;
460                pos += 1;
461                if pos + len > data.len() {
462                    return Err(DecodeError::UnexpectedEnd);
463                }
464                stack.push(PickleValue::Bytes(data[pos..pos + len].to_vec()));
465                pos += len;
466            }
467            BINBYTES => {
468                if pos + 4 > data.len() {
469                    return Err(DecodeError::UnexpectedEnd);
470                }
471                let len = u32::from_le_bytes([
472                    data[pos],
473                    data[pos + 1],
474                    data[pos + 2],
475                    data[pos + 3],
476                ]) as usize;
477                pos += 4;
478                if pos + len > data.len() {
479                    return Err(DecodeError::UnexpectedEnd);
480                }
481                stack.push(PickleValue::Bytes(data[pos..pos + len].to_vec()));
482                pos += len;
483            }
484            EMPTY_LIST => stack.push(PickleValue::List(Vec::new())),
485            EMPTY_DICT => stack.push(PickleValue::Dict(Vec::new())),
486            EMPTY_TUPLE => stack.push(PickleValue::List(Vec::new())), // treat tuple as list
487            MARK => stack.push(PickleValue::String("__mark__".into())),
488            FRAME => {
489                // Protocol 4: 8-byte frame length prefix. We just skip it
490                // since we already have the full data.
491                if pos + 8 > data.len() {
492                    return Err(DecodeError::UnexpectedEnd);
493                }
494                pos += 8;
495            }
496            MEMOIZE => {
497                // Protocol 4: implicit memo, auto-assigns next index
498                if let Some(val) = stack.last() {
499                    memo.insert(memo_counter, val.clone());
500                }
501                memo_counter += 1;
502            }
503            APPEND => {
504                // Pop item, append to list on top of stack
505                let item = stack.pop().ok_or(DecodeError::StackUnderflow)?;
506                if let Some(PickleValue::List(ref mut list)) = stack.last_mut() {
507                    list.push(item);
508                } else {
509                    return Err(DecodeError::StackUnderflow);
510                }
511            }
512            APPENDS => {
513                // Pop items until mark, then append them to the list before the mark
514                let mark_pos = find_mark(&stack)?;
515                let items: Vec<PickleValue> = stack.drain(mark_pos + 1..).collect();
516                stack.pop(); // remove mark
517                if let Some(PickleValue::List(ref mut list)) = stack.last_mut() {
518                    list.extend(items);
519                } else {
520                    return Err(DecodeError::StackUnderflow);
521                }
522            }
523            SETITEM => {
524                // Pop value, pop key, set on dict at top of stack
525                let value = stack.pop().ok_or(DecodeError::StackUnderflow)?;
526                let key = stack.pop().ok_or(DecodeError::StackUnderflow)?;
527                if let Some(PickleValue::Dict(ref mut dict)) = stack.last_mut() {
528                    dict.push((key, value));
529                } else {
530                    return Err(DecodeError::StackUnderflow);
531                }
532            }
533            SETITEMS => {
534                // Pop key-value pairs until mark, then set them on the dict before the mark
535                let mark_pos = find_mark(&stack)?;
536                let items: Vec<PickleValue> = stack.drain(mark_pos + 1..).collect();
537                stack.pop(); // remove mark
538                if let Some(PickleValue::Dict(ref mut dict)) = stack.last_mut() {
539                    for pair in items.chunks_exact(2) {
540                        dict.push((pair[0].clone(), pair[1].clone()));
541                    }
542                } else {
543                    return Err(DecodeError::StackUnderflow);
544                }
545            }
546            TUPLE1 => {
547                let a = stack.pop().ok_or(DecodeError::StackUnderflow)?;
548                stack.push(PickleValue::List(vec![a]));
549            }
550            TUPLE2 => {
551                let b = stack.pop().ok_or(DecodeError::StackUnderflow)?;
552                let a = stack.pop().ok_or(DecodeError::StackUnderflow)?;
553                stack.push(PickleValue::List(vec![a, b]));
554            }
555            TUPLE3 => {
556                let c = stack.pop().ok_or(DecodeError::StackUnderflow)?;
557                let b = stack.pop().ok_or(DecodeError::StackUnderflow)?;
558                let a = stack.pop().ok_or(DecodeError::StackUnderflow)?;
559                stack.push(PickleValue::List(vec![a, b, c]));
560            }
561            SHORT_BINBYTES8 => {
562                // Protocol 4: 8-byte length bytes
563                if pos + 8 > data.len() {
564                    return Err(DecodeError::UnexpectedEnd);
565                }
566                let len = u64::from_le_bytes([
567                    data[pos], data[pos+1], data[pos+2], data[pos+3],
568                    data[pos+4], data[pos+5], data[pos+6], data[pos+7],
569                ]) as usize;
570                pos += 8;
571                if pos + len > data.len() {
572                    return Err(DecodeError::UnexpectedEnd);
573                }
574                stack.push(PickleValue::Bytes(data[pos..pos + len].to_vec()));
575                pos += len;
576            }
577            BINUNICODE8 => {
578                // Protocol 4: 8-byte length unicode
579                if pos + 8 > data.len() {
580                    return Err(DecodeError::UnexpectedEnd);
581                }
582                let len = u64::from_le_bytes([
583                    data[pos], data[pos+1], data[pos+2], data[pos+3],
584                    data[pos+4], data[pos+5], data[pos+6], data[pos+7],
585                ]) as usize;
586                pos += 8;
587                if pos + len > data.len() {
588                    return Err(DecodeError::UnexpectedEnd);
589                }
590                let s = std::str::from_utf8(&data[pos..pos + len])
591                    .map_err(|_| DecodeError::InvalidUtf8)?;
592                stack.push(PickleValue::String(s.to_string()));
593                pos += len;
594            }
595            BYTEARRAY8 => {
596                // Protocol 5: 8-byte length bytearray (treat as bytes)
597                if pos + 8 > data.len() {
598                    return Err(DecodeError::UnexpectedEnd);
599                }
600                let len = u64::from_le_bytes([
601                    data[pos], data[pos+1], data[pos+2], data[pos+3],
602                    data[pos+4], data[pos+5], data[pos+6], data[pos+7],
603                ]) as usize;
604                pos += 8;
605                if pos + len > data.len() {
606                    return Err(DecodeError::UnexpectedEnd);
607                }
608                stack.push(PickleValue::Bytes(data[pos..pos + len].to_vec()));
609                pos += len;
610            }
611            BINPUT => {
612                if pos >= data.len() {
613                    return Err(DecodeError::UnexpectedEnd);
614                }
615                let idx = data[pos] as u32;
616                pos += 1;
617                if let Some(val) = stack.last() {
618                    memo.insert(idx, val.clone());
619                }
620            }
621            LONG_BINPUT => {
622                if pos + 4 > data.len() {
623                    return Err(DecodeError::UnexpectedEnd);
624                }
625                let idx = u32::from_le_bytes([
626                    data[pos],
627                    data[pos + 1],
628                    data[pos + 2],
629                    data[pos + 3],
630                ]);
631                pos += 4;
632                if let Some(val) = stack.last() {
633                    memo.insert(idx, val.clone());
634                }
635            }
636            BINGET => {
637                if pos >= data.len() {
638                    return Err(DecodeError::UnexpectedEnd);
639                }
640                let idx = data[pos] as u32;
641                pos += 1;
642                let val = memo
643                    .get(&idx)
644                    .cloned()
645                    .ok_or(DecodeError::StackUnderflow)?;
646                stack.push(val);
647            }
648            LONG_BINGET => {
649                if pos + 4 > data.len() {
650                    return Err(DecodeError::UnexpectedEnd);
651                }
652                let idx = u32::from_le_bytes([
653                    data[pos],
654                    data[pos + 1],
655                    data[pos + 2],
656                    data[pos + 3],
657                ]);
658                pos += 4;
659                let val = memo
660                    .get(&idx)
661                    .cloned()
662                    .ok_or(DecodeError::StackUnderflow)?;
663                stack.push(val);
664            }
665            GLOBAL => {
666                // Read module\nname\n
667                let nl1 = data[pos..]
668                    .iter()
669                    .position(|&b| b == b'\n')
670                    .ok_or(DecodeError::UnexpectedEnd)?;
671                let module =
672                    std::str::from_utf8(&data[pos..pos + nl1]).map_err(|_| DecodeError::InvalidUtf8)?;
673                pos += nl1 + 1;
674                let nl2 = data[pos..]
675                    .iter()
676                    .position(|&b| b == b'\n')
677                    .ok_or(DecodeError::UnexpectedEnd)?;
678                let name =
679                    std::str::from_utf8(&data[pos..pos + nl2]).map_err(|_| DecodeError::InvalidUtf8)?;
680                pos += nl2 + 1;
681
682                // Only allow _codecs.encode (for bytes encoding)
683                if module == "_codecs" && name == "encode" {
684                    stack.push(PickleValue::String("__codecs_encode__".into()));
685                } else {
686                    return Err(DecodeError::UnsupportedGlobal(format!(
687                        "{}.{}",
688                        module, name
689                    )));
690                }
691            }
692            REDUCE => {
693                // Pop args tuple and callable, apply
694                let args = stack.pop().ok_or(DecodeError::StackUnderflow)?;
695                let callable = stack.pop().ok_or(DecodeError::StackUnderflow)?;
696
697                if let PickleValue::String(ref s) = callable {
698                    if s == "__codecs_encode__" {
699                        // args should be a tuple (string, encoding)
700                        if let PickleValue::List(ref items) = args {
701                            if let Some(PickleValue::String(ref text)) = items.first() {
702                                // Convert latin-1 string back to bytes
703                                let bytes: Vec<u8> =
704                                    text.chars().map(|c| c as u8).collect();
705                                stack.push(PickleValue::Bytes(bytes));
706                            } else {
707                                stack.push(PickleValue::None);
708                            }
709                        } else {
710                            stack.push(PickleValue::None);
711                        }
712                    } else {
713                        return Err(DecodeError::UnsupportedGlobal(s.clone()));
714                    }
715                } else {
716                    return Err(DecodeError::StackUnderflow);
717                }
718            }
719            other => {
720                return Err(DecodeError::UnknownOpcode(other));
721            }
722        }
723    }
724}
725
726fn bytes_to_long(bytes: &[u8]) -> i64 {
727    if bytes.is_empty() {
728        return 0;
729    }
730    let negative = bytes[bytes.len() - 1] & 0x80 != 0;
731    let mut result: i64 = 0;
732    for (i, &b) in bytes.iter().enumerate() {
733        result |= (b as i64) << (i * 8);
734    }
735    if negative {
736        // Sign-extend
737        let bits = bytes.len() * 8;
738        if bits < 64 {
739            result |= !0i64 << bits;
740        }
741    }
742    result
743}
744
745fn find_mark(stack: &[PickleValue]) -> Result<usize, DecodeError> {
746    for i in (0..stack.len()).rev() {
747        if let PickleValue::String(ref s) = stack[i] {
748            if s == "__mark__" {
749                return Ok(i);
750            }
751        }
752    }
753    Err(DecodeError::NoMarkFound)
754}
755
756#[cfg(test)]
757mod tests {
758    use super::*;
759
760    #[test]
761    fn roundtrip_none() {
762        let val = PickleValue::None;
763        let encoded = encode(&val);
764        let decoded = decode(&encoded).unwrap();
765        assert_eq!(decoded, val);
766    }
767
768    #[test]
769    fn roundtrip_bool_true() {
770        let val = PickleValue::Bool(true);
771        let encoded = encode(&val);
772        let decoded = decode(&encoded).unwrap();
773        assert_eq!(decoded, val);
774    }
775
776    #[test]
777    fn roundtrip_bool_false() {
778        let val = PickleValue::Bool(false);
779        let encoded = encode(&val);
780        let decoded = decode(&encoded).unwrap();
781        assert_eq!(decoded, val);
782    }
783
784    #[test]
785    fn roundtrip_int_small() {
786        let val = PickleValue::Int(42);
787        let encoded = encode(&val);
788        let decoded = decode(&encoded).unwrap();
789        assert_eq!(decoded, val);
790    }
791
792    #[test]
793    fn roundtrip_int_medium() {
794        let val = PickleValue::Int(1000);
795        let encoded = encode(&val);
796        let decoded = decode(&encoded).unwrap();
797        assert_eq!(decoded, val);
798    }
799
800    #[test]
801    fn roundtrip_int_large() {
802        let val = PickleValue::Int(100000);
803        let encoded = encode(&val);
804        let decoded = decode(&encoded).unwrap();
805        assert_eq!(decoded, val);
806    }
807
808    #[test]
809    fn roundtrip_int_negative() {
810        let val = PickleValue::Int(-42);
811        let encoded = encode(&val);
812        let decoded = decode(&encoded).unwrap();
813        assert_eq!(decoded, val);
814    }
815
816    #[test]
817    fn roundtrip_float() {
818        let val = PickleValue::Float(3.14159);
819        let encoded = encode(&val);
820        let decoded = decode(&encoded).unwrap();
821        assert_eq!(decoded, val);
822    }
823
824    #[test]
825    fn roundtrip_string_short() {
826        let val = PickleValue::String("hello".into());
827        let encoded = encode(&val);
828        let decoded = decode(&encoded).unwrap();
829        assert_eq!(decoded, val);
830    }
831
832    #[test]
833    fn roundtrip_string_long() {
834        let val = PickleValue::String("x".repeat(300));
835        let encoded = encode(&val);
836        let decoded = decode(&encoded).unwrap();
837        assert_eq!(decoded, val);
838    }
839
840    #[test]
841    fn roundtrip_bytes() {
842        let val = PickleValue::Bytes(vec![0, 1, 2, 3, 255]);
843        let encoded = encode(&val);
844        let decoded = decode(&encoded).unwrap();
845        assert_eq!(decoded, val);
846    }
847
848    #[test]
849    fn roundtrip_empty_list() {
850        let val = PickleValue::List(vec![]);
851        let encoded = encode(&val);
852        let decoded = decode(&encoded).unwrap();
853        assert_eq!(decoded, val);
854    }
855
856    #[test]
857    fn roundtrip_list() {
858        let val = PickleValue::List(vec![
859            PickleValue::Int(1),
860            PickleValue::String("two".into()),
861            PickleValue::Bool(true),
862        ]);
863        let encoded = encode(&val);
864        let decoded = decode(&encoded).unwrap();
865        assert_eq!(decoded, val);
866    }
867
868    #[test]
869    fn roundtrip_empty_dict() {
870        let val = PickleValue::Dict(vec![]);
871        let encoded = encode(&val);
872        let decoded = decode(&encoded).unwrap();
873        assert_eq!(decoded, val);
874    }
875
876    #[test]
877    fn roundtrip_dict() {
878        let val = PickleValue::Dict(vec![
879            (
880                PickleValue::String("key".into()),
881                PickleValue::Int(42),
882            ),
883            (
884                PickleValue::String("flag".into()),
885                PickleValue::Bool(false),
886            ),
887        ]);
888        let encoded = encode(&val);
889        let decoded = decode(&encoded).unwrap();
890        assert_eq!(decoded, val);
891    }
892
893    #[test]
894    fn roundtrip_nested() {
895        let val = PickleValue::Dict(vec![
896            (
897                PickleValue::String("list".into()),
898                PickleValue::List(vec![
899                    PickleValue::Int(1),
900                    PickleValue::Dict(vec![(
901                        PickleValue::String("inner".into()),
902                        PickleValue::None,
903                    )]),
904                ]),
905            ),
906            (
907                PickleValue::String("bytes".into()),
908                PickleValue::Bytes(vec![0xDE, 0xAD]),
909            ),
910        ]);
911        let encoded = encode(&val);
912        let decoded = decode(&encoded).unwrap();
913        assert_eq!(decoded, val);
914    }
915
916    #[test]
917    fn reject_unknown_opcode() {
918        // 0x80 0x02 = protocol 2, then 0xFF = unknown
919        let data = vec![0x80, 0x02, 0xFF];
920        assert!(decode(&data).is_err());
921    }
922
923    #[test]
924    fn dict_get_helper() {
925        let val = PickleValue::Dict(vec![
926            (PickleValue::String("get".into()), PickleValue::String("interface_stats".into())),
927        ]);
928        assert_eq!(
929            val.get("get").unwrap().as_str().unwrap(),
930            "interface_stats"
931        );
932        assert!(val.get("missing").is_none());
933    }
934
935    #[test]
936    fn roundtrip_int_zero() {
937        let val = PickleValue::Int(0);
938        let encoded = encode(&val);
939        let decoded = decode(&encoded).unwrap();
940        assert_eq!(decoded, val);
941    }
942
943    #[test]
944    fn roundtrip_int_255() {
945        let val = PickleValue::Int(255);
946        let encoded = encode(&val);
947        let decoded = decode(&encoded).unwrap();
948        assert_eq!(decoded, val);
949    }
950
951    #[test]
952    fn roundtrip_bytes_empty() {
953        let val = PickleValue::Bytes(vec![]);
954        let encoded = encode(&val);
955        let decoded = decode(&encoded).unwrap();
956        assert_eq!(decoded, val);
957    }
958
959    #[test]
960    fn roundtrip_large_int() {
961        let val = PickleValue::Int(i64::MAX);
962        let encoded = encode(&val);
963        let decoded = decode(&encoded).unwrap();
964        assert_eq!(decoded, val);
965    }
966
967    #[test]
968    fn roundtrip_negative_large_int() {
969        let val = PickleValue::Int(i64::MIN);
970        let encoded = encode(&val);
971        let decoded = decode(&encoded).unwrap();
972        assert_eq!(decoded, val);
973    }
974
975    #[test]
976    fn decode_python_dict() {
977        // Manually constructed protocol 2 pickle of {"get": "stats"}
978        // PROTO 2, EMPTY_DICT, MARK, SHORT_BINUNICODE 3 "get", SHORT_BINUNICODE 5 "stats", SETITEMS, STOP
979        let data = vec![
980            0x80, 0x02, // PROTO 2
981            b'}',       // EMPTY_DICT
982            b'(',       // MARK
983            0x8c, 3, b'g', b'e', b't', // SHORT_BINUNICODE "get"
984            0x8c, 5, b's', b't', b'a', b't', b's', // SHORT_BINUNICODE "stats"
985            b'u',       // SETITEMS
986            b'.',       // STOP
987        ];
988        let val = decode(&data).unwrap();
989        assert_eq!(val.get("get").unwrap().as_str().unwrap(), "stats");
990    }
991
992    #[test]
993    fn decode_protocol4_dict() {
994        // Protocol 4 pickle of {"get": "interface_stats"} (from Python 3.8+)
995        // Generated by: pickle.dumps({"get": "interface_stats"})
996        let data = vec![
997            0x80, 0x04,  // PROTO 4
998            0x95, 0x1c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // FRAME (28 bytes)
999            b'}',        // EMPTY_DICT
1000            0x94,        // MEMOIZE
1001            0x8c, 0x03, b'g', b'e', b't', // SHORT_BINUNICODE "get"
1002            0x94,        // MEMOIZE
1003            0x8c, 0x0f,  // SHORT_BINUNICODE (15 bytes)
1004            b'i', b'n', b't', b'e', b'r', b'f', b'a', b'c', b'e',
1005            b'_', b's', b't', b'a', b't', b's',
1006            0x94,        // MEMOIZE
1007            b's',        // SETITEM
1008            b'.',        // STOP
1009        ];
1010        let val = decode(&data).unwrap();
1011        assert_eq!(val.get("get").unwrap().as_str().unwrap(), "interface_stats");
1012    }
1013
1014    #[test]
1015    fn decode_protocol4_with_bytes() {
1016        // Protocol 4 pickle of {"drop": "path", "destination_hash": b"\x01\x02\x03"}
1017        let data = vec![
1018            0x80, 0x04,  // PROTO 4
1019            0x95, 0x2c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // FRAME
1020            b'}',        // EMPTY_DICT
1021            0x94,        // MEMOIZE
1022            b'(',        // MARK
1023            0x8c, 0x04, b'd', b'r', b'o', b'p', // SHORT_BINUNICODE "drop"
1024            0x94,        // MEMOIZE
1025            0x8c, 0x04, b'p', b'a', b't', b'h', // SHORT_BINUNICODE "path"
1026            0x94,        // MEMOIZE
1027            0x8c, 0x10, b'd', b'e', b's', b't', b'i', b'n', b'a', b't',
1028            b'i', b'o', b'n', b'_', b'h', b'a', b's', b'h',
1029            0x94,        // MEMOIZE
1030            b'C', 0x03, 0x01, 0x02, 0x03, // SHORT_BINBYTES 3 bytes
1031            0x94,        // MEMOIZE
1032            b'u',        // SETITEMS
1033            b'.',        // STOP
1034        ];
1035        let val = decode(&data).unwrap();
1036        assert_eq!(val.get("drop").unwrap().as_str().unwrap(), "path");
1037        assert_eq!(val.get("destination_hash").unwrap().as_bytes().unwrap(), &[1, 2, 3]);
1038    }
1039}