Skip to main content

reddb_wire/redwire/
cursor.rs

1//! RedWire legacy cursor payload codec.
2
3use crate::legacy::{encode_column_name, encode_value, WireValue};
4use std::fmt;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub struct DeclareCursorPayload {
8    pub cursor_id: u32,
9    pub sql: String,
10}
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub struct FetchPayload {
14    pub cursor_id: u32,
15    pub max_rows: u32,
16}
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub struct CloseCursorPayload {
20    pub cursor_id: u32,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub enum CursorPayloadError {
25    TruncatedDeclareCursorId,
26    TruncatedDeclareSqlLen,
27    TruncatedDeclareSql,
28    InvalidDeclareSql,
29    TruncatedFetchCursorId,
30    TruncatedFetchMaxRows,
31    TruncatedCloseCursorId,
32    SqlTooLarge,
33    ColumnCountOverflow,
34    RowCountOverflow,
35}
36
37impl fmt::Display for CursorPayloadError {
38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39        match self {
40            Self::TruncatedDeclareCursorId => write!(f, "truncated declare cursor_id"),
41            Self::TruncatedDeclareSqlLen => write!(f, "truncated declare sql_len"),
42            Self::TruncatedDeclareSql => write!(f, "truncated declare sql"),
43            Self::InvalidDeclareSql => write!(f, "invalid UTF-8 in declare sql"),
44            Self::TruncatedFetchCursorId => write!(f, "truncated fetch cursor_id"),
45            Self::TruncatedFetchMaxRows => write!(f, "truncated fetch max_rows"),
46            Self::TruncatedCloseCursorId => write!(f, "truncated close cursor_id"),
47            Self::SqlTooLarge => write!(f, "declare sql is too large for RedWire cursor payload"),
48            Self::ColumnCountOverflow => {
49                write!(f, "column count is too large for RedWire cursor payload")
50            }
51            Self::RowCountOverflow => {
52                write!(f, "row count is too large for RedWire cursor payload")
53            }
54        }
55    }
56}
57
58impl std::error::Error for CursorPayloadError {}
59
60pub fn encode_declare_cursor_payload(
61    cursor_id: u32,
62    sql: &str,
63) -> Result<Vec<u8>, CursorPayloadError> {
64    let sql_len = u32::try_from(sql.len()).map_err(|_| CursorPayloadError::SqlTooLarge)?;
65    let mut out = Vec::with_capacity(8 + sql.len());
66    out.extend_from_slice(&cursor_id.to_le_bytes());
67    out.extend_from_slice(&sql_len.to_le_bytes());
68    out.extend_from_slice(sql.as_bytes());
69    Ok(out)
70}
71
72pub fn decode_declare_cursor_payload(
73    payload: &[u8],
74) -> Result<DeclareCursorPayload, CursorPayloadError> {
75    let mut pos = 0usize;
76    let cursor_id = u32::from_le_bytes(read_array(
77        payload,
78        &mut pos,
79        CursorPayloadError::TruncatedDeclareCursorId,
80    )?);
81    let sql_len = u32::from_le_bytes(read_array(
82        payload,
83        &mut pos,
84        CursorPayloadError::TruncatedDeclareSqlLen,
85    )?) as usize;
86    let sql_bytes = read_bytes(
87        payload,
88        &mut pos,
89        sql_len,
90        CursorPayloadError::TruncatedDeclareSql,
91    )?;
92    let sql = std::str::from_utf8(sql_bytes)
93        .map(str::to_string)
94        .map_err(|_| CursorPayloadError::InvalidDeclareSql)?;
95    Ok(DeclareCursorPayload { cursor_id, sql })
96}
97
98pub fn encode_fetch_payload(cursor_id: u32, max_rows: u32) -> Vec<u8> {
99    let mut out = Vec::with_capacity(8);
100    out.extend_from_slice(&cursor_id.to_le_bytes());
101    out.extend_from_slice(&max_rows.to_le_bytes());
102    out
103}
104
105pub fn decode_fetch_payload(payload: &[u8]) -> Result<FetchPayload, CursorPayloadError> {
106    let mut pos = 0usize;
107    let cursor_id = u32::from_le_bytes(read_array(
108        payload,
109        &mut pos,
110        CursorPayloadError::TruncatedFetchCursorId,
111    )?);
112    let max_rows = u32::from_le_bytes(read_array(
113        payload,
114        &mut pos,
115        CursorPayloadError::TruncatedFetchMaxRows,
116    )?);
117    Ok(FetchPayload {
118        cursor_id,
119        max_rows,
120    })
121}
122
123pub fn encode_close_cursor_payload(cursor_id: u32) -> Vec<u8> {
124    cursor_id.to_le_bytes().to_vec()
125}
126
127pub fn decode_close_cursor_payload(
128    payload: &[u8],
129) -> Result<CloseCursorPayload, CursorPayloadError> {
130    let mut pos = 0usize;
131    let cursor_id = u32::from_le_bytes(read_array(
132        payload,
133        &mut pos,
134        CursorPayloadError::TruncatedCloseCursorId,
135    )?);
136    Ok(CloseCursorPayload { cursor_id })
137}
138
139pub fn encode_cursor_ok_payload(
140    cursor_id: u32,
141    columns: &[impl AsRef<str>],
142    total_rows: u64,
143) -> Result<Vec<u8>, CursorPayloadError> {
144    let ncols =
145        u16::try_from(columns.len()).map_err(|_| CursorPayloadError::ColumnCountOverflow)?;
146    let mut out = Vec::with_capacity(4 + 2 + 8 + columns.len() * 16);
147    out.extend_from_slice(&cursor_id.to_le_bytes());
148    out.extend_from_slice(&ncols.to_le_bytes());
149    for col in columns {
150        encode_column_name(&mut out, col.as_ref());
151    }
152    out.extend_from_slice(&total_rows.to_le_bytes());
153    Ok(out)
154}
155
156pub fn encode_cursor_batch_payload(
157    cursor_id: u32,
158    rows: &[Vec<WireValue>],
159    has_more: bool,
160) -> Result<Vec<u8>, CursorPayloadError> {
161    let nrows = u32::try_from(rows.len()).map_err(|_| CursorPayloadError::RowCountOverflow)?;
162    let mut out = Vec::new();
163    out.extend_from_slice(&cursor_id.to_le_bytes());
164    out.extend_from_slice(&nrows.to_le_bytes());
165    out.push(u8::from(has_more));
166    for row in rows {
167        for value in row {
168            encode_value(&mut out, value);
169        }
170    }
171    Ok(out)
172}
173
174fn read_bytes<'a>(
175    payload: &'a [u8],
176    pos: &mut usize,
177    len: usize,
178    err: CursorPayloadError,
179) -> Result<&'a [u8], CursorPayloadError> {
180    let end = pos.checked_add(len).ok_or(err.clone())?;
181    if end > payload.len() {
182        return Err(err);
183    }
184    let bytes = &payload[*pos..end];
185    *pos = end;
186    Ok(bytes)
187}
188
189fn read_array<const N: usize>(
190    payload: &[u8],
191    pos: &mut usize,
192    err: CursorPayloadError,
193) -> Result<[u8; N], CursorPayloadError> {
194    let bytes = read_bytes(payload, pos, N, err)?;
195    let mut out = [0u8; N];
196    out.copy_from_slice(bytes);
197    Ok(out)
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn declare_cursor_payload_round_trips() {
206        let bytes = encode_declare_cursor_payload(7, "SELECT id FROM users").unwrap();
207        assert_eq!(
208            decode_declare_cursor_payload(&bytes).unwrap(),
209            DeclareCursorPayload {
210                cursor_id: 7,
211                sql: "SELECT id FROM users".to_string(),
212            }
213        );
214    }
215
216    #[test]
217    fn fetch_and_close_payloads_round_trip() {
218        assert_eq!(
219            decode_fetch_payload(&encode_fetch_payload(3, 50)).unwrap(),
220            FetchPayload {
221                cursor_id: 3,
222                max_rows: 50,
223            }
224        );
225        assert_eq!(
226            decode_close_cursor_payload(&encode_close_cursor_payload(9)).unwrap(),
227            CloseCursorPayload { cursor_id: 9 }
228        );
229    }
230
231    #[test]
232    fn cursor_ok_and_batch_payloads_encode_expected_headers() {
233        let ok = encode_cursor_ok_payload(5, &["id", "name"], 20).unwrap();
234        assert_eq!(u32::from_le_bytes([ok[0], ok[1], ok[2], ok[3]]), 5);
235        assert_eq!(u16::from_le_bytes([ok[4], ok[5]]), 2);
236
237        let batch = encode_cursor_batch_payload(
238            5,
239            &[vec![WireValue::I64(1), WireValue::Text("ada".to_string())]],
240            true,
241        )
242        .unwrap();
243        assert_eq!(
244            u32::from_le_bytes([batch[0], batch[1], batch[2], batch[3]]),
245            5
246        );
247        assert_eq!(
248            u32::from_le_bytes([batch[4], batch[5], batch[6], batch[7]]),
249            1
250        );
251        assert_eq!(batch[8], 1);
252    }
253
254    #[test]
255    fn cursor_errors_preserve_legacy_messages() {
256        assert_eq!(
257            decode_declare_cursor_payload(&[0, 0, 0])
258                .unwrap_err()
259                .to_string(),
260            "truncated declare cursor_id"
261        );
262        assert_eq!(
263            decode_fetch_payload(&[0, 0, 0, 0]).unwrap_err().to_string(),
264            "truncated fetch max_rows"
265        );
266        assert_eq!(
267            decode_close_cursor_payload(&[0]).unwrap_err().to_string(),
268            "truncated close cursor_id"
269        );
270    }
271}