use std::collections::HashMap;
use std::sync::Arc;
use xlog_core::Schema;
use xlog_cuda::{CudaBuffer, CudaKernelProvider};
pub struct RelationStore {
provider: Arc<CudaKernelProvider>,
relations: HashMap<String, VersionedCudaBuffer>,
}
struct VersionedCudaBuffer {
buffer: CudaBuffer,
version: u64,
}
impl RelationStore {
pub fn new(provider: Arc<CudaKernelProvider>) -> Self {
Self {
provider,
relations: HashMap::new(),
}
}
pub fn get(&self, name: &str) -> Option<&CudaBuffer> {
self.relations.get(name).map(|e| &e.buffer)
}
pub fn get_mut(&mut self, name: &str) -> Option<&mut CudaBuffer> {
self.relations.get_mut(name).map(|e| {
e.version = e.version.saturating_add(1);
&mut e.buffer
})
}
pub fn get_with_version(&self, name: &str) -> Option<(&CudaBuffer, u64)> {
self.relations.get(name).map(|e| (&e.buffer, e.version))
}
pub fn version(&self, name: &str) -> Option<u64> {
self.relations.get(name).map(|e| e.version)
}
pub fn put(&mut self, name: &str, buffer: CudaBuffer) {
let version = self
.relations
.get(name)
.map(|e| e.version.saturating_add(1))
.unwrap_or(1);
self.relations
.insert(name.to_string(), VersionedCudaBuffer { buffer, version });
}
pub fn get_or_insert_empty(
&mut self,
name: &str,
schema: &Schema,
) -> xlog_core::Result<&CudaBuffer> {
if !self.relations.contains_key(name) {
let buffer = self.provider.create_empty_buffer(schema.clone())?;
self.relations
.insert(name.to_string(), VersionedCudaBuffer { buffer, version: 1 });
}
Ok(&self
.relations
.get(name)
.expect("Relation must exist after insertion")
.buffer)
}
pub fn get_or_insert_empty_mut(
&mut self,
name: &str,
schema: &Schema,
) -> xlog_core::Result<&mut CudaBuffer> {
if !self.relations.contains_key(name) {
let buffer = self.provider.create_empty_buffer(schema.clone())?;
self.relations
.insert(name.to_string(), VersionedCudaBuffer { buffer, version: 1 });
}
let entry = self
.relations
.get_mut(name)
.expect("Relation must exist after insertion");
entry.version = entry.version.saturating_add(1);
Ok(&mut entry.buffer)
}
pub fn contains(&self, name: &str) -> bool {
self.relations.contains_key(name)
}
pub fn remove(&mut self, name: &str) -> Option<CudaBuffer> {
self.relations.remove(name).map(|e| e.buffer)
}
pub fn clear(&mut self) {
self.relations.clear();
}
pub fn len(&self) -> usize {
self.relations.len()
}
pub fn is_empty(&self) -> bool {
self.relations.is_empty()
}
pub fn names(&self) -> impl Iterator<Item = &str> {
self.relations.keys().map(|s| s.as_str())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use xlog_core::{MemoryBudget, ScalarType};
use xlog_cuda::{CudaDevice, CudaKernelProvider, GpuMemoryManager};
fn setup_provider() -> Option<Arc<CudaKernelProvider>> {
let device = match CudaDevice::new(0) {
Ok(d) => Arc::new(d),
Err(e) => {
eprintln!("Skipping: CUDA runtime unavailable: {}", e);
return None;
}
};
let memory = Arc::new(GpuMemoryManager::new(
device.clone(),
MemoryBudget::with_limit(1024 * 1024 * 1024),
));
CudaKernelProvider::new(device, memory).ok().map(Arc::new)
}
fn setup_store() -> Option<(RelationStore, Arc<CudaKernelProvider>)> {
let provider = setup_provider()?;
let store = RelationStore::new(provider.clone());
Some((store, provider))
}
fn test_schema() -> Schema {
Schema::new(vec![
("a".to_string(), ScalarType::U32),
("b".to_string(), ScalarType::U64),
])
}
fn device_row_count(provider: &CudaKernelProvider, buffer: &CudaBuffer) -> u32 {
let mut host_rows = [0u32];
provider
.device()
.inner()
.dtoh_sync_copy_into(buffer.num_rows_device(), &mut host_rows)
.expect("dtoh row count");
host_rows[0]
}
fn make_buffer(provider: &CudaKernelProvider, schema: Schema, rows: usize) -> CudaBuffer {
if schema.arity() == 0 {
if rows == 0 {
return provider.create_empty_buffer(schema).expect("empty buffer");
}
let rows_u32 = u32::try_from(rows).expect("row count fits u32");
let mut d_num_rows = provider.memory().alloc::<u32>(1).expect("alloc");
provider
.device()
.inner()
.htod_sync_copy_into(&[rows_u32], &mut d_num_rows)
.expect("htod row count");
return CudaBuffer::from_columns(Vec::new(), rows as u64, d_num_rows, schema);
}
if rows == 0 {
return provider.create_empty_buffer(schema).expect("empty buffer");
}
let mut columns: Vec<Vec<u8>> = Vec::with_capacity(schema.arity());
for col_idx in 0..schema.arity() {
let size = schema
.column_type(col_idx)
.map(|t| t.size_bytes())
.unwrap_or(4);
columns.push(vec![0u8; rows * size]);
}
let slices: Vec<&[u8]> = columns.iter().map(|c| c.as_slice()).collect();
provider
.create_buffer_from_slices(&slices, schema)
.expect("buffer")
}
#[test]
fn test_new_store_is_empty() {
let Some((store, _provider)) = setup_store() else {
return;
};
assert!(store.is_empty());
assert_eq!(store.len(), 0);
}
#[test]
fn test_put_and_get() {
let Some((mut store, provider)) = setup_store() else {
return;
};
let buffer = provider
.create_empty_buffer(Schema::new(vec![]))
.expect("empty");
store.put("test_rel", buffer);
assert!(store.contains("test_rel"));
assert!(!store.is_empty());
assert_eq!(store.len(), 1);
let retrieved = store.get("test_rel");
assert!(retrieved.is_some());
}
#[test]
fn test_get_nonexistent() {
let Some((store, _provider)) = setup_store() else {
return;
};
assert!(store.get("nonexistent").is_none());
}
#[test]
fn test_contains() {
let Some((mut store, provider)) = setup_store() else {
return;
};
assert!(!store.contains("test"));
store.put(
"test",
provider
.create_empty_buffer(Schema::new(vec![]))
.expect("empty"),
);
assert!(store.contains("test"));
assert!(!store.contains("other"));
}
#[test]
fn test_remove() {
let Some((mut store, provider)) = setup_store() else {
return;
};
store.put(
"test",
provider
.create_empty_buffer(Schema::new(vec![]))
.expect("empty"),
);
assert!(store.contains("test"));
let removed = store.remove("test");
assert!(removed.is_some());
assert!(!store.contains("test"));
assert!(store.is_empty());
}
#[test]
fn test_remove_nonexistent() {
let Some((mut store, _provider)) = setup_store() else {
return;
};
let removed = store.remove("nonexistent");
assert!(removed.is_none());
}
#[test]
fn test_clear() {
let Some((mut store, provider)) = setup_store() else {
return;
};
let empty = provider
.create_empty_buffer(Schema::new(vec![]))
.expect("empty");
store.put("rel1", empty);
store.put(
"rel2",
provider
.create_empty_buffer(Schema::new(vec![]))
.expect("empty"),
);
store.put(
"rel3",
provider
.create_empty_buffer(Schema::new(vec![]))
.expect("empty"),
);
assert_eq!(store.len(), 3);
store.clear();
assert!(store.is_empty());
assert_eq!(store.len(), 0);
}
#[test]
fn test_get_or_insert_empty_existing() {
let Some((mut store, provider)) = setup_store() else {
return;
};
let schema = test_schema();
let buffer = make_buffer(&provider, schema.clone(), 100);
store.put("existing", buffer);
let result = store.get_or_insert_empty("existing", &schema).unwrap();
assert_eq!(device_row_count(&provider, result), 100);
assert_eq!(result.schema(), &schema);
assert_eq!(store.len(), 1);
}
#[test]
fn test_get_or_insert_empty_nonexistent() {
let Some((mut store, provider)) = setup_store() else {
return;
};
let schema = test_schema();
assert!(store.is_empty());
let result = store.get_or_insert_empty("nonexistent", &schema).unwrap();
assert_eq!(device_row_count(&provider, result), 0);
assert_eq!(result.schema(), &schema);
assert!(result.is_empty());
assert!(store.contains("nonexistent"));
assert_eq!(store.len(), 1);
}
#[test]
fn test_get_mut() {
let Some((mut store, provider)) = setup_store() else {
return;
};
let buffer = make_buffer(&provider, Schema::new(vec![]), 10);
store.put("test", buffer);
{
let buf_mut = store.get_mut("test").unwrap();
buf_mut.row_cap = 50;
provider
.device()
.inner()
.htod_sync_copy_into(&[50u32], &mut buf_mut.d_num_rows)
.expect("htod row count");
}
assert_eq!(device_row_count(&provider, store.get("test").unwrap()), 50);
}
#[test]
fn test_get_mut_nonexistent() {
let Some((mut store, _provider)) = setup_store() else {
return;
};
assert!(store.get_mut("nonexistent").is_none());
}
#[test]
fn test_get_or_insert_empty_mut() {
let Some((mut store, provider)) = setup_store() else {
return;
};
let schema = test_schema();
{
let buf_mut = store.get_or_insert_empty_mut("new_rel", &schema).unwrap();
assert_eq!(device_row_count(&provider, buf_mut), 0);
buf_mut.row_cap = 42;
provider
.device()
.inner()
.htod_sync_copy_into(&[42u32], &mut buf_mut.d_num_rows)
.expect("htod row count");
}
assert!(store.contains("new_rel"));
assert_eq!(
device_row_count(&provider, store.get("new_rel").unwrap()),
42
);
}
#[test]
fn test_put_replaces_existing() {
let Some((mut store, provider)) = setup_store() else {
return;
};
let buffer1 = make_buffer(&provider, Schema::new(vec![]), 10);
let buffer2 = make_buffer(&provider, Schema::new(vec![]), 20);
store.put("test", buffer1);
assert_eq!(device_row_count(&provider, store.get("test").unwrap()), 10);
store.put("test", buffer2);
assert_eq!(device_row_count(&provider, store.get("test").unwrap()), 20);
assert_eq!(store.len(), 1);
}
#[test]
fn test_names_iterator() {
let Some((mut store, provider)) = setup_store() else {
return;
};
store.put(
"alpha",
provider
.create_empty_buffer(Schema::new(vec![]))
.expect("empty"),
);
store.put(
"beta",
provider
.create_empty_buffer(Schema::new(vec![]))
.expect("empty"),
);
store.put(
"gamma",
provider
.create_empty_buffer(Schema::new(vec![]))
.expect("empty"),
);
let mut names: Vec<&str> = store.names().collect();
names.sort();
assert_eq!(names, vec!["alpha", "beta", "gamma"]);
}
#[test]
fn test_multiple_operations() {
let Some((mut store, provider)) = setup_store() else {
return;
};
let empty = provider
.create_empty_buffer(Schema::new(vec![]))
.expect("empty");
store.put("a", empty);
store.put(
"b",
provider
.create_empty_buffer(Schema::new(vec![]))
.expect("empty"),
);
store.put(
"c",
provider
.create_empty_buffer(Schema::new(vec![]))
.expect("empty"),
);
assert_eq!(store.len(), 3);
store.remove("b");
assert_eq!(store.len(), 2);
assert!(!store.contains("b"));
store.put("a", make_buffer(&provider, Schema::new(vec![]), 50));
assert_eq!(store.len(), 2);
assert_eq!(device_row_count(&provider, store.get("a").unwrap()), 50);
store.clear();
assert!(store.is_empty());
}
}