minio_rsc/client/
select_object_reader.rs

1use std::{collections::HashMap, ops::Range, pin::Pin};
2
3use async_stream::stream as Stream2;
4use bytes::{Bytes, BytesMut};
5use futures_core::Stream;
6use futures_util::StreamExt;
7
8use crate::{datatype::OutputSerialization, error::{Error, Result}};
9
10/// read u32 from `&[u8]`
11/// # Panics
12/// Panics if `data.len() != 4`.
13#[inline]
14fn read_u32(data: &[u8]) -> u32 {
15    u32::from_be_bytes(<[u8; 4]>::try_from(data).unwrap())
16}
17
18/// read u16 from `&[u8]`
19/// # Panics
20/// Panics if `data.len() != 2`.
21#[inline]
22fn read_u16(data: &[u8]) -> u16 {
23    u16::from_be_bytes(<[u8; 2]>::try_from(data).unwrap())
24}
25
26/// the event type of message from select object content.
27#[derive(PartialEq, Eq)]
28enum EventType {
29    Records,
30    Continuation,
31    Progress,
32    Stats,
33    End,
34    RequestLevelError,
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
38enum HeaderName {
39    Messagetype,
40    EventType,
41    ErrorCode,
42    ErrorMessage,
43}
44
45/// message from select object content.
46pub struct Message {
47    data: Bytes,
48    type_: EventType,
49    payload: Range<usize>,
50    headers: HashMap<HeaderName, String>,
51}
52
53impl<'a> Message {
54    pub fn payload(&self) -> &[u8] {
55        &self.data[self.payload.clone()]
56    }
57
58    /// Message type is Records. It can contain a single record, a partial record, or multiple records, depending on the number of search results.
59    pub fn is_records(&self) -> bool {
60        self.type_ == EventType::Records
61    }
62
63    /// Message type is Progress.
64    pub fn is_progress(&self) -> bool {
65        self.type_ == EventType::Progress
66    }
67
68    /// Message type is Stats.
69    pub fn is_stats(&self) -> bool {
70        self.type_ == EventType::Stats
71    }
72
73    /// Message type is Continuation.
74    pub fn is_continuation(&self) -> bool {
75        self.type_ == EventType::Continuation
76    }
77
78    /// Message type is End.
79    pub fn is_end(&self) -> bool {
80        self.type_ == EventType::End
81    }
82
83    /// return the value of *:message-type* header.
84    pub fn message_type(&self) -> Option<&String> {
85        self.headers.get(&HeaderName::Messagetype)
86    }
87
88    /// Message type is Error, more info by `error_code` `error_message` method.
89    /// If returns this information, the End message information will not be returned.
90    pub fn is_error(&self) -> bool {
91        self.type_ == EventType::RequestLevelError
92    }
93
94    /// return the value of *:error-code* header, None if this Message is not error.
95    pub fn error_code(&self) -> Option<&String> {
96        self.headers.get(&HeaderName::ErrorCode)
97    }
98
99    /// return the value of *:error-message* header, None if this Message is not error.
100    pub fn error_message(&self) -> Option<&String> {
101        self.headers.get(&HeaderName::ErrorMessage)
102    }
103}
104
105impl<'a> TryFrom<Bytes> for Message {
106    type Error = String;
107
108    fn try_from(data: Bytes) -> std::result::Result<Self, Self::Error> {
109        let prelude_crc = read_u32(&data[8..12]);
110        let prelude_crc_calc = crc32fast::hash(&data[0..8]);
111        if prelude_crc != prelude_crc_calc {
112            return Err(format!(
113                "prelude CRC mismatch; expected: {prelude_crc}, got: {prelude_crc_calc}"
114            ));
115        }
116        let message_crc = read_u32(&data[data.len() - 4..]);
117        let message_crc_calc = crc32fast::hash(&data[0..data.len() - 4]);
118        if message_crc != message_crc_calc {
119            return Err(format!(
120                "message CRC mismatch; expected: {message_crc}, got: {message_crc_calc}"
121            ));
122        }
123        let header_length = read_u32(&data[4..8]) as usize;
124        let header_end = 12 + header_length;
125
126        let payload = 12 + header_length..data.len() - 4;
127
128        let mut pos = 12;
129        let mut headers = HashMap::new();
130        loop {
131            let key_len = data[pos] as usize;
132            pos += 1;
133            let key = &data[pos..pos + key_len];
134            pos += key_len + 1;
135            let value_len = read_u16(&data[pos..pos + 2]) as usize;
136            pos += 2;
137            let val = &data[pos..pos + value_len];
138            let val = String::from_utf8(val.to_vec()).unwrap();
139            pos += value_len;
140            let header_name = match key {
141                b":message-type" => HeaderName::Messagetype,
142                b":event-type" => HeaderName::EventType,
143                b":error-code" => HeaderName::ErrorCode,
144                b":error-message" => HeaderName::ErrorMessage,
145                _ => continue,
146            };
147            headers.insert(header_name, val);
148            if pos >= header_end {
149                break;
150            }
151        }
152        if let Some(event_type) = headers.get(&HeaderName::EventType) {
153            let type_: EventType = match event_type.as_str() {
154                "Continuation" => EventType::Continuation,
155                "Progress" => EventType::Progress,
156                "Records" => EventType::Records,
157                "Stats" => EventType::Stats,
158                "End" => EventType::End,
159                ev => return Err(format!("unknown event type: {ev:?}")),
160            };
161            return Ok(Message {
162                data,
163                type_,
164                payload,
165                headers,
166            });
167        } else {
168            if headers.contains_key(&HeaderName::ErrorCode) {
169                return Ok(Message {
170                    data,
171                    type_: EventType::RequestLevelError,
172                    payload,
173                    headers,
174                });
175            } else {
176                Err(format!("unknown message"))
177            }
178        }
179    }
180}
181
182/// reader response data of `select_object_content` method
183pub struct SelectObjectReader {
184    response: reqwest::Response,
185    output_serialization: OutputSerialization,
186}
187
188impl SelectObjectReader {
189    pub(crate) fn new(
190        response: reqwest::Response,
191        output_serialization: OutputSerialization,
192    ) -> Self {
193        Self {
194            response,
195            output_serialization,
196        }
197    }
198
199    /// Read [Message] as streams
200    pub fn read_message(mut self) -> Pin<Box<dyn Stream<Item = Result<Message>> + Send>> {
201        Box::pin(Stream2! {
202            let mut buf = BytesMut::new();
203            let mut msg_len = 0;
204            let mut is_over = false;
205            loop{
206                if !is_over{
207                    if let Some(data) = self.response.chunk().await?{
208                        buf.extend_from_slice(&data);
209                    }else{
210                        is_over = true;
211                    };
212                }else{
213                    match buf.len(){
214                        0=>break,
215                        l if l < 4 => Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, format!("not enough data in the stream; expected: 4, got: {} bytes", l)))?,
216                        l if l < msg_len => Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, format!("not enough data in the stream; expected: {}, got: {} bytes", msg_len, l)))?,
217                        _=>{}
218                    }
219                }
220                if msg_len == 0 && buf.len() >= 4{
221                    msg_len = read_u32(&buf[0..4]) as usize;
222                }
223                if msg_len > 0 && buf.len() >= msg_len{
224                    let msg_data = buf.split_to(msg_len);
225                    msg_len = 0;
226                    yield Ok(Message::try_from(msg_data.freeze()).map_err(|e| Error::MessageDecodeError(e))?);
227                }
228            }
229        })
230    }
231
232    /// Read all response data at once and decode the content to bytes.
233    pub async fn read_all(self) -> Result<Bytes> {
234        let mut data = BytesMut::new();
235        let mut messages = self.read_message();
236        while let Some(message) = messages.next().await {
237            let message = message?;
238            if message.is_records() {
239                data.extend_from_slice(message.payload());
240            } else if message.is_error() {
241                Err(Error::SelectObjectError(format!(
242                    "Select Message Error code: {:?}, error message: {:?}",
243                    message.error_code(),
244                    message.error_message(),
245                )))?
246            }
247        }
248        Ok(data.freeze())
249    }
250
251    /// get [OutputSerialization]
252    pub fn output_serialization(&self) -> &OutputSerialization {
253        &self.output_serialization
254    }
255}