1use apfsds_protocol::{ConnMeta, ConnRecord};
4use parking_lot::RwLock;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU64, Ordering};
7use thiserror::Error;
8
9use crate::{BLinkTree, Segment, SegmentPtr};
10
11#[derive(Error, Debug)]
12pub enum StorageError {
13 #[error("Segment full")]
14 SegmentFull,
15
16 #[error("Record not found")]
17 NotFound,
18
19 #[error("Serialization error: {0}")]
20 SerializationError(String),
21}
22
23#[derive(Debug, Clone)]
25pub struct StorageConfig {
26 pub segment_size_limit: usize,
28
29 pub compaction_threshold: usize,
31
32 pub cleanup_interval: u64,
34}
35
36impl Default for StorageConfig {
37 fn default() -> Self {
38 Self {
39 segment_size_limit: 10 * 1024 * 1024, compaction_threshold: 10,
41 cleanup_interval: 300, }
43 }
44}
45
46pub struct StorageEngine {
48 active_segment: RwLock<Segment>,
50
51 sealed_segments: RwLock<Vec<Segment>>,
53
54 index: Arc<BLinkTree>,
56
57 txid_counter: AtomicU64,
59
60 config: StorageConfig,
62}
63
64impl StorageEngine {
65 pub fn new(config: StorageConfig) -> Self {
67 let segment = Segment::with_size_limit(config.segment_size_limit);
68
69 Self {
70 active_segment: RwLock::new(segment),
71 sealed_segments: RwLock::new(Vec::new()),
72 index: Arc::new(BLinkTree::new()),
73 txid_counter: AtomicU64::new(1),
74 config,
75 }
76 }
77
78 pub fn next_txid(&self) -> u64 {
80 self.txid_counter.fetch_add(1, Ordering::SeqCst)
81 }
82
83 pub fn upsert(&self, conn_id: u64, metadata: ConnMeta) -> Result<u64, StorageError> {
85 let txid = self.next_txid();
86 let now = std::time::SystemTime::now()
87 .duration_since(std::time::UNIX_EPOCH)
88 .unwrap()
89 .as_millis() as u64;
90
91 let record = ConnRecord {
92 conn_id,
93 metadata,
94 created_at: now,
95 last_active: now,
96 access_count: 1,
97 txid,
98 };
99
100 let mut segment = self.active_segment.write();
102 let offset = segment.append(&record);
103
104 match offset {
105 Some(offset) => {
106 let ptr = SegmentPtr {
107 segment_id: segment.id,
108 offset,
109 };
110 self.index.insert(conn_id, ptr);
111 Ok(txid)
112 }
113 None => {
114 drop(segment);
116 self.rotate_segment()?;
117
118 let mut segment = self.active_segment.write();
120 let offset = segment.append(&record).ok_or(StorageError::SegmentFull)?;
121
122 let ptr = SegmentPtr {
123 segment_id: segment.id,
124 offset,
125 };
126 self.index.insert(conn_id, ptr);
127 Ok(txid)
128 }
129 }
130 }
131
132 pub fn get(&self, conn_id: u64) -> Option<ConnRecord> {
134 let ptr = self.index.search(conn_id)?;
135
136 let active = self.active_segment.read();
138 if ptr.segment_id == active.id {
139 return active.read_at(ptr.offset);
140 }
141 drop(active);
142
143 let sealed = self.sealed_segments.read();
145 for segment in sealed.iter() {
146 if ptr.segment_id == segment.id {
147 return segment.read_at(ptr.offset);
148 }
149 }
150
151 None
152 }
153
154 pub fn delete(&self, conn_id: u64) -> Option<SegmentPtr> {
156 self.index.remove(conn_id)
157 }
158
159 fn rotate_segment(&self) -> Result<(), StorageError> {
161 let mut active = self.active_segment.write();
162 let mut sealed = self.sealed_segments.write();
163
164 let mut old_segment = std::mem::replace(
166 &mut *active,
167 Segment::with_size_limit(self.config.segment_size_limit),
168 );
169 old_segment.seal();
170
171 sealed.push(old_segment);
172
173 if sealed.len() > self.config.compaction_threshold {
175 tracing::info!(
178 "Compaction threshold reached: {} sealed segments",
179 sealed.len()
180 );
181 }
182
183 Ok(())
184 }
185
186 pub fn stats(&self) -> StorageStats {
188 let active = self.active_segment.read();
189 let sealed = self.sealed_segments.read();
190
191 StorageStats {
192 active_segment_size: active.size(),
193 active_record_count: active.record_count(),
194 sealed_segment_count: sealed.len(),
195 total_indexed: self.index.len(),
196 }
197 }
198}
199
200#[derive(Debug, Clone)]
202pub struct StorageStats {
203 pub active_segment_size: usize,
204 pub active_record_count: usize,
205 pub sealed_segment_count: usize,
206 pub total_indexed: usize,
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 fn make_meta() -> ConnMeta {
214 ConnMeta {
215 client_addr: [127, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
216 nat_entry: (1234, 5678),
217 assigned_pod: 1,
218 stream_states: vec![],
219 }
220 }
221
222 #[test]
223 fn test_upsert_and_get() {
224 let engine = StorageEngine::new(StorageConfig::default());
225
226 let meta = make_meta();
227 engine.upsert(42, meta.clone()).unwrap();
228
229 let record = engine.get(42).unwrap();
230 assert_eq!(record.conn_id, 42);
231 }
232
233 #[test]
234 fn test_delete() {
235 let engine = StorageEngine::new(StorageConfig::default());
236
237 let meta = make_meta();
238 engine.upsert(42, meta).unwrap();
239 assert!(engine.get(42).is_some());
240
241 engine.delete(42);
242 assert!(engine.get(42).is_none());
243 }
244
245 #[test]
246 fn test_stats() {
247 let engine = StorageEngine::new(StorageConfig::default());
248
249 for i in 0..10 {
250 engine.upsert(i, make_meta()).unwrap();
251 }
252
253 let stats = engine.stats();
254 assert_eq!(stats.total_indexed, 10);
255 assert_eq!(stats.active_record_count, 10);
256 }
257}