1use std::fmt::{Debug, Display};
2use std::task::Waker;
3
4use futures::future::{BoxFuture, FusedFuture};
5use futures::stream::FusedStream;
6use futures::{Future, FutureExt, Stream, StreamExt};
7use imbl_value::Value;
8use serde::de::DeserializeOwned;
9use serde::ser::Error;
10use serde::{Deserialize, Serialize};
11use yajrc::RpcError;
12
13pub fn extract<T: DeserializeOwned>(value: &Value) -> Result<T, RpcError> {
14 imbl_value::from_value(value.clone()).map_err(invalid_params)
15}
16
17pub fn without<T: Serialize>(value: Value, remove: &T) -> Result<Value, imbl_value::Error> {
18 let to_remove = imbl_value::to_value(remove)?;
19 let (Value::Object(mut value), Value::Object(to_remove)) = (value, to_remove) else {
20 return Err(imbl_value::Error {
21 kind: imbl_value::ErrorKind::Serialization,
22 source: serde_json::Error::custom("params must be object"),
23 });
24 };
25 for k in to_remove.keys() {
26 value.remove(k);
27 }
28 Ok(Value::Object(value))
29}
30
31pub fn combine(v1: Value, v2: Value) -> Result<Value, imbl_value::Error> {
32 let (Value::Object(mut v1), Value::Object(v2)) = (v1, v2) else {
33 return Err(imbl_value::Error {
34 kind: imbl_value::ErrorKind::Serialization,
35 source: serde_json::Error::custom("params must be object"),
36 });
37 };
38 for (key, value) in v2 {
39 if v1.insert(key.clone(), value).is_some() {
40 return Err(imbl_value::Error {
41 kind: imbl_value::ErrorKind::Serialization,
42 source: serde_json::Error::custom(lazy_format::lazy_format!(
43 "duplicate key: {key}"
44 )),
45 });
46 }
47 }
48 Ok(Value::Object(v1))
49}
50
51pub fn invalid_params(e: imbl_value::Error) -> RpcError {
52 RpcError {
53 data: Some(e.to_string().into()),
54 ..yajrc::INVALID_PARAMS_ERROR
55 }
56}
57
58pub fn invalid_request(e: imbl_value::Error) -> RpcError {
59 RpcError {
60 data: Some(e.to_string().into()),
61 ..yajrc::INVALID_REQUEST_ERROR
62 }
63}
64
65pub fn parse_error(e: impl Display) -> RpcError {
66 RpcError {
67 data: Some(e.to_string().into()),
68 ..yajrc::PARSE_ERROR
69 }
70}
71
72pub fn internal_error(e: impl Display) -> RpcError {
73 RpcError {
74 data: Some(e.to_string().into()),
75 ..yajrc::INTERNAL_ERROR
76 }
77}
78
79pub struct Flat<A, B>(pub A, pub B);
80impl<'de, A, B> Deserialize<'de> for Flat<A, B>
81where
82 A: DeserializeOwned,
83 B: DeserializeOwned,
84{
85 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
86 where
87 D: serde::Deserializer<'de>,
88 {
89 let v = Value::deserialize(deserializer)?;
90 let a = imbl_value::from_value(v.clone()).map_err(serde::de::Error::custom)?;
91 let b = imbl_value::from_value(v).map_err(serde::de::Error::custom)?;
92 Ok(Flat(a, b))
93 }
94}
95impl<A, B> Serialize for Flat<A, B>
96where
97 A: Serialize,
98 B: Serialize,
99{
100 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
101 where
102 S: serde::Serializer,
103 {
104 #[derive(serde::Serialize)]
105 struct FlatStruct<'a, A, B> {
106 #[serde(flatten)]
107 a: &'a A,
108 #[serde(flatten)]
109 b: &'a B,
110 }
111 FlatStruct {
112 a: &self.0,
113 b: &self.1,
114 }
115 .serialize(serializer)
116 }
117}
118impl<A, B> clap::CommandFactory for Flat<A, B>
119where
120 A: clap::CommandFactory,
121 B: clap::Args,
122{
123 fn command() -> clap::Command {
124 B::augment_args(A::command())
125 }
126 fn command_for_update() -> clap::Command {
127 B::augment_args_for_update(A::command_for_update())
128 }
129}
130impl<A, B> clap::FromArgMatches for Flat<A, B>
131where
132 A: clap::FromArgMatches,
133 B: clap::FromArgMatches,
134{
135 fn from_arg_matches(matches: &clap::ArgMatches) -> Result<Self, clap::Error> {
136 Ok(Self(
137 A::from_arg_matches(matches)?,
138 B::from_arg_matches(matches)?,
139 ))
140 }
141 fn from_arg_matches_mut(matches: &mut clap::ArgMatches) -> Result<Self, clap::Error> {
142 Ok(Self(
143 A::from_arg_matches_mut(matches)?,
144 B::from_arg_matches_mut(matches)?,
145 ))
146 }
147 fn update_from_arg_matches(&mut self, matches: &clap::ArgMatches) -> Result<(), clap::Error> {
148 self.0.update_from_arg_matches(matches)?;
149 self.1.update_from_arg_matches(matches)?;
150 Ok(())
151 }
152 fn update_from_arg_matches_mut(
153 &mut self,
154 matches: &mut clap::ArgMatches,
155 ) -> Result<(), clap::Error> {
156 self.0.update_from_arg_matches_mut(matches)?;
157 self.1.update_from_arg_matches_mut(matches)?;
158 Ok(())
159 }
160}
161
162pub fn poll_select_all<'a, T>(
163 futs: &mut Vec<BoxFuture<'a, T>>,
164 cx: &mut std::task::Context<'_>,
165) -> std::task::Poll<T> {
166 let item = futs
167 .iter_mut()
168 .enumerate()
169 .find_map(|(i, f)| match f.poll_unpin(cx) {
170 std::task::Poll::Pending => None,
171 std::task::Poll::Ready(e) => Some((i, e)),
172 });
173 match item {
174 Some((idx, res)) => {
175 drop(futs.swap_remove(idx));
176 std::task::Poll::Ready(res)
177 }
178 None => std::task::Poll::Pending,
179 }
180}
181
182pub struct JobRunner<'a, T> {
183 wakers: Vec<Waker>,
184 closed: bool,
185 running: Vec<BoxFuture<'a, T>>,
186}
187impl<'a, T> JobRunner<'a, T> {
188 pub fn new() -> Self {
189 JobRunner {
190 wakers: Vec::new(),
191 closed: false,
192 running: Vec::new(),
193 }
194 }
195 pub async fn next_result<
196 Src: Stream<Item = Fut> + Unpin,
197 Fut: Future<Output = T> + Send + 'a,
198 >(
199 &mut self,
200 job_source: &mut Src,
201 ) -> Option<T> {
202 let mut job_source = Some(job_source);
203 loop {
204 let next_job_fut = async {
205 if let Some(job_source) = &mut job_source {
206 job_source.next().await
207 } else {
208 futures::future::pending().await
209 }
210 };
211 tokio::select! {
212 job = next_job_fut => {
213 if let Some(job) = job {
214 self.running.push(job.boxed());
215 while let Some(waker) = self.wakers.pop() {
216 waker.wake();
217 }
218 } else {
219 job_source.take();
220 self.closed = true;
221 if self.running.is_empty() {
222 return None;
223 }
224 }
225 }
226 res = self.next() => {
227 return res;
228 }
229 }
230 }
231 }
232}
233impl<'a, T> Stream for JobRunner<'a, T> {
234 type Item = T;
235 fn poll_next(
236 mut self: std::pin::Pin<&mut Self>,
237 cx: &mut std::task::Context<'_>,
238 ) -> std::task::Poll<Option<Self::Item>> {
239 if self.running.is_empty() {
240 self.wakers.push(cx.waker().clone());
241 return std::task::Poll::Pending;
242 }
243 match poll_select_all(&mut self.running, cx) {
244 std::task::Poll::Pending if self.closed && self.running.is_empty() => {
245 std::task::Poll::Ready(None)
246 }
247 a => a.map(Some),
248 }
249 }
250}
251
252#[pin_project::pin_project]
253pub struct StreamUntil<S, F> {
254 #[pin]
255 stream: S,
256 #[pin]
257 until: F,
258 done: bool,
259}
260impl<S, F> StreamUntil<S, F> {
261 pub fn new(stream: S, until: F) -> Self {
262 Self {
263 stream,
264 until,
265 done: false,
266 }
267 }
268}
269impl<S, F> Stream for StreamUntil<S, F>
270where
271 S: Stream,
272 F: Future,
273{
274 type Item = S::Item;
275 fn poll_next(
276 self: std::pin::Pin<&mut Self>,
277 cx: &mut std::task::Context<'_>,
278 ) -> std::task::Poll<Option<Self::Item>> {
279 let this = self.project();
280 *this.done = *this.done || this.until.poll(cx).is_ready();
281 if *this.done {
282 std::task::Poll::Ready(None)
283 } else {
284 this.stream.poll_next(cx)
285 }
286 }
287}
288impl<S, F> FusedStream for StreamUntil<S, F>
289where
290 S: FusedStream,
291 F: FusedFuture,
292{
293 fn is_terminated(&self) -> bool {
294 self.done || self.stream.is_terminated() || self.until.is_terminated()
295 }
296}
297
298pub struct PhantomData<T>(std::marker::PhantomData<T>);
299impl<T> PhantomData<T> {
300 pub fn new() -> Self {
301 PhantomData(std::marker::PhantomData)
302 }
303}
304impl<T> Clone for PhantomData<T> {
305 fn clone(&self) -> Self {
306 PhantomData::new()
307 }
308}
309impl<T> Debug for PhantomData<T> {
310 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311 self.0.fmt(f)
312 }
313}
314unsafe impl<T> Send for PhantomData<T> {}
315unsafe impl<T> Sync for PhantomData<T> {}