markov_generator/
map.rs

1/*
2 * Copyright (C) 2024 taylor.fish <contact@taylor.fish>
3 *
4 * This file is part of markov-generator.
5 *
6 * markov-generator is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation, either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * markov-generator is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with markov-generator. If not, see <https://www.gnu.org/licenses/>.
18 */
19
20//! Options for customizing the map type used by [`Chain`].
21//!
22//! The types in this module that implement [`Map`] can be provided as the
23//! second type parameter to [`Chain`] in order to change the type of map it
24//! uses for internal storage. For example, <code>[Chain]<T, [BTree]></code>
25//! will use [`BTreeMap`]s, while <code>[Chain]<T, [Hash]></code> will use
26//! [`HashMap`]s.
27//!
28//! [`Chain`]: crate::Chain
29//! [Chain]: crate::Chain
30//! [`BTreeMap`]: alloc::collections::BTreeMap
31//! [Hash]: self::Hash
32//! [`HashMap`]: std::collections::HashMap
33
34use alloc::collections::VecDeque;
35use core::borrow::Borrow;
36use core::fmt::{self, Debug};
37
38/// Represents a [`BTreeMap`].
39///
40/// [`Chain`] will use [`BTreeMap`]s internally when this type is provided as
41/// its second type parameter.
42///
43/// [`BTreeMap`]: alloc::collections::BTreeMap
44/// [`Chain`]: crate::Chain
45pub struct BTree(());
46
47#[cfg(feature = "std")]
48#[cfg_attr(feature = "doc_cfg", doc(cfg(feature = "std")))]
49/// Represents a [`HashMap`].
50///
51/// [`Chain`] will use [`HashMap`]s internally when this type is provided as
52/// its second type parameter.
53///
54/// [`HashMap`]: std::collections::HashMap
55/// [`Chain`]: crate::Chain
56pub struct Hash(());
57
58pub(crate) mod detail {
59    use alloc::boxed::Box;
60    use core::fmt::{self, Debug};
61    #[cfg(feature = "serde")]
62    use serde::{Deserialize, Serialize};
63
64    pub trait MapFrom<K> {
65        type To<V>: MapOps<K, V>;
66    }
67
68    pub trait MapFromSlice<K>: MapFrom<OwnedSliceKey<K>> {
69        type To<V>: MapOpsSlice<K, V>;
70    }
71
72    pub trait MapOps<K, V>: Default {
73        type Iter<'a>: Iterator<Item = (&'a K, &'a V)>
74        where
75            K: 'a,
76            V: 'a,
77            Self: 'a;
78
79        type IterMut<'a>: Iterator<Item = (&'a K, &'a mut V)>
80        where
81            K: 'a,
82            V: 'a,
83            Self: 'a;
84
85        fn iter(&self) -> Self::Iter<'_>;
86        #[allow(dead_code)]
87        fn iter_mut(&mut self) -> Self::IterMut<'_>;
88        #[allow(dead_code)]
89        fn get(&self, k: &K) -> Option<&V>;
90        fn get_or_insert_with<F>(&mut self, k: K, v: F) -> &mut V
91        where
92            F: FnOnce() -> V;
93
94        fn debug(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
95        where
96            V: Debug,
97            K: Debug;
98
99        fn clone(&self) -> Self
100        where
101            K: Clone,
102            V: Clone;
103
104        #[cfg(feature = "serde")]
105        fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
106        where
107            K: serde::Serialize,
108            V: serde::Serialize,
109            S: serde::Serializer;
110
111        #[cfg(feature = "serde")]
112        fn deserialize<'de, D>(d: D) -> Result<Self, D::Error>
113        where
114            K: serde::Deserialize<'de>,
115            V: serde::Deserialize<'de>,
116            D: serde::Deserializer<'de>;
117    }
118
119    pub trait MapOpsSlice<K, V>: MapOps<OwnedSliceKey<K>, V> {
120        fn slice_get<'a, S>(&'a self, k: &S) -> Option<&'a V>
121        where
122            S: SliceKey<K>;
123
124        fn slice_get_or_insert_with<S, F>(&mut self, k: S, v: F) -> &mut V
125        where
126            S: SliceKey<K> + Into<OwnedSliceKey<K>>,
127            F: FnOnce() -> V;
128    }
129
130    pub trait SliceKey<T> {
131        fn get(&self, i: usize) -> Option<&T>;
132    }
133
134    #[derive(Clone, Debug)]
135    #[cfg_attr(
136        feature = "serde",
137        derive(Serialize, Deserialize),
138        serde(transparent)
139    )]
140    pub struct OwnedSliceKey<T>(pub Box<[T]>);
141
142    pub trait Map<K>: MapFrom<K> + MapFromSlice<K> {}
143
144    impl<T, K> Map<K> for T where T: MapFrom<K> + MapFromSlice<K> {}
145
146    pub use Map as Sealed;
147}
148
149pub(crate) use detail::{MapFrom, MapFromSlice, MapOps, MapOpsSlice};
150pub(crate) use detail::{OwnedSliceKey, SliceKey};
151
152#[cfg(feature = "std")]
153impl<K: Eq + std::hash::Hash> MapFrom<K> for Hash {
154    type To<V> = std::collections::HashMap<K, V>;
155}
156
157#[cfg(feature = "std")]
158impl<K: Eq + std::hash::Hash> MapFromSlice<K> for Hash {
159    type To<V> = std::collections::HashMap<OwnedSliceKey<K>, V>;
160}
161
162#[cfg(feature = "std")]
163impl<K, V> MapOps<K, V> for std::collections::HashMap<K, V>
164where
165    K: Eq + std::hash::Hash,
166{
167    type Iter<'a>
168        = std::collections::hash_map::Iter<'a, K, V>
169    where
170        K: 'a,
171        V: 'a;
172
173    type IterMut<'a>
174        = std::collections::hash_map::IterMut<'a, K, V>
175    where
176        K: 'a,
177        V: 'a;
178
179    fn iter(&self) -> Self::Iter<'_> {
180        self.iter()
181    }
182
183    fn iter_mut(&mut self) -> Self::IterMut<'_> {
184        self.iter_mut()
185    }
186
187    fn get(&self, k: &K) -> Option<&V> {
188        self.get(k)
189    }
190
191    fn get_or_insert_with<F>(&mut self, k: K, v: F) -> &mut V
192    where
193        F: FnOnce() -> V,
194    {
195        self.entry(k).or_insert_with(v)
196    }
197
198    fn debug(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
199    where
200        V: Debug,
201        K: Debug,
202    {
203        Debug::fmt(self, f)
204    }
205
206    fn clone(&self) -> Self
207    where
208        K: Clone,
209        V: Clone,
210    {
211        Clone::clone(self)
212    }
213
214    #[cfg(feature = "serde")]
215    fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
216    where
217        K: serde::Serialize,
218        V: serde::Serialize,
219        S: serde::Serializer,
220    {
221        serde::Serialize::serialize(self, s)
222    }
223
224    #[cfg(feature = "serde")]
225    fn deserialize<'de, D>(d: D) -> Result<Self, D::Error>
226    where
227        K: serde::Deserialize<'de>,
228        V: serde::Deserialize<'de>,
229        D: serde::Deserializer<'de>,
230    {
231        serde::Deserialize::deserialize(d)
232    }
233}
234
235#[cfg(feature = "std")]
236impl<K, V> MapOpsSlice<K, V> for std::collections::HashMap<OwnedSliceKey<K>, V>
237where
238    K: Eq + std::hash::Hash,
239{
240    fn slice_get<'a, S>(&'a self, k: &S) -> Option<&'a V>
241    where
242        S: SliceKey<K>,
243    {
244        Self::get::<dyn SliceKey<K> + '_>(self, k)
245    }
246
247    fn slice_get_or_insert_with<S, F>(&mut self, k: S, v: F) -> &mut V
248    where
249        S: SliceKey<K> + Into<OwnedSliceKey<K>>,
250        F: FnOnce() -> V,
251    {
252        if let Some(v) = Self::get_mut::<dyn SliceKey<K> + '_>(self, &k) {
253            #[allow(
254                unsafe_code,
255                /* reason = "workaround for rust issue #51545" */
256            )]
257            // SAFETY: `v` is borrowed from `m` (and only `m`), so it is sound
258            // to return it with the same lifetime as `m`. Due to issues with
259            // Rust's borrow checker (#51545, #54663), this requires a lifetime
260            // extension, performed here by converting to a pointer and
261            // immediately dereferencing.
262            return unsafe { core::ptr::NonNull::from(v).as_mut() };
263        }
264        MapOps::get_or_insert_with(self, k.into(), v)
265    }
266}
267
268impl<K: Ord> MapFrom<K> for BTree {
269    type To<V> = alloc::collections::BTreeMap<K, V>;
270}
271
272impl<K: Ord> MapFromSlice<K> for BTree {
273    type To<V> = alloc::collections::BTreeMap<OwnedSliceKey<K>, V>;
274}
275
276impl<K: Ord, V> MapOps<K, V> for alloc::collections::BTreeMap<K, V> {
277    type Iter<'a>
278        = alloc::collections::btree_map::Iter<'a, K, V>
279    where
280        K: 'a,
281        V: 'a;
282
283    type IterMut<'a>
284        = alloc::collections::btree_map::IterMut<'a, K, V>
285    where
286        K: 'a,
287        V: 'a;
288
289    fn iter(&self) -> Self::Iter<'_> {
290        self.iter()
291    }
292
293    fn iter_mut(&mut self) -> Self::IterMut<'_> {
294        self.iter_mut()
295    }
296
297    fn get(&self, k: &K) -> Option<&V> {
298        self.get(k)
299    }
300
301    fn get_or_insert_with<F>(&mut self, k: K, v: F) -> &mut V
302    where
303        F: FnOnce() -> V,
304    {
305        self.entry(k).or_insert_with(v)
306    }
307
308    fn debug(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
309    where
310        V: Debug,
311        K: Debug,
312    {
313        Debug::fmt(self, f)
314    }
315
316    fn clone(&self) -> Self
317    where
318        K: Clone,
319        V: Clone,
320    {
321        Clone::clone(self)
322    }
323
324    #[cfg(feature = "serde")]
325    fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
326    where
327        K: serde::Serialize,
328        V: serde::Serialize,
329        S: serde::Serializer,
330    {
331        serde::Serialize::serialize(self, s)
332    }
333
334    #[cfg(feature = "serde")]
335    fn deserialize<'de, D>(d: D) -> Result<Self, D::Error>
336    where
337        K: serde::Deserialize<'de>,
338        V: serde::Deserialize<'de>,
339        D: serde::Deserializer<'de>,
340    {
341        serde::Deserialize::deserialize(d)
342    }
343}
344
345impl<K, V> MapOpsSlice<K, V>
346    for alloc::collections::BTreeMap<OwnedSliceKey<K>, V>
347where
348    K: Ord,
349{
350    fn slice_get<'a, S>(&'a self, k: &S) -> Option<&'a V>
351    where
352        S: SliceKey<K>,
353    {
354        Self::get::<dyn SliceKey<K> + '_>(self, k)
355    }
356
357    fn slice_get_or_insert_with<S, F>(&mut self, k: S, v: F) -> &mut V
358    where
359        S: SliceKey<K> + Into<OwnedSliceKey<K>>,
360        F: FnOnce() -> V,
361    {
362        if let Some(v) = Self::get_mut::<dyn SliceKey<K> + '_>(self, &k) {
363            #[allow(
364                unsafe_code,
365                /* reason = "workaround for rust issue #51545" */
366            )]
367            // SAFETY: `v` is borrowed from `m` (and only `m`), so it is sound
368            // to return it with the same lifetime as `m`. Due to issues with
369            // Rust's borrow checker (#51545, #54663), this requires a lifetime
370            // extension, performed here by converting to a pointer and
371            // immediately dereferencing.
372            return unsafe { core::ptr::NonNull::from(v).as_mut() };
373        }
374        MapOps::get_or_insert_with(self, k.into(), v)
375    }
376}
377
378pub(crate) struct MapDebug<'a, K, V, T>(
379    &'a T,
380    core::marker::PhantomData<fn() -> (K, V)>,
381);
382
383impl<'a, K, V, T> MapDebug<'a, K, V, T>
384where
385    T: MapOps<K, V>,
386{
387    pub fn new(map: &'a T) -> Self {
388        Self(map, core::marker::PhantomData)
389    }
390}
391
392impl<K, V, T> Debug for MapDebug<'_, K, V, T>
393where
394    K: Debug,
395    V: Debug,
396    T: MapOps<K, V>,
397{
398    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
399        self.0.debug(f)
400    }
401}
402
403#[cfg(feature = "std")]
404impl<T: std::hash::Hash> std::hash::Hash for OwnedSliceKey<T> {
405    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
406        self.0.iter().for_each(|v| v.hash(state));
407    }
408}
409
410impl<T: Ord> Ord for OwnedSliceKey<T> {
411    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
412        self.0.iter().cmp(other.0.iter())
413    }
414}
415
416impl<T: Ord> PartialOrd for OwnedSliceKey<T> {
417    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
418        Some(self.cmp(other))
419    }
420}
421
422impl<T: PartialEq> PartialEq for OwnedSliceKey<T> {
423    fn eq(&self, other: &Self) -> bool {
424        self.0.iter().eq(other.0.iter())
425    }
426}
427
428impl<T: Eq> Eq for OwnedSliceKey<T> {}
429
430fn slice_key_iter<'a, T, S>(key: &'a S) -> impl Iterator<Item = &'a T>
431where
432    T: 'a,
433    S: SliceKey<T> + ?Sized,
434{
435    (0..).map_while(|i| key.get(i))
436}
437
438impl<T, S> SliceKey<T> for &S
439where
440    S: SliceKey<T> + ?Sized,
441{
442    fn get(&self, i: usize) -> Option<&T> {
443        S::get(*self, i)
444    }
445}
446
447impl<T, S> SliceKey<T> for &mut S
448where
449    S: SliceKey<T> + ?Sized,
450{
451    fn get(&self, i: usize) -> Option<&T> {
452        S::get(*self, i)
453    }
454}
455
456impl<T> SliceKey<T> for OwnedSliceKey<T> {
457    fn get(&self, i: usize) -> Option<&T> {
458        self.0.get(i)
459    }
460}
461
462impl<T: Clone> From<&[T]> for OwnedSliceKey<T> {
463    fn from(v: &[T]) -> Self {
464        Self(Vec::from(v).into_boxed_slice())
465    }
466}
467
468impl<T, R: Borrow<T>> SliceKey<T> for VecDeque<R> {
469    fn get(&self, i: usize) -> Option<&T> {
470        self.get(i).map(|r| r.borrow())
471    }
472}
473
474impl<T: Clone> From<&VecDeque<T>> for OwnedSliceKey<T> {
475    fn from(v: &VecDeque<T>) -> Self {
476        Self(v.iter().cloned().collect())
477    }
478}
479
480impl<T> From<&mut VecDeque<T>> for OwnedSliceKey<T> {
481    fn from(v: &mut VecDeque<T>) -> Self {
482        Self(v.drain(..).collect())
483    }
484}
485
486impl<T, R: Borrow<T>> SliceKey<T> for [R] {
487    fn get(&self, i: usize) -> Option<&T> {
488        self.get(i).map(|r| r.borrow())
489    }
490}
491
492#[cfg(feature = "std")]
493impl<T: std::hash::Hash> std::hash::Hash for dyn SliceKey<T> + '_ {
494    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
495        slice_key_iter(self).for_each(|v| v.hash(state));
496    }
497}
498
499impl<T: Ord> Ord for dyn SliceKey<T> + '_ {
500    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
501        slice_key_iter(self).cmp(slice_key_iter(other))
502    }
503}
504
505impl<T: Ord> PartialOrd for dyn SliceKey<T> + '_ {
506    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
507        Some(self.cmp(other))
508    }
509}
510
511impl<T: PartialEq> PartialEq for dyn SliceKey<T> + '_ {
512    fn eq(&self, other: &Self) -> bool {
513        slice_key_iter(self).eq(slice_key_iter(other))
514    }
515}
516
517impl<T: Eq> Eq for dyn SliceKey<T> + '_ {}
518
519impl<'b, T: 'b> Borrow<dyn SliceKey<T> + 'b> for OwnedSliceKey<T> {
520    fn borrow(&self) -> &(dyn SliceKey<T> + 'b) {
521        self
522    }
523}
524
525/// Represents a possible map type for [`Chain`](crate::Chain).
526pub trait Map<K>: detail::Sealed<K> {}
527
528#[cfg(feature = "std")]
529impl<K: Eq + std::hash::Hash> Map<K> for Hash {}
530
531impl<K: Ord> Map<K> for BTree {}