1use bytes::{Buf, BufMut, Bytes, BytesMut};
4use std::convert::TryFrom;
5use thiserror::Error;
6
7pub const RPC_VERSION_2: u32 = 2;
8pub const MAX_AUTH_BYTES: usize = 400;
9pub const MAX_FRAGMENT_LEN: u32 = 0x7fff_ffff;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub struct Xid(pub u32);
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub struct ProgramVersion {
16 pub program: u32,
17 pub version: u32,
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub struct Procedure(pub u32);
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub struct VersionRange {
25 pub low: u32,
26 pub high: u32,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum AuthFlavor {
31 None,
32 Sys,
33 Short,
34 Dh,
35 RpcSecGss,
36 Unknown(u32),
37}
38
39impl AuthFlavor {
40 pub fn number(self) -> u32 {
41 match self {
42 Self::None => 0,
43 Self::Sys => 1,
44 Self::Short => 2,
45 Self::Dh => 3,
46 Self::RpcSecGss => 6,
47 Self::Unknown(value) => value,
48 }
49 }
50}
51
52impl From<u32> for AuthFlavor {
53 fn from(value: u32) -> Self {
54 match value {
55 0 => Self::None,
56 1 => Self::Sys,
57 2 => Self::Short,
58 3 => Self::Dh,
59 6 => Self::RpcSecGss,
60 other => Self::Unknown(other),
61 }
62 }
63}
64
65#[derive(Debug, Clone, PartialEq, Eq)]
66pub struct OpaqueAuth {
67 pub flavor: AuthFlavor,
68 pub body: Bytes,
69}
70
71impl OpaqueAuth {
72 pub fn new(flavor: AuthFlavor, body: Bytes) -> Result<Self, WireError> {
73 validate_auth_len(body.len())?;
74 Ok(Self { flavor, body })
75 }
76
77 pub fn none() -> Self {
78 Self {
79 flavor: AuthFlavor::None,
80 body: Bytes::new(),
81 }
82 }
83
84 fn encode_into(&self, output: &mut BytesMut) -> Result<(), WireError> {
85 validate_auth_len(self.body.len())?;
86 output.put_u32(self.flavor.number());
87 output.put_u32(self.body.len() as u32);
88 output.extend_from_slice(&self.body);
89 pad_to_xdr_alignment(output, self.body.len());
90 Ok(())
91 }
92
93 fn decode(input: &mut &[u8]) -> Result<Self, WireError> {
94 let flavor = AuthFlavor::from(read_u32(input)?);
95 let body_len = read_u32(input)? as usize;
96 validate_auth_len(body_len)?;
97 let body = read_opaque(input, body_len)?;
98
99 Ok(Self { flavor, body })
100 }
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104pub struct RecordMarker {
105 pub last_fragment: bool,
106 pub payload_len: u32,
107}
108
109impl RecordMarker {
110 pub fn new(payload_len: u32, last_fragment: bool) -> Result<Self, WireError> {
111 if payload_len > MAX_FRAGMENT_LEN {
112 return Err(WireError::FragmentLengthTooLarge(payload_len));
113 }
114
115 Ok(Self {
116 last_fragment,
117 payload_len,
118 })
119 }
120
121 pub fn encode(self) -> u32 {
122 let final_fragment = if self.last_fragment { 1_u32 << 31 } else { 0 };
123 final_fragment | self.payload_len
124 }
125
126 pub fn encode_bytes(self) -> [u8; 4] {
127 self.encode().to_be_bytes()
128 }
129
130 pub fn decode(word: u32) -> Self {
131 Self {
132 last_fragment: (word & (1_u32 << 31)) != 0,
133 payload_len: word & MAX_FRAGMENT_LEN,
134 }
135 }
136
137 pub fn decode_bytes(header: [u8; 4]) -> Self {
138 Self::decode(u32::from_be_bytes(header))
139 }
140}
141
142#[derive(Debug, Clone, PartialEq, Eq)]
143pub enum RecordRead {
144 Incomplete,
145 Complete {
146 data: Bytes,
147 consumed: usize,
148 fragments: usize,
149 },
150}
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq)]
153pub enum MessageType {
154 Call,
155 Reply,
156}
157
158impl MessageType {
159 fn number(self) -> u32 {
160 match self {
161 Self::Call => 0,
162 Self::Reply => 1,
163 }
164 }
165}
166
167impl TryFrom<u32> for MessageType {
168 type Error = WireError;
169
170 fn try_from(value: u32) -> Result<Self, Self::Error> {
171 match value {
172 0 => Ok(Self::Call),
173 1 => Ok(Self::Reply),
174 other => Err(WireError::InvalidMessageType(other)),
175 }
176 }
177}
178
179#[derive(Debug, Clone, Copy, PartialEq, Eq)]
180pub enum ReplyStat {
181 MessageAccepted,
182 MessageDenied,
183}
184
185impl ReplyStat {
186 fn number(self) -> u32 {
187 match self {
188 Self::MessageAccepted => 0,
189 Self::MessageDenied => 1,
190 }
191 }
192}
193
194impl TryFrom<u32> for ReplyStat {
195 type Error = WireError;
196
197 fn try_from(value: u32) -> Result<Self, Self::Error> {
198 match value {
199 0 => Ok(Self::MessageAccepted),
200 1 => Ok(Self::MessageDenied),
201 other => Err(WireError::InvalidReplyStat(other)),
202 }
203 }
204}
205
206#[derive(Debug, Clone, Copy, PartialEq, Eq)]
207pub enum AcceptStat {
208 Success,
209 ProgramUnavailable,
210 ProgramMismatch,
211 ProcedureUnavailable,
212 GarbageArgs,
213 SystemError,
214}
215
216impl AcceptStat {
217 fn number(self) -> u32 {
218 match self {
219 Self::Success => 0,
220 Self::ProgramUnavailable => 1,
221 Self::ProgramMismatch => 2,
222 Self::ProcedureUnavailable => 3,
223 Self::GarbageArgs => 4,
224 Self::SystemError => 5,
225 }
226 }
227}
228
229impl TryFrom<u32> for AcceptStat {
230 type Error = WireError;
231
232 fn try_from(value: u32) -> Result<Self, Self::Error> {
233 match value {
234 0 => Ok(Self::Success),
235 1 => Ok(Self::ProgramUnavailable),
236 2 => Ok(Self::ProgramMismatch),
237 3 => Ok(Self::ProcedureUnavailable),
238 4 => Ok(Self::GarbageArgs),
239 5 => Ok(Self::SystemError),
240 other => Err(WireError::InvalidAcceptStat(other)),
241 }
242 }
243}
244
245#[derive(Debug, Clone, Copy, PartialEq, Eq)]
246pub enum RejectStat {
247 RpcMismatch,
248 AuthError,
249}
250
251impl RejectStat {
252 fn number(self) -> u32 {
253 match self {
254 Self::RpcMismatch => 0,
255 Self::AuthError => 1,
256 }
257 }
258}
259
260impl TryFrom<u32> for RejectStat {
261 type Error = WireError;
262
263 fn try_from(value: u32) -> Result<Self, Self::Error> {
264 match value {
265 0 => Ok(Self::RpcMismatch),
266 1 => Ok(Self::AuthError),
267 other => Err(WireError::InvalidRejectStat(other)),
268 }
269 }
270}
271
272#[derive(Debug, Clone, Copy, PartialEq, Eq)]
273pub enum AuthStat {
274 Ok,
275 BadCred,
276 RejectedCred,
277 BadVerf,
278 RejectedVerf,
279 TooWeak,
280 InvalidResp,
281 Failed,
282 KerbGeneric,
283 TimeExpire,
284 TicketFile,
285 Decode,
286 NetAddr,
287 RpcSecGssCredProblem,
288 RpcSecGssCtxProblem,
289 Unknown(u32),
290}
291
292impl AuthStat {
293 fn number(self) -> u32 {
294 match self {
295 Self::Ok => 0,
296 Self::BadCred => 1,
297 Self::RejectedCred => 2,
298 Self::BadVerf => 3,
299 Self::RejectedVerf => 4,
300 Self::TooWeak => 5,
301 Self::InvalidResp => 6,
302 Self::Failed => 7,
303 Self::KerbGeneric => 8,
304 Self::TimeExpire => 9,
305 Self::TicketFile => 10,
306 Self::Decode => 11,
307 Self::NetAddr => 12,
308 Self::RpcSecGssCredProblem => 13,
309 Self::RpcSecGssCtxProblem => 14,
310 Self::Unknown(value) => value,
311 }
312 }
313}
314
315impl From<u32> for AuthStat {
316 fn from(value: u32) -> Self {
317 match value {
318 0 => Self::Ok,
319 1 => Self::BadCred,
320 2 => Self::RejectedCred,
321 3 => Self::BadVerf,
322 4 => Self::RejectedVerf,
323 5 => Self::TooWeak,
324 6 => Self::InvalidResp,
325 7 => Self::Failed,
326 8 => Self::KerbGeneric,
327 9 => Self::TimeExpire,
328 10 => Self::TicketFile,
329 11 => Self::Decode,
330 12 => Self::NetAddr,
331 13 => Self::RpcSecGssCredProblem,
332 14 => Self::RpcSecGssCtxProblem,
333 other => Self::Unknown(other),
334 }
335 }
336}
337
338#[derive(Debug, Clone, PartialEq, Eq)]
339pub struct CallBody {
340 pub rpc_version: u32,
341 pub program: ProgramVersion,
342 pub procedure: Procedure,
343 pub credentials: OpaqueAuth,
344 pub verifier: OpaqueAuth,
345 pub payload: Bytes,
346}
347
348impl CallBody {
349 pub fn new(
350 program: ProgramVersion,
351 procedure: Procedure,
352 credentials: OpaqueAuth,
353 verifier: OpaqueAuth,
354 payload: Bytes,
355 ) -> Self {
356 Self {
357 rpc_version: RPC_VERSION_2,
358 program,
359 procedure,
360 credentials,
361 verifier,
362 payload,
363 }
364 }
365}
366
367#[derive(Debug, Clone, PartialEq, Eq)]
368pub enum AcceptedStatus {
369 Success(Bytes),
370 ProgramUnavailable,
371 ProgramMismatch(VersionRange),
372 ProcedureUnavailable,
373 GarbageArgs,
374 SystemError,
375}
376
377#[derive(Debug, Clone, PartialEq, Eq)]
378pub struct AcceptedReply {
379 pub verifier: OpaqueAuth,
380 pub status: AcceptedStatus,
381}
382
383#[derive(Debug, Clone, PartialEq, Eq)]
384pub enum RejectedReply {
385 RpcMismatch(VersionRange),
386 AuthError(AuthStat),
387}
388
389#[derive(Debug, Clone, PartialEq, Eq)]
390pub enum ReplyBody {
391 Accepted(AcceptedReply),
392 Denied(RejectedReply),
393}
394
395#[derive(Debug, Clone, PartialEq, Eq)]
396pub enum MessageBody {
397 Call(CallBody),
398 Reply(ReplyBody),
399}
400
401#[derive(Debug, Clone, PartialEq, Eq)]
402pub struct RpcMessage {
403 pub xid: Xid,
404 pub body: MessageBody,
405}
406
407impl RpcMessage {
408 pub fn encode(&self) -> Result<Bytes, WireError> {
409 let mut output = BytesMut::new();
410 output.put_u32(self.xid.0);
411
412 match &self.body {
413 MessageBody::Call(call) => {
414 output.put_u32(MessageType::Call.number());
415 output.put_u32(call.rpc_version);
416 output.put_u32(call.program.program);
417 output.put_u32(call.program.version);
418 output.put_u32(call.procedure.0);
419 call.credentials.encode_into(&mut output)?;
420 call.verifier.encode_into(&mut output)?;
421 output.extend_from_slice(&call.payload);
422 }
423 MessageBody::Reply(reply) => {
424 output.put_u32(MessageType::Reply.number());
425 encode_reply_body(reply, &mut output)?;
426 }
427 }
428
429 Ok(output.freeze())
430 }
431
432 pub fn decode(input: &[u8]) -> Result<Self, WireError> {
433 let mut input = input;
434 let xid = Xid(read_u32(&mut input)?);
435 let message_type = MessageType::try_from(read_u32(&mut input)?)?;
436
437 let body = match message_type {
438 MessageType::Call => MessageBody::Call(decode_call_body(&mut input)?),
439 MessageType::Reply => MessageBody::Reply(decode_reply_body(&mut input)?),
440 };
441
442 Ok(Self { xid, body })
443 }
444}
445
446pub fn fragment_record(data: &[u8], max_fragment_len: u32) -> Result<Vec<Bytes>, WireError> {
447 if max_fragment_len == 0 {
448 return Err(WireError::ZeroFragmentLength);
449 }
450 if max_fragment_len > MAX_FRAGMENT_LEN {
451 return Err(WireError::FragmentLengthTooLarge(max_fragment_len));
452 }
453
454 let chunk_len = max_fragment_len as usize;
455 let fragment_count = data.len().div_ceil(chunk_len).max(1);
456 let mut fragments = Vec::with_capacity(fragment_count);
457
458 if data.is_empty() {
459 let marker = RecordMarker::new(0, true)?;
460 fragments.push(Bytes::copy_from_slice(&marker.encode_bytes()));
461 return Ok(fragments);
462 }
463
464 for (idx, chunk) in data.chunks(chunk_len).enumerate() {
465 let marker = RecordMarker::new(chunk.len() as u32, idx + 1 == fragment_count)?;
466 let mut fragment = BytesMut::with_capacity(4 + chunk.len());
467 fragment.extend_from_slice(&marker.encode_bytes());
468 fragment.extend_from_slice(chunk);
469 fragments.push(fragment.freeze());
470 }
471
472 Ok(fragments)
473}
474
475pub fn read_record(input: &[u8]) -> Result<RecordRead, WireError> {
476 let mut offset = 0_usize;
477 let mut fragments = 0_usize;
478 let mut record = BytesMut::new();
479
480 loop {
481 let remaining = &input[offset..];
482 if remaining.is_empty() && fragments == 0 {
483 return Ok(RecordRead::Incomplete);
484 }
485 if remaining.len() < 4 {
486 return Ok(RecordRead::Incomplete);
487 }
488
489 let marker =
490 RecordMarker::decode_bytes(remaining[..4].try_into().expect("slice len checked"));
491 let fragment_len = marker.payload_len as usize;
492 let fragment_end = 4 + fragment_len;
493
494 if remaining.len() < fragment_end {
495 return Ok(RecordRead::Incomplete);
496 }
497
498 record.extend_from_slice(&remaining[4..fragment_end]);
499 offset += fragment_end;
500 fragments += 1;
501
502 if marker.last_fragment {
503 return Ok(RecordRead::Complete {
504 data: record.freeze(),
505 consumed: offset,
506 fragments,
507 });
508 }
509 }
510}
511
512fn decode_call_body(input: &mut &[u8]) -> Result<CallBody, WireError> {
513 let rpc_version = read_u32(input)?;
514 if rpc_version != RPC_VERSION_2 {
515 return Err(WireError::UnsupportedRpcVersion(rpc_version));
516 }
517
518 let program = ProgramVersion {
519 program: read_u32(input)?,
520 version: read_u32(input)?,
521 };
522 let procedure = Procedure(read_u32(input)?);
523 let credentials = OpaqueAuth::decode(input)?;
524 let verifier = OpaqueAuth::decode(input)?;
525 let payload = Bytes::copy_from_slice(input);
526 *input = &[];
527
528 Ok(CallBody {
529 rpc_version,
530 program,
531 procedure,
532 credentials,
533 verifier,
534 payload,
535 })
536}
537
538fn decode_reply_body(input: &mut &[u8]) -> Result<ReplyBody, WireError> {
539 match ReplyStat::try_from(read_u32(input)?)? {
540 ReplyStat::MessageAccepted => decode_accepted_reply(input).map(ReplyBody::Accepted),
541 ReplyStat::MessageDenied => decode_rejected_reply(input).map(ReplyBody::Denied),
542 }
543}
544
545fn decode_accepted_reply(input: &mut &[u8]) -> Result<AcceptedReply, WireError> {
546 let verifier = OpaqueAuth::decode(input)?;
547 let status = match AcceptStat::try_from(read_u32(input)?)? {
548 AcceptStat::Success => {
549 let payload = Bytes::copy_from_slice(input);
550 *input = &[];
551 AcceptedStatus::Success(payload)
552 }
553 AcceptStat::ProgramUnavailable => expect_end(input, AcceptedStatus::ProgramUnavailable)?,
554 AcceptStat::ProgramMismatch => {
555 let range = VersionRange {
556 low: read_u32(input)?,
557 high: read_u32(input)?,
558 };
559 expect_end(input, AcceptedStatus::ProgramMismatch(range))?
560 }
561 AcceptStat::ProcedureUnavailable => {
562 expect_end(input, AcceptedStatus::ProcedureUnavailable)?
563 }
564 AcceptStat::GarbageArgs => expect_end(input, AcceptedStatus::GarbageArgs)?,
565 AcceptStat::SystemError => expect_end(input, AcceptedStatus::SystemError)?,
566 };
567
568 Ok(AcceptedReply { verifier, status })
569}
570
571fn decode_rejected_reply(input: &mut &[u8]) -> Result<RejectedReply, WireError> {
572 match RejectStat::try_from(read_u32(input)?)? {
573 RejectStat::RpcMismatch => {
574 let range = VersionRange {
575 low: read_u32(input)?,
576 high: read_u32(input)?,
577 };
578 expect_end(input, RejectedReply::RpcMismatch(range))
579 }
580 RejectStat::AuthError => {
581 let status = AuthStat::from(read_u32(input)?);
582 expect_end(input, RejectedReply::AuthError(status))
583 }
584 }
585}
586
587fn encode_reply_body(reply: &ReplyBody, output: &mut BytesMut) -> Result<(), WireError> {
588 match reply {
589 ReplyBody::Accepted(accepted) => {
590 output.put_u32(ReplyStat::MessageAccepted.number());
591 accepted.verifier.encode_into(output)?;
592 match &accepted.status {
593 AcceptedStatus::Success(results) => {
594 output.put_u32(AcceptStat::Success.number());
595 output.extend_from_slice(results);
596 }
597 AcceptedStatus::ProgramUnavailable => {
598 output.put_u32(AcceptStat::ProgramUnavailable.number());
599 }
600 AcceptedStatus::ProgramMismatch(range) => {
601 output.put_u32(AcceptStat::ProgramMismatch.number());
602 output.put_u32(range.low);
603 output.put_u32(range.high);
604 }
605 AcceptedStatus::ProcedureUnavailable => {
606 output.put_u32(AcceptStat::ProcedureUnavailable.number());
607 }
608 AcceptedStatus::GarbageArgs => {
609 output.put_u32(AcceptStat::GarbageArgs.number());
610 }
611 AcceptedStatus::SystemError => {
612 output.put_u32(AcceptStat::SystemError.number());
613 }
614 }
615 }
616 ReplyBody::Denied(rejected) => {
617 output.put_u32(ReplyStat::MessageDenied.number());
618 match rejected {
619 RejectedReply::RpcMismatch(range) => {
620 output.put_u32(RejectStat::RpcMismatch.number());
621 output.put_u32(range.low);
622 output.put_u32(range.high);
623 }
624 RejectedReply::AuthError(status) => {
625 output.put_u32(RejectStat::AuthError.number());
626 output.put_u32(status.number());
627 }
628 }
629 }
630 }
631
632 Ok(())
633}
634
635fn read_u32(input: &mut &[u8]) -> Result<u32, WireError> {
636 if input.len() < 4 {
637 return Err(WireError::UnexpectedEof {
638 needed: 4,
639 remaining: input.len(),
640 });
641 }
642
643 Ok(input.get_u32())
644}
645
646fn read_opaque(input: &mut &[u8], len: usize) -> Result<Bytes, WireError> {
647 if input.len() < len {
648 return Err(WireError::UnexpectedEof {
649 needed: len,
650 remaining: input.len(),
651 });
652 }
653
654 let body = Bytes::copy_from_slice(&input[..len]);
655 *input = &input[len..];
656
657 let padding = xdr_padding(len);
658 if input.len() < padding {
659 return Err(WireError::UnexpectedEof {
660 needed: padding,
661 remaining: input.len(),
662 });
663 }
664 if input[..padding].iter().any(|byte| *byte != 0) {
665 return Err(WireError::NonZeroXdrPadding);
666 }
667 *input = &input[padding..];
668
669 Ok(body)
670}
671
672fn pad_to_xdr_alignment(output: &mut BytesMut, len: usize) {
673 let padding = xdr_padding(len);
674 for _ in 0..padding {
675 output.put_u8(0);
676 }
677}
678
679fn xdr_padding(len: usize) -> usize {
680 (4 - (len % 4)) % 4
681}
682
683fn validate_auth_len(len: usize) -> Result<(), WireError> {
684 if len > MAX_AUTH_BYTES {
685 return Err(WireError::AuthBodyTooLong(len));
686 }
687
688 Ok(())
689}
690
691fn expect_end<T>(input: &mut &[u8], value: T) -> Result<T, WireError> {
692 if input.is_empty() {
693 Ok(value)
694 } else {
695 Err(WireError::TrailingBytes(input.len()))
696 }
697}
698
699#[derive(Debug, Clone, Error, PartialEq, Eq)]
700pub enum WireError {
701 #[error("unexpected end of input: need {needed} bytes, have {remaining}")]
702 UnexpectedEof { needed: usize, remaining: usize },
703 #[error("record fragment length {0} exceeds the 31-bit record marking limit")]
704 FragmentLengthTooLarge(u32),
705 #[error("record fragment size cannot be zero when fragmenting a non-empty record")]
706 ZeroFragmentLength,
707 #[error("unsupported rpc version {0}")]
708 UnsupportedRpcVersion(u32),
709 #[error("invalid message type discriminant {0}")]
710 InvalidMessageType(u32),
711 #[error("invalid reply status discriminant {0}")]
712 InvalidReplyStat(u32),
713 #[error("invalid accept status discriminant {0}")]
714 InvalidAcceptStat(u32),
715 #[error("invalid reject status discriminant {0}")]
716 InvalidRejectStat(u32),
717 #[error("auth body exceeds RFC 5531 limit of 400 bytes: {0}")]
718 AuthBodyTooLong(usize),
719 #[error("encountered non-zero XDR padding bytes")]
720 NonZeroXdrPadding,
721 #[error("unexpected trailing bytes after fixed-width reply branch: {0}")]
722 TrailingBytes(usize),
723}
724
725#[cfg(test)]
726mod tests {
727 use super::*;
728
729 fn auth(body: &[u8]) -> OpaqueAuth {
730 OpaqueAuth::new(AuthFlavor::None, Bytes::copy_from_slice(body)).expect("valid auth")
731 }
732
733 #[test]
734 fn record_marker_encodes_and_decodes() {
735 let marker = RecordMarker::new(1024, true).expect("valid marker");
736
737 assert_eq!(marker.encode(), 0x8000_0400);
738 assert_eq!(
739 RecordMarker::decode_bytes(marker.encode_bytes()),
740 RecordMarker {
741 last_fragment: true,
742 payload_len: 1024,
743 }
744 );
745 }
746
747 #[test]
748 fn fragment_record_splits_payload_into_multiple_fragments() {
749 let data = b"abcdefgh";
750 let fragments = fragment_record(data, 3).expect("fragmentation should succeed");
751
752 assert_eq!(fragments.len(), 3);
753 assert_eq!(&fragments[0][..4], &[0x00, 0x00, 0x00, 0x03]);
754 assert_eq!(&fragments[1][..4], &[0x00, 0x00, 0x00, 0x03]);
755 assert_eq!(&fragments[2][..4], &[0x80, 0x00, 0x00, 0x02]);
756 assert_eq!(&fragments[0][4..], b"abc");
757 assert_eq!(&fragments[1][4..], b"def");
758 assert_eq!(&fragments[2][4..], b"gh");
759 }
760
761 #[test]
762 fn read_record_reassembles_fragmented_message() {
763 let mut encoded = Vec::new();
764 for fragment in fragment_record(b"abcdefgh", 3).expect("fragmentation should succeed") {
765 encoded.extend_from_slice(&fragment);
766 }
767
768 let record = read_record(&encoded).expect("record read should succeed");
769 assert_eq!(
770 record,
771 RecordRead::Complete {
772 data: Bytes::from_static(b"abcdefgh"),
773 consumed: encoded.len(),
774 fragments: 3,
775 }
776 );
777 }
778
779 #[test]
780 fn read_record_reports_incomplete_header() {
781 let record = read_record(&[0x80, 0x00]).expect("partial header is not malformed");
782 assert_eq!(record, RecordRead::Incomplete);
783 }
784
785 #[test]
786 fn read_record_reports_incomplete_fragment_payload() {
787 let record = read_record(&[0x80, 0x00, 0x00, 0x04, 0xaa, 0xbb])
788 .expect("short payload is not malformed");
789 assert_eq!(record, RecordRead::Incomplete);
790 }
791
792 #[test]
793 fn opaque_auth_rejects_oversized_body() {
794 let body = Bytes::from(vec![0_u8; MAX_AUTH_BYTES + 1]);
795 let error = OpaqueAuth::new(AuthFlavor::Sys, body).expect_err("body must be rejected");
796
797 assert_eq!(error, WireError::AuthBodyTooLong(MAX_AUTH_BYTES + 1));
798 }
799
800 #[test]
801 fn rpc_call_round_trips() {
802 let message = RpcMessage {
803 xid: Xid(0x0102_0304),
804 body: MessageBody::Call(CallBody::new(
805 ProgramVersion {
806 program: 100_003,
807 version: 3,
808 },
809 Procedure(1),
810 auth(&[]),
811 auth(&[0xaa, 0xbb, 0xcc]),
812 Bytes::from_static(&[0xde, 0xad, 0xbe, 0xef]),
813 )),
814 };
815
816 let encoded = message.encode().expect("call should encode");
817 let decoded = RpcMessage::decode(&encoded).expect("call should decode");
818
819 assert_eq!(decoded, message);
820 }
821
822 #[test]
823 fn rpc_reply_round_trips() {
824 let message = RpcMessage {
825 xid: Xid(42),
826 body: MessageBody::Reply(ReplyBody::Accepted(AcceptedReply {
827 verifier: auth(&[]),
828 status: AcceptedStatus::Success(Bytes::from_static(&[0, 1, 2, 3])),
829 })),
830 };
831
832 let encoded = message.encode().expect("reply should encode");
833 let decoded = RpcMessage::decode(&encoded).expect("reply should decode");
834
835 assert_eq!(decoded, message);
836 }
837
838 #[test]
839 fn rpc_decode_rejects_invalid_message_type() {
840 let mut bytes = BytesMut::new();
841 bytes.put_u32(7);
842 bytes.put_u32(9);
843
844 let error = RpcMessage::decode(&bytes).expect_err("invalid message type must fail");
845 assert_eq!(error, WireError::InvalidMessageType(9));
846 }
847
848 #[test]
849 fn rpc_decode_rejects_invalid_accept_status() {
850 let mut bytes = BytesMut::new();
851 bytes.put_u32(7);
852 bytes.put_u32(MessageType::Reply.number());
853 bytes.put_u32(ReplyStat::MessageAccepted.number());
854 auth(&[]).encode_into(&mut bytes).expect("verifier encodes");
855 bytes.put_u32(88);
856
857 let error = RpcMessage::decode(&bytes).expect_err("invalid accept status must fail");
858 assert_eq!(error, WireError::InvalidAcceptStat(88));
859 }
860
861 #[test]
862 fn rpc_decode_rejects_non_zero_auth_padding() {
863 let bytes = [
864 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 3, 1, 2, 3, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, ];
876
877 let error = RpcMessage::decode(&bytes).expect_err("invalid padding must fail");
878 assert_eq!(error, WireError::NonZeroXdrPadding);
879 }
880
881 #[test]
882 fn rpc_decode_rejects_trailing_bytes_on_fixed_reply_branch() {
883 let mut bytes = BytesMut::new();
884 bytes.put_u32(9);
885 bytes.put_u32(MessageType::Reply.number());
886 bytes.put_u32(ReplyStat::MessageDenied.number());
887 bytes.put_u32(RejectStat::AuthError.number());
888 bytes.put_u32(AuthStat::Failed.number());
889 bytes.put_u32(0xdead_beef);
890
891 let error = RpcMessage::decode(&bytes).expect_err("trailing bytes must fail");
892 assert_eq!(error, WireError::TrailingBytes(4));
893 }
894
895 #[test]
896 fn rpc_decode_rejects_wrong_rpc_version() {
897 let mut bytes = BytesMut::new();
898 bytes.put_u32(1);
899 bytes.put_u32(MessageType::Call.number());
900 bytes.put_u32(99);
901 bytes.put_u32(3);
902 bytes.put_u32(4);
903 bytes.put_u32(5);
904 auth(&[]).encode_into(&mut bytes).expect("cred encodes");
905 auth(&[]).encode_into(&mut bytes).expect("verifier encodes");
906
907 let error = RpcMessage::decode(&bytes).expect_err("unexpected rpc version must fail");
908 assert_eq!(error, WireError::UnsupportedRpcVersion(99));
909 }
910}