1use std::collections::HashMap;
23use std::convert::TryFrom;
24use std::sync::Arc;
25
26use bytes::{Buf, BufMut, Bytes};
27use snafu::OptionExt;
28use uuid::Uuid;
29
30pub use crate::common::CompilationOptions;
31pub use crate::common::DumpFlags;
32pub use crate::common::{Capabilities, Cardinality, CompilationFlags};
33pub use crate::common::{RawTypedesc, State};
34use crate::encoding::{encode, Decode, Encode, Input, Output};
35use crate::encoding::{Annotations, KeyValues};
36use crate::errors::{self, DecodeError, EncodeError};
37use crate::new_protocol;
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40#[non_exhaustive]
41pub enum ClientMessage {
42 AuthenticationSaslInitialResponse(SaslInitialResponse),
43 AuthenticationSaslResponse(SaslResponse),
44 ClientHandshake(ClientHandshake),
45 Dump2(Dump2),
46 Dump3(Dump3),
47 Parse(Parse),
48 Execute1(Execute1),
49 Restore(Restore),
50 RestoreBlock(RestoreBlock),
51 RestoreEof,
52 Sync,
53 Terminate,
54 UnknownMessage(u8, Bytes),
55}
56
57#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct SaslInitialResponse {
59 pub method: String,
60 pub data: Bytes,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub struct SaslResponse {
65 pub data: Bytes,
66}
67
68#[derive(Debug, Clone, PartialEq, Eq)]
69pub struct ClientHandshake {
70 pub major_ver: u16,
71 pub minor_ver: u16,
72 pub params: HashMap<String, String>,
73 pub extensions: HashMap<String, Annotations>,
74}
75
76#[derive(Debug, Clone, PartialEq, Eq)]
77pub struct Parse {
78 pub annotations: Option<Arc<Annotations>>,
79 pub allowed_capabilities: Capabilities,
80 pub compilation_flags: CompilationFlags,
81 pub implicit_limit: Option<u64>,
82 pub output_format: IoFormat,
83 pub expected_cardinality: Cardinality,
84 pub command_text: String,
85 pub state: State,
86 pub input_language: InputLanguage,
87}
88
89#[derive(Debug, Clone, PartialEq, Eq)]
90pub struct Execute1 {
91 pub annotations: Option<Arc<Annotations>>,
92 pub allowed_capabilities: Capabilities,
93 pub compilation_flags: CompilationFlags,
94 pub implicit_limit: Option<u64>,
95 pub output_format: IoFormat,
96 pub expected_cardinality: Cardinality,
97 pub command_text: String,
98 pub state: State,
99 pub input_typedesc_id: Uuid,
100 pub output_typedesc_id: Uuid,
101 pub arguments: Bytes,
102 pub input_language: InputLanguage,
103}
104
105#[derive(Debug, Clone, PartialEq, Eq)]
106pub struct Dump2 {
107 pub headers: KeyValues,
108}
109
110#[derive(Debug, Clone, PartialEq, Eq)]
111pub struct Dump3 {
112 pub annotations: Option<Arc<Annotations>>,
113 pub flags: DumpFlags,
114}
115
116#[derive(Debug, Clone, PartialEq, Eq)]
117pub struct Restore {
118 pub headers: KeyValues,
119 pub jobs: u16,
120 pub data: Bytes,
121}
122
123#[derive(Debug, Clone, PartialEq, Eq)]
124pub struct RestoreBlock {
125 pub data: Bytes,
126}
127
128pub use crate::new_protocol::{InputLanguage, IoFormat};
129
130struct Empty;
131impl ClientMessage {
132 pub fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
133 use ClientMessage::*;
134 match self {
135 ClientHandshake(h) => encode(buf, 0x56, h),
136 AuthenticationSaslInitialResponse(h) => encode(buf, 0x70, h),
137 AuthenticationSaslResponse(h) => encode(buf, 0x72, h),
138 Parse(h) => encode(buf, 0x50, h),
139 Execute1(h) => encode(buf, 0x4f, h),
140 Dump2(h) => encode(buf, 0x3e, h),
141 Dump3(h) => encode(buf, 0x3e, h),
142 Restore(h) => encode(buf, 0x3c, h),
143 RestoreBlock(h) => encode(buf, 0x3d, h),
144 RestoreEof => encode(buf, 0x2e, &Empty),
145 Sync => encode(buf, 0x53, &Empty),
146 Terminate => encode(buf, 0x58, &Empty),
147
148 UnknownMessage(_, _) => errors::UnknownMessageCantBeEncoded.fail()?,
149 }
150 }
151 pub fn decode(buf: &mut Input) -> Result<ClientMessage, DecodeError> {
157 let message = new_protocol::Message::new(buf)?;
158 let mut next = buf.slice(..message.mlen() + 1);
159 buf.advance(message.mlen() + 1);
160 let buf = &mut next;
161
162 use self::ClientMessage as M;
163 let result = match buf[0] {
164 0x56 => ClientHandshake::decode(buf).map(M::ClientHandshake)?,
165 0x70 => SaslInitialResponse::decode(buf).map(M::AuthenticationSaslInitialResponse)?,
166 0x72 => SaslResponse::decode(buf).map(M::AuthenticationSaslResponse)?,
167 0x50 => Parse::decode(buf).map(M::Parse)?,
168 0x4f => Execute1::decode(buf).map(M::Execute1)?,
169 0x3e => {
170 if buf.proto().is_3() {
171 Dump3::decode(buf).map(M::Dump3)?
172 } else {
173 Dump2::decode(buf).map(M::Dump2)?
174 }
175 }
176 0x3c => Restore::decode(buf).map(M::Restore)?,
177 0x3d => RestoreBlock::decode(buf).map(M::RestoreBlock)?,
178 0x2e => M::RestoreEof,
179 0x53 => M::Sync,
180 0x58 => M::Terminate,
181 code => M::UnknownMessage(code, buf.copy_to_bytes(buf.remaining())),
182 };
183 Ok(result)
184 }
185}
186
187impl Encode for Empty {
188 fn encode(&self, _buf: &mut Output) -> Result<(), EncodeError> {
189 Ok(())
190 }
191}
192
193impl Encode for ClientHandshake {
194 fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
195 buf.reserve(8);
196 buf.put_u16(self.major_ver);
197 buf.put_u16(self.minor_ver);
198 buf.put_u16(
199 u16::try_from(self.params.len())
200 .ok()
201 .context(errors::TooManyParams)?,
202 );
203 for (k, v) in &self.params {
204 k.encode(buf)?;
205 v.encode(buf)?;
206 }
207 buf.reserve(2);
208 buf.put_u16(
209 u16::try_from(self.extensions.len())
210 .ok()
211 .context(errors::TooManyExtensions)?,
212 );
213 for (name, headers) in &self.extensions {
214 String::encode(name, buf)?;
215 buf.reserve(2);
216 buf.put_u16(
217 u16::try_from(headers.len())
218 .ok()
219 .context(errors::TooManyHeaders)?,
220 );
221 for (name, value) in headers {
222 String::encode(name, buf)?;
223 String::encode(value, buf)?;
224 }
225 }
226 Ok(())
227 }
228}
229
230impl Decode for ClientHandshake {
231 fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
232 let message = new_protocol::ClientHandshake::new(buf)?;
233 let mut params = HashMap::new();
234 for param in message.params() {
235 params.insert(
236 param.name().to_string_lossy().to_string(),
237 param.value().to_string_lossy().to_string(),
238 );
239 }
240
241 let mut extensions = HashMap::new();
242 for ext in message.extensions() {
243 let mut headers = HashMap::new();
244 for ann in ext.annotations() {
245 headers.insert(
246 ann.name().to_string_lossy().to_string(),
247 ann.value().to_string_lossy().to_string(),
248 );
249 }
250 extensions.insert(ext.name().to_string_lossy().to_string(), headers);
251 }
252
253 let decoded = ClientHandshake {
254 major_ver: message.major_ver(),
255 minor_ver: message.minor_ver(),
256 params,
257 extensions,
258 };
259 buf.advance(message.as_ref().len());
260 Ok(decoded)
261 }
262}
263
264impl Encode for SaslInitialResponse {
265 fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
266 self.method.encode(buf)?;
267 self.data.encode(buf)?;
268 Ok(())
269 }
270}
271
272impl Decode for SaslInitialResponse {
273 fn decode(buf: &mut Input) -> Result<SaslInitialResponse, DecodeError> {
274 let message = new_protocol::AuthenticationSASLInitialResponse::new(buf)?;
275 let decoded = SaslInitialResponse {
276 method: message.method().to_string_lossy().to_string(),
277 data: message.sasl_data().into_slice().to_owned().into(),
278 };
279 buf.advance(message.as_ref().len());
280 Ok(decoded)
281 }
282}
283
284impl Encode for SaslResponse {
285 fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
286 self.data.encode(buf)?;
287 Ok(())
288 }
289}
290
291impl Decode for SaslResponse {
292 fn decode(buf: &mut Input) -> Result<SaslResponse, DecodeError> {
293 let message = new_protocol::AuthenticationSASLResponse::new(buf)?;
294 let decoded = SaslResponse {
295 data: message.sasl_data().into_slice().to_owned().into(),
296 };
297 buf.advance(message.as_ref().len());
298 Ok(decoded)
299 }
300}
301
302impl Encode for Execute1 {
303 fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
304 buf.reserve(2 + 3 * 8 + 1 + 1 + 4 + 16 + 4 + 16 + 16 + 4);
305 if let Some(annotations) = self.annotations.as_deref() {
306 buf.put_u16(
307 u16::try_from(annotations.len())
308 .ok()
309 .context(errors::TooManyHeaders)?,
310 );
311 for (name, value) in annotations {
312 buf.reserve(4);
313 name.encode(buf)?;
314 value.encode(buf)?;
315 }
316 } else {
317 buf.put_u16(0);
318 }
319 buf.reserve(3 * 8 + 1 + 1 + 4 + 16 + 4 + 16 + 16 + 4);
320 buf.put_u64(self.allowed_capabilities.bits());
321 buf.put_u64(self.compilation_flags.bits());
322 buf.put_u64(self.implicit_limit.unwrap_or(0));
323 if buf.proto().is_multilingual() {
324 buf.put_u8(self.input_language as u8);
325 }
326 buf.put_u8(self.output_format as u8);
327 buf.put_u8(self.expected_cardinality as u8);
328 self.command_text.encode(buf)?;
329 self.state.typedesc_id.encode(buf)?;
330 self.state.data.encode(buf)?;
331 self.input_typedesc_id.encode(buf)?;
332 self.output_typedesc_id.encode(buf)?;
333 self.arguments.encode(buf)?;
334 Ok(())
335 }
336}
337
338impl Decode for Execute1 {
339 fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
340 if buf.proto().is_multilingual() {
341 let message = new_protocol::Execute::new(buf)?;
342
343 let annotations = if !message.annotations().is_empty() {
345 let mut ann_map = HashMap::new();
346 for ann in message.annotations() {
347 ann_map.insert(
348 ann.name().to_string_lossy().to_string(),
349 ann.value().to_string_lossy().to_string(),
350 );
351 }
352 Some(Arc::new(ann_map))
353 } else {
354 None
355 };
356
357 let state = State {
359 typedesc_id: message.state_typedesc_id(),
360 data: message.state_data().into_slice().to_owned().into(),
361 };
362
363 let decoded = Execute1 {
364 annotations,
365 allowed_capabilities: Capabilities::from_bits_retain(
366 message.allowed_capabilities(),
367 ),
368 compilation_flags: decode_compilation_flags(message.compilation_flags())?,
369 implicit_limit: match message.implicit_limit() {
370 0 => None,
371 val => Some(val),
372 },
373 output_format: message.output_format(),
374 expected_cardinality: TryFrom::try_from(message.expected_cardinality())?,
375 command_text: message.command_text().to_string_lossy().to_string(),
376 state,
377 input_typedesc_id: message.input_typedesc_id(),
378 output_typedesc_id: message.output_typedesc_id(),
379 arguments: message.arguments().into_slice().to_owned().into(),
380 input_language: message.input_language(),
381 };
382 buf.advance(message.as_ref().len());
383 Ok(decoded)
384 } else {
385 let message = new_protocol::Execute2::new(buf)?;
386
387 let annotations = if !message.annotations().is_empty() {
389 let mut ann_map = HashMap::new();
390 for ann in message.annotations() {
391 ann_map.insert(
392 ann.name().to_string_lossy().to_string(),
393 ann.value().to_string_lossy().to_string(),
394 );
395 }
396 Some(Arc::new(ann_map))
397 } else {
398 None
399 };
400
401 let state = State {
403 typedesc_id: message.state_typedesc_id(),
404 data: message.state_data().into_slice().to_owned().into(),
405 };
406
407 let decoded = Execute1 {
408 annotations,
409 allowed_capabilities: decode_capabilities(message.allowed_capabilities())?,
410 compilation_flags: decode_compilation_flags(message.compilation_flags())?,
411 implicit_limit: match message.implicit_limit() {
412 0 => None,
413 val => Some(val),
414 },
415 output_format: message.output_format(),
416 expected_cardinality: TryFrom::try_from(message.expected_cardinality())?,
417 command_text: message.command_text().to_string_lossy().to_string(),
418 state,
419 input_typedesc_id: message.input_typedesc_id(),
420 output_typedesc_id: message.output_typedesc_id(),
421 arguments: message.arguments().into_slice().to_owned().into(),
422 input_language: InputLanguage::EdgeQL,
423 };
424 buf.advance(message.as_ref().len());
425 Ok(decoded)
426 }
427 }
428}
429
430impl Encode for Dump2 {
431 fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
432 buf.reserve(10);
433 buf.put_u16(
434 u16::try_from(self.headers.len())
435 .ok()
436 .context(errors::TooManyHeaders)?,
437 );
438 for (&name, value) in &self.headers {
439 buf.reserve(2);
440 buf.put_u16(name);
441 value.encode(buf)?;
442 }
443 Ok(())
444 }
445}
446
447impl Decode for Dump2 {
448 fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
449 let message = new_protocol::Dump2::new(buf)?;
450 let mut headers = HashMap::new();
451 for header in message.headers() {
452 headers.insert(header.code(), header.value().into_slice().to_owned().into());
453 }
454
455 let decoded = Dump2 { headers };
456 buf.advance(message.as_ref().len());
457 Ok(decoded)
458 }
459}
460
461impl Encode for Dump3 {
462 fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
463 buf.reserve(2 + 8);
464 if let Some(annotations) = self.annotations.as_deref() {
465 buf.put_u16(
466 u16::try_from(annotations.len())
467 .ok()
468 .context(errors::TooManyHeaders)?,
469 );
470 for (name, value) in annotations {
471 buf.reserve(4);
472 name.encode(buf)?;
473 value.encode(buf)?;
474 }
475 } else {
476 buf.put_u16(0);
477 }
478 buf.put_u64(self.flags.bits());
479 Ok(())
480 }
481}
482
483impl Decode for Dump3 {
484 fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
485 let message = new_protocol::Dump3::new(buf)?;
486 let mut annotations = HashMap::new();
487 for ann in message.annotations() {
488 annotations.insert(
489 ann.name().to_string_lossy().to_string(),
490 ann.value().to_string_lossy().to_string(),
491 );
492 }
493
494 let decoded = Dump3 {
495 annotations: Some(Arc::new(annotations)),
496 flags: decode_dump_flags(message.flags())?,
497 };
498 buf.advance(message.as_ref().len());
499 Ok(decoded)
500 }
501}
502
503impl Encode for Restore {
504 fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
505 buf.reserve(4 + self.data.len());
506 buf.put_u16(
507 u16::try_from(self.headers.len())
508 .ok()
509 .context(errors::TooManyHeaders)?,
510 );
511 for (&name, value) in &self.headers {
512 buf.reserve(2);
513 buf.put_u16(name);
514 value.encode(buf)?;
515 }
516 buf.put_u16(self.jobs);
517 buf.extend(&self.data);
518 Ok(())
519 }
520}
521
522impl Decode for Restore {
523 fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
524 let message = new_protocol::Restore::new(buf)?;
525 let mut headers = HashMap::new();
526 for header in message.headers() {
527 headers.insert(header.code(), header.value().into_slice().to_owned().into());
528 }
529
530 let decoded = Restore {
531 headers,
532 jobs: message.jobs(),
533 data: message.data().as_ref().to_owned().into(),
534 };
535 buf.advance(message.as_ref().len());
536 Ok(decoded)
537 }
538}
539
540impl Encode for RestoreBlock {
541 fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
542 buf.extend(&self.data);
543 Ok(())
544 }
545}
546
547impl Decode for RestoreBlock {
548 fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
549 let message = new_protocol::RestoreBlock::new(buf)?;
550 let decoded = RestoreBlock {
551 data: message.block_data().into_slice().to_owned().into(),
552 };
553 buf.advance(message.as_ref().len());
554 Ok(decoded)
555 }
556}
557
558impl Parse {
559 pub fn new(
560 opts: &CompilationOptions,
561 query: &str,
562 state: State,
563 annotations: Option<Arc<Annotations>>,
564 ) -> Parse {
565 Parse {
566 annotations,
567 allowed_capabilities: opts.allow_capabilities,
568 compilation_flags: opts.flags(),
569 implicit_limit: opts.implicit_limit,
570 output_format: opts.io_format,
571 expected_cardinality: opts.expected_cardinality,
572 command_text: query.into(),
573 state,
574 input_language: opts.input_language,
575 }
576 }
577}
578
579fn decode_capabilities(val: u64) -> Result<Capabilities, DecodeError> {
580 Capabilities::from_bits(val)
581 .ok_or_else(|| errors::InvalidCapabilities { capabilities: val }.build())
582}
583
584fn decode_compilation_flags(val: u64) -> Result<CompilationFlags, DecodeError> {
585 CompilationFlags::from_bits(val).ok_or_else(|| {
586 errors::InvalidCompilationFlags {
587 compilation_flags: val,
588 }
589 .build()
590 })
591}
592
593fn decode_dump_flags(val: u64) -> Result<DumpFlags, DecodeError> {
594 DumpFlags::from_bits(val).ok_or_else(|| errors::InvalidDumpFlags { dump_flags: val }.build())
595}
596
597impl Decode for Parse {
598 fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
599 if buf.proto().is_multilingual() {
600 let message = new_protocol::Parse::new(buf)?;
601
602 let annotations = if !message.annotations().is_empty() {
604 let mut ann_map = HashMap::new();
605 for ann in message.annotations() {
606 ann_map.insert(
607 ann.name().to_string_lossy().to_string(),
608 ann.value().to_string_lossy().to_string(),
609 );
610 }
611 Some(Arc::new(ann_map))
612 } else {
613 None
614 };
615
616 let state = State {
618 typedesc_id: message.state_typedesc_id(),
619 data: message.state_data().into_slice().to_owned().into(),
620 };
621
622 let decoded = Parse {
623 annotations,
624 allowed_capabilities: decode_capabilities(message.allowed_capabilities())?,
625 compilation_flags: decode_compilation_flags(message.compilation_flags())?,
626 implicit_limit: match message.implicit_limit() {
627 0 => None,
628 val => Some(val),
629 },
630 output_format: message.output_format(),
631 expected_cardinality: TryFrom::try_from(message.expected_cardinality())?,
632 command_text: message.command_text().to_string_lossy().to_string(),
633 state,
634 input_language: message.input_language(),
635 };
636 buf.advance(message.as_ref().len());
637 Ok(decoded)
638 } else {
639 let message = new_protocol::Parse2::new(buf)?;
640
641 let annotations = if !message.annotations().is_empty() {
643 let mut ann_map = HashMap::new();
644 for ann in message.annotations() {
645 ann_map.insert(
646 ann.name().to_string_lossy().to_string(),
647 ann.value().to_string_lossy().to_string(),
648 );
649 }
650 Some(Arc::new(ann_map))
651 } else {
652 None
653 };
654
655 let state = State {
657 typedesc_id: message.state_typedesc_id(),
658 data: message.state_data().into_slice().to_owned().into(),
659 };
660
661 let decoded = Parse {
662 annotations,
663 allowed_capabilities: decode_capabilities(message.allowed_capabilities())?,
664 compilation_flags: decode_compilation_flags(message.compilation_flags())?,
665 implicit_limit: match message.implicit_limit() {
666 0 => None,
667 val => Some(val),
668 },
669 output_format: message.output_format(),
670 expected_cardinality: TryFrom::try_from(message.expected_cardinality())?,
671 command_text: message.command_text().to_string_lossy().to_string(),
672 state,
673 input_language: InputLanguage::EdgeQL, };
675 buf.advance(message.as_ref().len());
676 Ok(decoded)
677 }
678 }
679}
680
681impl Encode for Parse {
682 fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
683 buf.reserve(52);
684 if let Some(annotations) = self.annotations.as_deref() {
685 buf.put_u16(
686 u16::try_from(annotations.len())
687 .ok()
688 .context(errors::TooManyHeaders)?,
689 );
690 for (name, value) in annotations {
691 buf.reserve(8);
692 name.encode(buf)?;
693 value.encode(buf)?;
694 }
695 } else {
696 buf.put_u16(0);
697 }
698 buf.reserve(50);
699 buf.put_u64(self.allowed_capabilities.bits());
700 buf.put_u64(self.compilation_flags.bits());
701 buf.put_u64(self.implicit_limit.unwrap_or(0));
702 if buf.proto().is_multilingual() {
703 buf.put_u8(self.input_language as u8);
704 }
705 buf.put_u8(self.output_format as u8);
706 buf.put_u8(self.expected_cardinality as u8);
707 self.command_text.encode(buf)?;
708 self.state.typedesc_id.encode(buf)?;
709 self.state.data.encode(buf)?;
710 Ok(())
711 }
712}