codama_attributes/codama_directives/
seed_directive.rs

1use crate::{
2    utils::{FromMeta, SetOnce},
3    Attribute, AttributeContext, CodamaAttribute, CodamaDirective,
4};
5use codama_errors::CodamaError;
6use codama_nodes::{ConstantPdaSeedNode, PdaSeedNode, TypeNode, ValueNode, VariablePdaSeedNode};
7use codama_syn_helpers::{extensions::*, Meta};
8
9#[derive(Debug, PartialEq)]
10pub struct SeedDirective {
11    pub seed: SeedDirectiveType,
12}
13
14#[derive(Debug, PartialEq)]
15pub enum SeedDirectiveType {
16    Linked(String),
17    Defined(PdaSeedNode),
18}
19
20impl SeedDirective {
21    pub fn parse(meta: &Meta, ctx: &AttributeContext) -> syn::Result<Self> {
22        let pl = meta.assert_directive("seed")?.as_path_list()?;
23
24        let constant_seed = pl
25            .parse_metas()?
26            .iter()
27            .find_map(|m| match m.path_str().as_str() {
28                "name" => Some(false),
29                "value" => Some(true),
30                _ => None,
31            })
32            .ok_or_else(|| meta.error("seed must at least specify `name` for variable seeds or `type` and `value` for constant seeds"))?;
33
34        let mut name = SetOnce::<String>::new("name");
35        let mut r#type = SetOnce::<TypeNode>::new("type");
36        let mut value = SetOnce::<ValueNode>::new("value");
37
38        pl.each(|ref meta| match (meta.path_str().as_str(), constant_seed) {
39            ("name", true) => Err(meta.error("constant seeds cannot specify name")),
40            ("name", false) => name.set(meta.as_value()?.as_expr()?.as_string()?, meta),
41            ("value", true) => value.set(ValueNode::from_meta(meta.as_value()?)?, meta),
42            ("value", false) => Err(meta.error("variable seeds cannot specify value")),
43            ("type", _) => r#type.set(TypeNode::from_meta(meta.as_value()?)?, meta),
44            _ => Err(meta.error("unrecognized attribute")),
45        })?;
46
47        // Resolve linked seed if possible.
48        if !constant_seed && !r#type.is_set() {
49            let name = name.take(meta)?;
50            if !has_matching_field(ctx, &name) {
51                let message = format!("Could not find field \"{name}\". Either specify a `type` for the seed or use a name that matches a struct or variant field.");
52                return Err(meta.error(message));
53            }
54            return Ok(Self {
55                seed: SeedDirectiveType::Linked(name),
56            });
57        }
58
59        match constant_seed {
60            true => Ok(Self {
61                seed: SeedDirectiveType::Defined(
62                    ConstantPdaSeedNode::new(r#type.take(meta)?, value.take(meta)?).into(),
63                ),
64            }),
65            false => Ok(Self {
66                seed: SeedDirectiveType::Defined(
67                    VariablePdaSeedNode::new(name.take(meta)?, r#type.take(meta)?).into(),
68                ),
69            }),
70        }
71    }
72}
73
74fn has_matching_field(ctx: &AttributeContext, name: &str) -> bool {
75    let Some(fields) = ctx.get_named_fields() else {
76        return false;
77    };
78
79    fields
80        .named
81        .iter()
82        .any(|f| f.ident.as_ref().is_some_and(|id| id == name))
83}
84
85impl<'a> TryFrom<&'a CodamaAttribute<'a>> for &'a SeedDirective {
86    type Error = CodamaError;
87
88    fn try_from(attribute: &'a CodamaAttribute) -> Result<Self, Self::Error> {
89        match attribute.directive {
90            CodamaDirective::Seed(ref a) => Ok(a),
91            _ => Err(CodamaError::InvalidCodamaDirective {
92                expected: "seed".to_string(),
93                actual: attribute.directive.name().to_string(),
94            }),
95        }
96    }
97}
98
99impl<'a> TryFrom<&'a Attribute<'a>> for &'a SeedDirective {
100    type Error = CodamaError;
101
102    fn try_from(attribute: &'a Attribute) -> Result<Self, Self::Error> {
103        <&CodamaAttribute>::try_from(attribute)?.try_into()
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use codama_nodes::{NumberFormat::U8, NumberTypeNode, NumberValueNode, PublicKeyTypeNode};
110
111    use super::*;
112
113    #[test]
114    fn defined_constant() {
115        let meta: Meta = syn::parse_quote! { seed(type = number(u8), value = 42) };
116        let item = syn::parse_quote! { struct Foo; };
117        let ctx = AttributeContext::Item(&item);
118        let directive = SeedDirective::parse(&meta, &ctx).unwrap();
119        assert_eq!(
120            directive,
121            SeedDirective {
122                seed: SeedDirectiveType::Defined(
123                    ConstantPdaSeedNode::new(NumberTypeNode::le(U8), NumberValueNode::new(42u8))
124                        .into()
125                ),
126            }
127        );
128    }
129
130    #[test]
131    fn defined_variable() {
132        let meta: Meta = syn::parse_quote! { seed(name = "authority", type = public_key) };
133        let item = syn::parse_quote! { struct Foo; };
134        let ctx = AttributeContext::Item(&item);
135        let directive = SeedDirective::parse(&meta, &ctx).unwrap();
136        assert_eq!(
137            directive,
138            SeedDirective {
139                seed: SeedDirectiveType::Defined(
140                    VariablePdaSeedNode::new("authority", PublicKeyTypeNode::new()).into()
141                ),
142            }
143        );
144    }
145
146    #[test]
147    fn linked_seed() {
148        let meta: Meta = syn::parse_quote! { seed(name = "authority") };
149        let item = syn::parse_quote! { struct Foo { authority: PubKey } };
150        let ctx = AttributeContext::Item(&item);
151        let directive = SeedDirective::parse(&meta, &ctx).unwrap();
152        assert_eq!(
153            directive,
154            SeedDirective {
155                seed: SeedDirectiveType::Linked("authority".to_string()),
156            }
157        );
158    }
159
160    #[test]
161    fn linked_seed_in_variant() {
162        let meta: Meta = syn::parse_quote! { seed(name = "authority") };
163        let item: syn::Variant = syn::parse_quote! { Foo { authority: PubKey } };
164        let ctx = AttributeContext::Variant(&item);
165        let directive = SeedDirective::parse(&meta, &ctx).unwrap();
166        assert_eq!(
167            directive,
168            SeedDirective {
169                seed: SeedDirectiveType::Linked("authority".to_string()),
170            }
171        );
172    }
173
174    #[test]
175    fn cannot_identify_seed_type() {
176        let meta: Meta = syn::parse_quote! { seed(type = public_key) };
177        let item = syn::parse_quote! { struct Foo; };
178        let ctx = AttributeContext::Item(&item);
179        let error = SeedDirective::parse(&meta, &ctx).unwrap_err();
180        assert_eq!(error.to_string(), "seed must at least specify `name` for variable seeds or `type` and `value` for constant seeds");
181    }
182
183    #[test]
184    fn cannot_find_linked_field() {
185        let meta: Meta = syn::parse_quote! { seed(name = "authority") };
186        let item = syn::parse_quote! { struct Foo { owner: PubKey } };
187        let ctx = AttributeContext::Item(&item);
188        let error = SeedDirective::parse(&meta, &ctx).unwrap_err();
189        assert_eq!(
190            error.to_string(),
191            "Could not find field \"authority\". Either specify a `type` for the seed or use a name that matches a struct or variant field."
192        );
193    }
194
195    #[test]
196    fn value_with_name() {
197        let meta: Meta = syn::parse_quote! { seed(name = "amount", value = 42) };
198        let item = syn::parse_quote! { struct Foo; };
199        let ctx = AttributeContext::Item(&item);
200        let error = SeedDirective::parse(&meta, &ctx).unwrap_err();
201        assert_eq!(error.to_string(), "variable seeds cannot specify value");
202    }
203
204    #[test]
205    fn name_with_value() {
206        let meta: Meta = syn::parse_quote! { seed(value = 42, name = "amount") };
207        let item = syn::parse_quote! { struct Foo; };
208        let ctx = AttributeContext::Item(&item);
209        let error = SeedDirective::parse(&meta, &ctx).unwrap_err();
210        assert_eq!(error.to_string(), "constant seeds cannot specify name");
211    }
212}