1use core::panic;
2use std::num::ParseFloatError;
3use std::num::ParseIntError;
4use std::str::from_utf8;
5use std::str::Utf8Error;
6
7use crate::Error;
8use binrw::prelude::*;
9use binrw::NullString;
10use byteorder::BigEndian;
11use byteorder::ByteOrder;
12use chrono::DateTime;
13use chrono::NaiveDateTime;
14use chrono::Utc;
15
16use crate::core::defs::Time;
17
18const POSTGRES_EPOCH: i64 = 946684800000; #[derive(Debug)]
21pub enum ParseError {
22 Utf8Error(Utf8Error),
23 ParseIntError(ParseIntError),
24 ParseFloatError(ParseFloatError),
25 ChronoParseError(chrono::ParseError),
26}
27
28impl From<Utf8Error> for ParseError {
29 fn from(value: Utf8Error) -> Self {
30 ParseError::Utf8Error(value)
31 }
32}
33
34impl From<ParseIntError> for ParseError {
35 fn from(value: ParseIntError) -> Self {
36 ParseError::ParseIntError(value)
37 }
38}
39
40impl From<ParseFloatError> for ParseError {
41 fn from(value: ParseFloatError) -> Self {
42 ParseError::ParseFloatError(value)
43 }
44}
45
46impl From<chrono::ParseError> for ParseError {
47 fn from(value: chrono::ParseError) -> Self {
48 ParseError::ChronoParseError(value)
49 }
50}
51
52impl From<ParseError> for Error {
53 fn from(value: ParseError) -> Self {
54 Error::ReplicationError(format!("{:?}", value))
55 }
56}
57
58#[derive(BinRead, Debug)]
59#[br(import(size: u32))]
60pub struct StringWithSize {
61 #[br(count = size)]
62 pub string: Vec<u8>,
63}
64
65fn parse_timestamp_tz(ts: u64) -> DateTime<Utc> {
66 let timestamp = chrono::Duration::microseconds(ts as i64);
67 let postgres_epoch = chrono::DateTime::<chrono::Utc>::from_timestamp(946_684_800, 0).unwrap();
68 postgres_epoch + timestamp
69}
70
71fn parse_string(s: NullString) -> String {
72 s.to_string()
73}
74
75#[derive(BinRead, Debug)]
76pub struct MessageBegin {
77 pub final_lsn: u64,
78 #[br(map = parse_timestamp_tz)]
79 pub commit_timestamp: DateTime<Utc>,
80 pub transaction_id: u32,
81}
82
83#[derive(BinRead, Debug)]
84pub struct Message {
85 pub transaction_id: Option<u32>,
86 pub flags: u8,
87 pub lsn: u64,
88 #[br(map = parse_string)]
89 pub prefix: String,
90 pub length: u32,
91 #[br(count = length)]
92 pub content: Vec<u8>,
93}
94
95#[derive(BinRead, Debug)]
96pub struct MessageCommit {
97 pub flags: u8,
98 pub lsn: u64,
99 pub end_lsn: u64,
100 #[br(map = parse_timestamp_tz)]
101 pub commit_timestamp: DateTime<Utc>,
102}
103
104#[derive(BinRead, Debug)]
105pub struct MessageOrigin {
106 pub lsn: u64,
107 pub size: u32,
108 #[br(map = parse_string)]
109 pub name: String,
110}
111
112#[derive(BinRead, Debug)]
113pub struct MessageRelation {
114 pub transaction_id: u16,
115 pub relation_oid: u16,
116 #[br(map = parse_string)]
118 pub namespace: String,
119 #[br(map = parse_string)]
121 pub relation_name: String,
122 pub replica_identity_setting: u8,
123 pub number_of_columns: u16,
124 #[br(count = number_of_columns)]
125 pub columns: Vec<Column>,
126}
127
128#[derive(BinRead, Debug)]
129pub struct MessageType {
130 pub transaction_id: Option<u32>,
131 pub type_oid: u32,
132 pub namespace_size: u32,
133 #[br(args(namespace_size))]
134 pub namespace: StringWithSize,
135 pub name_size: u32,
136 #[br(args(name_size))]
137 pub name: StringWithSize,
138}
139
140#[derive(BinRead, Debug)]
141pub struct MessageInsert {
142 pub transaction_id: u16,
143 pub relation_oid: u16,
144 #[br(magic = b'N')]
145 pub new_tuple: TupleData,
146}
147
148#[derive(BinRead, Debug)]
149pub struct MessageUpdate {
150 pub transaction_id: u16,
151 pub relation_oid: u16,
152 #[br(try)]
153 pub key_or_old_tuple: Option<KeyOrOldTupleData>,
154 #[br(magic = b'N')]
155 pub new_tuple: TupleData,
156}
157
158#[derive(BinRead, Debug)]
159pub struct MessageDelete {
160 pub transaction_id: u16,
161 pub relation_oid: u16,
162 #[br(try)]
163 pub key_or_old_tuple: Option<KeyOrOldTupleData>,
164}
165
166#[derive(BinRead, Debug)]
167pub struct MessageTruncate {
168 pub transaction_id: Option<u32>,
169 pub number_of_relations: u32,
170 pub option_bits: u8,
171 #[br(count = number_of_relations)]
172 pub relation_oids: Vec<u32>,
173}
174
175#[derive(BinRead, Debug)]
176pub struct MessageStreamStart {
177 pub transaction_id: u32,
178 pub is_first_segment: u8,
179}
180
181#[derive(BinRead, Debug)]
182pub struct MessageStreamCommit {
183 pub transaction_id: u32,
184 pub flags: u8,
185 pub lsn: u64,
186 pub end_lsn: u64,
187 pub commit_timestamp: u64,
188}
189
190#[derive(BinRead, Debug)]
191pub struct MessageStreamAbort {
192 pub transaction_id: u32,
193 pub subtransaction_id: u32,
194 pub lsn: u64,
195 pub abort_timestamp: u64,
196}
197
198#[derive(BinRead, Debug)]
199pub struct MessageBeginPrepare {
200 pub prepare_lsn: u64,
201 pub end_lsn: u64,
202 pub prepare_timestamp: u64,
203 pub transaction_id: u32,
204 pub gid_size: u32,
205 #[br(args(gid_size))]
206 pub gid: StringWithSize,
207}
208
209#[derive(BinRead, Debug)]
210pub struct MessagePrepare {
211 pub flags: u8,
212 pub prepare_lsn: u64,
213 pub end_lsn: u64,
214 pub prepare_timestamp: u64,
215 pub transaction_id: u32,
216 pub gid_size: u32,
217 #[br(args(gid_size))]
218 pub gid: StringWithSize,
219}
220
221#[derive(BinRead, Debug)]
222pub struct MessageCommitPrepared {
223 pub flags: u8,
224 pub commit_lsn: u64,
225 pub end_lsn: u64,
226 pub commit_timestamp: u64,
227 pub transaction_id: u32,
228 pub gid_size: u32,
229 #[br(args(gid_size))]
230 pub gid: StringWithSize,
231}
232
233#[derive(BinRead, Debug)]
234pub struct MessageRollbackPrepared {
235 pub flags: u8,
236 pub prepare_end_lsn: u64,
237 pub rollback_end_lsn: u64,
238 pub prepare_timestamp: u64,
239 pub rollback_timestamp: u64,
240 pub transaction_id: u32,
241 pub gid_size: u32,
242 #[br(args(gid_size))]
243 pub gid: StringWithSize,
244}
245
246#[derive(BinRead, Debug)]
247pub struct MessageStreamPrepare {
248 pub flags: u8,
249 pub prepare_lsn: u64,
250 pub end_lsn: u64,
251 pub prepare_timestamp: u64,
252 pub transaction_id: u32,
253 pub gid_size: u32,
254 #[br(args(gid_size))]
255 pub gid: StringWithSize,
256}
257
258#[derive(BinRead, Debug)]
259#[br(big)]
260pub enum LogicalReplicationMessage {
261 #[br(magic = b'B')]
262 Begin(MessageBegin),
263 #[br(magic = b'M')]
264 Message(Message),
265 #[br(magic = b'C')]
266 Commit(MessageCommit),
267 #[br(magic = b'O')]
268 Origin(MessageOrigin),
269 #[br(magic = b'R')]
270 Relation(MessageRelation),
271 #[br(magic = b'Y')]
272 Type(MessageType),
273 #[br(magic = b'I')]
274 Insert(MessageInsert),
275 #[br(magic = b'U')]
276 Update(MessageUpdate),
277 #[br(magic = b'D')]
278 Delete(MessageDelete),
279 #[br(magic = b'T')]
280 Truncate(MessageTruncate),
281 #[br(magic = b'S')]
282 StreamStart(MessageStreamStart),
283 #[br(magic = b'E')]
284 StreamStop,
285 #[br(magic = b'c')]
286 StreamCommit(MessageStreamCommit),
287 #[br(magic = b'A')]
288 StreamAbort(MessageStreamAbort),
289 #[br(magic = b'b')]
290 BeginPrepare(MessageBeginPrepare),
291 #[br(magic = b'P')]
292 Prepare(MessagePrepare),
293 #[br(magic = b'K')]
294 CommitPrepared(MessageCommitPrepared),
295 #[br(magic = b'r')]
296 RollbackPrepared(MessageRollbackPrepared),
297 #[br(magic = b'p')]
298 StreamPrepare(MessageStreamPrepare),
299}
300
301#[derive(BinRead, Debug)]
302pub struct Column {
303 pub flags: u8,
304 #[br(map = parse_string)]
305 pub name: String,
306 pub type_oid: u32,
307 pub type_modifier: u32,
308}
309
310#[derive(BinRead, Debug)]
311#[br(big)]
312pub struct TupleData {
313 pub number_of_columns: u16,
314 #[br(count = number_of_columns)]
315 pub columns: Vec<ColumnValue>,
316}
317
318#[derive(BinRead, Debug)]
319pub enum ColumnValue {
320 #[br(magic = b'n')]
321 Null {
322 length: u32,
323 #[br(count = length)]
324 data: Vec<u8>,
325 },
326 #[br(magic = b'u')]
327 UnchangedToast {
328 length: u32,
329 #[br(count = length)]
330 data: Vec<u8>,
331 },
332 #[br(magic = b't')]
333 Text {
334 length: u32,
335 #[br(count = length)]
336 data: Vec<u8>,
337 },
338 #[br(magic = b'b')]
339 Binary {
340 length: u32,
341 #[br(count = length)]
342 data: Vec<u8>,
343 },
344}
345
346
347impl ColumnValue {
348 pub fn as_bytes<'r>(&'r self) -> Result<&'r [u8], ParseError> {
349 match self {
350 ColumnValue::Text { data, .. } => Ok(data),
351 ColumnValue::Binary { data, .. } => Ok(data),
352 _ => panic!("Invalid column type"),
353 }
354 }
355
356 pub fn as_str<'r>(&'r self) -> Result<&'r str, ParseError> {
357 Ok(from_utf8(self.as_bytes()?)?)
358 }
359}
360
361pub trait EntityRef {
362 fn id(&self) -> i64;
363 fn from_entity_id(id: i64) -> Self;
364}
365
366impl<T: EntityRef> Decode for T {
367 fn decode(value: &ColumnValue) -> Result<Self, ParseError> {
368 Ok(match value {
369 ColumnValue::Text { .. } => {
370 Self::from_entity_id(Decode::decode(value)?)
371 }
372 ColumnValue::Binary { .. } => {
373 Self::from_entity_id(Decode::decode(value)?)
374 }
375 _ => panic!("Invalid column type"),
376 })
377 }
378}
379
380
381pub trait Decode: Sized {
382 fn decode(value: &ColumnValue) -> Result<Self, ParseError>;
383}
384
385impl Decode for Time {
386 fn decode(value: &ColumnValue) -> Result<Self, ParseError> {
387 Ok(match value {
388 ColumnValue::Text { .. } => {
389 let s = value.as_str()?;
390 let dt = NaiveDateTime::parse_from_str(
391 s,
392 if s.contains('+') {
393 "%Y-%m-%d %H:%M:%S%.f%#z"
397 } else {
398 "%Y-%m-%d %H:%M:%S%.f"
399 },
400 )?;
401
402 Time::from_naive_datetime(dt)
403 }
404 ColumnValue::Binary { .. } => {
405 let us: i64 = Decode::decode(value)?;
406 Time(POSTGRES_EPOCH + us / 1_000_000)
407 }
408 _ => panic!("Invalid column type"),
409 })
410 }
411}
412
413impl Decode for i64 {
414 fn decode(value: &ColumnValue) -> Result<Self, ParseError> {
415 Ok(match value {
416 ColumnValue::Text { .. } => value.as_str()?.parse()?,
417 ColumnValue::Binary { data, .. } => BigEndian::read_int(&data, data.len()),
418 _ => panic!("Invalid column type"),
419 })
420 }
421}
422
423impl Decode for f32 {
424 fn decode(value: &ColumnValue) -> Result<Self, ParseError> {
425 Ok(match value {
426 ColumnValue::Text { .. } => value.as_str()?.parse()?,
427 ColumnValue::Binary { data, .. } => BigEndian::read_f32(&data),
428 _ => panic!("Invalid column type"),
429 })
430 }
431}
432
433impl Decode for f64 {
434 fn decode(value: &ColumnValue) -> Result<Self, ParseError> {
435 Ok(match value {
436 ColumnValue::Text { .. } => value.as_str()?.parse()?,
437 ColumnValue::Binary { data, .. } => BigEndian::read_f64(&data),
438 _ => panic!("Invalid column type"),
439 })
440 }
441}
442
443#[derive(BinRead, Debug)]
444#[br(big)]
445pub enum KeyOrOldTupleData {
446 #[br(magic = b'K')]
447 Key(TupleData),
448 #[br(magic = b'O')]
449 Old(TupleData),
450}
451
452pub fn decode(msg: &[u8]) -> Result<LogicalReplicationMessage, binrw::Error> {
453 LogicalReplicationMessage::read(&mut binrw::io::Cursor::new(msg))
454}
455