amadeus_core/
util.rs

1use derive_new::new;
2use futures::{ready, Stream};
3use pin_project::pin_project;
4use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize};
5use std::{
6	any::{Any, TypeId}, error, fmt, hash::{Hash, Hasher}, io, marker::PhantomData, pin::Pin, sync::Arc, task::{Context, Poll}
7};
8
9use crate::par_stream::{DistributedStream, ParallelStream};
10
11pub struct ResultExpand<T, E>(pub Result<T, E>);
12impl<T, E> IntoIterator for ResultExpand<T, E>
13where
14	T: IntoIterator,
15{
16	type Item = Result<T::Item, E>;
17	type IntoIter = ResultExpandIter<T::IntoIter, E>;
18
19	fn into_iter(self) -> Self::IntoIter {
20		ResultExpandIter::new(self.0.map(IntoIterator::into_iter))
21	}
22}
23#[pin_project(project = ResultExpandIterProj)]
24pub enum ResultExpandIter<T, E> {
25	Ok(#[pin] T),
26	Err(Option<E>),
27}
28impl<T, E> ResultExpandIter<T, E> {
29	pub fn new(t: Result<T, E>) -> Self {
30		match t {
31			Ok(t) => Self::Ok(t),
32			Err(e) => Self::Err(Some(e)),
33		}
34	}
35}
36impl<T, E> Iterator for ResultExpandIter<T, E>
37where
38	T: Iterator,
39{
40	type Item = Result<T::Item, E>;
41
42	fn next(&mut self) -> Option<Self::Item> {
43		match self {
44			Self::Ok(t) => t.next().map(Ok),
45			Self::Err(e) => e.take().map(Err),
46		}
47	}
48}
49impl<T, E> Stream for ResultExpandIter<T, E>
50where
51	T: Stream,
52{
53	type Item = Result<T::Item, E>;
54
55	fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
56		let ret = match self.project() {
57			ResultExpandIterProj::Ok(t) => ready!(t.poll_next(cx)).map(Ok),
58			ResultExpandIterProj::Err(e) => e.take().map(Err),
59		};
60		Poll::Ready(ret)
61	}
62}
63
64#[derive(Clone, Serialize, Deserialize)]
65#[serde(transparent)]
66pub struct IoError(#[serde(with = "crate::misc_serde")] Arc<io::Error>);
67impl PartialEq for IoError {
68	fn eq(&self, other: &Self) -> bool {
69		self.0.to_string() == other.0.to_string()
70	}
71}
72impl error::Error for IoError {}
73impl fmt::Display for IoError {
74	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
75		fmt::Display::fmt(&self.0, f)
76	}
77}
78impl fmt::Debug for IoError {
79	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
80		fmt::Debug::fmt(&self.0, f)
81	}
82}
83impl From<io::Error> for IoError {
84	fn from(err: io::Error) -> Self {
85		Self(Arc::new(err))
86	}
87}
88impl From<IoError> for io::Error {
89	fn from(err: IoError) -> Self {
90		Arc::try_unwrap(err.0).unwrap()
91	}
92}
93
94#[pin_project]
95#[derive(new)]
96#[repr(transparent)]
97pub struct DistParStream<S>(#[pin] S);
98impl<S> ParallelStream for DistParStream<S>
99where
100	S: DistributedStream,
101{
102	type Item = S::Item;
103	type Task = S::Task;
104
105	fn size_hint(&self) -> (usize, Option<usize>) {
106		self.0.size_hint()
107	}
108	fn next_task(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Task>> {
109		self.project().0.next_task(cx)
110	}
111}
112
113// This is a dumb hack to avoid triggering https://github.com/rust-lang/rust/issues/48214 in amadeus-derive: see https://github.com/taiki-e/pin-project/issues/102#issuecomment-540472282
114#[doc(hidden)]
115#[repr(transparent)]
116pub struct Wrapper<'a, T: ?Sized>(PhantomData<&'a ()>, T);
117impl<'a, T: ?Sized> Wrapper<'a, T> {
118	pub fn new(t: T) -> Self
119	where
120		T: Sized,
121	{
122		Self(PhantomData, t)
123	}
124	pub fn into_inner(self) -> T
125	where
126		T: Sized,
127	{
128		self.1
129	}
130}
131impl<'a, T: ?Sized> Hash for Wrapper<'a, T>
132where
133	T: Hash,
134{
135	fn hash<H: Hasher>(&self, state: &mut H) {
136		self.1.hash(state)
137	}
138}
139impl<'a, T: ?Sized> Serialize for Wrapper<'a, T>
140where
141	T: Serialize,
142{
143	fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
144	where
145		S: Serializer,
146	{
147		self.1.serialize(serializer)
148	}
149}
150impl<'a, 'de, T: ?Sized> Deserialize<'de> for Wrapper<'a, T>
151where
152	T: Deserialize<'de>,
153{
154	fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
155	where
156		D: Deserializer<'de>,
157	{
158		T::deserialize(deserializer).map(Wrapper::new)
159	}
160}
161impl<'a, T: ?Sized> fmt::Debug for Wrapper<'a, T>
162where
163	T: fmt::Debug,
164{
165	fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
166		fmt::Debug::fmt(&self.1, f)
167	}
168}
169
170/// As unsafe as regular `core::mem::transmute`, but asserts size at runtime.
171///
172/// # Safety
173///
174/// Not.
175#[allow(unsafe_code)]
176#[inline(always)]
177pub unsafe fn transmute<A, B>(a: A) -> B {
178	use std::mem;
179	assert_eq!(
180		(mem::size_of::<A>(), mem::align_of::<A>()),
181		(mem::size_of::<B>(), mem::align_of::<B>())
182	);
183	let ret = mem::transmute_copy(&a);
184	mem::forget(a);
185	ret
186}
187
188#[allow(unsafe_code)]
189#[inline(always)]
190pub fn type_coerce<A, B>(a: A) -> Option<B>
191where
192	A: 'static,
193	B: 'static,
194{
195	if type_eq::<A, B>() {
196		Some(unsafe { transmute(a) })
197	} else {
198		None
199	}
200}
201#[inline(always)]
202pub fn type_coerce_ref<A, B>(a: &A) -> Option<&B>
203where
204	A: 'static,
205	B: 'static,
206{
207	<dyn Any>::downcast_ref(a)
208}
209#[inline(always)]
210pub fn type_coerce_mut<A, B>(a: &mut A) -> Option<&mut B>
211where
212	A: 'static,
213	B: 'static,
214{
215	<dyn Any>::downcast_mut(a)
216}
217
218#[inline(always)]
219pub fn type_eq<A: ?Sized, B: ?Sized>() -> bool
220where
221	A: 'static,
222	B: 'static,
223{
224	TypeId::of::<A>() == TypeId::of::<B>()
225}
226
227#[allow(
228	clippy::cast_possible_truncation,
229	clippy::cast_sign_loss,
230	clippy::cast_precision_loss
231)]
232pub fn u64_to_f64(x: u64) -> f64 {
233	assert_eq!(x, x as f64 as u64);
234	x as f64
235}
236#[allow(
237	clippy::cast_possible_truncation,
238	clippy::cast_sign_loss,
239	clippy::cast_precision_loss,
240	clippy::float_cmp
241)]
242pub fn f64_to_u64(x: f64) -> u64 {
243	assert_eq!(x, x as u64 as f64);
244	x as u64
245}