flare_dht/shard/
mod.rs

1use std::fmt::Debug;
2use std::sync::Arc;
3mod hashmap;
4mod manager;
5
6use crate::error::FlareError;
7
8use bytes::Bytes;
9pub use hashmap::HashMapShard;
10pub use hashmap::HashMapShardFactory;
11pub use manager::ShardManager;
12use scc::HashMap;
13
14pub type ShardId = u64;
15
16#[cfg_attr(
17    feature = "rkyv",
18    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize),
19    rkyv(derive(Debug))
20)]
21#[derive(Debug, Default, Clone)]
22pub struct ShardMetadata {
23    pub id: u64,
24    pub collection: String,
25    pub partition_id: u16,
26    pub owner: Option<u64>,
27    pub primary: Option<u64>,
28    pub replica: Vec<u64>,
29    pub replica_owner: Vec<u64>,
30    pub shard_type: String,
31    pub options: std::collections::HashMap<String, String>,
32}
33
34impl ShardMetadata {
35    pub fn into_proto(&self) -> flare_pb::ShardMetadata {
36        flare_pb::ShardMetadata {
37            id: self.id,
38            collection: self.collection.clone(),
39            partition_id: self.partition_id as u32,
40            owner: self.owner,
41            primary: self.primary,
42            replica: self.replica.clone(),
43            shard_type: self.shard_type.clone(),
44            options: self.options.clone(),
45        }
46    }
47}
48
49#[async_trait::async_trait]
50pub trait KvShard: Send + Sync {
51    type Key: Send + Clone;
52    type Entry: Send + Sync + Default;
53
54    fn meta(&self) -> &ShardMetadata;
55
56    async fn initialize(&self) -> Result<(), FlareError> {
57        Ok(())
58    }
59
60    async fn close(&self) -> Result<(), FlareError> {
61        Ok(())
62    }
63
64    fn watch_readiness(&self) -> tokio::sync::watch::Receiver<bool> {
65        let (_, rx) = tokio::sync::watch::channel(true);
66        rx
67    }
68
69    async fn get(
70        &self,
71        key: &Self::Key,
72    ) -> Result<Option<Self::Entry>, FlareError>;
73
74    // async fn modify<F, O>(
75    //     &self,
76    //     key: &Self::Key,
77    //     f: F,
78    // ) -> Result<O, FlareError>
79    // where
80    //     F: FnOnce(&mut Self::Entry) -> O + Send;
81
82    async fn merge(
83        &self,
84        key: Self::Key,
85        value: Self::Entry,
86    ) -> Result<Self::Entry, FlareError> {
87        self.set(key.to_owned(), value).await?;
88        let item = self.get(&key).await?;
89        match item {
90            Some(entry) => Ok(entry),
91            None => Err(FlareError::InvalidArgument(
92                "Merged result is None".to_string(),
93            )),
94        }
95    }
96
97    async fn set(
98        &self,
99        key: Self::Key,
100        value: Self::Entry,
101    ) -> Result<(), FlareError>;
102
103    async fn delete(&self, key: &Self::Key) -> Result<(), FlareError>;
104}
105
106pub trait ShardEntry: Send + Sync {
107    fn to_vec(&self) -> Vec<u8>;
108    fn from_vec(v: Vec<u8>) -> Self;
109}
110
111#[derive(Debug, Default, Clone)]
112pub struct ByteEntry {
113    pub rc: u16,
114    pub value: Vec<u8>,
115    // pub value: Bytes,
116}
117
118impl ShardEntry for ByteEntry {
119    fn to_vec(&self) -> Vec<u8> {
120        self.value.clone()
121    }
122
123    fn from_vec(v: Vec<u8>) -> Self {
124        ByteEntry { rc: 0, value: v }
125    }
126}
127
128impl From<Vec<u8>> for ByteEntry {
129    #[inline]
130    fn from(v: Vec<u8>) -> Self {
131        ByteEntry { rc: 1, value: v }
132    }
133}
134
135impl From<&Vec<u8>> for ByteEntry {
136    #[inline]
137    fn from(v: &Vec<u8>) -> Self {
138        ByteEntry {
139            rc: 1,
140            value: v.clone(),
141        }
142    }
143}
144
145impl From<Bytes> for ByteEntry {
146    #[inline]
147    fn from(v: Bytes) -> Self {
148        ByteEntry {
149            rc: 1,
150            value: v.to_vec(),
151        }
152    }
153}
154
155impl From<&Bytes> for ByteEntry {
156    #[inline]
157    fn from(v: &Bytes) -> Self {
158        ByteEntry {
159            rc: 1,
160            value: v.to_vec(),
161        }
162    }
163}
164
165pub struct ShardManager2<K, V>
166where
167    K: Send + Clone,
168    V: Send + Sync + Default,
169{
170    pub shard_factory: Box<dyn ShardFactory2<Key = K, Entry = V>>,
171    pub shards: HashMap<ShardId, Arc<dyn KvShard<Key = K, Entry = V>>>,
172}
173
174impl<K, V> ShardManager2<K, V>
175where
176    K: Send + Clone,
177    V: Send + Sync + Default,
178{
179    pub fn new(
180        shard_factory: Box<dyn ShardFactory2<Key = K, Entry = V>>,
181    ) -> Self {
182        Self {
183            shards: HashMap::new(),
184            shard_factory,
185        }
186    }
187
188    #[inline]
189    pub fn get_shard(
190        &self,
191        shard_id: ShardId,
192    ) -> Result<Arc<dyn KvShard<Key = K, Entry = V>>, FlareError> {
193        self.shards
194            .get(&shard_id)
195            .map(|shard| shard.get().to_owned())
196            .ok_or_else(|| FlareError::NoShardFound(shard_id))
197    }
198
199    #[inline]
200    pub fn get_any_shard(
201        &self,
202        shard_ids: &Vec<ShardId>,
203    ) -> Result<Arc<dyn KvShard<Key = K, Entry = V>>, FlareError> {
204        for id in shard_ids.iter() {
205            if let Some(shard) =
206                self.shards.get(id).map(|shard| shard.get().to_owned())
207            {
208                return Ok(shard);
209            }
210        }
211        Err(FlareError::NoShardsFound(shard_ids.clone()))
212    }
213
214    #[inline]
215    pub async fn create_shard(&self, shard_metadata: ShardMetadata) {
216        let shard = self.shard_factory.create_shard(shard_metadata).await;
217        let shard_id = shard.meta().id;
218        shard.initialize().await.unwrap();
219        self.shards.upsert(shard_id, shard);
220    }
221
222    #[inline]
223    pub fn contains(&self, shard_id: ShardId) -> bool {
224        self.shards.contains(&shard_id)
225    }
226
227    pub async fn sync_shards(&self, shard_meta: &Vec<ShardMetadata>) {
228        for s in shard_meta {
229            if self.contains(s.id) {
230                continue;
231            }
232            self.create_shard(s.to_owned()).await;
233        }
234    }
235
236    pub async fn remove_shard(&self, shard_id: ShardId) {
237        if let Some((_, v)) = self.shards.remove(&shard_id) {
238            let _ = v.close().await;
239        }
240    }
241
242    pub async fn close(&self) {
243        let mut iter = self.shards.first_entry_async().await;
244        while let Some(entry) = iter {
245            entry.close().await.expect("close shard failed");
246            iter = entry.next_async().await;
247        }
248    }
249}
250
251pub trait ShardFactory<T>: Send + Sync
252where
253    T: KvShard,
254{
255    fn create_shard(&self, shard_metadata: ShardMetadata) -> Arc<T>;
256}
257
258#[async_trait::async_trait]
259pub trait ShardFactory2: Send + Sync {
260    type Key;
261    type Entry;
262    async fn create_shard(
263        &self,
264        shard_metadata: ShardMetadata,
265    ) -> Arc<dyn KvShard<Key = Self::Key, Entry = Self::Entry>>;
266}