1use proc_macro::TokenStream;
7use quote::quote;
8use syn::{parse_macro_input, Attribute, Data, DeriveInput, Fields};
9
10#[proc_macro_derive(McpConfig, attributes(mcp))]
41pub fn derive_mcp_config(input: TokenStream) -> TokenStream {
42 let input = parse_macro_input!(input as DeriveInput);
43
44 match generate_mcp_config_impl(&input) {
45 Ok(tokens) => tokens.into(),
46 Err(err) => err.to_compile_error().into(),
47 }
48}
49
50#[proc_macro_derive(McpBackend, attributes(mcp_backend))]
81pub fn derive_mcp_backend(input: TokenStream) -> TokenStream {
82 let input = parse_macro_input!(input as DeriveInput);
83
84 match generate_mcp_backend_impl(&input) {
85 Ok(tokens) => tokens.into(),
86 Err(err) => err.to_compile_error().into(),
87 }
88}
89
90fn generate_mcp_config_impl(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
91 let name = &input.ident;
92
93 let fields = match &input.data {
95 Data::Struct(data) => match &data.fields {
96 Fields::Named(fields) => &fields.named,
97 _ => {
98 return Err(syn::Error::new_spanned(
99 input,
100 "McpConfig can only be derived for structs with named fields",
101 ))
102 }
103 },
104 _ => {
105 return Err(syn::Error::new_spanned(
106 input,
107 "McpConfig can only be derived for structs",
108 ))
109 }
110 };
111
112 let mut server_info_field = None;
114 let mut logging_field = None;
115 let mut auto_populate_fields = Vec::new();
116
117 for field in fields {
118 if let Some(ident) = &field.ident {
119 for attr in &field.attrs {
121 if attr.path().is_ident("mcp") {
122 parse_mcp_attribute(
123 attr,
124 ident,
125 &mut server_info_field,
126 &mut logging_field,
127 &mut auto_populate_fields,
128 )?;
129 }
130 }
131
132 match ident.to_string().as_str() {
134 "server_info" => server_info_field = Some(ident.clone()),
135 "logging" => logging_field = Some(ident.clone()),
136 _ => {}
137 }
138 }
139 }
140
141 let server_info_impl = generate_server_info_impl(&server_info_field);
143 let logging_impl = generate_logging_impl(&logging_field);
144 let auto_populate_impl = generate_auto_populate_impl(&auto_populate_fields);
145
146 Ok(quote! {
147 impl pulseengine_mcp_cli::McpConfiguration for #name {
148 #server_info_impl
149 #logging_impl
150
151 fn validate(&self) -> std::result::Result<(), pulseengine_mcp_cli::CliError> {
152 Ok(())
154 }
155 }
156
157 impl #name {
158 pub fn with_auto_populate() -> Self
160 where
161 Self: Default,
162 {
163 let mut instance = Self::default();
164 instance.auto_populate();
165 instance
166 }
167
168 pub fn auto_populate(&mut self) {
170 #auto_populate_impl
171 }
172 }
173 })
174}
175
176fn generate_mcp_backend_impl(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
177 let name = &input.ident;
178
179 let backend_config = parse_backend_attributes(input)?;
181
182 let fields = match &input.data {
184 Data::Struct(data) => match &data.fields {
185 Fields::Named(fields) => &fields.named,
186 _ => {
187 return Err(syn::Error::new_spanned(
188 input,
189 "McpBackend can only be derived for structs with named fields",
190 ))
191 }
192 },
193 _ => {
194 return Err(syn::Error::new_spanned(
195 input,
196 "McpBackend can only be derived for structs",
197 ))
198 }
199 };
200
201 let mut delegate_field = None;
203 let mut error_from_fields = Vec::new();
204
205 for field in fields {
206 if let Some(ident) = &field.ident {
207 for attr in &field.attrs {
208 if attr.path().is_ident("mcp_backend") {
209 parse_backend_field_attribute(attr, ident, &mut delegate_field, &mut error_from_fields)?;
210 }
211 }
212 }
213 }
214
215 let error_type = backend_config.error_type.as_ref()
217 .map(|s| syn::parse_str::<syn::Type>(s))
218 .transpose()?
219 .unwrap_or_else(|| syn::parse_str(&format!("{name}Error")).unwrap());
220
221 let config_type = backend_config.config_type.as_ref()
222 .map(|s| syn::parse_str::<syn::Type>(s))
223 .transpose()?
224 .unwrap_or_else(|| syn::parse_str(&format!("{name}Config")).unwrap());
225
226 let error_definition = if backend_config.error_type.is_none() {
228 generate_error_type_definition(name, &error_from_fields)
229 } else {
230 quote! {}
231 };
232
233 let trait_impl = if backend_config.simple_backend {
235 generate_simple_backend_impl(name, &error_type, &config_type, &delegate_field)
236 } else {
237 generate_full_backend_impl(name, &error_type, &config_type, &delegate_field)
238 };
239
240 Ok(quote! {
241 #error_definition
242 #trait_impl
243 })
244}
245
246#[derive(Default)]
247struct BackendConfig {
248 error_type: Option<String>,
249 config_type: Option<String>,
250 simple_backend: bool,
251}
252
253fn parse_backend_attributes(input: &DeriveInput) -> syn::Result<BackendConfig> {
254 let mut config = BackendConfig::default();
255
256 for attr in &input.attrs {
257 if attr.path().is_ident("mcp_backend") {
258 attr.parse_nested_meta(|meta| {
259 if meta.path.is_ident("simple") {
260 config.simple_backend = true;
261 Ok(())
262 } else if meta.path.is_ident("error") {
263 if let Ok(value) = meta.value() {
264 if let Ok(lit) = value.parse::<syn::LitStr>() {
265 config.error_type = Some(lit.value());
266 }
267 }
268 Ok(())
269 } else if meta.path.is_ident("config") {
270 if let Ok(value) = meta.value() {
271 if let Ok(lit) = value.parse::<syn::LitStr>() {
272 config.config_type = Some(lit.value());
273 }
274 }
275 Ok(())
276 } else {
277 Err(meta.error(format!("unsupported mcp_backend attribute: {}", meta.path.get_ident().unwrap())))
278 }
279 })?;
280 }
281 }
282
283 Ok(config)
284}
285
286fn parse_backend_field_attribute(
287 attr: &Attribute,
288 field_ident: &syn::Ident,
289 delegate_field: &mut Option<syn::Ident>,
290 error_from_fields: &mut Vec<syn::Ident>,
291) -> syn::Result<()> {
292 attr.parse_nested_meta(|meta| {
293 if meta.path.is_ident("delegate") {
294 *delegate_field = Some(field_ident.clone());
295 Ok(())
296 } else if meta.path.is_ident("error_from") {
297 error_from_fields.push(field_ident.clone());
298 Ok(())
299 } else {
300 Err(meta.error(format!("unsupported mcp_backend field attribute: {}", meta.path.get_ident().unwrap())))
301 }
302 })
303}
304
305fn generate_error_type_definition(
306 name: &syn::Ident,
307 error_from_fields: &[syn::Ident],
308) -> proc_macro2::TokenStream {
309 let error_name = syn::Ident::new(&format!("{name}Error"), name.span());
310
311 let from_implementations = error_from_fields.iter().map(|_field| {
312 quote! {
314 impl From<std::io::Error> for #error_name {
315 fn from(err: std::io::Error) -> Self {
316 Self::Internal(err.to_string())
317 }
318 }
319 }
320 });
321
322 quote! {
323 #[derive(Debug, thiserror::Error)]
324 pub enum #error_name {
325 #[error("Configuration error: {0}")]
326 Configuration(String),
327
328 #[error("Connection error: {0}")]
329 Connection(String),
330
331 #[error("Operation not supported: {0}")]
332 NotSupported(String),
333
334 #[error("Internal error: {0}")]
335 Internal(String),
336 }
337
338 impl #error_name {
339 pub fn configuration(msg: impl Into<String>) -> Self {
340 Self::Configuration(msg.into())
341 }
342
343 pub fn connection(msg: impl Into<String>) -> Self {
344 Self::Connection(msg.into())
345 }
346
347 pub fn not_supported(msg: impl Into<String>) -> Self {
348 Self::NotSupported(msg.into())
349 }
350
351 pub fn internal(msg: impl Into<String>) -> Self {
352 Self::Internal(msg.into())
353 }
354 }
355
356 impl From<pulseengine_mcp_server::backend::BackendError> for #error_name {
357 fn from(err: pulseengine_mcp_server::backend::BackendError) -> Self {
358 Self::Internal(err.to_string())
359 }
360 }
361
362 impl From<#error_name> for pulseengine_mcp_protocol::Error {
363 fn from(err: #error_name) -> Self {
364 match err {
365 #error_name::Configuration(msg) => Self::invalid_params(msg),
366 #error_name::Connection(msg) => Self::internal_error(format!("Connection failed: {msg}")),
367 #error_name::NotSupported(msg) => Self::method_not_found(msg),
368 #error_name::Internal(msg) => Self::internal_error(msg),
369 }
370 }
371 }
372
373 #(#from_implementations)*
374 }
375}
376
377fn generate_simple_backend_impl(
378 name: &syn::Ident,
379 error_type: &syn::Type,
380 config_type: &syn::Type,
381 delegate_field: &Option<syn::Ident>,
382) -> proc_macro2::TokenStream {
383 if let Some(delegate) = delegate_field {
384 quote! {
385 #[async_trait::async_trait]
386 impl pulseengine_mcp_server::backend::SimpleBackend for #name {
387 type Error = #error_type;
388 type Config = #config_type;
389
390 async fn initialize(config: Self::Config) -> std::result::Result<Self, Self::Error> {
391 Err(Self::Error::not_supported("Backend initialization not implemented"))
393 }
394
395 fn get_server_info(&self) -> pulseengine_mcp_protocol::ServerInfo {
396 self.#delegate.get_server_info()
397 }
398
399 async fn health_check(&self) -> std::result::Result<(), Self::Error> {
400 self.#delegate.health_check().await.map_err(Into::into)
401 }
402
403 async fn list_tools(
404 &self,
405 request: pulseengine_mcp_protocol::PaginatedRequestParam,
406 ) -> std::result::Result<pulseengine_mcp_protocol::ListToolsResult, Self::Error> {
407 self.#delegate.list_tools(request).await.map_err(Into::into)
408 }
409
410 async fn call_tool(
411 &self,
412 request: pulseengine_mcp_protocol::CallToolRequestParam,
413 ) -> std::result::Result<pulseengine_mcp_protocol::CallToolResult, Self::Error> {
414 self.#delegate.call_tool(request).await.map_err(Into::into)
415 }
416 }
417 }
418 } else {
419 quote! {
420 #[async_trait::async_trait]
421 impl pulseengine_mcp_server::backend::SimpleBackend for #name {
422 type Error = #error_type;
423 type Config = #config_type;
424
425 async fn initialize(config: Self::Config) -> std::result::Result<Self, Self::Error> {
426 Err(Self::Error::not_supported("Backend initialization not implemented"))
428 }
429
430 fn get_server_info(&self) -> pulseengine_mcp_protocol::ServerInfo {
431 pulseengine_mcp_protocol::ServerInfo {
433 protocol_version: pulseengine_mcp_protocol::ProtocolVersion::default(),
434 capabilities: pulseengine_mcp_protocol::ServerCapabilities::default(),
435 server_info: pulseengine_mcp_protocol::Implementation {
436 name: env!("CARGO_PKG_NAME").to_string(),
437 version: env!("CARGO_PKG_VERSION").to_string(),
438 },
439 instructions: None,
440 }
441 }
442
443 async fn health_check(&self) -> std::result::Result<(), Self::Error> {
444 Ok(())
446 }
447
448 async fn list_tools(
449 &self,
450 _request: pulseengine_mcp_protocol::PaginatedRequestParam,
451 ) -> std::result::Result<pulseengine_mcp_protocol::ListToolsResult, Self::Error> {
452 Ok(pulseengine_mcp_protocol::ListToolsResult {
454 tools: vec![],
455 next_cursor: None,
456 })
457 }
458
459 async fn call_tool(
460 &self,
461 request: pulseengine_mcp_protocol::CallToolRequestParam,
462 ) -> std::result::Result<pulseengine_mcp_protocol::CallToolResult, Self::Error> {
463 Err(Self::Error::not_supported(format!("Tool not found: {}", request.name)))
465 }
466 }
467 }
468 }
469}
470
471fn generate_full_backend_impl(
472 name: &syn::Ident,
473 error_type: &syn::Type,
474 config_type: &syn::Type,
475 delegate_field: &Option<syn::Ident>,
476) -> proc_macro2::TokenStream {
477 if let Some(delegate) = delegate_field {
478 quote! {
479 #[async_trait::async_trait]
480 impl pulseengine_mcp_server::backend::McpBackend for #name {
481 type Error = #error_type;
482 type Config = #config_type;
483
484 async fn initialize(config: Self::Config) -> std::result::Result<Self, Self::Error> {
485 Err(Self::Error::not_supported("Backend initialization not implemented"))
487 }
488
489 fn get_server_info(&self) -> pulseengine_mcp_protocol::ServerInfo {
490 self.#delegate.get_server_info()
491 }
492
493 async fn health_check(&self) -> std::result::Result<(), Self::Error> {
494 self.#delegate.health_check().await.map_err(Into::into)
495 }
496
497 async fn list_tools(
499 &self,
500 request: pulseengine_mcp_protocol::PaginatedRequestParam,
501 ) -> std::result::Result<pulseengine_mcp_protocol::ListToolsResult, Self::Error> {
502 self.#delegate.list_tools(request).await.map_err(Into::into)
503 }
504
505 async fn call_tool(
506 &self,
507 request: pulseengine_mcp_protocol::CallToolRequestParam,
508 ) -> std::result::Result<pulseengine_mcp_protocol::CallToolResult, Self::Error> {
509 self.#delegate.call_tool(request).await.map_err(Into::into)
510 }
511
512 async fn list_resources(
513 &self,
514 request: pulseengine_mcp_protocol::PaginatedRequestParam,
515 ) -> std::result::Result<pulseengine_mcp_protocol::ListResourcesResult, Self::Error> {
516 self.#delegate.list_resources(request).await.map_err(Into::into)
517 }
518
519 async fn read_resource(
520 &self,
521 request: pulseengine_mcp_protocol::ReadResourceRequestParam,
522 ) -> std::result::Result<pulseengine_mcp_protocol::ReadResourceResult, Self::Error> {
523 self.#delegate.read_resource(request).await.map_err(Into::into)
524 }
525
526 async fn list_prompts(
527 &self,
528 request: pulseengine_mcp_protocol::PaginatedRequestParam,
529 ) -> std::result::Result<pulseengine_mcp_protocol::ListPromptsResult, Self::Error> {
530 self.#delegate.list_prompts(request).await.map_err(Into::into)
531 }
532
533 async fn get_prompt(
534 &self,
535 request: pulseengine_mcp_protocol::GetPromptRequestParam,
536 ) -> std::result::Result<pulseengine_mcp_protocol::GetPromptResult, Self::Error> {
537 self.#delegate.get_prompt(request).await.map_err(Into::into)
538 }
539 }
540 }
541 } else {
542 quote! {
543 #[async_trait::async_trait]
544 impl pulseengine_mcp_server::backend::McpBackend for #name {
545 type Error = #error_type;
546 type Config = #config_type;
547
548 async fn initialize(config: Self::Config) -> std::result::Result<Self, Self::Error> {
549 Err(Self::Error::not_supported("Backend initialization not implemented"))
551 }
552
553 fn get_server_info(&self) -> pulseengine_mcp_protocol::ServerInfo {
554 pulseengine_mcp_protocol::ServerInfo {
556 protocol_version: pulseengine_mcp_protocol::ProtocolVersion::default(),
557 capabilities: pulseengine_mcp_protocol::ServerCapabilities::default(),
558 server_info: pulseengine_mcp_protocol::Implementation {
559 name: env!("CARGO_PKG_NAME").to_string(),
560 version: env!("CARGO_PKG_VERSION").to_string(),
561 },
562 instructions: None,
563 }
564 }
565
566 async fn health_check(&self) -> std::result::Result<(), Self::Error> {
567 Ok(())
569 }
570
571 async fn list_tools(
573 &self,
574 _request: pulseengine_mcp_protocol::PaginatedRequestParam,
575 ) -> std::result::Result<pulseengine_mcp_protocol::ListToolsResult, Self::Error> {
576 Ok(pulseengine_mcp_protocol::ListToolsResult {
577 tools: vec![],
578 next_cursor: None,
579 })
580 }
581
582 async fn call_tool(
583 &self,
584 request: pulseengine_mcp_protocol::CallToolRequestParam,
585 ) -> std::result::Result<pulseengine_mcp_protocol::CallToolResult, Self::Error> {
586 Err(Self::Error::not_supported(format!("Tool not found: {}", request.name)))
587 }
588
589 async fn list_resources(
590 &self,
591 _request: pulseengine_mcp_protocol::PaginatedRequestParam,
592 ) -> std::result::Result<pulseengine_mcp_protocol::ListResourcesResult, Self::Error> {
593 Ok(pulseengine_mcp_protocol::ListResourcesResult {
594 resources: vec![],
595 next_cursor: None,
596 })
597 }
598
599 async fn read_resource(
600 &self,
601 request: pulseengine_mcp_protocol::ReadResourceRequestParam,
602 ) -> std::result::Result<pulseengine_mcp_protocol::ReadResourceResult, Self::Error> {
603 Err(Self::Error::not_supported(format!("Resource not found: {}", request.uri)))
604 }
605
606 async fn list_prompts(
607 &self,
608 _request: pulseengine_mcp_protocol::PaginatedRequestParam,
609 ) -> std::result::Result<pulseengine_mcp_protocol::ListPromptsResult, Self::Error> {
610 Ok(pulseengine_mcp_protocol::ListPromptsResult {
611 prompts: vec![],
612 next_cursor: None,
613 })
614 }
615
616 async fn get_prompt(
617 &self,
618 request: pulseengine_mcp_protocol::GetPromptRequestParam,
619 ) -> std::result::Result<pulseengine_mcp_protocol::GetPromptResult, Self::Error> {
620 Err(Self::Error::not_supported(format!("Prompt not found: {}", request.name)))
621 }
622 }
623 }
624 }
625}
626
627fn parse_mcp_attribute(
628 attr: &Attribute,
629 field_ident: &syn::Ident,
630 _server_info_field: &mut Option<syn::Ident>,
631 logging_field: &mut Option<syn::Ident>,
632 auto_populate_fields: &mut Vec<syn::Ident>,
633) -> syn::Result<()> {
634 attr.parse_nested_meta(|meta| {
635 if meta.path.is_ident("auto_populate") {
636 auto_populate_fields.push(field_ident.clone());
637 Ok(())
638 } else if meta.path.is_ident("logging") {
639 *logging_field = Some(field_ident.clone());
640 if meta.input.peek(syn::token::Paren) {
642 let _content;
643 syn::parenthesized!(_content in meta.input);
644 }
647 Ok(())
648 } else {
649 Err(meta.error(format!(
650 "unsupported mcp attribute: {}",
651 meta.path.get_ident().unwrap()
652 )))
653 }
654 })
655}
656
657fn generate_server_info_impl(server_info_field: &Option<syn::Ident>) -> proc_macro2::TokenStream {
658 if let Some(field) = server_info_field {
659 quote! {
660 fn get_server_info(&self) -> &pulseengine_mcp_protocol::ServerInfo {
661 self.#field.as_ref().unwrap_or_else(|| {
662 static SERVER_INFO: std::sync::OnceLock<pulseengine_mcp_protocol::ServerInfo> = std::sync::OnceLock::new();
663 SERVER_INFO.get_or_init(|| {
664 pulseengine_mcp_cli::config::create_server_info(None, None)
665 })
666 })
667 }
668 }
669 } else {
670 quote! {
671 fn get_server_info(&self) -> &pulseengine_mcp_protocol::ServerInfo {
672 use std::sync::OnceLock;
673 static SERVER_INFO: OnceLock<pulseengine_mcp_protocol::ServerInfo> = OnceLock::new();
674 SERVER_INFO.get_or_init(|| {
675 pulseengine_mcp_cli::config::create_server_info(None, None)
676 })
677 }
678 }
679 }
680}
681
682fn generate_logging_impl(logging_field: &Option<syn::Ident>) -> proc_macro2::TokenStream {
683 if let Some(field) = logging_field {
684 quote! {
685 fn get_logging_config(&self) -> &pulseengine_mcp_cli::DefaultLoggingConfig {
686 self.#field.as_ref().unwrap_or_else(|| {
687 static LOGGING_CONFIG: std::sync::OnceLock<pulseengine_mcp_cli::DefaultLoggingConfig> = std::sync::OnceLock::new();
688 LOGGING_CONFIG.get_or_init(|| {
689 pulseengine_mcp_cli::DefaultLoggingConfig::default()
690 })
691 })
692 }
693
694 fn initialize_logging(&self) -> std::result::Result<(), pulseengine_mcp_cli::CliError> {
695 if let Some(config) = &self.#field {
697 config.initialize()
698 } else {
699 pulseengine_mcp_cli::DefaultLoggingConfig::default().initialize()
700 }
701 }
702 }
703 } else {
704 quote! {
705 fn get_logging_config(&self) -> &pulseengine_mcp_cli::DefaultLoggingConfig {
706 use std::sync::OnceLock;
707 static LOGGING_CONFIG: OnceLock<pulseengine_mcp_cli::DefaultLoggingConfig> = OnceLock::new();
708 LOGGING_CONFIG.get_or_init(|| {
709 pulseengine_mcp_cli::DefaultLoggingConfig::default()
710 })
711 }
712
713 fn initialize_logging(&self) -> std::result::Result<(), pulseengine_mcp_cli::CliError> {
714 use pulseengine_mcp_cli::config::DefaultLoggingConfig;
715 let default_config = DefaultLoggingConfig::default();
716 default_config.initialize()
717 }
718 }
719 }
720}
721
722fn generate_auto_populate_impl(auto_populate_fields: &[syn::Ident]) -> proc_macro2::TokenStream {
723 if auto_populate_fields.is_empty() {
724 return quote! {};
725 }
726
727 let implementations = auto_populate_fields.iter().map(|field| {
728 match field.to_string().as_str() {
729 "server_info" => quote! {
730 self.#field = Some(pulseengine_mcp_cli::config::create_server_info(None, None));
731 },
732 "logging" => quote! {
733 use std::env;
735 if let Ok(level) = env::var("MCP_LOG_LEVEL") {
736 }
739 },
740 _ => quote! {
741 let env_var = format!("MCP_{}", stringify!(#field).to_uppercase());
744 if let Ok(value) = std::env::var(&env_var) {
745 tracing::debug!("Found environment variable {}: {}", env_var, value);
747 }
748 },
749 }
750 });
751
752 quote! {
753 #(#implementations)*
754 }
755}
756
757#[cfg(test)]
758mod tests {
759 use super::*;
760
761 #[test]
762 fn test_basic_mcp_config_derive() {
763 let input = quote::quote! {
764 struct TestConfig {
765 port: u16,
766 server_info: ServerInfo,
767 }
768 };
769
770 let input: DeriveInput = syn::parse2(input).unwrap();
771 let result = generate_mcp_config_impl(&input);
772 assert!(result.is_ok());
773 }
774
775 #[test]
776 fn test_mcp_config_with_attributes() {
777 let input = quote::quote! {
778 struct TestConfig {
779 #[mcp(auto_populate)]
780 server_info: ServerInfo,
781 #[mcp(logging)]
782 logging: LoggingConfig,
783 }
784 };
785
786 let input: DeriveInput = syn::parse2(input).unwrap();
787 let result = generate_mcp_config_impl(&input);
788 assert!(result.is_ok());
789 }
790
791 #[test]
792 fn test_basic_mcp_backend_derive() {
793 let input = quote::quote! {
794 struct TestBackend {
795 config: BackendConfig,
796 }
797 };
798
799 let input: DeriveInput = syn::parse2(input).unwrap();
800 let result = generate_mcp_backend_impl(&input);
801 assert!(result.is_ok());
802 }
803
804 #[test]
805 fn test_mcp_backend_with_simple() {
806 let input = quote::quote! {
807 #[mcp_backend(simple)]
808 struct SimpleTestBackend {
809 config: BackendConfig,
810 }
811 };
812
813 let input: DeriveInput = syn::parse2(input).unwrap();
814 let result = generate_mcp_backend_impl(&input);
815 assert!(result.is_ok());
816 }
817
818 #[test]
819 fn test_mcp_backend_with_custom_error() {
820 let input = quote::quote! {
821 #[mcp_backend(error = "CustomError")]
822 struct CustomErrorBackend {
823 config: BackendConfig,
824 }
825 };
826
827 let input: DeriveInput = syn::parse2(input).unwrap();
828 let result = generate_mcp_backend_impl(&input);
829 assert!(result.is_ok());
830 }
831
832 #[test]
833 fn test_mcp_backend_with_delegate() {
834 let input = quote::quote! {
835 struct DelegateBackend {
836 #[mcp_backend(delegate)]
837 inner: InnerBackend,
838 config: BackendConfig,
839 }
840 };
841
842 let input: DeriveInput = syn::parse2(input).unwrap();
843 let result = generate_mcp_backend_impl(&input);
844 assert!(result.is_ok());
845 }
846
847 #[test]
848 fn test_invalid_mcp_config() {
849 let input = quote::quote! {
850 enum TestEnum {
851 A, B, C
852 }
853 };
854
855 let input: DeriveInput = syn::parse2(input).unwrap();
856 let result = generate_mcp_config_impl(&input);
857 assert!(result.is_err());
858
859 let err = result.unwrap_err();
860 assert!(err.to_string().contains("can only be derived for structs"));
861 }
862
863 #[test]
864 fn test_invalid_mcp_backend() {
865 let input = quote::quote! {
866 enum TestEnum {
867 A, B, C
868 }
869 };
870
871 let input: DeriveInput = syn::parse2(input).unwrap();
872 let result = generate_mcp_backend_impl(&input);
873 assert!(result.is_err());
874
875 let err = result.unwrap_err();
876 assert!(err.to_string().contains("can only be derived for structs"));
877 }
878}