ampiato/replication/
pgoutput.rs

1use core::panic;
2use std::num::ParseFloatError;
3use std::num::ParseIntError;
4use std::str::from_utf8;
5use std::str::Utf8Error;
6
7use crate::Error;
8use binrw::prelude::*;
9use binrw::NullString;
10use byteorder::BigEndian;
11use byteorder::ByteOrder;
12use chrono::DateTime;
13use chrono::NaiveDateTime;
14use chrono::Utc;
15
16use crate::core::defs::Time;
17
18const POSTGRES_EPOCH: i64 = 946684800000; // PostgreSQL epoch in microseconds since UNIX epoch
19
20#[derive(Debug)]
21pub enum ParseError {
22    Utf8Error(Utf8Error),
23    ParseIntError(ParseIntError),
24    ParseFloatError(ParseFloatError),
25    ChronoParseError(chrono::ParseError),
26}
27
28impl From<Utf8Error> for ParseError {
29    fn from(value: Utf8Error) -> Self {
30        ParseError::Utf8Error(value)
31    }
32}
33
34impl From<ParseIntError> for ParseError {
35    fn from(value: ParseIntError) -> Self {
36        ParseError::ParseIntError(value)
37    }
38}
39
40impl From<ParseFloatError> for ParseError {
41    fn from(value: ParseFloatError) -> Self {
42        ParseError::ParseFloatError(value)
43    }
44}
45
46impl From<chrono::ParseError> for ParseError {
47    fn from(value: chrono::ParseError) -> Self {
48        ParseError::ChronoParseError(value)
49    }
50}
51
52impl From<ParseError> for Error {
53    fn from(value: ParseError) -> Self {
54        Error::ReplicationError(format!("{:?}", value))
55    }
56}
57
58#[derive(BinRead, Debug)]
59#[br(import(size: u32))]
60pub struct StringWithSize {
61    #[br(count = size)]
62    pub string: Vec<u8>,
63}
64
65fn parse_timestamp_tz(ts: u64) -> DateTime<Utc> {
66    let timestamp = chrono::Duration::microseconds(ts as i64);
67    let postgres_epoch = chrono::DateTime::<chrono::Utc>::from_timestamp(946_684_800, 0).unwrap();
68    postgres_epoch + timestamp
69}
70
71fn parse_string(s: NullString) -> String {
72    s.to_string()
73}
74
75#[derive(BinRead, Debug)]
76pub struct MessageBegin {
77    pub final_lsn: u64,
78    #[br(map = parse_timestamp_tz)]
79    pub commit_timestamp: DateTime<Utc>,
80    pub transaction_id: u32,
81}
82
83#[derive(BinRead, Debug)]
84pub struct Message {
85    pub transaction_id: Option<u32>,
86    pub flags: u8,
87    pub lsn: u64,
88    #[br(map = parse_string)]
89    pub prefix: String,
90    pub length: u32,
91    #[br(count = length)]
92    pub content: Vec<u8>,
93}
94
95#[derive(BinRead, Debug)]
96pub struct MessageCommit {
97    pub flags: u8,
98    pub lsn: u64,
99    pub end_lsn: u64,
100    #[br(map = parse_timestamp_tz)]
101    pub commit_timestamp: DateTime<Utc>,
102}
103
104#[derive(BinRead, Debug)]
105pub struct MessageOrigin {
106    pub lsn: u64,
107    pub size: u32,
108    #[br(map = parse_string)]
109    pub name: String,
110}
111
112#[derive(BinRead, Debug)]
113pub struct MessageRelation {
114    pub transaction_id: u16,
115    pub relation_oid: u16,
116    // namespace_size: u32,
117    #[br(map = parse_string)]
118    pub namespace: String,
119    // relation_name_size: u32,
120    #[br(map = parse_string)]
121    pub relation_name: String,
122    pub replica_identity_setting: u8,
123    pub number_of_columns: u16,
124    #[br(count = number_of_columns)]
125    pub columns: Vec<Column>,
126}
127
128#[derive(BinRead, Debug)]
129pub struct MessageType {
130    pub transaction_id: Option<u32>,
131    pub type_oid: u32,
132    pub namespace_size: u32,
133    #[br(args(namespace_size))]
134    pub namespace: StringWithSize,
135    pub name_size: u32,
136    #[br(args(name_size))]
137    pub name: StringWithSize,
138}
139
140#[derive(BinRead, Debug)]
141pub struct MessageInsert {
142    pub transaction_id: u16,
143    pub relation_oid: u16,
144    #[br(magic = b'N')]
145    pub new_tuple: TupleData,
146}
147
148#[derive(BinRead, Debug)]
149pub struct MessageUpdate {
150    pub transaction_id: u16,
151    pub relation_oid: u16,
152    #[br(try)]
153    pub key_or_old_tuple: Option<KeyOrOldTupleData>,
154    #[br(magic = b'N')]
155    pub new_tuple: TupleData,
156}
157
158#[derive(BinRead, Debug)]
159pub struct MessageDelete {
160    pub transaction_id: u16,
161    pub relation_oid: u16,
162    #[br(try)]
163    pub key_or_old_tuple: Option<KeyOrOldTupleData>,
164}
165
166#[derive(BinRead, Debug)]
167pub struct MessageTruncate {
168    pub transaction_id: Option<u32>,
169    pub number_of_relations: u32,
170    pub option_bits: u8,
171    #[br(count = number_of_relations)]
172    pub relation_oids: Vec<u32>,
173}
174
175#[derive(BinRead, Debug)]
176pub struct MessageStreamStart {
177    pub transaction_id: u32,
178    pub is_first_segment: u8,
179}
180
181#[derive(BinRead, Debug)]
182pub struct MessageStreamCommit {
183    pub transaction_id: u32,
184    pub flags: u8,
185    pub lsn: u64,
186    pub end_lsn: u64,
187    pub commit_timestamp: u64,
188}
189
190#[derive(BinRead, Debug)]
191pub struct MessageStreamAbort {
192    pub transaction_id: u32,
193    pub subtransaction_id: u32,
194    pub lsn: u64,
195    pub abort_timestamp: u64,
196}
197
198#[derive(BinRead, Debug)]
199pub struct MessageBeginPrepare {
200    pub prepare_lsn: u64,
201    pub end_lsn: u64,
202    pub prepare_timestamp: u64,
203    pub transaction_id: u32,
204    pub gid_size: u32,
205    #[br(args(gid_size))]
206    pub gid: StringWithSize,
207}
208
209#[derive(BinRead, Debug)]
210pub struct MessagePrepare {
211    pub flags: u8,
212    pub prepare_lsn: u64,
213    pub end_lsn: u64,
214    pub prepare_timestamp: u64,
215    pub transaction_id: u32,
216    pub gid_size: u32,
217    #[br(args(gid_size))]
218    pub gid: StringWithSize,
219}
220
221#[derive(BinRead, Debug)]
222pub struct MessageCommitPrepared {
223    pub flags: u8,
224    pub commit_lsn: u64,
225    pub end_lsn: u64,
226    pub commit_timestamp: u64,
227    pub transaction_id: u32,
228    pub gid_size: u32,
229    #[br(args(gid_size))]
230    pub gid: StringWithSize,
231}
232
233#[derive(BinRead, Debug)]
234pub struct MessageRollbackPrepared {
235    pub flags: u8,
236    pub prepare_end_lsn: u64,
237    pub rollback_end_lsn: u64,
238    pub prepare_timestamp: u64,
239    pub rollback_timestamp: u64,
240    pub transaction_id: u32,
241    pub gid_size: u32,
242    #[br(args(gid_size))]
243    pub gid: StringWithSize,
244}
245
246#[derive(BinRead, Debug)]
247pub struct MessageStreamPrepare {
248    pub flags: u8,
249    pub prepare_lsn: u64,
250    pub end_lsn: u64,
251    pub prepare_timestamp: u64,
252    pub transaction_id: u32,
253    pub gid_size: u32,
254    #[br(args(gid_size))]
255    pub gid: StringWithSize,
256}
257
258#[derive(BinRead, Debug)]
259#[br(big)]
260pub enum LogicalReplicationMessage {
261    #[br(magic = b'B')]
262    Begin(MessageBegin),
263    #[br(magic = b'M')]
264    Message(Message),
265    #[br(magic = b'C')]
266    Commit(MessageCommit),
267    #[br(magic = b'O')]
268    Origin(MessageOrigin),
269    #[br(magic = b'R')]
270    Relation(MessageRelation),
271    #[br(magic = b'Y')]
272    Type(MessageType),
273    #[br(magic = b'I')]
274    Insert(MessageInsert),
275    #[br(magic = b'U')]
276    Update(MessageUpdate),
277    #[br(magic = b'D')]
278    Delete(MessageDelete),
279    #[br(magic = b'T')]
280    Truncate(MessageTruncate),
281    #[br(magic = b'S')]
282    StreamStart(MessageStreamStart),
283    #[br(magic = b'E')]
284    StreamStop,
285    #[br(magic = b'c')]
286    StreamCommit(MessageStreamCommit),
287    #[br(magic = b'A')]
288    StreamAbort(MessageStreamAbort),
289    #[br(magic = b'b')]
290    BeginPrepare(MessageBeginPrepare),
291    #[br(magic = b'P')]
292    Prepare(MessagePrepare),
293    #[br(magic = b'K')]
294    CommitPrepared(MessageCommitPrepared),
295    #[br(magic = b'r')]
296    RollbackPrepared(MessageRollbackPrepared),
297    #[br(magic = b'p')]
298    StreamPrepare(MessageStreamPrepare),
299}
300
301#[derive(BinRead, Debug)]
302pub struct Column {
303    pub flags: u8,
304    #[br(map = parse_string)]
305    pub name: String,
306    pub type_oid: u32,
307    pub type_modifier: u32,
308}
309
310#[derive(BinRead, Debug)]
311#[br(big)]
312pub struct TupleData {
313    pub number_of_columns: u16,
314    #[br(count = number_of_columns)]
315    pub columns: Vec<ColumnValue>,
316}
317
318#[derive(BinRead, Debug)]
319pub enum ColumnValue {
320    #[br(magic = b'n')]
321    Null {
322        length: u32,
323        #[br(count = length)]
324        data: Vec<u8>,
325    },
326    #[br(magic = b'u')]
327    UnchangedToast {
328        length: u32,
329        #[br(count = length)]
330        data: Vec<u8>,
331    },
332    #[br(magic = b't')]
333    Text {
334        length: u32,
335        #[br(count = length)]
336        data: Vec<u8>,
337    },
338    #[br(magic = b'b')]
339    Binary {
340        length: u32,
341        #[br(count = length)]
342        data: Vec<u8>,
343    },
344}
345
346
347impl ColumnValue {
348    pub fn as_bytes<'r>(&'r self) -> Result<&'r [u8], ParseError> {
349        match self {
350            ColumnValue::Text { data, .. } => Ok(data),
351            ColumnValue::Binary { data, .. } => Ok(data),
352            _ => panic!("Invalid column type"),
353        }
354    }
355
356    pub fn as_str<'r>(&'r self) -> Result<&'r str, ParseError> {
357        Ok(from_utf8(self.as_bytes()?)?)
358    }
359}
360
361pub trait EntityRef {
362    fn id(&self) -> i64;
363    fn from_entity_id(id: i64) -> Self;
364}
365
366impl<T: EntityRef> Decode for T {
367    fn decode(value: &ColumnValue) -> Result<Self, ParseError> {
368        Ok(match value {
369            ColumnValue::Text { .. } => {
370                Self::from_entity_id(Decode::decode(value)?)
371            }
372            ColumnValue::Binary { .. } => {
373                Self::from_entity_id(Decode::decode(value)?)
374            }
375            _ => panic!("Invalid column type"),
376        })
377    }
378}
379
380
381pub trait Decode: Sized {
382    fn decode(value: &ColumnValue) -> Result<Self, ParseError>;
383}
384
385impl Decode for Time {
386    fn decode(value: &ColumnValue) -> Result<Self, ParseError> {
387        Ok(match value {
388            ColumnValue::Text { .. } => {
389                let s = value.as_str()?;
390                let dt = NaiveDateTime::parse_from_str(
391                    s,
392                    if s.contains('+') {
393                        // Contains a time-zone specifier
394                        // This is given for timestamptz for some reason
395                        // Postgres already guarantees this to always be UTC
396                        "%Y-%m-%d %H:%M:%S%.f%#z"
397                    } else {
398                        "%Y-%m-%d %H:%M:%S%.f"
399                    },
400                )?;
401
402                Time::from_naive_datetime(dt)
403            }
404            ColumnValue::Binary { .. } => {
405                let us: i64 = Decode::decode(value)?;
406                Time(POSTGRES_EPOCH + us / 1_000_000)
407            }
408            _ => panic!("Invalid column type"),
409        })
410    }
411}
412
413impl Decode for i64 {
414    fn decode(value: &ColumnValue) -> Result<Self, ParseError> {
415        Ok(match value {
416            ColumnValue::Text { .. } => value.as_str()?.parse()?,
417            ColumnValue::Binary { data, .. } => BigEndian::read_int(&data, data.len()),
418            _ => panic!("Invalid column type"),
419        })
420    }
421}
422
423impl Decode for f32 {
424    fn decode(value: &ColumnValue) -> Result<Self, ParseError> {
425        Ok(match value {
426            ColumnValue::Text { .. } => value.as_str()?.parse()?,
427            ColumnValue::Binary { data, .. } => BigEndian::read_f32(&data),
428            _ => panic!("Invalid column type"),
429        })
430    }
431}
432
433impl Decode for f64 {
434    fn decode(value: &ColumnValue) -> Result<Self, ParseError> {
435        Ok(match value {
436            ColumnValue::Text { .. } => value.as_str()?.parse()?,
437            ColumnValue::Binary { data, .. } => BigEndian::read_f64(&data),
438            _ => panic!("Invalid column type"),
439        })
440    }
441}
442
443#[derive(BinRead, Debug)]
444#[br(big)]
445pub enum KeyOrOldTupleData {
446    #[br(magic = b'K')]
447    Key(TupleData),
448    #[br(magic = b'O')]
449    Old(TupleData),
450}
451
452pub fn decode(msg: &[u8]) -> Result<LogicalReplicationMessage, binrw::Error> {
453    LogicalReplicationMessage::read(&mut binrw::io::Cursor::new(msg))
454}
455