1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
use std::sync::atomic::{AtomicUsize, Ordering};
use sha2::Digest;
use crate::util::*;
pub type TransferId = [u8; 32];
/// Multipart transfer
///
/// It is used to collect multiple values of ADNL `Part` messages.
///
/// See [crate::proto::adnl::Message]
pub struct Transfer {
/// Data parts labeled with offset
parts: FastDashMap<usize, Vec<u8>>,
/// Received data length
received_len: AtomicUsize,
/// Total data length
total_len: usize,
/// Transfer timings used to check its validity
timings: UpdatedAt,
}
impl Transfer {
/// Creates new multipart transfer with target length in bytes
pub fn new(total_len: usize) -> Self {
Self {
parts: FastDashMap::with_capacity_and_hasher(0, Default::default()),
received_len: Default::default(),
total_len,
timings: Default::default(),
}
}
/// Returns transfer timings info (when it was last updated)
#[inline(always)]
pub fn timings(&self) -> &UpdatedAt {
&self.timings
}
/// Tries to add new part to the transfer at given offset
///
/// Will do nothing if part at given offset already exists
pub fn add_part(
&self,
offset: usize,
data: Vec<u8>,
transfer_id: &TransferId,
) -> Result<Option<Vec<u8>>, TransferError> {
let len = data.len();
if self.parts.insert(offset, data).is_some() {
return Ok(None);
}
// Increase received length.
// This part heavily relies on ordering, so hope that it works as expected
self.received_len.fetch_add(len, Ordering::Release);
// Check if it is equal to the total length and make sure it will be big enough to fail
// next check on success
let mut received = self
.received_len
.compare_exchange(
self.total_len,
self.total_len * 2,
Ordering::Acquire,
Ordering::Acquire,
)
.unwrap_or_else(std::convert::identity);
// Handle part
match received.cmp(&self.total_len) {
std::cmp::Ordering::Equal => {
tracing::debug!(
received,
total = self.total_len,
transfer_id = %DisplayTransferId(transfer_id),
"finished ADNL transfer"
);
// Combine all parts
received = 0;
let mut buffer = Vec::with_capacity(self.total_len);
while received < self.total_len {
if let Some(data) = self.parts.get(&received) {
let data = data.value();
received += data.len();
buffer.extend_from_slice(data);
} else {
return Err(TransferError::PartMissing);
}
}
// Check hash
let hash = sha2::Sha256::digest(&buffer);
if hash.as_slice() != transfer_id {
return Err(TransferError::InvalidHash);
}
// Done
Ok(Some(buffer))
}
std::cmp::Ordering::Greater => Err(TransferError::ReceivedTooMuch),
std::cmp::Ordering::Less => {
tracing::trace!(
received,
total = self.total_len,
transfer_id = %DisplayTransferId(transfer_id),
"received ADNL transfer part"
);
Ok(None)
}
}
}
}
#[derive(Copy, Clone)]
pub struct DisplayTransferId<'a>(pub &'a TransferId);
impl std::fmt::Display for DisplayTransferId<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut output = [0u8; 64];
hex::encode_to_slice(self.0, &mut output).ok();
// SAFETY: output is guaranteed to contain only [0-9a-f]
let output = unsafe { std::str::from_utf8_unchecked(&output) };
f.write_str(output)
}
}
#[derive(thiserror::Error, Debug)]
pub enum TransferError {
#[error("Invalid transfer part (received too much)")]
ReceivedTooMuch,
#[error("Invalid transfer (part is missing)")]
PartMissing,
#[error("Invalid transfer data hash")]
InvalidHash,
}