use schema_core::{ColumnName, GenericValue, TableName};
use sources_core::{RowKey, SourceError};
#[derive(Debug)]
pub(crate) enum Decoded {
Relation(Relation),
Insert {
rel: u32,
new: Tuple,
},
Update {
rel: u32,
old: Option<Tuple>,
new: Tuple,
},
Delete {
rel: u32,
old: Tuple,
},
Truncate {
rels: Vec<u32>,
},
Other,
}
#[derive(Debug, Clone)]
pub(crate) struct Relation {
pub(crate) oid: u32,
pub(crate) table: TableName,
pub(crate) columns: Vec<Column>,
}
#[derive(Debug, Clone)]
pub(crate) struct Column {
pub(crate) name: ColumnName,
pub(crate) is_key: bool,
pub(crate) type_oid: u32,
}
pub(crate) type Tuple = Vec<Cell>;
#[derive(Debug, Clone)]
pub(crate) enum Cell {
Null,
Unchanged,
Text(String),
}
pub(crate) fn row_key(rel: &Relation, tuple: &Tuple) -> Result<RowKey, SourceError> {
let mut pairs = Vec::new();
for (col, cell) in rel.columns.iter().zip(tuple.iter()) {
if col.is_key {
let value = match cell {
Cell::Text(text) => typed_value(text, col.type_oid),
Cell::Null | Cell::Unchanged => GenericValue::Null,
};
pairs.push((col.name.clone(), value));
}
}
if pairs.is_empty() {
return Err(SourceError::Decode(format!(
"relation {} carries no key columns; set REPLICA IDENTITY so changes can be addressed",
rel.table
)));
}
Ok(RowKey(pairs))
}
fn typed_value(text: &str, type_oid: u32) -> GenericValue {
let as_text = || GenericValue::String(text.to_owned());
match type_oid {
16 => match text {
"t" => GenericValue::Bool(true),
"f" => GenericValue::Bool(false),
_ => as_text(),
},
21 => text
.parse::<i16>()
.map_or_else(|_| as_text(), GenericValue::SmallInt),
23 => text
.parse::<i32>()
.map_or_else(|_| as_text(), GenericValue::Int),
20 | 26 => text
.parse::<i64>()
.map_or_else(|_| as_text(), GenericValue::BigInt),
700 => text
.parse::<f32>()
.map_or_else(|_| as_text(), GenericValue::Float),
701 => text
.parse::<f64>()
.map_or_else(|_| as_text(), GenericValue::Double),
1700 => rust_decimal::Decimal::from_str_exact(text)
.map_or_else(|_| as_text(), GenericValue::Decimal),
2950 => uuid::Uuid::parse_str(text).map_or_else(|_| as_text(), GenericValue::Uuid),
_ => as_text(),
}
}
pub(crate) fn decode(data: &[u8]) -> Result<Decoded, SourceError> {
let (&tag, rest) = data
.split_first()
.ok_or_else(|| SourceError::Decode("pgoutput: empty message".into()))?;
let mut cur = Cursor::new(rest);
match tag {
b'R' => decode_relation(&mut cur),
b'I' => {
let rel = cur.u32()?;
expect(&mut cur, b'N', "insert new-tuple marker")?;
Ok(Decoded::Insert {
rel,
new: decode_tuple(&mut cur)?,
})
}
b'U' => {
let rel = cur.u32()?;
let marker = cur.u8()?;
let (old, new) = match marker {
b'K' | b'O' => {
let old = decode_tuple(&mut cur)?;
expect(&mut cur, b'N', "update new-tuple marker")?;
(Some(old), decode_tuple(&mut cur)?)
}
b'N' => (None, decode_tuple(&mut cur)?),
other => {
return Err(SourceError::Decode(format!(
"pgoutput update: unexpected tuple marker {other:#x}"
)));
}
};
Ok(Decoded::Update { rel, old, new })
}
b'D' => {
let rel = cur.u32()?;
let marker = cur.u8()?;
match marker {
b'K' | b'O' => {}
other => {
return Err(SourceError::Decode(format!(
"pgoutput delete: unexpected tuple marker {other:#x}"
)));
}
}
Ok(Decoded::Delete {
rel,
old: decode_tuple(&mut cur)?,
})
}
b'T' => {
let nrels = cur.i16_count()?;
let _flags = cur.u8()?;
let mut rels = Vec::with_capacity(nrels);
for _ in 0..nrels {
rels.push(cur.u32()?);
}
Ok(Decoded::Truncate { rels })
}
_ => Ok(Decoded::Other),
}
}
fn decode_relation(cur: &mut Cursor<'_>) -> Result<Decoded, SourceError> {
let oid = cur.u32()?;
let _namespace = cur.cstring()?;
let relname = cur.cstring()?;
let table = TableName::try_new(relname.clone()).map_err(|e| {
SourceError::Decode(format!("pgoutput relation: invalid table {relname:?}: {e}"))
})?;
let _replica_identity = cur.u8()?;
let ncols = cur.i16_count()?;
let mut columns = Vec::with_capacity(ncols);
for _ in 0..ncols {
let flags = cur.u8()?;
let colname = cur.cstring()?;
let type_oid = cur.u32()?;
let _type_modifier = cur.u32()?;
let name = ColumnName::try_new(colname.clone()).map_err(|e| {
SourceError::Decode(format!(
"pgoutput relation: invalid column {colname:?}: {e}"
))
})?;
columns.push(Column {
name,
is_key: (flags & 1) != 0,
type_oid,
});
}
Ok(Decoded::Relation(Relation {
oid,
table,
columns,
}))
}
fn decode_tuple(cur: &mut Cursor<'_>) -> Result<Tuple, SourceError> {
let ncols = cur.i16_count()?;
let mut cells = Vec::with_capacity(ncols);
for _ in 0..ncols {
let kind = cur.u8()?;
let cell = match kind {
b'n' => Cell::Null,
b'u' => Cell::Unchanged,
b't' | b'b' => {
let len = cur.i32_len()?;
Cell::Text(String::from_utf8_lossy(cur.take(len)?).into_owned())
}
other => {
return Err(SourceError::Decode(format!(
"pgoutput tuple: unknown cell kind {other:#x}"
)));
}
};
cells.push(cell);
}
Ok(cells)
}
fn expect(cur: &mut Cursor<'_>, want: u8, what: &str) -> Result<(), SourceError> {
let got = cur.u8()?;
if got == want {
Ok(())
} else {
Err(SourceError::Decode(format!(
"pgoutput: expected {what} {want:#x}, got {got:#x}"
)))
}
}
struct Cursor<'a> {
buf: &'a [u8],
pos: usize,
}
impl<'a> Cursor<'a> {
fn new(buf: &'a [u8]) -> Self {
Self { buf, pos: 0 }
}
fn take(&mut self, n: usize) -> Result<&'a [u8], SourceError> {
let end = self
.pos
.checked_add(n)
.ok_or_else(|| truncated("length overflow"))?;
let slice = self
.buf
.get(self.pos..end)
.ok_or_else(|| truncated("bytes"))?;
self.pos = end;
Ok(slice)
}
fn u8(&mut self) -> Result<u8, SourceError> {
let byte = self
.buf
.get(self.pos)
.copied()
.ok_or_else(|| truncated("u8"))?;
self.pos += 1;
Ok(byte)
}
fn u32(&mut self) -> Result<u32, SourceError> {
let arr: [u8; 4] = self.take(4)?.try_into().map_err(|_| truncated("u32"))?;
Ok(u32::from_be_bytes(arr))
}
fn i32_len(&mut self) -> Result<usize, SourceError> {
let arr: [u8; 4] = self.take(4)?.try_into().map_err(|_| truncated("i32"))?;
Ok(i32::from_be_bytes(arr).max(0) as usize)
}
fn i16_count(&mut self) -> Result<usize, SourceError> {
let arr: [u8; 2] = self.take(2)?.try_into().map_err(|_| truncated("i16"))?;
Ok(i16::from_be_bytes(arr).max(0) as usize)
}
fn cstring(&mut self) -> Result<String, SourceError> {
let rest = self
.buf
.get(self.pos..)
.ok_or_else(|| truncated("cstring"))?;
let nul = rest
.iter()
.position(|&b| b == 0)
.ok_or_else(|| SourceError::Decode("pgoutput: unterminated cstring".into()))?;
let text = rest.get(..nul).ok_or_else(|| truncated("cstring"))?;
let out = String::from_utf8_lossy(text).into_owned();
self.pos += nul + 1;
Ok(out)
}
}
fn truncated(what: &str) -> SourceError {
SourceError::Decode(format!("pgoutput: truncated {what}"))
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
mod tests;