1use crate::legacy::{encode_column_name, encode_value, WireValue};
4use std::fmt;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub struct DeclareCursorPayload {
8 pub cursor_id: u32,
9 pub sql: String,
10}
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub struct FetchPayload {
14 pub cursor_id: u32,
15 pub max_rows: u32,
16}
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub struct CloseCursorPayload {
20 pub cursor_id: u32,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub enum CursorPayloadError {
25 TruncatedDeclareCursorId,
26 TruncatedDeclareSqlLen,
27 TruncatedDeclareSql,
28 InvalidDeclareSql,
29 TruncatedFetchCursorId,
30 TruncatedFetchMaxRows,
31 TruncatedCloseCursorId,
32 SqlTooLarge,
33 ColumnCountOverflow,
34 RowCountOverflow,
35}
36
37impl fmt::Display for CursorPayloadError {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 match self {
40 Self::TruncatedDeclareCursorId => write!(f, "truncated declare cursor_id"),
41 Self::TruncatedDeclareSqlLen => write!(f, "truncated declare sql_len"),
42 Self::TruncatedDeclareSql => write!(f, "truncated declare sql"),
43 Self::InvalidDeclareSql => write!(f, "invalid UTF-8 in declare sql"),
44 Self::TruncatedFetchCursorId => write!(f, "truncated fetch cursor_id"),
45 Self::TruncatedFetchMaxRows => write!(f, "truncated fetch max_rows"),
46 Self::TruncatedCloseCursorId => write!(f, "truncated close cursor_id"),
47 Self::SqlTooLarge => write!(f, "declare sql is too large for RedWire cursor payload"),
48 Self::ColumnCountOverflow => {
49 write!(f, "column count is too large for RedWire cursor payload")
50 }
51 Self::RowCountOverflow => {
52 write!(f, "row count is too large for RedWire cursor payload")
53 }
54 }
55 }
56}
57
58impl std::error::Error for CursorPayloadError {}
59
60pub fn encode_declare_cursor_payload(
61 cursor_id: u32,
62 sql: &str,
63) -> Result<Vec<u8>, CursorPayloadError> {
64 let sql_len = u32::try_from(sql.len()).map_err(|_| CursorPayloadError::SqlTooLarge)?;
65 let mut out = Vec::with_capacity(8 + sql.len());
66 out.extend_from_slice(&cursor_id.to_le_bytes());
67 out.extend_from_slice(&sql_len.to_le_bytes());
68 out.extend_from_slice(sql.as_bytes());
69 Ok(out)
70}
71
72pub fn decode_declare_cursor_payload(
73 payload: &[u8],
74) -> Result<DeclareCursorPayload, CursorPayloadError> {
75 let mut pos = 0usize;
76 let cursor_id = u32::from_le_bytes(read_array(
77 payload,
78 &mut pos,
79 CursorPayloadError::TruncatedDeclareCursorId,
80 )?);
81 let sql_len = u32::from_le_bytes(read_array(
82 payload,
83 &mut pos,
84 CursorPayloadError::TruncatedDeclareSqlLen,
85 )?) as usize;
86 let sql_bytes = read_bytes(
87 payload,
88 &mut pos,
89 sql_len,
90 CursorPayloadError::TruncatedDeclareSql,
91 )?;
92 let sql = std::str::from_utf8(sql_bytes)
93 .map(str::to_string)
94 .map_err(|_| CursorPayloadError::InvalidDeclareSql)?;
95 Ok(DeclareCursorPayload { cursor_id, sql })
96}
97
98pub fn encode_fetch_payload(cursor_id: u32, max_rows: u32) -> Vec<u8> {
99 let mut out = Vec::with_capacity(8);
100 out.extend_from_slice(&cursor_id.to_le_bytes());
101 out.extend_from_slice(&max_rows.to_le_bytes());
102 out
103}
104
105pub fn decode_fetch_payload(payload: &[u8]) -> Result<FetchPayload, CursorPayloadError> {
106 let mut pos = 0usize;
107 let cursor_id = u32::from_le_bytes(read_array(
108 payload,
109 &mut pos,
110 CursorPayloadError::TruncatedFetchCursorId,
111 )?);
112 let max_rows = u32::from_le_bytes(read_array(
113 payload,
114 &mut pos,
115 CursorPayloadError::TruncatedFetchMaxRows,
116 )?);
117 Ok(FetchPayload {
118 cursor_id,
119 max_rows,
120 })
121}
122
123pub fn encode_close_cursor_payload(cursor_id: u32) -> Vec<u8> {
124 cursor_id.to_le_bytes().to_vec()
125}
126
127pub fn decode_close_cursor_payload(
128 payload: &[u8],
129) -> Result<CloseCursorPayload, CursorPayloadError> {
130 let mut pos = 0usize;
131 let cursor_id = u32::from_le_bytes(read_array(
132 payload,
133 &mut pos,
134 CursorPayloadError::TruncatedCloseCursorId,
135 )?);
136 Ok(CloseCursorPayload { cursor_id })
137}
138
139pub fn encode_cursor_ok_payload(
140 cursor_id: u32,
141 columns: &[impl AsRef<str>],
142 total_rows: u64,
143) -> Result<Vec<u8>, CursorPayloadError> {
144 let ncols =
145 u16::try_from(columns.len()).map_err(|_| CursorPayloadError::ColumnCountOverflow)?;
146 let mut out = Vec::with_capacity(4 + 2 + 8 + columns.len() * 16);
147 out.extend_from_slice(&cursor_id.to_le_bytes());
148 out.extend_from_slice(&ncols.to_le_bytes());
149 for col in columns {
150 encode_column_name(&mut out, col.as_ref());
151 }
152 out.extend_from_slice(&total_rows.to_le_bytes());
153 Ok(out)
154}
155
156pub fn encode_cursor_batch_payload(
157 cursor_id: u32,
158 rows: &[Vec<WireValue>],
159 has_more: bool,
160) -> Result<Vec<u8>, CursorPayloadError> {
161 let nrows = u32::try_from(rows.len()).map_err(|_| CursorPayloadError::RowCountOverflow)?;
162 let mut out = Vec::new();
163 out.extend_from_slice(&cursor_id.to_le_bytes());
164 out.extend_from_slice(&nrows.to_le_bytes());
165 out.push(u8::from(has_more));
166 for row in rows {
167 for value in row {
168 encode_value(&mut out, value);
169 }
170 }
171 Ok(out)
172}
173
174fn read_bytes<'a>(
175 payload: &'a [u8],
176 pos: &mut usize,
177 len: usize,
178 err: CursorPayloadError,
179) -> Result<&'a [u8], CursorPayloadError> {
180 let end = pos.checked_add(len).ok_or(err.clone())?;
181 if end > payload.len() {
182 return Err(err);
183 }
184 let bytes = &payload[*pos..end];
185 *pos = end;
186 Ok(bytes)
187}
188
189fn read_array<const N: usize>(
190 payload: &[u8],
191 pos: &mut usize,
192 err: CursorPayloadError,
193) -> Result<[u8; N], CursorPayloadError> {
194 let bytes = read_bytes(payload, pos, N, err)?;
195 let mut out = [0u8; N];
196 out.copy_from_slice(bytes);
197 Ok(out)
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn declare_cursor_payload_round_trips() {
206 let bytes = encode_declare_cursor_payload(7, "SELECT id FROM users").unwrap();
207 assert_eq!(
208 decode_declare_cursor_payload(&bytes).unwrap(),
209 DeclareCursorPayload {
210 cursor_id: 7,
211 sql: "SELECT id FROM users".to_string(),
212 }
213 );
214 }
215
216 #[test]
217 fn fetch_and_close_payloads_round_trip() {
218 assert_eq!(
219 decode_fetch_payload(&encode_fetch_payload(3, 50)).unwrap(),
220 FetchPayload {
221 cursor_id: 3,
222 max_rows: 50,
223 }
224 );
225 assert_eq!(
226 decode_close_cursor_payload(&encode_close_cursor_payload(9)).unwrap(),
227 CloseCursorPayload { cursor_id: 9 }
228 );
229 }
230
231 #[test]
232 fn cursor_ok_and_batch_payloads_encode_expected_headers() {
233 let ok = encode_cursor_ok_payload(5, &["id", "name"], 20).unwrap();
234 assert_eq!(u32::from_le_bytes([ok[0], ok[1], ok[2], ok[3]]), 5);
235 assert_eq!(u16::from_le_bytes([ok[4], ok[5]]), 2);
236
237 let batch = encode_cursor_batch_payload(
238 5,
239 &[vec![WireValue::I64(1), WireValue::Text("ada".to_string())]],
240 true,
241 )
242 .unwrap();
243 assert_eq!(
244 u32::from_le_bytes([batch[0], batch[1], batch[2], batch[3]]),
245 5
246 );
247 assert_eq!(
248 u32::from_le_bytes([batch[4], batch[5], batch[6], batch[7]]),
249 1
250 );
251 assert_eq!(batch[8], 1);
252 }
253
254 #[test]
255 fn cursor_errors_preserve_legacy_messages() {
256 assert_eq!(
257 decode_declare_cursor_payload(&[0, 0, 0])
258 .unwrap_err()
259 .to_string(),
260 "truncated declare cursor_id"
261 );
262 assert_eq!(
263 decode_fetch_payload(&[0, 0, 0, 0]).unwrap_err().to_string(),
264 "truncated fetch max_rows"
265 );
266 assert_eq!(
267 decode_close_cursor_payload(&[0]).unwrap_err().to_string(),
268 "truncated close cursor_id"
269 );
270 }
271}