use subtle::ConstantTimeEq;
use ecdh_wrapper::{PublicKey, PrivateKey};
use super::commands::{RoutingCommand, parse_routing_commands};
use super::constants::{AD_SIZE, ROUTING_INFO_SIZE, V0_AD, PER_HOP_ROUTING_INFO_SIZE, PAYLOAD_TAG_SIZE};
use super::error::SphinxUnwrapError;
use super::internal_crypto::{HASH_SIZE, MAC_SIZE, GROUP_ELEMENT_SIZE, StreamCipher, hash, kdf, hmac, sprp_decrypt};
const GROUP_ELEMENT_OFFSET: usize = AD_SIZE;
const ROUTING_INFO_OFFSET: usize = GROUP_ELEMENT_OFFSET + GROUP_ELEMENT_SIZE;
const MAC_OFFSET: usize = ROUTING_INFO_OFFSET + ROUTING_INFO_SIZE;
pub fn sphinx_packet_unwrap(private_key: &PrivateKey, packet: &mut [u8]) -> (Option<Vec<u8>>, Option<[u8; HASH_SIZE]>, Option<Vec<RoutingCommand>>, Option<SphinxUnwrapError>) {
let (header, payload) = packet.split_at_mut(MAC_OFFSET+MAC_SIZE);
let (authed_header, _mac) = header.split_at_mut(MAC_OFFSET);
let (ad, _after_ad) = authed_header.split_at_mut(AD_SIZE);
let (group_element_bytes, routing_info) = _after_ad.split_at_mut(GROUP_ELEMENT_SIZE);
if ad.ct_eq(&V0_AD).unwrap_u8() == 0 {
return (None, None, None, Some(SphinxUnwrapError::InvalidPacketError));
}
let mut group_element = PublicKey::default();
let m = group_element.from_bytes(group_element_bytes);
match m {
Ok(_) => {},
Err(_) => {
return (None, None, None, Some(SphinxUnwrapError::ImpossibleError))
},
};
let shared_secret = private_key.exp(&group_element);
let replay_tag_raw = hash(&group_element.as_array());
let mut replay_tag = [0u8; HASH_SIZE];
replay_tag[..].copy_from_slice(&replay_tag_raw);
let keys = kdf(&shared_secret);
let mac_key = keys.header_mac;
let mut _data = vec![];
_data.extend(ad.iter());
_data.extend(group_element_bytes.iter());
_data.extend(routing_info.iter());
let calculated_mac = hmac(&mac_key, &_data);
if calculated_mac.ct_eq(_mac).unwrap_u8() == 0 {
return (None, Some(replay_tag), None, Some(SphinxUnwrapError::MACError));
}
let mut stream_cipher = StreamCipher::new(&keys.header_encryption, &keys.header_encryption_iv);
let mut a = [0u8; ROUTING_INFO_SIZE + PER_HOP_ROUTING_INFO_SIZE];
let mut b = [0u8; ROUTING_INFO_SIZE + PER_HOP_ROUTING_INFO_SIZE];
a[..ROUTING_INFO_SIZE].clone_from_slice(routing_info);
stream_cipher.xor_key_stream(&mut b, &a);
let new_routing_info = &b[PER_HOP_ROUTING_INFO_SIZE..];
let cmd_buf = &b[..PER_HOP_ROUTING_INFO_SIZE];
let commands = parse_routing_commands(cmd_buf);
let commands_tuple = match commands {
Ok(cmds) => cmds,
Err(_) => {
return (None, Some(replay_tag), None, Some(SphinxUnwrapError::RouteInfoParseError))
},
};
let (cmds, maybe_next_hop, maybe_surb_reply) = commands_tuple;
let mut p = vec![0u8; payload.len()];
p.copy_from_slice(&payload[..]);
let decrypted_payload = match sprp_decrypt(&keys.payload_encryption, &keys.header_encryption_iv, payload.to_vec())
{
Ok(x) => x,
Err(_) => {
return (None, Some(replay_tag), Some(cmds), Some(SphinxUnwrapError::PayloadDecryptError))
}
};
let final_payload;
let mut final_cmds = vec![];
final_cmds.extend(cmds);
if maybe_next_hop.is_some() {
group_element.blind(&keys.blinding_factor);
group_element_bytes.copy_from_slice(&group_element.as_array());
routing_info.copy_from_slice(new_routing_info);
let next_hop = maybe_next_hop.unwrap();
match next_hop {
RoutingCommand::NextHop(next_hop_cmd) => {
_mac.copy_from_slice(&next_hop_cmd.mac);
final_cmds.push(RoutingCommand::NextHop(next_hop_cmd));
},
_ => unreachable!(),
}
payload.copy_from_slice(&decrypted_payload);
final_payload = None;
} else {
if !maybe_surb_reply.is_some() {
let zeros = [0u8; PAYLOAD_TAG_SIZE];
if zeros != decrypted_payload[..PAYLOAD_TAG_SIZE] {
return (None, Some(replay_tag), None, Some(SphinxUnwrapError::PayloadError));
}
final_payload = Some(decrypted_payload[PAYLOAD_TAG_SIZE..].to_vec());
} else {
final_payload = Some(decrypted_payload);
final_cmds.push(maybe_surb_reply.unwrap());
}
}
return (final_payload, Some(replay_tag), Some(final_cmds), None);
}