use quote::{quote, ToTokens};
use std::str::FromStr;
use proc_macro2::{Ident, Span, TokenStream};
use shank_macro_impl::{
parsed_struct::{
ProcessedSeed, Seed, ACCOUNT_INFO_TY, FULL_ACCOUNT_INFO_TY,
FULL_PUBKEY_TY, PUBKEY_TY,
},
syn::{Error as ParseError, Result as ParseResult},
types::{Composite, ParsedReference, Primitive, RustType, TypeKind, Value},
};
use super::render_args_comments;
pub fn try_render_seeds_fn(
processed_seeds: &[ProcessedSeed],
seeds_fn_name: &Ident,
seeds_fn_with_bump_name: &Ident,
include_comments: bool,
) -> ParseResult<Option<TokenStream>> {
let lifetime = "a";
let RenderedSeedsParts {
seed_array_items,
seed_fn_args,
} = try_render_seeds_parts(processed_seeds, lifetime)?;
if seed_array_items.is_empty() {
return Ok(None);
}
let len = seed_array_items.len();
let lifetime_toks = format!("<'{}>", lifetime).parse::<TokenStream>()?;
let len_with_bump = len + 1;
let bump = if seed_fn_args.is_empty() {
quote! { bump: &'a [u8; 1] }
} else {
quote! { , bump: &'a [u8; 1] }
};
let (seeds_comments, seeds_with_bump_comments) = if include_comments {
let args_comments = render_args_comments(processed_seeds, false);
(
format!(
r#"
/// Derives the seeds for this account.
///
{}"#,
args_comments.join("\n")
)
.to_token_stream(),
format!(
r#"
/// Derives the seeds for this account allowing to provide a bump seed.
///
{}
/// * **bump**: the bump seed to pass when deriving the PDA"#,
args_comments.join("\n")
)
.to_token_stream(),
)
} else {
(TokenStream::new(), TokenStream::new())
};
Ok(Some(quote! {
#seeds_comments
#[allow(unused, clippy::needless_lifetimes)]
pub fn #seeds_fn_name#lifetime_toks(#(#seed_fn_args),*) -> [&'a [u8]; #len] {
[#(#seed_array_items),*]
}
#seeds_with_bump_comments
#[allow(unused, clippy::needless_lifetimes)]
pub fn #seeds_fn_with_bump_name#lifetime_toks(#(#seed_fn_args),*#bump) -> [&'a [u8]; #len_with_bump] {
[#(#seed_array_items),*, bump]
}
}))
}
#[derive(Debug)]
struct RenderedSeedsParts {
seed_array_items: Vec<TokenStream>,
seed_fn_args: Vec<TokenStream>,
}
fn try_render_seeds_parts(
processed_seeds: &[ProcessedSeed],
lifetime: &str,
) -> ParseResult<RenderedSeedsParts> {
let seed_fn_args = processed_seeds
.iter()
.map(|x| render_seed_function_arg(x, lifetime))
.collect::<ParseResult<Vec<Option<TokenStream>>>>()?
.into_iter()
.filter(Option::is_some)
.flatten()
.collect::<Vec<TokenStream>>();
let seed_array_items = processed_seeds
.iter()
.map(render_seed_array_item)
.collect::<ParseResult<Vec<TokenStream>>>()?
.into_iter()
.collect::<Vec<TokenStream>>();
Ok(RenderedSeedsParts {
seed_fn_args,
seed_array_items,
})
}
fn render_seed_function_arg(
seed: &ProcessedSeed,
lifetime: &str,
) -> ParseResult<Option<TokenStream>> {
match &seed.seed {
Seed::Literal(_) => {
Ok(None)
}
Seed::ProgramId => {
let arg = seed
.arg
.as_ref()
.unwrap()
.ty
.try_with_lifetime(lifetime)?
.render_param("program_id");
Ok(Some(arg))
}
Seed::Param(name, _, _) => {
let ty = seed.arg.as_ref().unwrap().ty.clone();
let arg = adapt_seed_function_arg_type_kind(ty)
.try_with_lifetime(lifetime)?
.render_param(name);
Ok(Some(arg))
}
}
}
fn adapt_seed_function_arg_type_kind(ty: RustType) -> RustType {
match ty.kind {
TypeKind::Primitive(Primitive::U8) => {
let kind = TypeKind::Composite(
Composite::Array(1),
vec![ty.clone().as_owned()],
);
RustType {
kind,
reference: ParsedReference::Ref(None),
..ty
}
}
TypeKind::Primitive(_) => ty,
TypeKind::Value(_) => ty,
TypeKind::Composite(_, _) => ty,
TypeKind::Unit => ty,
TypeKind::Unknown => ty,
}
}
fn render_seed_array_item(seed: &ProcessedSeed) -> ParseResult<TokenStream> {
match &seed.seed {
Seed::Literal(lit) => {
let item = TokenStream::from_str(&format!("b\"{}\"", lit))?;
Ok(item)
}
Seed::ProgramId => {
let item =
seed_array_item("program_id", &seed.arg.as_ref().unwrap().ty)?;
Ok(item)
}
Seed::Param(name, _, _) => {
let item =
seed_array_item(name.as_str(), &seed.arg.as_ref().unwrap().ty)?;
Ok(item)
}
}
}
fn seed_array_item(name: &str, ty: &RustType) -> ParseResult<TokenStream> {
let ident = Ident::new(name, Span::call_site());
match &ty.kind {
TypeKind::Primitive(p) if p == &Primitive::Bool => {
Ok(quote! { &[if #ident { 1 } else { 0 } ] })
}
TypeKind::Primitive(Primitive::U8) => Ok(quote! { #ident }),
TypeKind::Primitive(prim) => Err(ParseError::new(
Span::call_site(),
format!(
"Unsupported primitive type: {}, only u8 is supported. Consider using String or str instead.",
prim
),
)),
TypeKind::Value(Value::String)
| TypeKind::Value(Value::CString)
| TypeKind::Value(Value::Str) => Ok(quote! { #ident.as_bytes() }),
TypeKind::Value(Value::Custom(x))
if x == PUBKEY_TY
|| x == FULL_PUBKEY_TY
|| x == ACCOUNT_INFO_TY
|| x == FULL_ACCOUNT_INFO_TY =>
{
Ok(quote! { #ident.as_ref() })
}
TypeKind::Value(Value::Custom(x)) => Err(ParseError::new(
ty.ident.span(),
format!("Custom seed type {} not supported yet", x),
)),
TypeKind::Composite(k1, k2) => Err(ParseError::new(
ty.ident.span(),
format!(
"Composite seed types aren't supported yet ({:?}, {:?})",
k1, k2
),
)),
TypeKind::Unit => {
Err(ParseError::new(ident.span(), "Seeds cannot be unit type"))
}
TypeKind::Unknown => Err(ParseError::new(
ident.span(),
"Seeds cannot be of unknown type",
)),
}
}
#[cfg(test)]
mod tests {
use shank_macro_impl::parsed_struct::Seeds;
use super::*;
fn render_seeds_parts(seeds: &[Seed]) -> RenderedSeedsParts {
let processed_seeds = Seeds(seeds.to_vec())
.process()
.expect("should process seeds without error");
try_render_seeds_parts(&processed_seeds, "a")
.expect("Should render seeds without error")
}
fn assert_tokenstream_eq(actual: &TokenStream, expected: &str) {
let expected_ts = expected.parse::<TokenStream>().unwrap().to_string();
assert_eq!(actual.to_string(), expected_ts);
}
#[test]
fn render_seed_literal() {
let seed = Seed::Literal("uno".to_string());
let RenderedSeedsParts {
seed_array_items,
seed_fn_args,
} = render_seeds_parts(&[seed]);
assert_eq!(seed_array_items.len(), 1);
assert_tokenstream_eq(&seed_array_items[0], "b\"uno\"");
assert_eq!(seed_fn_args.len(), 0);
}
#[test]
fn process_seed_program_id() {
let seed = Seed::ProgramId;
let RenderedSeedsParts {
seed_array_items,
seed_fn_args,
} = render_seeds_parts(&[seed]);
let expected_item = quote! { program_id.as_ref() }.to_string();
assert_eq!(seed_array_items.len(), 1);
assert_eq!(seed_fn_args.len(), 1);
assert_eq!(seed_array_items[0].to_string(), expected_item);
assert_tokenstream_eq(
&seed_fn_args[0],
"program_id : &'a ::solana_program::pubkey::Pubkey",
);
}
#[test]
fn process_seed_custom_pubkey() {
let seed =
Seed::Param("owner".to_string(), "The owner".to_string(), None);
let RenderedSeedsParts {
seed_array_items,
seed_fn_args,
} = render_seeds_parts(&[seed]);
let expected_item = quote! { owner.as_ref() }.to_string();
assert_eq!(seed_array_items.len(), 1);
assert_eq!(seed_fn_args.len(), 1);
assert_eq!(seed_array_items[0].to_string(), expected_item);
assert_tokenstream_eq(
&seed_fn_args[0],
"owner : &'a ::solana_program::pubkey::Pubkey",
);
}
#[test]
fn process_seed_explicit_custom_pubkey() {
let seed = Seed::Param(
"owner".to_string(),
"The owner".to_string(),
Some("Pubkey".to_string()),
);
let RenderedSeedsParts {
seed_array_items,
seed_fn_args,
} = render_seeds_parts(&[seed]);
let expected_item = quote! { owner.as_ref() }.to_string();
assert_eq!(seed_array_items.len(), 1);
assert_eq!(seed_fn_args.len(), 1);
assert_eq!(seed_array_items[0].to_string(), expected_item);
assert_tokenstream_eq(
&seed_fn_args[0],
"owner : &'a ::solana_program::pubkey::Pubkey",
);
}
}