Skip to main content

reddb_wire/redwire/
prepared.rs

1//! RedWire legacy prepared-statement payload codec.
2
3use crate::legacy::{encode_value, try_decode_value, WireValue};
4use std::fmt;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub struct PreparePayload {
8    pub stmt_id: u32,
9    pub sql: String,
10}
11
12#[derive(Debug, Clone, PartialEq)]
13pub struct ExecutePreparedPayload {
14    pub stmt_id: u32,
15    pub params: Vec<WireValue>,
16}
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub struct DeallocatePayload {
20    pub stmt_id: u32,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub struct PreparedOkPayload {
25    pub stmt_id: u32,
26    pub param_count: u16,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub enum PreparedPayloadError {
31    TruncatedPrepareStmtId,
32    TruncatedPrepareSqlLen,
33    TruncatedPrepareSql,
34    InvalidPrepareSql,
35    TruncatedExecuteStmtId,
36    TruncatedExecuteParamCount,
37    ExecuteParamValue(&'static str),
38    TruncatedDeallocateStmtId,
39    SqlTooLarge,
40    ParamCountOverflow,
41}
42
43impl fmt::Display for PreparedPayloadError {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        match self {
46            Self::TruncatedPrepareStmtId => write!(f, "truncated prepare stmt_id"),
47            Self::TruncatedPrepareSqlLen => write!(f, "truncated prepare sql_len"),
48            Self::TruncatedPrepareSql => write!(f, "truncated prepare sql"),
49            Self::InvalidPrepareSql => write!(f, "invalid UTF-8 in prepare sql"),
50            Self::TruncatedExecuteStmtId => write!(f, "truncated execute stmt_id"),
51            Self::TruncatedExecuteParamCount => write!(f, "truncated execute nparams"),
52            Self::ExecuteParamValue(err) => write!(f, "{err}"),
53            Self::TruncatedDeallocateStmtId => write!(f, "truncated deallocate stmt_id"),
54            Self::SqlTooLarge => write!(f, "prepare sql is too large for RedWire prepared payload"),
55            Self::ParamCountOverflow => {
56                write!(
57                    f,
58                    "parameter count is too large for RedWire prepared payload"
59                )
60            }
61        }
62    }
63}
64
65impl std::error::Error for PreparedPayloadError {}
66
67pub fn encode_prepare_payload(stmt_id: u32, sql: &str) -> Result<Vec<u8>, PreparedPayloadError> {
68    let sql_len = u32::try_from(sql.len()).map_err(|_| PreparedPayloadError::SqlTooLarge)?;
69    let mut out = Vec::with_capacity(8 + sql.len());
70    out.extend_from_slice(&stmt_id.to_le_bytes());
71    out.extend_from_slice(&sql_len.to_le_bytes());
72    out.extend_from_slice(sql.as_bytes());
73    Ok(out)
74}
75
76pub fn decode_prepare_payload(payload: &[u8]) -> Result<PreparePayload, PreparedPayloadError> {
77    let mut pos = 0usize;
78    let stmt_id = u32::from_le_bytes(read_array(
79        payload,
80        &mut pos,
81        PreparedPayloadError::TruncatedPrepareStmtId,
82    )?);
83    let sql_len = u32::from_le_bytes(read_array(
84        payload,
85        &mut pos,
86        PreparedPayloadError::TruncatedPrepareSqlLen,
87    )?) as usize;
88    let sql_bytes = read_bytes(
89        payload,
90        &mut pos,
91        sql_len,
92        PreparedPayloadError::TruncatedPrepareSql,
93    )?;
94    let sql = std::str::from_utf8(sql_bytes)
95        .map(str::to_string)
96        .map_err(|_| PreparedPayloadError::InvalidPrepareSql)?;
97    Ok(PreparePayload { stmt_id, sql })
98}
99
100pub fn encode_execute_prepared_payload(
101    stmt_id: u32,
102    params: &[WireValue],
103) -> Result<Vec<u8>, PreparedPayloadError> {
104    let param_count =
105        u16::try_from(params.len()).map_err(|_| PreparedPayloadError::ParamCountOverflow)?;
106    let mut out = Vec::new();
107    out.extend_from_slice(&stmt_id.to_le_bytes());
108    out.extend_from_slice(&param_count.to_le_bytes());
109    for param in params {
110        encode_value(&mut out, param);
111    }
112    Ok(out)
113}
114
115pub fn decode_execute_prepared_payload(
116    payload: &[u8],
117) -> Result<ExecutePreparedPayload, PreparedPayloadError> {
118    let mut pos = 0usize;
119    let stmt_id = u32::from_le_bytes(read_array(
120        payload,
121        &mut pos,
122        PreparedPayloadError::TruncatedExecuteStmtId,
123    )?);
124    let nparams = u16::from_le_bytes(read_array(
125        payload,
126        &mut pos,
127        PreparedPayloadError::TruncatedExecuteParamCount,
128    )?) as usize;
129    let mut params = Vec::with_capacity(nparams);
130    for _ in 0..nparams {
131        params.push(
132            try_decode_value(payload, &mut pos).map_err(PreparedPayloadError::ExecuteParamValue)?,
133        );
134    }
135    Ok(ExecutePreparedPayload { stmt_id, params })
136}
137
138pub fn encode_deallocate_payload(stmt_id: u32) -> Vec<u8> {
139    stmt_id.to_le_bytes().to_vec()
140}
141
142pub fn decode_deallocate_payload(
143    payload: &[u8],
144) -> Result<DeallocatePayload, PreparedPayloadError> {
145    let mut pos = 0usize;
146    let stmt_id = u32::from_le_bytes(read_array(
147        payload,
148        &mut pos,
149        PreparedPayloadError::TruncatedDeallocateStmtId,
150    )?);
151    Ok(DeallocatePayload { stmt_id })
152}
153
154pub fn encode_prepared_ok_payload(
155    stmt_id: u32,
156    param_count: usize,
157) -> Result<Vec<u8>, PreparedPayloadError> {
158    let param_count =
159        u16::try_from(param_count).map_err(|_| PreparedPayloadError::ParamCountOverflow)?;
160    let mut out = Vec::with_capacity(6);
161    out.extend_from_slice(&stmt_id.to_le_bytes());
162    out.extend_from_slice(&param_count.to_le_bytes());
163    Ok(out)
164}
165
166fn read_bytes<'a>(
167    payload: &'a [u8],
168    pos: &mut usize,
169    len: usize,
170    err: PreparedPayloadError,
171) -> Result<&'a [u8], PreparedPayloadError> {
172    let end = pos.checked_add(len).ok_or(err.clone())?;
173    if end > payload.len() {
174        return Err(err);
175    }
176    let bytes = &payload[*pos..end];
177    *pos = end;
178    Ok(bytes)
179}
180
181fn read_array<const N: usize>(
182    payload: &[u8],
183    pos: &mut usize,
184    err: PreparedPayloadError,
185) -> Result<[u8; N], PreparedPayloadError> {
186    let bytes = read_bytes(payload, pos, N, err)?;
187    let mut out = [0u8; N];
188    out.copy_from_slice(bytes);
189    Ok(out)
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[test]
197    fn prepare_payload_round_trips() {
198        let bytes = encode_prepare_payload(42, "SELECT * FROM users WHERE id = ?").unwrap();
199        assert_eq!(
200            decode_prepare_payload(&bytes).unwrap(),
201            PreparePayload {
202                stmt_id: 42,
203                sql: "SELECT * FROM users WHERE id = ?".to_string(),
204            }
205        );
206    }
207
208    #[test]
209    fn execute_prepared_payload_round_trips_wire_values() {
210        let params = vec![WireValue::I64(7), WireValue::Text("ada".to_string())];
211        let bytes = encode_execute_prepared_payload(9, &params).unwrap();
212        assert_eq!(
213            decode_execute_prepared_payload(&bytes).unwrap(),
214            ExecutePreparedPayload { stmt_id: 9, params }
215        );
216    }
217
218    #[test]
219    fn deallocate_payload_round_trips() {
220        let bytes = encode_deallocate_payload(11);
221        assert_eq!(
222            decode_deallocate_payload(&bytes).unwrap(),
223            DeallocatePayload { stmt_id: 11 }
224        );
225    }
226
227    #[test]
228    fn prepared_errors_preserve_legacy_messages() {
229        assert_eq!(
230            decode_prepare_payload(&[0, 0, 0]).unwrap_err().to_string(),
231            "truncated prepare stmt_id"
232        );
233        assert_eq!(
234            decode_execute_prepared_payload(&[1, 0, 0, 0])
235                .unwrap_err()
236                .to_string(),
237            "truncated execute nparams"
238        );
239        assert_eq!(
240            decode_deallocate_payload(&[1]).unwrap_err().to_string(),
241            "truncated deallocate stmt_id"
242        );
243    }
244}