Skip to main content

qail_pg/protocol/wire/
frontend.rs

1//! FrontendMessage encoder — client-to-server wire format.
2
3use super::types::*;
4
5impl FrontendMessage {
6    #[inline]
7    fn has_nul(s: &str) -> bool {
8        s.as_bytes().contains(&0)
9    }
10
11    #[inline]
12    fn content_len_to_wire_len(content_len: usize) -> Result<i32, FrontendEncodeError> {
13        let total = content_len
14            .checked_add(4)
15            .ok_or(FrontendEncodeError::MessageTooLarge(usize::MAX))?;
16        i32::try_from(total).map_err(|_| FrontendEncodeError::MessageTooLarge(total))
17    }
18
19    /// Fallible encoder that returns explicit reason on invalid input.
20    pub fn encode_checked(&self) -> Result<Vec<u8>, FrontendEncodeError> {
21        match self {
22            FrontendMessage::Startup {
23                user,
24                database,
25                protocol_version,
26                startup_params,
27            } => {
28                if Self::has_nul(user) {
29                    return Err(FrontendEncodeError::InteriorNul("user"));
30                }
31                if Self::has_nul(database) {
32                    return Err(FrontendEncodeError::InteriorNul("database"));
33                }
34                let mut seen_startup_keys = std::collections::HashSet::new();
35                let mut buf = Vec::new();
36                buf.extend_from_slice(&protocol_version.to_be_bytes());
37                buf.extend_from_slice(b"user\0");
38                buf.extend_from_slice(user.as_bytes());
39                buf.push(0);
40                buf.extend_from_slice(b"database\0");
41                buf.extend_from_slice(database.as_bytes());
42                buf.push(0);
43                for (key, value) in startup_params {
44                    let key_trimmed = key.trim();
45                    if key_trimmed.is_empty() {
46                        return Err(FrontendEncodeError::InvalidStartupParam(
47                            "key must not be empty".to_string(),
48                        ));
49                    }
50                    let key_lc = key_trimmed.to_ascii_lowercase();
51                    if key_lc == "user" || key_lc == "database" {
52                        return Err(FrontendEncodeError::InvalidStartupParam(format!(
53                            "reserved key '{}'",
54                            key_trimmed
55                        )));
56                    }
57                    if !seen_startup_keys.insert(key_lc) {
58                        return Err(FrontendEncodeError::InvalidStartupParam(format!(
59                            "duplicate key '{}'",
60                            key_trimmed
61                        )));
62                    }
63                    if Self::has_nul(key) {
64                        return Err(FrontendEncodeError::InteriorNul("startup_param_key"));
65                    }
66                    if Self::has_nul(value) {
67                        return Err(FrontendEncodeError::InteriorNul("startup_param_value"));
68                    }
69                    buf.extend_from_slice(key.as_bytes());
70                    buf.push(0);
71                    buf.extend_from_slice(value.as_bytes());
72                    buf.push(0);
73                }
74                buf.push(0);
75
76                let len = Self::content_len_to_wire_len(buf.len())?;
77                let mut result = len.to_be_bytes().to_vec();
78                result.extend(buf);
79                Ok(result)
80            }
81            FrontendMessage::Query(sql) => {
82                if Self::has_nul(sql) {
83                    return Err(FrontendEncodeError::InteriorNul("sql"));
84                }
85                let mut buf = Vec::new();
86                buf.push(b'Q');
87                let mut content = Vec::with_capacity(sql.len() + 1);
88                content.extend_from_slice(sql.as_bytes());
89                content.push(0);
90                let len = Self::content_len_to_wire_len(content.len())?;
91                buf.extend_from_slice(&len.to_be_bytes());
92                buf.extend_from_slice(&content);
93                Ok(buf)
94            }
95            FrontendMessage::Terminate => Ok(vec![b'X', 0, 0, 0, 4]),
96            FrontendMessage::SASLInitialResponse { mechanism, data } => {
97                if Self::has_nul(mechanism) {
98                    return Err(FrontendEncodeError::InteriorNul("mechanism"));
99                }
100                if data.len() > i32::MAX as usize {
101                    return Err(FrontendEncodeError::MessageTooLarge(data.len()));
102                }
103                let mut buf = Vec::new();
104                buf.push(b'p');
105
106                let mut content = Vec::new();
107                content.extend_from_slice(mechanism.as_bytes());
108                content.push(0);
109                let data_len = i32::try_from(data.len())
110                    .map_err(|_| FrontendEncodeError::MessageTooLarge(data.len()))?;
111                content.extend_from_slice(&data_len.to_be_bytes());
112                content.extend_from_slice(data);
113
114                let len = Self::content_len_to_wire_len(content.len())?;
115                buf.extend_from_slice(&len.to_be_bytes());
116                buf.extend_from_slice(&content);
117                Ok(buf)
118            }
119            FrontendMessage::SASLResponse(data) | FrontendMessage::GSSResponse(data) => {
120                if data.len() > i32::MAX as usize {
121                    return Err(FrontendEncodeError::MessageTooLarge(data.len()));
122                }
123                let mut buf = Vec::new();
124                buf.push(b'p');
125                let len = Self::content_len_to_wire_len(data.len())?;
126                buf.extend_from_slice(&len.to_be_bytes());
127                buf.extend_from_slice(data);
128                Ok(buf)
129            }
130            FrontendMessage::PasswordMessage(password) => {
131                if Self::has_nul(password) {
132                    return Err(FrontendEncodeError::InteriorNul("password"));
133                }
134                let mut buf = Vec::new();
135                buf.push(b'p');
136                let mut content = Vec::with_capacity(password.len() + 1);
137                content.extend_from_slice(password.as_bytes());
138                content.push(0);
139                let len = Self::content_len_to_wire_len(content.len())?;
140                buf.extend_from_slice(&len.to_be_bytes());
141                buf.extend_from_slice(&content);
142                Ok(buf)
143            }
144            FrontendMessage::Parse {
145                name,
146                query,
147                param_types,
148            } => {
149                if Self::has_nul(name) {
150                    return Err(FrontendEncodeError::InteriorNul("name"));
151                }
152                if Self::has_nul(query) {
153                    return Err(FrontendEncodeError::InteriorNul("query"));
154                }
155                if param_types.len() > i16::MAX as usize {
156                    return Err(FrontendEncodeError::TooManyParams(param_types.len()));
157                }
158                let mut buf = Vec::new();
159                buf.push(b'P');
160
161                let mut content = Vec::new();
162                content.extend_from_slice(name.as_bytes());
163                content.push(0);
164                content.extend_from_slice(query.as_bytes());
165                content.push(0);
166                let param_count = i16::try_from(param_types.len())
167                    .map_err(|_| FrontendEncodeError::TooManyParams(param_types.len()))?;
168                content.extend_from_slice(&param_count.to_be_bytes());
169                for oid in param_types {
170                    content.extend_from_slice(&oid.to_be_bytes());
171                }
172
173                let len = Self::content_len_to_wire_len(content.len())?;
174                buf.extend_from_slice(&len.to_be_bytes());
175                buf.extend_from_slice(&content);
176                Ok(buf)
177            }
178            FrontendMessage::Bind {
179                portal,
180                statement,
181                params,
182            } => {
183                if Self::has_nul(portal) {
184                    return Err(FrontendEncodeError::InteriorNul("portal"));
185                }
186                if Self::has_nul(statement) {
187                    return Err(FrontendEncodeError::InteriorNul("statement"));
188                }
189                if params.len() > i16::MAX as usize {
190                    return Err(FrontendEncodeError::TooManyParams(params.len()));
191                }
192                if let Some(too_large) = params
193                    .iter()
194                    .flatten()
195                    .find(|p| p.len() > i32::MAX as usize)
196                {
197                    return Err(FrontendEncodeError::MessageTooLarge(too_large.len()));
198                }
199
200                let mut buf = Vec::new();
201                buf.push(b'B');
202
203                let mut content = Vec::new();
204                content.extend_from_slice(portal.as_bytes());
205                content.push(0);
206                content.extend_from_slice(statement.as_bytes());
207                content.push(0);
208                content.extend_from_slice(&0i16.to_be_bytes());
209                let param_count = i16::try_from(params.len())
210                    .map_err(|_| FrontendEncodeError::TooManyParams(params.len()))?;
211                content.extend_from_slice(&param_count.to_be_bytes());
212                for param in params {
213                    match param {
214                        Some(data) => {
215                            let data_len = i32::try_from(data.len())
216                                .map_err(|_| FrontendEncodeError::MessageTooLarge(data.len()))?;
217                            content.extend_from_slice(&data_len.to_be_bytes());
218                            content.extend_from_slice(data);
219                        }
220                        None => content.extend_from_slice(&(-1i32).to_be_bytes()),
221                    }
222                }
223                content.extend_from_slice(&0i16.to_be_bytes());
224
225                let len = Self::content_len_to_wire_len(content.len())?;
226                buf.extend_from_slice(&len.to_be_bytes());
227                buf.extend_from_slice(&content);
228                Ok(buf)
229            }
230            FrontendMessage::Execute { portal, max_rows } => {
231                if Self::has_nul(portal) {
232                    return Err(FrontendEncodeError::InteriorNul("portal"));
233                }
234                if *max_rows < 0 {
235                    return Err(FrontendEncodeError::InvalidMaxRows(*max_rows));
236                }
237                let mut buf = Vec::new();
238                buf.push(b'E');
239                let mut content = Vec::new();
240                content.extend_from_slice(portal.as_bytes());
241                content.push(0);
242                content.extend_from_slice(&max_rows.to_be_bytes());
243                let len = Self::content_len_to_wire_len(content.len())?;
244                buf.extend_from_slice(&len.to_be_bytes());
245                buf.extend_from_slice(&content);
246                Ok(buf)
247            }
248            FrontendMessage::Sync => Ok(vec![b'S', 0, 0, 0, 4]),
249            FrontendMessage::CopyFail(msg) => {
250                if Self::has_nul(msg) {
251                    return Err(FrontendEncodeError::InteriorNul("copy_fail"));
252                }
253                let mut buf = Vec::new();
254                buf.push(b'f');
255                let mut content = Vec::with_capacity(msg.len() + 1);
256                content.extend_from_slice(msg.as_bytes());
257                content.push(0);
258                let len = Self::content_len_to_wire_len(content.len())?;
259                buf.extend_from_slice(&len.to_be_bytes());
260                buf.extend_from_slice(&content);
261                Ok(buf)
262            }
263            FrontendMessage::Close { is_portal, name } => {
264                if Self::has_nul(name) {
265                    return Err(FrontendEncodeError::InteriorNul("name"));
266                }
267                let mut buf = Vec::new();
268                buf.push(b'C');
269                let type_byte = if *is_portal { b'P' } else { b'S' };
270                let mut content = vec![type_byte];
271                content.extend_from_slice(name.as_bytes());
272                content.push(0);
273                let len = Self::content_len_to_wire_len(content.len())?;
274                buf.extend_from_slice(&len.to_be_bytes());
275                buf.extend_from_slice(&content);
276                Ok(buf)
277            }
278        }
279    }
280}