hdfs_client/
data_transfer.rs

1use std::io::{self, Read, Write};
2
3use crate::{crc32, hrpc::HRpc, HDFSError};
4use crc::Digest;
5use hdfs_types::hdfs::{
6    op_write_block_proto::BlockConstructionStage, BaseHeaderProto, BlockOpResponseProto,
7    ChecksumProto, ChecksumTypeProto, ClientOperationHeaderProto, ExtendedBlockProto,
8    LocatedBlockProto, OpReadBlockProto, OpWriteBlockProto, PacketHeaderProto, PipelineAckProto,
9    Status, UpdateBlockForPipelineRequestProto,
10};
11use prost::{
12    bytes::{BufMut, BytesMut},
13    Message,
14};
15
16const DATA_TRANSFER_PROTO: u16 = 28;
17const READ_BLOCK: u8 = 81;
18const WRITE_BLOCK: u8 = 80;
19
20fn read_prefixed_message<S: Read, M: Message + Default>(stream: &mut S) -> Result<M, HDFSError> {
21    use prost::encoding::decode_varint;
22    let mut buf = BytesMut::new();
23    let mut tmp_buf = [0u8];
24    let length = loop {
25        stream.read_exact(&mut tmp_buf)?;
26        buf.put_u8(tmp_buf[0]);
27        match decode_varint(&mut buf.clone()) {
28            Ok(length) => break length,
29            Err(_) => {
30                continue;
31            }
32        }
33    };
34    buf.clear();
35    buf.resize(length as usize, 0);
36    stream.read_exact(&mut buf)?;
37    let msg = M::decode(buf)?;
38    Ok(msg)
39}
40
41fn read_be_u32<S: Read>(stream: &mut S) -> io::Result<u32> {
42    let mut bytes = [0; 4];
43    stream.read_exact(&mut bytes)?;
44    Ok(u32::from_be_bytes(bytes))
45}
46
47fn read_be_u16<S: Read>(stream: &mut S) -> io::Result<u16> {
48    let mut bytes = [0; 2];
49    stream.read_exact(&mut bytes)?;
50    Ok(u16::from_be_bytes(bytes))
51}
52
53macro_rules! trace_valuable {
54    ($($s:stmt);*;) => {
55        #[cfg(feature="trace_valuable")]
56        {
57            use valuable::Valuable;
58            $($s)*
59        }
60    };
61}
62
63macro_rules! trace_dbg {
64    ($($s:stmt);*;) => {
65        #[cfg(feature="trace_dbg")]
66        {
67            $($s)*
68        }
69    };
70}
71
72#[allow(unused)]
73pub struct BlockReadStream<S> {
74    stream: S,
75    pub(crate) packet_remain: usize,
76    offset: u64,
77    block: LocatedBlockProto,
78    checksum: ChecksumProto,
79    checksum_data: Vec<u32>,
80    checksum_idx: usize,
81    checksum_read: usize,
82    digest_fn: Box<dyn Fn() -> Digest<'static, u32>>,
83    digest: Digest<'static, u32>,
84}
85
86impl<S: Read + Write> BlockReadStream<S> {
87    pub fn new(
88        client_name: String,
89        mut stream: S,
90        offset: u64,
91        send_checksums: Option<bool>,
92        block: LocatedBlockProto,
93    ) -> Result<Self, HDFSError> {
94        let len = block.b.num_bytes() - offset;
95        let req = OpReadBlockProto {
96            header: ClientOperationHeaderProto {
97                base_header: BaseHeaderProto {
98                    block: block.b.clone(),
99                    token: Some(block.block_token.clone()),
100                    trace_info: None,
101                },
102                client_name,
103            },
104            offset,
105            len,
106            send_checksums,
107            caching_strategy: None,
108        };
109        let mut buf = BytesMut::new();
110        buf.put_u16(DATA_TRANSFER_PROTO);
111        buf.put_u8(READ_BLOCK);
112        let length = req.encoded_len();
113        buf.reserve(prost::length_delimiter_len(length) + length + 2);
114        trace_dbg! {
115            tracing::trace!(target: "data-transfer", "\nreq: {req:#?}");
116        }
117        trace_valuable! {
118            tracing::trace!(target: "data-transfer", req=req.as_value());
119
120        }
121        req.encode_length_delimited(&mut buf)?;
122        stream.write_all(&buf)?;
123        stream.flush()?;
124        let resp: BlockOpResponseProto = read_prefixed_message(&mut stream)?;
125        trace_dbg! {
126            tracing::trace!(target: "data-transfer", "\nresp: {resp:#?}");
127        }
128        trace_valuable! {
129            tracing::trace!(target: "data-transfer", resp=resp.as_value());
130        }
131        if !matches!(resp.status(), Status::Success) {
132            tracing::warn!(
133                "init data transfer error {}",
134                resp.message.clone().unwrap_or_default()
135            );
136            return Err(HDFSError::DataNodeError(Box::new(resp)));
137        }
138        let checksum = resp
139            .read_op_checksum_info
140            .clone()
141            .map(|c| c.checksum)
142            .unwrap_or_else(|| ChecksumProto {
143                bytes_per_checksum: 512,
144                r#type: ChecksumTypeProto::ChecksumNull as i32,
145            });
146        let checksum_ty = checksum.r#type();
147        let (header, checksum_data) = start_new_packet(&checksum, &mut stream)?;
148
149        let digest_fn = move || match checksum_ty {
150            ChecksumTypeProto::ChecksumNull | ChecksumTypeProto::ChecksumCrc32 => {
151                crc32::CRC32.digest()
152            }
153            ChecksumTypeProto::ChecksumCrc32c => crc32::CRC32C.digest(),
154        };
155        let mut stream = Self {
156            stream,
157            block,
158            offset,
159            packet_remain: header.data_len as usize,
160            checksum,
161            checksum_data,
162            checksum_read: 0,
163            checksum_idx: 0,
164            digest: digest_fn(),
165            digest_fn: Box::new(digest_fn),
166        };
167        if let Some(ck_resp) = resp.read_op_checksum_info {
168            if ck_resp.chunk_offset < offset {
169                let diff = offset - ck_resp.chunk_offset;
170                let mut buf = vec![0; diff as usize];
171                stream.read_exact(&mut buf)?;
172                stream.offset = offset;
173            }
174        }
175        Ok(stream)
176    }
177
178    pub fn remaining(&self) -> u64 {
179        self.block.b.num_bytes() - self.offset
180    }
181
182    fn inner_read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
183        if self.packet_remain == 0 {
184            if self.offset >= self.block.b.num_bytes() {
185                return Ok(0);
186            } else {
187                let (header, checksum_data) = start_new_packet(&self.checksum, &mut self.stream)?;
188                self.checksum_read = 0;
189                self.checksum_idx = 0;
190                self.packet_remain = header.data_len as usize;
191                self.checksum_data = checksum_data;
192            }
193        }
194
195        let max_read = self.packet_remain.min(buf.len());
196        self.stream.read_exact(&mut buf[..max_read])?;
197        self.packet_remain -= max_read;
198        self.offset += max_read as u64;
199        if matches!(
200            self.checksum.r#type(),
201            ChecksumTypeProto::ChecksumCrc32 | ChecksumTypeProto::ChecksumCrc32c
202        ) {
203            let bytes_per_checksum = self.checksum.bytes_per_checksum as usize;
204            if self.checksum_read + max_read >= bytes_per_checksum || self.packet_remain == 0 {
205                let step = max_read.div_ceil(bytes_per_checksum);
206                for i in 0..step {
207                    let (start, end) = if i == 0 {
208                        (0, (bytes_per_checksum - self.checksum_read))
209                    } else {
210                        (
211                            bytes_per_checksum * i - self.checksum_read,
212                            (bytes_per_checksum * (i + 1) - self.checksum_read),
213                        )
214                    };
215                    let end = end.min(max_read);
216                    self.digest.update(&buf[start..end]);
217                    if (self.packet_remain == 0 && i + 1 == step)
218                        || (end + self.checksum_read) % bytes_per_checksum == 0
219                    {
220                        let checksum = self.checksum_data[self.checksum_idx];
221                        let digest = (self.digest_fn)();
222                        let old_digest = std::mem::replace(&mut self.digest, digest);
223                        let cal = old_digest.finalize();
224                        if cal != checksum {
225                            return Err(io::Error::new(
226                                io::ErrorKind::InvalidData,
227                                format!(
228                                    "checksum validation failed expect: {} got: {}",
229                                    checksum, cal
230                                ),
231                            ));
232                        }
233                    }
234                    self.checksum_idx += 1;
235                }
236                self.checksum_read = (self.checksum_read + max_read) % bytes_per_checksum;
237            } else {
238                self.digest.update(&buf[..max_read]);
239                self.checksum_read += max_read;
240            }
241        }
242        Ok(max_read)
243    }
244}
245
246fn start_new_packet<S: Read + Write>(
247    checksum: &ChecksumProto,
248    stream: &mut S,
249) -> Result<(PacketHeaderProto, Vec<u32>), HDFSError> {
250    let mut buf = BytesMut::new();
251    let _length = read_be_u32(stream)?;
252    let header_size = read_be_u16(stream)?;
253    buf.resize(header_size as usize, 0);
254    stream.read_exact(&mut buf)?;
255    let header = PacketHeaderProto::decode(buf)?;
256    trace_dbg! {
257        tracing::trace!(target: "data-transfer", "\npacket header: {header:#?}");
258    }
259    trace_valuable! {
260        tracing::trace!(target: "data-transfer", packet_header=header.as_value());
261    }
262    // if header.data_len == 0 {
263    //     let read_resp = ClientReadStatusProto { status: 0 };
264    //     trace_dbg! {
265    //         tracing::trace!(target: "data-transfer", "\nresp: {read_resp:#?}");
266    //     }
267    //     trace_valuable! {
268    //         tracing::trace!(target: "data-transfer", resp=read_resp.as_value());
269    //     }
270    //     stream.write_all(&read_resp.encode_length_delimited_to_vec())?;
271    //     stream.flush()?;
272    // }
273    let checksum_data = match checksum.r#type() {
274        ChecksumTypeProto::ChecksumNull => vec![],
275        _ => {
276            if header.data_len == 0 {
277                vec![]
278            } else {
279                let len = (header.data_len as u32).div_ceil(checksum.bytes_per_checksum) * 4;
280                let mut data = vec![0; len as usize];
281                stream.read_exact(&mut data)?;
282                data.as_slice()
283                    .chunks_exact(4)
284                    .map(|s| u32::from_be_bytes(s.try_into().unwrap()))
285                    .collect()
286            }
287        }
288    };
289    Ok((header, checksum_data))
290}
291
292impl<S: Read + Write> Read for BlockReadStream<S> {
293    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
294        self.inner_read(buf)
295    }
296}
297
298pub struct BlockWriteStream<S> {
299    pub(crate) stream: S,
300    pub(crate) offset: u64,
301    pub(crate) closed: bool,
302    seq_no: i64,
303    bytes_per_checksum: u32,
304    checksum_ty: ChecksumTypeProto,
305    block: LocatedBlockProto,
306    client_name: String,
307}
308
309impl<S: Read + Write> BlockWriteStream<S> {
310    pub fn close<D: Read + Write>(
311        &mut self,
312        ipc: &mut HRpc<D>,
313    ) -> Result<ExtendedBlockProto, HDFSError> {
314        if self.closed {
315            return Ok(self.block.b.clone());
316        }
317        self.write(&[], true)?;
318        self.stream.flush()?;
319        self.block.b.num_bytes = Some(self.offset);
320        let req = UpdateBlockForPipelineRequestProto {
321            block: self.block.b.clone(),
322            client_name: self.client_name.clone(),
323        };
324        ipc.update_block_for_pipeline(req)?;
325        self.closed = true;
326        Ok(self.block.b.clone())
327    }
328
329    pub fn create(
330        client_name: String,
331        mut stream: S,
332        block: LocatedBlockProto,
333        bytes_per_checksum: u32,
334        checksum_ty: ChecksumTypeProto,
335        offset: u64,
336        append: bool,
337    ) -> Result<Self, HDFSError> {
338        let req = OpWriteBlockProto {
339            header: ClientOperationHeaderProto {
340                base_header: BaseHeaderProto {
341                    block: block.b.clone(),
342                    token: Some(block.block_token.clone()),
343                    trace_info: None,
344                },
345                client_name: client_name.clone(),
346            },
347            stage: if append && offset != 0 {
348                BlockConstructionStage::PipelineSetupAppend as i32
349            } else {
350                BlockConstructionStage::PipelineSetupCreate as i32
351            },
352            pipeline_size: block.locs.len() as u32,
353            targets: vec![],
354            min_bytes_rcvd: block.b.num_bytes(),
355            max_bytes_rcvd: offset,
356            latest_generation_stamp: block.b.generation_stamp,
357            requested_checksum: ChecksumProto {
358                r#type: checksum_ty as i32,
359                bytes_per_checksum,
360            },
361            ..Default::default()
362        };
363        let mut buf = BytesMut::new();
364        buf.put_u16(DATA_TRANSFER_PROTO);
365        buf.put_u8(WRITE_BLOCK);
366        let length = req.encoded_len();
367        buf.reserve(prost::length_delimiter_len(length) + length + 2);
368        trace_dbg! {
369            tracing::trace!(target: "data-transfer", "\nreq: {req:#?}");
370        }
371        trace_valuable! {
372            tracing::trace!(target: "data-transfer", req=req.as_value());
373        }
374        req.encode_length_delimited(&mut buf)?;
375        stream.write_all(&buf)?;
376        stream.flush()?;
377
378        let message: BlockOpResponseProto = read_prefixed_message(&mut stream)?;
379        trace_dbg! {
380            tracing::trace!(target: "data-transfer", "\nresp: {message:#?}");
381        }
382        trace_valuable! {
383            tracing::trace!(target: "data-transfer", resp=message.as_value());
384        }
385        if !matches!(message.status(), Status::Success) {
386            tracing::warn!(
387                "init data transfer error {}",
388                message.message.unwrap_or_default()
389            );
390            // TODO more proper error type
391            return Err(io::Error::new(io::ErrorKind::InvalidData, "write block failed").into());
392        }
393        Ok(Self {
394            stream,
395            offset,
396            seq_no: 0,
397            closed: false,
398            block,
399            bytes_per_checksum,
400            checksum_ty,
401            client_name,
402        })
403    }
404
405    fn inner_write(&mut self, data: &[u8], last: bool) -> Result<(), HDFSError> {
406        let header = PacketHeaderProto {
407            offset_in_block: self.offset as i64,
408            seqno: self.seq_no,
409            last_packet_in_block: last,
410            data_len: data.len() as i32,
411            sync_block: None,
412        };
413
414        #[cfg(any(feature = "trace_dbg", feature = "trace_valuable"))]
415        {
416            tracing::trace!(
417                target: "data-transfer",
418                seq = self.seq_no,
419                offset = self.offset,
420                data_len = data.len(),
421                last
422            );
423        }
424        let chunks = if data.is_empty() {
425            0
426        } else {
427            data.len().div_ceil(self.bytes_per_checksum as usize)
428        };
429        let total_len = data.len() + chunks * 4 + 4;
430        let mut buffer = BytesMut::new();
431        buffer.put_u32(total_len as u32);
432        buffer.put_u16(header.encoded_len() as u16);
433        header.encode(&mut buffer)?;
434
435        match self.checksum_ty {
436            ChecksumTypeProto::ChecksumNull => {}
437            ChecksumTypeProto::ChecksumCrc32 => {
438                for chunk in data.chunks(self.bytes_per_checksum as usize) {
439                    buffer.put_u32(crc32::CRC32.checksum(chunk));
440                }
441            }
442            ChecksumTypeProto::ChecksumCrc32c => {
443                for chunk in data.chunks(self.bytes_per_checksum as usize) {
444                    buffer.put_u32(crc32::CRC32C.checksum(chunk));
445                }
446            }
447        }
448        self.stream.write_all(&buffer)?;
449        self.stream.write_all(data)?;
450        self.stream.flush()?;
451
452        let ack: PipelineAckProto = read_prefixed_message(&mut self.stream)?;
453        trace_dbg! {
454            tracing::trace!(target: "data-transfer", "\nack: {ack:#?}");
455        }
456        trace_valuable! {
457            tracing::trace!(target: "data-transfer", ack=ack.as_value());
458        }
459        if ack.seqno != self.seq_no {
460            return Err(HDFSError::IOError(io::Error::new(
461                io::ErrorKind::InvalidData,
462                "mis match seq",
463            )));
464        }
465        self.seq_no += 1;
466        self.offset += data.len() as u64;
467        Ok(())
468    }
469
470    pub fn write(&mut self, data: &[u8], last: bool) -> Result<(), HDFSError> {
471        match (data.is_empty(), last) {
472            (true, false) => Ok(()),
473            (true, true) => self.inner_write(data, last),
474            _ => {
475                if self.offset % (self.bytes_per_checksum as u64) != 0
476                    && self.offset + data.len() as u64 > self.bytes_per_checksum as u64
477                {
478                    let split = self.bytes_per_checksum as usize
479                        - self.offset as usize % (self.bytes_per_checksum as usize);
480                    let split = split.min(data.len());
481                    self.write_by_i32_max(&data[..split], false)?;
482                    self.write_by_i32_max(&data[split..], last)?;
483                    Ok(())
484                } else {
485                    self.write_by_i32_max(data, last)
486                }
487            }
488        }
489    }
490
491    fn write_by_i32_max(&mut self, data: &[u8], last: bool) -> Result<(), HDFSError> {
492        let total = data.len().div_ceil(i32::MAX as usize);
493        for (idx, part) in data.chunks(i32::MAX as usize).enumerate() {
494            let last = (idx + 1 == total) && last;
495            self.inner_write(part, last)?
496        }
497        Ok(())
498    }
499}