libsql_hrana/
protobuf.rs

1use std::mem::replace;
2use std::sync::Arc;
3
4use ::bytes::{Buf, BufMut, Bytes};
5use prost::encoding::{
6    bytes, double, message, sint64, skip_field, string, uint32, DecodeContext, WireType,
7};
8use prost::DecodeError;
9
10use super::proto::{
11    BatchCond, BatchCondList, BatchResult, CursorEntry, StreamRequest, StreamResponse,
12    StreamResult, Value,
13};
14
15impl prost::Message for StreamResult {
16    fn encode_raw<B>(&self, buf: &mut B)
17    where
18        B: BufMut,
19        Self: Sized,
20    {
21        match self {
22            StreamResult::None => {}
23            StreamResult::Ok { response } => message::encode(1, response, buf),
24            StreamResult::Error { error } => message::encode(2, error, buf),
25        }
26    }
27
28    fn encoded_len(&self) -> usize {
29        match self {
30            StreamResult::None => 0,
31            StreamResult::Ok { response } => message::encoded_len(1, response),
32            StreamResult::Error { error } => message::encoded_len(2, error),
33        }
34    }
35
36    fn merge_field<B>(
37        &mut self,
38        _tag: u32,
39        _wire_type: WireType,
40        _buf: &mut B,
41        _ctx: DecodeContext,
42    ) -> Result<(), DecodeError>
43    where
44        B: Buf,
45        Self: Sized,
46    {
47        panic!("StreamResult can only be encoded, not decoded")
48    }
49
50    fn clear(&mut self) {
51        panic!("StreamResult can only be encoded, not decoded")
52    }
53}
54
55impl prost::Message for StreamRequest {
56    fn encode_raw<B>(&self, _buf: &mut B)
57    where
58        B: BufMut,
59        Self: Sized,
60    {
61        panic!("StreamRequest can only be decoded, not encoded")
62    }
63
64    fn encoded_len(&self) -> usize {
65        panic!("StreamRequest can only be decoded, not encoded")
66    }
67
68    fn merge_field<B>(
69        &mut self,
70        tag: u32,
71        wire_type: WireType,
72        buf: &mut B,
73        ctx: DecodeContext,
74    ) -> Result<(), DecodeError>
75    where
76        B: Buf,
77        Self: Sized,
78    {
79        macro_rules! merge {
80            ($variant:ident) => {{
81                let mut msg = match replace(self, StreamRequest::None) {
82                    StreamRequest::$variant(msg) => msg,
83                    _ => Default::default(),
84                };
85                message::merge(wire_type, &mut msg, buf, ctx)?;
86                *self = StreamRequest::$variant(msg);
87            }};
88        }
89
90        match tag {
91            1 => merge!(Close),
92            2 => merge!(Execute),
93            3 => merge!(Batch),
94            4 => merge!(Sequence),
95            5 => merge!(Describe),
96            6 => merge!(StoreSql),
97            7 => merge!(CloseSql),
98            8 => merge!(GetAutocommit),
99            _ => skip_field(wire_type, tag, buf, ctx)?,
100        }
101        Ok(())
102    }
103
104    fn clear(&mut self) {
105        *self = StreamRequest::None;
106    }
107}
108
109impl prost::Message for StreamResponse {
110    fn encode_raw<B>(&self, buf: &mut B)
111    where
112        B: BufMut,
113        Self: Sized,
114    {
115        match self {
116            StreamResponse::Close(msg) => message::encode(1, msg, buf),
117            StreamResponse::Execute(msg) => message::encode(2, msg, buf),
118            StreamResponse::Batch(msg) => message::encode(3, msg, buf),
119            StreamResponse::Sequence(msg) => message::encode(4, msg, buf),
120            StreamResponse::Describe(msg) => message::encode(5, msg, buf),
121            StreamResponse::StoreSql(msg) => message::encode(6, msg, buf),
122            StreamResponse::CloseSql(msg) => message::encode(7, msg, buf),
123            StreamResponse::GetAutocommit(msg) => message::encode(8, msg, buf),
124        }
125    }
126
127    fn encoded_len(&self) -> usize {
128        match self {
129            StreamResponse::Close(msg) => message::encoded_len(1, msg),
130            StreamResponse::Execute(msg) => message::encoded_len(2, msg),
131            StreamResponse::Batch(msg) => message::encoded_len(3, msg),
132            StreamResponse::Sequence(msg) => message::encoded_len(4, msg),
133            StreamResponse::Describe(msg) => message::encoded_len(5, msg),
134            StreamResponse::StoreSql(msg) => message::encoded_len(6, msg),
135            StreamResponse::CloseSql(msg) => message::encoded_len(7, msg),
136            StreamResponse::GetAutocommit(msg) => message::encoded_len(8, msg),
137        }
138    }
139
140    fn merge_field<B>(
141        &mut self,
142        _tag: u32,
143        _wire_type: WireType,
144        _buf: &mut B,
145        _ctx: DecodeContext,
146    ) -> Result<(), DecodeError>
147    where
148        B: Buf,
149        Self: Sized,
150    {
151        panic!("StreamResponse can only be encoded, not decoded")
152    }
153
154    fn clear(&mut self) {
155        panic!("StreamResponse can only be encoded, not decoded")
156    }
157}
158
159impl prost::Message for BatchResult {
160    fn encode_raw<B>(&self, buf: &mut B)
161    where
162        B: BufMut,
163        Self: Sized,
164    {
165        vec_as_map::encode(1, &self.step_results, buf);
166        vec_as_map::encode(2, &self.step_errors, buf);
167    }
168
169    fn encoded_len(&self) -> usize {
170        vec_as_map::encoded_len(1, &self.step_results)
171            + vec_as_map::encoded_len(2, &self.step_errors)
172    }
173
174    fn merge_field<B>(
175        &mut self,
176        _tag: u32,
177        _wire_type: WireType,
178        _buf: &mut B,
179        _ctx: DecodeContext,
180    ) -> Result<(), DecodeError>
181    where
182        B: Buf,
183        Self: Sized,
184    {
185        panic!("BatchResult can only be encoded, not decoded")
186    }
187
188    fn clear(&mut self) {
189        self.step_results.clear();
190        self.step_errors.clear();
191    }
192}
193
194impl prost::Message for BatchCond {
195    fn encode_raw<B>(&self, _buf: &mut B)
196    where
197        B: BufMut,
198        Self: Sized,
199    {
200        panic!("BatchCond can only be decoded, not encoded")
201    }
202
203    fn encoded_len(&self) -> usize {
204        panic!("BatchCond can only be decoded, not encoded")
205    }
206
207    fn merge_field<B>(
208        &mut self,
209        tag: u32,
210        wire_type: WireType,
211        buf: &mut B,
212        ctx: DecodeContext,
213    ) -> Result<(), DecodeError>
214    where
215        B: Buf,
216        Self: Sized,
217    {
218        match tag {
219            1 => {
220                let mut step = 0;
221                uint32::merge(wire_type, &mut step, buf, ctx)?;
222                *self = BatchCond::Ok { step }
223            }
224            2 => {
225                let mut step = 0;
226                uint32::merge(wire_type, &mut step, buf, ctx)?;
227                *self = BatchCond::Error { step }
228            }
229            3 => {
230                let mut cond = match replace(self, BatchCond::None) {
231                    BatchCond::Not { cond } => cond,
232                    _ => Box::new(BatchCond::None),
233                };
234                message::merge(wire_type, &mut *cond, buf, ctx)?;
235                *self = BatchCond::Not { cond };
236            }
237            4 => {
238                let mut cond_list = match replace(self, BatchCond::None) {
239                    BatchCond::And(cond_list) => cond_list,
240                    _ => BatchCondList::default(),
241                };
242                message::merge(wire_type, &mut cond_list, buf, ctx)?;
243                *self = BatchCond::And(cond_list);
244            }
245            5 => {
246                let mut cond_list = match replace(self, BatchCond::None) {
247                    BatchCond::Or(cond_list) => cond_list,
248                    _ => BatchCondList::default(),
249                };
250                message::merge(wire_type, &mut cond_list, buf, ctx)?;
251                *self = BatchCond::Or(cond_list);
252            }
253            6 => {
254                skip_field(wire_type, tag, buf, ctx)?;
255                *self = BatchCond::IsAutocommit {};
256            }
257            _ => {
258                skip_field(wire_type, tag, buf, ctx)?;
259            }
260        }
261        Ok(())
262    }
263
264    fn clear(&mut self) {
265        *self = BatchCond::None;
266    }
267}
268
269impl prost::Message for CursorEntry {
270    fn encode_raw<B>(&self, buf: &mut B)
271    where
272        B: BufMut,
273        Self: Sized,
274    {
275        match self {
276            CursorEntry::None => {}
277            CursorEntry::StepBegin(entry) => message::encode(1, entry, buf),
278            CursorEntry::StepEnd(entry) => message::encode(2, entry, buf),
279            CursorEntry::StepError(entry) => message::encode(3, entry, buf),
280            CursorEntry::Row { row } => message::encode(4, row, buf),
281            CursorEntry::Error { error } => message::encode(5, error, buf),
282            CursorEntry::ReplicationIndex { replication_index } => {
283                if let Some(replication_index) = replication_index {
284                    message::encode(6, replication_index, buf)
285                }
286            }
287        }
288    }
289
290    fn encoded_len(&self) -> usize {
291        match self {
292            CursorEntry::None => 0,
293            CursorEntry::StepBegin(entry) => message::encoded_len(1, entry),
294            CursorEntry::StepEnd(entry) => message::encoded_len(2, entry),
295            CursorEntry::StepError(entry) => message::encoded_len(3, entry),
296            CursorEntry::Row { row } => message::encoded_len(4, row),
297            CursorEntry::Error { error } => message::encoded_len(5, error),
298            CursorEntry::ReplicationIndex { replication_index } => {
299                if let Some(replication_index) = replication_index {
300                    message::encoded_len(6, replication_index)
301                } else {
302                    0
303                }
304            }
305        }
306    }
307
308    fn merge_field<B>(
309        &mut self,
310        _tag: u32,
311        _wire_type: WireType,
312        _buf: &mut B,
313        _ctx: DecodeContext,
314    ) -> Result<(), DecodeError>
315    where
316        B: Buf,
317        Self: Sized,
318    {
319        panic!("CursorEntry can only be encoded, not decoded")
320    }
321
322    fn clear(&mut self) {
323        *self = CursorEntry::None;
324    }
325}
326
327impl prost::Message for Value {
328    fn encode_raw<B>(&self, buf: &mut B)
329    where
330        B: BufMut,
331        Self: Sized,
332    {
333        match self {
334            Value::None => {}
335            Value::Null => empty_message::encode(1, buf),
336            Value::Integer { value } => sint64::encode(2, value, buf),
337            Value::Float { value } => double::encode(3, value, buf),
338            Value::Text { value } => arc_str::encode(4, value, buf),
339            Value::Blob { value } => bytes::encode(5, value, buf),
340        }
341    }
342
343    fn encoded_len(&self) -> usize {
344        match self {
345            Value::None => 0,
346            Value::Null => empty_message::encoded_len(1),
347            Value::Integer { value } => sint64::encoded_len(2, value),
348            Value::Float { value } => double::encoded_len(3, value),
349            Value::Text { value } => arc_str::encoded_len(4, value),
350            Value::Blob { value } => bytes::encoded_len(5, value),
351        }
352    }
353
354    fn merge_field<B>(
355        &mut self,
356        tag: u32,
357        wire_type: WireType,
358        buf: &mut B,
359        ctx: DecodeContext,
360    ) -> Result<(), DecodeError>
361    where
362        B: Buf,
363        Self: Sized,
364    {
365        match tag {
366            1 => {
367                skip_field(wire_type, tag, buf, ctx)?;
368                *self = Value::Null
369            }
370            2 => {
371                let mut value = 0;
372                sint64::merge(wire_type, &mut value, buf, ctx)?;
373                *self = Value::Integer { value };
374            }
375            3 => {
376                let mut value = 0.;
377                double::merge(wire_type, &mut value, buf, ctx)?;
378                *self = Value::Float { value };
379            }
380            4 => {
381                let mut value = String::new();
382                string::merge(wire_type, &mut value, buf, ctx)?;
383                // TODO: this makes an unnecessary copy
384                let value: Arc<str> = value.into();
385                *self = Value::Text { value };
386            }
387            5 => {
388                let mut value = Bytes::new();
389                bytes::merge(wire_type, &mut value, buf, ctx)?;
390                *self = Value::Blob { value };
391            }
392            _ => {
393                skip_field(wire_type, tag, buf, ctx)?;
394            }
395        }
396        Ok(())
397    }
398
399    fn clear(&mut self) {
400        *self = Value::None;
401    }
402}
403
404mod vec_as_map {
405    use bytes::BufMut;
406    use prost::encoding::{
407        encode_key, encode_varint, encoded_len_varint, key_len, message, uint32, WireType,
408    };
409
410    pub fn encode<B, M>(tag: u32, values: &[Option<M>], buf: &mut B)
411    where
412        B: BufMut,
413        M: prost::Message,
414    {
415        for (index, msg) in values.iter().enumerate() {
416            if let Some(msg) = msg {
417                encode_map_entry(tag, index as u32, msg, buf);
418            }
419        }
420    }
421
422    pub fn encoded_len<M>(tag: u32, values: &[Option<M>]) -> usize
423    where
424        M: prost::Message,
425    {
426        values
427            .iter()
428            .enumerate()
429            .map(|(index, msg)| match msg {
430                Some(msg) => encoded_map_entry_len(tag, index as u32, msg),
431                None => 0,
432            })
433            .sum()
434    }
435
436    fn encode_map_entry<B, M>(tag: u32, key: u32, value: &M, buf: &mut B)
437    where
438        B: BufMut,
439        M: prost::Message,
440    {
441        encode_key(tag, WireType::LengthDelimited, buf);
442
443        let entry_key_len = uint32::encoded_len(1, &key);
444        let entry_value_len = message::encoded_len(2, value);
445
446        encode_varint((entry_key_len + entry_value_len) as u64, buf);
447        uint32::encode(1, &key, buf);
448        message::encode(2, value, buf);
449    }
450
451    fn encoded_map_entry_len<M>(tag: u32, key: u32, value: &M) -> usize
452    where
453        M: prost::Message,
454    {
455        let entry_key_len = uint32::encoded_len(1, &key);
456        let entry_value_len = message::encoded_len(2, value);
457        let entry_len = entry_key_len + entry_value_len;
458        key_len(tag) + encoded_len_varint(entry_len as u64) + entry_len
459    }
460}
461
462mod empty_message {
463    use bytes::BufMut;
464    use prost::encoding::{encode_key, encode_varint, encoded_len_varint, key_len, WireType};
465
466    pub fn encode<B>(tag: u32, buf: &mut B)
467    where
468        B: BufMut,
469    {
470        encode_key(tag, WireType::LengthDelimited, buf);
471        encode_varint(0, buf);
472    }
473
474    pub fn encoded_len(tag: u32) -> usize {
475        key_len(tag) + encoded_len_varint(0)
476    }
477}
478
479mod arc_str {
480    use bytes::BufMut;
481    use prost::encoding::{encode_key, encode_varint, encoded_len_varint, key_len, WireType};
482    use std::sync::Arc;
483
484    pub fn encode<B>(tag: u32, value: &Arc<str>, buf: &mut B)
485    where
486        B: BufMut,
487    {
488        encode_key(tag, WireType::LengthDelimited, buf);
489        encode_varint(value.len() as u64, buf);
490        buf.put_slice(value.as_bytes());
491    }
492
493    pub fn encoded_len(tag: u32, value: &Arc<str>) -> usize {
494        key_len(tag) + encoded_len_varint(value.len() as u64) + value.len()
495    }
496}