1use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
32use chrono::{DateTime, Utc};
33use const_oid::ObjectIdentifier;
34use std::io::{Error as IoError, ErrorKind};
35
36use super::record::TimestampRecord;
37use crate::{DocumentId, Error, Result};
38
39pub mod servers {
41 pub const FREETSA: &str = "https://freetsa.org/tsr";
43 pub const SECTIGO: &str = "http://timestamp.sectigo.com";
45 pub const DIGICERT: &str = "http://timestamp.digicert.com";
47}
48
49const OID_SHA256: ObjectIdentifier = ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.2.1");
51const OID_SHA384: ObjectIdentifier = ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.2.2");
52const OID_SHA512: ObjectIdentifier = ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.2.3");
53
54#[derive(Debug, Clone)]
59pub struct Rfc3161Client {
60 servers: Vec<String>,
62 client: reqwest::Client,
64 timeout_secs: u64,
66 cert_req: bool,
68}
69
70impl Default for Rfc3161Client {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76impl Rfc3161Client {
77 #[must_use]
79 pub fn new() -> Self {
80 Self {
81 servers: vec![
82 servers::FREETSA.to_string(),
83 servers::SECTIGO.to_string(),
84 servers::DIGICERT.to_string(),
85 ],
86 client: reqwest::Client::new(),
87 timeout_secs: 30,
88 cert_req: true,
89 }
90 }
91
92 #[must_use]
94 pub fn with_server(server: impl Into<String>) -> Self {
95 Self {
96 servers: vec![server.into()],
97 client: reqwest::Client::new(),
98 timeout_secs: 30,
99 cert_req: true,
100 }
101 }
102
103 #[must_use]
105 pub fn with_servers(servers: Vec<String>) -> Self {
106 Self {
107 servers,
108 client: reqwest::Client::new(),
109 timeout_secs: 30,
110 cert_req: true,
111 }
112 }
113
114 #[must_use]
116 pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
117 self.timeout_secs = timeout_secs;
118 self
119 }
120
121 #[must_use]
123 pub fn with_cert_req(mut self, cert_req: bool) -> Self {
124 self.cert_req = cert_req;
125 self
126 }
127
128 pub async fn acquire_timestamp(&self, document_id: &DocumentId) -> Result<TimestampRecord> {
140 let hash_oid = match document_id.algorithm().as_str() {
142 "sha256" => OID_SHA256,
143 "sha384" => OID_SHA384,
144 "sha512" => OID_SHA512,
145 alg => {
146 return Err(Error::InvalidManifest {
147 reason: format!("Unsupported hash algorithm for RFC 3161: {alg}"),
148 })
149 }
150 };
151
152 let hash_hex = document_id.hex_digest();
154 let hash_bytes = hex_to_bytes(&hash_hex)?;
155
156 let mut nonce_bytes = [0u8; 8];
158 getrandom::fill(&mut nonce_bytes).map_err(|e| Error::Network {
159 message: format!("System RNG failed: {e}"),
160 })?;
161 let nonce = u64::from_be_bytes(nonce_bytes);
162
163 let request_der = encode_timestamp_request(&hash_bytes, hash_oid, nonce, self.cert_req);
165
166 let mut last_error = None;
168 for server_url in &self.servers {
169 match self.submit_to_tsa(server_url, &request_der).await {
170 Ok((token, time)) => {
171 return Ok(TimestampRecord::rfc3161(
172 server_url,
173 time,
174 BASE64.encode(&token),
175 ));
176 }
177 Err(e) => {
178 last_error = Some(e);
179 }
180 }
181 }
182
183 Err(last_error.unwrap_or_else(|| {
184 Error::Io(IoError::new(
185 ErrorKind::NotConnected,
186 "No TSA servers configured",
187 ))
188 }))
189 }
190
191 async fn submit_to_tsa(
193 &self,
194 server_url: &str,
195 request_der: &[u8],
196 ) -> Result<(Vec<u8>, DateTime<Utc>)> {
197 let response = self
198 .client
199 .post(server_url)
200 .timeout(std::time::Duration::from_secs(self.timeout_secs))
201 .header("Content-Type", "application/timestamp-query")
202 .body(request_der.to_vec())
203 .send()
204 .await
205 .map_err(|e| {
206 Error::Io(IoError::new(
207 ErrorKind::ConnectionRefused,
208 format!("Failed to contact TSA server: {e}"),
209 ))
210 })?;
211
212 if !response.status().is_success() {
213 let status = response.status();
214 let text = response.text().await.unwrap_or_default();
215 return Err(Error::Io(IoError::other(format!(
216 "TSA server returned error: {status} {text}"
217 ))));
218 }
219
220 let response_der = response.bytes().await.map_err(|e| {
221 Error::Io(IoError::new(
222 ErrorKind::InvalidData,
223 format!("Failed to read TSA response: {e}"),
224 ))
225 })?;
226
227 let (status, token) = parse_timestamp_response(&response_der)?;
229
230 if status > 1 {
232 let status_text = match status {
233 2 => "rejection",
234 3 => "waiting",
235 4 => "revocation warning",
236 5 => "revocation notification",
237 _ => "unknown error",
238 };
239 return Err(Error::Io(IoError::other(format!(
240 "TSA rejected request: {status_text}"
241 ))));
242 }
243
244 let time = Utc::now();
246
247 Ok((token, time))
248 }
249
250 pub fn verify_timestamp(
263 &self,
264 timestamp: &TimestampRecord,
265 _document_id: &DocumentId,
266 ) -> Result<TimestampVerification> {
267 let token_bytes = BASE64.decode(×tamp.token).map_err(|e| {
269 Error::Io(IoError::new(
270 ErrorKind::InvalidData,
271 format!("Invalid timestamp token: {e}"),
272 ))
273 })?;
274
275 if token_bytes.is_empty() {
277 return Ok(TimestampVerification {
278 valid: false,
279 status: VerificationStatus::Invalid,
280 message: "Empty token".to_string(),
281 });
282 }
283
284 if token_bytes[0] != 0x30 {
286 return Ok(TimestampVerification {
287 valid: false,
288 status: VerificationStatus::Invalid,
289 message: "Invalid ASN.1 structure".to_string(),
290 });
291 }
292
293 Ok(TimestampVerification {
294 valid: true,
295 status: VerificationStatus::Valid,
296 message: "Timestamp token is well-formed".to_string(),
297 })
298 }
299}
300
301#[derive(Debug, Clone)]
303pub struct TimestampVerification {
304 pub valid: bool,
306 pub status: VerificationStatus,
308 pub message: String,
310}
311
312#[derive(Debug, Clone, Copy, PartialEq, Eq)]
314pub enum VerificationStatus {
315 Valid,
317 Invalid,
319}
320
321fn encode_timestamp_request(
330 hash: &[u8],
331 hash_oid: ObjectIdentifier,
332 nonce: u64,
333 cert_req: bool,
334) -> Vec<u8> {
335 let mut content = Vec::new();
336
337 content.extend_from_slice(&encode_integer(1));
339
340 content.extend_from_slice(&encode_message_imprint(hash, hash_oid));
342
343 content.extend_from_slice(&encode_integer_u64(nonce));
345
346 if cert_req {
348 content.extend_from_slice(&[0x01, 0x01, 0xff]); }
350
351 encode_sequence(&content)
353}
354
355fn encode_message_imprint(hash: &[u8], hash_oid: ObjectIdentifier) -> Vec<u8> {
357 let mut content = Vec::new();
358
359 let mut alg_id = Vec::new();
361 alg_id.extend_from_slice(&encode_oid(&hash_oid));
362 alg_id.extend_from_slice(&[0x05, 0x00]); content.extend_from_slice(&encode_sequence(&alg_id));
364
365 content.extend_from_slice(&encode_octet_string(hash));
367
368 encode_sequence(&content)
369}
370
371fn encode_sequence(content: &[u8]) -> Vec<u8> {
373 let mut result = vec![0x30]; result.extend_from_slice(&encode_length(content.len()));
375 result.extend_from_slice(content);
376 result
377}
378
379fn encode_integer(value: u8) -> Vec<u8> {
381 vec![0x02, 0x01, value]
382}
383
384#[allow(clippy::cast_possible_truncation)]
386fn encode_integer_u64(value: u64) -> Vec<u8> {
387 let bytes = value.to_be_bytes();
388 let start = bytes.iter().position(|&b| b != 0).unwrap_or(7);
390 let significant = &bytes[start..];
391
392 let mut result = vec![0x02]; if significant.first().is_some_and(|&b| b & 0x80 != 0) {
397 result.push((significant.len() + 1) as u8);
398 result.push(0x00);
399 } else {
400 result.push(significant.len() as u8);
401 }
402 result.extend_from_slice(significant);
403 result
404}
405
406fn encode_oid(oid: &ObjectIdentifier) -> Vec<u8> {
408 let oid_bytes = oid.as_bytes();
409 let mut result = vec![0x06]; result.extend_from_slice(&encode_length(oid_bytes.len()));
411 result.extend_from_slice(oid_bytes);
412 result
413}
414
415fn encode_octet_string(data: &[u8]) -> Vec<u8> {
417 let mut result = vec![0x04]; result.extend_from_slice(&encode_length(data.len()));
419 result.extend_from_slice(data);
420 result
421}
422
423#[allow(clippy::cast_possible_truncation)]
427fn encode_length(len: usize) -> Vec<u8> {
428 if len < 128 {
430 vec![len as u8]
431 } else if len < 256 {
432 vec![0x81, len as u8]
433 } else {
434 vec![0x82, (len >> 8) as u8, (len & 0xff) as u8]
435 }
436}
437
438fn parse_timestamp_response(data: &[u8]) -> Result<(u8, Vec<u8>)> {
440 if data.is_empty() || data[0] != 0x30 {
447 return Err(Error::Io(IoError::new(
448 ErrorKind::InvalidData,
449 "Invalid TSA response: not a SEQUENCE",
450 )));
451 }
452
453 let (content, _) = parse_tlv(data)?;
454
455 if content.is_empty() || content[0] != 0x30 {
457 return Err(Error::Io(IoError::new(
458 ErrorKind::InvalidData,
459 "Invalid PKIStatusInfo",
460 )));
461 }
462
463 let (status_info, rest) = parse_tlv(content)?;
464
465 if status_info.is_empty() || status_info[0] != 0x02 {
467 return Err(Error::Io(IoError::new(
468 ErrorKind::InvalidData,
469 "Invalid status in PKIStatusInfo",
470 )));
471 }
472
473 let (status_bytes, _) = parse_tlv(status_info)?;
474 let status = *status_bytes.last().unwrap_or(&255);
475
476 if rest.is_empty() {
478 return Err(Error::Io(IoError::new(
479 ErrorKind::InvalidData,
480 "No timestamp token in response",
481 )));
482 }
483
484 if rest[0] != 0x30 {
486 return Err(Error::Io(IoError::new(
487 ErrorKind::InvalidData,
488 "Invalid timestamp token format",
489 )));
490 }
491
492 let token_len = get_tlv_total_length(rest)?;
494 let token = rest[..token_len].to_vec();
495
496 Ok((status, token))
497}
498
499fn parse_tlv(data: &[u8]) -> Result<(&[u8], &[u8])> {
501 if data.len() < 2 {
502 return Err(Error::Io(IoError::new(
503 ErrorKind::InvalidData,
504 "TLV too short",
505 )));
506 }
507
508 let (len, header_len) = if data[1] < 128 {
509 (data[1] as usize, 2)
510 } else if data[1] == 0x81 {
511 if data.len() < 3 {
512 return Err(Error::Io(IoError::new(
513 ErrorKind::InvalidData,
514 "Invalid length encoding",
515 )));
516 }
517 (data[2] as usize, 3)
518 } else if data[1] == 0x82 {
519 if data.len() < 4 {
520 return Err(Error::Io(IoError::new(
521 ErrorKind::InvalidData,
522 "Invalid length encoding",
523 )));
524 }
525 (((data[2] as usize) << 8) | (data[3] as usize), 4)
526 } else {
527 return Err(Error::Io(IoError::new(
528 ErrorKind::InvalidData,
529 "Unsupported length encoding",
530 )));
531 };
532
533 if data.len() < header_len + len {
534 return Err(Error::Io(IoError::new(
535 ErrorKind::InvalidData,
536 "TLV length exceeds data",
537 )));
538 }
539
540 let value = &data[header_len..header_len + len];
541 let rest = &data[header_len + len..];
542 Ok((value, rest))
543}
544
545fn get_tlv_total_length(data: &[u8]) -> Result<usize> {
547 if data.len() < 2 {
548 return Err(Error::Io(IoError::new(
549 ErrorKind::InvalidData,
550 "TLV too short",
551 )));
552 }
553
554 let (len, header_len) = if data[1] < 128 {
555 (data[1] as usize, 2)
556 } else if data[1] == 0x81 {
557 if data.len() < 3 {
558 return Err(Error::Io(IoError::new(
559 ErrorKind::InvalidData,
560 "Invalid length encoding",
561 )));
562 }
563 (data[2] as usize, 3)
564 } else if data[1] == 0x82 {
565 if data.len() < 4 {
566 return Err(Error::Io(IoError::new(
567 ErrorKind::InvalidData,
568 "Invalid length encoding",
569 )));
570 }
571 (((data[2] as usize) << 8) | (data[3] as usize), 4)
572 } else {
573 return Err(Error::Io(IoError::new(
574 ErrorKind::InvalidData,
575 "Unsupported length encoding",
576 )));
577 };
578
579 Ok(header_len + len)
580}
581
582fn hex_to_bytes(hex: &str) -> Result<Vec<u8>> {
584 let hex = hex.trim();
585 if !hex.len().is_multiple_of(2) {
586 return Err(Error::InvalidHashFormat {
587 value: "Invalid hex string length".to_string(),
588 });
589 }
590
591 (0..hex.len())
592 .step_by(2)
593 .map(|i| {
594 u8::from_str_radix(&hex[i..i + 2], 16).map_err(|_| Error::InvalidHashFormat {
595 value: "Invalid hex character".to_string(),
596 })
597 })
598 .collect()
599}
600
601#[cfg(test)]
602mod tests {
603 use super::*;
604 use crate::{HashAlgorithm, Hasher};
605
606 #[test]
607 fn test_rfc3161_client_creation() {
608 let client = Rfc3161Client::new();
609 assert!(!client.servers.is_empty());
610 }
611
612 #[test]
613 fn test_rfc3161_client_custom_server() {
614 let client = Rfc3161Client::with_server("https://custom.example.com/tsa");
615 assert_eq!(client.servers.len(), 1);
616 assert_eq!(client.servers[0], "https://custom.example.com/tsa");
617 }
618
619 #[test]
620 fn test_hex_to_bytes() {
621 let bytes = hex_to_bytes("deadbeef").unwrap();
622 assert_eq!(bytes, vec![0xde, 0xad, 0xbe, 0xef]);
623 }
624
625 #[test]
626 fn test_hex_to_bytes_invalid() {
627 assert!(hex_to_bytes("deadbee").is_err()); assert!(hex_to_bytes("deadbeeg").is_err()); }
630
631 #[test]
632 fn test_timestamp_req_encoding() {
633 let hash = vec![0u8; 32]; let req = encode_timestamp_request(&hash, OID_SHA256, 12345, true);
635 assert!(!req.is_empty());
637 assert_eq!(req[0], 0x30); }
639
640 #[test]
641 fn test_encode_integer() {
642 let encoded = encode_integer(1);
643 assert_eq!(encoded, vec![0x02, 0x01, 0x01]);
644 }
645
646 #[test]
647 fn test_encode_integer_u64() {
648 let encoded = encode_integer_u64(256);
649 assert_eq!(encoded, vec![0x02, 0x02, 0x01, 0x00]);
651 }
652
653 #[test]
654 fn test_verify_empty_token() {
655 let client = Rfc3161Client::new();
656 let doc_id = Hasher::hash(HashAlgorithm::Sha256, b"test");
657 let timestamp = TimestampRecord::rfc3161("https://example.com", Utc::now(), "");
658
659 let result = client.verify_timestamp(×tamp, &doc_id).unwrap();
660 assert!(!result.valid);
661 assert_eq!(result.status, VerificationStatus::Invalid);
662 }
663
664 #[test]
665 fn test_verify_invalid_base64() {
666 let client = Rfc3161Client::new();
667 let doc_id = Hasher::hash(HashAlgorithm::Sha256, b"test");
668 let timestamp =
669 TimestampRecord::rfc3161("https://example.com", Utc::now(), "!!!invalid!!!");
670
671 let result = client.verify_timestamp(×tamp, &doc_id);
672 assert!(result.is_err());
673 }
674
675 #[test]
676 fn test_verify_valid_structure() {
677 let client = Rfc3161Client::new();
678 let doc_id = Hasher::hash(HashAlgorithm::Sha256, b"test");
679 let token = BASE64.encode([0x30, 0x00]);
681 let timestamp = TimestampRecord::rfc3161("https://example.com", Utc::now(), token);
682
683 let result = client.verify_timestamp(×tamp, &doc_id).unwrap();
684 assert!(result.valid);
685 assert_eq!(result.status, VerificationStatus::Valid);
686 }
687}