1use crate::partition::PartitionError;
8use crate::Result;
9use redb::ReadableTable;
10use std::marker::PhantomData;
11
12fn build_segment_prefix(base_key: &[u8], shard: u16) -> Result<Vec<u8>> {
15 let mut prefix = Vec::with_capacity(4 + base_key.len() + 2);
16
17 prefix.extend_from_slice(&(base_key.len() as u32).to_be_bytes());
19
20 prefix.extend_from_slice(base_key);
22
23 prefix.extend_from_slice(&shard.to_be_bytes());
25
26 Ok(prefix)
27}
28
29#[derive(Debug, Clone)]
31pub struct SegmentInfo {
32 pub segment_id: u16,
34 pub segment_key: Vec<u8>,
36 pub segment_data: Option<Vec<u8>>,
38}
39
40impl SegmentInfo {
41 pub fn new(segment_id: u16, segment_key: Vec<u8>) -> Self {
43 Self {
44 segment_id,
45 segment_key,
46 segment_data: None,
47 }
48 }
49
50 pub fn with_data(segment_id: u16, segment_key: Vec<u8>, segment_data: Vec<u8>) -> Self {
52 Self {
53 segment_id,
54 segment_key,
55 segment_data: Some(segment_data),
56 }
57 }
58}
59
60pub fn enumerate_segments<'a, T>(
74 table: &'a T,
75 base_key: &[u8],
76 shard: u16,
77) -> Result<SegmentIterator<'a>>
78where
79 T: ReadableTable<&'static [u8], &'static [u8]>,
80{
81 let (start_key, end_key) = build_segment_scan_range(base_key, shard)?;
82 let range = table
83 .range(start_key.as_slice()..end_key.as_slice())
84 .map_err(|e| {
85 crate::error::Error::Partition(PartitionError::SegmentScanFailed(format!(
86 "Failed to create range iterator: {}",
87 e
88 )))
89 })?;
90
91 Ok(SegmentIterator {
92 range,
93 base_key: base_key.to_vec(),
94 shard,
95 _phantom: PhantomData,
96 })
97}
98
99pub fn find_head_segment<T>(table: &T, base_key: &[u8], shard: u16) -> Result<Option<u16>>
113where
114 T: ReadableTable<&'static [u8], &'static [u8]>,
115{
116 let mut iter = enumerate_segments(table, base_key, shard)?;
117 let mut head_segment = None;
118
119 while let Some(segment_result) = iter.next() {
120 let segment_info = segment_result?;
121 head_segment = Some(segment_info.segment_id);
122 }
123
124 Ok(head_segment)
125}
126
127fn build_segment_scan_range(base_key: &[u8], shard: u16) -> Result<(Vec<u8>, Vec<u8>)> {
139 let start_key = build_segment_prefix(base_key, shard)?;
140
141 let mut end_key = start_key.clone();
144 if let Some(last_byte) = end_key.last_mut() {
145 *last_byte = last_byte.saturating_add(1);
146 } else {
147 return Err(crate::error::Error::Partition(
148 PartitionError::SegmentScanFailed(
149 "Prefix key is empty, cannot create range".to_string(),
150 ),
151 ));
152 }
153
154 Ok((start_key, end_key))
155}
156
157fn extract_segment_id(encoded_key: &[u8]) -> Result<u16> {
168 if encoded_key.len() < 6 {
169 return Err(crate::error::Error::Partition(
171 PartitionError::SegmentScanFailed(
172 "Encoded key too short to contain segment ID".to_string(),
173 ),
174 ));
175 }
176
177 let segment_bytes = &encoded_key[encoded_key.len() - 2..];
178 Ok(u16::from_be_bytes([segment_bytes[0], segment_bytes[1]]))
179}
180
181fn validate_key_match(encoded_key: &[u8], expected_base_key: &[u8], expected_shard: u16) -> bool {
195 if encoded_key.len() < 4 {
196 return false;
197 }
198
199 let key_len = u32::from_be_bytes([
201 encoded_key[0],
202 encoded_key[1],
203 encoded_key[2],
204 encoded_key[3],
205 ]) as usize;
206
207 if encoded_key.len() < 4 + key_len + 4 {
208 return false;
210 }
211
212 let base_key_slice = &encoded_key[4..4 + key_len];
213 if base_key_slice != expected_base_key {
214 return false;
215 }
216
217 let shard_start = 4 + key_len;
219 let shard_bytes = &encoded_key[shard_start..shard_start + 2];
220 let shard = u16::from_be_bytes([shard_bytes[0], shard_bytes[1]]);
221
222 shard == expected_shard
223}
224
225pub struct SegmentIterator<'a> {
230 range: redb::Range<'a, &'static [u8], &'static [u8]>,
231 base_key: Vec<u8>,
232 shard: u16,
233 _phantom: PhantomData<()>,
234}
235
236impl<'a> Iterator for SegmentIterator<'a> {
237 type Item = Result<SegmentInfo>;
238
239 fn next(&mut self) -> Option<Self::Item> {
240 loop {
241 match self.range.next() {
242 Some(Ok((key_guard, value_guard))) => {
243 let key = key_guard.value();
244 let value = value_guard.value();
245
246 if !validate_key_match(key, &self.base_key, self.shard) {
248 continue; }
250
251 match extract_segment_id(key) {
253 Ok(segment_id) => {
254 let segment_info =
255 SegmentInfo::with_data(segment_id, key.to_vec(), value.to_vec());
256 return Some(Ok(segment_info));
257 }
258 Err(e) => return Some(Err(e)),
259 }
260 }
261 Some(Err(e)) => {
262 return Some(Err(PartitionError::SegmentScanFailed(format!(
263 "Database error during iteration: {}",
264 e
265 ))
266 .into()));
267 }
268 None => return None,
269 }
270 }
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 use redb::{Database, ReadableDatabase, TableDefinition};
279
280 const TEST_TABLE: TableDefinition<&[u8], &[u8]> = TableDefinition::new("test_scan");
281
282 #[test]
283 fn test_build_segment_scan_range() {
284 let base_key = b"test_key";
285 let shard = 42;
286
287 let (start, end) = build_segment_scan_range(base_key, shard).unwrap();
288
289 let expected_prefix = build_segment_prefix(base_key, shard).unwrap();
291 assert_eq!(start, expected_prefix);
292
293 assert_eq!(end.len(), start.len());
295 assert_eq!(end[..end.len() - 1], start[..start.len() - 1]);
296 assert_eq!(end[end.len() - 1], start[start.len() - 1] + 1);
297 }
298
299 #[test]
300 fn test_extract_segment_id() {
301 let base_key = b"test";
303 let shard = 42u16;
304 let segment = 123u16;
305
306 let mut key = Vec::new();
307 key.extend_from_slice(&4u32.to_be_bytes());
308 key.extend_from_slice(base_key);
309 key.extend_from_slice(&shard.to_be_bytes());
310 key.extend_from_slice(&segment.to_be_bytes());
311
312 let extracted = extract_segment_id(&key).unwrap();
313 assert_eq!(extracted, 123);
314 }
315
316 #[test]
317 fn test_extract_segment_id_invalid() {
318 let short_key = b"short";
319 let result = extract_segment_id(short_key);
320 assert!(result.is_err());
321 }
322
323 #[test]
324 fn test_validate_key_match() {
325 let base_key = b"test_key";
326 let shard = 42u16;
327 let segment = 123u16;
328
329 let mut key = Vec::new();
331 key.extend_from_slice(&(base_key.len() as u32).to_be_bytes());
332 key.extend_from_slice(base_key);
333 key.extend_from_slice(&shard.to_be_bytes());
334 key.extend_from_slice(&segment.to_be_bytes());
335
336 assert!(validate_key_match(&key, base_key, shard));
337
338 assert!(!validate_key_match(&key, b"wrong_key", shard));
340
341 assert!(!validate_key_match(&key, base_key, 99));
343 }
344
345 #[test]
346 fn test_enumerate_segments() {
347 let temp_file = tempfile::NamedTempFile::new().unwrap();
348 let db = Database::create(temp_file.path()).unwrap();
349 let write_txn = db.begin_write().unwrap();
350
351 let base_key = b"test_key";
352 let shard = 42u16;
353
354 {
355 let mut table = write_txn.open_table(TEST_TABLE).unwrap();
356
357 for segment in 0..3u16 {
359 let segment_key =
360 crate::partition::table::encode_segment_key(base_key, shard, segment).unwrap();
361 let segment_data = format!("segment_{}", segment).into_bytes();
362 table.insert(&*segment_key, &*segment_data).unwrap();
363 }
364 }
365
366 write_txn.commit().unwrap();
367
368 let read_txn = db.begin_read().unwrap();
370 let table = read_txn.open_table(TEST_TABLE).unwrap();
371
372 let mut iter = enumerate_segments(&table, base_key, shard).unwrap();
373 let mut segments = Vec::new();
374
375 while let Some(segment_result) = iter.next() {
376 segments.push(segment_result.unwrap());
377 }
378
379 assert_eq!(segments.len(), 3);
380
381 for (i, segment) in segments.iter().enumerate() {
383 assert_eq!(segment.segment_id, i as u16);
384 }
385 }
386
387 #[test]
388 fn test_find_head_segment() {
389 let temp_file = tempfile::NamedTempFile::new().unwrap();
390 let db = Database::create(temp_file.path()).unwrap();
391 let write_txn = db.begin_write().unwrap();
392
393 let base_key = b"test_key";
394 let shard = 42u16;
395
396 {
397 let mut table = write_txn.open_table(TEST_TABLE).unwrap();
398
399 for segment in [0u16, 2u16, 5u16] {
401 let segment_key =
402 crate::partition::table::encode_segment_key(base_key, shard, segment).unwrap();
403 let segment_data = format!("segment_{}", segment).into_bytes();
404 table.insert(&*segment_key, &*segment_data).unwrap();
405 }
406 }
407
408 write_txn.commit().unwrap();
409
410 let read_txn = db.begin_read().unwrap();
412 let table = read_txn.open_table(TEST_TABLE).unwrap();
413
414 let head_segment = find_head_segment(&table, base_key, shard).unwrap();
415 assert_eq!(head_segment, Some(5));
416 }
417
418 #[test]
419 fn test_find_head_segment_empty() {
420 let temp_file = tempfile::NamedTempFile::new().unwrap();
421 let db = Database::create(temp_file.path()).unwrap();
422
423 let write_txn = db.begin_write().unwrap();
425 {
426 let _table = write_txn.open_table(TEST_TABLE).unwrap();
427 }
428 write_txn.commit().unwrap();
429
430 let read_txn = db.begin_read().unwrap();
431 let table = read_txn.open_table(TEST_TABLE).unwrap();
432
433 let head_segment = find_head_segment(&table, b"nonexistent", 0).unwrap();
434 assert_eq!(head_segment, None);
435 }
436
437 #[test]
438 fn test_segment_info() {
439 let segment_info = SegmentInfo::new(42, b"test_key".to_vec());
440 assert_eq!(segment_info.segment_id, 42);
441 assert_eq!(segment_info.segment_key, b"test_key");
442 assert!(segment_info.segment_data.is_none());
443
444 let segment_info = SegmentInfo::with_data(42, b"test_key".to_vec(), b"data".to_vec());
445 assert_eq!(segment_info.segment_data, Some(b"data".to_vec()));
446 }
447}