redpanda_transform_sdk_varint/
lib.rs

1// Copyright 2023 Redpanda Data, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::fmt;
16use std::error::Error;
17
18#[derive(Debug, PartialEq, Eq)]
19pub enum DecodeError {
20    Overflow,
21    ShortRead,
22    ShortReadBuffer {
23        buf_size: usize,
24        payload_remaining: usize,
25    },
26}
27
28impl Error for DecodeError {}
29
30impl fmt::Display for DecodeError {
31    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
32        match self {
33            DecodeError::Overflow => write!(f, "decoded varint would overflow i64::MAX"),
34            DecodeError::ShortRead => write!(f, "short read when decoding varint"),
35            DecodeError::ShortReadBuffer {
36                buf_size,
37                payload_remaining,
38            } => write!(
39                f,
40                "decoded sized buffer required size: {}, but only {} was remaining in buffer",
41                buf_size, payload_remaining
42            ),
43        }
44    }
45}
46
47pub type Result<T> = std::result::Result<T, DecodeError>;
48
49#[derive(PartialEq, Eq)]
50pub struct Decoded<T> {
51    pub value: T,
52    pub read: usize,
53}
54
55impl<T> Decoded<T> {
56    pub fn map<U, F: FnOnce(T) -> U>(self, op: F) -> Decoded<U> {
57        Decoded {
58            value: op(self.value),
59            read: self.read,
60        }
61    }
62}
63
64// The maximum encoded size of an i64
65pub const MAX_LENGTH: usize = 10;
66
67fn zigzag_encode(x: i64) -> u64 {
68    ((x << 1) ^ (x >> 63)) as u64
69}
70
71fn zigzag_decode(x: u64) -> i64 {
72    ((x >> 1) as i64) ^ (-((x & 1) as i64))
73}
74
75fn read_unsigned(payload: &[u8]) -> Result<Decoded<u64>> {
76    let mut decoded: u64 = 0;
77    let mut shift = 0;
78    for (i, b) in payload.iter().enumerate() {
79        if i >= MAX_LENGTH {
80            return Err(DecodeError::Overflow);
81        }
82        decoded |= ((b & 0x7F) as u64) << shift;
83        if b & 0x80 == 0 {
84            return Ok(Decoded {
85                value: decoded,
86                read: i + 1,
87            });
88        }
89        shift += 7;
90    }
91    Err(DecodeError::ShortRead)
92}
93
94pub fn read(payload: &[u8]) -> Result<Decoded<i64>> {
95    read_unsigned(payload).map(|r| r.map(zigzag_decode))
96}
97
98pub fn read_sized_buffer(payload: &[u8]) -> Result<Decoded<Option<&[u8]>>> {
99    let result = read(payload)?;
100    if result.value < 0 {
101        return Ok(result.map(|_| None));
102    }
103    let payload = &payload[result.read..];
104    let buf_size = result.value as usize;
105    if buf_size > payload.len() {
106        return Err(DecodeError::ShortReadBuffer {
107            buf_size,
108            payload_remaining: payload.len(),
109        });
110    }
111    Ok(Decoded {
112        value: Some(&payload[..buf_size]),
113        read: result.read + buf_size,
114    })
115}
116
117fn write_unsigned(payload: &mut Vec<u8>, mut v: u64) {
118    while v >= 0x80 {
119        let b = (v as u8) | 0x80;
120        v >>= 7;
121        payload.push(b);
122    }
123    payload.push(v as u8);
124}
125
126pub fn write(payload: &mut Vec<u8>, v: i64) {
127    write_unsigned(payload, zigzag_encode(v))
128}
129
130pub fn write_sized_buffer(payload: &mut Vec<u8>, buf: Option<&[u8]>) {
131    match buf {
132        Some(b) => {
133            write(payload, b.len() as i64);
134            payload.extend_from_slice(b);
135        }
136        None => write(payload, -1),
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::{
143        read, read_sized_buffer, read_unsigned, write, write_sized_buffer, write_unsigned,
144        zigzag_decode, zigzag_encode,
145    };
146
147    use quickcheck::quickcheck;
148
149    quickcheck! {
150        fn zigzag(n: i64) -> bool {
151            n == zigzag_decode(zigzag_encode(n))
152        }
153    }
154    quickcheck! {
155        fn roundtrip_unsigned(n: u64) -> bool {
156            let mut buf = Vec::new();
157            write_unsigned(&mut buf, n);
158            let r = read_unsigned(&buf[..]).expect("valid buffer");
159            if r.read != buf.len() {
160                panic!("expected to consume the whole buffer: {read} != {remaining}", read = r.read, remaining = buf.len());
161            }
162            r.value == n
163        }
164    }
165    quickcheck! {
166        fn roundtrip_signed(n: i64) -> bool {
167            let mut buf = Vec::new();
168            write(&mut buf, n);
169            let r = read(&buf[..]).expect("valid buffer");
170            if r.read != buf.len() {
171                panic!("expected to consume the whole buffer: {read} != {remaining}", read = r.read, remaining = buf.len());
172            }
173            r.value == n
174        }
175    }
176    quickcheck! {
177        fn roundtrip_buffer(input: Option<Vec<u8>>) -> bool {
178            let mut buf = Vec::new();
179            write_sized_buffer(&mut buf, input.as_ref().map(|b| &b[..]));
180            let r = read_sized_buffer(&buf[..]).expect("valid buffer");
181            if r.read != buf.len() {
182                panic!("expected to consume the whole buffer: {read} != {remaining}", read = r.read, remaining = buf.len());
183            }
184            r.value == input.as_ref().map(|b| &b[..])
185        }
186    }
187}