flow_value/
with.rs

1//! [serde_with](https://docs.rs/serde_with/latest/serde_with/) helpers.
2
3use serde::{
4    Deserialize, Serialize,
5    de::{self, MapAccess},
6};
7use serde_with::serde_conv;
8use std::{borrow::Cow, convert::Infallible};
9use std::{mem::MaybeUninit, ops::ControlFlow};
10
11pub use decimal::AsDecimal;
12#[cfg(feature = "solana-keypair")]
13pub use keypair::AsKeypair;
14#[cfg(feature = "solana-pubkey")]
15pub use pubkey::AsPubkey;
16#[cfg(feature = "solana-signature")]
17pub use signature::AsSignature;
18
19fn try_from_fn_erased<T: Copy, E>(
20    buffer: &mut [MaybeUninit<T>],
21    mut generator: impl FnMut(usize) -> Result<T, E>,
22) -> ControlFlow<E> {
23    for (i, elem) in buffer.iter_mut().enumerate() {
24        let item = match generator(i) {
25            Ok(item) => item,
26            Err(error) => return ControlFlow::Break(error),
27        };
28        elem.write(item);
29    }
30
31    ControlFlow::Continue(())
32}
33
34fn try_from_fn<const N: usize, T: Copy, E, F>(cb: F) -> Result<[T; N], E>
35where
36    F: FnMut(usize) -> Result<T, E>,
37{
38    let mut array = [const { MaybeUninit::uninit() }; N];
39    match try_from_fn_erased(&mut array, cb) {
40        ControlFlow::Break(error) => Err(error),
41        ControlFlow::Continue(()) => Ok(array.map(|uninit| unsafe { uninit.assume_init() })),
42    }
43}
44
45#[cfg(feature = "solana-pubkey")]
46pub(crate) mod pubkey {
47    use std::marker::PhantomData;
48
49    use super::*;
50    use five8::BASE58_ENCODED_32_MAX_LEN;
51    use solana_pubkey::Pubkey;
52
53    struct CustomPubkey<'a>(Cow<'a, Pubkey>);
54
55    pub(crate) const TOKEN: &str = "$$p";
56
57    impl Serialize for CustomPubkey<'_> {
58        fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
59        where
60            S: serde::Serializer,
61        {
62            s.serialize_newtype_struct(TOKEN, &crate::Bytes((*self.0).as_ref()))
63        }
64    }
65
66    impl<'de> Deserialize<'de> for CustomPubkey<'_> {
67        fn deserialize<D>(d: D) -> Result<Self, D::Error>
68        where
69            D: serde::Deserializer<'de>,
70        {
71            d.deserialize_newtype_struct(TOKEN, Visitor { map: true })
72                .map(|pk| CustomPubkey(Cow::Owned(pk)))
73        }
74    }
75
76    struct Visitor {
77        map: bool,
78    }
79
80    impl<'de> serde::de::Visitor<'de> for Visitor {
81        type Value = Pubkey;
82
83        fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
84            if self.map {
85                formatter.write_str("pubkey, keypair, base58 string, or adapter wallet")
86            } else {
87                formatter.write_str("pubkey, keypair, or base58 string")
88            }
89        }
90
91        fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
92        where
93            E: serde::de::Error,
94        {
95            match v.len() {
96                32 => Ok(Pubkey::new_from_array(v.try_into().unwrap())),
97                // see ed25519-dalek's Keypair
98                64 => Ok(Pubkey::new_from_array(v[32..].try_into().unwrap())),
99                l => Err(serde::de::Error::invalid_length(l, &"32 or 64")),
100            }
101        }
102
103        fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
104        where
105            E: serde::de::Error,
106        {
107            if v.len() > BASE58_ENCODED_32_MAX_LEN {
108                let mut buf = [0u8; 64];
109                five8::decode_64(v, &mut buf).map_err(|_| {
110                    serde::de::Error::invalid_value(
111                        serde::de::Unexpected::Str(v),
112                        &"pubkey or keypair encoded in bs58",
113                    )
114                })?;
115                Ok(Pubkey::new_from_array(buf[32..].try_into().unwrap()))
116            } else {
117                let mut buf = [0u8; 32];
118                five8::decode_32(v, &mut buf).map_err(|_| {
119                    serde::de::Error::invalid_value(
120                        serde::de::Unexpected::Str(v),
121                        &"pubkey or keypair encoded in bs58",
122                    )
123                })?;
124                Ok(Pubkey::new_from_array(buf))
125            }
126        }
127
128        fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
129        where
130            A: serde::de::SeqAccess<'de>,
131        {
132            let hint = seq.size_hint();
133            match hint {
134                Some(n) => {
135                    if n == 32 {
136                        let buffer: [u8; 32] = try_from_fn(|i| {
137                            seq.next_element()?
138                                .ok_or_else(|| de::Error::invalid_length(i, &"32"))
139                        })?;
140                        Ok(Pubkey::new_from_array(buffer))
141                    } else if n == 64 {
142                        for _ in 0..32 {
143                            seq.next_element::<u8>()?;
144                        }
145                        let buffer: [u8; 32] = try_from_fn(|i| {
146                            seq.next_element()?
147                                .ok_or_else(|| de::Error::invalid_length(i + 32, &"64"))
148                        })?;
149                        Ok(Pubkey::new_from_array(buffer))
150                    } else {
151                        Err(de::Error::invalid_length(n, &"32 or 64"))
152                    }
153                }
154                None => {
155                    let buffer: [u8; 32] = try_from_fn(|i| {
156                        seq.next_element()?
157                            .ok_or_else(|| de::Error::invalid_length(i, &"32"))
158                    })?;
159                    let next = seq.next_element::<u8>()?;
160                    if let Some(x) = next {
161                        let mut result = [0u8; 32];
162                        result[0] = x;
163                        let buffer: [u8; 31] = try_from_fn(|i| {
164                            seq.next_element()?
165                                .ok_or_else(|| de::Error::invalid_length(i, &"64"))
166                        })?;
167                        result[1..].copy_from_slice(&buffer);
168                        Ok(Pubkey::new_from_array(result))
169                    } else {
170                        Ok(Pubkey::new_from_array(buffer))
171                    }
172                }
173            }
174        }
175
176        fn visit_newtype_struct<D>(self, d: D) -> Result<Self::Value, D::Error>
177        where
178            D: serde::Deserializer<'de>,
179        {
180            d.deserialize_any(self)
181        }
182
183        fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
184        where
185            A: MapAccess<'de>,
186        {
187            if self.map {
188                map.next_key::<Const<public_key>>()?;
189                let value = map.next_value::<CustomPubkeyNoMap>()?;
190                Ok(value.0)
191            } else {
192                Err(de::Error::invalid_type(de::Unexpected::Map, &self))
193            }
194        }
195    }
196
197    struct CustomPubkeyNoMap(Pubkey);
198
199    impl<'de> Deserialize<'de> for CustomPubkeyNoMap {
200        fn deserialize<D>(d: D) -> Result<Self, D::Error>
201        where
202            D: de::Deserializer<'de>,
203        {
204            d.deserialize_any(Visitor { map: false })
205                .map(CustomPubkeyNoMap)
206        }
207    }
208
209    #[allow(non_camel_case_types)]
210    struct public_key;
211
212    impl Key for public_key {
213        const KEY: &'static str = "public_key";
214        fn new() -> Self {
215            Self
216        }
217    }
218
219    trait Key {
220        const KEY: &'static str;
221        fn new() -> Self;
222    }
223
224    struct Const<K>(K);
225
226    impl<'de, K> Deserialize<'de> for Const<K>
227    where
228        K: Key,
229    {
230        fn deserialize<D>(d: D) -> Result<Self, D::Error>
231        where
232            D: de::Deserializer<'de>,
233        {
234            d.deserialize_str(StrVisitor::<K>(PhantomData))
235        }
236    }
237
238    struct StrVisitor<K: Key>(PhantomData<fn() -> K>);
239
240    impl<K: Key> de::Visitor<'_> for StrVisitor<K> {
241        type Value = Const<K>;
242
243        fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
244            f.write_str(K::KEY)
245        }
246
247        fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
248        where
249            E: de::Error,
250        {
251            if v == K::KEY {
252                Ok(Const(K::new()))
253            } else {
254                Err(de::Error::invalid_value(de::Unexpected::Str(v), &K::KEY))
255            }
256        }
257    }
258
259    fn to_custom_pubkey(pk: &Pubkey) -> CustomPubkey<'_> {
260        CustomPubkey(Cow::Borrowed(pk))
261    }
262    fn from_custom_pubkey(pk: CustomPubkey<'static>) -> Result<Pubkey, Infallible> {
263        Ok(pk.0.into_owned())
264    }
265    serde_conv!(pub AsPubkey, Pubkey, to_custom_pubkey, from_custom_pubkey);
266
267    #[cfg(test)]
268    mod tests {
269        use super::*;
270        use crate::Value;
271        use serde_with::{DeserializeAs, SerializeAs};
272        use solana_keypair::Keypair;
273        use solana_signer::Signer;
274
275        #[test]
276        fn test_pubkey() {
277            let key = Pubkey::new_unique();
278            let value = AsPubkey::serialize_as(&key, crate::ser::Serializer).unwrap();
279            assert!(matches!(value, Value::B32(_)));
280            let de_key = AsPubkey::deserialize_as(value).unwrap();
281            assert_eq!(key, de_key);
282
283            let value = Value::Map(crate::map! { "public_key" => key });
284            let de_key = AsPubkey::deserialize_as(value).unwrap();
285            assert_eq!(key, de_key);
286
287            let value = Value::String(key.to_string());
288            let de_key = AsPubkey::deserialize_as(value).unwrap();
289            assert_eq!(key, de_key);
290
291            let value = Value::Array(key.to_bytes().map(Value::from).to_vec());
292            let de_key = AsPubkey::deserialize_as(value).unwrap();
293            assert_eq!(key, de_key);
294
295            let keypair = Keypair::new();
296            let key = keypair.pubkey();
297            let value = Value::B64(keypair.to_bytes());
298            let de_key = AsPubkey::deserialize_as(value).unwrap();
299            assert_eq!(key, de_key);
300
301            let value = Value::String(keypair.to_base58_string());
302            let de_key = AsPubkey::deserialize_as(value).unwrap();
303            assert_eq!(key, de_key);
304
305            let value = Value::Array(keypair.to_bytes().map(Value::from).to_vec());
306            let de_key = AsPubkey::deserialize_as(value).unwrap();
307            assert_eq!(key, de_key);
308        }
309    }
310}
311
312#[cfg(feature = "solana-signature")]
313pub(crate) mod signature {
314    use super::*;
315    use solana_signature::Signature;
316
317    struct CustomSignature<'a>(Cow<'a, Signature>);
318
319    pub(crate) const TOKEN: &str = "$$s";
320
321    impl Serialize for CustomSignature<'_> {
322        fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
323        where
324            S: serde::Serializer,
325        {
326            s.serialize_newtype_struct(TOKEN, &crate::Bytes((*self.0).as_ref()))
327        }
328    }
329
330    impl<'de> Deserialize<'de> for CustomSignature<'_> {
331        fn deserialize<D>(d: D) -> Result<Self, D::Error>
332        where
333            D: serde::Deserializer<'de>,
334        {
335            d.deserialize_newtype_struct(TOKEN, Visitor)
336                .map(|pk| CustomSignature(Cow::Owned(pk)))
337        }
338    }
339
340    struct Visitor;
341
342    impl<'de> serde::de::Visitor<'de> for Visitor {
343        type Value = Signature;
344
345        fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
346            formatter.write_str("signature or bs58 string")
347        }
348
349        fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
350        where
351            E: serde::de::Error,
352        {
353            let buffer: [u8; 64] = v
354                .try_into()
355                .map_err(|_| de::Error::invalid_length(v.len(), &"64"))?;
356            Ok(Signature::from(buffer))
357        }
358
359        fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
360        where
361            E: serde::de::Error,
362        {
363            let mut buffer = [0u8; 64];
364            five8::decode_64(v, &mut buffer).map_err(|_| de::Error::custom("invalid base58"))?;
365            Ok(Signature::from(buffer))
366        }
367
368        fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
369        where
370            A: serde::de::SeqAccess<'de>,
371        {
372            let buffer: [u8; 64] = try_from_fn(|i| {
373                seq.next_element()?
374                    .ok_or_else(|| de::Error::invalid_length(i, &"64"))
375            })?;
376
377            Ok(Signature::from(buffer))
378        }
379
380        fn visit_newtype_struct<D>(self, d: D) -> Result<Self::Value, D::Error>
381        where
382            D: serde::Deserializer<'de>,
383        {
384            d.deserialize_any(self)
385        }
386    }
387
388    fn to_custom_signature(pk: &Signature) -> CustomSignature<'_> {
389        CustomSignature(Cow::Borrowed(pk))
390    }
391    fn from_custom_signature(pk: CustomSignature<'static>) -> Result<Signature, Infallible> {
392        Ok(pk.0.into_owned())
393    }
394    serde_conv!(pub AsSignature, Signature, to_custom_signature, from_custom_signature);
395
396    #[cfg(test)]
397    mod tests {
398        use super::*;
399        use crate::Value;
400        use serde_with::{DeserializeAs, SerializeAs};
401        use solana_signature::Signature;
402
403        #[test]
404        fn test_signature() {
405            let sig = Signature::default();
406            let value = AsSignature::serialize_as(&sig, crate::ser::Serializer).unwrap();
407            assert!(matches!(value, Value::B64(_)));
408            let de_sig = AsSignature::deserialize_as(value).unwrap();
409            assert_eq!(sig, de_sig);
410
411            let value = Value::String(sig.to_string());
412            let de_sig = AsSignature::deserialize_as(value).unwrap();
413            assert_eq!(sig, de_sig);
414
415            let value = Value::Array(
416                sig.as_ref()
417                    .iter()
418                    .map(|i| Value::from(*i))
419                    .collect::<Vec<_>>(),
420            );
421            let de_sig = AsSignature::deserialize_as(value).unwrap();
422            assert_eq!(sig, de_sig);
423        }
424    }
425}
426
427#[cfg(feature = "solana-keypair")]
428pub(crate) mod keypair {
429    use super::*;
430    use solana_keypair::Keypair;
431
432    struct CustomKeypair([u8; 64]);
433
434    pub(crate) const TOKEN: &str = "$$k";
435
436    impl Serialize for CustomKeypair {
437        fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
438        where
439            S: serde::Serializer,
440        {
441            s.serialize_newtype_struct(TOKEN, &crate::Bytes(&self.0))
442        }
443    }
444
445    impl<'de> Deserialize<'de> for CustomKeypair {
446        fn deserialize<D>(d: D) -> Result<Self, D::Error>
447        where
448            D: serde::Deserializer<'de>,
449        {
450            d.deserialize_newtype_struct(TOKEN, Visitor)
451        }
452    }
453
454    struct Visitor;
455
456    impl<'de> serde::de::Visitor<'de> for Visitor {
457        type Value = CustomKeypair;
458
459        fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
460            formatter.write_str("keypair or bs58 string")
461        }
462
463        fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
464        where
465            E: serde::de::Error,
466        {
467            let buffer: [u8; 64] = v
468                .try_into()
469                .map_err(|_| de::Error::invalid_length(v.len(), &"64"))?;
470            Ok(CustomKeypair(buffer))
471        }
472
473        fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
474        where
475            E: serde::de::Error,
476        {
477            let mut buffer = [0u8; 64];
478            five8::decode_64(v, &mut buffer).map_err(|_| de::Error::custom("invalid base58"))?;
479            Ok(CustomKeypair(buffer))
480        }
481
482        fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
483        where
484            A: serde::de::SeqAccess<'de>,
485        {
486            let buffer: [u8; 64] = try_from_fn(|i| {
487                seq.next_element()?
488                    .ok_or_else(|| de::Error::invalid_length(i, &"64"))
489            })?;
490
491            Ok(CustomKeypair(buffer))
492        }
493
494        fn visit_newtype_struct<D>(self, d: D) -> Result<Self::Value, D::Error>
495        where
496            D: serde::Deserializer<'de>,
497        {
498            d.deserialize_any(self)
499        }
500    }
501
502    fn to_custom_keypair(k: &'_ Keypair) -> CustomKeypair {
503        CustomKeypair(k.to_bytes())
504    }
505    fn from_custom_keypair(k: CustomKeypair) -> Result<Keypair, String> {
506        Keypair::try_from(&k.0[..]).map_err(|error| error.to_string())
507    }
508    serde_conv!(pub AsKeypair, Keypair, to_custom_keypair, from_custom_keypair);
509
510    #[cfg(test)]
511    mod tests {
512        use super::*;
513        use crate::Value;
514        use serde_with::{DeserializeAs, SerializeAs};
515
516        #[test]
517        fn test_keypair() {
518            let key = Keypair::new();
519            let value = AsKeypair::serialize_as(&key, crate::ser::Serializer).unwrap();
520            assert!(matches!(value, Value::B64(_)));
521            let de_key = AsKeypair::deserialize_as(value).unwrap();
522            assert_eq!(key, de_key);
523
524            let value = Value::String(key.to_base58_string());
525            let de_key = AsKeypair::deserialize_as(value).unwrap();
526            assert_eq!(key, de_key);
527
528            let value = Value::Array(key.to_bytes().map(Value::from).to_vec());
529            let de_key = AsKeypair::deserialize_as(value).unwrap();
530            assert_eq!(key, de_key);
531        }
532    }
533}
534
535pub(crate) mod decimal {
536    use super::*;
537    use rust_decimal::Decimal;
538
539    struct CustomDecimal<'a>(Cow<'a, Decimal>);
540
541    pub(crate) const TOKEN: &str = "$$d";
542
543    impl Serialize for CustomDecimal<'_> {
544        fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
545        where
546            S: serde::Serializer,
547        {
548            s.serialize_newtype_struct(TOKEN, &crate::Bytes(&(*self.0).serialize()))
549        }
550    }
551
552    impl<'de> Deserialize<'de> for CustomDecimal<'_> {
553        fn deserialize<D>(d: D) -> Result<Self, D::Error>
554        where
555            D: de::Deserializer<'de>,
556        {
557            d.deserialize_newtype_struct(TOKEN, Visitor)
558                .map(|d| CustomDecimal(Cow::Owned(d)))
559        }
560    }
561
562    struct Visitor;
563
564    impl<'de> serde::de::Visitor<'de> for Visitor {
565        type Value = Decimal;
566
567        fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
568            formatter.write_str("decimal")
569        }
570
571        fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
572        where
573            E: serde::de::Error,
574        {
575            let buf: [u8; 16] = v
576                .try_into()
577                .map_err(|_| de::Error::invalid_length(v.len(), &"16"))?;
578            Ok(Decimal::deserialize(buf))
579        }
580
581        fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
582        where
583            E: serde::de::Error,
584        {
585            Ok(Decimal::from(v))
586        }
587
588        fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
589        where
590            E: serde::de::Error,
591        {
592            Ok(Decimal::from(v))
593        }
594
595        fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
596        where
597            E: serde::de::Error,
598        {
599            // TODO: this is lossy
600            Decimal::try_from(v).map_err(serde::de::Error::custom)
601        }
602
603        fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
604        where
605            E: serde::de::Error,
606        {
607            let v = v.trim();
608            if v.bytes().any(|c| c == b'e' || c == b'E') {
609                Decimal::from_scientific(v).map_err(serde::de::Error::custom)
610            } else {
611                v.parse().map_err(serde::de::Error::custom)
612            }
613        }
614
615        fn visit_newtype_struct<D>(self, d: D) -> Result<Self::Value, D::Error>
616        where
617            D: serde::Deserializer<'de>,
618        {
619            d.deserialize_any(self)
620        }
621    }
622
623    fn to_custom_decimal(d: &Decimal) -> CustomDecimal<'_> {
624        CustomDecimal(Cow::Borrowed(d))
625    }
626    fn from_custom_decimal(d: CustomDecimal<'static>) -> Result<Decimal, Infallible> {
627        Ok(d.0.into_owned())
628    }
629    serde_conv!(pub AsDecimal, Decimal, to_custom_decimal, from_custom_decimal);
630
631    #[cfg(test)]
632    mod tests {
633        use super::*;
634        use crate::Value;
635        use rust_decimal_macros::dec;
636        use serde_with::{DeserializeAs, SerializeAs};
637
638        fn de<'de, D: serde::Deserializer<'de>>(d: D) -> Decimal {
639            AsDecimal::deserialize_as(d).unwrap()
640        }
641
642        #[test]
643        fn test_decimal() {
644            assert_eq!(
645                AsDecimal::serialize_as(&Decimal::MAX, crate::ser::Serializer).unwrap(),
646                Value::Decimal(Decimal::MAX)
647            );
648            assert_eq!(de(Value::U64(100)), dec!(100));
649            assert_eq!(de(Value::I64(-1)), dec!(-1));
650            assert_eq!(de(Value::Decimal(Decimal::MAX)), Decimal::MAX);
651            assert_eq!(de(Value::F64(1231.2221)), dec!(1231.2221));
652            assert_eq!(de(Value::String("1234.0".to_owned())), dec!(1234));
653            assert_eq!(de(Value::String("  1234.0".to_owned())), dec!(1234));
654            assert_eq!(de(Value::String("1e5".to_owned())), dec!(100000));
655            assert_eq!(de(Value::String("  1e5".to_owned())), dec!(100000));
656        }
657    }
658}