minio_rsc/client/
select_object_reader.rs1use std::{collections::HashMap, ops::Range, pin::Pin};
2
3use async_stream::stream as Stream2;
4use bytes::{Bytes, BytesMut};
5use futures_core::Stream;
6use futures_util::StreamExt;
7
8use crate::{datatype::OutputSerialization, error::{Error, Result}};
9
10#[inline]
14fn read_u32(data: &[u8]) -> u32 {
15 u32::from_be_bytes(<[u8; 4]>::try_from(data).unwrap())
16}
17
18#[inline]
22fn read_u16(data: &[u8]) -> u16 {
23 u16::from_be_bytes(<[u8; 2]>::try_from(data).unwrap())
24}
25
26#[derive(PartialEq, Eq)]
28enum EventType {
29 Records,
30 Continuation,
31 Progress,
32 Stats,
33 End,
34 RequestLevelError,
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
38enum HeaderName {
39 Messagetype,
40 EventType,
41 ErrorCode,
42 ErrorMessage,
43}
44
45pub struct Message {
47 data: Bytes,
48 type_: EventType,
49 payload: Range<usize>,
50 headers: HashMap<HeaderName, String>,
51}
52
53impl<'a> Message {
54 pub fn payload(&self) -> &[u8] {
55 &self.data[self.payload.clone()]
56 }
57
58 pub fn is_records(&self) -> bool {
60 self.type_ == EventType::Records
61 }
62
63 pub fn is_progress(&self) -> bool {
65 self.type_ == EventType::Progress
66 }
67
68 pub fn is_stats(&self) -> bool {
70 self.type_ == EventType::Stats
71 }
72
73 pub fn is_continuation(&self) -> bool {
75 self.type_ == EventType::Continuation
76 }
77
78 pub fn is_end(&self) -> bool {
80 self.type_ == EventType::End
81 }
82
83 pub fn message_type(&self) -> Option<&String> {
85 self.headers.get(&HeaderName::Messagetype)
86 }
87
88 pub fn is_error(&self) -> bool {
91 self.type_ == EventType::RequestLevelError
92 }
93
94 pub fn error_code(&self) -> Option<&String> {
96 self.headers.get(&HeaderName::ErrorCode)
97 }
98
99 pub fn error_message(&self) -> Option<&String> {
101 self.headers.get(&HeaderName::ErrorMessage)
102 }
103}
104
105impl<'a> TryFrom<Bytes> for Message {
106 type Error = String;
107
108 fn try_from(data: Bytes) -> std::result::Result<Self, Self::Error> {
109 let prelude_crc = read_u32(&data[8..12]);
110 let prelude_crc_calc = crc32fast::hash(&data[0..8]);
111 if prelude_crc != prelude_crc_calc {
112 return Err(format!(
113 "prelude CRC mismatch; expected: {prelude_crc}, got: {prelude_crc_calc}"
114 ));
115 }
116 let message_crc = read_u32(&data[data.len() - 4..]);
117 let message_crc_calc = crc32fast::hash(&data[0..data.len() - 4]);
118 if message_crc != message_crc_calc {
119 return Err(format!(
120 "message CRC mismatch; expected: {message_crc}, got: {message_crc_calc}"
121 ));
122 }
123 let header_length = read_u32(&data[4..8]) as usize;
124 let header_end = 12 + header_length;
125
126 let payload = 12 + header_length..data.len() - 4;
127
128 let mut pos = 12;
129 let mut headers = HashMap::new();
130 loop {
131 let key_len = data[pos] as usize;
132 pos += 1;
133 let key = &data[pos..pos + key_len];
134 pos += key_len + 1;
135 let value_len = read_u16(&data[pos..pos + 2]) as usize;
136 pos += 2;
137 let val = &data[pos..pos + value_len];
138 let val = String::from_utf8(val.to_vec()).unwrap();
139 pos += value_len;
140 let header_name = match key {
141 b":message-type" => HeaderName::Messagetype,
142 b":event-type" => HeaderName::EventType,
143 b":error-code" => HeaderName::ErrorCode,
144 b":error-message" => HeaderName::ErrorMessage,
145 _ => continue,
146 };
147 headers.insert(header_name, val);
148 if pos >= header_end {
149 break;
150 }
151 }
152 if let Some(event_type) = headers.get(&HeaderName::EventType) {
153 let type_: EventType = match event_type.as_str() {
154 "Continuation" => EventType::Continuation,
155 "Progress" => EventType::Progress,
156 "Records" => EventType::Records,
157 "Stats" => EventType::Stats,
158 "End" => EventType::End,
159 ev => return Err(format!("unknown event type: {ev:?}")),
160 };
161 return Ok(Message {
162 data,
163 type_,
164 payload,
165 headers,
166 });
167 } else {
168 if headers.contains_key(&HeaderName::ErrorCode) {
169 return Ok(Message {
170 data,
171 type_: EventType::RequestLevelError,
172 payload,
173 headers,
174 });
175 } else {
176 Err(format!("unknown message"))
177 }
178 }
179 }
180}
181
182pub struct SelectObjectReader {
184 response: reqwest::Response,
185 output_serialization: OutputSerialization,
186}
187
188impl SelectObjectReader {
189 pub(crate) fn new(
190 response: reqwest::Response,
191 output_serialization: OutputSerialization,
192 ) -> Self {
193 Self {
194 response,
195 output_serialization,
196 }
197 }
198
199 pub fn read_message(mut self) -> Pin<Box<dyn Stream<Item = Result<Message>> + Send>> {
201 Box::pin(Stream2! {
202 let mut buf = BytesMut::new();
203 let mut msg_len = 0;
204 let mut is_over = false;
205 loop{
206 if !is_over{
207 if let Some(data) = self.response.chunk().await?{
208 buf.extend_from_slice(&data);
209 }else{
210 is_over = true;
211 };
212 }else{
213 match buf.len(){
214 0=>break,
215 l if l < 4 => Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, format!("not enough data in the stream; expected: 4, got: {} bytes", l)))?,
216 l if l < msg_len => Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, format!("not enough data in the stream; expected: {}, got: {} bytes", msg_len, l)))?,
217 _=>{}
218 }
219 }
220 if msg_len == 0 && buf.len() >= 4{
221 msg_len = read_u32(&buf[0..4]) as usize;
222 }
223 if msg_len > 0 && buf.len() >= msg_len{
224 let msg_data = buf.split_to(msg_len);
225 msg_len = 0;
226 yield Ok(Message::try_from(msg_data.freeze()).map_err(|e| Error::MessageDecodeError(e))?);
227 }
228 }
229 })
230 }
231
232 pub async fn read_all(self) -> Result<Bytes> {
234 let mut data = BytesMut::new();
235 let mut messages = self.read_message();
236 while let Some(message) = messages.next().await {
237 let message = message?;
238 if message.is_records() {
239 data.extend_from_slice(message.payload());
240 } else if message.is_error() {
241 Err(Error::SelectObjectError(format!(
242 "Select Message Error code: {:?}, error message: {:?}",
243 message.error_code(),
244 message.error_message(),
245 )))?
246 }
247 }
248 Ok(data.freeze())
249 }
250
251 pub fn output_serialization(&self) -> &OutputSerialization {
253 &self.output_serialization
254 }
255}