use crate::error::{Result, SQLRiteError};
use crate::sql::pager::cell::KIND_FTS_POSTING;
use crate::sql::pager::varint;
#[derive(Debug, Clone, PartialEq)]
pub struct FtsPostingCell {
pub cell_id: i64,
pub term: String,
pub entries: Vec<(i64, u32)>,
}
impl FtsPostingCell {
pub fn posting(cell_id: i64, term: String, entries: Vec<(i64, u32)>) -> Self {
Self {
cell_id,
term,
entries,
}
}
pub fn doc_lengths(cell_id: i64, entries: Vec<(i64, u32)>) -> Self {
Self {
cell_id,
term: String::new(),
entries,
}
}
pub fn encode(&self) -> Result<Vec<u8>> {
let pair_bytes = self.entries.len() * 15;
let mut body = Vec::with_capacity(1 + 10 + 5 + self.term.len() + 5 + pair_bytes);
body.push(KIND_FTS_POSTING);
varint::write_i64(&mut body, self.cell_id);
varint::write_u64(&mut body, self.term.len() as u64);
body.extend_from_slice(self.term.as_bytes());
varint::write_u64(&mut body, self.entries.len() as u64);
for (rowid, value) in &self.entries {
varint::write_i64(&mut body, *rowid);
varint::write_u64(&mut body, *value as u64);
}
let mut out = Vec::with_capacity(body.len() + varint::MAX_VARINT_BYTES);
varint::write_u64(&mut out, body.len() as u64);
out.extend_from_slice(&body);
Ok(out)
}
pub fn decode(buf: &[u8], pos: usize) -> Result<(FtsPostingCell, usize)> {
let (body_len, len_bytes) = varint::read_u64(buf, pos)?;
let body_start = pos + len_bytes;
let body_end = body_start
.checked_add(body_len as usize)
.ok_or_else(|| SQLRiteError::Internal("FTS cell length overflow".to_string()))?;
if body_end > buf.len() {
return Err(SQLRiteError::Internal(format!(
"FTS cell extends past buffer: needs {body_start}..{body_end}, have {}",
buf.len()
)));
}
let body = &buf[body_start..body_end];
if body.first().copied() != Some(KIND_FTS_POSTING) {
return Err(SQLRiteError::Internal(format!(
"FtsPostingCell::decode called on non-FTS entry (kind_tag = {:#x})",
body.first().copied().unwrap_or(0)
)));
}
let mut cur = 1usize;
let (cell_id, n) = varint::read_i64(body, cur)?;
cur += n;
let (term_len, n) = varint::read_u64(body, cur)?;
cur += n;
if term_len as usize > body.len().saturating_sub(cur) {
return Err(SQLRiteError::Internal(format!(
"FTS cell {cell_id}: term_len {term_len} exceeds remaining body \
({}) — corrupt cell?",
body.len() - cur
)));
}
let term_bytes = &body[cur..cur + term_len as usize];
cur += term_len as usize;
let term = std::str::from_utf8(term_bytes)
.map_err(|e| {
SQLRiteError::Internal(format!("FTS cell {cell_id}: term not valid UTF-8: {e}"))
})?
.to_string();
let (count, n) = varint::read_u64(body, cur)?;
cur += n;
if count > 1 << 28 {
return Err(SQLRiteError::Internal(format!(
"FTS cell {cell_id}: claims {count} entries (>2^28) — corrupt cell?"
)));
}
let mut entries = Vec::with_capacity(count as usize);
for _ in 0..count {
let (rowid, n) = varint::read_i64(body, cur)?;
cur += n;
let (value_u64, n) = varint::read_u64(body, cur)?;
cur += n;
if value_u64 > u32::MAX as u64 {
return Err(SQLRiteError::Internal(format!(
"FTS cell {cell_id}: value {value_u64} exceeds u32::MAX — corrupt cell?"
)));
}
entries.push((rowid, value_u64 as u32));
}
if cur != body.len() {
return Err(SQLRiteError::Internal(format!(
"FTS cell {cell_id} had {} trailing bytes",
body.len() - cur
)));
}
Ok((
FtsPostingCell {
cell_id,
term,
entries,
},
len_bytes + body_len as usize,
))
}
pub fn is_doc_lengths(&self) -> bool {
self.term.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn round_trip(cell: &FtsPostingCell) {
let bytes = cell.encode().expect("encode");
let (decoded, consumed) = FtsPostingCell::decode(&bytes, 0).expect("decode");
assert_eq!(
consumed,
bytes.len(),
"decode should consume the whole cell"
);
assert_eq!(&decoded, cell);
}
#[test]
fn posting_cell_round_trips() {
let cell = FtsPostingCell::posting(7, "rust".to_string(), vec![(1, 2), (3, 1), (5, 7)]);
round_trip(&cell);
}
#[test]
fn doc_lengths_sidecar_round_trips() {
let cell = FtsPostingCell::doc_lengths(1, vec![(1, 12), (2, 20), (3, 0), (4, 7)]);
assert!(cell.is_doc_lengths());
round_trip(&cell);
}
#[test]
fn empty_postings_round_trips() {
let cell = FtsPostingCell::posting(2, "ghost".to_string(), vec![]);
round_trip(&cell);
}
#[test]
fn negative_and_large_rowids_round_trip() {
round_trip(&FtsPostingCell::posting(
3,
"x".to_string(),
vec![(-1, 1), (i64::MAX, 99), (i64::MIN, 1)],
));
}
#[test]
fn long_term_round_trips() {
let term = "a".repeat(1024);
let cell = FtsPostingCell::posting(4, term, vec![(1, 1)]);
round_trip(&cell);
}
#[test]
fn long_posting_list_round_trips() {
let entries: Vec<(i64, u32)> = (0..5000_i64).map(|i| (i, ((i * 3) as u32) + 1)).collect();
let cell = FtsPostingCell::posting(5, "common".to_string(), entries);
round_trip(&cell);
}
#[test]
fn decode_rejects_wrong_kind_tag() {
let mut bad = Vec::new();
varint::write_u64(&mut bad, 1); bad.push(0x01); let err = FtsPostingCell::decode(&bad, 0).unwrap_err();
assert!(format!("{err}").contains("non-FTS entry"));
}
#[test]
fn decode_rejects_truncated_buffer() {
let cell = FtsPostingCell::posting(1, "rust".to_string(), vec![(1, 2), (5, 3)]);
let bytes = cell.encode().expect("encode");
for chop in 1..=3 {
let truncated = &bytes[..bytes.len() - chop];
assert!(
FtsPostingCell::decode(truncated, 0).is_err(),
"expected error chopping {chop} byte(s) from end of {} byte cell",
bytes.len()
);
}
}
#[test]
fn decode_rejects_invalid_utf8_term() {
let mut body = Vec::new();
body.push(KIND_FTS_POSTING);
varint::write_i64(&mut body, 1); varint::write_u64(&mut body, 2); body.extend_from_slice(&[0xFF, 0xFE]); varint::write_u64(&mut body, 0); let mut out = Vec::new();
varint::write_u64(&mut out, body.len() as u64);
out.extend_from_slice(&body);
let err = FtsPostingCell::decode(&out, 0).unwrap_err();
assert!(format!("{err}").to_lowercase().contains("utf-8"));
}
#[test]
fn decode_rejects_implausible_count() {
let mut body = Vec::new();
body.push(KIND_FTS_POSTING);
varint::write_i64(&mut body, 1);
varint::write_u64(&mut body, 4);
body.extend_from_slice(b"term");
varint::write_u64(&mut body, 1u64 << 29);
let mut out = Vec::new();
varint::write_u64(&mut out, body.len() as u64);
out.extend_from_slice(&body);
let err = FtsPostingCell::decode(&out, 0).unwrap_err();
assert!(format!("{err}").to_lowercase().contains("corrupt"));
}
}