1use alloc::collections::BTreeMap;
2use core::hash::Hash;
3use core::{fmt, mem};
4
5use crate::transaction::{Merge, NoOutput, Transaction};
6
7#[derive(Clone, Copy, Debug, Eq, PartialEq)]
9#[non_exhaustive]
10pub struct MapMismatch<K, E> {
11 pub key: K,
13 pub mismatch: E,
15}
16
17#[derive(Clone, Copy, Debug, Eq, PartialEq)]
19#[non_exhaustive]
20pub struct MapConflict<K, C> {
21 pub key: K,
23 pub conflict: C,
25}
26
27impl<K: fmt::Debug, E: core::error::Error + 'static> core::error::Error for MapMismatch<K, E> {
28 fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
29 Some(&self.mismatch)
30 }
31}
32
33impl<K: fmt::Debug, C: core::error::Error + 'static> core::error::Error for MapConflict<K, C> {
34 fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
35 Some(&self.conflict)
36 }
37}
38
39impl<K: fmt::Debug, E> fmt::Display for MapMismatch<K, E> {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 let MapMismatch { key, mismatch: _ } = self;
42 write!(f, "transaction precondition not met at key {key:?}")
43 }
44}
45
46impl<K: fmt::Debug, C> fmt::Display for MapConflict<K, C> {
47 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48 let MapConflict { key, conflict: _ } = self;
49 write!(f, "transaction conflict at key {key:?}")
50 }
51}
52
53impl<K, V> Merge for BTreeMap<K, V>
54where
55 K: Clone + Ord + fmt::Debug + 'static,
56 V: Default + Merge,
57{
58 type MergeCheck = BTreeMap<K, <V as Merge>::MergeCheck>;
59 type Conflict = MapConflict<K, <V as Merge>::Conflict>;
60
61 fn check_merge<'a>(&'a self, mut map2: &'a Self) -> Result<Self::MergeCheck, Self::Conflict> {
62 let mut map1 = self;
63 if map1.len() > map2.len() {
64 mem::swap(&mut map1, &mut map2);
71 }
72 let mut checks = BTreeMap::new();
73 for (k, v1) in map1.iter() {
74 if let Some(v2) = map2.get(k) {
75 checks.insert(
76 k.clone(),
77 v1.check_merge(v2).map_err(|conflict| MapConflict {
78 key: k.clone(),
79 conflict,
80 })?,
81 );
82 }
83 }
84 Ok(checks)
85 }
86
87 fn commit_merge(&mut self, mut other: Self, mut check: Self::MergeCheck) {
88 if other.len() > self.len() {
89 mem::swap(self, &mut other);
90 }
91 for (k, v2) in other {
92 use alloc::collections::btree_map::Entry::*;
93 match self.entry(k) {
94 Occupied(mut entry) => {
95 let entry_check = check.remove(entry.key()).unwrap();
96 entry.get_mut().commit_merge(v2, entry_check);
97 }
98 Vacant(entry) => {
99 entry.insert(v2);
100 }
101 }
102 }
103 }
104}
105
106macro_rules! hashmap_merge {
107 ($module:ident) => {
108 impl<K, V, S> Merge for $module::HashMap<K, V, S>
109 where
110 K: Clone + Eq + Hash + fmt::Debug + 'static,
111 V: Default + Merge,
112 S: core::hash::BuildHasher + Default + 'static,
113 {
114 type MergeCheck = $module::HashMap<K, <V as Merge>::MergeCheck, S>;
115 type Conflict = MapConflict<K, <V as Merge>::Conflict>;
116
117 fn check_merge<'a>(
118 &'a self,
119 mut map2: &'a Self,
120 ) -> Result<Self::MergeCheck, Self::Conflict> {
121 let mut map1 = self;
122 if map1.len() > map2.len() {
123 mem::swap(&mut map1, &mut map2);
126 }
127 let mut checks = $module::HashMap::default();
128 for (k, v1) in map1.iter() {
129 if let Some(v2) = map2.get(k) {
130 checks.insert(
131 k.clone(),
132 v1.check_merge(v2).map_err(|conflict| MapConflict {
133 key: k.clone(),
134 conflict,
135 })?,
136 );
137 }
138 }
139 Ok(checks)
140 }
141
142 fn commit_merge(&mut self, mut other: Self, mut check: Self::MergeCheck) {
143 if other.len() > self.len() {
144 mem::swap(self, &mut other);
145 }
146 for (k, v2) in other {
147 use $module::Entry::*;
148 match self.entry(k) {
149 Occupied(mut entry) => {
150 let entry_check = check.remove(entry.key()).unwrap();
151 entry.get_mut().commit_merge(v2, entry_check);
152 }
153 Vacant(entry) => {
154 entry.insert(v2);
155 }
156 }
157 }
158 }
159 }
160 };
161}
162
163#[cfg(feature = "std")]
164use std::collections::hash_map as std_map;
165#[cfg(feature = "std")]
166hashmap_merge!(std_map);
167
168use hashbrown::hash_map as hb_map;
169hashmap_merge!(hb_map);
170
171macro_rules! impl_transaction_for_tuple {
176 ( $count:literal : $( $name:literal ),* ) => {
177 paste::paste! {
178 impl<$( [<Tr $name>] ),*>
182 Transaction for ($( [<Tr $name>], )*)
183 where
184 $( [<Tr $name>]: Transaction<Output = NoOutput> ),*
185 {
186 type Target = ($( [<Tr $name>]::Target, )*);
187 type Context<'a> = ($( [<Tr $name>]::Context<'a>, )*);
188 type CommitCheck = (
189 $( <[<Tr $name>] as Transaction>::CommitCheck, )*
190 );
191 type Output = NoOutput;
192 type Mismatch = [< TupleError $count >]<
193 $( <[<Tr $name >] as Transaction>::Mismatch, )*
194 >;
195
196 #[allow(unused_variables, reason = "empty tuple case")]
197 fn check(
198 &self,
199 target: &($( [<Tr $name>]::Target, )*),
200 context: ($( [<Tr $name>]::Context<'_>, )*),
201 ) -> Result<Self::CommitCheck, Self::Mismatch> {
202 let ($( [<txn_ $name>], )*) = self;
203 let ($( [<target_ $name>], )*) = target;
204 let ($( [<context_ $name>], )*) = context;
205 Ok((
206 $(
207 [<txn_ $name>].check([<target_ $name>], [<context_ $name>])
208 .map_err([< TupleError $count >]::[<At $name>])?,
209 )*
210 ))
211 }
212
213 fn commit(
214 self,
215 #[allow(unused_variables, reason = "empty tuple case")]
216 target: &mut ($( [<Tr $name>]::Target, )*),
217 check: Self::CommitCheck,
218 outputs: &mut dyn FnMut(Self::Output),
219 ) -> Result<(), super::CommitError> {
220 let ($( [<txn_ $name>], )*) = self;
221 let ($( [<check_ $name>], )*) = check;
222 let ($( [<target_ $name>], )*) = target;
223 $( [<txn_ $name>].commit([<target_ $name>], [<check_ $name>], outputs)?; )*
224 Ok(())
225 }
226 }
227
228 impl<$( [<T $name >] ),*> Merge for ($( [<T $name >], )*)
229 where
230 $( [<T $name >]: Merge ),*
231 {
232 type MergeCheck = (
233 $( <[<T $name >] as Merge>::MergeCheck, )*
234 );
235 type Conflict = [< TupleConflict $count >]<
236 $( <[<T $name >] as Merge>::Conflict, )*
237 >;
238
239 fn check_merge(&self, other: &Self) -> Result<Self::MergeCheck, Self::Conflict> {
240 let ($( [<txn1_ $name>], )*) = self;
241 let ($( [<txn2_ $name>], )*) = other;
242 Ok((
243 $(
244 [<txn1_ $name>].check_merge([<txn2_ $name>])
245 .map_err([< TupleConflict $count >]::[<At $name>])?,
246 )*
247 ))
248 }
249
250 fn commit_merge(&mut self, other: Self, check: Self::MergeCheck) {
251 let ($( [<txn1_ $name>], )*) = self;
252 let ($( [<txn2_ $name>], )*) = other;
253 let ($( [<check_ $name>], )*) = check;
254 $( [<txn1_ $name>].commit_merge([<txn2_ $name>], [<check_ $name>]); )*
255 }
256 }
257
258 #[doc = concat!("Transaction precondition error type for tuples of length ", $count, ".")]
259 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
260 #[expect(clippy::exhaustive_enums)]
261 pub enum [< TupleError $count >]<$( [<E $name>], )*> {
262 $(
263 #[doc = concat!("Error at tuple element ", $name, ".")]
264 [<At $name>]([<E $name>]),
265 )*
266 }
267 #[doc = concat!("Transaction conflict error type for tuples of length ", $count, ".")]
268 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
269 #[expect(clippy::exhaustive_enums)]
270 pub enum [< TupleConflict $count >]<$( [<C $name>], )*> {
271 $(
272 #[doc = concat!("Conflict at tuple element ", $name, ".")]
273 [<At $name>]([<C $name>]),
274 )*
275 }
276
277 impl<$( [<E $name>]: core::error::Error, )*> core::error::Error for
280 [< TupleError $count >]<$( [<E $name>], )*> {
281 fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
282 match *self {
283 $( Self::[<At $name>](ref [<e $name>]) => [<e $name>].source(), )*
284 }
285 }
286 }
287 impl<$( [<C $name>]: core::error::Error, )*> core::error::Error for
288 [< TupleConflict $count >]<$( [<C $name>], )*> {
289 fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
290 match *self {
291 $( Self::[<At $name>](ref [<c $name>]) => [<c $name>].source(), )*
292 }
293 }
294 }
295
296
297 impl<$( [<E $name>]: fmt::Display, )*> fmt::Display for
298 [< TupleError $count >]<$( [<E $name>], )*> {
299 fn fmt(&self, #[allow(unused)] f: &mut fmt::Formatter<'_>) -> fmt::Result {
300 match *self {
301 $( Self::[<At $name>](ref [<e $name>]) => [<e $name>].fmt(f), )*
302 }
303 }
304 }
305
306 impl<$( [<C $name>]: fmt::Display, )*> fmt::Display for
307 [< TupleConflict $count >]<$( [<C $name>], )*> {
308 fn fmt(&self, #[allow(unused)] f: &mut fmt::Formatter<'_>) -> fmt::Result {
309 match *self {
310 $( Self::[<At $name>](ref [<c $name>]) => [<c $name>].fmt(f), )*
311 }
312 }
313 }
314
315 }
316 };
317}
318
319impl_transaction_for_tuple!(1: 0);
320impl_transaction_for_tuple!(2: 0, 1);
321impl_transaction_for_tuple!(3: 0, 1, 2);
322impl_transaction_for_tuple!(4: 0, 1, 2, 3);
323impl_transaction_for_tuple!(5: 0, 1, 2, 3, 4);
324impl_transaction_for_tuple!(6: 0, 1, 2, 3, 4, 5);
325
326impl Transaction for () {
331 type Target = ();
332 type Context<'a> = ();
333 type CommitCheck = ();
334 type Output = core::convert::Infallible;
335 type Mismatch = core::convert::Infallible;
336
337 fn check(&self, (): &(), (): Self::Context<'_>) -> Result<Self::CommitCheck, Self::Mismatch> {
338 Ok(())
339 }
340
341 fn commit(
342 self,
343 (): &mut (),
344 (): Self::CommitCheck,
345 _: &mut dyn FnMut(Self::Output),
346 ) -> Result<(), super::CommitError> {
347 Ok(())
348 }
349}
350
351impl Merge for () {
355 type MergeCheck = ();
356
357 type Conflict = core::convert::Infallible;
358
359 fn check_merge(&self, (): &Self) -> Result<Self::MergeCheck, Self::Conflict> {
360 Ok(())
361 }
362
363 fn commit_merge(&mut self, (): Self, (): Self::MergeCheck) {}
364}