bonsaidb_core/schema/
schematic.rs

1use std::any::TypeId;
2use std::collections::{hash_map, HashMap};
3use std::fmt::Debug;
4use std::marker::PhantomData;
5
6use derive_where::derive_where;
7
8use crate::document::{BorrowedDocument, DocumentId, KeyId};
9use crate::key::{ByteSource, Key, KeyDescription};
10use crate::schema::collection::Collection;
11use crate::schema::view::map::{self, MappedValue};
12use crate::schema::view::{
13    self, MapReduce, Serialized, SerializedView, ViewSchema, ViewUpdatePolicy,
14};
15use crate::schema::{CollectionName, Schema, SchemaName, View, ViewName};
16use crate::Error;
17
18/// A collection of defined collections and views.
19pub struct Schematic {
20    /// The name of the schema this was built from.
21    pub name: SchemaName,
22    contained_collections: HashMap<CollectionName, KeyDescription>,
23    collections_by_type_id: HashMap<TypeId, CollectionName>,
24    collection_encryption_keys: HashMap<CollectionName, KeyId>,
25    collection_id_generators: HashMap<CollectionName, Box<dyn IdGenerator>>,
26    views: HashMap<TypeId, Box<dyn view::Serialized>>,
27    views_by_name: HashMap<ViewName, TypeId>,
28    views_by_collection: HashMap<CollectionName, Vec<TypeId>>,
29    eager_views_by_collection: HashMap<CollectionName, Vec<TypeId>>,
30}
31
32impl Schematic {
33    /// Returns an initialized version from `S`.
34    pub fn from_schema<S: Schema + ?Sized>() -> Result<Self, Error> {
35        let mut schematic = Self {
36            name: S::schema_name(),
37            contained_collections: HashMap::new(),
38            collections_by_type_id: HashMap::new(),
39            collection_encryption_keys: HashMap::new(),
40            collection_id_generators: HashMap::new(),
41            views: HashMap::new(),
42            views_by_name: HashMap::new(),
43            views_by_collection: HashMap::new(),
44            eager_views_by_collection: HashMap::new(),
45        };
46        S::define_collections(&mut schematic)?;
47        Ok(schematic)
48    }
49
50    /// Adds the collection `C` and its views.
51    pub fn define_collection<C: Collection + 'static>(&mut self) -> Result<(), Error> {
52        let name = C::collection_name();
53        match self.contained_collections.entry(name.clone()) {
54            hash_map::Entry::Vacant(entry) => {
55                self.collections_by_type_id
56                    .insert(TypeId::of::<C>(), name.clone());
57                if let Some(key) = C::encryption_key() {
58                    self.collection_encryption_keys.insert(name.clone(), key);
59                }
60                self.collection_id_generators
61                    .insert(name, Box::<KeyIdGenerator<C>>::default());
62                entry.insert(KeyDescription::for_key::<C::PrimaryKey>());
63                C::define_views(self)
64            }
65            hash_map::Entry::Occupied(_) => Err(Error::CollectionAlreadyDefined),
66        }
67    }
68
69    /// Adds the view `V`.
70    pub fn define_view<V: MapReduce + ViewSchema<View = V> + SerializedView + Clone + 'static>(
71        &mut self,
72        view: V,
73    ) -> Result<(), Error> {
74        self.define_view_with_schema(view.clone(), view)
75    }
76
77    /// Adds the view `V`.
78    pub fn define_view_with_schema<
79        V: SerializedView + 'static,
80        S: MapReduce + ViewSchema<View = V> + 'static,
81    >(
82        &mut self,
83        view: V,
84        schema: S,
85    ) -> Result<(), Error> {
86        let instance = ViewInstance { view, schema };
87        let name = instance.view_name();
88        if self.views_by_name.contains_key(&name) {
89            return Err(Error::ViewAlreadyRegistered(name));
90        }
91
92        let collection = instance.collection();
93        let eager = instance.update_policy().is_eager();
94        self.views.insert(TypeId::of::<V>(), Box::new(instance));
95        self.views_by_name.insert(name, TypeId::of::<V>());
96
97        if eager {
98            let unique_views = self
99                .eager_views_by_collection
100                .entry(collection.clone())
101                .or_insert_with(Vec::new);
102            unique_views.push(TypeId::of::<V>());
103        }
104        let views = self
105            .views_by_collection
106            .entry(collection)
107            .or_insert_with(Vec::new);
108        views.push(TypeId::of::<V>());
109
110        Ok(())
111    }
112
113    /// Returns `true` if this schema contains the collection `C`.
114    #[must_use]
115    pub fn contains_collection<C: Collection + 'static>(&self) -> bool {
116        self.collections_by_type_id.contains_key(&TypeId::of::<C>())
117    }
118
119    /// Returns the description of the primary keyof the collection with the
120    /// given name, or `None` if the collection can't be found.
121    #[must_use]
122    pub fn collection_primary_key_description<'a>(
123        &'a self,
124        collection: &CollectionName,
125    ) -> Option<&'a KeyDescription> {
126        self.contained_collections.get(collection)
127    }
128
129    /// Returns the next id in sequence for the collection, if the primary key
130    /// type supports the operation and the next id would not overflow.
131    pub fn next_id_for_collection(
132        &self,
133        collection: &CollectionName,
134        id: Option<DocumentId>,
135    ) -> Result<DocumentId, Error> {
136        let generator = self
137            .collection_id_generators
138            .get(collection)
139            .ok_or(Error::CollectionNotFound)?;
140        generator.next_id(id)
141    }
142
143    /// Looks up a [`view::Serialized`] by name.
144    pub fn view_by_name(&self, name: &ViewName) -> Result<&'_ dyn view::Serialized, Error> {
145        self.views_by_name
146            .get(name)
147            .and_then(|type_id| self.views.get(type_id))
148            .map(AsRef::as_ref)
149            .ok_or(Error::ViewNotFound)
150    }
151
152    /// Looks up a [`view::Serialized`] through the the type `V`.
153    pub fn view<V: View + 'static>(&self) -> Result<&'_ dyn view::Serialized, Error> {
154        self.views
155            .get(&TypeId::of::<V>())
156            .map(AsRef::as_ref)
157            .ok_or(Error::ViewNotFound)
158    }
159
160    /// Iterates over all registered views.
161    pub fn views(&self) -> impl Iterator<Item = &'_ dyn view::Serialized> {
162        self.views.values().map(AsRef::as_ref)
163    }
164
165    /// Iterates over all views that belong to `collection`.
166    pub fn views_in_collection(
167        &self,
168        collection: &CollectionName,
169    ) -> impl Iterator<Item = &'_ dyn view::Serialized> {
170        self.views_by_collection
171            .get(collection)
172            .into_iter()
173            .flat_map(|view_ids| {
174                view_ids
175                    .iter()
176                    .filter_map(|id| self.views.get(id).map(AsRef::as_ref))
177            })
178    }
179
180    /// Iterates over all views that are eagerly updated that belong to
181    /// `collection`.
182    pub fn eager_views_in_collection(
183        &self,
184        collection: &CollectionName,
185    ) -> impl Iterator<Item = &'_ dyn view::Serialized> {
186        self.eager_views_by_collection
187            .get(collection)
188            .into_iter()
189            .flat_map(|view_ids| {
190                view_ids
191                    .iter()
192                    .filter_map(|id| self.views.get(id).map(AsRef::as_ref))
193            })
194    }
195
196    /// Returns a collection's default encryption key, if one was defined.
197    #[must_use]
198    pub fn encryption_key_for_collection(&self, collection: &CollectionName) -> Option<&KeyId> {
199        self.collection_encryption_keys.get(collection)
200    }
201
202    /// Returns a list of all collections contained in this schematic.
203    pub fn collections(&self) -> impl Iterator<Item = &CollectionName> {
204        self.contained_collections.keys()
205    }
206}
207
208impl Debug for Schematic {
209    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210        let mut views = self
211            .views
212            .values()
213            .map(|v| v.view_name())
214            .collect::<Vec<_>>();
215        views.sort();
216
217        f.debug_struct("Schematic")
218            .field("name", &self.name)
219            .field("contained_collections", &self.contained_collections)
220            .field("collections_by_type_id", &self.collections_by_type_id)
221            .field(
222                "collection_encryption_keys",
223                &self.collection_encryption_keys,
224            )
225            .field("collection_id_generators", &self.collection_id_generators)
226            .field("views", &views)
227            .field("views_by_name", &self.views_by_name)
228            .field("views_by_collection", &self.views_by_collection)
229            .field("eager_views_by_collection", &self.eager_views_by_collection)
230            .finish()
231    }
232}
233
234#[derive(Debug)]
235struct ViewInstance<V, S> {
236    view: V,
237    schema: S,
238}
239
240impl<V, S> Serialized for ViewInstance<V, S>
241where
242    V: SerializedView,
243    S: MapReduce + ViewSchema<View = V>,
244{
245    fn collection(&self) -> CollectionName {
246        <<V as View>::Collection as Collection>::collection_name()
247    }
248
249    fn key_description(&self) -> KeyDescription {
250        KeyDescription::for_key::<<V as View>::Key>()
251    }
252
253    fn update_policy(&self) -> ViewUpdatePolicy {
254        self.schema.update_policy()
255    }
256
257    fn version(&self) -> u64 {
258        self.schema.version()
259    }
260
261    fn view_name(&self) -> ViewName {
262        self.view.view_name()
263    }
264
265    fn map(&self, document: &BorrowedDocument<'_>) -> Result<Vec<map::Serialized>, view::Error> {
266        let mappings = self.schema.map(document)?;
267
268        mappings
269            .iter()
270            .map(map::Map::serialized::<V>)
271            .collect::<Result<_, _>>()
272            .map_err(view::Error::key_serialization)
273    }
274
275    fn reduce(&self, mappings: &[(&[u8], &[u8])], rereduce: bool) -> Result<Vec<u8>, view::Error> {
276        let mappings = mappings
277            .iter()
278            .map(|(key, value)| {
279                match <S::MappedKey<'_> as Key>::from_ord_bytes(ByteSource::Borrowed(key)) {
280                    Ok(key) => {
281                        let value = V::deserialize(value)?;
282                        Ok(MappedValue::new(key, value))
283                    }
284                    Err(err) => Err(view::Error::key_serialization(err)),
285                }
286            })
287            .collect::<Result<Vec<_>, view::Error>>()?;
288
289        let reduced_value = self.schema.reduce(&mappings, rereduce)?;
290
291        V::serialize(&reduced_value).map_err(view::Error::from)
292    }
293}
294
295pub trait IdGenerator: Debug + Send + Sync {
296    fn next_id(&self, id: Option<DocumentId>) -> Result<DocumentId, Error>;
297}
298
299#[derive_where(Default, Debug)]
300pub struct KeyIdGenerator<C: Collection>(PhantomData<C>);
301
302impl<C> IdGenerator for KeyIdGenerator<C>
303where
304    C: Collection,
305{
306    fn next_id(&self, id: Option<DocumentId>) -> Result<DocumentId, Error> {
307        let key = id.map(|id| id.deserialize::<C::PrimaryKey>()).transpose()?;
308        let key = if let Some(key) = key {
309            key
310        } else {
311            <C::PrimaryKey as Key<'_>>::first_value()
312                .map_err(|err| Error::DocumentPush(C::collection_name(), err))?
313        };
314        let next_value = key
315            .next_value()
316            .map_err(|err| Error::DocumentPush(C::collection_name(), err))?;
317        DocumentId::new(&next_value)
318    }
319}
320
321#[test]
322fn schema_tests() -> anyhow::Result<()> {
323    use crate::test_util::{Basic, BasicCount};
324    let schema = Schematic::from_schema::<Basic>()?;
325
326    assert_eq!(schema.collections_by_type_id.len(), 1);
327    assert_eq!(
328        schema.collections_by_type_id[&TypeId::of::<Basic>()],
329        Basic::collection_name()
330    );
331    assert_eq!(schema.views.len(), 6);
332    assert_eq!(
333        schema.views[&TypeId::of::<BasicCount>()].view_name(),
334        View::view_name(&BasicCount)
335    );
336
337    Ok(())
338}