gmcrypto_core/asn1/
ciphertext.rs1use alloc::vec::Vec;
41use crypto_bigint::U256;
42use subtle::ConstantTimeLess;
43
44use crate::sm2::curve::Fp;
45
46const HASH_LEN: usize = 32;
48
49#[derive(Clone, Debug)]
55pub struct Sm2Ciphertext {
56 pub x: U256,
58 pub y: U256,
60 pub hash: [u8; HASH_LEN],
62 pub ciphertext: Vec<u8>,
64}
65
66#[must_use]
68pub fn encode(ct: &Sm2Ciphertext) -> Vec<u8> {
69 let x_der = encode_integer(&ct.x.to_be_bytes());
70 let y_der = encode_integer(&ct.y.to_be_bytes());
71 let hash_der = encode_octet_string(&ct.hash);
72 let ciphertext_der = encode_octet_string(&ct.ciphertext);
73 let body_len = x_der.len() + y_der.len() + hash_der.len() + ciphertext_der.len();
74 let mut out = Vec::with_capacity(body_len + 8);
75 out.push(0x30); push_length(&mut out, body_len);
77 out.extend_from_slice(&x_der);
78 out.extend_from_slice(&y_der);
79 out.extend_from_slice(&hash_der);
80 out.extend_from_slice(&ciphertext_der);
81 out
82}
83
84#[must_use]
88pub fn decode(input: &[u8]) -> Option<Sm2Ciphertext> {
89 let (tag, rest) = input.split_first()?;
90 if *tag != 0x30 {
91 return None;
92 }
93 let (body_len, rest) = read_length(rest)?;
94 if rest.len() != body_len {
95 return None;
96 }
97 let (x, rest) = read_integer(rest)?;
98 let (y, rest) = read_integer(rest)?;
99 let (hash_bytes, rest) = read_octet_string(rest)?;
100 let (ciphertext, rest) = read_octet_string(rest)?;
101 if !rest.is_empty() {
102 return None;
103 }
104 if hash_bytes.len() != HASH_LEN {
105 return None;
106 }
107 let mut hash = [0u8; HASH_LEN];
108 hash.copy_from_slice(hash_bytes);
109 Some(Sm2Ciphertext {
110 x,
111 y,
112 hash,
113 ciphertext: ciphertext.to_vec(),
114 })
115}
116
117fn encode_integer(value_be: &[u8]) -> Vec<u8> {
126 let mut start = 0;
130 while start < value_be.len() - 1 && value_be[start] == 0 {
131 start += 1;
132 }
133 let trimmed = &value_be[start..];
134 let needs_pad = (trimmed[0] & 0x80) != 0;
135 let int_len = trimmed.len() + usize::from(needs_pad);
136 let mut out = Vec::with_capacity(int_len + 4);
137 out.push(0x02); push_length(&mut out, int_len);
139 if needs_pad {
140 out.push(0x00);
141 }
142 out.extend_from_slice(trimmed);
143 out
144}
145
146fn read_integer(input: &[u8]) -> Option<(U256, &[u8])> {
147 let (tag, rest) = input.split_first()?;
148 if *tag != 0x02 {
149 return None;
150 }
151 let (int_len, rest) = read_length(rest)?;
152 if rest.len() < int_len {
153 return None;
154 }
155 let (int_bytes, rest_after) = rest.split_at(int_len);
156
157 if int_bytes.is_empty() {
177 return None;
178 }
179 if int_bytes[0] & 0x80 != 0 {
180 return None;
181 }
182 let bytes = if int_bytes[0] == 0x00 {
183 if int_bytes.len() == 1 {
184 int_bytes
186 } else if int_bytes[1] & 0x80 == 0 {
187 return None;
190 } else {
191 &int_bytes[1..]
192 }
193 } else {
194 int_bytes
195 };
196 if bytes.len() > 32 {
197 return None;
198 }
199 let mut padded = [0u8; 32];
200 padded[32 - bytes.len()..].copy_from_slice(bytes);
201 let value = U256::from_be_slice(&padded);
202 let in_field: bool = value.ct_lt(Fp::MODULUS.as_ref()).into();
206 if !in_field {
207 return None;
208 }
209 Some((value, rest_after))
210}
211
212fn encode_octet_string(value: &[u8]) -> Vec<u8> {
213 let mut out = Vec::with_capacity(value.len() + 4);
214 out.push(0x04); push_length(&mut out, value.len());
216 out.extend_from_slice(value);
217 out
218}
219
220fn read_octet_string(input: &[u8]) -> Option<(&[u8], &[u8])> {
221 let (tag, rest) = input.split_first()?;
222 if *tag != 0x04 {
223 return None;
224 }
225 let (len, rest) = read_length(rest)?;
226 if rest.len() < len {
227 return None;
228 }
229 Some(rest.split_at(len))
230}
231
232fn push_length(out: &mut Vec<u8>, len: usize) {
233 if len < 128 {
234 #[allow(clippy::cast_possible_truncation)]
235 out.push(len as u8);
236 } else if len < 256 {
237 out.push(0x81);
238 #[allow(clippy::cast_possible_truncation)]
239 out.push(len as u8);
240 } else if len < 65_536 {
241 #[allow(clippy::cast_possible_truncation)]
242 {
243 out.push(0x82);
244 out.push((len >> 8) as u8);
245 out.push(len as u8);
246 }
247 } else if len < 16_777_216 {
248 #[allow(clippy::cast_possible_truncation)]
249 {
250 out.push(0x83);
251 out.push((len >> 16) as u8);
252 out.push((len >> 8) as u8);
253 out.push(len as u8);
254 }
255 } else {
256 panic!("ciphertext DER length overflow (> 16 MB)");
261 }
262}
263
264fn read_length(input: &[u8]) -> Option<(usize, &[u8])> {
265 let (first, rest) = input.split_first()?;
266 if *first < 0x80 {
267 Some((*first as usize, rest))
268 } else if *first == 0x81 {
269 let (b, rest) = rest.split_first()?;
270 if *b < 0x80 {
271 return None; }
273 Some((*b as usize, rest))
274 } else if *first == 0x82 {
275 let (hi, rest) = rest.split_first()?;
276 let (lo, rest) = rest.split_first()?;
277 let len = ((*hi as usize) << 8) | (*lo as usize);
278 if len < 256 {
279 return None; }
281 Some((len, rest))
282 } else if *first == 0x83 {
283 let (b2, rest) = rest.split_first()?;
284 let (b1, rest) = rest.split_first()?;
285 let (b0, rest) = rest.split_first()?;
286 let len = ((*b2 as usize) << 16) | ((*b1 as usize) << 8) | (*b0 as usize);
287 if len < 65_536 {
288 return None; }
290 Some((len, rest))
291 } else {
292 None }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 fn make_ct(ciphertext: Vec<u8>) -> Sm2Ciphertext {
301 Sm2Ciphertext {
302 x: U256::from_be_hex(
303 "1234567890ABCDEF1234567890ABCDEF1234567890ABCDEF1234567890ABCDEF",
304 ),
305 y: U256::from_be_hex(
306 "FEDCBA0987654321FEDCBA0987654321FEDCBA0987654321FEDCBA0987654321",
307 ),
308 hash: [0xa5u8; 32],
309 ciphertext,
310 }
311 }
312
313 #[test]
315 fn round_trip_short() {
316 let ct = make_ct(b"hello world".to_vec());
317 let der = encode(&ct);
318 let decoded = decode(&der).expect("decode round-trip");
319 assert_eq!(decoded.x, ct.x);
320 assert_eq!(decoded.y, ct.y);
321 assert_eq!(decoded.hash, ct.hash);
322 assert_eq!(decoded.ciphertext, ct.ciphertext);
323 }
324
325 #[test]
328 fn round_trip_x_high_bit_set() {
329 let mut ct = make_ct(b"x".to_vec());
330 ct.x =
331 U256::from_be_hex("FFEDCBA9876543210FEDCBA9876543210FEDCBA9876543210FEDCBA987654321");
332 let der = encode(&ct);
333 let decoded = decode(&der).expect("decode high-bit round-trip");
334 assert_eq!(decoded.x, ct.x);
335 }
336
337 #[test]
340 fn round_trip_medium_ciphertext_300_bytes() {
341 let mut payload = alloc::vec![0u8; 300];
342 for (i, b) in payload.iter_mut().enumerate() {
343 #[allow(clippy::cast_possible_truncation)]
344 {
345 *b = (i as u8).wrapping_mul(13);
346 }
347 }
348 let ct = make_ct(payload.clone());
349 let der = encode(&ct);
350 let decoded = decode(&der).expect("decode 300-byte round-trip");
351 assert_eq!(decoded.ciphertext, payload);
352 }
353
354 #[test]
357 fn round_trip_empty_ciphertext() {
358 let ct = make_ct(Vec::new());
359 let der = encode(&ct);
360 let decoded = decode(&der).expect("decode empty-ciphertext round-trip");
361 assert!(decoded.ciphertext.is_empty());
362 }
363
364 #[test]
366 fn rejects_malformed() {
367 assert!(decode(&[]).is_none(), "empty input");
368 assert!(decode(&[0x30]).is_none(), "truncated SEQUENCE header");
369 assert!(decode(&[0x31, 0x00]).is_none(), "wrong outer tag");
370 assert!(decode(&[0x30, 0x05, 0x02, 0x01, 0x01]).is_none());
372 }
373
374 #[test]
378 fn rejects_wrong_hash_length() {
379 let bad_hash = [0x55u8; 31];
381 let ciphertext = b"x";
382 let mut body = Vec::new();
383 body.extend_from_slice(&encode_integer(&[0x01]));
384 body.extend_from_slice(&encode_integer(&[0x02]));
385 body.extend_from_slice(&encode_octet_string(&bad_hash));
386 body.extend_from_slice(&encode_octet_string(ciphertext));
387 let mut der = Vec::new();
388 der.push(0x30);
389 push_length(&mut der, body.len());
390 der.extend_from_slice(&body);
391 assert!(
392 decode(&der).is_none(),
393 "31-byte HASH must be rejected; SM3 always produces 32 bytes"
394 );
395 }
396
397 #[test]
402 fn rejects_non_canonical_x_leading_zero() {
403 let mut body = Vec::new();
405 body.extend_from_slice(&[0x02, 0x02, 0x00, 0x01]); body.extend_from_slice(&encode_integer(&[0x02])); body.extend_from_slice(&encode_octet_string(&[0u8; 32]));
408 body.extend_from_slice(&encode_octet_string(b""));
409 let mut der = Vec::new();
410 der.push(0x30);
411 push_length(&mut der, body.len());
412 der.extend_from_slice(&body);
413 assert!(
414 decode(&der).is_none(),
415 "non-canonical 00-pad on x must be rejected"
416 );
417 }
418
419 #[test]
422 fn rejects_negative_y_encoding() {
423 let mut body = Vec::new();
424 body.extend_from_slice(&encode_integer(&[0x01]));
425 body.extend_from_slice(&[0x02, 0x01, 0x80]); body.extend_from_slice(&encode_octet_string(&[0u8; 32]));
427 body.extend_from_slice(&encode_octet_string(b""));
428 let mut der = Vec::new();
429 der.push(0x30);
430 push_length(&mut der, body.len());
431 der.extend_from_slice(&body);
432 assert!(decode(&der).is_none());
433 }
434
435 #[test]
438 fn rejects_trailing_bytes() {
439 let ct = make_ct(b"hi".to_vec());
440 let mut der = encode(&ct);
441 der.push(0xff); assert!(decode(&der).is_none());
443 }
444
445 #[test]
451 fn round_trip_x_zero() {
452 let mut ct = make_ct(b"z".to_vec());
453 ct.x = U256::ZERO;
454 let der = encode(&ct);
455 let decoded = decode(&der).expect("decode round-trip with x = 0");
456 assert_eq!(decoded.x, U256::ZERO);
457 assert_eq!(decoded.y, ct.y);
458 }
459
460 #[test]
466 fn rejects_x_at_or_above_p() {
467 let p = *Fp::MODULUS.as_ref();
471 let p_bytes = p.to_be_bytes();
472 let mut body = Vec::new();
473 body.extend_from_slice(&encode_integer(&p_bytes));
474 body.extend_from_slice(&encode_integer(&[0x01]));
475 body.extend_from_slice(&encode_octet_string(&[0u8; 32]));
476 body.extend_from_slice(&encode_octet_string(b""));
477 let mut der = Vec::new();
478 der.push(0x30);
479 push_length(&mut der, body.len());
480 der.extend_from_slice(&body);
481 assert!(
482 decode(&der).is_none(),
483 "x = p is not a field element and must be rejected"
484 );
485
486 let max_bytes = [0xffu8; 32];
488 let mut body = Vec::new();
489 body.extend_from_slice(&encode_integer(&max_bytes));
490 body.extend_from_slice(&encode_integer(&[0x01]));
491 body.extend_from_slice(&encode_octet_string(&[0u8; 32]));
492 body.extend_from_slice(&encode_octet_string(b""));
493 let mut der = Vec::new();
494 der.push(0x30);
495 push_length(&mut der, body.len());
496 der.extend_from_slice(&body);
497 assert!(decode(&der).is_none(), "x = 2^256 - 1 must be rejected");
498 }
499
500 #[test]
503 fn round_trip_x_p_minus_one() {
504 let p_minus_one = Fp::MODULUS.as_ref().wrapping_sub(&U256::ONE);
505 let mut ct = make_ct(b"q".to_vec());
506 ct.x = p_minus_one;
507 let der = encode(&ct);
508 let decoded = decode(&der).expect("decode round-trip with x = p - 1");
509 assert_eq!(decoded.x, p_minus_one);
510 }
511
512 #[test]
515 fn round_trip_65536_byte_ciphertext_uses_3byte_length() {
516 let payload = alloc::vec![0xa5u8; 65_536];
517 let ct = make_ct(payload.clone());
518 let der = encode(&ct);
519 let decoded = decode(&der).expect("decode 65,536-byte round-trip");
523 assert_eq!(decoded.ciphertext, payload);
524 }
525}