1#[cfg(feature = "borsh")]
3use borsh::{BorshDeserialize, BorshSchema, BorshSerialize};
4use {
5 bytemuck_derive::{Pod, Zeroable},
6 solana_program_error::ProgramError,
7 solana_program_option::COption,
8 solana_pubkey::Pubkey,
9 solana_zk_sdk::encryption::pod::elgamal::PodElGamalPubkey,
10};
11#[cfg(feature = "serde-traits")]
12use {
13 serde::de::{Error, Unexpected, Visitor},
14 serde::{Deserialize, Deserializer, Serialize, Serializer},
15 std::{convert::TryFrom, fmt, str::FromStr},
16};
17
18#[cfg_attr(
21 feature = "borsh",
22 derive(BorshDeserialize, BorshSerialize, BorshSchema)
23)]
24#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
25#[repr(transparent)]
26pub struct OptionalNonZeroPubkey(pub Pubkey);
27impl TryFrom<Option<Pubkey>> for OptionalNonZeroPubkey {
28 type Error = ProgramError;
29 fn try_from(p: Option<Pubkey>) -> Result<Self, Self::Error> {
30 match p {
31 None => Ok(Self(Pubkey::default())),
32 Some(pubkey) => {
33 if pubkey == Pubkey::default() {
34 Err(ProgramError::InvalidArgument)
35 } else {
36 Ok(Self(pubkey))
37 }
38 }
39 }
40 }
41}
42impl TryFrom<COption<Pubkey>> for OptionalNonZeroPubkey {
43 type Error = ProgramError;
44 fn try_from(p: COption<Pubkey>) -> Result<Self, Self::Error> {
45 match p {
46 COption::None => Ok(Self(Pubkey::default())),
47 COption::Some(pubkey) => {
48 if pubkey == Pubkey::default() {
49 Err(ProgramError::InvalidArgument)
50 } else {
51 Ok(Self(pubkey))
52 }
53 }
54 }
55 }
56}
57impl From<OptionalNonZeroPubkey> for Option<Pubkey> {
58 fn from(p: OptionalNonZeroPubkey) -> Self {
59 if p.0 == Pubkey::default() {
60 None
61 } else {
62 Some(p.0)
63 }
64 }
65}
66impl From<OptionalNonZeroPubkey> for COption<Pubkey> {
67 fn from(p: OptionalNonZeroPubkey) -> Self {
68 if p.0 == Pubkey::default() {
69 COption::None
70 } else {
71 COption::Some(p.0)
72 }
73 }
74}
75
76#[cfg(feature = "serde-traits")]
77impl Serialize for OptionalNonZeroPubkey {
78 fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
79 where
80 S: Serializer,
81 {
82 if self.0 == Pubkey::default() {
83 s.serialize_none()
84 } else {
85 s.serialize_some(&self.0.to_string())
86 }
87 }
88}
89
90#[cfg(feature = "serde-traits")]
91struct OptionalNonZeroPubkeyVisitor;
93
94#[cfg(feature = "serde-traits")]
95impl Visitor<'_> for OptionalNonZeroPubkeyVisitor {
96 type Value = OptionalNonZeroPubkey;
97
98 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
99 formatter.write_str("a Pubkey in base58 or `null`")
100 }
101
102 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
103 where
104 E: Error,
105 {
106 let pkey = Pubkey::from_str(v)
107 .map_err(|_| Error::invalid_value(Unexpected::Str(v), &"value string"))?;
108
109 OptionalNonZeroPubkey::try_from(Some(pkey))
110 .map_err(|_| Error::custom("Failed to convert from pubkey"))
111 }
112
113 fn visit_unit<E>(self) -> Result<Self::Value, E>
114 where
115 E: Error,
116 {
117 OptionalNonZeroPubkey::try_from(None).map_err(|e| Error::custom(e.to_string()))
118 }
119}
120
121#[cfg(feature = "serde-traits")]
122impl<'de> Deserialize<'de> for OptionalNonZeroPubkey {
123 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
124 where
125 D: Deserializer<'de>,
126 {
127 deserializer.deserialize_any(OptionalNonZeroPubkeyVisitor)
128 }
129}
130
131#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
134#[repr(transparent)]
135pub struct OptionalNonZeroElGamalPubkey(PodElGamalPubkey);
136impl OptionalNonZeroElGamalPubkey {
137 pub fn equals(&self, other: &PodElGamalPubkey) -> bool {
140 &self.0 == other
141 }
142}
143impl TryFrom<Option<PodElGamalPubkey>> for OptionalNonZeroElGamalPubkey {
144 type Error = ProgramError;
145 fn try_from(p: Option<PodElGamalPubkey>) -> Result<Self, Self::Error> {
146 match p {
147 None => Ok(Self(PodElGamalPubkey::default())),
148 Some(elgamal_pubkey) => {
149 if elgamal_pubkey == PodElGamalPubkey::default() {
150 Err(ProgramError::InvalidArgument)
151 } else {
152 Ok(Self(elgamal_pubkey))
153 }
154 }
155 }
156 }
157}
158impl From<OptionalNonZeroElGamalPubkey> for Option<PodElGamalPubkey> {
159 fn from(p: OptionalNonZeroElGamalPubkey) -> Self {
160 if p.0 == PodElGamalPubkey::default() {
161 None
162 } else {
163 Some(p.0)
164 }
165 }
166}
167
168#[cfg(feature = "serde-traits")]
169impl Serialize for OptionalNonZeroElGamalPubkey {
170 fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
171 where
172 S: Serializer,
173 {
174 if self.0 == PodElGamalPubkey::default() {
175 s.serialize_none()
176 } else {
177 s.serialize_some(&self.0.to_string())
178 }
179 }
180}
181
182#[cfg(feature = "serde-traits")]
183struct OptionalNonZeroElGamalPubkeyVisitor;
184
185#[cfg(feature = "serde-traits")]
186impl Visitor<'_> for OptionalNonZeroElGamalPubkeyVisitor {
187 type Value = OptionalNonZeroElGamalPubkey;
188
189 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
190 formatter.write_str("an ElGamal public key as base64 or `null`")
191 }
192
193 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
194 where
195 E: Error,
196 {
197 let elgamal_pubkey: PodElGamalPubkey = FromStr::from_str(v).map_err(Error::custom)?;
198 OptionalNonZeroElGamalPubkey::try_from(Some(elgamal_pubkey)).map_err(Error::custom)
199 }
200
201 fn visit_unit<E>(self) -> Result<Self::Value, E>
202 where
203 E: Error,
204 {
205 Ok(OptionalNonZeroElGamalPubkey::default())
206 }
207}
208
209#[cfg(feature = "serde-traits")]
210impl<'de> Deserialize<'de> for OptionalNonZeroElGamalPubkey {
211 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
212 where
213 D: Deserializer<'de>,
214 {
215 deserializer.deserialize_any(OptionalNonZeroElGamalPubkeyVisitor)
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use {
222 super::*,
223 crate::bytemuck::pod_from_bytes,
224 base64::{prelude::BASE64_STANDARD, Engine},
225 solana_pubkey::PUBKEY_BYTES,
226 };
227
228 #[test]
229 fn test_pod_non_zero_option() {
230 assert_eq!(
231 Some(Pubkey::new_from_array([1; PUBKEY_BYTES])),
232 Option::<Pubkey>::from(
233 *pod_from_bytes::<OptionalNonZeroPubkey>(&[1; PUBKEY_BYTES]).unwrap()
234 )
235 );
236 assert_eq!(
237 None,
238 Option::<Pubkey>::from(
239 *pod_from_bytes::<OptionalNonZeroPubkey>(&[0; PUBKEY_BYTES]).unwrap()
240 )
241 );
242 assert_eq!(
243 pod_from_bytes::<OptionalNonZeroPubkey>(&[]).unwrap_err(),
244 ProgramError::InvalidArgument
245 );
246 assert_eq!(
247 pod_from_bytes::<OptionalNonZeroPubkey>(&[0; 1]).unwrap_err(),
248 ProgramError::InvalidArgument
249 );
250 assert_eq!(
251 pod_from_bytes::<OptionalNonZeroPubkey>(&[1; 1]).unwrap_err(),
252 ProgramError::InvalidArgument
253 );
254 }
255
256 #[cfg(feature = "serde-traits")]
257 #[test]
258 fn test_pod_non_zero_option_serde_some() {
259 let optional_non_zero_pubkey_some =
260 OptionalNonZeroPubkey(Pubkey::new_from_array([1; PUBKEY_BYTES]));
261 let serialized_some = serde_json::to_string(&optional_non_zero_pubkey_some).unwrap();
262 assert_eq!(
263 &serialized_some,
264 "\"4vJ9JU1bJJE96FWSJKvHsmmFADCg4gpZQff4P3bkLKi\""
265 );
266
267 let deserialized_some =
268 serde_json::from_str::<OptionalNonZeroPubkey>(&serialized_some).unwrap();
269 assert_eq!(optional_non_zero_pubkey_some, deserialized_some);
270 }
271
272 #[cfg(feature = "serde-traits")]
273 #[test]
274 fn test_pod_non_zero_option_serde_none() {
275 let optional_non_zero_pubkey_none =
276 OptionalNonZeroPubkey(Pubkey::new_from_array([0; PUBKEY_BYTES]));
277 let serialized_none = serde_json::to_string(&optional_non_zero_pubkey_none).unwrap();
278 assert_eq!(&serialized_none, "null");
279
280 let deserialized_none =
281 serde_json::from_str::<OptionalNonZeroPubkey>(&serialized_none).unwrap();
282 assert_eq!(optional_non_zero_pubkey_none, deserialized_none);
283 }
284
285 const OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN: usize = 32;
286
287 fn elgamal_pubkey_from_bytes(bytes: &[u8]) -> PodElGamalPubkey {
295 let string = BASE64_STANDARD.encode(bytes);
296 std::str::FromStr::from_str(&string).unwrap()
297 }
298
299 #[test]
300 fn test_pod_non_zero_elgamal_option() {
301 assert_eq!(
302 Some(elgamal_pubkey_from_bytes(
303 &[1; OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN]
304 )),
305 Option::<PodElGamalPubkey>::from(OptionalNonZeroElGamalPubkey(
306 elgamal_pubkey_from_bytes(&[1; OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN])
307 ))
308 );
309 assert_eq!(
310 None,
311 Option::<PodElGamalPubkey>::from(OptionalNonZeroElGamalPubkey(
312 elgamal_pubkey_from_bytes(&[0; OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN])
313 ))
314 );
315
316 assert_eq!(
317 OptionalNonZeroElGamalPubkey(elgamal_pubkey_from_bytes(
318 &[1; OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN]
319 )),
320 *pod_from_bytes::<OptionalNonZeroElGamalPubkey>(
321 &[1; OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN]
322 )
323 .unwrap()
324 );
325 assert!(pod_from_bytes::<OptionalNonZeroElGamalPubkey>(&[]).is_err());
326 }
327
328 #[cfg(feature = "serde-traits")]
329 #[test]
330 fn test_pod_non_zero_elgamal_option_serde_some() {
331 let optional_non_zero_elgamal_pubkey_some = OptionalNonZeroElGamalPubkey(
332 elgamal_pubkey_from_bytes(&[1; OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN]),
333 );
334 let serialized_some =
335 serde_json::to_string(&optional_non_zero_elgamal_pubkey_some).unwrap();
336 assert_eq!(
337 &serialized_some,
338 "\"AQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQE=\""
339 );
340
341 let deserialized_some =
342 serde_json::from_str::<OptionalNonZeroElGamalPubkey>(&serialized_some).unwrap();
343 assert_eq!(optional_non_zero_elgamal_pubkey_some, deserialized_some);
344 }
345
346 #[cfg(feature = "serde-traits")]
347 #[test]
348 fn test_pod_non_zero_elgamal_option_serde_none() {
349 let optional_non_zero_elgamal_pubkey_none = OptionalNonZeroElGamalPubkey(
350 elgamal_pubkey_from_bytes(&[0; OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN]),
351 );
352 let serialized_none =
353 serde_json::to_string(&optional_non_zero_elgamal_pubkey_none).unwrap();
354 assert_eq!(&serialized_none, "null");
355
356 let deserialized_none =
357 serde_json::from_str::<OptionalNonZeroElGamalPubkey>(&serialized_none).unwrap();
358 assert_eq!(optional_non_zero_elgamal_pubkey_none, deserialized_none);
359 }
360}