crowdstrike_cloudproto/services/lfo/
response.rs1use crate::framing::CloudProtoPacket;
2use crate::services::lfo::file_header::{CRC_LEN, LFO_RESP_HDR_LEN};
3use crate::services::lfo::pkt_kind::LfoPacketKind;
4use crate::services::lfo::{CompressionFormats, LfoError, LfoFileHeader};
5use bytes::Bytes;
6use std::cmp;
7use std::io::{Read, Write};
8use tracing::trace;
9
10#[cfg(feature = "lfo-compress-xz")]
11use bytes::Buf;
12#[cfg(feature = "lfo-compress-xz")]
13use xz2::read::XzDecoder;
14
15enum ResponseReadState {
16 Direct {
17 read_pos: usize,
18 },
19 #[cfg(feature = "lfo-compress-xz")]
20 Compressed {
21 stream: XzDecoder<bytes::buf::Reader<Bytes>>,
22 },
23}
24
25pub struct LfoResponse {
27 raw_lfo_payload: Bytes,
28 header: LfoFileHeader,
29 lfo_data: Bytes,
31 read_state: ResponseReadState,
32 #[cfg(feature = "lfo-check-hash")]
33 read_hasher: sha2::Sha256,
34 #[cfg(not(feature = "lfo-check-hash"))]
35 read_hasher: (),
36}
37
38impl LfoResponse {
39 pub fn data(&self) -> Result<Bytes, LfoError> {
43 let full_data = match self.read_state {
44 ResponseReadState::Direct { .. } => self.lfo_data.clone(),
45 #[cfg(feature = "lfo-compress-xz")]
46 ResponseReadState::Compressed { .. } => {
47 let mut stream = XzDecoder::new(self.lfo_data.clone().reader());
48 let mut buf = Vec::with_capacity(self.header.payload_size as usize);
49 stream.read_to_end(&mut buf)?;
50 buf.into()
51 }
52 };
53 self.check_full_data_len(full_data.len())?;
55 self.validate_full_data_hash(full_data.as_ref())?;
56 Ok(full_data)
57 }
58
59 pub fn raw_lfo_payload(&self) -> Bytes {
63 self.raw_lfo_payload.clone()
64 }
65
66 pub fn lfo_file_header(&self) -> &LfoFileHeader {
69 &self.header
70 }
71
72 #[cfg(feature = "lfo-check-hash")]
73 fn update_running_hash(hasher: &mut sha2::Sha256, buf: &[u8]) {
74 use sha2::Digest;
75 hasher.update(buf);
76 }
77 #[cfg(not(feature = "lfo-check-hash"))]
78 fn update_running_hash(_hasher: &mut (), _buf: &[u8]) {}
79
80 #[cfg(feature = "lfo-check-hash")]
81 fn check_hash_matches(expected: &[u8; 32], hasher: &mut sha2::Sha256) -> Result<(), LfoError> {
82 use sha2::Digest;
83 let actual = hasher.finalize_reset();
84 if expected != actual.as_slice() {
85 return Err(LfoError::InvalidHash {
86 expected: *expected,
87 actual: *actual.as_ref(),
88 });
89 }
90 Ok(())
91 }
92 #[cfg(not(feature = "lfo-check-hash"))]
93 fn check_hash_matches(_expected: &[u8; 32], _actual: &()) -> Result<(), LfoError> {
94 Ok(())
95 }
96
97 #[cfg(feature = "lfo-check-hash")]
98 fn validate_full_data_hash(&self, data: &[u8]) -> Result<(), LfoError> {
99 use sha2::Digest;
100 let mut hasher = sha2::Sha256::new();
101 hasher.update(data);
102 Self::check_hash_matches(&self.header.data_hash, &mut hasher)
103 }
104 #[cfg(not(feature = "lfo-check-hash"))]
105 fn validate_full_data_hash(&self, _data: &[u8]) -> Result<(), LfoError> {
106 Ok(())
107 }
108
109 fn check_full_data_len(&self, data_len: usize) -> Result<(), LfoError> {
110 if data_len != self.header.payload_size as usize {
111 return Err(LfoError::ReplyParseError {
112 reason: format!(
113 "LFO file data has length {:#x}, but expected {:#x}",
114 data_len, self.header.payload_size
115 ),
116 raw_payload: Default::default(),
117 });
118 }
119 Ok(())
120 }
121
122 fn try_from_raw_lfo_payload(raw_payload: Vec<u8>) -> Result<Self, LfoError> {
123 let raw_payload = Bytes::from(raw_payload);
124 let header = match LfoFileHeader::try_from(raw_payload.as_ref()) {
125 Ok(h) => h,
126 Err(e) => {
127 return Err(LfoError::ReplyParseError {
128 reason: e,
129 raw_payload,
130 })
131 }
132 };
133 let chunk_data = raw_payload.slice(LFO_RESP_HDR_LEN..raw_payload.len() - CRC_LEN);
134 let read_state = if header.comp_format == CompressionFormats::None as u16 {
135 ResponseReadState::Direct { read_pos: 0 }
136 } else if cfg!(feature = "lfo-compress-xz")
137 && header.comp_format == CompressionFormats::Xz as u16
138 {
139 #[cfg(not(feature = "lfo-compress-xz"))]
140 unreachable!();
141 #[cfg(feature = "lfo-compress-xz")]
142 ResponseReadState::Compressed {
143 stream: XzDecoder::new(chunk_data.clone().reader()),
144 }
145 } else {
146 return Err(LfoError::ReplyParseError {
147 reason: format!("Unsupported compression format {}", header.comp_format),
148 raw_payload,
149 });
150 };
151 Ok(Self {
152 raw_lfo_payload: raw_payload,
153 header,
154 lfo_data: chunk_data,
155 read_state,
156 read_hasher: Default::default(),
157 })
158 }
159}
160
161impl TryFrom<CloudProtoPacket> for LfoResponse {
162 type Error = LfoError;
163
164 fn try_from(reply: CloudProtoPacket) -> Result<Self, Self::Error> {
165 if reply.kind == LfoPacketKind::ReplyFail && reply.payload.len() >= 8 {
166 let msg = String::from_utf8_lossy(&reply.payload[8..]);
167
168 if msg == "internal error" {
171 Err(LfoError::NotFound)
172 } else {
173 Err(LfoError::ServerError(msg.to_string()))
174 }
175 } else if reply.kind == LfoPacketKind::ReplyOk {
176 trace!(
177 "Received LfoOk with {:#x} bytes raw payload",
178 reply.payload.len()
179 );
180 Self::try_from_raw_lfo_payload(reply.payload)
181 } else {
182 Err(LfoError::BadReplyKind(reply.kind))
183 }
184 }
185}
186
187impl Read for LfoResponse {
188 fn read(&mut self, mut buf: &mut [u8]) -> std::io::Result<usize> {
189 let hasher = &mut self.read_hasher;
190 match &mut self.read_state {
191 ResponseReadState::Direct { read_pos } => {
192 let remaining = &self.lfo_data[*read_pos..];
193 let attempted_count = cmp::min(buf.len(), remaining.len());
194 let count = buf.write(&remaining[..attempted_count])?;
195
196 Self::update_running_hash(hasher, &remaining[..count]);
197 if count == remaining.len() && count != 0 {
198 Self::check_hash_matches(&self.header.data_hash, hasher)
199 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
200 }
201
202 *read_pos += count;
203 Ok(count)
204 }
205 #[cfg(feature = "lfo-compress-xz")]
206 ResponseReadState::Compressed { stream } => {
207 let count = stream.read(buf)?;
208 Self::update_running_hash(hasher, &buf[..count]);
209
210 if stream.total_out() > self.header.payload_size as u64 {
211 return Err(std::io::Error::new(
212 std::io::ErrorKind::InvalidData,
213 LfoError::InvalidFinalSize {
214 expected: self.header.payload_size as usize,
215 actual: stream.total_out() as usize,
216 },
217 ));
218 } else if count != 0 && stream.total_out() == self.header.payload_size as u64 {
219 Self::check_hash_matches(&self.header.data_hash, hasher)
220 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
221 }
222
223 Ok(count)
224 }
225 }
226 }
227}
228
229#[cfg(test)]
230mod test {
231 use crate::framing::{CloudProtoPacket, CloudProtoVersion};
232 use crate::services::lfo::pkt_kind::LfoPacketKind;
233 use crate::services::lfo::test::TEST_REPLY_DATA;
234 use crate::services::lfo::{LfoError, LfoResponse};
235 use crate::services::CloudProtoMagic;
236 use std::io::Read;
237
238 fn check_test_vector(lfo_reply_hex: &str, expected_hash: &str) -> Result<(), LfoError> {
239 let lfo_reply = hex::decode(lfo_reply_hex).unwrap();
240 let reply_pkt = CloudProtoPacket {
241 magic: CloudProtoMagic::TS,
242 kind: LfoPacketKind::ReplyOk.into(),
243 version: CloudProtoVersion::Normal,
244 payload: lfo_reply.clone(),
245 };
246 let mut resp = LfoResponse::try_from(reply_pkt)?;
247 assert_eq!(resp.raw_lfo_payload(), &lfo_reply);
248
249 let data = {
250 let data_from_bytes1 = resp.data()?;
251 let mut data_from_read = Vec::new();
252 resp.read_to_end(&mut data_from_read)?;
253 let data_from_bytes2 = resp.data()?;
254 assert_eq!(data_from_bytes1, data_from_bytes2);
255 assert_eq!(data_from_bytes1, data_from_read);
256 data_from_read
257 };
258
259 use sha2::Digest;
260 let mut hasher = sha2::Sha256::new();
261 hasher.update(&data);
262 assert_eq!(&hex::encode(hasher.finalize().as_slice()), expected_hash);
263
264 assert_eq!(
266 expected_hash,
267 &hex::encode(resp.lfo_file_header().data_hash)
268 );
269 Ok(())
270 }
271
272 #[test]
273 fn simple_test_vector() -> Result<(), LfoError> {
274 let expected_hash = "a330869acb341ad81b4b64f92ed7b85e0a361ab0449017a9f7a5f09276a43655";
275 check_test_vector(TEST_REPLY_DATA, expected_hash)
276 }
277
278 #[test]
279 #[cfg(feature = "lfo-compress-xz")]
280 fn xz_test_vector() -> Result<(), LfoError> {
281 let hex = "000000000000015658dd00985ef1c304b973374fad8726aeac9769fe45d1bea2335630b0899b9ef60001fd377a585a0000016922de36020021011c00000010cf\
282 58cce0015500645d0055687c400160306c2cec9513bc4360c68796e3b982a76ad18024af592b8f044aae3937e42bec03336fa43a3ecd228463d4545ae8cf99a9\
283 6368bfc3d7137b5f1fe5cb4201c3928e6a07895cba5f7220d2a3f5400768f1a63acc53ae5abbf13d5b6b84000000c3d9916a00017cd602000000155b09133e30\
284 0d8b020000000001595a75e2d281";
285 let expected_hash = "58dd00985ef1c304b973374fad8726aeac9769fe45d1bea2335630b0899b9ef6";
286 check_test_vector(hex, expected_hash)
287 }
288}