use pgrx::pg_sys;
use std::mem::size_of;
use std::ptr;
use std::slice;
const P_NEW_BLOCK: pg_sys::BlockNumber = pg_sys::InvalidBlockNumber;
const CENTROIDS_PER_PAGE: usize = 32;
const VECTORS_PER_PAGE: usize = 64;
pub unsafe fn write_centroids(
index: pg_sys::Relation,
centroids: &[Vec<f32>],
start_page: u32,
) -> u32 {
let mut current_page = start_page;
let mut written = 0;
while written < centroids.len() {
let buffer = pg_sys::ReadBuffer(index, P_NEW_BLOCK);
let actual_page = pg_sys::BufferGetBlockNumber(buffer);
pg_sys::LockBuffer(buffer, pg_sys::BUFFER_LOCK_EXCLUSIVE as i32);
let page = pg_sys::BufferGetPage(buffer);
pg_sys::PageInit(page, pg_sys::BLCKSZ as pg_sys::Size, 0);
let header = page as *const pg_sys::PageHeaderData;
let page_data = (header as *const u8).add(size_of::<pg_sys::PageHeaderData>()) as *mut u8;
let mut offset = 0usize;
let batch_size = (centroids.len() - written).min(CENTROIDS_PER_PAGE);
for i in 0..batch_size {
let centroid = ¢roids[written + i];
let cluster_id = (written + i) as u32;
ptr::write(page_data.add(offset) as *mut u32, cluster_id);
offset += 4;
ptr::write(page_data.add(offset) as *mut u32, 0);
offset += 4;
ptr::write(page_data.add(offset) as *mut u32, 0);
offset += 4;
let centroid_ptr = page_data.add(offset) as *mut f32;
for (j, &val) in centroid.iter().enumerate() {
ptr::write(centroid_ptr.add(j), val);
}
offset += centroid.len() * 4;
}
written += batch_size;
pg_sys::MarkBufferDirty(buffer);
pg_sys::UnlockReleaseBuffer(buffer);
current_page = actual_page + 1;
}
current_page
}
pub unsafe fn read_centroids(
index: pg_sys::Relation,
start_page: u32,
num_centroids: usize,
dimensions: usize,
) -> Vec<Vec<f32>> {
let mut centroids = Vec::with_capacity(num_centroids);
let mut read = 0;
let mut current_page = start_page;
while read < num_centroids {
let buffer = pg_sys::ReadBuffer(index, current_page);
pg_sys::LockBuffer(buffer, pg_sys::BUFFER_LOCK_SHARE as i32);
let page = pg_sys::BufferGetPage(buffer);
let header = page as *const pg_sys::PageHeaderData;
let page_data = (header as *const u8).add(size_of::<pg_sys::PageHeaderData>());
let mut offset = 0usize;
let batch_size = (num_centroids - read).min(CENTROIDS_PER_PAGE);
for _ in 0..batch_size {
offset += 12;
let centroid_ptr = page_data.add(offset) as *const f32;
let centroid: Vec<f32> = slice::from_raw_parts(centroid_ptr, dimensions).to_vec();
centroids.push(centroid);
offset += dimensions * 4;
}
read += batch_size;
pg_sys::UnlockReleaseBuffer(buffer);
current_page += 1;
}
centroids
}
#[derive(Debug, Clone)]
pub struct InvertedListEntry {
pub tid: pg_sys::ItemPointerData,
pub vector: Vec<f32>,
}
pub unsafe fn write_inverted_list(
index: pg_sys::Relation,
list: &[(pg_sys::ItemPointerData, Vec<f32>)],
) -> u32 {
if list.is_empty() {
return 0;
}
let buffer = pg_sys::ReadBuffer(index, P_NEW_BLOCK);
let page_num = pg_sys::BufferGetBlockNumber(buffer);
pg_sys::LockBuffer(buffer, pg_sys::BUFFER_LOCK_EXCLUSIVE as i32);
let page = pg_sys::BufferGetPage(buffer);
pg_sys::PageInit(page, pg_sys::BLCKSZ as pg_sys::Size, 0);
let header = page as *const pg_sys::PageHeaderData;
let page_data = (header as *const u8).add(size_of::<pg_sys::PageHeaderData>()) as *mut u8;
let mut offset = 0usize;
let dimensions = list[0].1.len();
let batch_size = list.len().min(VECTORS_PER_PAGE);
for i in 0..batch_size {
let (tid, vector) = &list[i];
ptr::write(page_data.add(offset) as *mut pg_sys::ItemPointerData, *tid);
offset += size_of::<pg_sys::ItemPointerData>();
let vector_ptr = page_data.add(offset) as *mut f32;
for (j, &val) in vector.iter().enumerate() {
ptr::write(vector_ptr.add(j), val);
}
offset += dimensions * 4;
}
pg_sys::MarkBufferDirty(buffer);
pg_sys::UnlockReleaseBuffer(buffer);
page_num
}
pub unsafe fn read_inverted_list(
index: pg_sys::Relation,
start_page: u32,
dimensions: usize,
) -> Vec<InvertedListEntry> {
if start_page == 0 {
return Vec::new();
}
let buffer = pg_sys::ReadBuffer(index, start_page);
pg_sys::LockBuffer(buffer, pg_sys::BUFFER_LOCK_SHARE as i32);
let page = pg_sys::BufferGetPage(buffer);
let header = page as *const pg_sys::PageHeaderData;
let page_data = (header as *const u8).add(size_of::<pg_sys::PageHeaderData>());
let mut offset = 0usize;
let mut entries = Vec::new();
let entry_size = size_of::<pg_sys::ItemPointerData>() + dimensions * 4;
let page_header_size = size_of::<pg_sys::PageHeaderData>();
let available_space = pg_sys::BLCKSZ as usize - page_header_size;
let max_entries = available_space / entry_size;
for _ in 0..max_entries {
if offset + entry_size > available_space {
break;
}
let tid = ptr::read(page_data.add(offset) as *const pg_sys::ItemPointerData);
offset += size_of::<pg_sys::ItemPointerData>();
if tid.ip_blkid.bi_hi == 0 && tid.ip_blkid.bi_lo == 0 {
break;
}
let vector_ptr = page_data.add(offset) as *const f32;
let vector: Vec<f32> = slice::from_raw_parts(vector_ptr, dimensions).to_vec();
offset += dimensions * 4;
entries.push(InvertedListEntry { tid, vector });
}
pg_sys::UnlockReleaseBuffer(buffer);
entries
}
pub unsafe fn extract_vector_from_tuple(
tuple: *mut pg_sys::HeapTupleData,
tuple_desc: pg_sys::TupleDesc,
attno: i16,
) -> Option<Vec<f32>> {
let mut is_null = false;
let datum = pg_sys::heap_getattr(tuple, attno as i32, tuple_desc, &mut is_null);
if is_null {
return None;
}
extract_vector_from_datum(datum)
}
unsafe fn extract_vector_from_datum(datum: pg_sys::Datum) -> Option<Vec<f32>> {
if datum.is_null() {
return None;
}
let varlena = pg_sys::pg_detoast_datum_packed(datum.cast_mut_ptr());
let varlena_ptr = varlena as *const u8;
let header = ptr::read(varlena_ptr as *const u32);
let _data_size = (header >> 2) as usize;
let data_ptr = varlena_ptr.add(4);
let dimensions = ptr::read(data_ptr as *const u32) as usize;
let vector_ptr = data_ptr.add(4) as *const f32;
let vector = slice::from_raw_parts(vector_ptr, dimensions).to_vec();
Some(vector)
}
pub unsafe fn create_vector_datum(vector: &[f32]) -> pg_sys::Datum {
let dimensions = vector.len() as u32;
let data_size = 4 + (dimensions as usize * 4);
let total_size = 4 + data_size;
let varlena = pg_sys::palloc(total_size) as *mut u8;
let header = (total_size as u32) << 2;
ptr::write(varlena as *mut u32, header);
let data_ptr = varlena.add(4);
ptr::write(data_ptr as *mut u32, dimensions);
let vector_ptr = data_ptr.add(4) as *mut f32;
for (i, &val) in vector.iter().enumerate() {
ptr::write(vector_ptr.add(i), val);
}
pg_sys::Datum::from(varlena as *mut ::std::os::raw::c_void)
}
pub type HeapScanCallback =
unsafe extern "C" fn(tuple: *mut pg_sys::HeapTupleData, context: *mut ::std::os::raw::c_void);
pub unsafe fn scan_heap_for_vectors(
_heap: pg_sys::Relation,
_index_info: *mut pg_sys::IndexInfo,
_callback: impl Fn(pg_sys::ItemPointerData, Vec<f32>),
) {
}
#[cfg(test)]
mod tests {
#[test]
fn test_centroid_serialization() {
}
#[test]
fn test_inverted_list_serialization() {
}
}