amadeus_streaming/
sort.rs

1#![allow(missing_docs)] // due to FnNamed
2
3use serde::{Deserialize, Serialize};
4use serde_closure::{traits, FnNamed};
5use std::{
6	cmp::Ordering, fmt::{self, Debug}, iter, ops
7};
8
9FnNamed! {
10	pub type NeverEqual<F, T> = |self, f: F|a=> &T, b=> &T| -> Ordering where ; where F: (for<'a> traits::Fn<(&'a T, &'a T), Output = Ordering>) {
11		match (self.f).call((a, b)) {
12			Ordering::Equal => Ordering::Less,
13			ord => ord
14		}
15	}
16}
17
18/// This data structure tracks the `n` top values given a stream. It uses only `O(n)` space.
19#[derive(Clone, Serialize, Deserialize)]
20#[serde(bound(
21	serialize = "T: Serialize, F: Serialize + for<'a> traits::Fn<(&'a T, &'a T), Output = Ordering>",
22	deserialize = "T: Deserialize<'de>, F: Deserialize<'de> + for<'a> traits::Fn<(&'a T, &'a T), Output = Ordering>"
23))]
24pub struct Sort<T, F> {
25	top: BTreeSet<T, NeverEqual<F, T>>,
26	n: usize,
27}
28impl<T, F> Sort<T, F> {
29	/// Create an empty `Sort` data structure with the specified `n` capacity.
30	pub fn new(cmp: F, n: usize) -> Self {
31		Self {
32			top: BTreeSet::with_cmp(NeverEqual::new(cmp)),
33			n,
34		}
35	}
36
37	/// The `n` top elements we have capacity to track.
38	pub fn capacity(&self) -> usize {
39		self.n
40	}
41
42	/// The number of elements currently held.
43	pub fn len(&self) -> usize {
44		self.top.len()
45	}
46
47	/// If `.len() == 0`
48	pub fn is_empty(&self) -> bool {
49		self.top.is_empty()
50	}
51
52	/// Clears the `Sort` data structure, as if it was new.
53	pub fn clear(&mut self) {
54		self.top.clear();
55	}
56
57	/// An iterator visiting all elements in ascending order. The iterator element type is `&'_ T`.
58	pub fn iter(&self) -> std::collections::btree_set::Iter<'_, T> {
59		self.top.iter()
60	}
61}
62#[cfg_attr(not(nightly), serde_closure::desugar)]
63impl<T, F> Sort<T, F>
64where
65	F: traits::Fn(&T, &T) -> Ordering,
66{
67	/// "Visit" an element.
68	pub fn push(&mut self, item: T) {
69		let mut at_capacity = false;
70		if self.top.len() < self.n || {
71			at_capacity = true;
72			!matches!(self.top.partial_cmp(&item), Some(Ordering::Less))
73		} {
74			let x = self.top.insert(item);
75			assert!(x);
76			if at_capacity {
77				let _ = self.top.pop_last().unwrap();
78			}
79		}
80	}
81}
82impl<T, F> IntoIterator for Sort<T, F> {
83	type Item = T;
84	type IntoIter = std::collections::btree_set::IntoIter<T>;
85
86	fn into_iter(self) -> Self::IntoIter {
87		self.top.into_iter()
88	}
89}
90#[cfg_attr(not(nightly), serde_closure::desugar)]
91impl<T, F> iter::Sum<Sort<T, F>> for Option<Sort<T, F>>
92where
93	F: traits::Fn(&T, &T) -> Ordering,
94{
95	fn sum<I>(mut iter: I) -> Self
96	where
97		I: Iterator<Item = Sort<T, F>>,
98	{
99		let mut total = iter.next()?;
100		for sample in iter {
101			total += sample;
102		}
103		Some(total)
104	}
105}
106#[cfg_attr(not(nightly), serde_closure::desugar)]
107impl<T, F> ops::Add for Sort<T, F>
108where
109	F: traits::Fn(&T, &T) -> Ordering,
110{
111	type Output = Self;
112
113	fn add(mut self, other: Self) -> Self {
114		self += other;
115		self
116	}
117}
118#[cfg_attr(not(nightly), serde_closure::desugar)]
119impl<T, F> ops::AddAssign for Sort<T, F>
120where
121	F: traits::Fn(&T, &T) -> Ordering,
122{
123	fn add_assign(&mut self, other: Self) {
124		assert_eq!(self.n, other.n);
125		for t in other.top {
126			self.push(t);
127		}
128	}
129}
130impl<T, F> Debug for Sort<T, F>
131where
132	T: Debug,
133{
134	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
135		f.debug_list().entries(self.iter()).finish()
136	}
137}
138
139use btree_set::BTreeSet;
140mod btree_set {
141	use serde::{Deserialize, Deserializer, Serialize, Serializer};
142	use serde_closure::traits;
143	use std::{
144		borrow::Borrow, cmp::Ordering, collections::btree_set, marker::PhantomData, mem::{self, ManuallyDrop, MaybeUninit}
145	};
146
147	#[derive(Clone, Serialize, Deserialize)]
148	#[serde(bound(
149		serialize = "T: Serialize, F: Serialize + for<'a> traits::FnMut<(&'a T, &'a T), Output = Ordering>",
150		deserialize = "T: Deserialize<'de>, F: Deserialize<'de> + for<'a> traits::FnMut<(&'a T, &'a T), Output = Ordering>"
151	))]
152	pub struct BTreeSet<T, F> {
153		set: std::collections::BTreeSet<Node<T, F>>,
154		cmp: F,
155	}
156	impl<T, F> BTreeSet<T, F> {
157		// pub fn new() -> BTreeSet<T, Cmp> {
158		//     Self::with_cmp(Cmp)
159		// }
160		pub fn with_cmp(cmp: F) -> Self {
161			// Sound due to repr(transparent)
162			let set = unsafe {
163				mem::transmute::<
164					btree_set::BTreeSet<TrivialOrd<Node<T, F>>>,
165					btree_set::BTreeSet<Node<T, F>>,
166				>(btree_set::BTreeSet::new())
167			};
168			Self { set, cmp }
169		}
170		pub fn cmp(&self) -> &F {
171			&self.cmp
172		}
173		pub fn cmp_mut(&mut self) -> &mut F {
174			&mut self.cmp
175		}
176		pub fn clear(&mut self) {
177			self.trivial_ord_mut().clear();
178		}
179		pub fn iter(&self) -> btree_set::Iter<'_, T> {
180			// Sound due to repr(transparent)
181			unsafe { mem::transmute(self.set.iter()) }
182		}
183		pub fn len(&self) -> usize {
184			self.set.len()
185		}
186		pub fn is_empty(&self) -> bool {
187			self.set.is_empty()
188		}
189		pub fn pop_last(&mut self) -> Option<T> {
190			#[cfg(nightly)]
191			return self.trivial_ord_mut().pop_last().map(|value| value.0.t);
192			#[cfg(not(nightly))]
193			todo!();
194		}
195		fn trivial_ord_mut(&mut self) -> &mut std::collections::BTreeSet<TrivialOrd<Node<T, F>>> {
196			let set: *mut std::collections::BTreeSet<Node<T, F>> = &mut self.set;
197			let set: *mut std::collections::BTreeSet<TrivialOrd<Node<T, F>>> = set.cast();
198			// Sound due to repr(transparent)
199			unsafe { &mut *set }
200		}
201	}
202	impl<T, F> BTreeSet<T, F>
203	where
204		F: for<'a> traits::FnMut<(&'a T, &'a T), Output = Ordering>,
205	{
206		pub fn insert(&mut self, value: T) -> bool {
207			self.set.insert(Node::new(value, &self.cmp))
208		}
209		pub fn remove(&mut self, value: &T) -> Option<T> {
210			let value: *const T = value;
211			let value: *const TrivialOrd<T> = value.cast();
212			let value = unsafe { &*value };
213			self.set.take(value).map(|node| node.t)
214		}
215	}
216	impl<T, F> IntoIterator for BTreeSet<T, F> {
217		type Item = T;
218		type IntoIter = btree_set::IntoIter<T>;
219
220		fn into_iter(self) -> Self::IntoIter {
221			// Sound due to repr(transparent)
222			unsafe { mem::transmute(self.set.into_iter()) }
223		}
224	}
225	#[cfg_attr(not(nightly), serde_closure::desugar)]
226	impl<T, F> PartialEq<T> for BTreeSet<T, F>
227	where
228		F: traits::Fn(&T, &T) -> Ordering,
229	{
230		fn eq(&self, other: &T) -> bool {
231			matches!(
232				self.cmp.call((self.iter().next().unwrap(), other)),
233				Ordering::Equal
234			) && matches!(
235				self.cmp.call((self.iter().last().unwrap(), other)),
236				Ordering::Equal
237			)
238		}
239	}
240	#[cfg_attr(not(nightly), serde_closure::desugar)]
241	impl<T, F> PartialOrd<T> for BTreeSet<T, F>
242	where
243		F: traits::Fn(&T, &T) -> Ordering,
244	{
245		fn partial_cmp(&self, other: &T) -> Option<Ordering> {
246			match (
247				self.cmp.call((self.iter().next().unwrap(), other)),
248				self.cmp.call((self.iter().last().unwrap(), other)),
249			) {
250				(Ordering::Less, Ordering::Less) => Some(Ordering::Less),
251				(Ordering::Equal, Ordering::Equal) => Some(Ordering::Equal),
252				(Ordering::Greater, Ordering::Greater) => Some(Ordering::Greater),
253				_ => None,
254			}
255		}
256	}
257
258	#[repr(transparent)]
259	struct TrivialOrd<T: ?Sized>(T);
260	impl<T: ?Sized> PartialEq for TrivialOrd<T> {
261		fn eq(&self, _other: &Self) -> bool {
262			unreachable!()
263		}
264	}
265	impl<T: ?Sized> Eq for TrivialOrd<T> {}
266	impl<T: ?Sized> PartialOrd for TrivialOrd<T> {
267		fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
268			unreachable!()
269		}
270	}
271	impl<T: ?Sized> Ord for TrivialOrd<T> {
272		fn cmp(&self, _other: &Self) -> Ordering {
273			unreachable!()
274		}
275	}
276
277	#[repr(transparent)]
278	struct Node<T, F: ?Sized> {
279		t: T,
280		marker: PhantomData<fn() -> F>,
281	}
282	impl<T, F: ?Sized> Node<T, F> {
283		fn new(t: T, f: &F) -> Self {
284			if mem::size_of_val(f) != 0 {
285				panic!("Closures with nonzero size not supported");
286			}
287			Self {
288				t,
289				marker: PhantomData,
290			}
291		}
292	}
293	impl<T, F: ?Sized> Borrow<T> for Node<T, F> {
294		fn borrow(&self) -> &T {
295			&self.t
296		}
297	}
298	impl<T, F: ?Sized> Borrow<TrivialOrd<T>> for Node<T, F> {
299		fn borrow(&self) -> &TrivialOrd<T> {
300			let self_: *const T = &self.t;
301			let self_: *const TrivialOrd<T> = self_.cast();
302			unsafe { &*self_ }
303		}
304	}
305	impl<T, F> PartialEq for Node<T, F>
306	where
307		F: for<'a> traits::FnMut<(&'a T, &'a T), Output = Ordering>,
308	{
309		fn eq(&self, other: &Self) -> bool {
310			matches!(self.cmp(other), Ordering::Equal)
311		}
312	}
313	impl<T, F> Eq for Node<T, F> where F: for<'a> traits::FnMut<(&'a T, &'a T), Output = Ordering> {}
314	impl<T, F> PartialOrd for Node<T, F>
315	where
316		F: for<'a> traits::FnMut<(&'a T, &'a T), Output = Ordering>,
317	{
318		fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
319			Some(self.cmp(other))
320		}
321	}
322	impl<T, F> Ord for Node<T, F>
323	where
324		F: for<'a> traits::FnMut<(&'a T, &'a T), Output = Ordering>,
325	{
326		fn cmp(&self, other: &Self) -> Ordering {
327			// This is safe as an F has already been materialized (so we know it isn't
328			// uninhabited) and its size is zero. Related:
329			// https://internals.rust-lang.org/t/is-synthesizing-zero-sized-values-safe/11506
330			#[allow(clippy::uninit_assumed_init)]
331			let mut cmp: ManuallyDrop<F> = unsafe { MaybeUninit::uninit().assume_init() };
332			cmp.call_mut((&self.t, &other.t))
333		}
334	}
335
336	impl<T, F: ?Sized> Clone for Node<T, F>
337	where
338		T: Clone,
339	{
340		fn clone(&self) -> Self {
341			Self {
342				t: self.t.clone(),
343				marker: PhantomData,
344			}
345		}
346	}
347	impl<T, F: ?Sized> Serialize for Node<T, F>
348	where
349		T: Serialize,
350	{
351		fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
352		where
353			S: Serializer,
354		{
355			self.t.serialize(serializer)
356		}
357	}
358	impl<'de, T, F: ?Sized> Deserialize<'de> for Node<T, F>
359	where
360		T: Deserialize<'de>,
361	{
362		fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
363		where
364			D: Deserializer<'de>,
365		{
366			T::deserialize(deserializer).map(|t| Self {
367				t,
368				marker: PhantomData,
369			})
370		}
371	}
372}