clap_maybe_deser/
lib.rs

1use core::fmt;
2use std::{marker::PhantomData, str::FromStr};
3
4use clap::{Args, FromArgMatches};
5use serde::de::DeserializeOwned;
6
7#[cfg(feature = "serde_json")]
8mod serde_json;
9#[cfg(feature = "stdin")]
10use clap_stdin::MaybeStdin;
11#[cfg(feature = "serde_json")]
12pub use serde_json::JsonDeserializer;
13
14pub trait CustomDeserializer {
15    const NAME: &'static str;
16    type Error: fmt::Display;
17
18    fn from_str<Data: DeserializeOwned>(s: &str) -> Result<Data, Self::Error>;
19}
20
21#[derive(Debug, Clone)]
22pub struct Deser<Data, Deserializer> {
23    pub data:      Data,
24    _deserializer: PhantomData<Deserializer>,
25}
26
27impl<Data, Deserializer> From<Data> for Deser<Data, Deserializer> {
28    fn from(data: Data) -> Self {
29        Deser {
30            data,
31            _deserializer: PhantomData,
32        }
33    }
34}
35
36impl<Data, Deserializer> fmt::Display for Deser<Data, Deserializer>
37where
38    Data: fmt::Display,
39{
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        write!(f, "{}", self.data)
42    }
43}
44
45impl<Data, Deserializer> FromStr for Deser<Data, Deserializer>
46where
47    Data: DeserializeOwned,
48    Deserializer: CustomDeserializer,
49{
50    type Err = Deserializer::Error;
51
52    fn from_str(s: &str) -> Result<Self, Self::Err> {
53        let data = Deserializer::from_str(s)?;
54        Ok(Deser {
55            data,
56            _deserializer: PhantomData,
57        })
58    }
59}
60
61#[derive(Debug)]
62pub struct MaybeDeser<Data, Deserializer> {
63    pub data:      Data,
64    _deserializer: PhantomData<Deserializer>,
65}
66
67impl<Data, Deserializer> From<Data> for MaybeDeser<Data, Deserializer> {
68    fn from(data: Data) -> Self {
69        MaybeDeser {
70            data,
71            _deserializer: PhantomData,
72        }
73    }
74}
75
76impl<Data, Deserializer> fmt::Display for MaybeDeser<Data, Deserializer>
77where
78    Data: fmt::Display,
79{
80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81        write!(f, "{}", self.data)
82    }
83}
84
85impl<Data, Deserializer> FromArgMatches for MaybeDeser<Data, Deserializer>
86where
87    Data: DeserializeOwned + Args + Clone + Send + Sync + 'static,
88    Deserializer: CustomDeserializer,
89{
90    fn from_arg_matches(matches: &clap::ArgMatches) -> std::result::Result<Self, clap::Error> {
91        if let Some(data_str) = matches.get_one::<String>(Deserializer::NAME) {
92            let data: Data = Deserializer::from_str(data_str)
93                .map_err(|e: Deserializer::Error| clap::Error::raw(clap::error::ErrorKind::InvalidValue, e))?;
94            Ok(Self::from(data))
95        } else {
96            let fields = Data::from_arg_matches(matches)?;
97            Ok(Self::from(fields))
98        }
99    }
100
101    fn update_from_arg_matches(&mut self, matches: &clap::ArgMatches) -> std::result::Result<(), clap::Error> {
102        if let Some(data_str) = matches.get_one::<String>(Deserializer::NAME) {
103            let data: Data = Deserializer::from_str(data_str).map_err(|e: Deserializer::Error| {
104                clap::Error::raw(clap::error::ErrorKind::InvalidValue, e.to_string())
105            })?;
106            *self = Self::from(data);
107        } else {
108            *self = Self::from(Data::from_arg_matches(matches)?);
109        }
110        Ok(())
111    }
112}
113
114impl<Data, Deserializer> Args for MaybeDeser<Data, Deserializer>
115where
116    Data: DeserializeOwned + Args + Clone + Send + Sync + 'static,
117    Deserializer: CustomDeserializer,
118{
119    fn augment_args(cmd: clap::Command) -> clap::Command {
120        // Create a list of field names dynamically from T
121        let field_names = Data::augment_args(clap::Command::new(""))
122            .get_arguments()
123            .map(|arg| arg.get_id().clone())
124            .collect::<Vec<_>>();
125
126        let cmd = cmd.arg(
127            clap::Arg::new(Deserializer::NAME)
128                .long(Deserializer::NAME)
129                .num_args(1)
130                .help(format!(
131                    "{} input for the request. If this is provided, all other flags will be ignored.",
132                    Deserializer::NAME
133                ))
134                .conflicts_with_all(field_names),
135        );
136        Data::augment_args(cmd)
137    }
138
139    fn augment_args_for_update(cmd: clap::Command) -> clap::Command {
140        // Create a list of field names dynamically from T
141        let field_names = Data::augment_args_for_update(clap::Command::new(""))
142            .get_arguments()
143            .map(|arg| arg.get_id().clone())
144            .collect::<Vec<_>>();
145
146        let cmd = cmd.arg(
147            clap::Arg::new(Deserializer::NAME)
148                .long(Deserializer::NAME)
149                .num_args(1)
150                .help(format!(
151                    "{} input for the request. If this is provided, all other flags will be ignored.",
152                    Deserializer::NAME
153                ))
154                .conflicts_with_all(field_names),
155        );
156        Data::augment_args_for_update(cmd)
157    }
158}
159
160#[cfg(feature = "stdin")]
161#[derive(Debug)]
162pub struct MaybeStdinDeser<Data, Deserializer> {
163    pub data:      Data,
164    _deserializer: PhantomData<Deserializer>,
165}
166
167#[cfg(feature = "stdin")]
168impl<Data, Deserializer> From<Data> for MaybeStdinDeser<Data, Deserializer> {
169    fn from(data: Data) -> Self {
170        MaybeStdinDeser {
171            data,
172            _deserializer: PhantomData,
173        }
174    }
175}
176
177#[cfg(feature = "stdin")]
178impl<Data, Deserializer> fmt::Display for MaybeStdinDeser<Data, Deserializer>
179where
180    Data: fmt::Display,
181{
182    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
183        write!(f, "{}", self.data)
184    }
185}
186
187#[cfg(feature = "stdin")]
188impl<Data, Deserializer> FromArgMatches for MaybeStdinDeser<Data, Deserializer>
189where
190    Data: DeserializeOwned + Args + Clone + Send + Sync + 'static,
191    Deserializer: CustomDeserializer,
192{
193    fn from_arg_matches(matches: &clap::ArgMatches) -> std::result::Result<Self, clap::Error> {
194        if let Some(maybe_stdin) = matches.get_one::<MaybeStdin<String>>(Deserializer::NAME) {
195            let data_str = maybe_stdin.as_ref();
196            let data: Data = Deserializer::from_str(data_str)
197                .map_err(|e: Deserializer::Error| clap::Error::raw(clap::error::ErrorKind::InvalidValue, e))?;
198            Ok(Self::from(data))
199        } else {
200            let fields = Data::from_arg_matches(matches)?;
201            Ok(Self::from(fields))
202        }
203    }
204
205    fn update_from_arg_matches(&mut self, matches: &clap::ArgMatches) -> std::result::Result<(), clap::Error> {
206        if let Some(maybe_stdin) = matches.get_one::<MaybeStdin<String>>(Deserializer::NAME) {
207            let data_str = maybe_stdin.as_ref();
208            let data: Data = Deserializer::from_str(data_str).map_err(|e: Deserializer::Error| {
209                clap::Error::raw(clap::error::ErrorKind::InvalidValue, e.to_string())
210            })?;
211            *self = Self::from(data);
212        } else {
213            *self = Self::from(Data::from_arg_matches(matches)?);
214        }
215        Ok(())
216    }
217}
218
219#[cfg(feature = "stdin")]
220impl<Data, Deserializer> Args for MaybeStdinDeser<Data, Deserializer>
221where
222    Data: DeserializeOwned + Args + Clone + Send + Sync + 'static,
223    Deserializer: CustomDeserializer,
224{
225    fn augment_args(cmd: clap::Command) -> clap::Command {
226        // Create a list of field names dynamically from T
227        let field_names = Data::augment_args(clap::Command::new(""))
228            .get_arguments()
229            .map(|arg| arg.get_id().clone())
230            .collect::<Vec<_>>();
231
232        let cmd = cmd.arg(
233            clap::Arg::new(Deserializer::NAME)
234                .long(Deserializer::NAME)
235                .num_args(1)
236                .help(format!(
237                    "{} input for the request. If this is provided, all other flags will be ignored.",
238                    Deserializer::NAME
239                ))
240                .value_parser(MaybeStdin::<String>::from_str)
241                .conflicts_with_all(field_names),
242        );
243        Data::augment_args(cmd)
244    }
245
246    fn augment_args_for_update(cmd: clap::Command) -> clap::Command {
247        // Create a list of field names dynamically from T
248        let field_names = Data::augment_args_for_update(clap::Command::new(""))
249            .get_arguments()
250            .map(|arg| arg.get_id().clone())
251            .collect::<Vec<_>>();
252
253        let cmd = cmd.arg(
254            clap::Arg::new(Deserializer::NAME)
255                .long(Deserializer::NAME)
256                .num_args(1)
257                .help(format!(
258                    "{} input for the request. If this is provided, all other flags will be ignored.",
259                    Deserializer::NAME
260                ))
261                .value_parser(MaybeStdin::<String>::from_str)
262                .conflicts_with_all(field_names),
263        );
264        Data::augment_args_for_update(cmd)
265    }
266}