option_chain_tool/
lib.rs

1use proc_macro::{Delimiter, Group, Ident, Punct, Spacing, TokenStream, TokenTree};
2
3/// A procedural macro for safe optional chaining in Rust.
4///
5/// The `opt!` macro provides a concise syntax for chaining operations on `Option` and `Result` types,
6/// similar to optional chaining in languages like TypeScript or Swift. It automatically handles
7/// unwrapping and propagates `None` values through the chain.
8///
9/// # Syntax
10///
11/// The macro supports several operators for different use cases:
12///
13/// - `?.` - Unwraps an `Option`, returns `None` if the value is `None`
14/// - `?Ok.` - Unwraps a `Result` to its `Ok` variant, returns `None` if `Err`
15/// - `?Err.` - Unwraps a `Result` to its `Err` variant, returns `None` if `Ok`
16/// - `.field` - Access a field without unwrapping (for required fields)
17///
18/// The macro returns `Some(value)` if all operations succeed, or `None` if any step fails.
19///
20/// # Examples
21///
22/// ## Basic Option chaining
23///
24/// ```ignore
25/// use option_chain_tool::opt;
26///
27/// struct User {
28///     profile: Option<Profile>,
29/// }
30///
31/// struct Profile {
32///     address: Option<Address>,
33/// }
34///
35/// struct Address {
36///     city: Option<String>,
37/// }
38///
39/// let user = User {
40///     profile: Some(Profile {
41///         address: Some(Address {
42///             city: Some("New York".to_string()),
43///         }),
44///     }),
45/// };
46///
47/// // Instead of: user.profile.as_ref().and_then(|p| p.address.as_ref()).and_then(|a| a.city.as_ref())
48/// let city: Option<&String> = opt!(user.profile?.address?.city?);
49/// assert_eq!(city, Some(&"New York".to_string()));
50/// ```
51///
52/// ## Chaining with method calls
53///
54/// ```ignore
55/// use option_chain_tool::opt;
56///
57/// impl Address {
58///     fn get_city(&self) -> Option<&String> {
59///         self.city.as_ref()
60///     }
61/// }
62///
63/// let city: Option<&String> = opt!(user.profile?.address?.get_city()?);
64/// ```
65///
66/// ## Accessing required fields
67///
68/// ```ignore
69/// use option_chain_tool::opt;
70///
71/// struct Address {
72///     city: Option<String>,
73///     street: String, // Required field
74/// }
75///
76/// // Access a required field in the chain (no ? after street)
77/// let street: Option<&String> = opt!(user.profile?.address?.street);
78/// ```
79///
80/// ## Working with Result types
81///
82/// ```ignore
83/// use option_chain_tool::opt;
84///
85/// struct Address {
86///     validation: Result<String, String>,
87/// }
88///
89/// // Extract the Ok variant
90/// let ok_value: Option<&String> = opt!(user.profile?.address?.validation?Ok);
91///
92/// // Extract the Err variant
93/// let err_value: Option<&String> = opt!(user.profile?.address?.validation?Err);
94/// ```
95///
96/// ## Complex chaining
97///
98/// ```ignore
99/// use option_chain_tool::opt;
100///
101/// // Combine multiple patterns in a single chain
102/// let value: Option<&String> = opt!(
103///     user
104///         .profile?        // Unwrap Option<Profile>
105///         .address?        // Unwrap Option<Address>
106///         .street          // Access required field
107///         .validation?Ok   // Unwrap Result to Ok variant
108/// );
109/// ```
110///
111/// # Returns
112///
113/// - `Some(value)` if all operations in the chain succeed
114/// - `None` if any operation in the chain returns `None` or encounters an unwrappable value
115///
116/// # Notes
117///
118/// The macro generates nested `if let` expressions that short-circuit on `None`, providing
119/// efficient and safe optional chaining without runtime panics.
120#[proc_macro]
121pub fn opt(input: TokenStream) -> TokenStream {
122    let resp = split_on_optional_variants(input);
123    // for r in resp.iter() {
124    //     let tokens = r
125    //         .tokens
126    //         .clone()
127    //         .into_iter()
128    //         .collect::<TokenStream>()
129    //         .to_string();
130    //     dbg!(format!("Variant: {:?}, Tokens: {}", r.variant, tokens));
131    // }
132    // dbg!(resp.len());
133    let mut result = TokenStream::new();
134    let segments_len = resp.len();
135    for (index, segment) in resp.into_iter().rev().enumerate() {
136        if segments_len - 1 == index {
137            if result.is_empty() {
138                let mut ____v = TokenStream::new();
139                ____v.extend([TokenTree::Ident(Ident::new(
140                    "____v",
141                    proc_macro::Span::call_site(),
142                ))]);
143                result = some_wrapper(____v);
144            }
145            result = if_let(
146                segment.variant,
147                segment.tokens.into_iter().collect(),
148                result,
149                true,
150            );
151            continue;
152        }
153        {
154            let mut is_add_amp = true;
155            if index == 0 {
156                if ends_with_fn_call(&segment.tokens) {
157                    is_add_amp = false;
158                }
159            }
160
161            let mut after_eq = TokenStream::new();
162            after_eq.extend([
163                TokenTree::Ident(Ident::new("____v", proc_macro::Span::call_site())),
164                TokenTree::Punct(Punct::new('.', Spacing::Joint)),
165            ]);
166            after_eq.extend(segment.tokens.into_iter());
167            if result.is_empty() {
168                let mut ____v = TokenStream::new();
169                ____v.extend([TokenTree::Ident(Ident::new(
170                    "____v",
171                    proc_macro::Span::call_site(),
172                ))]);
173                result = some_wrapper(____v);
174            }
175            result = if_let(segment.variant, after_eq, result, is_add_amp);
176        }
177    }
178
179    result
180}
181
182/// Wraps a token stream in a `Some(...)` expression.
183///
184/// This helper function takes a token stream and wraps it in a `Some` constructor,
185/// which is used to return successful values in the optional chaining.
186///
187/// # Arguments
188///
189/// * `body` - The token stream to wrap inside `Some`
190///
191/// # Returns
192///
193/// A token stream representing `Some(body)`
194///
195/// # Example
196///
197/// ```ignore
198/// // Input: ____v
199/// // Output: Some(____v)
200/// ```
201fn some_wrapper(body: TokenStream) -> TokenStream {
202    let mut ts = TokenStream::new();
203    ts.extend([TokenTree::Ident(Ident::new(
204        "Some",
205        proc_macro::Span::call_site(),
206    ))]);
207    ts.extend([TokenTree::Group(Group::new(Delimiter::Parenthesis, body))]);
208    ts
209}
210
211/// Checks if a sequence of tokens ends with a function call.
212///
213/// This function examines the last token in a slice to determine if it represents
214/// a function call, which is identified by a closing parenthesis group.
215///
216/// # Arguments
217///
218/// * `tokens` - A slice of `TokenTree` to examine
219///
220/// # Returns
221///
222/// `true` if the last token is a group with parenthesis delimiter (indicating a function call),
223/// `false` otherwise
224///
225/// # Example
226///
227/// ```ignore
228/// // Returns true for: foo.bar()
229/// // Returns false for: foo.bar
230/// ```
231fn ends_with_fn_call(tokens: &[TokenTree]) -> bool {
232    let last = match tokens.last() {
233        Some(tt) => tt,
234        None => return false,
235    };
236
237    if let TokenTree::Group(group) = last {
238        if group.delimiter() == Delimiter::Parenthesis {
239            return true;
240        }
241    }
242
243    false
244}
245
246/// Generates an `if let` expression for pattern matching in the optional chain.
247///
248/// This function constructs an `if let` expression that attempts to unwrap a value
249/// according to the specified variant (`Some`, `Ok`, or `Err`). If the pattern matches,
250/// the body is executed; otherwise, `None` is returned.
251///
252/// # Arguments
253///
254/// * `variant` - The type of unwrapping to perform (Option, Ok, Err, Required, or Root)
255/// * `after_eq` - Token stream representing the expression to be matched
256/// * `body` - Token stream representing the code to execute if the pattern matches
257/// * `is_add_amp` - Whether to add a reference (`&`) before the expression being matched
258///
259/// # Returns
260///
261/// A token stream representing the complete `if let` expression with an `else` clause
262/// that returns `None`
263///
264/// # Panics
265///
266/// Panics if called with `OptionalVariant::Root`
267///
268/// # Example
269///
270/// ```ignore
271/// // Generates: if let Some(____v) = &expr { body } else { None }
272/// ```
273fn if_let(
274    variant: OptionalVariant,
275    after_eq: TokenStream,
276    body: TokenStream,
277    is_add_amp: bool,
278) -> TokenStream {
279    let mut ts = TokenStream::new();
280    ts.extend([TokenTree::Ident(Ident::new(
281        "if",
282        proc_macro::Span::call_site(),
283    ))]);
284    ts.extend([TokenTree::Ident(Ident::new(
285        "let",
286        proc_macro::Span::call_site(),
287    ))]);
288    match variant {
289        OptionalVariant::Option => {
290            ts.extend([TokenTree::Ident(Ident::new(
291                "Some",
292                proc_macro::Span::call_site(),
293            ))]);
294        }
295        OptionalVariant::Ok => {
296            ts.extend([TokenTree::Ident(Ident::new(
297                "Ok",
298                proc_macro::Span::call_site(),
299            ))]);
300        }
301        OptionalVariant::Err => {
302            ts.extend([TokenTree::Ident(Ident::new(
303                "Err",
304                proc_macro::Span::call_site(),
305            ))]);
306        }
307        OptionalVariant::Required => {
308            // panic!("if_let called with Required variant");
309        }
310        OptionalVariant::Root => {
311            panic!("if_let called with Root variant");
312        }
313    }
314    ts.extend([TokenTree::Group(Group::new(
315        Delimiter::Parenthesis,
316        TokenTree::Ident(Ident::new("____v", proc_macro::Span::call_site())).into(),
317    ))]);
318    ts.extend([TokenTree::Punct(Punct::new('=', Spacing::Alone))]);
319    if is_add_amp {
320        ts.extend([TokenTree::Punct(Punct::new('&', Spacing::Joint))]);
321    }
322    ts.extend(after_eq);
323    ts.extend([TokenTree::Group(Group::new(Delimiter::Brace, body))]);
324    ts.extend([TokenTree::Ident(Ident::new(
325        "else",
326        proc_macro::Span::call_site(),
327    ))]);
328    ts.extend([TokenTree::Group(Group::new(
329        Delimiter::Brace,
330        TokenTree::Ident(Ident::new("None", proc_macro::Span::call_site())).into(),
331    ))]);
332    ts
333}
334
335/// Represents the type of optional chaining operation at each segment.
336///
337/// This enum identifies how each segment in the optional chain should be unwrapped
338/// or accessed, enabling the macro to generate the appropriate pattern matching code.
339#[derive(Debug, Clone, Copy, PartialEq, Eq)]
340enum OptionalVariant {
341    /// First segment of the chain (no unwrapping operator)
342    Root,
343    /// Unwrap an `Option` using `?.` operator
344    Option,
345    /// Unwrap a `Result` to its `Ok` variant using `?Ok.` operator
346    Ok,
347    /// Unwrap a `Result` to its `Err` variant using `?Err.` operator
348    Err,
349    /// Access a field directly without unwrapping (no `?` operator)
350    Required,
351}
352
353/// Represents a single segment in the optional chaining expression.
354///
355/// Each segment contains the tokens that make up that part of the chain,
356/// along with the variant indicating how it should be unwrapped.
357#[derive(Debug, Clone)]
358struct OptionalSegment {
359    /// The type of unwrapping operation for this segment
360    pub variant: OptionalVariant,
361    /// The token trees that make up this segment's expression
362    pub tokens: Vec<TokenTree>,
363}
364
365/// Parses the input token stream and splits it into segments based on optional chaining operators.
366///
367/// This function analyzes the input token stream to identify optional chaining operators
368/// (`?.`, `?Ok.`, `?Err.`) and splits the expression into segments, each with its corresponding
369/// variant type. The segments are then used to generate the nested `if let` expressions.
370///
371/// # Arguments
372///
373/// * `input` - The input token stream to parse
374///
375/// # Returns
376///
377/// A vector of `OptionalSegment` structs, where each segment represents a portion of the
378/// chaining expression along with its unwrapping variant
379///
380/// # Example
381///
382/// ```ignore
383/// // Input: user.profile?.address?.city?
384/// // Output: [
385/// //   OptionalSegment { variant: Option, tokens: [user, .profile] },
386/// //   OptionalSegment { variant: Option, tokens: [address] },
387/// //   OptionalSegment { variant: Option, tokens: [city] }
388/// // ]
389/// ```
390fn split_on_optional_variants(input: TokenStream) -> Vec<OptionalSegment> {
391    let input_tokens: Vec<TokenTree> = input.clone().into_iter().collect();
392    let mut iter = input.into_iter().peekable();
393
394    let mut result: Vec<OptionalSegment> = Vec::new();
395    let mut current: Vec<TokenTree> = Vec::new();
396    let mut current_variant = OptionalVariant::Root;
397    while let Some(tt) = iter.next().as_ref() {
398        match &tt {
399            TokenTree::Punct(q) if q.as_char() == '?' => {
400                // Try to detect ?. / ?Ok. / ?Err.
401                let variant = match iter.peek() {
402                    Some(TokenTree::Punct(dot)) if dot.as_char() == '.' => {
403                        iter.next(); // consume '.'
404                        Some(OptionalVariant::Option)
405                    }
406
407                    Some(TokenTree::Ident(ident))
408                        if ident.to_string() == "Ok" || ident.to_string() == "Err" =>
409                    {
410                        let ident = ident.clone();
411                        let v = if ident.to_string() == "Ok" {
412                            OptionalVariant::Ok
413                        } else {
414                            OptionalVariant::Err
415                        };
416
417                        // consume Ident
418                        iter.next();
419
420                        // require trailing '.'
421                        match &iter.next() {
422                            Some(TokenTree::Punct(dot)) if dot.as_char() == '.' => Some(v),
423                            other => {
424                                // rollback-ish: treat as normal tokens
425                                if let Some(o) = other {
426                                    current.push(o.clone());
427                                }
428                                None
429                            }
430                        }
431                    }
432
433                    _ => None,
434                };
435
436                if let Some(v) = variant {
437                    if !current.is_empty() {
438                        result.push(OptionalSegment {
439                            variant: current_variant,
440                            tokens: std::mem::take(&mut current),
441                        });
442                    }
443
444                    current_variant = v;
445                    continue;
446                }
447
448                // Not a recognized optional-chain operator
449            }
450
451            _ => current.push(tt.clone()),
452        }
453    }
454
455    result.push(OptionalSegment {
456        variant: current_variant,
457        tokens: current,
458    });
459
460    for i in 0..result.len() - 1 {
461        result[i].variant = result[i + 1].variant.clone();
462    }
463
464    // dbg!(last_token.to_string());
465    if input_tokens.last().is_none() {
466        return result;
467    }
468    let result_len = result.len();
469    match input_tokens.last().unwrap() {
470        TokenTree::Punct(p) if p.as_char() == '?' => {
471            result[result_len - 1].variant = OptionalVariant::Option;
472        }
473        TokenTree::Ident(p) if p.to_string() == "Ok" => {
474            result[result_len - 1].variant = OptionalVariant::Ok;
475        }
476        TokenTree::Ident(p) if p.to_string() == "Err" => {
477            result[result_len - 1].variant = OptionalVariant::Err;
478        }
479        _ => {
480            result[result_len - 1].variant = OptionalVariant::Required;
481        }
482    }
483    result
484}