1use proc_macro2::TokenStream;
4use quote::quote;
5use syn::{Attribute, ItemStruct, Result, Visibility, parse_quote};
6
7use crate::format::DocRec;
8
9#[derive(Debug, Clone)]
10pub struct CapConfig {
11 pub input: ItemStruct,
12}
13
14impl CapConfig {
15 pub fn new(mut input: ItemStruct, doc_rec: DocRec) -> Result<Self> {
16 if !matches!(input.vis, Visibility::Public(_)) {
18 return Err(syn::Error::new_spanned(
19 &input.vis,
20 "capability_config structs must be public",
21 ));
22 }
23
24 Self::validate_docs(&input, doc_rec)?;
26
27 let serde_crate: Attribute = parse_quote!(
33 #[serde(crate = "::pyroduct::format::serde")]
34 );
35
36 let serde_derive: Attribute = parse_quote!(
38 #[derive(::pyroduct::format::serde::Serialize, ::pyroduct::format::serde::Deserialize)]
39 );
40
41 input.attrs.insert(0, serde_crate);
43 input.attrs.insert(0, serde_derive);
44 Ok(Self { input })
47 }
48
49 fn validate_docs(input: &ItemStruct, doc_rec: DocRec) -> Result<()> {
50 let has_struct_doc = input.attrs.iter().any(|a| a.path().is_ident("doc"));
51
52 match doc_rec {
53 DocRec::StructDoc | DocRec::AllDoc if !has_struct_doc => {
54 return Err(syn::Error::new_spanned(
55 &input.ident,
56 "Configuration struct must be documented",
57 ));
58 }
59 _ => {}
60 }
61
62 if doc_rec == DocRec::AllDoc
63 && let syn::Fields::Named(fields) = &input.fields
64 {
65 for field in &fields.named {
66 let has_field_doc = field.attrs.iter().any(|a| a.path().is_ident("doc"));
67 if !has_field_doc {
68 let tokens = if let Some(ident) = &field.ident {
70 quote! { #ident }
71 } else {
72 quote! { #field }
73 };
74
75 return Err(syn::Error::new_spanned(
76 tokens,
77 "Configuration fields must be documented",
78 ));
79 }
80 }
81 }
82 Ok(())
83 }
84
85 pub fn expand(&self) -> TokenStream {
87 let input = &self.input;
88 quote! { #input }
89 }
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95 use syn::parse2;
96
97 fn expand_config(code: TokenStream, doc_rec: DocRec) -> TokenStream {
99 let item = parse2(code).expect("Failed to parse struct input");
100 CapConfig::new(item, doc_rec)
101 .expect("CapConfig validation failed")
102 .expand()
103 }
104
105 #[test]
106 fn test_config_basic() {
107 let code = quote! {
108 pub struct MyConfig {
109 pub host: String,
110 pub port: u16,
111 }
112 };
113
114 let output = expand_config(code, DocRec::NoReq);
115
116 let expected = quote! {
117 #[derive(::pyroduct::format::serde::Serialize, ::pyroduct::format::serde::Deserialize)]
118 #[serde(crate = "::pyroduct::format::serde")]
119 pub struct MyConfig {
120 pub host: String,
121 pub port: u16,
122 }
123 };
124
125 crate::fmt::assert_code_eq_token(&output, &expected);
126 }
127
128 #[test]
129 fn test_doc_rec_struct_missing() {
130 let code = quote! {
131 pub struct Undocumented {
132 pub x: i32,
133 }
134 };
135 let item = parse2(code).unwrap();
136
137 let err = CapConfig::new(item, DocRec::StructDoc).unwrap_err();
139 assert_eq!(err.to_string(), "Configuration struct must be documented");
140 }
141
142 #[test]
143 fn test_doc_rec_field_missing() {
144 let code = quote! {
145 pub struct PartiallyDocumented {
147 pub x: i32,
149 pub y: i32, }
151 };
152 let item: ItemStruct = parse2(code).unwrap();
153
154 assert!(CapConfig::new(item.clone(), DocRec::StructDoc).is_ok());
156
157 let err = CapConfig::new(item, DocRec::AllDoc).unwrap_err();
159 assert_eq!(err.to_string(), "Configuration fields must be documented");
160 }
161
162 #[test]
163 fn test_doc_rec_full_success() {
164 let code = quote! {
165 pub struct ServerConfig {
167 pub host: String,
169 pub port: u16,
171 }
172 };
173 let item = parse2(code).unwrap();
174 assert!(CapConfig::new(item, DocRec::AllDoc).is_ok());
175 }
176
177 #[test]
178 fn test_config_with_generics_allowed() {
179 let code = quote! {
181 #[derive(Clone, Debug)]
182 pub struct GenericConfig<T> {
183 pub options: T,
184 }
185 };
186
187 let output = expand_config(code, DocRec::NoReq);
189
190 let expected = quote! {
193 #[derive(::pyroduct::format::serde::Serialize, ::pyroduct::format::serde::Deserialize)]
194 #[serde(crate = "::pyroduct::format::serde")]
195 #[derive(Clone, Debug)]
196 pub struct GenericConfig<T> {
197 pub options: T,
198 }
199 };
200
201 crate::fmt::assert_code_eq_token(&output, &expected);
202 }
203
204 #[test]
205 fn test_config_tuple_struct() {
206 let code = quote! {
208 pub struct TupleConfig(String, u32);
209 };
210
211 let output = expand_config(code, DocRec::NoReq);
213
214 let expected = quote! {
216 #[derive(::pyroduct::format::serde::Serialize, ::pyroduct::format::serde::Deserialize)]
217 #[serde(crate = "::pyroduct::format::serde")]
218 pub struct TupleConfig(String, u32);
219 };
220
221 crate::fmt::assert_code_eq_token(&output, &expected);
222 }
223
224 #[test]
225 fn test_validation_still_requires_pub() {
226 let code_vis = quote! {
227 struct PrivateConfig { timeout: u64 }
228 };
229 let item_vis = parse2(code_vis).unwrap();
230 let res_vis = CapConfig::new(item_vis, DocRec::NoReq);
231 assert!(res_vis.is_err());
232 assert_eq!(
233 res_vis.unwrap_err().to_string(),
234 "capability_config structs must be public"
235 );
236 }
237}