use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use chrono::{DateTime, Utc};
use const_oid::ObjectIdentifier;
use std::io::{Error as IoError, ErrorKind};
use super::record::TimestampRecord;
use crate::{DocumentId, Error, Result};
pub mod servers {
pub const FREETSA: &str = "https://freetsa.org/tsr";
pub const SECTIGO: &str = "http://timestamp.sectigo.com";
pub const DIGICERT: &str = "http://timestamp.digicert.com";
}
const OID_SHA256: ObjectIdentifier = ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.2.1");
const OID_SHA384: ObjectIdentifier = ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.2.2");
const OID_SHA512: ObjectIdentifier = ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.2.3");
#[derive(Debug, Clone)]
pub struct Rfc3161Client {
servers: Vec<String>,
client: reqwest::Client,
timeout_secs: u64,
cert_req: bool,
}
impl Default for Rfc3161Client {
fn default() -> Self {
Self::new()
}
}
impl Rfc3161Client {
#[must_use]
pub fn new() -> Self {
Self {
servers: vec![
servers::FREETSA.to_string(),
servers::SECTIGO.to_string(),
servers::DIGICERT.to_string(),
],
client: reqwest::Client::new(),
timeout_secs: 30,
cert_req: true,
}
}
#[must_use]
pub fn with_server(server: impl Into<String>) -> Self {
Self {
servers: vec![server.into()],
client: reqwest::Client::new(),
timeout_secs: 30,
cert_req: true,
}
}
#[must_use]
pub fn with_servers(servers: Vec<String>) -> Self {
Self {
servers,
client: reqwest::Client::new(),
timeout_secs: 30,
cert_req: true,
}
}
#[must_use]
pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
self.timeout_secs = timeout_secs;
self
}
#[must_use]
pub fn with_cert_req(mut self, cert_req: bool) -> Self {
self.cert_req = cert_req;
self
}
pub async fn acquire_timestamp(&self, document_id: &DocumentId) -> Result<TimestampRecord> {
let hash_oid = match document_id.algorithm().as_str() {
"sha256" => OID_SHA256,
"sha384" => OID_SHA384,
"sha512" => OID_SHA512,
alg => {
return Err(Error::InvalidManifest {
reason: format!("Unsupported hash algorithm for RFC 3161: {alg}"),
})
}
};
let hash_hex = document_id.hex_digest();
let hash_bytes = hex_to_bytes(&hash_hex)?;
let mut nonce_bytes = [0u8; 8];
getrandom::fill(&mut nonce_bytes).map_err(|e| Error::Network {
message: format!("System RNG failed: {e}"),
})?;
let nonce = u64::from_be_bytes(nonce_bytes);
let request_der = encode_timestamp_request(&hash_bytes, hash_oid, nonce, self.cert_req);
let mut last_error = None;
for server_url in &self.servers {
match self.submit_to_tsa(server_url, &request_der).await {
Ok((token, time)) => {
return Ok(TimestampRecord::rfc3161(
server_url,
time,
BASE64.encode(&token),
));
}
Err(e) => {
last_error = Some(e);
}
}
}
Err(last_error.unwrap_or_else(|| {
Error::Io(IoError::new(
ErrorKind::NotConnected,
"No TSA servers configured",
))
}))
}
async fn submit_to_tsa(
&self,
server_url: &str,
request_der: &[u8],
) -> Result<(Vec<u8>, DateTime<Utc>)> {
let response = self
.client
.post(server_url)
.timeout(std::time::Duration::from_secs(self.timeout_secs))
.header("Content-Type", "application/timestamp-query")
.body(request_der.to_vec())
.send()
.await
.map_err(|e| {
Error::Io(IoError::new(
ErrorKind::ConnectionRefused,
format!("Failed to contact TSA server: {e}"),
))
})?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
return Err(Error::Io(IoError::other(format!(
"TSA server returned error: {status} {text}"
))));
}
let response_der = response.bytes().await.map_err(|e| {
Error::Io(IoError::new(
ErrorKind::InvalidData,
format!("Failed to read TSA response: {e}"),
))
})?;
let (status, token) = parse_timestamp_response(&response_der)?;
if status > 1 {
let status_text = match status {
2 => "rejection",
3 => "waiting",
4 => "revocation warning",
5 => "revocation notification",
_ => "unknown error",
};
return Err(Error::Io(IoError::other(format!(
"TSA rejected request: {status_text}"
))));
}
let time = Utc::now();
Ok((token, time))
}
pub fn verify_timestamp(
&self,
timestamp: &TimestampRecord,
_document_id: &DocumentId,
) -> Result<TimestampVerification> {
let token_bytes = BASE64.decode(×tamp.token).map_err(|e| {
Error::Io(IoError::new(
ErrorKind::InvalidData,
format!("Invalid timestamp token: {e}"),
))
})?;
if token_bytes.is_empty() {
return Ok(TimestampVerification {
valid: false,
status: VerificationStatus::Invalid,
message: "Empty token".to_string(),
});
}
if token_bytes[0] != 0x30 {
return Ok(TimestampVerification {
valid: false,
status: VerificationStatus::Invalid,
message: "Invalid ASN.1 structure".to_string(),
});
}
Ok(TimestampVerification {
valid: true,
status: VerificationStatus::Valid,
message: "Timestamp token is well-formed".to_string(),
})
}
}
#[derive(Debug, Clone)]
pub struct TimestampVerification {
pub valid: bool,
pub status: VerificationStatus,
pub message: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VerificationStatus {
Valid,
Invalid,
}
fn encode_timestamp_request(
hash: &[u8],
hash_oid: ObjectIdentifier,
nonce: u64,
cert_req: bool,
) -> Vec<u8> {
let mut content = Vec::new();
content.extend_from_slice(&encode_integer(1));
content.extend_from_slice(&encode_message_imprint(hash, hash_oid));
content.extend_from_slice(&encode_integer_u64(nonce));
if cert_req {
content.extend_from_slice(&[0x01, 0x01, 0xff]); }
encode_sequence(&content)
}
fn encode_message_imprint(hash: &[u8], hash_oid: ObjectIdentifier) -> Vec<u8> {
let mut content = Vec::new();
let mut alg_id = Vec::new();
alg_id.extend_from_slice(&encode_oid(&hash_oid));
alg_id.extend_from_slice(&[0x05, 0x00]); content.extend_from_slice(&encode_sequence(&alg_id));
content.extend_from_slice(&encode_octet_string(hash));
encode_sequence(&content)
}
fn encode_sequence(content: &[u8]) -> Vec<u8> {
let mut result = vec![0x30]; result.extend_from_slice(&encode_length(content.len()));
result.extend_from_slice(content);
result
}
fn encode_integer(value: u8) -> Vec<u8> {
vec![0x02, 0x01, value]
}
#[allow(clippy::cast_possible_truncation)]
fn encode_integer_u64(value: u64) -> Vec<u8> {
let bytes = value.to_be_bytes();
let start = bytes.iter().position(|&b| b != 0).unwrap_or(7);
let significant = &bytes[start..];
let mut result = vec![0x02];
if significant.first().is_some_and(|&b| b & 0x80 != 0) {
result.push((significant.len() + 1) as u8);
result.push(0x00);
} else {
result.push(significant.len() as u8);
}
result.extend_from_slice(significant);
result
}
fn encode_oid(oid: &ObjectIdentifier) -> Vec<u8> {
let oid_bytes = oid.as_bytes();
let mut result = vec![0x06]; result.extend_from_slice(&encode_length(oid_bytes.len()));
result.extend_from_slice(oid_bytes);
result
}
fn encode_octet_string(data: &[u8]) -> Vec<u8> {
let mut result = vec![0x04]; result.extend_from_slice(&encode_length(data.len()));
result.extend_from_slice(data);
result
}
#[allow(clippy::cast_possible_truncation)]
fn encode_length(len: usize) -> Vec<u8> {
if len < 128 {
vec![len as u8]
} else if len < 256 {
vec![0x81, len as u8]
} else {
vec![0x82, (len >> 8) as u8, (len & 0xff) as u8]
}
}
fn parse_timestamp_response(data: &[u8]) -> Result<(u8, Vec<u8>)> {
if data.is_empty() || data[0] != 0x30 {
return Err(Error::Io(IoError::new(
ErrorKind::InvalidData,
"Invalid TSA response: not a SEQUENCE",
)));
}
let (content, _) = parse_tlv(data)?;
if content.is_empty() || content[0] != 0x30 {
return Err(Error::Io(IoError::new(
ErrorKind::InvalidData,
"Invalid PKIStatusInfo",
)));
}
let (status_info, rest) = parse_tlv(content)?;
if status_info.is_empty() || status_info[0] != 0x02 {
return Err(Error::Io(IoError::new(
ErrorKind::InvalidData,
"Invalid status in PKIStatusInfo",
)));
}
let (status_bytes, _) = parse_tlv(status_info)?;
let status = *status_bytes.last().unwrap_or(&255);
if rest.is_empty() {
return Err(Error::Io(IoError::new(
ErrorKind::InvalidData,
"No timestamp token in response",
)));
}
if rest[0] != 0x30 {
return Err(Error::Io(IoError::new(
ErrorKind::InvalidData,
"Invalid timestamp token format",
)));
}
let token_len = get_tlv_total_length(rest)?;
let token = rest[..token_len].to_vec();
Ok((status, token))
}
fn parse_tlv(data: &[u8]) -> Result<(&[u8], &[u8])> {
if data.len() < 2 {
return Err(Error::Io(IoError::new(
ErrorKind::InvalidData,
"TLV too short",
)));
}
let (len, header_len) = if data[1] < 128 {
(data[1] as usize, 2)
} else if data[1] == 0x81 {
if data.len() < 3 {
return Err(Error::Io(IoError::new(
ErrorKind::InvalidData,
"Invalid length encoding",
)));
}
(data[2] as usize, 3)
} else if data[1] == 0x82 {
if data.len() < 4 {
return Err(Error::Io(IoError::new(
ErrorKind::InvalidData,
"Invalid length encoding",
)));
}
(((data[2] as usize) << 8) | (data[3] as usize), 4)
} else {
return Err(Error::Io(IoError::new(
ErrorKind::InvalidData,
"Unsupported length encoding",
)));
};
if data.len() < header_len + len {
return Err(Error::Io(IoError::new(
ErrorKind::InvalidData,
"TLV length exceeds data",
)));
}
let value = &data[header_len..header_len + len];
let rest = &data[header_len + len..];
Ok((value, rest))
}
fn get_tlv_total_length(data: &[u8]) -> Result<usize> {
if data.len() < 2 {
return Err(Error::Io(IoError::new(
ErrorKind::InvalidData,
"TLV too short",
)));
}
let (len, header_len) = if data[1] < 128 {
(data[1] as usize, 2)
} else if data[1] == 0x81 {
if data.len() < 3 {
return Err(Error::Io(IoError::new(
ErrorKind::InvalidData,
"Invalid length encoding",
)));
}
(data[2] as usize, 3)
} else if data[1] == 0x82 {
if data.len() < 4 {
return Err(Error::Io(IoError::new(
ErrorKind::InvalidData,
"Invalid length encoding",
)));
}
(((data[2] as usize) << 8) | (data[3] as usize), 4)
} else {
return Err(Error::Io(IoError::new(
ErrorKind::InvalidData,
"Unsupported length encoding",
)));
};
Ok(header_len + len)
}
fn hex_to_bytes(hex: &str) -> Result<Vec<u8>> {
let hex = hex.trim();
if !hex.len().is_multiple_of(2) {
return Err(Error::InvalidHashFormat {
value: "Invalid hex string length".to_string(),
});
}
(0..hex.len())
.step_by(2)
.map(|i| {
u8::from_str_radix(&hex[i..i + 2], 16).map_err(|_| Error::InvalidHashFormat {
value: "Invalid hex character".to_string(),
})
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{HashAlgorithm, Hasher};
#[test]
fn test_rfc3161_client_creation() {
let client = Rfc3161Client::new();
assert!(!client.servers.is_empty());
}
#[test]
fn test_rfc3161_client_custom_server() {
let client = Rfc3161Client::with_server("https://custom.example.com/tsa");
assert_eq!(client.servers.len(), 1);
assert_eq!(client.servers[0], "https://custom.example.com/tsa");
}
#[test]
fn test_hex_to_bytes() {
let bytes = hex_to_bytes("deadbeef").unwrap();
assert_eq!(bytes, vec![0xde, 0xad, 0xbe, 0xef]);
}
#[test]
fn test_hex_to_bytes_invalid() {
assert!(hex_to_bytes("deadbee").is_err()); assert!(hex_to_bytes("deadbeeg").is_err()); }
#[test]
fn test_timestamp_req_encoding() {
let hash = vec![0u8; 32]; let req = encode_timestamp_request(&hash, OID_SHA256, 12345, true);
assert!(!req.is_empty());
assert_eq!(req[0], 0x30); }
#[test]
fn test_encode_integer() {
let encoded = encode_integer(1);
assert_eq!(encoded, vec![0x02, 0x01, 0x01]);
}
#[test]
fn test_encode_integer_u64() {
let encoded = encode_integer_u64(256);
assert_eq!(encoded, vec![0x02, 0x02, 0x01, 0x00]);
}
#[test]
fn test_verify_empty_token() {
let client = Rfc3161Client::new();
let doc_id = Hasher::hash(HashAlgorithm::Sha256, b"test");
let timestamp = TimestampRecord::rfc3161("https://example.com", Utc::now(), "");
let result = client.verify_timestamp(×tamp, &doc_id).unwrap();
assert!(!result.valid);
assert_eq!(result.status, VerificationStatus::Invalid);
}
#[test]
fn test_verify_invalid_base64() {
let client = Rfc3161Client::new();
let doc_id = Hasher::hash(HashAlgorithm::Sha256, b"test");
let timestamp =
TimestampRecord::rfc3161("https://example.com", Utc::now(), "!!!invalid!!!");
let result = client.verify_timestamp(×tamp, &doc_id);
assert!(result.is_err());
}
#[test]
fn test_verify_valid_structure() {
let client = Rfc3161Client::new();
let doc_id = Hasher::hash(HashAlgorithm::Sha256, b"test");
let token = BASE64.encode([0x30, 0x00]);
let timestamp = TimestampRecord::rfc3161("https://example.com", Utc::now(), token);
let result = client.verify_timestamp(×tamp, &doc_id).unwrap();
assert!(result.valid);
assert_eq!(result.status, VerificationStatus::Valid);
}
}