1use proc_macro::TokenStream;
7use quote::quote;
8use syn::{parse_macro_input, Attribute, Data, DeriveInput, Fields};
9
10#[proc_macro_derive(McpConfig, attributes(mcp))]
41pub fn derive_mcp_config(input: TokenStream) -> TokenStream {
42 let input = parse_macro_input!(input as DeriveInput);
43
44 match generate_mcp_config_impl(&input) {
45 Ok(tokens) => tokens.into(),
46 Err(err) => err.to_compile_error().into(),
47 }
48}
49
50#[proc_macro_derive(McpBackend, attributes(mcp_backend))]
81pub fn derive_mcp_backend(input: TokenStream) -> TokenStream {
82 let input = parse_macro_input!(input as DeriveInput);
83
84 match generate_mcp_backend_impl(&input) {
85 Ok(tokens) => tokens.into(),
86 Err(err) => err.to_compile_error().into(),
87 }
88}
89
90fn generate_mcp_config_impl(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
91 let name = &input.ident;
92
93 let fields = match &input.data {
95 Data::Struct(data) => match &data.fields {
96 Fields::Named(fields) => &fields.named,
97 _ => {
98 return Err(syn::Error::new_spanned(
99 input,
100 "McpConfig can only be derived for structs with named fields",
101 ))
102 }
103 },
104 _ => {
105 return Err(syn::Error::new_spanned(
106 input,
107 "McpConfig can only be derived for structs",
108 ))
109 }
110 };
111
112 let mut server_info_field = None;
114 let mut logging_field = None;
115 let mut auto_populate_fields = Vec::new();
116
117 for field in fields {
118 if let Some(ident) = &field.ident {
119 for attr in &field.attrs {
121 if attr.path().is_ident("mcp") {
122 parse_mcp_attribute(
123 attr,
124 ident,
125 &mut server_info_field,
126 &mut logging_field,
127 &mut auto_populate_fields,
128 )?;
129 }
130 }
131
132 match ident.to_string().as_str() {
134 "server_info" => server_info_field = Some(ident.clone()),
135 "logging" => logging_field = Some(ident.clone()),
136 _ => {}
137 }
138 }
139 }
140
141 let server_info_impl = generate_server_info_impl(&server_info_field);
143 let logging_impl = generate_logging_impl(&logging_field);
144 let auto_populate_impl = generate_auto_populate_impl(&auto_populate_fields);
145
146 Ok(quote! {
147 impl pulseengine_mcp_cli::McpConfiguration for #name {
148 #server_info_impl
149 #logging_impl
150
151 fn validate(&self) -> std::result::Result<(), pulseengine_mcp_cli::CliError> {
152 Ok(())
154 }
155 }
156
157 impl #name {
158 pub fn with_auto_populate() -> Self
160 where
161 Self: Default,
162 {
163 let mut instance = Self::default();
164 instance.auto_populate();
165 instance
166 }
167
168 pub fn auto_populate(&mut self) {
170 #auto_populate_impl
171 }
172 }
173 })
174}
175
176fn generate_mcp_backend_impl(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
177 let name = &input.ident;
178
179 let backend_config = parse_backend_attributes(input)?;
181
182 let fields = match &input.data {
184 Data::Struct(data) => match &data.fields {
185 Fields::Named(fields) => &fields.named,
186 _ => {
187 return Err(syn::Error::new_spanned(
188 input,
189 "McpBackend can only be derived for structs with named fields",
190 ))
191 }
192 },
193 _ => {
194 return Err(syn::Error::new_spanned(
195 input,
196 "McpBackend can only be derived for structs",
197 ))
198 }
199 };
200
201 let mut delegate_field = None;
203 let mut error_from_fields = Vec::new();
204
205 for field in fields {
206 if let Some(ident) = &field.ident {
207 for attr in &field.attrs {
208 if attr.path().is_ident("mcp_backend") {
209 parse_backend_field_attribute(
210 attr,
211 ident,
212 &mut delegate_field,
213 &mut error_from_fields,
214 )?;
215 }
216 }
217 }
218 }
219
220 let error_type = backend_config
222 .error_type
223 .as_ref()
224 .map(|s| syn::parse_str::<syn::Type>(s))
225 .transpose()?
226 .unwrap_or_else(|| syn::parse_str(&format!("{name}Error")).unwrap());
227
228 let config_type = backend_config
229 .config_type
230 .as_ref()
231 .map(|s| syn::parse_str::<syn::Type>(s))
232 .transpose()?
233 .unwrap_or_else(|| syn::parse_str(&format!("{name}Config")).unwrap());
234
235 let error_definition = if backend_config.error_type.is_none() {
237 generate_error_type_definition(name, &error_from_fields)
238 } else {
239 quote! {}
240 };
241
242 let trait_impl = if backend_config.simple_backend {
244 generate_simple_backend_impl(name, &error_type, &config_type, &delegate_field)
245 } else {
246 generate_full_backend_impl(name, &error_type, &config_type, &delegate_field)
247 };
248
249 Ok(quote! {
250 #error_definition
251 #trait_impl
252 })
253}
254
255#[derive(Default)]
256struct BackendConfig {
257 error_type: Option<String>,
258 config_type: Option<String>,
259 simple_backend: bool,
260}
261
262fn parse_backend_attributes(input: &DeriveInput) -> syn::Result<BackendConfig> {
263 let mut config = BackendConfig::default();
264
265 for attr in &input.attrs {
266 if attr.path().is_ident("mcp_backend") {
267 attr.parse_nested_meta(|meta| {
268 if meta.path.is_ident("simple") {
269 config.simple_backend = true;
270 Ok(())
271 } else if meta.path.is_ident("error") {
272 if let Ok(value) = meta.value() {
273 if let Ok(lit) = value.parse::<syn::LitStr>() {
274 config.error_type = Some(lit.value());
275 }
276 }
277 Ok(())
278 } else if meta.path.is_ident("config") {
279 if let Ok(value) = meta.value() {
280 if let Ok(lit) = value.parse::<syn::LitStr>() {
281 config.config_type = Some(lit.value());
282 }
283 }
284 Ok(())
285 } else {
286 Err(meta.error(format!(
287 "unsupported mcp_backend attribute: {}",
288 meta.path.get_ident().unwrap()
289 )))
290 }
291 })?;
292 }
293 }
294
295 Ok(config)
296}
297
298fn parse_backend_field_attribute(
299 attr: &Attribute,
300 field_ident: &syn::Ident,
301 delegate_field: &mut Option<syn::Ident>,
302 error_from_fields: &mut Vec<syn::Ident>,
303) -> syn::Result<()> {
304 attr.parse_nested_meta(|meta| {
305 if meta.path.is_ident("delegate") {
306 *delegate_field = Some(field_ident.clone());
307 Ok(())
308 } else if meta.path.is_ident("error_from") {
309 error_from_fields.push(field_ident.clone());
310 Ok(())
311 } else {
312 Err(meta.error(format!(
313 "unsupported mcp_backend field attribute: {}",
314 meta.path.get_ident().unwrap()
315 )))
316 }
317 })
318}
319
320fn generate_error_type_definition(
321 name: &syn::Ident,
322 error_from_fields: &[syn::Ident],
323) -> proc_macro2::TokenStream {
324 let error_name = syn::Ident::new(&format!("{name}Error"), name.span());
325
326 let from_implementations = error_from_fields.iter().map(|_field| {
327 quote! {
329 impl From<std::io::Error> for #error_name {
330 fn from(err: std::io::Error) -> Self {
331 Self::Internal(err.to_string())
332 }
333 }
334 }
335 });
336
337 quote! {
338 #[derive(Debug, thiserror::Error)]
339 pub enum #error_name {
340 #[error("Configuration error: {0}")]
341 Configuration(String),
342
343 #[error("Connection error: {0}")]
344 Connection(String),
345
346 #[error("Operation not supported: {0}")]
347 NotSupported(String),
348
349 #[error("Internal error: {0}")]
350 Internal(String),
351 }
352
353 impl #error_name {
354 pub fn configuration(msg: impl Into<String>) -> Self {
355 Self::Configuration(msg.into())
356 }
357
358 pub fn connection(msg: impl Into<String>) -> Self {
359 Self::Connection(msg.into())
360 }
361
362 pub fn not_supported(msg: impl Into<String>) -> Self {
363 Self::NotSupported(msg.into())
364 }
365
366 pub fn internal(msg: impl Into<String>) -> Self {
367 Self::Internal(msg.into())
368 }
369 }
370
371 impl From<pulseengine_mcp_server::backend::BackendError> for #error_name {
372 fn from(err: pulseengine_mcp_server::backend::BackendError) -> Self {
373 Self::Internal(err.to_string())
374 }
375 }
376
377 impl From<#error_name> for pulseengine_mcp_protocol::Error {
378 fn from(err: #error_name) -> Self {
379 match err {
380 #error_name::Configuration(msg) => Self::invalid_params(msg),
381 #error_name::Connection(msg) => Self::internal_error(format!("Connection failed: {msg}")),
382 #error_name::NotSupported(msg) => Self::method_not_found(msg),
383 #error_name::Internal(msg) => Self::internal_error(msg),
384 }
385 }
386 }
387
388 #(#from_implementations)*
389 }
390}
391
392fn generate_simple_backend_impl(
393 name: &syn::Ident,
394 error_type: &syn::Type,
395 config_type: &syn::Type,
396 delegate_field: &Option<syn::Ident>,
397) -> proc_macro2::TokenStream {
398 if let Some(delegate) = delegate_field {
399 quote! {
400 #[async_trait::async_trait]
401 impl pulseengine_mcp_server::backend::SimpleBackend for #name {
402 type Error = #error_type;
403 type Config = #config_type;
404
405 async fn initialize(config: Self::Config) -> std::result::Result<Self, Self::Error> {
406 Err(Self::Error::not_supported("Backend initialization not implemented"))
408 }
409
410 fn get_server_info(&self) -> pulseengine_mcp_protocol::ServerInfo {
411 self.#delegate.get_server_info()
412 }
413
414 async fn health_check(&self) -> std::result::Result<(), Self::Error> {
415 self.#delegate.health_check().await.map_err(Into::into)
416 }
417
418 async fn list_tools(
419 &self,
420 request: pulseengine_mcp_protocol::PaginatedRequestParam,
421 ) -> std::result::Result<pulseengine_mcp_protocol::ListToolsResult, Self::Error> {
422 self.#delegate.list_tools(request).await.map_err(Into::into)
423 }
424
425 async fn call_tool(
426 &self,
427 request: pulseengine_mcp_protocol::CallToolRequestParam,
428 ) -> std::result::Result<pulseengine_mcp_protocol::CallToolResult, Self::Error> {
429 self.#delegate.call_tool(request).await.map_err(Into::into)
430 }
431 }
432 }
433 } else {
434 quote! {
435 #[async_trait::async_trait]
436 impl pulseengine_mcp_server::backend::SimpleBackend for #name {
437 type Error = #error_type;
438 type Config = #config_type;
439
440 async fn initialize(config: Self::Config) -> std::result::Result<Self, Self::Error> {
441 Err(Self::Error::not_supported("Backend initialization not implemented"))
443 }
444
445 fn get_server_info(&self) -> pulseengine_mcp_protocol::ServerInfo {
446 pulseengine_mcp_protocol::ServerInfo {
448 protocol_version: pulseengine_mcp_protocol::ProtocolVersion::default(),
449 capabilities: pulseengine_mcp_protocol::ServerCapabilities::default(),
450 server_info: pulseengine_mcp_protocol::Implementation {
451 name: env!("CARGO_PKG_NAME").to_string(),
452 version: env!("CARGO_PKG_VERSION").to_string(),
453 },
454 instructions: None,
455 }
456 }
457
458 async fn health_check(&self) -> std::result::Result<(), Self::Error> {
459 Ok(())
461 }
462
463 async fn list_tools(
464 &self,
465 _request: pulseengine_mcp_protocol::PaginatedRequestParam,
466 ) -> std::result::Result<pulseengine_mcp_protocol::ListToolsResult, Self::Error> {
467 Ok(pulseengine_mcp_protocol::ListToolsResult {
469 tools: vec![],
470 next_cursor: None,
471 })
472 }
473
474 async fn call_tool(
475 &self,
476 request: pulseengine_mcp_protocol::CallToolRequestParam,
477 ) -> std::result::Result<pulseengine_mcp_protocol::CallToolResult, Self::Error> {
478 Err(Self::Error::not_supported(format!("Tool not found: {}", request.name)))
480 }
481 }
482 }
483 }
484}
485
486fn generate_full_backend_impl(
487 name: &syn::Ident,
488 error_type: &syn::Type,
489 config_type: &syn::Type,
490 delegate_field: &Option<syn::Ident>,
491) -> proc_macro2::TokenStream {
492 if let Some(delegate) = delegate_field {
493 quote! {
494 #[async_trait::async_trait]
495 impl pulseengine_mcp_server::backend::McpBackend for #name {
496 type Error = #error_type;
497 type Config = #config_type;
498
499 async fn initialize(config: Self::Config) -> std::result::Result<Self, Self::Error> {
500 Err(Self::Error::not_supported("Backend initialization not implemented"))
502 }
503
504 fn get_server_info(&self) -> pulseengine_mcp_protocol::ServerInfo {
505 self.#delegate.get_server_info()
506 }
507
508 async fn health_check(&self) -> std::result::Result<(), Self::Error> {
509 self.#delegate.health_check().await.map_err(Into::into)
510 }
511
512 async fn list_tools(
514 &self,
515 request: pulseengine_mcp_protocol::PaginatedRequestParam,
516 ) -> std::result::Result<pulseengine_mcp_protocol::ListToolsResult, Self::Error> {
517 self.#delegate.list_tools(request).await.map_err(Into::into)
518 }
519
520 async fn call_tool(
521 &self,
522 request: pulseengine_mcp_protocol::CallToolRequestParam,
523 ) -> std::result::Result<pulseengine_mcp_protocol::CallToolResult, Self::Error> {
524 self.#delegate.call_tool(request).await.map_err(Into::into)
525 }
526
527 async fn list_resources(
528 &self,
529 request: pulseengine_mcp_protocol::PaginatedRequestParam,
530 ) -> std::result::Result<pulseengine_mcp_protocol::ListResourcesResult, Self::Error> {
531 self.#delegate.list_resources(request).await.map_err(Into::into)
532 }
533
534 async fn read_resource(
535 &self,
536 request: pulseengine_mcp_protocol::ReadResourceRequestParam,
537 ) -> std::result::Result<pulseengine_mcp_protocol::ReadResourceResult, Self::Error> {
538 self.#delegate.read_resource(request).await.map_err(Into::into)
539 }
540
541 async fn list_prompts(
542 &self,
543 request: pulseengine_mcp_protocol::PaginatedRequestParam,
544 ) -> std::result::Result<pulseengine_mcp_protocol::ListPromptsResult, Self::Error> {
545 self.#delegate.list_prompts(request).await.map_err(Into::into)
546 }
547
548 async fn get_prompt(
549 &self,
550 request: pulseengine_mcp_protocol::GetPromptRequestParam,
551 ) -> std::result::Result<pulseengine_mcp_protocol::GetPromptResult, Self::Error> {
552 self.#delegate.get_prompt(request).await.map_err(Into::into)
553 }
554 }
555 }
556 } else {
557 quote! {
558 #[async_trait::async_trait]
559 impl pulseengine_mcp_server::backend::McpBackend for #name {
560 type Error = #error_type;
561 type Config = #config_type;
562
563 async fn initialize(config: Self::Config) -> std::result::Result<Self, Self::Error> {
564 Err(Self::Error::not_supported("Backend initialization not implemented"))
566 }
567
568 fn get_server_info(&self) -> pulseengine_mcp_protocol::ServerInfo {
569 pulseengine_mcp_protocol::ServerInfo {
571 protocol_version: pulseengine_mcp_protocol::ProtocolVersion::default(),
572 capabilities: pulseengine_mcp_protocol::ServerCapabilities::default(),
573 server_info: pulseengine_mcp_protocol::Implementation {
574 name: env!("CARGO_PKG_NAME").to_string(),
575 version: env!("CARGO_PKG_VERSION").to_string(),
576 },
577 instructions: None,
578 }
579 }
580
581 async fn health_check(&self) -> std::result::Result<(), Self::Error> {
582 Ok(())
584 }
585
586 async fn list_tools(
588 &self,
589 _request: pulseengine_mcp_protocol::PaginatedRequestParam,
590 ) -> std::result::Result<pulseengine_mcp_protocol::ListToolsResult, Self::Error> {
591 Ok(pulseengine_mcp_protocol::ListToolsResult {
592 tools: vec![],
593 next_cursor: None,
594 })
595 }
596
597 async fn call_tool(
598 &self,
599 request: pulseengine_mcp_protocol::CallToolRequestParam,
600 ) -> std::result::Result<pulseengine_mcp_protocol::CallToolResult, Self::Error> {
601 Err(Self::Error::not_supported(format!("Tool not found: {}", request.name)))
602 }
603
604 async fn list_resources(
605 &self,
606 _request: pulseengine_mcp_protocol::PaginatedRequestParam,
607 ) -> std::result::Result<pulseengine_mcp_protocol::ListResourcesResult, Self::Error> {
608 Ok(pulseengine_mcp_protocol::ListResourcesResult {
609 resources: vec![],
610 next_cursor: None,
611 })
612 }
613
614 async fn read_resource(
615 &self,
616 request: pulseengine_mcp_protocol::ReadResourceRequestParam,
617 ) -> std::result::Result<pulseengine_mcp_protocol::ReadResourceResult, Self::Error> {
618 Err(Self::Error::not_supported(format!("Resource not found: {}", request.uri)))
619 }
620
621 async fn list_prompts(
622 &self,
623 _request: pulseengine_mcp_protocol::PaginatedRequestParam,
624 ) -> std::result::Result<pulseengine_mcp_protocol::ListPromptsResult, Self::Error> {
625 Ok(pulseengine_mcp_protocol::ListPromptsResult {
626 prompts: vec![],
627 next_cursor: None,
628 })
629 }
630
631 async fn get_prompt(
632 &self,
633 request: pulseengine_mcp_protocol::GetPromptRequestParam,
634 ) -> std::result::Result<pulseengine_mcp_protocol::GetPromptResult, Self::Error> {
635 Err(Self::Error::not_supported(format!("Prompt not found: {}", request.name)))
636 }
637 }
638 }
639 }
640}
641
642fn parse_mcp_attribute(
643 attr: &Attribute,
644 field_ident: &syn::Ident,
645 _server_info_field: &mut Option<syn::Ident>,
646 logging_field: &mut Option<syn::Ident>,
647 auto_populate_fields: &mut Vec<syn::Ident>,
648) -> syn::Result<()> {
649 attr.parse_nested_meta(|meta| {
650 if meta.path.is_ident("auto_populate") {
651 auto_populate_fields.push(field_ident.clone());
652 Ok(())
653 } else if meta.path.is_ident("logging") {
654 *logging_field = Some(field_ident.clone());
655 if meta.input.peek(syn::token::Paren) {
657 let _content;
658 syn::parenthesized!(_content in meta.input);
659 }
662 Ok(())
663 } else {
664 Err(meta.error(format!(
665 "unsupported mcp attribute: {}",
666 meta.path.get_ident().unwrap()
667 )))
668 }
669 })
670}
671
672fn generate_server_info_impl(server_info_field: &Option<syn::Ident>) -> proc_macro2::TokenStream {
673 if let Some(field) = server_info_field {
674 quote! {
675 fn get_server_info(&self) -> &pulseengine_mcp_protocol::ServerInfo {
676 self.#field.as_ref().unwrap_or_else(|| {
677 static SERVER_INFO: std::sync::OnceLock<pulseengine_mcp_protocol::ServerInfo> = std::sync::OnceLock::new();
678 SERVER_INFO.get_or_init(|| {
679 pulseengine_mcp_cli::config::create_server_info(None, None)
680 })
681 })
682 }
683 }
684 } else {
685 quote! {
686 fn get_server_info(&self) -> &pulseengine_mcp_protocol::ServerInfo {
687 use std::sync::OnceLock;
688 static SERVER_INFO: OnceLock<pulseengine_mcp_protocol::ServerInfo> = OnceLock::new();
689 SERVER_INFO.get_or_init(|| {
690 pulseengine_mcp_cli::config::create_server_info(None, None)
691 })
692 }
693 }
694 }
695}
696
697fn generate_logging_impl(logging_field: &Option<syn::Ident>) -> proc_macro2::TokenStream {
698 if let Some(field) = logging_field {
699 quote! {
700 fn get_logging_config(&self) -> &pulseengine_mcp_cli::DefaultLoggingConfig {
701 self.#field.as_ref().unwrap_or_else(|| {
702 static LOGGING_CONFIG: std::sync::OnceLock<pulseengine_mcp_cli::DefaultLoggingConfig> = std::sync::OnceLock::new();
703 LOGGING_CONFIG.get_or_init(|| {
704 pulseengine_mcp_cli::DefaultLoggingConfig::default()
705 })
706 })
707 }
708
709 fn initialize_logging(&self) -> std::result::Result<(), pulseengine_mcp_cli::CliError> {
710 if let Some(config) = &self.#field {
712 config.initialize()
713 } else {
714 pulseengine_mcp_cli::DefaultLoggingConfig::default().initialize()
715 }
716 }
717 }
718 } else {
719 quote! {
720 fn get_logging_config(&self) -> &pulseengine_mcp_cli::DefaultLoggingConfig {
721 use std::sync::OnceLock;
722 static LOGGING_CONFIG: OnceLock<pulseengine_mcp_cli::DefaultLoggingConfig> = OnceLock::new();
723 LOGGING_CONFIG.get_or_init(|| {
724 pulseengine_mcp_cli::DefaultLoggingConfig::default()
725 })
726 }
727
728 fn initialize_logging(&self) -> std::result::Result<(), pulseengine_mcp_cli::CliError> {
729 use pulseengine_mcp_cli::config::DefaultLoggingConfig;
730 let default_config = DefaultLoggingConfig::default();
731 default_config.initialize()
732 }
733 }
734 }
735}
736
737fn generate_auto_populate_impl(auto_populate_fields: &[syn::Ident]) -> proc_macro2::TokenStream {
738 if auto_populate_fields.is_empty() {
739 return quote! {};
740 }
741
742 let implementations = auto_populate_fields.iter().map(|field| {
743 match field.to_string().as_str() {
744 "server_info" => quote! {
745 self.#field = Some(pulseengine_mcp_cli::config::create_server_info(None, None));
746 },
747 "logging" => quote! {
748 use std::env;
750 if let Ok(level) = env::var("MCP_LOG_LEVEL") {
751 }
754 },
755 _ => quote! {
756 let env_var = format!("MCP_{}", stringify!(#field).to_uppercase());
759 if let Ok(value) = std::env::var(&env_var) {
760 tracing::debug!("Found environment variable {}: {}", env_var, value);
762 }
763 },
764 }
765 });
766
767 quote! {
768 #(#implementations)*
769 }
770}
771
772#[cfg(test)]
773mod tests {
774 use super::*;
775
776 #[test]
777 fn test_basic_mcp_config_derive() {
778 let input = quote::quote! {
779 struct TestConfig {
780 port: u16,
781 server_info: ServerInfo,
782 }
783 };
784
785 let input: DeriveInput = syn::parse2(input).unwrap();
786 let result = generate_mcp_config_impl(&input);
787 assert!(result.is_ok());
788 }
789
790 #[test]
791 fn test_mcp_config_with_attributes() {
792 let input = quote::quote! {
793 struct TestConfig {
794 #[mcp(auto_populate)]
795 server_info: ServerInfo,
796 #[mcp(logging)]
797 logging: LoggingConfig,
798 }
799 };
800
801 let input: DeriveInput = syn::parse2(input).unwrap();
802 let result = generate_mcp_config_impl(&input);
803 assert!(result.is_ok());
804 }
805
806 #[test]
807 fn test_basic_mcp_backend_derive() {
808 let input = quote::quote! {
809 struct TestBackend {
810 config: BackendConfig,
811 }
812 };
813
814 let input: DeriveInput = syn::parse2(input).unwrap();
815 let result = generate_mcp_backend_impl(&input);
816 assert!(result.is_ok());
817 }
818
819 #[test]
820 fn test_mcp_backend_with_simple() {
821 let input = quote::quote! {
822 #[mcp_backend(simple)]
823 struct SimpleTestBackend {
824 config: BackendConfig,
825 }
826 };
827
828 let input: DeriveInput = syn::parse2(input).unwrap();
829 let result = generate_mcp_backend_impl(&input);
830 assert!(result.is_ok());
831 }
832
833 #[test]
834 fn test_mcp_backend_with_custom_error() {
835 let input = quote::quote! {
836 #[mcp_backend(error = "CustomError")]
837 struct CustomErrorBackend {
838 config: BackendConfig,
839 }
840 };
841
842 let input: DeriveInput = syn::parse2(input).unwrap();
843 let result = generate_mcp_backend_impl(&input);
844 assert!(result.is_ok());
845 }
846
847 #[test]
848 fn test_mcp_backend_with_delegate() {
849 let input = quote::quote! {
850 struct DelegateBackend {
851 #[mcp_backend(delegate)]
852 inner: InnerBackend,
853 config: BackendConfig,
854 }
855 };
856
857 let input: DeriveInput = syn::parse2(input).unwrap();
858 let result = generate_mcp_backend_impl(&input);
859 assert!(result.is_ok());
860 }
861
862 #[test]
863 fn test_invalid_mcp_config() {
864 let input = quote::quote! {
865 enum TestEnum {
866 A, B, C
867 }
868 };
869
870 let input: DeriveInput = syn::parse2(input).unwrap();
871 let result = generate_mcp_config_impl(&input);
872 assert!(result.is_err());
873
874 let err = result.unwrap_err();
875 assert!(err.to_string().contains("can only be derived for structs"));
876 }
877
878 #[test]
879 fn test_invalid_mcp_backend() {
880 let input = quote::quote! {
881 enum TestEnum {
882 A, B, C
883 }
884 };
885
886 let input: DeriveInput = syn::parse2(input).unwrap();
887 let result = generate_mcp_backend_impl(&input);
888 assert!(result.is_err());
889
890 let err = result.unwrap_err();
891 assert!(err.to_string().contains("can only be derived for structs"));
892 }
893}