option_chain_tool/lib.rs
1use proc_macro::{Delimiter, Group, Ident, Punct, Spacing, TokenStream, TokenTree};
2
3/// A procedural macro for safe optional chaining in Rust.
4///
5/// The `opt!` macro provides a concise syntax for chaining operations on `Option` and `Result` types,
6/// similar to optional chaining in languages like TypeScript or Swift. It automatically handles
7/// unwrapping and propagates `None` values through the chain.
8///
9/// # Syntax
10///
11/// The macro supports several operators for different use cases:
12///
13/// - `?.` - Unwraps an `Option`, returns `None` if the value is `None`
14/// - `?Ok.` - Unwraps a `Result` to its `Ok` variant, returns `None` if `Err`
15/// - `?Err.` - Unwraps a `Result` to its `Err` variant, returns `None` if `Ok`
16/// - `.field` - Access a field without unwrapping (for required fields)
17///
18/// The macro returns `Some(value)` if all operations succeed, or `None` if any step fails.
19///
20/// # Examples
21///
22/// ## Basic Option chaining
23///
24/// ```ignore
25/// use option_chain_tool::opt;
26///
27/// struct User {
28/// profile: Option<Profile>,
29/// }
30///
31/// struct Profile {
32/// address: Option<Address>,
33/// }
34///
35/// struct Address {
36/// city: Option<String>,
37/// }
38///
39/// let user = User {
40/// profile: Some(Profile {
41/// address: Some(Address {
42/// city: Some("New York".to_string()),
43/// }),
44/// }),
45/// };
46///
47/// // Instead of: user.profile.as_ref().and_then(|p| p.address.as_ref()).and_then(|a| a.city.as_ref())
48/// let city: Option<&String> = opt!(user.profile?.address?.city?);
49/// assert_eq!(city, Some(&"New York".to_string()));
50/// ```
51///
52/// ## Chaining with method calls
53///
54/// ```ignore
55/// use option_chain_tool::opt;
56///
57/// impl Address {
58/// fn get_city(&self) -> Option<&String> {
59/// self.city.as_ref()
60/// }
61/// }
62///
63/// let city: Option<&String> = opt!(user.profile?.address?.get_city()?);
64/// ```
65///
66/// ## Accessing required fields
67///
68/// ```ignore
69/// use option_chain_tool::opt;
70///
71/// struct Address {
72/// city: Option<String>,
73/// street: String, // Required field
74/// }
75///
76/// // Access a required field in the chain (no ? after street)
77/// let street: Option<&String> = opt!(user.profile?.address?.street);
78/// ```
79///
80/// ## Working with Result types
81///
82/// ```ignore
83/// use option_chain_tool::opt;
84///
85/// struct Address {
86/// validation: Result<String, String>,
87/// }
88///
89/// // Extract the Ok variant
90/// let ok_value: Option<&String> = opt!(user.profile?.address?.validation?Ok);
91///
92/// // Extract the Err variant
93/// let err_value: Option<&String> = opt!(user.profile?.address?.validation?Err);
94/// ```
95///
96/// ## Complex chaining
97///
98/// ```ignore
99/// use option_chain_tool::opt;
100///
101/// // Combine multiple patterns in a single chain
102/// let value: Option<&String> = opt!(
103/// user
104/// .profile? // Unwrap Option<Profile>
105/// .address? // Unwrap Option<Address>
106/// .street // Access required field
107/// .validation?Ok // Unwrap Result to Ok variant
108/// );
109/// ```
110///
111/// # Returns
112///
113/// - `Some(value)` if all operations in the chain succeed
114/// - `None` if any operation in the chain returns `None` or encounters an unwrappable value
115///
116/// # Notes
117///
118/// The macro generates nested `if let` expressions that short-circuit on `None`, providing
119/// efficient and safe optional chaining without runtime panics.
120#[proc_macro]
121pub fn opt(input: TokenStream) -> TokenStream {
122 let resp = split_on_optional_variants(input);
123 // for r in resp.iter() {
124 // let tokens = r
125 // .tokens
126 // .clone()
127 // .into_iter()
128 // .collect::<TokenStream>()
129 // .to_string();
130 // dbg!(format!("Variant: {:?}, Tokens: {}", r.variant, tokens));
131 // }
132 // dbg!(resp.len());
133 let mut result = TokenStream::new();
134 let segments_len = resp.len();
135 for (index, segment) in resp.into_iter().rev().enumerate() {
136 if segments_len - 1 == index {
137 if result.is_empty() {
138 let mut ____v = TokenStream::new();
139 ____v.extend([TokenTree::Ident(Ident::new(
140 "____v",
141 proc_macro::Span::call_site(),
142 ))]);
143 result = some_wrapper(____v);
144 }
145 result = if_let(
146 segment.variant,
147 segment.tokens.into_iter().collect(),
148 result,
149 true,
150 );
151 continue;
152 }
153 {
154 let mut is_add_amp = true;
155 if index == 0 {
156 if ends_with_fn_call(&segment.tokens) {
157 is_add_amp = false;
158 }
159 }
160
161 let mut after_eq = TokenStream::new();
162 after_eq.extend([
163 TokenTree::Ident(Ident::new("____v", proc_macro::Span::call_site())),
164 TokenTree::Punct(Punct::new('.', Spacing::Joint)),
165 ]);
166 after_eq.extend(segment.tokens.into_iter());
167 if result.is_empty() {
168 let mut ____v = TokenStream::new();
169 ____v.extend([TokenTree::Ident(Ident::new(
170 "____v",
171 proc_macro::Span::call_site(),
172 ))]);
173 result = some_wrapper(____v);
174 }
175 result = if_let(segment.variant, after_eq, result, is_add_amp);
176 }
177 }
178
179 result
180}
181
182/// Wraps a token stream in a `Some(...)` expression.
183///
184/// This helper function takes a token stream and wraps it in a `Some` constructor,
185/// which is used to return successful values in the optional chaining.
186///
187/// # Arguments
188///
189/// * `body` - The token stream to wrap inside `Some`
190///
191/// # Returns
192///
193/// A token stream representing `Some(body)`
194///
195/// # Example
196///
197/// ```ignore
198/// // Input: ____v
199/// // Output: Some(____v)
200/// ```
201fn some_wrapper(body: TokenStream) -> TokenStream {
202 let mut ts = TokenStream::new();
203 ts.extend([TokenTree::Ident(Ident::new(
204 "Some",
205 proc_macro::Span::call_site(),
206 ))]);
207 ts.extend([TokenTree::Group(Group::new(Delimiter::Parenthesis, body))]);
208 ts
209}
210
211/// Checks if a sequence of tokens ends with a function call.
212///
213/// This function examines the last token in a slice to determine if it represents
214/// a function call, which is identified by a closing parenthesis group.
215///
216/// # Arguments
217///
218/// * `tokens` - A slice of `TokenTree` to examine
219///
220/// # Returns
221///
222/// `true` if the last token is a group with parenthesis delimiter (indicating a function call),
223/// `false` otherwise
224///
225/// # Example
226///
227/// ```ignore
228/// // Returns true for: foo.bar()
229/// // Returns false for: foo.bar
230/// ```
231fn ends_with_fn_call(tokens: &[TokenTree]) -> bool {
232 let last = match tokens.last() {
233 Some(tt) => tt,
234 None => return false,
235 };
236
237 if let TokenTree::Group(group) = last {
238 if group.delimiter() == Delimiter::Parenthesis {
239 return true;
240 }
241 }
242
243 false
244}
245
246/// Generates an `if let` expression for pattern matching in the optional chain.
247///
248/// This function constructs an `if let` expression that attempts to unwrap a value
249/// according to the specified variant (`Some`, `Ok`, or `Err`). If the pattern matches,
250/// the body is executed; otherwise, `None` is returned.
251///
252/// # Arguments
253///
254/// * `variant` - The type of unwrapping to perform (Option, Ok, Err, Required, or Root)
255/// * `after_eq` - Token stream representing the expression to be matched
256/// * `body` - Token stream representing the code to execute if the pattern matches
257/// * `is_add_amp` - Whether to add a reference (`&`) before the expression being matched
258///
259/// # Returns
260///
261/// A token stream representing the complete `if let` expression with an `else` clause
262/// that returns `None`
263///
264/// # Panics
265///
266/// Panics if called with `OptionalVariant::Root`
267///
268/// # Example
269///
270/// ```ignore
271/// // Generates: if let Some(____v) = &expr { body } else { None }
272/// ```
273fn if_let(
274 variant: OptionalVariant,
275 after_eq: TokenStream,
276 body: TokenStream,
277 is_add_amp: bool,
278) -> TokenStream {
279 let mut ts = TokenStream::new();
280 ts.extend([TokenTree::Ident(Ident::new(
281 "if",
282 proc_macro::Span::call_site(),
283 ))]);
284 ts.extend([TokenTree::Ident(Ident::new(
285 "let",
286 proc_macro::Span::call_site(),
287 ))]);
288 match variant {
289 OptionalVariant::Option => {
290 ts.extend([TokenTree::Ident(Ident::new(
291 "Some",
292 proc_macro::Span::call_site(),
293 ))]);
294 }
295 OptionalVariant::Ok => {
296 ts.extend([TokenTree::Ident(Ident::new(
297 "Ok",
298 proc_macro::Span::call_site(),
299 ))]);
300 }
301 OptionalVariant::Err => {
302 ts.extend([TokenTree::Ident(Ident::new(
303 "Err",
304 proc_macro::Span::call_site(),
305 ))]);
306 }
307 OptionalVariant::Required => {
308 // panic!("if_let called with Required variant");
309 }
310 OptionalVariant::Root => {
311 panic!("if_let called with Root variant");
312 }
313 }
314 ts.extend([TokenTree::Group(Group::new(
315 Delimiter::Parenthesis,
316 TokenTree::Ident(Ident::new("____v", proc_macro::Span::call_site())).into(),
317 ))]);
318 ts.extend([TokenTree::Punct(Punct::new('=', Spacing::Alone))]);
319 if is_add_amp {
320 ts.extend([TokenTree::Punct(Punct::new('&', Spacing::Joint))]);
321 }
322 ts.extend(after_eq);
323 ts.extend([TokenTree::Group(Group::new(Delimiter::Brace, body))]);
324 ts.extend([TokenTree::Ident(Ident::new(
325 "else",
326 proc_macro::Span::call_site(),
327 ))]);
328 ts.extend([TokenTree::Group(Group::new(
329 Delimiter::Brace,
330 TokenTree::Ident(Ident::new("None", proc_macro::Span::call_site())).into(),
331 ))]);
332 ts
333}
334
335/// Represents the type of optional chaining operation at each segment.
336///
337/// This enum identifies how each segment in the optional chain should be unwrapped
338/// or accessed, enabling the macro to generate the appropriate pattern matching code.
339#[derive(Debug, Clone, Copy, PartialEq, Eq)]
340enum OptionalVariant {
341 /// First segment of the chain (no unwrapping operator)
342 Root,
343 /// Unwrap an `Option` using `?.` operator
344 Option,
345 /// Unwrap a `Result` to its `Ok` variant using `?Ok.` operator
346 Ok,
347 /// Unwrap a `Result` to its `Err` variant using `?Err.` operator
348 Err,
349 /// Access a field directly without unwrapping (no `?` operator)
350 Required,
351}
352
353/// Represents a single segment in the optional chaining expression.
354///
355/// Each segment contains the tokens that make up that part of the chain,
356/// along with the variant indicating how it should be unwrapped.
357#[derive(Debug, Clone)]
358struct OptionalSegment {
359 /// The type of unwrapping operation for this segment
360 pub variant: OptionalVariant,
361 /// The token trees that make up this segment's expression
362 pub tokens: Vec<TokenTree>,
363}
364
365/// Parses the input token stream and splits it into segments based on optional chaining operators.
366///
367/// This function analyzes the input token stream to identify optional chaining operators
368/// (`?.`, `?Ok.`, `?Err.`) and splits the expression into segments, each with its corresponding
369/// variant type. The segments are then used to generate the nested `if let` expressions.
370///
371/// # Arguments
372///
373/// * `input` - The input token stream to parse
374///
375/// # Returns
376///
377/// A vector of `OptionalSegment` structs, where each segment represents a portion of the
378/// chaining expression along with its unwrapping variant
379///
380/// # Example
381///
382/// ```ignore
383/// // Input: user.profile?.address?.city?
384/// // Output: [
385/// // OptionalSegment { variant: Option, tokens: [user, .profile] },
386/// // OptionalSegment { variant: Option, tokens: [address] },
387/// // OptionalSegment { variant: Option, tokens: [city] }
388/// // ]
389/// ```
390fn split_on_optional_variants(input: TokenStream) -> Vec<OptionalSegment> {
391 let input_tokens: Vec<TokenTree> = input.clone().into_iter().collect();
392 let mut iter = input.into_iter().peekable();
393
394 let mut result: Vec<OptionalSegment> = Vec::new();
395 let mut current: Vec<TokenTree> = Vec::new();
396 let mut current_variant = OptionalVariant::Root;
397 while let Some(tt) = iter.next().as_ref() {
398 match &tt {
399 TokenTree::Punct(q) if q.as_char() == '?' => {
400 // Try to detect ?. / ?Ok. / ?Err.
401 let variant = match iter.peek() {
402 Some(TokenTree::Punct(dot)) if dot.as_char() == '.' => {
403 iter.next(); // consume '.'
404 Some(OptionalVariant::Option)
405 }
406
407 Some(TokenTree::Ident(ident))
408 if ident.to_string() == "Ok" || ident.to_string() == "Err" =>
409 {
410 let ident = ident.clone();
411 let v = if ident.to_string() == "Ok" {
412 OptionalVariant::Ok
413 } else {
414 OptionalVariant::Err
415 };
416
417 // consume Ident
418 iter.next();
419
420 // require trailing '.'
421 match &iter.next() {
422 Some(TokenTree::Punct(dot)) if dot.as_char() == '.' => Some(v),
423 other => {
424 // rollback-ish: treat as normal tokens
425 if let Some(o) = other {
426 current.push(o.clone());
427 }
428 None
429 }
430 }
431 }
432
433 _ => None,
434 };
435
436 if let Some(v) = variant {
437 if !current.is_empty() {
438 result.push(OptionalSegment {
439 variant: current_variant,
440 tokens: std::mem::take(&mut current),
441 });
442 }
443
444 current_variant = v;
445 continue;
446 }
447
448 // Not a recognized optional-chain operator
449 }
450
451 _ => current.push(tt.clone()),
452 }
453 }
454
455 result.push(OptionalSegment {
456 variant: current_variant,
457 tokens: current,
458 });
459
460 for i in 0..result.len() - 1 {
461 result[i].variant = result[i + 1].variant.clone();
462 }
463
464 // dbg!(last_token.to_string());
465 if input_tokens.last().is_none() {
466 return result;
467 }
468 let result_len = result.len();
469 match input_tokens.last().unwrap() {
470 TokenTree::Punct(p) if p.as_char() == '?' => {
471 result[result_len - 1].variant = OptionalVariant::Option;
472 }
473 TokenTree::Ident(p) if p.to_string() == "Ok" => {
474 result[result_len - 1].variant = OptionalVariant::Ok;
475 }
476 TokenTree::Ident(p) if p.to_string() == "Err" => {
477 result[result_len - 1].variant = OptionalVariant::Err;
478 }
479 _ => {
480 result[result_len - 1].variant = OptionalVariant::Required;
481 }
482 }
483 result
484}