1use std::collections::hash_map::DefaultHasher;
8use std::hash::{Hash, Hasher};
9
10use crate::dropper::DropHandle;
11use crate::error::ShardError;
12use crate::keyspace::ShardConfig;
13use crate::shard::{self, ShardHandle, ShardPersistenceConfig, ShardRequest, ShardResponse};
14
15const SHARD_BUFFER: usize = 256;
18
19#[derive(Debug, Clone, Default)]
21pub struct EngineConfig {
22 pub shard: ShardConfig,
24 pub persistence: Option<ShardPersistenceConfig>,
27 #[cfg(feature = "protobuf")]
30 pub schema_registry: Option<crate::schema::SharedSchemaRegistry>,
31}
32
33#[derive(Debug, Clone)]
39pub struct Engine {
40 shards: Vec<ShardHandle>,
41 #[cfg(feature = "protobuf")]
42 schema_registry: Option<crate::schema::SharedSchemaRegistry>,
43}
44
45impl Engine {
46 pub fn new(shard_count: usize) -> Self {
51 Self::with_config(shard_count, EngineConfig::default())
52 }
53
54 pub fn with_config(shard_count: usize, config: EngineConfig) -> Self {
61 assert!(shard_count > 0, "shard count must be at least 1");
62
63 let drop_handle = DropHandle::spawn();
64
65 let shards = (0..shard_count)
66 .map(|i| {
67 let mut shard_config = config.shard.clone();
68 shard_config.shard_id = i as u16;
69 shard::spawn_shard(
70 SHARD_BUFFER,
71 shard_config,
72 config.persistence.clone(),
73 Some(drop_handle.clone()),
74 #[cfg(feature = "protobuf")]
75 config.schema_registry.clone(),
76 )
77 })
78 .collect();
79
80 Self {
81 shards,
82 #[cfg(feature = "protobuf")]
83 schema_registry: config.schema_registry,
84 }
85 }
86
87 pub fn with_available_cores() -> Self {
91 Self::with_available_cores_config(EngineConfig::default())
92 }
93
94 pub fn with_available_cores_config(config: EngineConfig) -> Self {
97 let cores = std::thread::available_parallelism()
98 .map(|n| n.get())
99 .unwrap_or(1);
100 Self::with_config(cores, config)
101 }
102
103 #[cfg(feature = "protobuf")]
105 pub fn schema_registry(&self) -> Option<&crate::schema::SharedSchemaRegistry> {
106 self.schema_registry.as_ref()
107 }
108
109 pub fn shard_count(&self) -> usize {
111 self.shards.len()
112 }
113
114 pub async fn send_to_shard(
118 &self,
119 shard_idx: usize,
120 request: ShardRequest,
121 ) -> Result<ShardResponse, ShardError> {
122 if shard_idx >= self.shards.len() {
123 return Err(ShardError::Unavailable);
124 }
125 self.shards[shard_idx].send(request).await
126 }
127
128 pub async fn route(
130 &self,
131 key: &str,
132 request: ShardRequest,
133 ) -> Result<ShardResponse, ShardError> {
134 let idx = self.shard_for_key(key);
135 self.shards[idx].send(request).await
136 }
137
138 pub async fn broadcast<F>(&self, make_req: F) -> Result<Vec<ShardResponse>, ShardError>
144 where
145 F: Fn() -> ShardRequest,
146 {
147 let mut receivers = Vec::with_capacity(self.shards.len());
149 for shard in &self.shards {
150 receivers.push(shard.dispatch(make_req()).await?);
151 }
152
153 let mut results = Vec::with_capacity(receivers.len());
155 for rx in receivers {
156 results.push(rx.await.map_err(|_| ShardError::Unavailable)?);
157 }
158 Ok(results)
159 }
160
161 pub async fn route_multi<F>(
167 &self,
168 keys: &[String],
169 make_req: F,
170 ) -> Result<Vec<ShardResponse>, ShardError>
171 where
172 F: Fn(String) -> ShardRequest,
173 {
174 let mut receivers = Vec::with_capacity(keys.len());
175 for key in keys {
176 let idx = self.shard_for_key(key);
177 let rx = self.shards[idx].dispatch(make_req(key.clone())).await?;
178 receivers.push(rx);
179 }
180
181 let mut results = Vec::with_capacity(receivers.len());
182 for rx in receivers {
183 results.push(rx.await.map_err(|_| ShardError::Unavailable)?);
184 }
185 Ok(results)
186 }
187
188 fn shard_for_key(&self, key: &str) -> usize {
190 shard_index(key, self.shards.len())
191 }
192}
193
194fn shard_index(key: &str, shard_count: usize) -> usize {
200 let mut hasher = DefaultHasher::new();
201 key.hash(&mut hasher);
202 (hasher.finish() as usize) % shard_count
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use crate::types::Value;
209 use bytes::Bytes;
210
211 #[test]
212 fn same_key_same_shard() {
213 let idx1 = shard_index("foo", 8);
214 let idx2 = shard_index("foo", 8);
215 assert_eq!(idx1, idx2);
216 }
217
218 #[test]
219 fn keys_spread_across_shards() {
220 let mut seen = std::collections::HashSet::new();
221 for i in 0..100 {
223 let key = format!("key:{i}");
224 seen.insert(shard_index(&key, 4));
225 }
226 assert!(seen.len() > 1, "expected keys to spread across shards");
227 }
228
229 #[test]
230 fn single_shard_always_zero() {
231 assert_eq!(shard_index("anything", 1), 0);
232 assert_eq!(shard_index("other", 1), 0);
233 }
234
235 #[tokio::test]
236 async fn engine_round_trip() {
237 let engine = Engine::new(4);
238
239 let resp = engine
240 .route(
241 "greeting",
242 ShardRequest::Set {
243 key: "greeting".into(),
244 value: Bytes::from("hello"),
245 expire: None,
246 nx: false,
247 xx: false,
248 },
249 )
250 .await
251 .unwrap();
252 assert!(matches!(resp, ShardResponse::Ok));
253
254 let resp = engine
255 .route(
256 "greeting",
257 ShardRequest::Get {
258 key: "greeting".into(),
259 },
260 )
261 .await
262 .unwrap();
263 match resp {
264 ShardResponse::Value(Some(Value::String(data))) => {
265 assert_eq!(data, Bytes::from("hello"));
266 }
267 other => panic!("expected Value(Some(String)), got {other:?}"),
268 }
269 }
270
271 #[tokio::test]
272 async fn multi_shard_del() {
273 let engine = Engine::new(4);
274
275 for key in &["a", "b", "c", "d"] {
277 engine
278 .route(
279 key,
280 ShardRequest::Set {
281 key: key.to_string(),
282 value: Bytes::from("v"),
283 expire: None,
284 nx: false,
285 xx: false,
286 },
287 )
288 .await
289 .unwrap();
290 }
291
292 let mut count = 0i64;
294 for key in &["a", "b", "c", "d", "missing"] {
295 let resp = engine
296 .route(
297 key,
298 ShardRequest::Del {
299 key: key.to_string(),
300 },
301 )
302 .await
303 .unwrap();
304 if let ShardResponse::Bool(true) = resp {
305 count += 1;
306 }
307 }
308 assert_eq!(count, 4);
309 }
310
311 #[test]
312 #[should_panic(expected = "shard count must be at least 1")]
313 fn zero_shards_panics() {
314 Engine::new(0);
315 }
316}