gel_protocol/
client_message.rs

1/*!
2([Website reference](https://www.edgedb.com/docs/reference/protocol/messages)) The [ClientMessage] enum and related types.
3
4```rust,ignore
5pub enum ClientMessage {
6    ClientHandshake(ClientHandshake),
7    Parse(Parse),
8    Execute1(Execute1),
9    UnknownMessage(u8, Bytes),
10    AuthenticationSaslInitialResponse(SaslInitialResponse),
11    AuthenticationSaslResponse(SaslResponse),
12    Dump(Dump),
13    Restore(Restore),
14    RestoreBlock(RestoreBlock),
15    RestoreEof,
16    Sync,
17    Terminate,
18}
19```
20*/
21
22use std::collections::HashMap;
23use std::convert::TryFrom;
24use std::sync::Arc;
25
26use bytes::{Buf, BufMut, Bytes};
27use snafu::OptionExt;
28use uuid::Uuid;
29
30pub use crate::common::CompilationOptions;
31pub use crate::common::DumpFlags;
32pub use crate::common::{Capabilities, Cardinality, CompilationFlags};
33pub use crate::common::{RawTypedesc, State};
34use crate::encoding::{encode, Decode, Encode, Input, Output};
35use crate::encoding::{Annotations, KeyValues};
36use crate::errors::{self, DecodeError, EncodeError};
37use crate::new_protocol;
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40#[non_exhaustive]
41pub enum ClientMessage {
42    AuthenticationSaslInitialResponse(SaslInitialResponse),
43    AuthenticationSaslResponse(SaslResponse),
44    ClientHandshake(ClientHandshake),
45    Dump2(Dump2),
46    Dump3(Dump3),
47    Parse(Parse),
48    Execute1(Execute1),
49    Restore(Restore),
50    RestoreBlock(RestoreBlock),
51    RestoreEof,
52    Sync,
53    Terminate,
54    UnknownMessage(u8, Bytes),
55}
56
57#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct SaslInitialResponse {
59    pub method: String,
60    pub data: Bytes,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub struct SaslResponse {
65    pub data: Bytes,
66}
67
68#[derive(Debug, Clone, PartialEq, Eq)]
69pub struct ClientHandshake {
70    pub major_ver: u16,
71    pub minor_ver: u16,
72    pub params: HashMap<String, String>,
73    pub extensions: HashMap<String, Annotations>,
74}
75
76#[derive(Debug, Clone, PartialEq, Eq)]
77pub struct Parse {
78    pub annotations: Option<Arc<Annotations>>,
79    pub allowed_capabilities: Capabilities,
80    pub compilation_flags: CompilationFlags,
81    pub implicit_limit: Option<u64>,
82    pub output_format: IoFormat,
83    pub expected_cardinality: Cardinality,
84    pub command_text: String,
85    pub state: State,
86    pub input_language: InputLanguage,
87}
88
89#[derive(Debug, Clone, PartialEq, Eq)]
90pub struct Execute1 {
91    pub annotations: Option<Arc<Annotations>>,
92    pub allowed_capabilities: Capabilities,
93    pub compilation_flags: CompilationFlags,
94    pub implicit_limit: Option<u64>,
95    pub output_format: IoFormat,
96    pub expected_cardinality: Cardinality,
97    pub command_text: String,
98    pub state: State,
99    pub input_typedesc_id: Uuid,
100    pub output_typedesc_id: Uuid,
101    pub arguments: Bytes,
102    pub input_language: InputLanguage,
103}
104
105#[derive(Debug, Clone, PartialEq, Eq)]
106pub struct Dump2 {
107    pub headers: KeyValues,
108}
109
110#[derive(Debug, Clone, PartialEq, Eq)]
111pub struct Dump3 {
112    pub annotations: Option<Arc<Annotations>>,
113    pub flags: DumpFlags,
114}
115
116#[derive(Debug, Clone, PartialEq, Eq)]
117pub struct Restore {
118    pub headers: KeyValues,
119    pub jobs: u16,
120    pub data: Bytes,
121}
122
123#[derive(Debug, Clone, PartialEq, Eq)]
124pub struct RestoreBlock {
125    pub data: Bytes,
126}
127
128pub use crate::new_protocol::{InputLanguage, IoFormat};
129
130struct Empty;
131impl ClientMessage {
132    pub fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
133        use ClientMessage::*;
134        match self {
135            ClientHandshake(h) => encode(buf, 0x56, h),
136            AuthenticationSaslInitialResponse(h) => encode(buf, 0x70, h),
137            AuthenticationSaslResponse(h) => encode(buf, 0x72, h),
138            Parse(h) => encode(buf, 0x50, h),
139            Execute1(h) => encode(buf, 0x4f, h),
140            Dump2(h) => encode(buf, 0x3e, h),
141            Dump3(h) => encode(buf, 0x3e, h),
142            Restore(h) => encode(buf, 0x3c, h),
143            RestoreBlock(h) => encode(buf, 0x3d, h),
144            RestoreEof => encode(buf, 0x2e, &Empty),
145            Sync => encode(buf, 0x53, &Empty),
146            Terminate => encode(buf, 0x58, &Empty),
147
148            UnknownMessage(_, _) => errors::UnknownMessageCantBeEncoded.fail()?,
149        }
150    }
151    /// Decode exactly one frame from the buffer.
152    ///
153    /// This expects a full frame to already be in the buffer. It can return
154    /// an arbitrary error or be silent if a message is only partially present
155    /// in the buffer or if extra data is present.
156    pub fn decode(buf: &mut Input) -> Result<ClientMessage, DecodeError> {
157        let message = new_protocol::Message::new(buf)?;
158        let mut next = buf.slice(..message.mlen() + 1);
159        buf.advance(message.mlen() + 1);
160        let buf = &mut next;
161
162        use self::ClientMessage as M;
163        let result = match buf[0] {
164            0x56 => ClientHandshake::decode(buf).map(M::ClientHandshake)?,
165            0x70 => SaslInitialResponse::decode(buf).map(M::AuthenticationSaslInitialResponse)?,
166            0x72 => SaslResponse::decode(buf).map(M::AuthenticationSaslResponse)?,
167            0x50 => Parse::decode(buf).map(M::Parse)?,
168            0x4f => Execute1::decode(buf).map(M::Execute1)?,
169            0x3e => {
170                if buf.proto().is_3() {
171                    Dump3::decode(buf).map(M::Dump3)?
172                } else {
173                    Dump2::decode(buf).map(M::Dump2)?
174                }
175            }
176            0x3c => Restore::decode(buf).map(M::Restore)?,
177            0x3d => RestoreBlock::decode(buf).map(M::RestoreBlock)?,
178            0x2e => M::RestoreEof,
179            0x53 => M::Sync,
180            0x58 => M::Terminate,
181            code => M::UnknownMessage(code, buf.copy_to_bytes(buf.remaining())),
182        };
183        Ok(result)
184    }
185}
186
187impl Encode for Empty {
188    fn encode(&self, _buf: &mut Output) -> Result<(), EncodeError> {
189        Ok(())
190    }
191}
192
193impl Encode for ClientHandshake {
194    fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
195        buf.reserve(8);
196        buf.put_u16(self.major_ver);
197        buf.put_u16(self.minor_ver);
198        buf.put_u16(
199            u16::try_from(self.params.len())
200                .ok()
201                .context(errors::TooManyParams)?,
202        );
203        for (k, v) in &self.params {
204            k.encode(buf)?;
205            v.encode(buf)?;
206        }
207        buf.reserve(2);
208        buf.put_u16(
209            u16::try_from(self.extensions.len())
210                .ok()
211                .context(errors::TooManyExtensions)?,
212        );
213        for (name, headers) in &self.extensions {
214            String::encode(name, buf)?;
215            buf.reserve(2);
216            buf.put_u16(
217                u16::try_from(headers.len())
218                    .ok()
219                    .context(errors::TooManyHeaders)?,
220            );
221            for (name, value) in headers {
222                String::encode(name, buf)?;
223                String::encode(value, buf)?;
224            }
225        }
226        Ok(())
227    }
228}
229
230impl Decode for ClientHandshake {
231    fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
232        let message = new_protocol::ClientHandshake::new(buf)?;
233        let mut params = HashMap::new();
234        for param in message.params() {
235            params.insert(
236                param.name().to_string_lossy().to_string(),
237                param.value().to_string_lossy().to_string(),
238            );
239        }
240
241        let mut extensions = HashMap::new();
242        for ext in message.extensions() {
243            let mut headers = HashMap::new();
244            for ann in ext.annotations() {
245                headers.insert(
246                    ann.name().to_string_lossy().to_string(),
247                    ann.value().to_string_lossy().to_string(),
248                );
249            }
250            extensions.insert(ext.name().to_string_lossy().to_string(), headers);
251        }
252
253        let decoded = ClientHandshake {
254            major_ver: message.major_ver(),
255            minor_ver: message.minor_ver(),
256            params,
257            extensions,
258        };
259        buf.advance(message.as_ref().len());
260        Ok(decoded)
261    }
262}
263
264impl Encode for SaslInitialResponse {
265    fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
266        self.method.encode(buf)?;
267        self.data.encode(buf)?;
268        Ok(())
269    }
270}
271
272impl Decode for SaslInitialResponse {
273    fn decode(buf: &mut Input) -> Result<SaslInitialResponse, DecodeError> {
274        let message = new_protocol::AuthenticationSASLInitialResponse::new(buf)?;
275        let decoded = SaslInitialResponse {
276            method: message.method().to_string_lossy().to_string(),
277            data: message.sasl_data().into_slice().to_owned().into(),
278        };
279        buf.advance(message.as_ref().len());
280        Ok(decoded)
281    }
282}
283
284impl Encode for SaslResponse {
285    fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
286        self.data.encode(buf)?;
287        Ok(())
288    }
289}
290
291impl Decode for SaslResponse {
292    fn decode(buf: &mut Input) -> Result<SaslResponse, DecodeError> {
293        let message = new_protocol::AuthenticationSASLResponse::new(buf)?;
294        let decoded = SaslResponse {
295            data: message.sasl_data().into_slice().to_owned().into(),
296        };
297        buf.advance(message.as_ref().len());
298        Ok(decoded)
299    }
300}
301
302impl Encode for Execute1 {
303    fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
304        buf.reserve(2 + 3 * 8 + 1 + 1 + 4 + 16 + 4 + 16 + 16 + 4);
305        if let Some(annotations) = self.annotations.as_deref() {
306            buf.put_u16(
307                u16::try_from(annotations.len())
308                    .ok()
309                    .context(errors::TooManyHeaders)?,
310            );
311            for (name, value) in annotations {
312                buf.reserve(4);
313                name.encode(buf)?;
314                value.encode(buf)?;
315            }
316        } else {
317            buf.put_u16(0);
318        }
319        buf.reserve(3 * 8 + 1 + 1 + 4 + 16 + 4 + 16 + 16 + 4);
320        buf.put_u64(self.allowed_capabilities.bits());
321        buf.put_u64(self.compilation_flags.bits());
322        buf.put_u64(self.implicit_limit.unwrap_or(0));
323        if buf.proto().is_multilingual() {
324            buf.put_u8(self.input_language as u8);
325        }
326        buf.put_u8(self.output_format as u8);
327        buf.put_u8(self.expected_cardinality as u8);
328        self.command_text.encode(buf)?;
329        self.state.typedesc_id.encode(buf)?;
330        self.state.data.encode(buf)?;
331        self.input_typedesc_id.encode(buf)?;
332        self.output_typedesc_id.encode(buf)?;
333        self.arguments.encode(buf)?;
334        Ok(())
335    }
336}
337
338impl Decode for Execute1 {
339    fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
340        if buf.proto().is_multilingual() {
341            let message = new_protocol::Execute::new(buf)?;
342
343            // Convert annotations
344            let annotations = if !message.annotations().is_empty() {
345                let mut ann_map = HashMap::new();
346                for ann in message.annotations() {
347                    ann_map.insert(
348                        ann.name().to_string_lossy().to_string(),
349                        ann.value().to_string_lossy().to_string(),
350                    );
351                }
352                Some(Arc::new(ann_map))
353            } else {
354                None
355            };
356
357            // Convert state
358            let state = State {
359                typedesc_id: message.state_typedesc_id(),
360                data: message.state_data().into_slice().to_owned().into(),
361            };
362
363            let decoded = Execute1 {
364                annotations,
365                allowed_capabilities: Capabilities::from_bits_retain(
366                    message.allowed_capabilities(),
367                ),
368                compilation_flags: decode_compilation_flags(message.compilation_flags())?,
369                implicit_limit: match message.implicit_limit() {
370                    0 => None,
371                    val => Some(val),
372                },
373                output_format: message.output_format(),
374                expected_cardinality: TryFrom::try_from(message.expected_cardinality())?,
375                command_text: message.command_text().to_string_lossy().to_string(),
376                state,
377                input_typedesc_id: message.input_typedesc_id(),
378                output_typedesc_id: message.output_typedesc_id(),
379                arguments: message.arguments().into_slice().to_owned().into(),
380                input_language: message.input_language(),
381            };
382            buf.advance(message.as_ref().len());
383            Ok(decoded)
384        } else {
385            let message = new_protocol::Execute2::new(buf)?;
386
387            // Convert annotations
388            let annotations = if !message.annotations().is_empty() {
389                let mut ann_map = HashMap::new();
390                for ann in message.annotations() {
391                    ann_map.insert(
392                        ann.name().to_string_lossy().to_string(),
393                        ann.value().to_string_lossy().to_string(),
394                    );
395                }
396                Some(Arc::new(ann_map))
397            } else {
398                None
399            };
400
401            // Convert state
402            let state = State {
403                typedesc_id: message.state_typedesc_id(),
404                data: message.state_data().into_slice().to_owned().into(),
405            };
406
407            let decoded = Execute1 {
408                annotations,
409                allowed_capabilities: decode_capabilities(message.allowed_capabilities())?,
410                compilation_flags: decode_compilation_flags(message.compilation_flags())?,
411                implicit_limit: match message.implicit_limit() {
412                    0 => None,
413                    val => Some(val),
414                },
415                output_format: message.output_format(),
416                expected_cardinality: TryFrom::try_from(message.expected_cardinality())?,
417                command_text: message.command_text().to_string_lossy().to_string(),
418                state,
419                input_typedesc_id: message.input_typedesc_id(),
420                output_typedesc_id: message.output_typedesc_id(),
421                arguments: message.arguments().into_slice().to_owned().into(),
422                input_language: InputLanguage::EdgeQL,
423            };
424            buf.advance(message.as_ref().len());
425            Ok(decoded)
426        }
427    }
428}
429
430impl Encode for Dump2 {
431    fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
432        buf.reserve(10);
433        buf.put_u16(
434            u16::try_from(self.headers.len())
435                .ok()
436                .context(errors::TooManyHeaders)?,
437        );
438        for (&name, value) in &self.headers {
439            buf.reserve(2);
440            buf.put_u16(name);
441            value.encode(buf)?;
442        }
443        Ok(())
444    }
445}
446
447impl Decode for Dump2 {
448    fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
449        let message = new_protocol::Dump2::new(buf)?;
450        let mut headers = HashMap::new();
451        for header in message.headers() {
452            headers.insert(header.code(), header.value().into_slice().to_owned().into());
453        }
454
455        let decoded = Dump2 { headers };
456        buf.advance(message.as_ref().len());
457        Ok(decoded)
458    }
459}
460
461impl Encode for Dump3 {
462    fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
463        buf.reserve(2 + 8);
464        if let Some(annotations) = self.annotations.as_deref() {
465            buf.put_u16(
466                u16::try_from(annotations.len())
467                    .ok()
468                    .context(errors::TooManyHeaders)?,
469            );
470            for (name, value) in annotations {
471                buf.reserve(4);
472                name.encode(buf)?;
473                value.encode(buf)?;
474            }
475        } else {
476            buf.put_u16(0);
477        }
478        buf.put_u64(self.flags.bits());
479        Ok(())
480    }
481}
482
483impl Decode for Dump3 {
484    fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
485        let message = new_protocol::Dump3::new(buf)?;
486        let mut annotations = HashMap::new();
487        for ann in message.annotations() {
488            annotations.insert(
489                ann.name().to_string_lossy().to_string(),
490                ann.value().to_string_lossy().to_string(),
491            );
492        }
493
494        let decoded = Dump3 {
495            annotations: Some(Arc::new(annotations)),
496            flags: decode_dump_flags(message.flags())?,
497        };
498        buf.advance(message.as_ref().len());
499        Ok(decoded)
500    }
501}
502
503impl Encode for Restore {
504    fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
505        buf.reserve(4 + self.data.len());
506        buf.put_u16(
507            u16::try_from(self.headers.len())
508                .ok()
509                .context(errors::TooManyHeaders)?,
510        );
511        for (&name, value) in &self.headers {
512            buf.reserve(2);
513            buf.put_u16(name);
514            value.encode(buf)?;
515        }
516        buf.put_u16(self.jobs);
517        buf.extend(&self.data);
518        Ok(())
519    }
520}
521
522impl Decode for Restore {
523    fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
524        let message = new_protocol::Restore::new(buf)?;
525        let mut headers = HashMap::new();
526        for header in message.headers() {
527            headers.insert(header.code(), header.value().into_slice().to_owned().into());
528        }
529
530        let decoded = Restore {
531            headers,
532            jobs: message.jobs(),
533            data: message.data().as_ref().to_owned().into(),
534        };
535        buf.advance(message.as_ref().len());
536        Ok(decoded)
537    }
538}
539
540impl Encode for RestoreBlock {
541    fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
542        buf.extend(&self.data);
543        Ok(())
544    }
545}
546
547impl Decode for RestoreBlock {
548    fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
549        let message = new_protocol::RestoreBlock::new(buf)?;
550        let decoded = RestoreBlock {
551            data: message.block_data().into_slice().to_owned().into(),
552        };
553        buf.advance(message.as_ref().len());
554        Ok(decoded)
555    }
556}
557
558impl Parse {
559    pub fn new(
560        opts: &CompilationOptions,
561        query: &str,
562        state: State,
563        annotations: Option<Arc<Annotations>>,
564    ) -> Parse {
565        Parse {
566            annotations,
567            allowed_capabilities: opts.allow_capabilities,
568            compilation_flags: opts.flags(),
569            implicit_limit: opts.implicit_limit,
570            output_format: opts.io_format,
571            expected_cardinality: opts.expected_cardinality,
572            command_text: query.into(),
573            state,
574            input_language: opts.input_language,
575        }
576    }
577}
578
579fn decode_capabilities(val: u64) -> Result<Capabilities, DecodeError> {
580    Capabilities::from_bits(val)
581        .ok_or_else(|| errors::InvalidCapabilities { capabilities: val }.build())
582}
583
584fn decode_compilation_flags(val: u64) -> Result<CompilationFlags, DecodeError> {
585    CompilationFlags::from_bits(val).ok_or_else(|| {
586        errors::InvalidCompilationFlags {
587            compilation_flags: val,
588        }
589        .build()
590    })
591}
592
593fn decode_dump_flags(val: u64) -> Result<DumpFlags, DecodeError> {
594    DumpFlags::from_bits(val).ok_or_else(|| errors::InvalidDumpFlags { dump_flags: val }.build())
595}
596
597impl Decode for Parse {
598    fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
599        if buf.proto().is_multilingual() {
600            let message = new_protocol::Parse::new(buf)?;
601
602            // Convert annotations
603            let annotations = if !message.annotations().is_empty() {
604                let mut ann_map = HashMap::new();
605                for ann in message.annotations() {
606                    ann_map.insert(
607                        ann.name().to_string_lossy().to_string(),
608                        ann.value().to_string_lossy().to_string(),
609                    );
610                }
611                Some(Arc::new(ann_map))
612            } else {
613                None
614            };
615
616            // Convert state
617            let state = State {
618                typedesc_id: message.state_typedesc_id(),
619                data: message.state_data().into_slice().to_owned().into(),
620            };
621
622            let decoded = Parse {
623                annotations,
624                allowed_capabilities: decode_capabilities(message.allowed_capabilities())?,
625                compilation_flags: decode_compilation_flags(message.compilation_flags())?,
626                implicit_limit: match message.implicit_limit() {
627                    0 => None,
628                    val => Some(val),
629                },
630                output_format: message.output_format(),
631                expected_cardinality: TryFrom::try_from(message.expected_cardinality())?,
632                command_text: message.command_text().to_string_lossy().to_string(),
633                state,
634                input_language: message.input_language(),
635            };
636            buf.advance(message.as_ref().len());
637            Ok(decoded)
638        } else {
639            let message = new_protocol::Parse2::new(buf)?;
640
641            // Convert annotations
642            let annotations = if !message.annotations().is_empty() {
643                let mut ann_map = HashMap::new();
644                for ann in message.annotations() {
645                    ann_map.insert(
646                        ann.name().to_string_lossy().to_string(),
647                        ann.value().to_string_lossy().to_string(),
648                    );
649                }
650                Some(Arc::new(ann_map))
651            } else {
652                None
653            };
654
655            // Convert state
656            let state = State {
657                typedesc_id: message.state_typedesc_id(),
658                data: message.state_data().into_slice().to_owned().into(),
659            };
660
661            let decoded = Parse {
662                annotations,
663                allowed_capabilities: decode_capabilities(message.allowed_capabilities())?,
664                compilation_flags: decode_compilation_flags(message.compilation_flags())?,
665                implicit_limit: match message.implicit_limit() {
666                    0 => None,
667                    val => Some(val),
668                },
669                output_format: message.output_format(),
670                expected_cardinality: TryFrom::try_from(message.expected_cardinality())?,
671                command_text: message.command_text().to_string_lossy().to_string(),
672                state,
673                input_language: InputLanguage::EdgeQL, // Default for non-multilingual
674            };
675            buf.advance(message.as_ref().len());
676            Ok(decoded)
677        }
678    }
679}
680
681impl Encode for Parse {
682    fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
683        buf.reserve(52);
684        if let Some(annotations) = self.annotations.as_deref() {
685            buf.put_u16(
686                u16::try_from(annotations.len())
687                    .ok()
688                    .context(errors::TooManyHeaders)?,
689            );
690            for (name, value) in annotations {
691                buf.reserve(8);
692                name.encode(buf)?;
693                value.encode(buf)?;
694            }
695        } else {
696            buf.put_u16(0);
697        }
698        buf.reserve(50);
699        buf.put_u64(self.allowed_capabilities.bits());
700        buf.put_u64(self.compilation_flags.bits());
701        buf.put_u64(self.implicit_limit.unwrap_or(0));
702        if buf.proto().is_multilingual() {
703            buf.put_u8(self.input_language as u8);
704        }
705        buf.put_u8(self.output_format as u8);
706        buf.put_u8(self.expected_cardinality as u8);
707        self.command_text.encode(buf)?;
708        self.state.typedesc_id.encode(buf)?;
709        self.state.data.encode(buf)?;
710        Ok(())
711    }
712}