1use std::fmt::Debug;
2use std::sync::Arc;
3mod hashmap;
4mod manager;
5
6use crate::error::FlareError;
7
8use bytes::Bytes;
9pub use hashmap::HashMapShard;
10pub use hashmap::HashMapShardFactory;
11pub use manager::ShardManager;
12use scc::HashMap;
13
14pub type ShardId = u64;
15
16#[cfg_attr(
17 feature = "rkyv",
18 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize),
19 rkyv(derive(Debug))
20)]
21#[derive(Debug, Default, Clone)]
22pub struct ShardMetadata {
23 pub id: u64,
24 pub collection: String,
25 pub partition_id: u16,
26 pub owner: Option<u64>,
27 pub primary: Option<u64>,
28 pub replica: Vec<u64>,
29 pub replica_owner: Vec<u64>,
30 pub shard_type: String,
31 pub options: std::collections::HashMap<String, String>,
32}
33
34impl ShardMetadata {
35 pub fn into_proto(&self) -> flare_pb::ShardMetadata {
36 flare_pb::ShardMetadata {
37 id: self.id,
38 collection: self.collection.clone(),
39 partition_id: self.partition_id as u32,
40 owner: self.owner,
41 primary: self.primary,
42 replica: self.replica.clone(),
43 shard_type: self.shard_type.clone(),
44 options: self.options.clone(),
45 }
46 }
47}
48
49#[async_trait::async_trait]
50pub trait KvShard: Send + Sync {
51 type Key: Send + Clone;
52 type Entry: Send + Sync + Default;
53
54 fn meta(&self) -> &ShardMetadata;
55
56 async fn initialize(&self) -> Result<(), FlareError> {
57 Ok(())
58 }
59
60 async fn close(&self) -> Result<(), FlareError> {
61 Ok(())
62 }
63
64 fn watch_readiness(&self) -> tokio::sync::watch::Receiver<bool> {
65 let (_, rx) = tokio::sync::watch::channel(true);
66 rx
67 }
68
69 async fn get(
70 &self,
71 key: &Self::Key,
72 ) -> Result<Option<Self::Entry>, FlareError>;
73
74 async fn merge(
83 &self,
84 key: Self::Key,
85 value: Self::Entry,
86 ) -> Result<Self::Entry, FlareError> {
87 self.set(key.to_owned(), value).await?;
88 let item = self.get(&key).await?;
89 match item {
90 Some(entry) => Ok(entry),
91 None => Err(FlareError::InvalidArgument(
92 "Merged result is None".to_string(),
93 )),
94 }
95 }
96
97 async fn set(
98 &self,
99 key: Self::Key,
100 value: Self::Entry,
101 ) -> Result<(), FlareError>;
102
103 async fn delete(&self, key: &Self::Key) -> Result<(), FlareError>;
104}
105
106pub trait ShardEntry: Send + Sync {
107 fn to_vec(&self) -> Vec<u8>;
108 fn from_vec(v: Vec<u8>) -> Self;
109}
110
111#[derive(Debug, Default, Clone)]
112pub struct ByteEntry {
113 pub rc: u16,
114 pub value: Vec<u8>,
115 }
117
118impl ShardEntry for ByteEntry {
119 fn to_vec(&self) -> Vec<u8> {
120 self.value.clone()
121 }
122
123 fn from_vec(v: Vec<u8>) -> Self {
124 ByteEntry { rc: 0, value: v }
125 }
126}
127
128impl From<Vec<u8>> for ByteEntry {
129 #[inline]
130 fn from(v: Vec<u8>) -> Self {
131 ByteEntry { rc: 1, value: v }
132 }
133}
134
135impl From<&Vec<u8>> for ByteEntry {
136 #[inline]
137 fn from(v: &Vec<u8>) -> Self {
138 ByteEntry {
139 rc: 1,
140 value: v.clone(),
141 }
142 }
143}
144
145impl From<Bytes> for ByteEntry {
146 #[inline]
147 fn from(v: Bytes) -> Self {
148 ByteEntry {
149 rc: 1,
150 value: v.to_vec(),
151 }
152 }
153}
154
155impl From<&Bytes> for ByteEntry {
156 #[inline]
157 fn from(v: &Bytes) -> Self {
158 ByteEntry {
159 rc: 1,
160 value: v.to_vec(),
161 }
162 }
163}
164
165pub struct ShardManager2<K, V>
166where
167 K: Send + Clone,
168 V: Send + Sync + Default,
169{
170 pub shard_factory: Box<dyn ShardFactory2<Key = K, Entry = V>>,
171 pub shards: HashMap<ShardId, Arc<dyn KvShard<Key = K, Entry = V>>>,
172}
173
174impl<K, V> ShardManager2<K, V>
175where
176 K: Send + Clone,
177 V: Send + Sync + Default,
178{
179 pub fn new(
180 shard_factory: Box<dyn ShardFactory2<Key = K, Entry = V>>,
181 ) -> Self {
182 Self {
183 shards: HashMap::new(),
184 shard_factory,
185 }
186 }
187
188 #[inline]
189 pub fn get_shard(
190 &self,
191 shard_id: ShardId,
192 ) -> Result<Arc<dyn KvShard<Key = K, Entry = V>>, FlareError> {
193 self.shards
194 .get(&shard_id)
195 .map(|shard| shard.get().to_owned())
196 .ok_or_else(|| FlareError::NoShardFound(shard_id))
197 }
198
199 #[inline]
200 pub fn get_any_shard(
201 &self,
202 shard_ids: &Vec<ShardId>,
203 ) -> Result<Arc<dyn KvShard<Key = K, Entry = V>>, FlareError> {
204 for id in shard_ids.iter() {
205 if let Some(shard) =
206 self.shards.get(id).map(|shard| shard.get().to_owned())
207 {
208 return Ok(shard);
209 }
210 }
211 Err(FlareError::NoShardsFound(shard_ids.clone()))
212 }
213
214 #[inline]
215 pub async fn create_shard(&self, shard_metadata: ShardMetadata) {
216 let shard = self.shard_factory.create_shard(shard_metadata).await;
217 let shard_id = shard.meta().id;
218 shard.initialize().await.unwrap();
219 self.shards.upsert(shard_id, shard);
220 }
221
222 #[inline]
223 pub fn contains(&self, shard_id: ShardId) -> bool {
224 self.shards.contains(&shard_id)
225 }
226
227 pub async fn sync_shards(&self, shard_meta: &Vec<ShardMetadata>) {
228 for s in shard_meta {
229 if self.contains(s.id) {
230 continue;
231 }
232 self.create_shard(s.to_owned()).await;
233 }
234 }
235
236 pub async fn remove_shard(&self, shard_id: ShardId) {
237 if let Some((_, v)) = self.shards.remove(&shard_id) {
238 let _ = v.close().await;
239 }
240 }
241
242 pub async fn close(&self) {
243 let mut iter = self.shards.first_entry_async().await;
244 while let Some(entry) = iter {
245 entry.close().await.expect("close shard failed");
246 iter = entry.next_async().await;
247 }
248 }
249}
250
251pub trait ShardFactory<T>: Send + Sync
252where
253 T: KvShard,
254{
255 fn create_shard(&self, shard_metadata: ShardMetadata) -> Arc<T>;
256}
257
258#[async_trait::async_trait]
259pub trait ShardFactory2: Send + Sync {
260 type Key;
261 type Entry;
262 async fn create_shard(
263 &self,
264 shard_metadata: ShardMetadata,
265 ) -> Arc<dyn KvShard<Key = Self::Key, Entry = Self::Entry>>;
266}