crowdstrike_cloudproto/services/lfo/
response.rs

1use crate::framing::CloudProtoPacket;
2use crate::services::lfo::file_header::{CRC_LEN, LFO_RESP_HDR_LEN};
3use crate::services::lfo::pkt_kind::LfoPacketKind;
4use crate::services::lfo::{CompressionFormats, LfoError, LfoFileHeader};
5use bytes::Bytes;
6use std::cmp;
7use std::io::{Read, Write};
8use tracing::trace;
9
10#[cfg(feature = "lfo-compress-xz")]
11use bytes::Buf;
12#[cfg(feature = "lfo-compress-xz")]
13use xz2::read::XzDecoder;
14
15enum ResponseReadState {
16    Direct {
17        read_pos: usize,
18    },
19    #[cfg(feature = "lfo-compress-xz")]
20    Compressed {
21        stream: XzDecoder<bytes::buf::Reader<Bytes>>,
22    },
23}
24
25/// The reply from the server corresponding to a single [`LfoRequest`](super::LfoRequest).
26pub struct LfoResponse {
27    raw_lfo_payload: Bytes,
28    header: LfoFileHeader,
29    // This could be the plain file data, or compressed
30    lfo_data: Bytes,
31    read_state: ResponseReadState,
32    #[cfg(feature = "lfo-check-hash")]
33    read_hasher: sha2::Sha256,
34    #[cfg(not(feature = "lfo-check-hash"))]
35    read_hasher: (),
36}
37
38impl LfoResponse {
39    /// Extracts the data of the requested file from the response.
40    /// May fail if the received data (after any decompression) has the wrong size or hash.
41    /// This ignores the [`Read`](std::io::Read) cursor and always returns the entire data.
42    pub fn data(&self) -> Result<Bytes, LfoError> {
43        let full_data = match self.read_state {
44            ResponseReadState::Direct { .. } => self.lfo_data.clone(),
45            #[cfg(feature = "lfo-compress-xz")]
46            ResponseReadState::Compressed { .. } => {
47                let mut stream = XzDecoder::new(self.lfo_data.clone().reader());
48                let mut buf = Vec::with_capacity(self.header.payload_size as usize);
49                stream.read_to_end(&mut buf)?;
50                buf.into()
51            }
52        };
53        // This explicitly does not use Read, so we have to do these checks here too
54        self.check_full_data_len(full_data.len())?;
55        self.validate_full_data_hash(full_data.as_ref())?;
56        Ok(full_data)
57    }
58
59    /// This returns the raw, still serialized LFO server's response.
60    /// You most likely want to use [`Self::data()`](Self::data) instead.
61    /// Only use this if you would like to parse some fields of the LFO header yourself.
62    pub fn raw_lfo_payload(&self) -> Bytes {
63        self.raw_lfo_payload.clone()
64    }
65
66    /// The LFO file header mostly contains low-level details about the file being downloaded.
67    /// You can use it check the size the decompressed file, before actually decompressing it.
68    pub fn lfo_file_header(&self) -> &LfoFileHeader {
69        &self.header
70    }
71
72    #[cfg(feature = "lfo-check-hash")]
73    fn update_running_hash(hasher: &mut sha2::Sha256, buf: &[u8]) {
74        use sha2::Digest;
75        hasher.update(buf);
76    }
77    #[cfg(not(feature = "lfo-check-hash"))]
78    fn update_running_hash(_hasher: &mut (), _buf: &[u8]) {}
79
80    #[cfg(feature = "lfo-check-hash")]
81    fn check_hash_matches(expected: &[u8; 32], hasher: &mut sha2::Sha256) -> Result<(), LfoError> {
82        use sha2::Digest;
83        let actual = hasher.finalize_reset();
84        if expected != actual.as_slice() {
85            return Err(LfoError::InvalidHash {
86                expected: *expected,
87                actual: *actual.as_ref(),
88            });
89        }
90        Ok(())
91    }
92    #[cfg(not(feature = "lfo-check-hash"))]
93    fn check_hash_matches(_expected: &[u8; 32], _actual: &()) -> Result<(), LfoError> {
94        Ok(())
95    }
96
97    #[cfg(feature = "lfo-check-hash")]
98    fn validate_full_data_hash(&self, data: &[u8]) -> Result<(), LfoError> {
99        use sha2::Digest;
100        let mut hasher = sha2::Sha256::new();
101        hasher.update(data);
102        Self::check_hash_matches(&self.header.data_hash, &mut hasher)
103    }
104    #[cfg(not(feature = "lfo-check-hash"))]
105    fn validate_full_data_hash(&self, _data: &[u8]) -> Result<(), LfoError> {
106        Ok(())
107    }
108
109    fn check_full_data_len(&self, data_len: usize) -> Result<(), LfoError> {
110        if data_len != self.header.payload_size as usize {
111            return Err(LfoError::ReplyParseError {
112                reason: format!(
113                    "LFO file data has length {:#x}, but expected {:#x}",
114                    data_len, self.header.payload_size
115                ),
116                raw_payload: Default::default(),
117            });
118        }
119        Ok(())
120    }
121
122    fn try_from_raw_lfo_payload(raw_payload: Vec<u8>) -> Result<Self, LfoError> {
123        let raw_payload = Bytes::from(raw_payload);
124        let header = match LfoFileHeader::try_from(raw_payload.as_ref()) {
125            Ok(h) => h,
126            Err(e) => {
127                return Err(LfoError::ReplyParseError {
128                    reason: e,
129                    raw_payload,
130                })
131            }
132        };
133        let chunk_data = raw_payload.slice(LFO_RESP_HDR_LEN..raw_payload.len() - CRC_LEN);
134        let read_state = if header.comp_format == CompressionFormats::None as u16 {
135            ResponseReadState::Direct { read_pos: 0 }
136        } else if cfg!(feature = "lfo-compress-xz")
137            && header.comp_format == CompressionFormats::Xz as u16
138        {
139            #[cfg(not(feature = "lfo-compress-xz"))]
140            unreachable!();
141            #[cfg(feature = "lfo-compress-xz")]
142            ResponseReadState::Compressed {
143                stream: XzDecoder::new(chunk_data.clone().reader()),
144            }
145        } else {
146            return Err(LfoError::ReplyParseError {
147                reason: format!("Unsupported compression format {}", header.comp_format),
148                raw_payload,
149            });
150        };
151        Ok(Self {
152            raw_lfo_payload: raw_payload,
153            header,
154            lfo_data: chunk_data,
155            read_state,
156            read_hasher: Default::default(),
157        })
158    }
159}
160
161impl TryFrom<CloudProtoPacket> for LfoResponse {
162    type Error = LfoError;
163
164    fn try_from(reply: CloudProtoPacket) -> Result<Self, Self::Error> {
165        if reply.kind == LfoPacketKind::ReplyFail && reply.payload.len() >= 8 {
166            let msg = String::from_utf8_lossy(&reply.payload[8..]);
167
168            // I realize this is terrible, but internal errors indicate file not found errors
169            // I have not seen any other internal errors, except for when the path is wrong
170            if msg == "internal error" {
171                Err(LfoError::NotFound)
172            } else {
173                Err(LfoError::ServerError(msg.to_string()))
174            }
175        } else if reply.kind == LfoPacketKind::ReplyOk {
176            trace!(
177                "Received LfoOk with {:#x} bytes raw payload",
178                reply.payload.len()
179            );
180            Self::try_from_raw_lfo_payload(reply.payload)
181        } else {
182            Err(LfoError::BadReplyKind(reply.kind))
183        }
184    }
185}
186
187impl Read for LfoResponse {
188    fn read(&mut self, mut buf: &mut [u8]) -> std::io::Result<usize> {
189        let hasher = &mut self.read_hasher;
190        match &mut self.read_state {
191            ResponseReadState::Direct { read_pos } => {
192                let remaining = &self.lfo_data[*read_pos..];
193                let attempted_count = cmp::min(buf.len(), remaining.len());
194                let count = buf.write(&remaining[..attempted_count])?;
195
196                Self::update_running_hash(hasher, &remaining[..count]);
197                if count == remaining.len() && count != 0 {
198                    Self::check_hash_matches(&self.header.data_hash, hasher)
199                        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
200                }
201
202                *read_pos += count;
203                Ok(count)
204            }
205            #[cfg(feature = "lfo-compress-xz")]
206            ResponseReadState::Compressed { stream } => {
207                let count = stream.read(buf)?;
208                Self::update_running_hash(hasher, &buf[..count]);
209
210                if stream.total_out() > self.header.payload_size as u64 {
211                    return Err(std::io::Error::new(
212                        std::io::ErrorKind::InvalidData,
213                        LfoError::InvalidFinalSize {
214                            expected: self.header.payload_size as usize,
215                            actual: stream.total_out() as usize,
216                        },
217                    ));
218                } else if count != 0 && stream.total_out() == self.header.payload_size as u64 {
219                    Self::check_hash_matches(&self.header.data_hash, hasher)
220                        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
221                }
222
223                Ok(count)
224            }
225        }
226    }
227}
228
229#[cfg(test)]
230mod test {
231    use crate::framing::{CloudProtoPacket, CloudProtoVersion};
232    use crate::services::lfo::pkt_kind::LfoPacketKind;
233    use crate::services::lfo::test::TEST_REPLY_DATA;
234    use crate::services::lfo::{LfoError, LfoResponse};
235    use crate::services::CloudProtoMagic;
236    use std::io::Read;
237
238    fn check_test_vector(lfo_reply_hex: &str, expected_hash: &str) -> Result<(), LfoError> {
239        let lfo_reply = hex::decode(lfo_reply_hex).unwrap();
240        let reply_pkt = CloudProtoPacket {
241            magic: CloudProtoMagic::TS,
242            kind: LfoPacketKind::ReplyOk.into(),
243            version: CloudProtoVersion::Normal,
244            payload: lfo_reply.clone(),
245        };
246        let mut resp = LfoResponse::try_from(reply_pkt)?;
247        assert_eq!(resp.raw_lfo_payload(), &lfo_reply);
248
249        let data = {
250            let data_from_bytes1 = resp.data()?;
251            let mut data_from_read = Vec::new();
252            resp.read_to_end(&mut data_from_read)?;
253            let data_from_bytes2 = resp.data()?;
254            assert_eq!(data_from_bytes1, data_from_bytes2);
255            assert_eq!(data_from_bytes1, data_from_read);
256            data_from_read
257        };
258
259        use sha2::Digest;
260        let mut hasher = sha2::Sha256::new();
261        hasher.update(&data);
262        assert_eq!(&hex::encode(hasher.finalize().as_slice()), expected_hash);
263
264        // We should already check the hash by default, but let's do it again for good measure
265        assert_eq!(
266            expected_hash,
267            &hex::encode(resp.lfo_file_header().data_hash)
268        );
269        Ok(())
270    }
271
272    #[test]
273    fn simple_test_vector() -> Result<(), LfoError> {
274        let expected_hash = "a330869acb341ad81b4b64f92ed7b85e0a361ab0449017a9f7a5f09276a43655";
275        check_test_vector(TEST_REPLY_DATA, expected_hash)
276    }
277
278    #[test]
279    #[cfg(feature = "lfo-compress-xz")]
280    fn xz_test_vector() -> Result<(), LfoError> {
281        let hex = "000000000000015658dd00985ef1c304b973374fad8726aeac9769fe45d1bea2335630b0899b9ef60001fd377a585a0000016922de36020021011c00000010cf\
282                         58cce0015500645d0055687c400160306c2cec9513bc4360c68796e3b982a76ad18024af592b8f044aae3937e42bec03336fa43a3ecd228463d4545ae8cf99a9\
283                         6368bfc3d7137b5f1fe5cb4201c3928e6a07895cba5f7220d2a3f5400768f1a63acc53ae5abbf13d5b6b84000000c3d9916a00017cd602000000155b09133e30\
284                         0d8b020000000001595a75e2d281";
285        let expected_hash = "58dd00985ef1c304b973374fad8726aeac9769fe45d1bea2335630b0899b9ef6";
286        check_test_vector(hex, expected_hash)
287    }
288}