1use crate::legacy::{encode_value, try_decode_value, WireValue};
4use std::fmt;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub struct PreparePayload {
8 pub stmt_id: u32,
9 pub sql: String,
10}
11
12#[derive(Debug, Clone, PartialEq)]
13pub struct ExecutePreparedPayload {
14 pub stmt_id: u32,
15 pub params: Vec<WireValue>,
16}
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub struct DeallocatePayload {
20 pub stmt_id: u32,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub struct PreparedOkPayload {
25 pub stmt_id: u32,
26 pub param_count: u16,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub enum PreparedPayloadError {
31 TruncatedPrepareStmtId,
32 TruncatedPrepareSqlLen,
33 TruncatedPrepareSql,
34 InvalidPrepareSql,
35 TruncatedExecuteStmtId,
36 TruncatedExecuteParamCount,
37 ExecuteParamValue(&'static str),
38 TruncatedDeallocateStmtId,
39 SqlTooLarge,
40 ParamCountOverflow,
41}
42
43impl fmt::Display for PreparedPayloadError {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 match self {
46 Self::TruncatedPrepareStmtId => write!(f, "truncated prepare stmt_id"),
47 Self::TruncatedPrepareSqlLen => write!(f, "truncated prepare sql_len"),
48 Self::TruncatedPrepareSql => write!(f, "truncated prepare sql"),
49 Self::InvalidPrepareSql => write!(f, "invalid UTF-8 in prepare sql"),
50 Self::TruncatedExecuteStmtId => write!(f, "truncated execute stmt_id"),
51 Self::TruncatedExecuteParamCount => write!(f, "truncated execute nparams"),
52 Self::ExecuteParamValue(err) => write!(f, "{err}"),
53 Self::TruncatedDeallocateStmtId => write!(f, "truncated deallocate stmt_id"),
54 Self::SqlTooLarge => write!(f, "prepare sql is too large for RedWire prepared payload"),
55 Self::ParamCountOverflow => {
56 write!(
57 f,
58 "parameter count is too large for RedWire prepared payload"
59 )
60 }
61 }
62 }
63}
64
65impl std::error::Error for PreparedPayloadError {}
66
67pub fn encode_prepare_payload(stmt_id: u32, sql: &str) -> Result<Vec<u8>, PreparedPayloadError> {
68 let sql_len = u32::try_from(sql.len()).map_err(|_| PreparedPayloadError::SqlTooLarge)?;
69 let mut out = Vec::with_capacity(8 + sql.len());
70 out.extend_from_slice(&stmt_id.to_le_bytes());
71 out.extend_from_slice(&sql_len.to_le_bytes());
72 out.extend_from_slice(sql.as_bytes());
73 Ok(out)
74}
75
76pub fn decode_prepare_payload(payload: &[u8]) -> Result<PreparePayload, PreparedPayloadError> {
77 let mut pos = 0usize;
78 let stmt_id = u32::from_le_bytes(read_array(
79 payload,
80 &mut pos,
81 PreparedPayloadError::TruncatedPrepareStmtId,
82 )?);
83 let sql_len = u32::from_le_bytes(read_array(
84 payload,
85 &mut pos,
86 PreparedPayloadError::TruncatedPrepareSqlLen,
87 )?) as usize;
88 let sql_bytes = read_bytes(
89 payload,
90 &mut pos,
91 sql_len,
92 PreparedPayloadError::TruncatedPrepareSql,
93 )?;
94 let sql = std::str::from_utf8(sql_bytes)
95 .map(str::to_string)
96 .map_err(|_| PreparedPayloadError::InvalidPrepareSql)?;
97 Ok(PreparePayload { stmt_id, sql })
98}
99
100pub fn encode_execute_prepared_payload(
101 stmt_id: u32,
102 params: &[WireValue],
103) -> Result<Vec<u8>, PreparedPayloadError> {
104 let param_count =
105 u16::try_from(params.len()).map_err(|_| PreparedPayloadError::ParamCountOverflow)?;
106 let mut out = Vec::new();
107 out.extend_from_slice(&stmt_id.to_le_bytes());
108 out.extend_from_slice(¶m_count.to_le_bytes());
109 for param in params {
110 encode_value(&mut out, param);
111 }
112 Ok(out)
113}
114
115pub fn decode_execute_prepared_payload(
116 payload: &[u8],
117) -> Result<ExecutePreparedPayload, PreparedPayloadError> {
118 let mut pos = 0usize;
119 let stmt_id = u32::from_le_bytes(read_array(
120 payload,
121 &mut pos,
122 PreparedPayloadError::TruncatedExecuteStmtId,
123 )?);
124 let nparams = u16::from_le_bytes(read_array(
125 payload,
126 &mut pos,
127 PreparedPayloadError::TruncatedExecuteParamCount,
128 )?) as usize;
129 let mut params = Vec::with_capacity(nparams);
130 for _ in 0..nparams {
131 params.push(
132 try_decode_value(payload, &mut pos).map_err(PreparedPayloadError::ExecuteParamValue)?,
133 );
134 }
135 Ok(ExecutePreparedPayload { stmt_id, params })
136}
137
138pub fn encode_deallocate_payload(stmt_id: u32) -> Vec<u8> {
139 stmt_id.to_le_bytes().to_vec()
140}
141
142pub fn decode_deallocate_payload(
143 payload: &[u8],
144) -> Result<DeallocatePayload, PreparedPayloadError> {
145 let mut pos = 0usize;
146 let stmt_id = u32::from_le_bytes(read_array(
147 payload,
148 &mut pos,
149 PreparedPayloadError::TruncatedDeallocateStmtId,
150 )?);
151 Ok(DeallocatePayload { stmt_id })
152}
153
154pub fn encode_prepared_ok_payload(
155 stmt_id: u32,
156 param_count: usize,
157) -> Result<Vec<u8>, PreparedPayloadError> {
158 let param_count =
159 u16::try_from(param_count).map_err(|_| PreparedPayloadError::ParamCountOverflow)?;
160 let mut out = Vec::with_capacity(6);
161 out.extend_from_slice(&stmt_id.to_le_bytes());
162 out.extend_from_slice(¶m_count.to_le_bytes());
163 Ok(out)
164}
165
166fn read_bytes<'a>(
167 payload: &'a [u8],
168 pos: &mut usize,
169 len: usize,
170 err: PreparedPayloadError,
171) -> Result<&'a [u8], PreparedPayloadError> {
172 let end = pos.checked_add(len).ok_or(err.clone())?;
173 if end > payload.len() {
174 return Err(err);
175 }
176 let bytes = &payload[*pos..end];
177 *pos = end;
178 Ok(bytes)
179}
180
181fn read_array<const N: usize>(
182 payload: &[u8],
183 pos: &mut usize,
184 err: PreparedPayloadError,
185) -> Result<[u8; N], PreparedPayloadError> {
186 let bytes = read_bytes(payload, pos, N, err)?;
187 let mut out = [0u8; N];
188 out.copy_from_slice(bytes);
189 Ok(out)
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn prepare_payload_round_trips() {
198 let bytes = encode_prepare_payload(42, "SELECT * FROM users WHERE id = ?").unwrap();
199 assert_eq!(
200 decode_prepare_payload(&bytes).unwrap(),
201 PreparePayload {
202 stmt_id: 42,
203 sql: "SELECT * FROM users WHERE id = ?".to_string(),
204 }
205 );
206 }
207
208 #[test]
209 fn execute_prepared_payload_round_trips_wire_values() {
210 let params = vec![WireValue::I64(7), WireValue::Text("ada".to_string())];
211 let bytes = encode_execute_prepared_payload(9, ¶ms).unwrap();
212 assert_eq!(
213 decode_execute_prepared_payload(&bytes).unwrap(),
214 ExecutePreparedPayload { stmt_id: 9, params }
215 );
216 }
217
218 #[test]
219 fn deallocate_payload_round_trips() {
220 let bytes = encode_deallocate_payload(11);
221 assert_eq!(
222 decode_deallocate_payload(&bytes).unwrap(),
223 DeallocatePayload { stmt_id: 11 }
224 );
225 }
226
227 #[test]
228 fn prepared_errors_preserve_legacy_messages() {
229 assert_eq!(
230 decode_prepare_payload(&[0, 0, 0]).unwrap_err().to_string(),
231 "truncated prepare stmt_id"
232 );
233 assert_eq!(
234 decode_execute_prepared_payload(&[1, 0, 0, 0])
235 .unwrap_err()
236 .to_string(),
237 "truncated execute nparams"
238 );
239 assert_eq!(
240 decode_deallocate_payload(&[1]).unwrap_err().to_string(),
241 "truncated deallocate stmt_id"
242 );
243 }
244}