1use serde::{
4 Deserialize, Serialize,
5 de::{self, MapAccess},
6};
7use serde_with::serde_conv;
8use std::{borrow::Cow, convert::Infallible};
9use std::{mem::MaybeUninit, ops::ControlFlow};
10
11pub use decimal::AsDecimal;
12#[cfg(feature = "solana-keypair")]
13pub use keypair::AsKeypair;
14#[cfg(feature = "solana-pubkey")]
15pub use pubkey::AsPubkey;
16#[cfg(feature = "solana-signature")]
17pub use signature::AsSignature;
18
19fn try_from_fn_erased<T: Copy, E>(
20 buffer: &mut [MaybeUninit<T>],
21 mut generator: impl FnMut(usize) -> Result<T, E>,
22) -> ControlFlow<E> {
23 for (i, elem) in buffer.iter_mut().enumerate() {
24 let item = match generator(i) {
25 Ok(item) => item,
26 Err(error) => return ControlFlow::Break(error),
27 };
28 elem.write(item);
29 }
30
31 ControlFlow::Continue(())
32}
33
34fn try_from_fn<const N: usize, T: Copy, E, F>(cb: F) -> Result<[T; N], E>
35where
36 F: FnMut(usize) -> Result<T, E>,
37{
38 let mut array = [const { MaybeUninit::uninit() }; N];
39 match try_from_fn_erased(&mut array, cb) {
40 ControlFlow::Break(error) => Err(error),
41 ControlFlow::Continue(()) => Ok(array.map(|uninit| unsafe { uninit.assume_init() })),
42 }
43}
44
45#[cfg(feature = "solana-pubkey")]
46pub(crate) mod pubkey {
47 use std::marker::PhantomData;
48
49 use super::*;
50 use five8::BASE58_ENCODED_32_MAX_LEN;
51 use solana_pubkey::Pubkey;
52
53 struct CustomPubkey<'a>(Cow<'a, Pubkey>);
54
55 pub(crate) const TOKEN: &str = "$$p";
56
57 impl Serialize for CustomPubkey<'_> {
58 fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
59 where
60 S: serde::Serializer,
61 {
62 s.serialize_newtype_struct(TOKEN, &crate::Bytes((*self.0).as_ref()))
63 }
64 }
65
66 impl<'de> Deserialize<'de> for CustomPubkey<'_> {
67 fn deserialize<D>(d: D) -> Result<Self, D::Error>
68 where
69 D: serde::Deserializer<'de>,
70 {
71 d.deserialize_newtype_struct(TOKEN, Visitor { map: true })
72 .map(|pk| CustomPubkey(Cow::Owned(pk)))
73 }
74 }
75
76 struct Visitor {
77 map: bool,
78 }
79
80 impl<'de> serde::de::Visitor<'de> for Visitor {
81 type Value = Pubkey;
82
83 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
84 if self.map {
85 formatter.write_str("pubkey, keypair, base58 string, or adapter wallet")
86 } else {
87 formatter.write_str("pubkey, keypair, or base58 string")
88 }
89 }
90
91 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
92 where
93 E: serde::de::Error,
94 {
95 match v.len() {
96 32 => Ok(Pubkey::new_from_array(v.try_into().unwrap())),
97 64 => Ok(Pubkey::new_from_array(v[32..].try_into().unwrap())),
99 l => Err(serde::de::Error::invalid_length(l, &"32 or 64")),
100 }
101 }
102
103 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
104 where
105 E: serde::de::Error,
106 {
107 if v.len() > BASE58_ENCODED_32_MAX_LEN {
108 let mut buf = [0u8; 64];
109 five8::decode_64(v, &mut buf).map_err(|_| {
110 serde::de::Error::invalid_value(
111 serde::de::Unexpected::Str(v),
112 &"pubkey or keypair encoded in bs58",
113 )
114 })?;
115 Ok(Pubkey::new_from_array(buf[32..].try_into().unwrap()))
116 } else {
117 let mut buf = [0u8; 32];
118 five8::decode_32(v, &mut buf).map_err(|_| {
119 serde::de::Error::invalid_value(
120 serde::de::Unexpected::Str(v),
121 &"pubkey or keypair encoded in bs58",
122 )
123 })?;
124 Ok(Pubkey::new_from_array(buf))
125 }
126 }
127
128 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
129 where
130 A: serde::de::SeqAccess<'de>,
131 {
132 let hint = seq.size_hint();
133 match hint {
134 Some(n) => {
135 if n == 32 {
136 let buffer: [u8; 32] = try_from_fn(|i| {
137 seq.next_element()?
138 .ok_or_else(|| de::Error::invalid_length(i, &"32"))
139 })?;
140 Ok(Pubkey::new_from_array(buffer))
141 } else if n == 64 {
142 for _ in 0..32 {
143 seq.next_element::<u8>()?;
144 }
145 let buffer: [u8; 32] = try_from_fn(|i| {
146 seq.next_element()?
147 .ok_or_else(|| de::Error::invalid_length(i + 32, &"64"))
148 })?;
149 Ok(Pubkey::new_from_array(buffer))
150 } else {
151 Err(de::Error::invalid_length(n, &"32 or 64"))
152 }
153 }
154 None => {
155 let buffer: [u8; 32] = try_from_fn(|i| {
156 seq.next_element()?
157 .ok_or_else(|| de::Error::invalid_length(i, &"32"))
158 })?;
159 let next = seq.next_element::<u8>()?;
160 if let Some(x) = next {
161 let mut result = [0u8; 32];
162 result[0] = x;
163 let buffer: [u8; 31] = try_from_fn(|i| {
164 seq.next_element()?
165 .ok_or_else(|| de::Error::invalid_length(i, &"64"))
166 })?;
167 result[1..].copy_from_slice(&buffer);
168 Ok(Pubkey::new_from_array(result))
169 } else {
170 Ok(Pubkey::new_from_array(buffer))
171 }
172 }
173 }
174 }
175
176 fn visit_newtype_struct<D>(self, d: D) -> Result<Self::Value, D::Error>
177 where
178 D: serde::Deserializer<'de>,
179 {
180 d.deserialize_any(self)
181 }
182
183 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
184 where
185 A: MapAccess<'de>,
186 {
187 if self.map {
188 map.next_key::<Const<public_key>>()?;
189 let value = map.next_value::<CustomPubkeyNoMap>()?;
190 Ok(value.0)
191 } else {
192 Err(de::Error::invalid_type(de::Unexpected::Map, &self))
193 }
194 }
195 }
196
197 struct CustomPubkeyNoMap(Pubkey);
198
199 impl<'de> Deserialize<'de> for CustomPubkeyNoMap {
200 fn deserialize<D>(d: D) -> Result<Self, D::Error>
201 where
202 D: de::Deserializer<'de>,
203 {
204 d.deserialize_any(Visitor { map: false })
205 .map(CustomPubkeyNoMap)
206 }
207 }
208
209 #[allow(non_camel_case_types)]
210 struct public_key;
211
212 impl Key for public_key {
213 const KEY: &'static str = "public_key";
214 fn new() -> Self {
215 Self
216 }
217 }
218
219 trait Key {
220 const KEY: &'static str;
221 fn new() -> Self;
222 }
223
224 struct Const<K>(K);
225
226 impl<'de, K> Deserialize<'de> for Const<K>
227 where
228 K: Key,
229 {
230 fn deserialize<D>(d: D) -> Result<Self, D::Error>
231 where
232 D: de::Deserializer<'de>,
233 {
234 d.deserialize_str(StrVisitor::<K>(PhantomData))
235 }
236 }
237
238 struct StrVisitor<K: Key>(PhantomData<fn() -> K>);
239
240 impl<K: Key> de::Visitor<'_> for StrVisitor<K> {
241 type Value = Const<K>;
242
243 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
244 f.write_str(K::KEY)
245 }
246
247 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
248 where
249 E: de::Error,
250 {
251 if v == K::KEY {
252 Ok(Const(K::new()))
253 } else {
254 Err(de::Error::invalid_value(de::Unexpected::Str(v), &K::KEY))
255 }
256 }
257 }
258
259 fn to_custom_pubkey(pk: &Pubkey) -> CustomPubkey<'_> {
260 CustomPubkey(Cow::Borrowed(pk))
261 }
262 fn from_custom_pubkey(pk: CustomPubkey<'static>) -> Result<Pubkey, Infallible> {
263 Ok(pk.0.into_owned())
264 }
265 serde_conv!(pub AsPubkey, Pubkey, to_custom_pubkey, from_custom_pubkey);
266
267 #[cfg(test)]
268 mod tests {
269 use super::*;
270 use crate::Value;
271 use serde_with::{DeserializeAs, SerializeAs};
272 use solana_keypair::Keypair;
273 use solana_signer::Signer;
274
275 #[test]
276 fn test_pubkey() {
277 let key = Pubkey::new_unique();
278 let value = AsPubkey::serialize_as(&key, crate::ser::Serializer).unwrap();
279 assert!(matches!(value, Value::B32(_)));
280 let de_key = AsPubkey::deserialize_as(value).unwrap();
281 assert_eq!(key, de_key);
282
283 let value = Value::Map(crate::map! { "public_key" => key });
284 let de_key = AsPubkey::deserialize_as(value).unwrap();
285 assert_eq!(key, de_key);
286
287 let value = Value::String(key.to_string());
288 let de_key = AsPubkey::deserialize_as(value).unwrap();
289 assert_eq!(key, de_key);
290
291 let value = Value::Array(key.to_bytes().map(Value::from).to_vec());
292 let de_key = AsPubkey::deserialize_as(value).unwrap();
293 assert_eq!(key, de_key);
294
295 let keypair = Keypair::new();
296 let key = keypair.pubkey();
297 let value = Value::B64(keypair.to_bytes());
298 let de_key = AsPubkey::deserialize_as(value).unwrap();
299 assert_eq!(key, de_key);
300
301 let value = Value::String(keypair.to_base58_string());
302 let de_key = AsPubkey::deserialize_as(value).unwrap();
303 assert_eq!(key, de_key);
304
305 let value = Value::Array(keypair.to_bytes().map(Value::from).to_vec());
306 let de_key = AsPubkey::deserialize_as(value).unwrap();
307 assert_eq!(key, de_key);
308 }
309 }
310}
311
312#[cfg(feature = "solana-signature")]
313pub(crate) mod signature {
314 use super::*;
315 use solana_signature::Signature;
316
317 struct CustomSignature<'a>(Cow<'a, Signature>);
318
319 pub(crate) const TOKEN: &str = "$$s";
320
321 impl Serialize for CustomSignature<'_> {
322 fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
323 where
324 S: serde::Serializer,
325 {
326 s.serialize_newtype_struct(TOKEN, &crate::Bytes((*self.0).as_ref()))
327 }
328 }
329
330 impl<'de> Deserialize<'de> for CustomSignature<'_> {
331 fn deserialize<D>(d: D) -> Result<Self, D::Error>
332 where
333 D: serde::Deserializer<'de>,
334 {
335 d.deserialize_newtype_struct(TOKEN, Visitor)
336 .map(|pk| CustomSignature(Cow::Owned(pk)))
337 }
338 }
339
340 struct Visitor;
341
342 impl<'de> serde::de::Visitor<'de> for Visitor {
343 type Value = Signature;
344
345 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
346 formatter.write_str("signature or bs58 string")
347 }
348
349 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
350 where
351 E: serde::de::Error,
352 {
353 let buffer: [u8; 64] = v
354 .try_into()
355 .map_err(|_| de::Error::invalid_length(v.len(), &"64"))?;
356 Ok(Signature::from(buffer))
357 }
358
359 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
360 where
361 E: serde::de::Error,
362 {
363 let mut buffer = [0u8; 64];
364 five8::decode_64(v, &mut buffer).map_err(|_| de::Error::custom("invalid base58"))?;
365 Ok(Signature::from(buffer))
366 }
367
368 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
369 where
370 A: serde::de::SeqAccess<'de>,
371 {
372 let buffer: [u8; 64] = try_from_fn(|i| {
373 seq.next_element()?
374 .ok_or_else(|| de::Error::invalid_length(i, &"64"))
375 })?;
376
377 Ok(Signature::from(buffer))
378 }
379
380 fn visit_newtype_struct<D>(self, d: D) -> Result<Self::Value, D::Error>
381 where
382 D: serde::Deserializer<'de>,
383 {
384 d.deserialize_any(self)
385 }
386 }
387
388 fn to_custom_signature(pk: &Signature) -> CustomSignature<'_> {
389 CustomSignature(Cow::Borrowed(pk))
390 }
391 fn from_custom_signature(pk: CustomSignature<'static>) -> Result<Signature, Infallible> {
392 Ok(pk.0.into_owned())
393 }
394 serde_conv!(pub AsSignature, Signature, to_custom_signature, from_custom_signature);
395
396 #[cfg(test)]
397 mod tests {
398 use super::*;
399 use crate::Value;
400 use serde_with::{DeserializeAs, SerializeAs};
401 use solana_signature::Signature;
402
403 #[test]
404 fn test_signature() {
405 let sig = Signature::default();
406 let value = AsSignature::serialize_as(&sig, crate::ser::Serializer).unwrap();
407 assert!(matches!(value, Value::B64(_)));
408 let de_sig = AsSignature::deserialize_as(value).unwrap();
409 assert_eq!(sig, de_sig);
410
411 let value = Value::String(sig.to_string());
412 let de_sig = AsSignature::deserialize_as(value).unwrap();
413 assert_eq!(sig, de_sig);
414
415 let value = Value::Array(
416 sig.as_ref()
417 .iter()
418 .map(|i| Value::from(*i))
419 .collect::<Vec<_>>(),
420 );
421 let de_sig = AsSignature::deserialize_as(value).unwrap();
422 assert_eq!(sig, de_sig);
423 }
424 }
425}
426
427#[cfg(feature = "solana-keypair")]
428pub(crate) mod keypair {
429 use super::*;
430 use solana_keypair::Keypair;
431
432 struct CustomKeypair([u8; 64]);
433
434 pub(crate) const TOKEN: &str = "$$k";
435
436 impl Serialize for CustomKeypair {
437 fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
438 where
439 S: serde::Serializer,
440 {
441 s.serialize_newtype_struct(TOKEN, &crate::Bytes(&self.0))
442 }
443 }
444
445 impl<'de> Deserialize<'de> for CustomKeypair {
446 fn deserialize<D>(d: D) -> Result<Self, D::Error>
447 where
448 D: serde::Deserializer<'de>,
449 {
450 d.deserialize_newtype_struct(TOKEN, Visitor)
451 }
452 }
453
454 struct Visitor;
455
456 impl<'de> serde::de::Visitor<'de> for Visitor {
457 type Value = CustomKeypair;
458
459 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
460 formatter.write_str("keypair or bs58 string")
461 }
462
463 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
464 where
465 E: serde::de::Error,
466 {
467 let buffer: [u8; 64] = v
468 .try_into()
469 .map_err(|_| de::Error::invalid_length(v.len(), &"64"))?;
470 Ok(CustomKeypair(buffer))
471 }
472
473 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
474 where
475 E: serde::de::Error,
476 {
477 let mut buffer = [0u8; 64];
478 five8::decode_64(v, &mut buffer).map_err(|_| de::Error::custom("invalid base58"))?;
479 Ok(CustomKeypair(buffer))
480 }
481
482 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
483 where
484 A: serde::de::SeqAccess<'de>,
485 {
486 let buffer: [u8; 64] = try_from_fn(|i| {
487 seq.next_element()?
488 .ok_or_else(|| de::Error::invalid_length(i, &"64"))
489 })?;
490
491 Ok(CustomKeypair(buffer))
492 }
493
494 fn visit_newtype_struct<D>(self, d: D) -> Result<Self::Value, D::Error>
495 where
496 D: serde::Deserializer<'de>,
497 {
498 d.deserialize_any(self)
499 }
500 }
501
502 fn to_custom_keypair(k: &'_ Keypair) -> CustomKeypair {
503 CustomKeypair(k.to_bytes())
504 }
505 fn from_custom_keypair(k: CustomKeypair) -> Result<Keypair, String> {
506 Keypair::try_from(&k.0[..]).map_err(|error| error.to_string())
507 }
508 serde_conv!(pub AsKeypair, Keypair, to_custom_keypair, from_custom_keypair);
509
510 #[cfg(test)]
511 mod tests {
512 use super::*;
513 use crate::Value;
514 use serde_with::{DeserializeAs, SerializeAs};
515
516 #[test]
517 fn test_keypair() {
518 let key = Keypair::new();
519 let value = AsKeypair::serialize_as(&key, crate::ser::Serializer).unwrap();
520 assert!(matches!(value, Value::B64(_)));
521 let de_key = AsKeypair::deserialize_as(value).unwrap();
522 assert_eq!(key, de_key);
523
524 let value = Value::String(key.to_base58_string());
525 let de_key = AsKeypair::deserialize_as(value).unwrap();
526 assert_eq!(key, de_key);
527
528 let value = Value::Array(key.to_bytes().map(Value::from).to_vec());
529 let de_key = AsKeypair::deserialize_as(value).unwrap();
530 assert_eq!(key, de_key);
531 }
532 }
533}
534
535pub(crate) mod decimal {
536 use super::*;
537 use rust_decimal::Decimal;
538
539 struct CustomDecimal<'a>(Cow<'a, Decimal>);
540
541 pub(crate) const TOKEN: &str = "$$d";
542
543 impl Serialize for CustomDecimal<'_> {
544 fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
545 where
546 S: serde::Serializer,
547 {
548 s.serialize_newtype_struct(TOKEN, &crate::Bytes(&(*self.0).serialize()))
549 }
550 }
551
552 impl<'de> Deserialize<'de> for CustomDecimal<'_> {
553 fn deserialize<D>(d: D) -> Result<Self, D::Error>
554 where
555 D: de::Deserializer<'de>,
556 {
557 d.deserialize_newtype_struct(TOKEN, Visitor)
558 .map(|d| CustomDecimal(Cow::Owned(d)))
559 }
560 }
561
562 struct Visitor;
563
564 impl<'de> serde::de::Visitor<'de> for Visitor {
565 type Value = Decimal;
566
567 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
568 formatter.write_str("decimal")
569 }
570
571 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
572 where
573 E: serde::de::Error,
574 {
575 let buf: [u8; 16] = v
576 .try_into()
577 .map_err(|_| de::Error::invalid_length(v.len(), &"16"))?;
578 Ok(Decimal::deserialize(buf))
579 }
580
581 fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
582 where
583 E: serde::de::Error,
584 {
585 Ok(Decimal::from(v))
586 }
587
588 fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
589 where
590 E: serde::de::Error,
591 {
592 Ok(Decimal::from(v))
593 }
594
595 fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
596 where
597 E: serde::de::Error,
598 {
599 Decimal::try_from(v).map_err(serde::de::Error::custom)
601 }
602
603 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
604 where
605 E: serde::de::Error,
606 {
607 let v = v.trim();
608 if v.bytes().any(|c| c == b'e' || c == b'E') {
609 Decimal::from_scientific(v).map_err(serde::de::Error::custom)
610 } else {
611 v.parse().map_err(serde::de::Error::custom)
612 }
613 }
614
615 fn visit_newtype_struct<D>(self, d: D) -> Result<Self::Value, D::Error>
616 where
617 D: serde::Deserializer<'de>,
618 {
619 d.deserialize_any(self)
620 }
621 }
622
623 fn to_custom_decimal(d: &Decimal) -> CustomDecimal<'_> {
624 CustomDecimal(Cow::Borrowed(d))
625 }
626 fn from_custom_decimal(d: CustomDecimal<'static>) -> Result<Decimal, Infallible> {
627 Ok(d.0.into_owned())
628 }
629 serde_conv!(pub AsDecimal, Decimal, to_custom_decimal, from_custom_decimal);
630
631 #[cfg(test)]
632 mod tests {
633 use super::*;
634 use crate::Value;
635 use rust_decimal_macros::dec;
636 use serde_with::{DeserializeAs, SerializeAs};
637
638 fn de<'de, D: serde::Deserializer<'de>>(d: D) -> Decimal {
639 AsDecimal::deserialize_as(d).unwrap()
640 }
641
642 #[test]
643 fn test_decimal() {
644 assert_eq!(
645 AsDecimal::serialize_as(&Decimal::MAX, crate::ser::Serializer).unwrap(),
646 Value::Decimal(Decimal::MAX)
647 );
648 assert_eq!(de(Value::U64(100)), dec!(100));
649 assert_eq!(de(Value::I64(-1)), dec!(-1));
650 assert_eq!(de(Value::Decimal(Decimal::MAX)), Decimal::MAX);
651 assert_eq!(de(Value::F64(1231.2221)), dec!(1231.2221));
652 assert_eq!(de(Value::String("1234.0".to_owned())), dec!(1234));
653 assert_eq!(de(Value::String(" 1234.0".to_owned())), dec!(1234));
654 assert_eq!(de(Value::String("1e5".to_owned())), dec!(100000));
655 assert_eq!(de(Value::String(" 1e5".to_owned())), dec!(100000));
656 }
657 }
658}