differential_equations_derive/
lib.rs

1//! Derive macros for differential equations crate
2//! 
3//! [![GitHub](https://img.shields.io/badge/GitHub-differential--equations-blue)](https://github.com/Ryan-D-Gast/differential-equations)
4//! [![Documentation](https://docs.rs/differential-equations/badge.svg)](https://docs.rs/differential-equations)
5//! 
6//! # License
7//!
8//! ```text
9//! Copyright 2025 Ryan D. Gast
10//!
11//! Licensed under the Apache License, Version 2.0 (the "License");
12//! you may not use this file except in compliance with the License.
13//! You may obtain a copy of the License at
14//!
15//!     http://www.apache.org/licenses/LICENSE-2.0
16//!
17//! Unless required by applicable law or agreed to in writing, software
18//! distributed under the License is distributed on an "AS IS" BASIS,
19//! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20//! See the License for the specific language governing permissions and
21//! limitations under the License.
22//! ```
23
24use proc_macro::TokenStream;
25use syn::{parse_macro_input, DeriveInput, Data, Fields};
26
27mod field_analysis;
28mod code_generation;
29mod implementations;
30
31use field_analysis::analyze_field_type;
32use code_generation::*;
33use implementations::*;
34
35/// Derive macro for the `State` trait.
36/// 
37/// This macro generates implementations of:
38/// - The `State` trait for accessing state elements
39/// - Required arithmetic operations (Add, Sub, AddAssign, Mul<T>, Div<T>)
40/// - Clone, Copy, and Debug
41/// 
42/// It supports structs with fields of type T, [T; N], SMatrix<T, R, C>, or Complex<T> where T implements the required traits.
43/// 
44/// # Example
45/// ```rust
46/// use differential_equations_derive::State;
47/// use nalgebra::SMatrix;
48/// use num_complex::Complex;
49/// 
50/// #[derive(State)]
51/// struct MyState<T> {
52///    x: T,
53///    y: T,
54///    velocities: [T; 3],
55///    transformation: SMatrix<T, 2, 2>,
56///    impedance: Complex<T>,
57/// }
58/// ```
59/// 
60/// Fields can be:
61/// - Single values of type T
62/// - Arrays of type [T; N] where N is a compile-time constant
63/// - Matrices of type SMatrix<T, R, C> where R and C are compile-time constants
64/// - Complex numbers of type Complex<T>
65/// 
66#[proc_macro_derive(State)]
67pub fn derive_state(input: TokenStream) -> TokenStream {
68    // Parse the input tokens into a syntax tree
69    let input = parse_macro_input!(input as DeriveInput);
70    
71    // Get the struct name
72    let name = input.ident;
73    
74    // Extract field information
75    let fields = match input.data {
76        Data::Struct(data) => {
77            match data.fields {
78                Fields::Named(fields) => fields.named,
79                Fields::Unnamed(fields) => fields.unnamed,
80                Fields::Unit => panic!("Unit structs are not supported"),
81            }
82        },
83        _ => panic!("State can only be derived for structs"),
84    };
85
86    // Analyze field types and calculate total element count
87    let mut total_elements = 0usize;
88    let mut field_info = Vec::new();
89    
90    for field in &fields {
91        let field_type_info = analyze_field_type(&field.ty);
92        total_elements += field_type_info.element_count();
93        field_info.push(field_type_info);
94    }
95    
96    // Generate all the required components
97    let field_get_branches = generate_get_branches(&fields, &field_info);
98    let field_set_branches = generate_set_branches(&fields, &field_info);
99    let zeros_init = generate_zeros_init(&fields, &field_info);
100    let debug_fields = generate_debug_fields(&fields);
101    
102    // Generate all trait implementations
103    let clone_impl = generate_clone_impl(&name);
104    let copy_impl = generate_copy_impl(&name);
105    let debug_impl = generate_debug_impl(&name, &debug_fields);
106    let state_impl = generate_state_impl(&name, total_elements, &field_get_branches, &field_set_branches, &zeros_init, &fields);
107    let add_impl = generate_add_impl(&name, &fields, &field_info);
108    let sub_impl = generate_sub_impl(&name, &fields, &field_info);
109    let add_assign_impl = generate_add_assign_impl(&name, &fields, &field_info);
110    let neg_impl = generate_neg_impl(&name, &fields, &field_info);
111    let mul_impl = generate_mul_impl(&name, &fields, &field_info);
112    let div_impl = generate_div_impl(&name, &fields, &field_info);
113    
114    // Combine all implementations
115    let expanded = quote::quote! {
116        #clone_impl
117        #copy_impl
118        #debug_impl
119        #state_impl
120        #add_impl
121        #sub_impl
122        #add_assign_impl
123        #mul_impl
124        #div_impl
125    #neg_impl
126    };
127    
128    // Return the generated implementation
129    TokenStream::from(expanded)
130}