rpc_toolkit/
cli.rs

1use std::collections::VecDeque;
2use std::ffi::OsString;
3
4use clap::{CommandFactory, FromArgMatches};
5use futures::Future;
6use imbl_value::imbl::OrdMap;
7use imbl_value::Value;
8use reqwest::header::{ACCEPT, CONTENT_LENGTH, CONTENT_TYPE};
9use reqwest::{Client, Method};
10use serde::de::DeserializeOwned;
11use serde::Serialize;
12use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
13use url::Url;
14use yajrc::{Id, RpcError};
15
16use crate::util::{internal_error, invalid_params, parse_error, without, Flat, PhantomData};
17use crate::{
18    AnyHandler, CliBindings, CliBindingsAny, Empty, HandleAny, HandleAnyArgs, HandlerArgs,
19    HandlerArgsFor, HandlerFor, HandlerTypes, Name, ParentHandler, PrintCliResult,
20};
21
22type GenericRpcMethod<'a> = yajrc::GenericRpcMethod<&'a str, Value, Value>;
23type RpcRequest<'a> = yajrc::RpcRequest<GenericRpcMethod<'a>>;
24type RpcResponse<'a> = yajrc::RpcResponse<GenericRpcMethod<'static>>;
25
26pub struct CliApp<Context: crate::Context + Clone, Config: CommandFactory + FromArgMatches> {
27    _phantom: PhantomData<(Context, Config)>,
28    make_ctx: Box<dyn FnOnce(Config) -> Result<Context, RpcError> + Send + Sync>,
29    root_handler: ParentHandler<Context>,
30}
31impl<Context: crate::Context + Clone, Config: CommandFactory + FromArgMatches>
32    CliApp<Context, Config>
33{
34    pub fn new<MakeCtx: FnOnce(Config) -> Result<Context, RpcError> + Send + Sync + 'static>(
35        make_ctx: MakeCtx,
36        root_handler: ParentHandler<Context>,
37    ) -> Self {
38        Self {
39            _phantom: PhantomData::new(),
40            make_ctx: Box::new(make_ctx),
41            root_handler,
42        }
43    }
44    pub fn run(self, args: impl IntoIterator<Item = OsString>) -> Result<(), RpcError> {
45        let mut cmd = Config::command();
46        for (name, handler) in &self.root_handler.subcommands.1 {
47            if let (Name(name), Some(cli)) = (name, handler.cli()) {
48                cmd = cmd.subcommand(cli.cli_command().name(name));
49            }
50        }
51        let matches = cmd.get_matches_from(args);
52        let config = Config::from_arg_matches(&matches)?;
53        let ctx = (self.make_ctx)(config)?;
54        let root_handler = AnyHandler::new(self.root_handler);
55        let (method, params) = root_handler.cli_parse(&matches)?;
56        let res = root_handler.handle_sync(HandleAnyArgs {
57            context: ctx.clone(),
58            parent_method: VecDeque::new(),
59            method: method.clone(),
60            params: params.clone(),
61            inherited: crate::Empty {},
62        })?;
63        root_handler.cli_display(
64            HandleAnyArgs {
65                context: ctx,
66                parent_method: VecDeque::new(),
67                method,
68                params,
69                inherited: crate::Empty {},
70            },
71            res,
72        )?;
73        Ok(())
74    }
75}
76
77pub trait CallRemote<RemoteContext, Extra = Empty>: crate::Context {
78    fn call_remote(
79        &self,
80        method: &str,
81        params: Value,
82        extra: Extra,
83    ) -> impl Future<Output = Result<Value, RpcError>> + Send;
84}
85
86pub async fn call_remote_http(
87    client: &Client,
88    url: Url,
89    method: &str,
90    params: Value,
91) -> Result<Value, RpcError> {
92    let rpc_req = RpcRequest {
93        id: Some(Id::Number(0.into())),
94        method: GenericRpcMethod::new(method),
95        params,
96    };
97    let mut req = client.request(Method::POST, url);
98    let body;
99    #[cfg(feature = "cbor")]
100    {
101        req = req.header(CONTENT_TYPE, "application/cbor");
102        req = req.header(ACCEPT, "application/cbor, application/json");
103        body = serde_cbor::to_vec(&rpc_req)?;
104    }
105    #[cfg(not(feature = "cbor"))]
106    {
107        req = req.header(CONTENT_TYPE, "application/json");
108        req = req.header(ACCEPT, "application/json");
109        body = serde_json::to_vec(&rpc_req)?;
110    }
111    let res = req
112        .header(CONTENT_LENGTH, body.len())
113        .body(body)
114        .send()
115        .await?;
116
117    match res
118        .headers()
119        .get(CONTENT_TYPE)
120        .and_then(|v| v.to_str().ok())
121    {
122        Some("application/json") => {
123            serde_json::from_slice::<RpcResponse>(&*res.bytes().await.map_err(internal_error)?)
124                .map_err(parse_error)?
125                .result
126        }
127        #[cfg(feature = "cbor")]
128        Some("application/cbor") => {
129            serde_cbor::from_slice::<RpcResponse>(&*res.bytes().await.map_err(internal_error)?)
130                .map_err(parse_error)?
131                .result
132        }
133        _ => Err(internal_error("missing content type")),
134    }
135}
136
137pub async fn call_remote_socket(
138    connection: impl AsyncRead + AsyncWrite,
139    method: &str,
140    params: Value,
141) -> Result<Value, RpcError> {
142    let rpc_req = RpcRequest {
143        id: Some(Id::Number(0.into())),
144        method: GenericRpcMethod::new(method),
145        params,
146    };
147    let conn = connection;
148    tokio::pin!(conn);
149    let mut buf = serde_json::to_vec(&rpc_req).map_err(|e| RpcError {
150        data: Some(e.to_string().into()),
151        ..yajrc::INTERNAL_ERROR
152    })?;
153    buf.push(b'\n');
154    conn.write_all(&buf).await.map_err(|e| RpcError {
155        data: Some(e.to_string().into()),
156        ..yajrc::INTERNAL_ERROR
157    })?;
158    let mut line = String::new();
159    BufReader::new(conn).read_line(&mut line).await?;
160    serde_json::from_str::<RpcResponse>(&line)
161        .map_err(parse_error)?
162        .result
163}
164
165pub struct CallRemoteHandler<Context, RemoteContext, RemoteHandler, Extra = Empty> {
166    _phantom: PhantomData<(Context, RemoteContext, Extra)>,
167    handler: RemoteHandler,
168}
169impl<Context, RemoteContext, RemoteHandler, Extra>
170    CallRemoteHandler<Context, RemoteContext, RemoteHandler, Extra>
171{
172    pub fn new(handler: RemoteHandler) -> Self {
173        Self {
174            _phantom: PhantomData::new(),
175            handler: handler,
176        }
177    }
178}
179impl<Context, RemoteContext, RemoteHandler: Clone, Extra> Clone
180    for CallRemoteHandler<Context, RemoteContext, RemoteHandler, Extra>
181{
182    fn clone(&self) -> Self {
183        Self {
184            _phantom: PhantomData::new(),
185            handler: self.handler.clone(),
186        }
187    }
188}
189impl<Context, RemoteHandler, Extra> std::fmt::Debug
190    for CallRemoteHandler<Context, RemoteHandler, Extra>
191{
192    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193        f.debug_tuple("CallRemoteHandler").finish()
194    }
195}
196
197impl<Context, RemoteContext, RemoteHandler, Extra> HandlerTypes
198    for CallRemoteHandler<Context, RemoteContext, RemoteHandler, Extra>
199where
200    RemoteHandler: HandlerTypes,
201    RemoteHandler::Params: Serialize,
202    RemoteHandler::Ok: DeserializeOwned,
203    RemoteHandler::Err: From<RpcError>,
204    Extra: Send + Sync + 'static,
205{
206    type Params = Flat<RemoteHandler::Params, Extra>;
207    type InheritedParams = RemoteHandler::InheritedParams;
208    type Ok = RemoteHandler::Ok;
209    type Err = RemoteHandler::Err;
210}
211
212impl<Context, RemoteContext, RemoteHandler, Extra> HandlerFor<Context>
213    for CallRemoteHandler<Context, RemoteContext, RemoteHandler, Extra>
214where
215    Context: CallRemote<RemoteContext, Extra>,
216    RemoteContext: crate::Context,
217    RemoteHandler: HandlerFor<RemoteContext>,
218    RemoteHandler::Params: Serialize,
219    RemoteHandler::Ok: DeserializeOwned,
220    RemoteHandler::Err: From<RpcError>,
221    Extra: Serialize + Send + Sync + 'static,
222{
223    async fn handle_async(
224        &self,
225        handle_args: HandlerArgsFor<Context, Self>,
226    ) -> Result<Self::Ok, Self::Err> {
227        let full_method = handle_args
228            .parent_method
229            .into_iter()
230            .chain(handle_args.method)
231            .collect::<Vec<_>>();
232        match handle_args
233            .context
234            .call_remote(
235                &full_method.join("."),
236                without(handle_args.raw_params.clone(), &handle_args.params.1)
237                    .map_err(invalid_params)?,
238                handle_args.params.1,
239            )
240            .await
241        {
242            Ok(a) => imbl_value::from_value(a)
243                .map_err(internal_error)
244                .map_err(Self::Err::from),
245            Err(e) => Err(Self::Err::from(e)),
246        }
247    }
248    fn metadata(&self, method: VecDeque<&'static str>) -> OrdMap<&'static str, Value> {
249        self.handler.metadata(method)
250    }
251    fn method_from_dots(&self, method: &str) -> Option<VecDeque<&'static str>> {
252        self.handler.method_from_dots(method)
253    }
254}
255impl<Context, RemoteContext, RemoteHandler, Extra> PrintCliResult<Context>
256    for CallRemoteHandler<Context, RemoteContext, RemoteHandler, Extra>
257where
258    Context: CallRemote<RemoteContext>,
259    RemoteHandler: PrintCliResult<Context>,
260    RemoteHandler::Params: Serialize,
261    RemoteHandler::Ok: DeserializeOwned,
262    RemoteHandler::Err: From<RpcError>,
263    Extra: Send + Sync + 'static,
264{
265    fn print(
266        &self,
267        HandlerArgs {
268            context,
269            parent_method,
270            method,
271            params,
272            inherited_params,
273            raw_params,
274        }: HandlerArgsFor<Context, Self>,
275        result: Self::Ok,
276    ) -> Result<(), Self::Err> {
277        self.handler.print(
278            HandlerArgs {
279                context,
280                parent_method,
281                method,
282                params: params.0,
283                inherited_params,
284                raw_params,
285            },
286            result,
287        )
288    }
289}
290impl<Context, RemoteContext, RemoteHandler, Extra> CliBindings<Context>
291    for CallRemoteHandler<Context, RemoteContext, RemoteHandler, Extra>
292where
293    Context: crate::Context,
294    RemoteHandler: CliBindings<Context>,
295    RemoteHandler::Params: Serialize,
296    RemoteHandler::Ok: DeserializeOwned,
297    RemoteHandler::Err: From<RpcError>,
298    Extra: Send + Sync + 'static,
299{
300    fn cli_command(&self) -> clap::Command {
301        self.handler.cli_command()
302    }
303    fn cli_parse(
304        &self,
305        matches: &clap::ArgMatches,
306    ) -> Result<(VecDeque<&'static str>, Value), clap::Error> {
307        self.handler.cli_parse(matches)
308    }
309    fn cli_display(
310        &self,
311        HandlerArgs {
312            context,
313            parent_method,
314            method,
315            params,
316            inherited_params,
317            raw_params,
318        }: HandlerArgsFor<Context, Self>,
319        result: Self::Ok,
320    ) -> Result<(), Self::Err> {
321        self.handler.cli_display(
322            HandlerArgs {
323                context,
324                parent_method,
325                method,
326                params: params.0,
327                inherited_params,
328                raw_params,
329            },
330            result,
331        )
332    }
333}