1use proc_macro2::TokenStream;
4use quote::quote;
5use syn::{Attribute, ItemStruct, Result, Visibility, parse_quote};
6
7use crate::format::DocRec;
8
9#[derive(Debug, Clone)]
10pub struct CapConfig {
11 pub input: ItemStruct,
12}
13
14impl CapConfig {
15 pub fn new(mut input: ItemStruct, doc_rec: DocRec) -> Result<Self> {
16 if !matches!(input.vis, Visibility::Public(_)) {
18 return Err(syn::Error::new_spanned(
19 &input.vis,
20 "capability_config structs must be public",
21 ));
22 }
23
24 Self::validate_docs(&input, doc_rec)?;
26
27 let serde_crate: Attribute = parse_quote!(
33 #[serde(crate = "::pyroduct::format::serde")]
34 );
35
36 let serde_derive: Attribute = parse_quote!(
38 #[derive(::pyroduct::format::serde::Serialize, ::pyroduct::format::serde::Deserialize)]
39 );
40
41 input.attrs.insert(0, serde_crate);
43 input.attrs.insert(0, serde_derive);
44 Ok(Self { input })
47 }
48
49 fn validate_docs(input: &ItemStruct, doc_rec: DocRec) -> Result<()> {
50 let has_struct_doc = input.attrs.iter().any(|a| a.path().is_ident("doc"));
51
52 match doc_rec {
53 DocRec::StructDoc | DocRec::AllDoc => {
54 if !has_struct_doc {
55 return Err(syn::Error::new_spanned(
56 &input.ident,
57 "Configuration struct must be documented",
58 ));
59 }
60 }
61 _ => {}
62 }
63
64 if doc_rec == DocRec::AllDoc {
65 if let syn::Fields::Named(fields) = &input.fields {
66 for field in &fields.named {
67 let has_field_doc = field.attrs.iter().any(|a| a.path().is_ident("doc"));
68 if !has_field_doc {
69 let tokens = if let Some(ident) = &field.ident {
71 quote! { #ident }
72 } else {
73 quote! { #field }
74 };
75
76 return Err(syn::Error::new_spanned(
77 tokens,
78 "Configuration fields must be documented",
79 ));
80 }
81 }
82 }
83 }
84 Ok(())
85 }
86
87 pub fn expand(&self) -> TokenStream {
89 let input = &self.input;
90 quote! { #input }
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97 use syn::parse2;
98
99 fn expand_config(code: TokenStream, doc_rec: DocRec) -> TokenStream {
101 let item = parse2(code).expect("Failed to parse struct input");
102 CapConfig::new(item, doc_rec)
103 .expect("CapConfig validation failed")
104 .expand()
105 }
106
107 #[test]
108 fn test_config_basic() {
109 let code = quote! {
110 pub struct MyConfig {
111 pub host: String,
112 pub port: u16,
113 }
114 };
115
116 let output = expand_config(code, DocRec::NoReq);
117
118 let expected = quote! {
119 #[derive(::pyroduct::format::serde::Serialize, ::pyroduct::format::serde::Deserialize)]
120 #[serde(crate = "::pyroduct::format::serde")]
121 pub struct MyConfig {
122 pub host: String,
123 pub port: u16,
124 }
125 };
126
127 crate::fmt::assert_code_eq_token(&output, &expected);
128 }
129
130 #[test]
131 fn test_doc_rec_struct_missing() {
132 let code = quote! {
133 pub struct Undocumented {
134 pub x: i32,
135 }
136 };
137 let item = parse2(code).unwrap();
138
139 let err = CapConfig::new(item, DocRec::StructDoc).unwrap_err();
141 assert_eq!(err.to_string(), "Configuration struct must be documented");
142 }
143
144 #[test]
145 fn test_doc_rec_field_missing() {
146 let code = quote! {
147 pub struct PartiallyDocumented {
149 pub x: i32,
151 pub y: i32, }
153 };
154 let item: ItemStruct = parse2(code).unwrap();
155
156 assert!(CapConfig::new(item.clone(), DocRec::StructDoc).is_ok());
158
159 let err = CapConfig::new(item, DocRec::AllDoc).unwrap_err();
161 assert_eq!(err.to_string(), "Configuration fields must be documented");
162 }
163
164 #[test]
165 fn test_doc_rec_full_success() {
166 let code = quote! {
167 pub struct ServerConfig {
169 pub host: String,
171 pub port: u16,
173 }
174 };
175 let item = parse2(code).unwrap();
176 assert!(CapConfig::new(item, DocRec::AllDoc).is_ok());
177 }
178
179 #[test]
180 fn test_config_with_generics_allowed() {
181 let code = quote! {
183 #[derive(Clone, Debug)]
184 pub struct GenericConfig<T> {
185 pub options: T,
186 }
187 };
188
189 let output = expand_config(code, DocRec::NoReq);
191
192 let expected = quote! {
195 #[derive(::pyroduct::format::serde::Serialize, ::pyroduct::format::serde::Deserialize)]
196 #[serde(crate = "::pyroduct::format::serde")]
197 #[derive(Clone, Debug)]
198 pub struct GenericConfig<T> {
199 pub options: T,
200 }
201 };
202
203 crate::fmt::assert_code_eq_token(&output, &expected);
204 }
205
206 #[test]
207 fn test_config_tuple_struct() {
208 let code = quote! {
210 pub struct TupleConfig(String, u32);
211 };
212
213 let output = expand_config(code, DocRec::NoReq);
215
216 let expected = quote! {
218 #[derive(::pyroduct::format::serde::Serialize, ::pyroduct::format::serde::Deserialize)]
219 #[serde(crate = "::pyroduct::format::serde")]
220 pub struct TupleConfig(String, u32);
221 };
222
223 crate::fmt::assert_code_eq_token(&output, &expected);
224 }
225
226 #[test]
227 fn test_validation_still_requires_pub() {
228 let code_vis = quote! {
229 struct PrivateConfig { timeout: u64 }
230 };
231 let item_vis = parse2(code_vis).unwrap();
232 let res_vis = CapConfig::new(item_vis, DocRec::NoReq);
233 assert!(res_vis.is_err());
234 assert_eq!(
235 res_vis.unwrap_err().to_string(),
236 "capability_config structs must be public"
237 );
238 }
239}