Skip to main content

shank_macro_impl/parsed_struct/
seed.rs

1use crate::types::{Primitive, RustType, TypeKind, Value};
2use std::convert::TryFrom;
3use syn::{Error as ParseError, Result as ParseResult};
4
5const PROGRAM_ID_DESC: &str = "The id of the program";
6const PROGRAM_ID_NAME: &str = "program_id";
7pub const PUBKEY_TY: &str = "Pubkey";
8pub const FULL_PUBKEY_TY: &str = "::solana_program::pubkey::Pubkey";
9pub const ACCOUNT_INFO_TY: &str = "AccountInfo";
10pub const FULL_ACCOUNT_INFO_TY: &str =
11    "::solana_program::account_info::AccountInfo";
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub enum Seed {
15    Literal(String),
16    ProgramId,
17    /// Seed param with (name, desc, type)
18    Param(String, String, Option<String>),
19}
20
21impl Seed {
22    pub fn get_literal(&self) -> Option<String> {
23        match self {
24            Seed::Literal(lit) => Some(lit.to_string()),
25            _ => None,
26        }
27    }
28
29    pub fn get_program_id(&self) -> Option<Seed> {
30        match self {
31            Seed::ProgramId => Some(Seed::ProgramId),
32            _ => None,
33        }
34    }
35
36    pub fn get_param(&self) -> Option<Seed> {
37        match self {
38            Seed::Param(name, desc, ty) => {
39                Some(Seed::Param(name.to_owned(), desc.to_owned(), ty.clone()))
40            }
41            _ => None,
42        }
43    }
44}
45
46#[derive(Debug, Clone, PartialEq, Eq)]
47pub struct SeedArg {
48    pub name: String,
49    pub desc: String,
50    pub ty: RustType,
51}
52impl SeedArg {
53    fn new(name: String, desc: String, ty: RustType) -> Self {
54        Self { name, desc, ty }
55    }
56}
57
58#[derive(Debug)]
59pub struct ProcessedSeed {
60    pub seed: Seed,
61    pub arg: Option<SeedArg>,
62}
63
64impl ProcessedSeed {
65    fn new(seed: Seed, arg: Option<SeedArg>) -> Self {
66        Self { seed, arg }
67    }
68}
69
70impl TryFrom<&Seed> for ProcessedSeed {
71    type Error = ParseError;
72    fn try_from(seed: &Seed) -> ParseResult<Self> {
73        // NOTE: We include lifetimes as part of the render step to guarantee that they match
74        // NOTE: All seeds need to be passed as references since we return an array containing
75        //       them and thus cannot take ownership.
76        match seed {
77            Seed::Literal(_) => Ok(ProcessedSeed::new(seed.clone(), None)),
78            Seed::ProgramId => {
79                let name = PROGRAM_ID_NAME.to_string();
80                let desc = PROGRAM_ID_DESC.to_string();
81                let ty = RustType::reference(
82                    PUBKEY_TY,
83                    TypeKind::Value(Value::Custom(FULL_PUBKEY_TY.to_string())),
84                    None,
85                );
86                Ok(ProcessedSeed::new(
87                    seed.clone(),
88                    Some(SeedArg::new(name, desc, ty)),
89                ))
90            }
91            Seed::Param(name, desc, maybe_kind) => {
92                let ty = match maybe_kind.as_ref().map(String::as_str) {
93                    Some(PUBKEY_TY) | None => {
94                        let kind = TypeKind::Value(Value::Custom(
95                            FULL_PUBKEY_TY.to_string(),
96                        ));
97                        RustType::reference(PUBKEY_TY, kind, None)
98                    }
99                    Some(ACCOUNT_INFO_TY) => {
100                        let kind = TypeKind::Value(Value::Custom(
101                            FULL_ACCOUNT_INFO_TY.to_string(),
102                        ));
103                        RustType::reference(ACCOUNT_INFO_TY, kind, None)
104                    }
105                    Some(ty_name) => {
106                        let ty = RustType::try_from(ty_name)?;
107                        match ty.get_primitive() {
108                            Some(Primitive::U8) => {
109                                // u8 is the only primitive we allow for seeds and it is the only
110                                // type that we don't pass by ref
111                                // Instead when passed to the seeds fn it is wrapped as &[u8]
112                                RustType::owned("u8", ty.kind)
113                            }
114                            Some(_) => {
115                                return Err(ParseError::new_spanned(
116                                    ty.ident,
117                                    "Only u8 primitive is allowed for seeds. All other primitives need to be passed as strings."));
118                            }
119                            None => ty.as_reference(None),
120                        }
121                    }
122                };
123                Ok(ProcessedSeed::new(
124                    seed.clone(),
125                    Some(SeedArg::new(name.to_owned(), desc.to_owned(), ty)),
126                ))
127            }
128        }
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use assert_matches::assert_matches;
136
137    #[test]
138    fn process_seed_literal() {
139        let seed = Seed::Literal("uno".to_string());
140        let ProcessedSeed { arg, .. } = ProcessedSeed::try_from(&seed)
141            .expect("Should parse seed without error");
142
143        assert!(arg.is_none());
144    }
145
146    #[test]
147    fn process_seed_program_id() {
148        let seed = Seed::ProgramId;
149        let ProcessedSeed { arg, .. } = ProcessedSeed::try_from(&seed)
150            .expect("Should parse seed without error");
151
152        assert_matches!(arg, Some(SeedArg { name, desc, ty }) => {
153            assert_eq!(name, PROGRAM_ID_NAME);
154            assert_eq!(desc, PROGRAM_ID_DESC);
155            assert_eq!(ty.ident.to_string().as_str(), "Pubkey");
156            assert!(ty.kind.is_custom());
157            assert_eq!(&format!("{:?}", ty.kind), "TypeKind::Value(Value::Custom(\"::solana_program::pubkey::Pubkey\"))")
158        });
159    }
160
161    #[test]
162    fn process_seed_pubkey() {
163        let seed =
164            Seed::Param("mypubkey".to_string(), "my desc".to_string(), None);
165        let ProcessedSeed { arg, .. } = ProcessedSeed::try_from(&seed)
166            .expect("Should parse seed without error");
167
168        assert_matches!(arg, Some(SeedArg { name, desc, ty }) => {
169            assert_eq!(name, "mypubkey");
170            assert_eq!(desc, "my desc");
171            assert_eq!(ty.ident.to_string().as_str(), "Pubkey");
172            assert!(ty.kind.is_custom());
173            assert_eq!(&format!("{:?}", ty.kind), "TypeKind::Value(Value::Custom(\"::solana_program::pubkey::Pubkey\"))")
174        });
175    }
176
177    #[test]
178    fn process_seed_u8() {
179        let seed = Seed::Param(
180            "myu8".to_string(),
181            "u8 desc".to_string(),
182            Some("u8".to_string()),
183        );
184        let ProcessedSeed { arg, .. } = ProcessedSeed::try_from(&seed)
185            .expect("Should parse seed without error");
186
187        assert_matches!(arg, Some(SeedArg { name, desc, ty }) => {
188            assert_eq!(name, "myu8");
189            assert_eq!(desc, "u8 desc");
190            assert_eq!(ty.ident.to_string().as_str(), "u8");
191            assert!(ty.kind.is_primitive());
192            assert_eq!(&format!("{:?}", ty.kind), "TypeKind::Primitive(Primitive::U8)")
193        });
194    }
195}