rpc_toolkit/
util.rs

1use std::fmt::{Debug, Display};
2use std::task::Waker;
3
4use futures::future::{BoxFuture, FusedFuture};
5use futures::stream::FusedStream;
6use futures::{Future, FutureExt, Stream, StreamExt};
7use imbl_value::Value;
8use serde::de::DeserializeOwned;
9use serde::ser::Error;
10use serde::{Deserialize, Serialize};
11use yajrc::RpcError;
12
13pub fn extract<T: DeserializeOwned>(value: &Value) -> Result<T, RpcError> {
14    imbl_value::from_value(value.clone()).map_err(invalid_params)
15}
16
17pub fn without<T: Serialize>(value: Value, remove: &T) -> Result<Value, imbl_value::Error> {
18    let to_remove = imbl_value::to_value(remove)?;
19    let (Value::Object(mut value), Value::Object(to_remove)) = (value, to_remove) else {
20        return Err(imbl_value::Error {
21            kind: imbl_value::ErrorKind::Serialization,
22            source: serde_json::Error::custom("params must be object"),
23        });
24    };
25    for k in to_remove.keys() {
26        value.remove(k);
27    }
28    Ok(Value::Object(value))
29}
30
31pub fn combine(v1: Value, v2: Value) -> Result<Value, imbl_value::Error> {
32    let (Value::Object(mut v1), Value::Object(v2)) = (v1, v2) else {
33        return Err(imbl_value::Error {
34            kind: imbl_value::ErrorKind::Serialization,
35            source: serde_json::Error::custom("params must be object"),
36        });
37    };
38    for (key, value) in v2 {
39        if v1.insert(key.clone(), value).is_some() {
40            return Err(imbl_value::Error {
41                kind: imbl_value::ErrorKind::Serialization,
42                source: serde_json::Error::custom(lazy_format::lazy_format!(
43                    "duplicate key: {key}"
44                )),
45            });
46        }
47    }
48    Ok(Value::Object(v1))
49}
50
51pub fn invalid_params(e: imbl_value::Error) -> RpcError {
52    RpcError {
53        data: Some(e.to_string().into()),
54        ..yajrc::INVALID_PARAMS_ERROR
55    }
56}
57
58pub fn invalid_request(e: imbl_value::Error) -> RpcError {
59    RpcError {
60        data: Some(e.to_string().into()),
61        ..yajrc::INVALID_REQUEST_ERROR
62    }
63}
64
65pub fn parse_error(e: impl Display) -> RpcError {
66    RpcError {
67        data: Some(e.to_string().into()),
68        ..yajrc::PARSE_ERROR
69    }
70}
71
72pub fn internal_error(e: impl Display) -> RpcError {
73    RpcError {
74        data: Some(e.to_string().into()),
75        ..yajrc::INTERNAL_ERROR
76    }
77}
78
79pub struct Flat<A, B>(pub A, pub B);
80impl<'de, A, B> Deserialize<'de> for Flat<A, B>
81where
82    A: DeserializeOwned,
83    B: DeserializeOwned,
84{
85    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
86    where
87        D: serde::Deserializer<'de>,
88    {
89        let v = Value::deserialize(deserializer)?;
90        let a = imbl_value::from_value(v.clone()).map_err(serde::de::Error::custom)?;
91        let b = imbl_value::from_value(v).map_err(serde::de::Error::custom)?;
92        Ok(Flat(a, b))
93    }
94}
95impl<A, B> Serialize for Flat<A, B>
96where
97    A: Serialize,
98    B: Serialize,
99{
100    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
101    where
102        S: serde::Serializer,
103    {
104        #[derive(serde::Serialize)]
105        struct FlatStruct<'a, A, B> {
106            #[serde(flatten)]
107            a: &'a A,
108            #[serde(flatten)]
109            b: &'a B,
110        }
111        FlatStruct {
112            a: &self.0,
113            b: &self.1,
114        }
115        .serialize(serializer)
116    }
117}
118impl<A, B> clap::CommandFactory for Flat<A, B>
119where
120    A: clap::CommandFactory,
121    B: clap::Args,
122{
123    fn command() -> clap::Command {
124        B::augment_args(A::command())
125    }
126    fn command_for_update() -> clap::Command {
127        B::augment_args_for_update(A::command_for_update())
128    }
129}
130impl<A, B> clap::FromArgMatches for Flat<A, B>
131where
132    A: clap::FromArgMatches,
133    B: clap::FromArgMatches,
134{
135    fn from_arg_matches(matches: &clap::ArgMatches) -> Result<Self, clap::Error> {
136        Ok(Self(
137            A::from_arg_matches(matches)?,
138            B::from_arg_matches(matches)?,
139        ))
140    }
141    fn from_arg_matches_mut(matches: &mut clap::ArgMatches) -> Result<Self, clap::Error> {
142        Ok(Self(
143            A::from_arg_matches_mut(matches)?,
144            B::from_arg_matches_mut(matches)?,
145        ))
146    }
147    fn update_from_arg_matches(&mut self, matches: &clap::ArgMatches) -> Result<(), clap::Error> {
148        self.0.update_from_arg_matches(matches)?;
149        self.1.update_from_arg_matches(matches)?;
150        Ok(())
151    }
152    fn update_from_arg_matches_mut(
153        &mut self,
154        matches: &mut clap::ArgMatches,
155    ) -> Result<(), clap::Error> {
156        self.0.update_from_arg_matches_mut(matches)?;
157        self.1.update_from_arg_matches_mut(matches)?;
158        Ok(())
159    }
160}
161
162pub fn poll_select_all<'a, T>(
163    futs: &mut Vec<BoxFuture<'a, T>>,
164    cx: &mut std::task::Context<'_>,
165) -> std::task::Poll<T> {
166    let item = futs
167        .iter_mut()
168        .enumerate()
169        .find_map(|(i, f)| match f.poll_unpin(cx) {
170            std::task::Poll::Pending => None,
171            std::task::Poll::Ready(e) => Some((i, e)),
172        });
173    match item {
174        Some((idx, res)) => {
175            drop(futs.swap_remove(idx));
176            std::task::Poll::Ready(res)
177        }
178        None => std::task::Poll::Pending,
179    }
180}
181
182pub struct JobRunner<'a, T> {
183    wakers: Vec<Waker>,
184    closed: bool,
185    running: Vec<BoxFuture<'a, T>>,
186}
187impl<'a, T> JobRunner<'a, T> {
188    pub fn new() -> Self {
189        JobRunner {
190            wakers: Vec::new(),
191            closed: false,
192            running: Vec::new(),
193        }
194    }
195    pub async fn next_result<
196        Src: Stream<Item = Fut> + Unpin,
197        Fut: Future<Output = T> + Send + 'a,
198    >(
199        &mut self,
200        job_source: &mut Src,
201    ) -> Option<T> {
202        let mut job_source = Some(job_source);
203        loop {
204            let next_job_fut = async {
205                if let Some(job_source) = &mut job_source {
206                    job_source.next().await
207                } else {
208                    futures::future::pending().await
209                }
210            };
211            tokio::select! {
212                job = next_job_fut => {
213                    if let Some(job) = job {
214                        self.running.push(job.boxed());
215                        while let Some(waker) = self.wakers.pop() {
216                            waker.wake();
217                        }
218                    } else {
219                        job_source.take();
220                        self.closed = true;
221                        if self.running.is_empty() {
222                            return None;
223                        }
224                    }
225                }
226                res = self.next() => {
227                    return res;
228                }
229            }
230        }
231    }
232}
233impl<'a, T> Stream for JobRunner<'a, T> {
234    type Item = T;
235    fn poll_next(
236        mut self: std::pin::Pin<&mut Self>,
237        cx: &mut std::task::Context<'_>,
238    ) -> std::task::Poll<Option<Self::Item>> {
239        if self.running.is_empty() {
240            self.wakers.push(cx.waker().clone());
241            return std::task::Poll::Pending;
242        }
243        match poll_select_all(&mut self.running, cx) {
244            std::task::Poll::Pending if self.closed && self.running.is_empty() => {
245                std::task::Poll::Ready(None)
246            }
247            a => a.map(Some),
248        }
249    }
250}
251
252#[pin_project::pin_project]
253pub struct StreamUntil<S, F> {
254    #[pin]
255    stream: S,
256    #[pin]
257    until: F,
258    done: bool,
259}
260impl<S, F> StreamUntil<S, F> {
261    pub fn new(stream: S, until: F) -> Self {
262        Self {
263            stream,
264            until,
265            done: false,
266        }
267    }
268}
269impl<S, F> Stream for StreamUntil<S, F>
270where
271    S: Stream,
272    F: Future,
273{
274    type Item = S::Item;
275    fn poll_next(
276        self: std::pin::Pin<&mut Self>,
277        cx: &mut std::task::Context<'_>,
278    ) -> std::task::Poll<Option<Self::Item>> {
279        let this = self.project();
280        *this.done = *this.done || this.until.poll(cx).is_ready();
281        if *this.done {
282            std::task::Poll::Ready(None)
283        } else {
284            this.stream.poll_next(cx)
285        }
286    }
287}
288impl<S, F> FusedStream for StreamUntil<S, F>
289where
290    S: FusedStream,
291    F: FusedFuture,
292{
293    fn is_terminated(&self) -> bool {
294        self.done || self.stream.is_terminated() || self.until.is_terminated()
295    }
296}
297
298pub struct PhantomData<T>(std::marker::PhantomData<T>);
299impl<T> PhantomData<T> {
300    pub fn new() -> Self {
301        PhantomData(std::marker::PhantomData)
302    }
303}
304impl<T> Clone for PhantomData<T> {
305    fn clone(&self) -> Self {
306        PhantomData::new()
307    }
308}
309impl<T> Debug for PhantomData<T> {
310    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311        self.0.fmt(f)
312    }
313}
314unsafe impl<T> Send for PhantomData<T> {}
315unsafe impl<T> Sync for PhantomData<T> {}