1use core::fmt;
4
5use bitflags::bitflags;
6use byteorder::{ByteOrder, NetworkEndian};
7use getset::{CopyGetters, Getters};
8use num_enum::{TryFromPrimitive, TryFromPrimitiveError};
9
10use super::{
11 AuthenticationContext, AuthenticationType, DeserializeError, MinorVersion, PacketBody,
12 PacketType, Serialize, SerializeError, UserInformation,
13};
14use crate::{Deserialize, FieldText};
15
16#[cfg(test)]
17mod tests;
18
19#[cfg(feature = "std")]
20mod owned;
21
22mod data;
23pub use data::{DataTooLong, PacketData};
24
25#[cfg(feature = "std")]
26pub use owned::ReplyOwned;
27
28#[repr(u8)]
30#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
31pub enum Action {
32 Login = 0x01,
34
35 ChangePassword = 0x02,
37
38 SendAuth = 0x04,
44}
45
46impl Action {
47 const WIRE_SIZE: usize = 1;
49}
50
51#[repr(u8)]
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive)]
54pub enum Status {
55 Pass = 0x01,
57
58 Fail = 0x02,
60
61 GetData = 0x03,
63
64 GetUser = 0x04,
66
67 GetPassword = 0x05,
69
70 Restart = 0x06,
72
73 Error = 0x07,
75
76 #[deprecated = "Forwarding to an alternative daemon was deprecated in RFC-8907."]
78 Follow = 0x21,
79}
80
81impl Status {
82 const WIRE_SIZE: usize = 1;
84}
85
86#[doc(hidden)]
87impl From<TryFromPrimitiveError<Status>> for DeserializeError {
88 fn from(value: TryFromPrimitiveError<Status>) -> Self {
89 Self::InvalidStatus(value.number)
90 }
91}
92
93#[derive(Debug, Clone, PartialEq, Eq, Hash)]
95pub struct Start<'packet> {
96 action: Action,
97 authentication: AuthenticationContext,
98 user_information: UserInformation<'packet>,
99 data: Option<PacketData<'packet>>,
100}
101
102#[non_exhaustive]
104#[derive(Debug, PartialEq, Eq)]
105pub enum BadStart {
106 AuthTypeNotSet,
108
109 IncompatibleActionAndType,
115}
116
117impl fmt::Display for BadStart {
118 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
119 match self {
120 Self::AuthTypeNotSet => write!(
121 f,
122 "authentication type must be set for authentication packets"
123 ),
124 Self::IncompatibleActionAndType => {
125 write!(f, "authentication action & type are incompatible")
126 }
127 }
128 }
129}
130
131impl<'packet> Start<'packet> {
132 pub fn new(
134 action: Action,
135 authentication: AuthenticationContext,
136 user_information: UserInformation<'packet>,
137 data: Option<PacketData<'packet>>,
138 ) -> Result<Self, BadStart> {
139 if authentication.authentication_type == AuthenticationType::NotSet {
140 Err(BadStart::AuthTypeNotSet)
142 } else if !Self::action_and_type_compatible(authentication.authentication_type, action) {
143 Err(BadStart::IncompatibleActionAndType)
144 } else {
145 Ok(Self {
146 action,
147 authentication,
148 user_information,
149 data,
150 })
151 }
152 }
153
154 fn action_and_type_compatible(auth_type: AuthenticationType, action: Action) -> bool {
162 match (auth_type, action) {
163 (AuthenticationType::Ascii, Action::Login | Action::ChangePassword) => true,
165
166 (AuthenticationType::Ascii, Action::SendAuth) => false,
168
169 (_, Action::ChangePassword) => false,
171
172 (AuthenticationType::NotSet, _) => unreachable!(),
174
175 _ => true,
177 }
178 }
179}
180
181impl PacketBody for Start<'_> {
182 const TYPE: PacketType = PacketType::Authentication;
183
184 const REQUIRED_FIELDS_LENGTH: usize = Action::WIRE_SIZE
186 + AuthenticationContext::WIRE_SIZE
187 + UserInformation::HEADER_INFORMATION_SIZE
188 + 1;
189
190 fn required_minor_version(&self) -> Option<MinorVersion> {
191 match self.authentication.authentication_type {
193 AuthenticationType::Ascii => Some(MinorVersion::Default),
194 _ => Some(MinorVersion::V1),
195 }
196 }
197}
198
199impl Serialize for Start<'_> {
200 fn wire_size(&self) -> usize {
201 Action::WIRE_SIZE
202 + AuthenticationContext::WIRE_SIZE
203 + self.user_information.wire_size()
204 + 1 + self.data.as_ref().map_or(0, |data| data.as_bytes().len())
206 }
207
208 fn serialize_into_buffer(&self, buffer: &mut [u8]) -> Result<usize, SerializeError> {
209 let wire_size = self.wire_size();
210
211 if buffer.len() >= self.wire_size() {
212 buffer[0] = self.action as u8;
213
214 self.authentication.serialize(&mut buffer[1..4]);
215
216 self.user_information
217 .serialize_field_lengths(&mut buffer[4..7])?;
218
219 let mut total_bytes_written = 8;
221
222 let user_info_written_len = self
225 .user_information
226 .serialize_field_values(&mut buffer[8..wire_size])?;
227 total_bytes_written += user_info_written_len;
228
229 let data_start = 8 + user_info_written_len;
231 if let Some(data) = self.data.as_ref() {
232 let data_len = data.len();
233
234 buffer[7] = data.len();
236
237 buffer[data_start..data_start + data_len as usize].copy_from_slice(data.as_bytes());
239
240 total_bytes_written += data_len as usize;
241 } else {
242 buffer[7] = 0;
244 }
245
246 if total_bytes_written == wire_size {
247 Ok(total_bytes_written)
248 } else {
249 Err(SerializeError::LengthMismatch {
250 expected: wire_size,
251 actual: total_bytes_written,
252 })
253 }
254 } else {
255 Err(SerializeError::NotEnoughSpace)
256 }
257 }
258}
259
260bitflags! {
261 #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
263 #[repr(transparent)]
264 pub struct ReplyFlags: u8 {
265 const NO_ECHO = 0b00000001;
267 }
268}
269
270impl ReplyFlags {
271 const WIRE_SIZE: usize = 1;
273}
274
275crate::util::bitflags_display_impl!(ReplyFlags);
276
277#[derive(Debug, Clone, PartialEq, Eq, Hash, Getters, CopyGetters)]
279pub struct Reply<'packet> {
280 #[getset(get = "pub")]
282 status: Status,
283
284 #[getset(get = "pub")]
286 server_message: FieldText<'packet>,
287
288 #[getset(get_copy = "pub")]
290 data: &'packet [u8],
291
292 #[getset(get = "pub")]
294 flags: ReplyFlags,
295}
296
297struct ReplyFieldLengths {
298 server_message_length: u16,
299 data_length: u16,
300 total_length: u32,
301}
302
303impl Reply<'_> {
304 const SERVER_MESSAGE_OFFSET: usize = 6;
306
307 pub fn extract_total_length(buffer: &[u8]) -> Result<u32, DeserializeError> {
309 Self::extract_field_lengths(buffer).map(|lengths| lengths.total_length)
310 }
311
312 fn extract_field_lengths(buffer: &[u8]) -> Result<ReplyFieldLengths, DeserializeError> {
314 if buffer.len() >= Self::REQUIRED_FIELDS_LENGTH {
316 let server_message_length = NetworkEndian::read_u16(&buffer[2..4]);
317 let data_length = NetworkEndian::read_u16(&buffer[4..6]);
318
319 let total_length = u32::try_from(Self::REQUIRED_FIELDS_LENGTH).unwrap()
322 + u32::from(server_message_length)
323 + u32::from(data_length);
324
325 Ok(ReplyFieldLengths {
326 server_message_length,
327 data_length,
328 total_length,
329 })
330 } else {
331 Err(DeserializeError::UnexpectedEnd)
332 }
333 }
334}
335
336impl PacketBody for Reply<'_> {
337 const TYPE: PacketType = PacketType::Authentication;
338
339 const REQUIRED_FIELDS_LENGTH: usize = Status::WIRE_SIZE + ReplyFlags::WIRE_SIZE + 4;
341}
342
343#[doc(hidden)]
345impl<'raw> Deserialize<'raw> for Reply<'raw> {
346 fn deserialize_from_buffer(buffer: &'raw [u8]) -> Result<Self, DeserializeError> {
347 let field_lengths = Self::extract_field_lengths(buffer)?;
348
349 let length_from_header = buffer.len();
352
353 if field_lengths.total_length as usize == length_from_header {
355 let status = Status::try_from(buffer[0])?;
356 let flag_byte = buffer[1];
357 let flags = ReplyFlags::from_bits(flag_byte)
358 .ok_or(DeserializeError::InvalidBodyFlags(flag_byte))?;
359
360 let data_begin =
361 Self::SERVER_MESSAGE_OFFSET + field_lengths.server_message_length as usize;
362
363 let server_message =
364 FieldText::try_from(&buffer[Self::SERVER_MESSAGE_OFFSET..data_begin])
365 .map_err(|_| DeserializeError::BadText)?;
366 let data = &buffer[data_begin..data_begin + field_lengths.data_length as usize];
367
368 Ok(Reply {
369 status,
370 server_message,
371 data,
372 flags,
373 })
374 } else {
375 Err(DeserializeError::WrongBodyBufferSize {
376 expected: field_lengths.total_length as usize,
377 buffer_size: length_from_header,
378 })
379 }
380 }
381}
382
383bitflags! {
384 #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
386 #[repr(transparent)]
387 pub struct ContinueFlags: u8 {
388 const ABORT = 0b00000001;
390 }
391}
392
393crate::util::bitflags_display_impl!(ContinueFlags);
394
395#[derive(PartialEq, Eq, Clone, Debug, Hash)]
397pub struct Continue<'packet> {
398 user_message: Option<&'packet [u8]>,
399 data: Option<&'packet [u8]>,
400 flags: ContinueFlags,
401}
402
403impl<'packet> Continue<'packet> {
404 const USER_MESSAGE_OFFSET: usize = 5;
406
407 pub fn new(
409 user_message: Option<&'packet [u8]>,
410 data: Option<&'packet [u8]>,
411 flags: ContinueFlags,
412 ) -> Option<Self> {
413 if user_message.map_or(true, |message| u16::try_from(message.len()).is_ok())
414 && data.map_or(true, |data_slice| u16::try_from(data_slice.len()).is_ok())
415 {
416 Some(Continue {
417 user_message,
418 data,
419 flags,
420 })
421 } else {
422 None
423 }
424 }
425}
426
427impl PacketBody for Continue<'_> {
428 const TYPE: PacketType = PacketType::Authentication;
429
430 const REQUIRED_FIELDS_LENGTH: usize = 5;
432}
433
434impl Serialize for Continue<'_> {
435 fn wire_size(&self) -> usize {
436 Self::REQUIRED_FIELDS_LENGTH
437 + self.user_message.map_or(0, <[u8]>::len)
438 + self.data.map_or(0, <[u8]>::len)
439 }
440
441 fn serialize_into_buffer(&self, buffer: &mut [u8]) -> Result<usize, SerializeError> {
442 let wire_size = self.wire_size();
443
444 if buffer.len() >= wire_size {
445 let user_message_len = self.user_message.map_or(0, <[u8]>::len).try_into()?;
447 NetworkEndian::write_u16(&mut buffer[..2], user_message_len);
448
449 let data_len = self.data.map_or(0, <[u8]>::len).try_into()?;
450 NetworkEndian::write_u16(&mut buffer[2..4], data_len);
451
452 let data_offset = Self::USER_MESSAGE_OFFSET + user_message_len as usize;
453
454 buffer[4] = self.flags.bits();
456
457 if let Some(message) = self.user_message {
459 buffer[Self::USER_MESSAGE_OFFSET..data_offset].copy_from_slice(message);
460 }
461
462 if let Some(data) = self.data {
464 buffer[data_offset..data_offset + data_len as usize].copy_from_slice(data);
465 }
466
467 let actual_written_len =
469 Self::REQUIRED_FIELDS_LENGTH + user_message_len as usize + data_len as usize;
470
471 if actual_written_len == wire_size {
472 Ok(actual_written_len)
473 } else {
474 Err(SerializeError::LengthMismatch {
475 expected: wire_size,
476 actual: actual_written_len,
477 })
478 }
479 } else {
480 Err(SerializeError::NotEnoughSpace)
481 }
482 }
483}