use std::fs::OpenOptions;
use std::io::Write;
use std::path::Path;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use membase::{MmapOptions, MmapMut};
#[repr(C)]
struct DbHeader {
magic: [u8; 8], version: u32, record_count: u32, data_size: u64, index_offset: u64, }
#[repr(C)]
struct Record {
key_length: u32, value_length: u32, key_offset: u64, value_offset: u64, }
struct MmapDatabase {
file_path: String,
mmap: MmapMut,
index: Arc<RwLock<HashMap<String, (u64, u32)>>>, header: *mut DbHeader,
data_offset: u64,
}
impl MmapDatabase {
fn open(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
let file_path = path.to_string();
let path = Path::new(path);
let file_exists = path.exists();
let mut file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.open(path)?;
if !file_exists {
println!("Creating new database file...");
let header = DbHeader {
magic: *b"MMAPDBV1",
version: 1,
record_count: 0,
data_size: 0,
index_offset: 0,
};
let header_bytes = unsafe {
std::slice::from_raw_parts(
&header as *const _ as *const u8,
std::mem::size_of::<DbHeader>(),
)
};
file.write_all(header_bytes)?;
file.set_len(1024 * 1024)?;
file.sync_all()?;
}
let mmap = unsafe { MmapOptions::new().write(true).map_mut(&file)? };
let header = mmap.as_ptr() as *mut DbHeader;
if file_exists {
let magic = unsafe { &(*header).magic };
if magic != b"MMAPDBV1" {
return Err("Invalid database file format".into());
}
}
let data_offset = std::mem::size_of::<DbHeader>() as u64;
let index = Arc::new(RwLock::new(HashMap::new()));
if file_exists {
println!("Loading existing database...");
let record_count = unsafe { (*header).record_count };
let index_offset = unsafe { (*header).index_offset };
if record_count > 0 && index_offset > 0 {
for i in 0..record_count {
let record_ptr = unsafe {
(mmap.as_ptr().add(index_offset as usize) as *const Record).add(i as usize)
};
let record = unsafe { &*record_ptr };
let key_bytes = &mmap[record.key_offset as usize..(record.key_offset + record.key_length as u64) as usize];
let key = String::from_utf8_lossy(key_bytes).to_string();
index.write().unwrap().insert(key, (record.value_offset, record.value_length));
}
}
}
Ok(MmapDatabase {
file_path,
mmap,
index,
header,
data_offset,
})
}
fn get(&self, key: &str) -> Option<Vec<u8>> {
let index = self.index.read().unwrap();
if let Some(&(offset, length)) = index.get(key) {
let value = self.mmap[offset as usize..(offset + length as u64) as usize].to_vec();
Some(value)
} else {
None
}
}
fn put(&mut self, key: &str, value: &[u8]) -> Result<(), Box<dyn std::error::Error>> {
let key_bytes = key.as_bytes();
let required_space = key_bytes.len() + value.len() + std::mem::size_of::<Record>();
let current_size = self.mmap.len();
let used_size = unsafe { (*self.header).data_size } as usize +
unsafe { (*self.header).record_count } as usize * std::mem::size_of::<Record>();
if used_size + required_space > current_size {
let new_size = (current_size * 2).max(current_size + required_space);
println!("Resizing database from {} to {} bytes", current_size, new_size);
drop(std::mem::replace(&mut self.mmap, unsafe { MmapMut::map_anon(0)? }));
let file = OpenOptions::new()
.read(true)
.write(true)
.open(&self.file_path)?;
file.set_len(new_size as u64)?;
file.sync_all()?;
self.mmap = unsafe { MmapOptions::new().write(true).map_mut(&file)? };
self.header = self.mmap.as_ptr() as *mut DbHeader;
}
let mut index = self.index.write().unwrap();
if let Some(&(offset, length)) = index.get(key) {
if length as usize >= value.len() {
self.mmap[offset as usize..(offset + value.len() as u64) as usize].copy_from_slice(value);
if length as usize != value.len() {
index.insert(key.to_string(), (offset, value.len() as u32));
}
return Ok(());
}
}
let data_size = unsafe { (*self.header).data_size };
let key_offset = self.data_offset + data_size;
let value_offset = key_offset + key_bytes.len() as u64;
self.mmap[key_offset as usize..(key_offset + key_bytes.len() as u64) as usize]
.copy_from_slice(key_bytes);
self.mmap[value_offset as usize..(value_offset + value.len() as u64) as usize]
.copy_from_slice(value);
index.insert(key.to_string(), (value_offset, value.len() as u32));
unsafe {
(*self.header).data_size = data_size + key_bytes.len() as u64 + value.len() as u64;
}
unsafe {
(*self.header).record_count += 1;
}
drop(index);
self.flush_index()?;
Ok(())
}
fn flush_index(&mut self) -> Result<(), Box<dyn std::error::Error>> {
let index = self.index.read().unwrap();
let record_count = index.len();
let index_offset = self.data_offset + unsafe { (*self.header).data_size };
unsafe {
(*self.header).index_offset = index_offset;
(*self.header).record_count = record_count as u32;
}
let mut i = 0;
for (key, &(value_offset, value_length)) in index.iter() {
let key_bytes = key.as_bytes();
let key_offset = self.data_offset + unsafe { (*self.header).data_size } -
(key_bytes.len() as u64 + value_length as u64);
let record = Record {
key_length: key_bytes.len() as u32,
value_length,
key_offset,
value_offset,
};
let record_ptr = unsafe {
(self.mmap.as_ptr().add(index_offset as usize) as *mut Record).add(i)
};
unsafe {
*record_ptr = record;
}
i += 1;
}
self.mmap.flush()?;
Ok(())
}
fn flush(&self) -> Result<(), Box<dyn std::error::Error>> {
self.mmap.flush()?;
Ok(())
}
fn record_count(&self) -> u32 {
unsafe { (*self.header).record_count }
}
fn data_size(&self) -> u64 {
unsafe { (*self.header).data_size }
}
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("membase Simple Database Example");
println!("===============================");
let mut db = MmapDatabase::open("example_db.mmdb")?;
println!("Database opened successfully");
println!(" Record count: {}", db.record_count());
println!(" Data size: {} bytes", db.data_size());
println!("\nStoring values...");
db.put("name", b"membase Database")?;
db.put("version", b"1.0.0")?;
db.put("author", b"membase Team")?;
db.put("description", b"A high-performance memory-mapped database example")?;
db.flush()?;
println!("\nRetrieving values:");
for key in &["name", "version", "author", "description"] {
if let Some(value) = db.get(key) {
let value_str = String::from_utf8_lossy(&value);
println!(" {}: {}", key, value_str);
} else {
println!(" {}: <not found>", key);
}
}
println!("\nUpdating a value...");
db.put("version", b"1.0.1")?;
if let Some(value) = db.get("version") {
let value_str = String::from_utf8_lossy(&value);
println!(" version: {}", value_str);
}
println!("\nStoring a large value...");
let large_value = vec![42u8; 1024 * 1024]; db.put("large_value", &large_value)?;
if let Some(value) = db.get("large_value") {
println!(" large_value: {} bytes, first byte: {}", value.len(), value[0]);
}
println!("\nFinal database statistics:");
println!(" Record count: {}", db.record_count());
println!(" Data size: {} bytes", db.data_size());
println!("\nExample completed successfully!");
Ok(())
}