postgres_proto_rs/messages/
frontend.rs

1use bytes::{Buf, BufMut, BytesMut};
2use std::{collections::HashMap, io::Cursor, mem};
3
4use crate::messages::{BytesMutReader, Error, Message};
5
6//----------------------------------------------------------------
7// Startup Messages
8
9pub const PROTOCOL_VERSION_CODE: i32 = 196608;
10pub const SSL_REQUEST_CODE: i32 = 80877103;
11pub const CANCEL_REQUEST_CODE: i32 = 80877102;
12pub const GSS_ENC_REQ_CODE: i32 = 80877104;
13
14#[derive(Debug)]
15pub enum StartupMessageType {
16    StartupParameters(StartupParameters),
17    SSLRequest(SSLRequest),
18    CancelRequest(CancelRequest),
19    GssEncReq(GssEncReq),
20}
21
22impl StartupMessageType {
23    pub fn get_bytes(&self) -> &BytesMut {
24        match self {
25            StartupMessageType::StartupParameters(startup_params) => {
26                return startup_params.get_bytes();
27            }
28            StartupMessageType::SSLRequest(ssl_request) => {
29                return ssl_request.get_bytes();
30            }
31            StartupMessageType::CancelRequest(cancel_request) => {
32                return cancel_request.get_bytes();
33            }
34            StartupMessageType::GssEncReq(gss_enc_req) => {
35                return gss_enc_req.get_bytes();
36            }
37        }
38    }
39
40    pub fn new_from_bytes(code: i32, message_bytes: BytesMut) -> Result<Self, Error> {
41        match code {
42            PROTOCOL_VERSION_CODE => {
43                let startup_message = StartupParameters::new_from_bytes(message_bytes)?;
44                Ok(StartupMessageType::StartupParameters(startup_message))
45            }
46            SSL_REQUEST_CODE => {
47                let ssl_request = SSLRequest::new_from_bytes(message_bytes)?;
48                Ok(StartupMessageType::SSLRequest(ssl_request))
49            }
50            CANCEL_REQUEST_CODE => {
51                let cancel_request = CancelRequest::new_from_bytes(message_bytes)?;
52                Ok(StartupMessageType::CancelRequest(cancel_request))
53            }
54            GSS_ENC_REQ_CODE => {
55                let gss_enc_req = GssEncReq::new_from_bytes(message_bytes)?;
56                Ok(StartupMessageType::GssEncReq(gss_enc_req))
57            }
58            _ => return Err(Error::InvalidProtocol),
59        }
60    }
61}
62
63pub trait StartupMessage: Message {}
64
65#[derive(Debug)]
66pub struct StartupParameters {
67    message_bytes: BytesMut,
68}
69
70pub struct StartupParametersParams {
71    pub parameters: HashMap<String, String>,
72}
73
74impl StartupParameters {
75    pub fn new(parameters: HashMap<String, String>) -> Result<Self, Error> {
76        if !parameters.contains_key("user") {
77            return Err(Error::ParseError("Missing user parameter".to_string()));
78        };
79
80        let mut data_bytes = BytesMut::new();
81
82        data_bytes.put_i32(PROTOCOL_VERSION_CODE);
83
84        for (key, value) in &parameters {
85            data_bytes.put(key.as_bytes());
86            data_bytes.put_u8(b'\0');
87            data_bytes.put(value.as_bytes());
88            data_bytes.put_u8(b'\0');
89        }
90        data_bytes.put_u8(b'\0');
91
92        let mut message_bytes = BytesMut::with_capacity(data_bytes.len() + mem::size_of::<i32>());
93
94        message_bytes.put_i32(data_bytes.len() as i32 + mem::size_of::<i32>() as i32);
95        message_bytes.put(data_bytes);
96
97        Ok(Self { message_bytes })
98    }
99
100    pub fn get_params(&self) -> StartupParametersParams {
101        let mut cursor = Cursor::new(&self.message_bytes);
102
103        let _len = cursor.get_i32();
104        let _protocol_version = cursor.get_i32();
105
106        let mut parameters = HashMap::new();
107
108        loop {
109            let key = match cursor.read_string() {
110                Ok(s) => {
111                    if s.len() == 0 {
112                        break;
113                    } else {
114                        s
115                    }
116                }
117                Err(_) => break, // TODO: handle error
118            };
119
120            let value = cursor.read_string().unwrap();
121
122            parameters.insert(key, value);
123        }
124
125        StartupParametersParams { parameters }
126    }
127}
128
129impl StartupMessage for StartupParameters {}
130
131impl Message for StartupParameters {
132    fn new_from_bytes(message_bytes: BytesMut) -> Result<Self, Error> {
133        Ok(Self {
134            message_bytes: message_bytes,
135        })
136    }
137
138    fn get_bytes(&self) -> &BytesMut {
139        return &self.message_bytes;
140    }
141}
142
143#[derive(Debug)]
144pub struct SSLRequest {
145    message_bytes: BytesMut,
146}
147
148impl SSLRequest {
149    pub fn new() -> Self {
150        let mut message_bytes =
151            BytesMut::with_capacity(mem::size_of::<i32>() + mem::size_of::<i32>());
152
153        message_bytes.put_i32(8);
154        message_bytes.put_i32(SSL_REQUEST_CODE);
155
156        Self { message_bytes }
157    }
158}
159
160impl StartupMessage for SSLRequest {}
161
162impl Message for SSLRequest {
163    fn new_from_bytes(message_bytes: BytesMut) -> Result<Self, Error> {
164        Ok(Self { message_bytes })
165    }
166
167    fn get_bytes(&self) -> &BytesMut {
168        return &self.message_bytes;
169    }
170}
171
172#[derive(Debug)]
173pub struct CancelRequest {
174    message_bytes: BytesMut,
175}
176
177pub struct CancelRequestParams {
178    pub process_id: i32,
179    pub secret_key: i32,
180}
181
182impl CancelRequest {
183    pub fn new(process_id: i32, secret_key: i32) -> Self {
184        let mut message_bytes = BytesMut::with_capacity(mem::size_of::<i32>() * 4);
185
186        message_bytes.put_i32(16);
187        message_bytes.put_i32(CANCEL_REQUEST_CODE);
188        message_bytes.put_i32(process_id);
189        message_bytes.put_i32(secret_key);
190
191        Self { message_bytes }
192    }
193
194    pub fn get_params(&self) -> CancelRequestParams {
195        let mut cursor = Cursor::new(&self.message_bytes);
196
197        let _len = cursor.get_i32();
198        let _code = cursor.get_i32();
199        let process_id = cursor.get_i32();
200        let secret_key = cursor.get_i32();
201
202        CancelRequestParams {
203            process_id,
204            secret_key,
205        }
206    }
207}
208
209impl StartupMessage for CancelRequest {}
210
211impl Message for CancelRequest {
212    fn new_from_bytes(message_bytes: BytesMut) -> Result<Self, Error> {
213        if message_bytes.len() != mem::size_of::<i32>() * 4 {
214            return Err(Error::InvalidBytes);
215        }
216
217        Ok(Self {
218            message_bytes: message_bytes,
219        })
220    }
221
222    fn get_bytes(&self) -> &BytesMut {
223        return &self.message_bytes;
224    }
225}
226
227#[derive(Debug)]
228pub struct GssEncReq {
229    message_bytes: BytesMut,
230}
231
232impl GssEncReq {
233    pub fn new() -> Self {
234        let mut message_bytes =
235            BytesMut::with_capacity(mem::size_of::<i32>() + mem::size_of::<i32>());
236
237        message_bytes.put_i32(8);
238        message_bytes.put_i32(GSS_ENC_REQ_CODE);
239
240        Self { message_bytes }
241    }
242}
243
244impl StartupMessage for GssEncReq {}
245
246impl Message for GssEncReq {
247    fn new_from_bytes(message_bytes: BytesMut) -> Result<Self, Error> {
248        Ok(Self { message_bytes })
249    }
250
251    fn get_bytes(&self) -> &BytesMut {
252        return &self.message_bytes;
253    }
254}
255
256//----------------------------------------------------------------
257// Frontend Messages
258#[derive(Debug)]
259pub enum FrontendMessageType {
260    Query(Query),
261    Bind(Bind),
262    Close(Close),
263    Describe(Describe),
264    Execute(Execute),
265    FunctionCall(FunctionCall),
266    CopyFail(CopyFail),
267    CopyData(CopyData),
268    CopyDone(CopyDone),
269    Flush(Flush),
270    Parse(Parse),
271    Sync(Sync),
272    Terminate(Terminate),
273}
274
275impl FrontendMessageType {
276    pub fn get_bytes(&self) -> &BytesMut {
277        match self {
278            Self::Query(query) => query.get_bytes(),
279            Self::Bind(bind) => bind.get_bytes(),
280            Self::Close(close) => close.get_bytes(),
281            Self::Describe(describe) => describe.get_bytes(),
282            Self::Execute(execute) => execute.get_bytes(),
283            Self::FunctionCall(function_call) => function_call.get_bytes(),
284            Self::CopyFail(copy_fail) => copy_fail.get_bytes(),
285            Self::CopyData(copy_data) => copy_data.get_bytes(),
286            Self::CopyDone(copy_done) => copy_done.get_bytes(),
287            Self::Flush(flush) => flush.get_bytes(),
288            Self::Parse(parse) => parse.get_bytes(),
289            Self::Sync(sync) => sync.get_bytes(),
290            Self::Terminate(terminate) => terminate.get_bytes(),
291        }
292    }
293
294    pub fn new_from_bytes(msg_type: u8, message_bytes: BytesMut) -> Result<Self, Error> {
295        match msg_type as char {
296            'Q' => {
297                let query = Query::new_from_bytes(message_bytes)?;
298                Ok(Self::Query(query))
299            }
300            'B' => {
301                let bind = Bind::new_from_bytes(message_bytes)?;
302                Ok(Self::Bind(bind))
303            }
304            'C' => {
305                let close = Close::new_from_bytes(message_bytes)?;
306                Ok(Self::Close(close))
307            }
308            'D' => {
309                let describe = Describe::new_from_bytes(message_bytes)?;
310                Ok(Self::Describe(describe))
311            }
312            'E' => {
313                let execute = Execute::new_from_bytes(message_bytes)?;
314                Ok(Self::Execute(execute))
315            }
316            'F' => {
317                let function_call = FunctionCall::new_from_bytes(message_bytes)?;
318                Ok(Self::FunctionCall(function_call))
319            }
320            'f' => {
321                let copy_fail = CopyFail::new_from_bytes(message_bytes)?;
322                Ok(Self::CopyFail(copy_fail))
323            }
324            'd' => {
325                let copy_data = CopyData::new_from_bytes(message_bytes)?;
326                Ok(Self::CopyData(copy_data))
327            }
328            'c' => {
329                let copy_done = CopyDone::new_from_bytes(message_bytes)?;
330                Ok(Self::CopyDone(copy_done))
331            }
332            'H' => {
333                let flush = Flush::new_from_bytes(message_bytes)?;
334                Ok(Self::Flush(flush))
335            }
336            'P' => {
337                let parse = Parse::new_from_bytes(message_bytes)?;
338                Ok(Self::Parse(parse))
339            }
340            'S' => {
341                let sync = Sync::new_from_bytes(message_bytes)?;
342                Ok(Self::Sync(sync))
343            }
344            'X' => {
345                let terminate = Terminate::new_from_bytes(message_bytes)?;
346                Ok(Self::Terminate(terminate))
347            }
348            _ => Err(Error::InvalidProtocol),
349        }
350    }
351}
352
353pub trait FrontendMessage: Message {}
354
355#[derive(Debug)]
356pub struct Terminate {
357    message_bytes: BytesMut,
358}
359
360impl Terminate {
361    pub fn new(message_bytes: BytesMut) -> Self {
362        Self { message_bytes }
363    }
364}
365
366impl FrontendMessage for Terminate {}
367
368impl Message for Terminate {
369    fn new_from_bytes(message_bytes: BytesMut) -> Result<Self, Error> {
370        Ok(Self { message_bytes })
371    }
372
373    fn get_bytes(&self) -> &BytesMut {
374        &self.message_bytes
375    }
376}
377
378#[derive(Debug)]
379pub struct Sync {
380    message_bytes: BytesMut,
381}
382
383impl Sync {
384    pub fn new(message_bytes: BytesMut) -> Self {
385        Self { message_bytes }
386    }
387}
388
389impl FrontendMessage for Sync {}
390
391impl Message for Sync {
392    fn new_from_bytes(message_bytes: BytesMut) -> Result<Self, Error> {
393        Ok(Self { message_bytes })
394    }
395
396    fn get_bytes(&self) -> &BytesMut {
397        &self.message_bytes
398    }
399}
400
401#[derive(Debug)]
402pub struct Parse {
403    message_bytes: BytesMut,
404}
405
406impl Parse {
407    pub fn new(message_bytes: BytesMut) -> Self {
408        Self { message_bytes }
409    }
410}
411
412impl FrontendMessage for Parse {}
413
414impl Message for Parse {
415    fn new_from_bytes(message_bytes: BytesMut) -> Result<Self, Error> {
416        Ok(Self { message_bytes })
417    }
418
419    fn get_bytes(&self) -> &BytesMut {
420        &self.message_bytes
421    }
422}
423
424#[derive(Debug)]
425pub struct Flush {
426    message_bytes: BytesMut,
427}
428
429impl Flush {
430    pub fn new(message_bytes: BytesMut) -> Self {
431        Self { message_bytes }
432    }
433}
434
435impl FrontendMessage for Flush {}
436
437impl Message for Flush {
438    fn new_from_bytes(message_bytes: BytesMut) -> Result<Self, Error> {
439        Ok(Self { message_bytes })
440    }
441
442    fn get_bytes(&self) -> &BytesMut {
443        &self.message_bytes
444    }
445}
446
447#[derive(Debug)]
448pub struct CopyDone {
449    message_bytes: BytesMut,
450}
451
452impl CopyDone {
453    pub fn new(message_bytes: BytesMut) -> Self {
454        Self { message_bytes }
455    }
456}
457
458impl FrontendMessage for CopyDone {}
459
460impl Message for CopyDone {
461    fn new_from_bytes(message_bytes: BytesMut) -> Result<Self, Error> {
462        Ok(Self { message_bytes })
463    }
464
465    fn get_bytes(&self) -> &BytesMut {
466        &self.message_bytes
467    }
468}
469
470#[derive(Debug)]
471pub struct CopyData {
472    message_bytes: BytesMut,
473}
474
475impl CopyData {
476    pub fn new(message_bytes: BytesMut) -> Self {
477        Self { message_bytes }
478    }
479}
480
481impl FrontendMessage for CopyData {}
482
483impl Message for CopyData {
484    fn new_from_bytes(message_bytes: BytesMut) -> Result<Self, Error> {
485        Ok(Self { message_bytes })
486    }
487
488    fn get_bytes(&self) -> &BytesMut {
489        &self.message_bytes
490    }
491}
492
493#[derive(Debug)]
494pub struct CopyFail {
495    message_bytes: BytesMut,
496}
497
498impl CopyFail {
499    pub fn new(message_bytes: BytesMut) -> Self {
500        Self { message_bytes }
501    }
502}
503
504impl FrontendMessage for CopyFail {}
505
506impl Message for CopyFail {
507    fn new_from_bytes(message_bytes: BytesMut) -> Result<Self, Error> {
508        Ok(Self { message_bytes })
509    }
510
511    fn get_bytes(&self) -> &BytesMut {
512        &self.message_bytes
513    }
514}
515
516#[derive(Debug)]
517pub struct FunctionCall {
518    message_bytes: BytesMut,
519}
520
521impl FunctionCall {
522    pub fn new(message_bytes: BytesMut) -> Self {
523        Self { message_bytes }
524    }
525}
526
527impl FrontendMessage for FunctionCall {}
528
529impl Message for FunctionCall {
530    fn new_from_bytes(message_bytes: BytesMut) -> Result<Self, Error> {
531        Ok(Self { message_bytes })
532    }
533
534    fn get_bytes(&self) -> &BytesMut {
535        &self.message_bytes
536    }
537}
538
539#[derive(Debug)]
540pub struct Execute {
541    message_bytes: BytesMut,
542}
543
544impl Execute {
545    pub fn new(message_bytes: BytesMut) -> Self {
546        Self { message_bytes }
547    }
548}
549
550impl FrontendMessage for Execute {}
551
552impl Message for Execute {
553    fn new_from_bytes(message_bytes: BytesMut) -> Result<Self, Error> {
554        Ok(Self { message_bytes })
555    }
556
557    fn get_bytes(&self) -> &BytesMut {
558        &self.message_bytes
559    }
560}
561
562#[derive(Debug)]
563pub struct Describe {
564    message_bytes: BytesMut,
565}
566
567impl Describe {
568    pub fn new(message_bytes: BytesMut) -> Self {
569        Self { message_bytes }
570    }
571}
572
573impl FrontendMessage for Describe {}
574
575impl Message for Describe {
576    fn new_from_bytes(message_bytes: BytesMut) -> Result<Self, Error> {
577        Ok(Self { message_bytes })
578    }
579
580    fn get_bytes(&self) -> &BytesMut {
581        &self.message_bytes
582    }
583}
584
585#[derive(Debug)]
586pub struct Close {
587    message_bytes: BytesMut,
588}
589
590impl Close {
591    pub fn new(message_bytes: BytesMut) -> Self {
592        Self { message_bytes }
593    }
594}
595
596impl FrontendMessage for Close {}
597
598impl Message for Close {
599    fn new_from_bytes(message_bytes: BytesMut) -> Result<Self, Error> {
600        Ok(Self { message_bytes })
601    }
602
603    fn get_bytes(&self) -> &BytesMut {
604        &self.message_bytes
605    }
606}
607
608#[derive(Debug)]
609pub struct Bind {
610    message_bytes: BytesMut,
611}
612
613impl Bind {
614    pub fn new(message_bytes: BytesMut) -> Self {
615        Self { message_bytes }
616    }
617}
618
619impl FrontendMessage for Bind {}
620
621impl Message for Bind {
622    fn new_from_bytes(message_bytes: BytesMut) -> Result<Self, Error> {
623        Ok(Self { message_bytes })
624    }
625
626    fn get_bytes(&self) -> &BytesMut {
627        &self.message_bytes
628    }
629}
630
631#[derive(Debug)]
632pub struct Query {
633    message_bytes: BytesMut,
634}
635
636pub struct QueryParams {
637    pub query_string: String,
638}
639
640impl Query {
641    pub fn new(query_string: String) -> Self {
642        let mut message_bytes = BytesMut::with_capacity(
643            mem::size_of::<u8>() + mem::size_of::<i32>() + mem::size_of::<u8>(),
644        );
645
646        let msg_len = (query_string.len() + 1 + mem::size_of::<i32>()) as i32;
647
648        message_bytes.put_u8(b'Q');
649        message_bytes.put_i32(msg_len);
650        message_bytes.put(&query_string.as_bytes()[..]);
651        message_bytes.put_u8(0);
652
653        Self { message_bytes }
654    }
655
656    pub fn get_params(&self) -> QueryParams {
657        let mut cursor = Cursor::new(&self.message_bytes);
658
659        let _code = cursor.get_u8();
660        let _len = cursor.get_i32();
661
662        let query_string = cursor.read_string().unwrap();
663
664        QueryParams {
665            query_string: query_string,
666        }
667    }
668}
669
670impl FrontendMessage for Query {}
671
672impl Message for Query {
673    fn new_from_bytes(message_bytes: BytesMut) -> Result<Self, Error> {
674        Ok(Self { message_bytes })
675    }
676
677    fn get_bytes(&self) -> &BytesMut {
678        return &self.message_bytes;
679    }
680}