1#[cfg(test)]
4mod tests;
5
6mod header;
7use header::Header;
8
9pub mod types;
10
11use enum_dispatch::enum_dispatch;
12
13use crate::common::{DecodeError, DecodeResult, Reader, SliceReader, VecWriter, Writer};
14use core::borrow::Borrow;
15use std::ops::DerefMut;
16
17#[enum_dispatch]
24#[derive(Clone, Debug, Eq, PartialEq)]
25pub enum AVP {
26 MessageType(types::MessageType),
27 RandomVector(types::RandomVector),
28 ResultCode(types::ResultCode),
29 ProtocolVersion(types::ProtocolVersion),
30 FramingCapabilities(types::FramingCapabilities),
31 BearerCapabilities(types::BearerCapabilities),
32 TieBreaker(types::TieBreaker),
33 FirmwareRevision(types::FirmwareRevision),
34 HostName(types::HostName),
35 VendorName(types::VendorName),
36 AssignedTunnelId(types::AssignedTunnelId),
37 ReceiveWindowSize(types::ReceiveWindowSize),
38 Challenge(types::Challenge),
39 ChallengeResponse(types::ChallengeResponse),
40 Q931CauseCode(types::Q931CauseCode),
41 AssignedSessionId(types::AssignedSessionId),
42 CallSerialNumber(types::CallSerialNumber),
43 MinimumBps(types::MinimumBps),
44 MaximumBps(types::MaximumBps),
45 BearerType(types::BearerType),
46 FramingType(types::FramingType),
47 CalledNumber(types::CalledNumber),
48 CallingNumber(types::CallingNumber),
49 SubAddress(types::SubAddress),
50 TxConnectSpeed(types::TxConnectSpeed),
51 RxConnectSpeed(types::RxConnectSpeed),
52 PhysicalChannelId(types::PhysicalChannelId),
53 PrivateGroupId(types::PrivateGroupId),
54 SequencingRequired(types::SequencingRequired),
55 InitialReceivedLcpConfReq(types::InitialReceivedLcpConfReq),
56 LastSentLcpConfReq(types::LastSentLcpConfReq),
57 LastReceivedLcpConfReq(types::LastReceivedLcpConfReq),
58 ProxyAuthenType(types::ProxyAuthenType),
59 ProxyAuthenName(types::ProxyAuthenName),
60 ProxyAuthenChallenge(types::ProxyAuthenChallenge),
61 ProxyAuthenId(types::ProxyAuthenId),
62 ProxyAuthenResponse(types::ProxyAuthenResponse),
63 CallErrors(types::CallErrors),
64 Accm(types::Accm),
65 Hidden(types::Hidden),
66}
67
68#[enum_dispatch(AVP)]
69pub(crate) trait QueryableAVP {
70 fn get_length(&self) -> usize;
71}
72
73#[enum_dispatch(AVP)]
74pub(crate) trait WritableAVP {
75 fn write(&self, writer: &mut impl Writer);
76}
77
78use AVP::*;
79
80pub(crate) fn avp_name(attribute_type: u16) -> String {
81 let result = match attribute_type {
82 0u16 => "MessageType",
83 1u16 => "ResultCode",
84 2u16 => "ProtocolVersion",
85 3u16 => "FramingCapabilities",
86 4u16 => "BearerCapabilities",
87 5u16 => "TieBreaker",
88 6u16 => "FirmwareRevision",
89 7u16 => "HostName",
90 8u16 => "VendorName",
91 9u16 => "AssignedTunnelId",
92 10u16 => "ReceiveWindowSize",
93 11u16 => "Challenge",
94 12u16 => "Q931CauseCode",
95 13u16 => "ChallengeResponse",
96 14u16 => "AssignedSessionId",
97 15u16 => "CallSerialNumber",
98 16u16 => "MinimumBps",
99 17u16 => "MaximumBps",
100 18u16 => "BearerType",
101 19u16 => "FramingType",
102 21u16 => "CalledNumber",
103 22u16 => "CallingNumber",
104 23u16 => "SubAddress",
105 24u16 => "TxConnectSpeed",
106 25u16 => "PhysicalChannelId",
107 26u16 => "InitialReceivedLcpConfReq",
108 27u16 => "LastSentLcpConfReq",
109 28u16 => "LastReceivedLcpConfReq",
110 29u16 => "ProxyAuthenType",
111 30u16 => "ProxyAuthenName",
112 31u16 => "ProxyAuthenChallenge",
113 32u16 => "ProxyAuthenId",
114 33u16 => "ProxyAuthenResponse",
115 34u16 => "CallErrors",
116 35u16 => "Accm",
117 36u16 => "RandomVector",
118 37u16 => "PrivateGroupId",
119 38u16 => "RxConnectSpeed",
120 39u16 => "SequencingRequired",
121 x => return format!("{x}"),
122 };
123
124 result.to_owned()
125}
126
127fn decode_avp<T: Borrow<[u8]>>(
128 attribute_type: u16,
129 reader: &mut impl Reader<T>,
130) -> DecodeResult<AVP> {
131 Ok(match attribute_type {
132 0u16 => MessageType(types::MessageType::try_read(reader)?),
133 1u16 => ResultCode(types::ResultCode::try_read(reader)?),
134 2u16 => ProtocolVersion(types::ProtocolVersion::try_read(reader)?),
135 3u16 => FramingCapabilities(types::FramingCapabilities::try_read(reader)?),
136 4u16 => BearerCapabilities(types::BearerCapabilities::try_read(reader)?),
137 5u16 => TieBreaker(types::TieBreaker::try_read(reader)?),
138 6u16 => FirmwareRevision(types::FirmwareRevision::try_read(reader)?),
139 7u16 => HostName(types::HostName::try_read(reader)?),
140 8u16 => VendorName(types::VendorName::try_read(reader)?),
141 9u16 => AssignedTunnelId(types::AssignedTunnelId::try_read(reader)?),
142 10u16 => ReceiveWindowSize(types::ReceiveWindowSize::try_read(reader)?),
143 11u16 => Challenge(types::Challenge::try_read(reader)?),
144 12u16 => Q931CauseCode(types::Q931CauseCode::try_read(reader)?),
145 13u16 => ChallengeResponse(types::ChallengeResponse::try_read(reader)?),
146 14u16 => AssignedSessionId(types::AssignedSessionId::try_read(reader)?),
147 15u16 => CallSerialNumber(types::CallSerialNumber::try_read(reader)?),
148 16u16 => MinimumBps(types::MinimumBps::try_read(reader)?),
149 17u16 => MaximumBps(types::MaximumBps::try_read(reader)?),
150 18u16 => BearerType(types::BearerType::try_read(reader)?),
151 19u16 => FramingType(types::FramingType::try_read(reader)?),
152 21u16 => CalledNumber(types::CalledNumber::try_read(reader)?),
153 22u16 => CallingNumber(types::CallingNumber::try_read(reader)?),
154 23u16 => SubAddress(types::SubAddress::try_read(reader)?),
155 24u16 => TxConnectSpeed(types::TxConnectSpeed::try_read(reader)?),
156 25u16 => PhysicalChannelId(types::PhysicalChannelId::try_read(reader)?),
157 26u16 => InitialReceivedLcpConfReq(types::InitialReceivedLcpConfReq::try_read(reader)?),
158 27u16 => LastSentLcpConfReq(types::LastSentLcpConfReq::try_read(reader)?),
159 28u16 => LastReceivedLcpConfReq(types::LastReceivedLcpConfReq::try_read(reader)?),
160 29u16 => ProxyAuthenType(types::ProxyAuthenType::try_read(reader)?),
161 30u16 => ProxyAuthenName(types::ProxyAuthenName::try_read(reader)?),
162 31u16 => ProxyAuthenChallenge(types::ProxyAuthenChallenge::try_read(reader)?),
163 32u16 => ProxyAuthenId(types::ProxyAuthenId::try_read(reader)?),
164 33u16 => ProxyAuthenResponse(types::ProxyAuthenResponse::try_read(reader)?),
165 34u16 => CallErrors(types::CallErrors::try_read(reader)?),
166 35u16 => Accm(types::Accm::try_read(reader)?),
167 36u16 => RandomVector(types::RandomVector::try_read(reader)?),
168 37u16 => PrivateGroupId(types::PrivateGroupId::try_read(reader)?),
169 38u16 => RxConnectSpeed(types::RxConnectSpeed::try_read(reader)?),
170 39u16 => SequencingRequired(types::SequencingRequired::default()),
171 x => Err(DecodeError::UnknownAvp(x))?,
172 })
173}
174
175impl AVP {
176 pub const CRYPTO_CHUNK_SIZE: usize = 16;
177
178 const ATTRIBUTE_TYPE_SIZE: usize = 2;
179 const LENGTH_BITS: u8 = 10;
180 const MAX_LENGTH: u16 = (1 << Self::LENGTH_BITS) - 1;
181
182 pub fn hide(
193 self,
194 secret: &[u8],
195 random_vector: &types::RandomVector,
196 length_padding: &[u8],
197 alignment_padding: &[u8; Self::CRYPTO_CHUNK_SIZE],
198 ) -> Self {
199 match &self {
200 Hidden(_) => self,
201 avp => {
202 let chunk_size: usize = Self::CRYPTO_CHUNK_SIZE;
203
204 let mut writer = VecWriter::new();
205
206 WritableAVP::write(avp, &mut writer);
207 assert!(writer.len() >= Self::ATTRIBUTE_TYPE_SIZE);
208
209 let attribute_type_octets: [u8; Self::ATTRIBUTE_TYPE_SIZE] =
211 writer.data[..Self::ATTRIBUTE_TYPE_SIZE].try_into().unwrap();
212
213 let length =
215 writer.data.len() + Header::LENGTH as usize - Self::ATTRIBUTE_TYPE_SIZE;
216
217 assert!(length <= Self::MAX_LENGTH as usize);
219 let length_octets = (length as u16).to_be_bytes();
220 writer.write_bytes_at(&length_octets, 0);
221
222 let mut input = writer.data;
223
224 input.extend_from_slice(length_padding);
226
227 let chunk_padding_length = (chunk_size - (input.len() % chunk_size)) % chunk_size;
228
229 input.extend_from_slice(&alignment_padding[..chunk_padding_length]);
231
232 let n_chunks = input.len() / chunk_size;
233
234 let buffer_length =
236 Self::ATTRIBUTE_TYPE_SIZE + secret.len() + random_vector.value.len();
237 let mut buffer = Vec::with_capacity(buffer_length);
238
239 buffer.extend_from_slice(&attribute_type_octets);
241 buffer.extend_from_slice(secret);
242 buffer.extend_from_slice(&random_vector.value);
243 let mut intermediate = md5::compute(&buffer);
244 for j in 0..chunk_size {
246 input[j] ^= intermediate[j];
247 }
248
249 if n_chunks > 1 {
250 buffer.clear();
252 buffer.extend_from_slice(secret);
253
254 for i in 1..n_chunks {
256 let prev_chunk_start = (i - 1) * chunk_size;
257 let chunk_start = prev_chunk_start + chunk_size;
258
259 buffer.truncate(secret.len());
261
262 buffer.extend_from_slice(&input[prev_chunk_start..chunk_start]);
264 intermediate = md5::compute(&buffer);
265
266 for j in 0..chunk_size {
268 input[chunk_start + j] ^= intermediate[j];
269 }
270 }
271 }
272
273 Hidden(types::Hidden {
274 attribute_type: u16::from_be_bytes(attribute_type_octets),
275 value: input,
276 })
277 }
278 }
279 }
280
281 pub fn reveal(self, secret: &[u8], random_vector: &types::RandomVector) -> DecodeResult<Self> {
290 if let Hidden(mut hidden) = self {
291 let chunk_size: usize = Self::CRYPTO_CHUNK_SIZE;
292
293 let chunk_data = &mut hidden.value;
294
295 if chunk_data.is_empty() {
296 return Err(DecodeError::EmptyHiddenAVP);
297 }
298 if chunk_data.len() % chunk_size != 0 {
299 return Err(DecodeError::MisalignedHiddenAVP);
300 }
301
302 let n_chunks = chunk_data.len() / chunk_size;
303
304 let buffer_length =
306 Self::ATTRIBUTE_TYPE_SIZE + secret.len() + random_vector.value.len();
307 let mut buffer = Vec::with_capacity(buffer_length);
308
309 if n_chunks > 1 {
310 buffer.extend_from_slice(secret);
312
313 for i in (1..n_chunks).rev() {
315 let prev_chunk_start = (i - 1) * chunk_size;
316 let chunk_start = prev_chunk_start + chunk_size;
317
318 buffer.truncate(secret.len());
320
321 buffer.extend_from_slice(&chunk_data[prev_chunk_start..chunk_start]);
323 let intermediate = md5::compute(&buffer);
324
325 for j in 0..chunk_size {
327 chunk_data[chunk_start + j] ^= intermediate[j];
328 }
329 }
330 }
331
332 buffer.clear();
334 buffer.extend_from_slice(&hidden.attribute_type.to_be_bytes());
335 buffer.extend_from_slice(secret);
336 buffer.extend_from_slice(&random_vector.value);
337 let intermediate = md5::compute(&buffer);
338
339 for j in 0..chunk_size {
341 chunk_data[j] ^= intermediate[j];
342 }
343
344 let mut reader = SliceReader::from(chunk_data.deref_mut());
346 let total_length = unsafe { reader.read_u16_be_unchecked() };
347 if !(Header::LENGTH..=Self::MAX_LENGTH).contains(&total_length) {
348 return Err(DecodeError::InvalidOriginalAVPLength(total_length));
349 }
350 let payload_length = total_length - Header::LENGTH;
351
352 let mut payload_reader = reader.subreader(payload_length as usize);
354
355 return decode_avp(hidden.attribute_type, &mut payload_reader);
356 }
357
358 Ok(self)
359 }
360
361 #[inline]
366 pub fn try_read_greedy<T: Borrow<[u8]>>(
367 reader: &mut impl Reader<T>,
368 ) -> Vec<DecodeResult<Self>> {
369 let mut result = Vec::new();
370 while let Some(header) = Header::try_read(reader) {
371 if header.payload_length as usize > reader.len() {
372 result.push(Err(DecodeError::InvalidAVPLength(header.payload_length)));
373 break;
374 }
375 if header.vendor_id != 0 {
376 result.push(Err(DecodeError::UnsupportedVendorId(header.vendor_id)));
377 reader.skip_bytes(header.payload_length as usize);
378 continue;
379 }
380
381 let avp = if header.flags.is_hidden() {
382 let hidden_data = reader
384 .bytes(header.payload_length as usize)
385 .map(|x| x.borrow().to_owned())
386 .unwrap_or_default();
387 Ok(Self::Hidden(types::Hidden {
388 attribute_type: header.attribute_type,
389 value: hidden_data,
390 }))
391 } else {
392 let mut subreader = reader.subreader(header.payload_length as usize);
394 decode_avp(header.attribute_type, &mut subreader)
395 };
396 result.push(avp);
397 }
398
399 result
400 }
401
402 #[inline]
405 pub fn get_length(&self) -> usize {
406 QueryableAVP::get_length(self)
407 }
408
409 #[inline]
410 fn make_flags_and_length(is_mandatory: bool, is_hidden: bool, length: usize) -> [u8; 2] {
411 assert!(length <= Self::MAX_LENGTH as usize);
412
413 let msb = ((length >> 8) & 0x3) as u8;
414 let lsb = length as u8;
415 let m_bit = is_mandatory as u8;
416 let h_bit = (is_hidden as u8) << 1;
417 let octet1 = (msb << 6) | m_bit | h_bit;
418 let octet2 = lsb;
419 [octet1, octet2]
420 }
421
422 #[inline]
425 pub fn write(&self, writer: &mut impl Writer) {
426 const VENDOR_ID: u16 = 0;
427 const IS_MANDATORY: bool = true;
428
429 let start_position = writer.len();
431
432 writer.write_bytes(&[0, 0]);
434
435 writer.write_u16_be(VENDOR_ID);
437
438 WritableAVP::write(self, writer);
440
441 let end_position = writer.len();
443 let length = end_position - start_position;
444
445 let is_hidden = matches!(self, Hidden(_));
446
447 let flags_and_length = Self::make_flags_and_length(IS_MANDATORY, is_hidden, length);
448
449 writer.write_bytes_at(&flags_and_length, start_position);
451 }
452}