Skip to main content

cuda_rust_wasm/parser/
cuda_parser.rs

1//! CUDA source code parser using nom combinators
2//!
3//! Parses a subset of CUDA C++ sufficient for common GPU kernels.
4
5use nom::{
6    IResult,
7    branch::alt,
8    bytes::complete::{tag, take_while, take_while1, take_until},
9    character::complete::{char, multispace0, multispace1, digit1, alpha1, one_of},
10    combinator::{opt, map, value, recognize},
11    multi::{separated_list0, many1},
12    sequence::{pair, tuple, delimited, preceded},
13};
14
15use crate::{Result, parse_error};
16use super::ast::*;
17
18// ═══════════════════════════════════════════════════════════════
19//  Utility combinators
20// ═══════════════════════════════════════════════════════════════
21
22/// Whitespace + comment skipper
23fn ws(input: &str) -> IResult<&str, ()> {
24    let mut rest = input;
25    loop {
26        let (r, _) = multispace0(rest)?;
27        rest = r;
28        if rest.starts_with("//") {
29            let end = rest.find('\n').unwrap_or(rest.len());
30            rest = &rest[end..];
31        } else if rest.starts_with("/*") {
32            if let Some(end) = rest.find("*/") {
33                rest = &rest[end + 2..];
34            } else {
35                return Err(nom::Err::Error(nom::error::Error::new(rest, nom::error::ErrorKind::Tag)));
36            }
37        } else {
38            break;
39        }
40    }
41    Ok((rest, ()))
42}
43
44/// Parse with surrounding whitespace/comments
45fn ws_around<'a, F, O>(inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O>
46where
47    F: FnMut(&'a str) -> IResult<&'a str, O>,
48{
49    delimited(ws, inner, ws)
50}
51
52/// Parse an identifier: [a-zA-Z_][a-zA-Z0-9_]*
53fn identifier(input: &str) -> IResult<&str, &str> {
54    recognize(pair(
55        alt((alpha1, tag("_"))),
56        take_while(|c: char| c.is_alphanumeric() || c == '_'),
57    ))(input)
58}
59
60/// Parse an identifier with surrounding whitespace
61fn ws_ident(input: &str) -> IResult<&str, &str> {
62    let (input, _) = ws(input)?;
63    identifier(input)
64}
65
66/// Parse a specific tag with surrounding whitespace
67fn ws_tag<'a>(t: &'a str) -> impl FnMut(&'a str) -> IResult<&'a str, &'a str> {
68    delimited(ws, tag(t), ws)
69}
70
71/// Helper to call `tag` with explicit type annotation, avoiding turbofish issues
72fn t<'a>(s: &'a str) -> impl FnMut(&'a str) -> IResult<&'a str, &'a str> {
73    tag(s)
74}
75
76/// Check that the character after a keyword is not alphanumeric/underscore
77fn keyword<'a>(kw: &'a str) -> impl FnMut(&'a str) -> IResult<&'a str, &'a str> {
78    move |input: &'a str| {
79        let (rest, matched) = tag(kw)(input)?;
80        // Make sure it's not just a prefix of a longer identifier
81        if let Some(c) = rest.chars().next() {
82            if c.is_alphanumeric() || c == '_' {
83                return Err(nom::Err::Error(nom::error::Error::new(input, nom::error::ErrorKind::Tag)));
84            }
85        }
86        Ok((rest, matched))
87    }
88}
89
90// ═══════════════════════════════════════════════════════════════
91//  Literal parsing
92// ═══════════════════════════════════════════════════════════════
93
94fn parse_float_literal(input: &str) -> IResult<&str, Expression> {
95    // Try: digits.digits[e/E[+-]digits][f/F]  OR  .digits[e/E[+-]digits][f/F]  OR  digits e/E [+-] digits [f/F]  OR  digits f/F
96    let (rest, text) = recognize(alt((
97        // 1.0, 1.0f, 1.0e5, 1.0e5f, 1., 1.f
98        recognize(tuple((
99            digit1,
100            char('.'),
101            opt(digit1),
102            opt(recognize(tuple((one_of("eE"), opt(one_of("+-")), digit1)))),
103            opt(one_of("fF")),
104        ))),
105        // .5, .5f, .5e3
106        recognize(tuple((
107            char('.'),
108            digit1,
109            opt(recognize(tuple((one_of("eE"), opt(one_of("+-")), digit1)))),
110            opt(one_of("fF")),
111        ))),
112        // 1e5, 1e5f
113        recognize(tuple((
114            digit1,
115            one_of("eE"),
116            opt(one_of("+-")),
117            digit1,
118            opt(one_of("fF")),
119        ))),
120        // 1f  (integer with f suffix = float)
121        recognize(tuple((digit1, one_of("fF")))),
122    )))(input)?;
123
124    let clean = text.trim_end_matches(|c| c == 'f' || c == 'F');
125    let val: f64 = clean.parse().unwrap_or(0.0);
126    Ok((rest, Expression::Literal(Literal::Float(val))))
127}
128
129fn parse_hex_literal(input: &str) -> IResult<&str, Expression> {
130    let (rest, text) = recognize(tuple((
131        tag("0"),
132        one_of("xX"),
133        take_while1(|c: char| c.is_ascii_hexdigit()),
134        opt(one_of("uUlL")),
135    )))(input)?;
136    let clean = text.trim_start_matches("0x").trim_start_matches("0X");
137    let clean = clean.trim_end_matches(|c| c == 'u' || c == 'U' || c == 'l' || c == 'L');
138    let val = u64::from_str_radix(clean, 16).unwrap_or(0);
139    if text.contains('u') || text.contains('U') {
140        Ok((rest, Expression::Literal(Literal::UInt(val))))
141    } else {
142        Ok((rest, Expression::Literal(Literal::Int(val as i64))))
143    }
144}
145
146fn parse_int_literal(input: &str) -> IResult<&str, Expression> {
147    let (rest, text) = recognize(pair(
148        digit1,
149        opt(recognize(many1(one_of("uUlL")))),
150    ))(input)?;
151    let clean = text.trim_end_matches(|c| c == 'u' || c == 'U' || c == 'l' || c == 'L');
152    if text.contains('u') || text.contains('U') {
153        let val: u64 = clean.parse().unwrap_or(0);
154        Ok((rest, Expression::Literal(Literal::UInt(val))))
155    } else {
156        let val: i64 = clean.parse().unwrap_or(0);
157        Ok((rest, Expression::Literal(Literal::Int(val))))
158    }
159}
160
161fn parse_literal(input: &str) -> IResult<&str, Expression> {
162    alt((
163        parse_hex_literal,
164        parse_float_literal,
165        parse_int_literal,
166    ))(input)
167}
168
169// ═══════════════════════════════════════════════════════════════
170//  Type parsing
171// ═══════════════════════════════════════════════════════════════
172
173fn parse_base_type(input: &str) -> IResult<&str, Type> {
174    let (input, _) = ws(input)?;
175    alt((
176        value(Type::Void, keyword("void")),
177        value(Type::Bool, keyword("bool")),
178        // Vector types before scalar to avoid partial match
179        value(Type::Vector(VectorType { element: Box::new(Type::Float(FloatType::F32)), size: 2 }), keyword("float2")),
180        value(Type::Vector(VectorType { element: Box::new(Type::Float(FloatType::F32)), size: 3 }), keyword("float3")),
181        value(Type::Vector(VectorType { element: Box::new(Type::Float(FloatType::F32)), size: 4 }), keyword("float4")),
182        value(Type::Vector(VectorType { element: Box::new(Type::Int(IntType::I32)), size: 2 }), keyword("int2")),
183        value(Type::Vector(VectorType { element: Box::new(Type::Int(IntType::I32)), size: 3 }), keyword("int3")),
184        value(Type::Vector(VectorType { element: Box::new(Type::Int(IntType::I32)), size: 4 }), keyword("int4")),
185        value(Type::Vector(VectorType { element: Box::new(Type::Float(FloatType::F64)), size: 2 }), keyword("double2")),
186        value(Type::Vector(VectorType { element: Box::new(Type::Float(FloatType::F64)), size: 3 }), keyword("double3")),
187        value(Type::Vector(VectorType { element: Box::new(Type::Float(FloatType::F64)), size: 4 }), keyword("double4")),
188        value(Type::Named("dim3".to_string()), keyword("dim3")),
189        value(Type::Float(FloatType::F64), keyword("double")),
190        value(Type::Float(FloatType::F32), keyword("float")),
191        // "unsigned int", "unsigned" alone
192        map(preceded(keyword("unsigned"), opt(preceded(multispace1, keyword("int")))), |_| Type::Int(IntType::U32)),
193        // "long long"
194        map(pair(keyword("long"), opt(preceded(multispace1, keyword("long")))), |(_, ll)| {
195            if ll.is_some() { Type::Int(IntType::I64) } else { Type::Int(IntType::I64) }
196        }),
197        value(Type::Int(IntType::I16), keyword("short")),
198        value(Type::Int(IntType::I8), keyword("char")),
199        value(Type::Int(IntType::I32), keyword("int")),
200        // size_t mapped to U64
201        value(Type::Int(IntType::U64), keyword("size_t")),
202        // Named/user-defined type (fallback)
203        map(identifier, |name: &str| Type::Named(name.to_string())),
204    ))(input)
205}
206
207/// Parse a full type including const, pointer, and array suffixes
208fn parse_type(input: &str) -> IResult<&str, (Type, Vec<ParamQualifier>)> {
209    let (input, _) = ws(input)?;
210
211    // Collect leading qualifiers: const, volatile, __restrict__
212    let mut qualifiers = Vec::new();
213    let mut rest = input;
214    loop {
215        let (r, _) = ws(rest)?;
216        if let Ok((r2, _)) = keyword("const")(r) {
217            qualifiers.push(ParamQualifier::Const);
218            rest = r2;
219        } else if let Ok((r2, _)) = keyword("volatile")(r) {
220            qualifiers.push(ParamQualifier::Volatile);
221            rest = r2;
222        } else if let Ok((r2, _)) = t("__restrict__")(r) {
223            qualifiers.push(ParamQualifier::Restrict);
224            rest = r2;
225        } else {
226            break;
227        }
228    }
229
230    // Parse the base type
231    let (rest, mut ty) = parse_base_type(rest)?;
232
233    // Trailing qualifiers/pointers
234    let mut rest = rest;
235    loop {
236        let (r, _) = ws(rest)?;
237        if let Ok((r2, _)) = char::<&str, nom::error::Error<&str>>('*')(r) {
238            ty = Type::Pointer(Box::new(ty));
239            rest = r2;
240            // After *, may have const/__restrict__
241            let (r3, _) = ws(rest)?;
242            if let Ok((r4, _)) = keyword("const")(r3) {
243                qualifiers.push(ParamQualifier::Const);
244                rest = r4;
245            } else if let Ok((r4, _)) = t("__restrict__")(r3) {
246                qualifiers.push(ParamQualifier::Restrict);
247                rest = r4;
248            } else if let Ok((r4, _)) = keyword("restrict")(r3) {
249                qualifiers.push(ParamQualifier::Restrict);
250                rest = r4;
251            } else {
252                rest = r3;
253            }
254        } else {
255            rest = r;
256            break;
257        }
258    }
259
260    Ok((rest, (ty, qualifiers)))
261}
262
263// ═══════════════════════════════════════════════════════════════
264//  Expression parsing (precedence climbing)
265// ═══════════════════════════════════════════════════════════════
266
267/// Primary expression: literals, variables, parenthesised, casts, CUDA builtins
268fn parse_primary(input: &str) -> IResult<&str, Expression> {
269    let (input, _) = ws(input)?;
270    alt((
271        parse_cuda_builtin,
272        parse_sizeof_expr,
273        parse_cast_or_paren,
274        parse_literal,
275        parse_ident_or_call,
276    ))(input)
277}
278
279/// Parse sizeof(type) or sizeof(expr)
280fn parse_sizeof_expr(input: &str) -> IResult<&str, Expression> {
281    let (input, _) = keyword("sizeof")(input)?;
282    let (input, _) = ws(input)?;
283    let (input, _) = char('(')(input)?;
284    let (input, _) = ws(input)?;
285    // Try to parse a type first, fall back to expression
286    if let Ok((rest, (ty, _))) = parse_type(input) {
287        let (rest, _) = ws(rest)?;
288        let (rest, _) = char(')')(rest)?;
289        // Return as a call for simplicity
290        Ok((rest, Expression::Call {
291            name: "sizeof".to_string(),
292            args: vec![Expression::Var(format!("{:?}", ty))],
293        }))
294    } else {
295        let (input, expr) = parse_expr(input)?;
296        let (input, _) = ws(input)?;
297        let (input, _) = char(')')(input)?;
298        Ok((input, Expression::Call {
299            name: "sizeof".to_string(),
300            args: vec![expr],
301        }))
302    }
303}
304
305/// threadIdx.x, blockIdx.y, blockDim.z, gridDim.x
306fn parse_cuda_builtin(input: &str) -> IResult<&str, Expression> {
307    let (input, builtin) = alt((
308        tag("threadIdx"),
309        tag("blockIdx"),
310        tag("blockDim"),
311        tag("gridDim"),
312    ))(input)?;
313    // Ensure not part of a longer ident
314    if let Some(c) = input.chars().next() {
315        if c.is_alphanumeric() || c == '_' {
316            return Err(nom::Err::Error(nom::error::Error::new(input, nom::error::ErrorKind::Tag)));
317        }
318    }
319    let (input, _) = ws(input)?;
320    let (input, _) = char('.')(input)?;
321    let (input, _) = ws(input)?;
322    let (input, dim_str) = alt((tag("x"), tag("y"), tag("z")))(input)?;
323    let dim = match dim_str {
324        "x" => Dimension::X,
325        "y" => Dimension::Y,
326        "z" => Dimension::Z,
327        _ => unreachable!(),
328    };
329    let expr = match builtin {
330        "threadIdx" => Expression::ThreadIdx(dim),
331        "blockIdx" => Expression::BlockIdx(dim),
332        "blockDim" => Expression::BlockDim(dim),
333        "gridDim" => Expression::GridDim(dim),
334        _ => unreachable!(),
335    };
336    Ok((input, expr))
337}
338
339/// Try cast `(type)expr` or parenthesised expression `(expr)`
340fn parse_cast_or_paren(input: &str) -> IResult<&str, Expression> {
341    let (input, _) = char('(')(input)?;
342    let (input, _) = ws(input)?;
343
344    // Try to parse as a type cast: (type)expr
345    // We speculatively try parsing a type. If it succeeds and is immediately
346    // followed by ')', treat it as a cast.
347    let checkpoint = input;
348    if let Ok((after_ty, (ty, _))) = parse_type(checkpoint) {
349        let (after_ty, _) = ws(after_ty)?;
350        if let Ok((after_close, _)) = char::<&str, nom::error::Error<&str>>(')')(after_ty) {
351            // Check that it's actually a cast (what follows looks like an expression start)
352            let (peek_rest, _) = ws(after_close)?;
353            let looks_like_expr = peek_rest.starts_with('(')
354                || peek_rest.starts_with(|c: char| c.is_alphanumeric() || c == '_' || c == '-' || c == '!' || c == '~' || c == '.');
355            if looks_like_expr {
356                // It's a cast only if the type is a real type (not just a variable name being subtracted)
357                let is_real_type = matches!(ty,
358                    Type::Void | Type::Bool | Type::Int(_) | Type::Float(_) | Type::Pointer(_)
359                    | Type::Vector(_) | Type::Array(_, _));
360                if is_real_type {
361                    let (rest, expr) = parse_unary(after_close)?;
362                    return Ok((rest, Expression::Cast { ty, expr: Box::new(expr) }));
363                }
364            }
365        }
366    }
367
368    // Otherwise, parenthesised expression
369    let (input, expr) = parse_expr(checkpoint)?;
370    let (input, _) = ws(input)?;
371    let (input, _) = char(')')(input)?;
372    Ok((input, expr))
373}
374
375/// Identifier, function call, or __syncthreads()
376fn parse_ident_or_call(input: &str) -> IResult<&str, Expression> {
377    // __syncthreads()
378    if let Ok((rest, _)) = tag::<&str, &str, nom::error::Error<&str>>("__syncthreads")(input) {
379        let (rest, _) = ws(rest)?;
380        if let Ok((rest, _)) = char::<&str, nom::error::Error<&str>>('(')(rest) {
381            let (rest, _) = ws(rest)?;
382            let (rest, _) = char(')')(rest)?;
383            // We'll handle this as a special call that gets turned into SyncThreads statement
384            return Ok((rest, Expression::Call { name: "__syncthreads".to_string(), args: vec![] }));
385        }
386    }
387
388    let (input, name) = identifier(input)?;
389    let (input, _) = ws(input)?;
390
391    // Check for function call
392    if let Ok((rest, _)) = char::<&str, nom::error::Error<&str>>('(')(input) {
393        let (rest, _) = ws(rest)?;
394        let (rest, args) = separated_list0(
395            delimited(ws, char(','), ws),
396            parse_expr,
397        )(rest)?;
398        let (rest, _) = ws(rest)?;
399        let (rest, _) = char(')')(rest)?;
400
401        // Detect warp primitives
402        let expr = match name {
403            "__shfl_sync" => Expression::WarpPrimitive { op: WarpOp::Shuffle, args },
404            "__shfl_xor_sync" => Expression::WarpPrimitive { op: WarpOp::ShuffleXor, args },
405            "__shfl_up_sync" => Expression::WarpPrimitive { op: WarpOp::ShuffleUp, args },
406            "__shfl_down_sync" => Expression::WarpPrimitive { op: WarpOp::ShuffleDown, args },
407            "__ballot_sync" => Expression::WarpPrimitive { op: WarpOp::Ballot, args },
408            "__activemask" => Expression::WarpPrimitive { op: WarpOp::ActiveMask, args },
409            _ => Expression::Call { name: name.to_string(), args },
410        };
411        return Ok((rest, expr));
412    }
413
414    Ok((input, Expression::Var(name.to_string())))
415}
416
417/// Postfix: a[i], a.field, a->field, a++, a--
418fn parse_postfix(input: &str) -> IResult<&str, Expression> {
419    let (mut rest, mut expr) = parse_primary(input)?;
420
421    loop {
422        let (r, _) = ws(rest)?;
423
424        // Array index: expr[index]
425        if let Ok((r2, _)) = char::<&str, nom::error::Error<&str>>('[')(r) {
426            let (r2, _) = ws(r2)?;
427            let (r2, index) = parse_expr(r2)?;
428            let (r2, _) = ws(r2)?;
429            let (r2, _) = char(']')(r2)?;
430            expr = Expression::Index {
431                array: Box::new(expr),
432                index: Box::new(index),
433            };
434            rest = r2;
435            continue;
436        }
437
438        // Member access: expr.field (but not after CUDA builtins which already consumed the dot)
439        if let Ok((r2, _)) = char::<&str, nom::error::Error<&str>>('.')(r) {
440            if let Ok((r3, field)) = identifier(r2) {
441                // Make sure it's not a float literal like ".5"
442                expr = Expression::Member {
443                    object: Box::new(expr),
444                    field: field.to_string(),
445                };
446                rest = r3;
447                continue;
448            }
449        }
450
451        // Arrow: expr->field
452        if let Ok((r2, _)) = tag::<&str, &str, nom::error::Error<&str>>("->")(r) {
453            let (r3, field) = identifier(r2)?;
454            expr = Expression::Member {
455                object: Box::new(expr),
456                field: field.to_string(),
457            };
458            rest = r3;
459            continue;
460        }
461
462        // Post-increment: expr++
463        if let Ok((r2, _)) = tag::<&str, &str, nom::error::Error<&str>>("++")(r) {
464            expr = Expression::Unary {
465                op: UnaryOp::PostInc,
466                expr: Box::new(expr),
467            };
468            rest = r2;
469            continue;
470        }
471
472        // Post-decrement: expr--
473        if let Ok((r2, _)) = tag::<&str, &str, nom::error::Error<&str>>("--")(r) {
474            expr = Expression::Unary {
475                op: UnaryOp::PostDec,
476                expr: Box::new(expr),
477            };
478            rest = r2;
479            continue;
480        }
481
482        rest = r;
483        break;
484    }
485
486    Ok((rest, expr))
487}
488
489/// Unary prefix: ++x, --x, -x, !x, ~x, *x, &x
490fn parse_unary(input: &str) -> IResult<&str, Expression> {
491    let (input, _) = ws(input)?;
492
493    // Pre-increment
494    if let Ok((rest, _)) = tag::<&str, &str, nom::error::Error<&str>>("++")(input) {
495        let (rest, expr) = parse_unary(rest)?;
496        return Ok((rest, Expression::Unary { op: UnaryOp::PreInc, expr: Box::new(expr) }));
497    }
498    // Pre-decrement
499    if let Ok((rest, _)) = tag::<&str, &str, nom::error::Error<&str>>("--")(input) {
500        let (rest, expr) = parse_unary(rest)?;
501        return Ok((rest, Expression::Unary { op: UnaryOp::PreDec, expr: Box::new(expr) }));
502    }
503    // Unary minus (not --> )
504    if input.starts_with('-') && !input.starts_with("--") && !input.starts_with("->") {
505        let (rest, _) = char('-')(input)?;
506        let (rest, expr) = parse_unary(rest)?;
507        return Ok((rest, Expression::Unary { op: UnaryOp::Neg, expr: Box::new(expr) }));
508    }
509    // Logical NOT
510    if input.starts_with('!') && !input.starts_with("!=") {
511        let (rest, _) = char('!')(input)?;
512        let (rest, expr) = parse_unary(rest)?;
513        return Ok((rest, Expression::Unary { op: UnaryOp::Not, expr: Box::new(expr) }));
514    }
515    // Bitwise NOT
516    if input.starts_with('~') {
517        let (rest, _) = char('~')(input)?;
518        let (rest, expr) = parse_unary(rest)?;
519        return Ok((rest, Expression::Unary { op: UnaryOp::BitNot, expr: Box::new(expr) }));
520    }
521    // Dereference
522    if input.starts_with('*') && !input.starts_with("*=") {
523        let (rest, _) = char('*')(input)?;
524        let (rest, expr) = parse_unary(rest)?;
525        return Ok((rest, Expression::Unary { op: UnaryOp::Deref, expr: Box::new(expr) }));
526    }
527    // Address-of
528    if input.starts_with('&') && !input.starts_with("&&") && !input.starts_with("&=") {
529        let (rest, _) = char('&')(input)?;
530        let (rest, expr) = parse_unary(rest)?;
531        return Ok((rest, Expression::Unary { op: UnaryOp::AddrOf, expr: Box::new(expr) }));
532    }
533
534    parse_postfix(input)
535}
536
537/// Binary expression using precedence climbing
538fn parse_expr(input: &str) -> IResult<&str, Expression> {
539    parse_assignment(input)
540}
541
542fn parse_assignment(input: &str) -> IResult<&str, Expression> {
543    let (mut rest, mut left) = parse_ternary(input)?;
544
545    loop {
546        let (r, _) = ws(rest)?;
547
548        // Try compound assignments first (longer tokens before shorter)
549        let compound_op: Option<(usize, BinaryOp)> = if r.starts_with("<<=") {
550            Some((3, BinaryOp::Shl))
551        } else if r.starts_with(">>=") {
552            Some((3, BinaryOp::Shr))
553        } else if r.starts_with("+=") {
554            Some((2, BinaryOp::Add))
555        } else if r.starts_with("-=") {
556            Some((2, BinaryOp::Sub))
557        } else if r.starts_with("*=") {
558            Some((2, BinaryOp::Mul))
559        } else if r.starts_with("/=") {
560            Some((2, BinaryOp::Div))
561        } else if r.starts_with("%=") {
562            Some((2, BinaryOp::Mod))
563        } else if r.starts_with("&=") {
564            Some((2, BinaryOp::And))
565        } else if r.starts_with("|=") {
566            Some((2, BinaryOp::Or))
567        } else if r.starts_with("^=") {
568            Some((2, BinaryOp::Xor))
569        } else {
570            None
571        };
572
573        if let Some((len, op)) = compound_op {
574            let r2 = &r[len..];
575            let (r2, right) = parse_assignment(r2)?;
576            // Desugar: a += b  =>  a = a + b
577            left = Expression::Binary {
578                op: BinaryOp::Assign,
579                left: Box::new(left.clone()),
580                right: Box::new(Expression::Binary {
581                    op,
582                    left: Box::new(left),
583                    right: Box::new(right),
584                }),
585            };
586            rest = r2;
587            continue;
588        }
589
590        // Simple assignment: = but not ==
591        if r.starts_with('=') && !r.starts_with("==") {
592            let r2 = &r[1..];
593            let (r2, right) = parse_assignment(r2)?;
594            left = Expression::Binary { op: BinaryOp::Assign, left: Box::new(left), right: Box::new(right) };
595            rest = r2;
596            continue;
597        }
598
599        break;
600    }
601
602    Ok((rest, left))
603}
604
605fn parse_ternary(input: &str) -> IResult<&str, Expression> {
606    let (rest, cond) = parse_logical_or(input)?;
607    let (r, _) = ws(rest)?;
608    if let Ok((r2, _)) = char::<&str, nom::error::Error<&str>>('?')(r) {
609        let (r2, then_expr) = parse_expr(r2)?;
610        let (r2, _) = ws(r2)?;
611        let (r2, _) = char(':')(r2)?;
612        let (r2, else_expr) = parse_ternary(r2)?;
613        // Represent ternary as an if-like construct using Call for now
614        // Actually, let's return it as a special call  __ternary__(cond, then, else)
615        // The AST doesn't have ternary, so we use a synthetic representation
616        Ok((r2, Expression::Call {
617            name: "__ternary__".to_string(),
618            args: vec![cond, then_expr, else_expr],
619        }))
620    } else {
621        Ok((rest, cond))
622    }
623}
624
625// ── Binary operator levels ──────────────────────────────────
626
627fn parse_logical_or(input: &str) -> IResult<&str, Expression> {
628    let (mut rest, mut left) = parse_logical_and(input)?;
629    loop {
630        let (r, _) = ws(rest)?;
631        if let Ok((r2, _)) = tag::<&str, &str, nom::error::Error<&str>>("||")(r) {
632            let (r2, right) = parse_logical_and(r2)?;
633            left = Expression::Binary { op: BinaryOp::LogicalOr, left: Box::new(left), right: Box::new(right) };
634            rest = r2;
635        } else {
636            rest = r;
637            break;
638        }
639    }
640    Ok((rest, left))
641}
642
643fn parse_logical_and(input: &str) -> IResult<&str, Expression> {
644    let (mut rest, mut left) = parse_bitwise_or(input)?;
645    loop {
646        let (r, _) = ws(rest)?;
647        if let Ok((r2, _)) = tag::<&str, &str, nom::error::Error<&str>>("&&")(r) {
648            let (r2, right) = parse_bitwise_or(r2)?;
649            left = Expression::Binary { op: BinaryOp::LogicalAnd, left: Box::new(left), right: Box::new(right) };
650            rest = r2;
651        } else {
652            rest = r;
653            break;
654        }
655    }
656    Ok((rest, left))
657}
658
659fn parse_bitwise_or(input: &str) -> IResult<&str, Expression> {
660    let (mut rest, mut left) = parse_bitwise_xor(input)?;
661    loop {
662        let (r, _) = ws(rest)?;
663        // | but not ||
664        if r.starts_with('|') && !r.starts_with("||") && !r.starts_with("|=") {
665            let (r2, _) = char('|')(r)?;
666            let (r2, right) = parse_bitwise_xor(r2)?;
667            left = Expression::Binary { op: BinaryOp::Or, left: Box::new(left), right: Box::new(right) };
668            rest = r2;
669        } else {
670            rest = r;
671            break;
672        }
673    }
674    Ok((rest, left))
675}
676
677fn parse_bitwise_xor(input: &str) -> IResult<&str, Expression> {
678    let (mut rest, mut left) = parse_bitwise_and(input)?;
679    loop {
680        let (r, _) = ws(rest)?;
681        if r.starts_with('^') && !r.starts_with("^=") {
682            let (r2, _) = char('^')(r)?;
683            let (r2, right) = parse_bitwise_and(r2)?;
684            left = Expression::Binary { op: BinaryOp::Xor, left: Box::new(left), right: Box::new(right) };
685            rest = r2;
686        } else {
687            rest = r;
688            break;
689        }
690    }
691    Ok((rest, left))
692}
693
694fn parse_bitwise_and(input: &str) -> IResult<&str, Expression> {
695    let (mut rest, mut left) = parse_equality(input)?;
696    loop {
697        let (r, _) = ws(rest)?;
698        // & but not && or &=
699        if r.starts_with('&') && !r.starts_with("&&") && !r.starts_with("&=") {
700            let (r2, _) = char('&')(r)?;
701            let (r2, right) = parse_equality(r2)?;
702            left = Expression::Binary { op: BinaryOp::And, left: Box::new(left), right: Box::new(right) };
703            rest = r2;
704        } else {
705            rest = r;
706            break;
707        }
708    }
709    Ok((rest, left))
710}
711
712fn parse_equality(input: &str) -> IResult<&str, Expression> {
713    let (mut rest, mut left) = parse_relational(input)?;
714    loop {
715        let (r, _) = ws(rest)?;
716        if let Ok((r2, _)) = tag::<&str, &str, nom::error::Error<&str>>("==")(r) {
717            let (r2, right) = parse_relational(r2)?;
718            left = Expression::Binary { op: BinaryOp::Eq, left: Box::new(left), right: Box::new(right) };
719            rest = r2;
720        } else if let Ok((r2, _)) = tag::<&str, &str, nom::error::Error<&str>>("!=")(r) {
721            let (r2, right) = parse_relational(r2)?;
722            left = Expression::Binary { op: BinaryOp::Ne, left: Box::new(left), right: Box::new(right) };
723            rest = r2;
724        } else {
725            rest = r;
726            break;
727        }
728    }
729    Ok((rest, left))
730}
731
732fn parse_relational(input: &str) -> IResult<&str, Expression> {
733    let (mut rest, mut left) = parse_shift(input)?;
734    loop {
735        let (r, _) = ws(rest)?;
736        if r.starts_with("<=") {
737            let r2 = &r[2..];
738            let (r2, right) = parse_shift(r2)?;
739            left = Expression::Binary { op: BinaryOp::Le, left: Box::new(left), right: Box::new(right) };
740            rest = r2;
741        } else if r.starts_with(">=") {
742            let r2 = &r[2..];
743            let (r2, right) = parse_shift(r2)?;
744            left = Expression::Binary { op: BinaryOp::Ge, left: Box::new(left), right: Box::new(right) };
745            rest = r2;
746        } else if r.starts_with('<') && !r.starts_with("<<") {
747            let r2 = &r[1..];
748            let (r2, right) = parse_shift(r2)?;
749            left = Expression::Binary { op: BinaryOp::Lt, left: Box::new(left), right: Box::new(right) };
750            rest = r2;
751        } else if r.starts_with('>') && !r.starts_with(">>") {
752            let r2 = &r[1..];
753            let (r2, right) = parse_shift(r2)?;
754            left = Expression::Binary { op: BinaryOp::Gt, left: Box::new(left), right: Box::new(right) };
755            rest = r2;
756        } else {
757            rest = r;
758            break;
759        }
760    }
761    Ok((rest, left))
762}
763
764fn parse_shift(input: &str) -> IResult<&str, Expression> {
765    let (mut rest, mut left) = parse_additive(input)?;
766    loop {
767        let (r, _) = ws(rest)?;
768        if r.starts_with("<<=") {
769            rest = r;
770            break;
771        } else if r.starts_with(">>=") {
772            rest = r;
773            break;
774        } else if r.starts_with("<<") {
775            let r2 = &r[2..];
776            let (r2, right) = parse_additive(r2)?;
777            left = Expression::Binary { op: BinaryOp::Shl, left: Box::new(left), right: Box::new(right) };
778            rest = r2;
779        } else if r.starts_with(">>") {
780            let r2 = &r[2..];
781            let (r2, right) = parse_additive(r2)?;
782            left = Expression::Binary { op: BinaryOp::Shr, left: Box::new(left), right: Box::new(right) };
783            rest = r2;
784        } else {
785            rest = r;
786            break;
787        }
788    }
789    Ok((rest, left))
790}
791
792fn parse_additive(input: &str) -> IResult<&str, Expression> {
793    let (mut rest, mut left) = parse_multiplicative(input)?;
794    loop {
795        let (r, _) = ws(rest)?;
796        if r.starts_with('+') && !r.starts_with("++") && !r.starts_with("+=") {
797            let (r2, _) = char('+')(r)?;
798            let (r2, right) = parse_multiplicative(r2)?;
799            left = Expression::Binary { op: BinaryOp::Add, left: Box::new(left), right: Box::new(right) };
800            rest = r2;
801        } else if r.starts_with('-') && !r.starts_with("--") && !r.starts_with("-=") && !r.starts_with("->") {
802            let (r2, _) = char('-')(r)?;
803            let (r2, right) = parse_multiplicative(r2)?;
804            left = Expression::Binary { op: BinaryOp::Sub, left: Box::new(left), right: Box::new(right) };
805            rest = r2;
806        } else {
807            rest = r;
808            break;
809        }
810    }
811    Ok((rest, left))
812}
813
814fn parse_multiplicative(input: &str) -> IResult<&str, Expression> {
815    let (mut rest, mut left) = parse_unary(input)?;
816    loop {
817        let (r, _) = ws(rest)?;
818        if r.starts_with('*') && !r.starts_with("*=") {
819            let (r2, _) = char('*')(r)?;
820            let (r2, right) = parse_unary(r2)?;
821            left = Expression::Binary { op: BinaryOp::Mul, left: Box::new(left), right: Box::new(right) };
822            rest = r2;
823        } else if r.starts_with('/') && !r.starts_with("/=") && !r.starts_with("//") && !r.starts_with("/*") {
824            let (r2, _) = char('/')(r)?;
825            let (r2, right) = parse_unary(r2)?;
826            left = Expression::Binary { op: BinaryOp::Div, left: Box::new(left), right: Box::new(right) };
827            rest = r2;
828        } else if r.starts_with('%') && !r.starts_with("%=") {
829            let (r2, _) = char('%')(r)?;
830            let (r2, right) = parse_unary(r2)?;
831            left = Expression::Binary { op: BinaryOp::Mod, left: Box::new(left), right: Box::new(right) };
832            rest = r2;
833        } else {
834            rest = r;
835            break;
836        }
837    }
838    Ok((rest, left))
839}
840
841// ═══════════════════════════════════════════════════════════════
842//  Statement parsing
843// ═══════════════════════════════════════════════════════════════
844
845fn parse_block(input: &str) -> IResult<&str, Block> {
846    let (input, _) = ws(input)?;
847    let (input, _) = char('{')(input)?;
848    let (input, stmts) = parse_statement_list(input)?;
849    let (input, _) = ws(input)?;
850    let (input, _) = char('}')(input)?;
851    Ok((input, Block { statements: stmts }))
852}
853
854fn parse_statement_list(input: &str) -> IResult<&str, Vec<Statement>> {
855    let mut stmts = Vec::new();
856    let mut rest = input;
857    loop {
858        let (r, _) = ws(rest)?;
859        if r.starts_with('}') || r.is_empty() {
860            rest = r;
861            break;
862        }
863        match parse_statement(r) {
864            Ok((r2, stmt)) => {
865                stmts.push(stmt);
866                rest = r2;
867            }
868            Err(_) => {
869                // Skip unrecognized token and try again (error recovery)
870                if let Some(pos) = r.find(|c: char| c == ';' || c == '}') {
871                    if r.as_bytes()[pos] == b';' {
872                        rest = &r[pos + 1..];
873                    } else {
874                        rest = &r[pos..];
875                    }
876                } else {
877                    break;
878                }
879            }
880        }
881    }
882    Ok((rest, stmts))
883}
884
885fn parse_statement(input: &str) -> IResult<&str, Statement> {
886    let (input, _) = ws(input)?;
887    alt((
888        parse_syncthreads_stmt,
889        parse_return_stmt,
890        parse_break_stmt,
891        parse_continue_stmt,
892        parse_if_stmt,
893        parse_for_stmt,
894        parse_while_stmt,
895        parse_do_while_stmt,
896        parse_block_stmt,
897        parse_var_decl_stmt,
898        parse_expr_stmt,
899    ))(input)
900}
901
902fn parse_syncthreads_stmt(input: &str) -> IResult<&str, Statement> {
903    let (input, _) = tag("__syncthreads")(input)?;
904    let (input, _) = ws(input)?;
905    let (input, _) = char('(')(input)?;
906    let (input, _) = ws(input)?;
907    let (input, _) = char(')')(input)?;
908    let (input, _) = ws(input)?;
909    let (input, _) = char(';')(input)?;
910    Ok((input, Statement::SyncThreads))
911}
912
913fn parse_return_stmt(input: &str) -> IResult<&str, Statement> {
914    let (input, _) = keyword("return")(input)?;
915    let (input, _) = ws(input)?;
916    if let Ok((rest, _)) = char::<&str, nom::error::Error<&str>>(';')(input) {
917        return Ok((rest, Statement::Return(None)));
918    }
919    let (input, expr) = parse_expr(input)?;
920    let (input, _) = ws(input)?;
921    let (input, _) = char(';')(input)?;
922    Ok((input, Statement::Return(Some(expr))))
923}
924
925fn parse_break_stmt(input: &str) -> IResult<&str, Statement> {
926    let (input, _) = keyword("break")(input)?;
927    let (input, _) = ws(input)?;
928    let (input, _) = char(';')(input)?;
929    Ok((input, Statement::Break))
930}
931
932fn parse_continue_stmt(input: &str) -> IResult<&str, Statement> {
933    let (input, _) = keyword("continue")(input)?;
934    let (input, _) = ws(input)?;
935    let (input, _) = char(';')(input)?;
936    Ok((input, Statement::Continue))
937}
938
939fn parse_if_stmt(input: &str) -> IResult<&str, Statement> {
940    let (input, _) = keyword("if")(input)?;
941    let (input, _) = ws(input)?;
942    let (input, _) = char('(')(input)?;
943    let (input, condition) = parse_expr(input)?;
944    let (input, _) = ws(input)?;
945    let (input, _) = char(')')(input)?;
946    let (input, then_branch) = parse_statement(input)?;
947    let (input, _) = ws(input)?;
948    let (input, else_branch) = opt(preceded(
949        pair(keyword("else"), ws),
950        parse_statement,
951    ))(input)?;
952
953    Ok((input, Statement::If {
954        condition,
955        then_branch: Box::new(then_branch),
956        else_branch: else_branch.map(Box::new),
957    }))
958}
959
960fn parse_for_stmt(input: &str) -> IResult<&str, Statement> {
961    let (input, _) = keyword("for")(input)?;
962    let (input, _) = ws(input)?;
963    let (input, _) = char('(')(input)?;
964
965    // Init: either a var decl or an expression statement, or empty
966    let (input, _) = ws(input)?;
967    let (input, init) = if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>(';')(input) {
968        (r, None)
969    } else if let Ok((r, stmt)) = parse_var_decl_stmt(input) {
970        // var_decl_stmt already consumes the semicolon
971        (r, Some(Box::new(stmt)))
972    } else {
973        let (r, expr) = parse_expr(input)?;
974        let (r, _) = ws(r)?;
975        let (r, _) = char(';')(r)?;
976        (r, Some(Box::new(Statement::Expr(expr))))
977    };
978
979    // Condition
980    let (input, _) = ws(input)?;
981    let (input, condition) = if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>(';')(input) {
982        (r, None)
983    } else {
984        let (r, expr) = parse_expr(input)?;
985        let (r, _) = ws(r)?;
986        let (r, _) = char(';')(r)?;
987        (r, Some(expr))
988    };
989
990    // Update
991    let (input, _) = ws(input)?;
992    let (input, update) = if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>(')')(input) {
993        (r, None)
994    } else {
995        let (r, expr) = parse_expr(input)?;
996        let (r, _) = ws(r)?;
997        let (r, _) = char(')')(r)?;
998        (r, Some(expr))
999    };
1000
1001    let (input, body) = parse_statement(input)?;
1002
1003    Ok((input, Statement::For {
1004        init,
1005        condition,
1006        update,
1007        body: Box::new(body),
1008    }))
1009}
1010
1011fn parse_while_stmt(input: &str) -> IResult<&str, Statement> {
1012    let (input, _) = keyword("while")(input)?;
1013    let (input, _) = ws(input)?;
1014    let (input, _) = char('(')(input)?;
1015    let (input, condition) = parse_expr(input)?;
1016    let (input, _) = ws(input)?;
1017    let (input, _) = char(')')(input)?;
1018    let (input, body) = parse_statement(input)?;
1019
1020    Ok((input, Statement::While {
1021        condition,
1022        body: Box::new(body),
1023    }))
1024}
1025
1026fn parse_do_while_stmt(input: &str) -> IResult<&str, Statement> {
1027    let (input, _) = keyword("do")(input)?;
1028    let (input, body) = parse_statement(input)?;
1029    let (input, _) = ws(input)?;
1030    let (input, _) = keyword("while")(input)?;
1031    let (input, _) = ws(input)?;
1032    let (input, _) = char('(')(input)?;
1033    let (input, condition) = parse_expr(input)?;
1034    let (input, _) = ws(input)?;
1035    let (input, _) = char(')')(input)?;
1036    let (input, _) = ws(input)?;
1037    let (input, _) = char(';')(input)?;
1038
1039    Ok((input, Statement::While {
1040        condition,
1041        body: Box::new(body),
1042    }))
1043}
1044
1045fn parse_block_stmt(input: &str) -> IResult<&str, Statement> {
1046    let (input, block) = parse_block(input)?;
1047    Ok((input, Statement::Block(block)))
1048}
1049
1050/// Try to detect if the next tokens look like a variable declaration.
1051/// This is the key heuristic: we look for patterns like:
1052///   type name [= init] ;
1053///   type name [ size ] [= init] ;
1054///   extern __shared__ type name [] ;
1055///   __shared__ type name [ size ] ;
1056fn parse_var_decl_stmt(input: &str) -> IResult<&str, Statement> {
1057    let (input, _) = ws(input)?;
1058
1059    // Storage class qualifiers
1060    let mut storage = StorageClass::Auto;
1061    let mut rest = input;
1062    let mut has_extern = false;
1063
1064    // extern keyword
1065    if let Ok((r, _)) = keyword("extern")(rest) {
1066        has_extern = true;
1067        rest = r;
1068        let (r, _) = ws(rest)?;
1069        rest = r;
1070    }
1071
1072    // __shared__, __constant__, register, static
1073    if let Ok((r, _)) = tag::<&str, &str, nom::error::Error<&str>>("__shared__")(rest) {
1074        storage = StorageClass::Shared;
1075        rest = r;
1076    } else if let Ok((r, _)) = tag::<&str, &str, nom::error::Error<&str>>("__constant__")(rest) {
1077        storage = StorageClass::Constant;
1078        rest = r;
1079    } else if let Ok((r, _)) = keyword("register")(rest) {
1080        storage = StorageClass::Register;
1081        rest = r;
1082    } else if let Ok((r, _)) = keyword("static")(rest) {
1083        // Keep Auto for static locals
1084        rest = r;
1085    } else if has_extern {
1086        // Just "extern" without __shared__/__constant__ - not a var decl we handle
1087        return Err(nom::Err::Error(nom::error::Error::new(input, nom::error::ErrorKind::Tag)));
1088    }
1089
1090    // Parse the type
1091    let (rest, (mut ty, qualifiers)) = parse_type(rest)?;
1092    let (rest, _) = ws(rest)?;
1093
1094    // Need an identifier here. Make sure it's not a keyword or `(` (which would be a function call).
1095    let (rest, name) = identifier(rest)?;
1096
1097    // Check that name is not a keyword that starts a statement
1098    let kw_set = ["if", "else", "for", "while", "do", "return", "break", "continue",
1099                   "switch", "case", "default", "goto", "__syncthreads"];
1100    if kw_set.contains(&name) {
1101        return Err(nom::Err::Error(nom::error::Error::new(input, nom::error::ErrorKind::Tag)));
1102    }
1103
1104    let (rest, _) = ws(rest)?;
1105
1106    // Array suffix: [size] or []
1107    if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>('[')(rest) {
1108        let (r, _) = ws(r)?;
1109        if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>(']')(r) {
1110            ty = Type::Array(Box::new(ty), None);
1111            let (r, _) = ws(r)?;
1112            // Optional additional dimensions: [16][16]
1113            let mut r = r;
1114            while let Ok((r2, _)) = char::<&str, nom::error::Error<&str>>('[')(r) {
1115                let (r2, _) = ws(r2)?;
1116                let (r2, size_expr) = parse_expr(r2)?;
1117                let (r2, _) = ws(r2)?;
1118                let (r2, _) = char(']')(r2)?;
1119                let (r2, _) = ws(r2)?;
1120                r = r2;
1121                // We wrap in nested Array types
1122                // (simplification: we don't track multi-dim precisely)
1123            }
1124            let (r, _) = ws(r)?;
1125            let (r, _) = char(';')(r)?;
1126            return Ok((r, Statement::VarDecl {
1127                name: name.to_string(),
1128                ty,
1129                init: None,
1130                storage,
1131            }));
1132        } else {
1133            // [size] - possibly multi-dimensional: [16][16]
1134            let (r, size_expr) = parse_expr(r)?;
1135            let (r, _) = ws(r)?;
1136            let (r, _) = char(']')(r)?;
1137            let size = if let Expression::Literal(Literal::Int(n)) = &size_expr {
1138                Some(*n as usize)
1139            } else {
1140                None
1141            };
1142            ty = Type::Array(Box::new(ty), size);
1143            let mut r = r;
1144            let (r2, _) = ws(r)?;
1145            // Additional dimensions
1146            while let Ok((r3, _)) = char::<&str, nom::error::Error<&str>>('[')(r2) {
1147                let (r3, _) = ws(r3)?;
1148                let (r3, _size2) = parse_expr(r3)?;
1149                let (r3, _) = ws(r3)?;
1150                let (r3, _) = char(']')(r3)?;
1151                r = r3;
1152                let (r4, _) = ws(r)?;
1153                // Check for more dimensions
1154                if r4.starts_with('[') {
1155                    continue;
1156                }
1157                break;
1158            }
1159            let (r, _) = ws(r)?;
1160            // Optional initializer
1161            let (r, init) = if let Ok((r2, _)) = char::<&str, nom::error::Error<&str>>('=')(r) {
1162                let (r2, expr) = parse_expr(r2)?;
1163                (r2, Some(expr))
1164            } else {
1165                (r, None)
1166            };
1167            let (r, _) = ws(r)?;
1168            let (r, _) = char(';')(r)?;
1169            return Ok((r, Statement::VarDecl {
1170                name: name.to_string(),
1171                ty,
1172                init,
1173                storage,
1174            }));
1175        }
1176    }
1177
1178    // Optional initializer: = expr
1179    let (rest, init) = if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>('=')(rest) {
1180        let (r, expr) = parse_expr(r)?;
1181        (r, Some(expr))
1182    } else {
1183        (rest, None)
1184    };
1185
1186    let (rest, _) = ws(rest)?;
1187    let (rest, _) = char(';')(rest)?;
1188
1189    Ok((rest, Statement::VarDecl {
1190        name: name.to_string(),
1191        ty,
1192        init,
1193        storage,
1194    }))
1195}
1196
1197fn parse_expr_stmt(input: &str) -> IResult<&str, Statement> {
1198    let (input, expr) = parse_expr(input)?;
1199    let (input, _) = ws(input)?;
1200    let (input, _) = char(';')(input)?;
1201
1202    // Convert __syncthreads() call to SyncThreads statement
1203    if let Expression::Call { ref name, ref args } = expr {
1204        if name == "__syncthreads" && args.is_empty() {
1205            return Ok((input, Statement::SyncThreads));
1206        }
1207    }
1208
1209    Ok((input, Statement::Expr(expr)))
1210}
1211
1212// ═══════════════════════════════════════════════════════════════
1213//  Top-level item parsing
1214// ═══════════════════════════════════════════════════════════════
1215
1216fn parse_parameter(input: &str) -> IResult<&str, Parameter> {
1217    let (input, _) = ws(input)?;
1218    let (input, (ty, qualifiers)) = parse_type(input)?;
1219    let (input, _) = ws(input)?;
1220    let (input, name) = identifier(input)?;
1221    // Optional array suffix on parameter: int arr[]
1222    let (input, _) = ws(input)?;
1223    let (input, _ty) = if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>('[')(input) {
1224        let (r, _) = ws(r)?;
1225        if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>(']')(r) {
1226            (r, Type::Pointer(Box::new(ty.clone())))
1227        } else {
1228            let (r, _) = parse_expr(r)?;
1229            let (r, _) = ws(r)?;
1230            let (r, _) = char(']')(r)?;
1231            (r, Type::Pointer(Box::new(ty.clone())))
1232        }
1233    } else {
1234        (input, ty.clone())
1235    };
1236
1237    Ok((input, Parameter {
1238        name: name.to_string(),
1239        ty: _ty,
1240        qualifiers,
1241    }))
1242}
1243
1244fn parse_param_list(input: &str) -> IResult<&str, Vec<Parameter>> {
1245    let (input, _) = ws(input)?;
1246    let (input, _) = char('(')(input)?;
1247    let (input, _) = ws(input)?;
1248    // Handle empty param list and void param list
1249    if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>(')')(input) {
1250        return Ok((r, vec![]));
1251    }
1252    if let Ok((r, _)) = keyword("void")(input) {
1253        let (r, _) = ws(r)?;
1254        if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>(')')(r) {
1255            return Ok((r, vec![]));
1256        }
1257    }
1258    let (input, params) = separated_list0(
1259        delimited(ws, char(','), ws),
1260        parse_parameter,
1261    )(input)?;
1262    let (input, _) = ws(input)?;
1263    let (input, _) = char(')')(input)?;
1264    Ok((input, params))
1265}
1266
1267/// Parse a kernel definition: __global__ void name(params) { body }
1268fn parse_kernel_def(input: &str) -> IResult<&str, Item> {
1269    let (input, _) = ws(input)?;
1270    // Optional template<...> - skip it
1271    let input = skip_template(input);
1272    let (input, _) = ws(input)?;
1273    let (input, _) = tag("__global__")(input)?;
1274    let (input, _) = ws(input)?;
1275    let (input, _) = keyword("void")(input)?;
1276    let (input, _) = ws(input)?;
1277    let (input, name) = identifier(input)?;
1278    let (input, params) = parse_param_list(input)?;
1279    let (input, body) = parse_block(input)?;
1280
1281    Ok((input, Item::Kernel(KernelDef {
1282        name: name.to_string(),
1283        params,
1284        body,
1285        attributes: vec![],
1286    })))
1287}
1288
1289/// Parse a __device__ function
1290fn parse_device_function(input: &str) -> IResult<&str, Item> {
1291    let (input, _) = ws(input)?;
1292    let input = skip_template(input);
1293    let (input, _) = ws(input)?;
1294    let (input, _) = tag("__device__")(input)?;
1295    let (input, _) = ws(input)?;
1296    // May also have __host__ qualifier
1297    let (input, also_host) = opt(preceded(tag("__host__"), ws))(input)?;
1298    // May have __forceinline__
1299    let (input, _) = opt(preceded(tag("__forceinline__"), ws))(input)?;
1300    let (input, _) = opt(preceded(keyword("inline"), ws))(input)?;
1301    let (input, (ret_ty, _)) = parse_type(input)?;
1302    let (input, _) = ws(input)?;
1303    let (input, name) = identifier(input)?;
1304    let (input, params) = parse_param_list(input)?;
1305    let (input, body) = parse_block(input)?;
1306
1307    let mut qualifiers = vec![FunctionQualifier::Device];
1308    if also_host.is_some() {
1309        qualifiers.push(FunctionQualifier::Host);
1310    }
1311
1312    Ok((input, Item::DeviceFunction(FunctionDef {
1313        name: name.to_string(),
1314        return_type: ret_ty,
1315        params,
1316        body,
1317        qualifiers,
1318    })))
1319}
1320
1321/// Parse a __host__ function
1322fn parse_host_function(input: &str) -> IResult<&str, Item> {
1323    let (input, _) = ws(input)?;
1324    let input = skip_template(input);
1325    let (input, _) = ws(input)?;
1326    let (input, _) = tag("__host__")(input)?;
1327    let (input, _) = ws(input)?;
1328    // May also have __device__
1329    let (input, also_device) = opt(preceded(tag("__device__"), ws))(input)?;
1330    let (input, _) = opt(preceded(keyword("inline"), ws))(input)?;
1331    let (input, (ret_ty, _)) = parse_type(input)?;
1332    let (input, _) = ws(input)?;
1333    let (input, name) = identifier(input)?;
1334    let (input, params) = parse_param_list(input)?;
1335    let (input, body) = parse_block(input)?;
1336
1337    let mut qualifiers = vec![FunctionQualifier::Host];
1338    if also_device.is_some() {
1339        qualifiers.push(FunctionQualifier::Device);
1340    }
1341
1342    Ok((input, Item::HostFunction(FunctionDef {
1343        name: name.to_string(),
1344        return_type: ret_ty,
1345        params,
1346        body,
1347        qualifiers,
1348    })))
1349}
1350
1351/// Parse an #include directive
1352fn parse_include(input: &str) -> IResult<&str, Item> {
1353    let (input, _) = ws(input)?;
1354    let (input, _) = char('#')(input)?;
1355    let (input, _) = ws(input)?;
1356    let (input, _) = tag("include")(input)?;
1357    let (input, _) = take_while(|c: char| c == ' ' || c == '\t')(input)?;
1358    // Match <...> or "..."
1359    let (input, path) = alt((
1360        delimited(char('<'), take_until(">"), char('>')),
1361        delimited(char('"'), take_until("\""), char('"')),
1362    ))(input)?;
1363    // Consume rest of line
1364    let (input, _) = take_while(|c: char| c != '\n')(input)?;
1365    Ok((input, Item::Include(path.to_string())))
1366}
1367
1368/// Parse a typedef: typedef type name;
1369fn parse_typedef(input: &str) -> IResult<&str, Item> {
1370    let (input, _) = ws(input)?;
1371    let (input, _) = keyword("typedef")(input)?;
1372    let (input, _) = ws(input)?;
1373
1374    // Handle typedef struct { ... } Name;
1375    if let Ok((r, _)) = keyword("struct")(input) {
1376        let (r, _) = ws(r)?;
1377        // Optional struct name
1378        let (r, _struct_name) = opt(identifier)(r)?;
1379        let (r, _) = ws(r)?;
1380        // Struct body - just skip it
1381        if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>('{')(r) {
1382            let r = skip_balanced_braces(r);
1383            let (r, _) = ws(r)?;
1384            let (r, name) = identifier(r)?;
1385            let (r, _) = ws(r)?;
1386            let (r, _) = char(';')(r)?;
1387            return Ok((r, Item::TypeDef(TypeDef {
1388                name: name.to_string(),
1389                ty: Type::Named(name.to_string()),
1390            })));
1391        }
1392    }
1393
1394    let (input, (ty, _)) = parse_type(input)?;
1395    let (input, _) = ws(input)?;
1396    let (input, name) = identifier(input)?;
1397    let (input, _) = ws(input)?;
1398    let (input, _) = char(';')(input)?;
1399    Ok((input, Item::TypeDef(TypeDef {
1400        name: name.to_string(),
1401        ty,
1402    })))
1403}
1404
1405/// Parse a struct definition (non-typedef)
1406fn parse_struct(input: &str) -> IResult<&str, Item> {
1407    let (input, _) = ws(input)?;
1408    let (input, _) = keyword("struct")(input)?;
1409    let (input, _) = ws(input)?;
1410    let (input, name) = identifier(input)?;
1411    let (input, _) = ws(input)?;
1412    let (input, _) = char('{')(input)?;
1413    let rest = skip_balanced_braces(input);
1414    let (rest, _) = ws(rest)?;
1415    let (rest, _) = char(';')(rest)?;
1416    Ok((rest, Item::TypeDef(TypeDef {
1417        name: name.to_string(),
1418        ty: Type::Named(name.to_string()),
1419    })))
1420}
1421
1422/// Skip template<...> prefixes
1423fn skip_template(input: &str) -> &str {
1424    let trimmed = input.trim_start();
1425    if !trimmed.starts_with("template") {
1426        return input;
1427    }
1428    let rest = &trimmed[8..];
1429    let rest = rest.trim_start();
1430    if !rest.starts_with('<') {
1431        return input;
1432    }
1433    let mut depth = 0;
1434    for (i, c) in rest.char_indices() {
1435        match c {
1436            '<' => depth += 1,
1437            '>' => {
1438                depth -= 1;
1439                if depth == 0 {
1440                    return &rest[i + 1..];
1441                }
1442            }
1443            _ => {}
1444        }
1445    }
1446    input
1447}
1448
1449/// Skip balanced braces, returning the rest after the closing '}'
1450fn skip_balanced_braces(input: &str) -> &str {
1451    let mut depth = 1;
1452    for (i, c) in input.char_indices() {
1453        match c {
1454            '{' => depth += 1,
1455            '}' => {
1456                depth -= 1;
1457                if depth == 0 {
1458                    return &input[i + 1..];
1459                }
1460            }
1461            _ => {}
1462        }
1463    }
1464    input
1465}
1466
1467/// Skip a preprocessor directive (everything until end of line)
1468fn parse_preprocessor(input: &str) -> IResult<&str, ()> {
1469    let (input, _) = ws(input)?;
1470    let (input, _) = char('#')(input)?;
1471    let (input, _) = take_while(|c: char| c != '\n')(input)?;
1472    Ok((input, ()))
1473}
1474
1475// ═══════════════════════════════════════════════════════════════
1476//  Top-level parser
1477// ═══════════════════════════════════════════════════════════════
1478
1479/// Parse a top-level global variable declaration (__constant__ or __shared__)
1480fn parse_global_var_decl(input: &str) -> IResult<&str, Item> {
1481    let (input, _) = ws(input)?;
1482    // Only match __constant__ or __shared__ at top level
1483    let rest = input;
1484    let storage;
1485    let rest = if let Ok((r, _)) = tag::<&str, &str, nom::error::Error<&str>>("__constant__")(rest) {
1486        storage = StorageClass::Constant;
1487        r
1488    } else if let Ok((r, _)) = tag::<&str, &str, nom::error::Error<&str>>("__shared__")(rest) {
1489        storage = StorageClass::Shared;
1490        r
1491    } else {
1492        return Err(nom::Err::Error(nom::error::Error::new(input, nom::error::ErrorKind::Tag)));
1493    };
1494    let (rest, _) = ws(rest)?;
1495    let (rest, (ty, _qualifiers)) = parse_type(rest)?;
1496    let (rest, _) = ws(rest)?;
1497    let (rest, name) = identifier(rest)?;
1498    let (rest, _) = ws(rest)?;
1499    // Array suffix
1500    let (rest, ty) = if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>('[')(rest) {
1501        let (r, _) = ws(r)?;
1502        let (r, size_expr) = parse_expr(r)?;
1503        let (r, _) = ws(r)?;
1504        let (r, _) = char(']')(r)?;
1505        let size = if let Expression::Literal(Literal::Int(n)) = &size_expr {
1506            Some(*n as usize)
1507        } else {
1508            None
1509        };
1510        (r, Type::Array(Box::new(ty), size))
1511    } else {
1512        (rest, ty)
1513    };
1514    let (rest, _) = ws(rest)?;
1515    // Optional initializer
1516    let (rest, init) = if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>('=')(rest) {
1517        let (r, _) = ws(r)?;
1518        // Handle brace-enclosed initializers: {1.0, 2.0, ...}
1519        if r.starts_with('{') {
1520            // Skip to matching }
1521            let end = r.find('}').unwrap_or(r.len() - 1);
1522            let r = &r[end + 1..];
1523            (r, None) // We don't parse initializer lists into AST yet
1524        } else {
1525            let (r, expr) = parse_expr(r)?;
1526            (r, Some(expr))
1527        }
1528    } else {
1529        (rest, None)
1530    };
1531    let (rest, _) = ws(rest)?;
1532    let (rest, _) = char(';')(rest)?;
1533    Ok((rest, Item::GlobalVar(GlobalVar {
1534        name: name.to_string(),
1535        ty,
1536        storage,
1537        init,
1538    })))
1539}
1540
1541fn parse_top_level_item(input: &str) -> IResult<&str, Option<Item>> {
1542    let (input, _) = ws(input)?;
1543    if input.is_empty() {
1544        return Err(nom::Err::Error(nom::error::Error::new(input, nom::error::ErrorKind::Eof)));
1545    }
1546
1547    // Try each top-level construct
1548    if let Ok((r, item)) = parse_include(input) {
1549        return Ok((r, Some(item)));
1550    }
1551    // Global vars before kernels (so __constant__ is not skipped)
1552    if let Ok((r, item)) = parse_global_var_decl(input) {
1553        return Ok((r, Some(item)));
1554    }
1555    if let Ok((r, item)) = parse_kernel_def(input) {
1556        return Ok((r, Some(item)));
1557    }
1558    if let Ok((r, item)) = parse_device_function(input) {
1559        return Ok((r, Some(item)));
1560    }
1561    if let Ok((r, item)) = parse_host_function(input) {
1562        return Ok((r, Some(item)));
1563    }
1564    if let Ok((r, item)) = parse_typedef(input) {
1565        return Ok((r, Some(item)));
1566    }
1567    if let Ok((r, item)) = parse_struct(input) {
1568        return Ok((r, Some(item)));
1569    }
1570
1571    // Skip preprocessor directives
1572    if input.starts_with('#') {
1573        let (r, _) = parse_preprocessor(input)?;
1574        return Ok((r, None));
1575    }
1576
1577    // Skip unrecognized top-level constructs (free functions, etc.)
1578    // Try to skip to next semicolon or closing brace
1579    if let Some(pos) = input.find(|c: char| c == '{' || c == ';') {
1580        if input.as_bytes()[pos] == b'{' {
1581            let rest = skip_balanced_braces(&input[pos + 1..]);
1582            return Ok((rest, None));
1583        } else {
1584            return Ok((&input[pos + 1..], None));
1585        }
1586    }
1587
1588    Err(nom::Err::Error(nom::error::Error::new(input, nom::error::ErrorKind::Eof)))
1589}
1590
1591// ═══════════════════════════════════════════════════════════════
1592//  Public API
1593// ═══════════════════════════════════════════════════════════════
1594
1595/// Main CUDA parser
1596pub struct CudaParser {
1597    // Parser state can be added here
1598}
1599
1600impl CudaParser {
1601    /// Create a new CUDA parser
1602    pub fn new() -> Self {
1603        Self {}
1604    }
1605
1606    /// Parse CUDA source code into AST
1607    pub fn parse(&self, source: &str) -> Result<Ast> {
1608        let mut items = Vec::new();
1609        let mut rest = source;
1610
1611        loop {
1612            // Skip whitespace and comments
1613            match ws(rest) {
1614                Ok((r, _)) => rest = r,
1615                Err(_) => break,
1616            }
1617            if rest.is_empty() {
1618                break;
1619            }
1620
1621            match parse_top_level_item(rest) {
1622                Ok((r, Some(item))) => {
1623                    items.push(item);
1624                    rest = r;
1625                }
1626                Ok((r, None)) => {
1627                    // Skipped preprocessor or unrecognized construct
1628                    rest = r;
1629                }
1630                Err(_) => {
1631                    // Skip one character and try again (error recovery)
1632                    if rest.is_empty() {
1633                        break;
1634                    }
1635                    // Try to find the next meaningful token
1636                    if let Some(pos) = rest[1..].find(|c: char| {
1637                        c == '#' || c == '_' || c.is_alphabetic()
1638                    }) {
1639                        rest = &rest[pos + 1..];
1640                    } else {
1641                        break;
1642                    }
1643                }
1644            }
1645        }
1646
1647        Ok(Ast { items })
1648    }
1649}
1650
1651impl Default for CudaParser {
1652    fn default() -> Self {
1653        Self::new()
1654    }
1655}
1656
1657// ═══════════════════════════════════════════════════════════════
1658//  Tests
1659// ═══════════════════════════════════════════════════════════════
1660
1661#[cfg(test)]
1662mod tests {
1663    use super::*;
1664
1665    #[test]
1666    fn test_vector_add() {
1667        let src = r#"
1668__global__ void vectorAdd(const float* a, const float* b, float* c, int n) {
1669    int i = blockIdx.x * blockDim.x + threadIdx.x;
1670    if (i < n) {
1671        c[i] = a[i] + b[i];
1672    }
1673}
1674"#;
1675        let parser = CudaParser::new();
1676        let ast = parser.parse(src).unwrap();
1677        assert_eq!(ast.items.len(), 1);
1678        if let Item::Kernel(ref k) = ast.items[0] {
1679            assert_eq!(k.name, "vectorAdd");
1680            assert_eq!(k.params.len(), 4);
1681            assert_eq!(k.params[0].name, "a");
1682            assert_eq!(k.params[3].name, "n");
1683            // Body should have 2 statements: var decl + if
1684            assert_eq!(k.body.statements.len(), 2);
1685        } else {
1686            panic!("Expected kernel");
1687        }
1688    }
1689
1690    #[test]
1691    fn test_mat_mul() {
1692        let src = r#"
1693__global__ void matMul(float* A, float* B, float* C, int M, int N, int K) {
1694    __shared__ float sA[16][16];
1695    __shared__ float sB[16][16];
1696    int row = blockIdx.y * blockDim.y + threadIdx.y;
1697    int col = blockIdx.x * blockDim.x + threadIdx.x;
1698    float sum = 0.0f;
1699    for (int t = 0; t < (K + 15) / 16; t++) {
1700        sA[threadIdx.y][threadIdx.x] = A[row * K + t * 16 + threadIdx.x];
1701        sB[threadIdx.y][threadIdx.x] = B[(t * 16 + threadIdx.y) * N + col];
1702        __syncthreads();
1703        for (int k = 0; k < 16; k++) {
1704            sum += sA[threadIdx.y][k] * sB[k][threadIdx.x];
1705        }
1706        __syncthreads();
1707    }
1708    C[row * N + col] = sum;
1709}
1710"#;
1711        let parser = CudaParser::new();
1712        let ast = parser.parse(src).unwrap();
1713        assert_eq!(ast.items.len(), 1);
1714        if let Item::Kernel(ref k) = ast.items[0] {
1715            assert_eq!(k.name, "matMul");
1716            assert_eq!(k.params.len(), 6);
1717        } else {
1718            panic!("Expected kernel");
1719        }
1720    }
1721
1722    #[test]
1723    fn test_reduce() {
1724        let src = r#"
1725__global__ void reduce(float* input, float* output, int n) {
1726    extern __shared__ float sdata[];
1727    unsigned int tid = threadIdx.x;
1728    unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
1729    sdata[tid] = (i < n) ? input[i] : 0.0f;
1730    __syncthreads();
1731    for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
1732        if (tid < s) {
1733            sdata[tid] += sdata[tid + s];
1734        }
1735        __syncthreads();
1736    }
1737    if (tid == 0) output[blockIdx.x] = sdata[0];
1738}
1739"#;
1740        let parser = CudaParser::new();
1741        let ast = parser.parse(src).unwrap();
1742        assert_eq!(ast.items.len(), 1);
1743        if let Item::Kernel(ref k) = ast.items[0] {
1744            assert_eq!(k.name, "reduce");
1745            assert_eq!(k.params.len(), 3);
1746        } else {
1747            panic!("Expected kernel");
1748        }
1749    }
1750
1751    #[test]
1752    fn test_include_directive() {
1753        let src = r#"#include <cuda_runtime.h>
1754#include "myheader.h"
1755"#;
1756        let parser = CudaParser::new();
1757        let ast = parser.parse(src).unwrap();
1758        assert!(ast.items.len() >= 2);
1759        assert!(matches!(&ast.items[0], Item::Include(p) if p == "cuda_runtime.h"));
1760        assert!(matches!(&ast.items[1], Item::Include(p) if p == "myheader.h"));
1761    }
1762
1763    #[test]
1764    fn test_multiple_kernels() {
1765        let src = r#"
1766__global__ void kernel1(int* a) {
1767    a[threadIdx.x] = 0;
1768}
1769__global__ void kernel2(float* b, int n) {
1770    int i = threadIdx.x;
1771    if (i < n) b[i] = 1.0f;
1772}
1773"#;
1774        let parser = CudaParser::new();
1775        let ast = parser.parse(src).unwrap();
1776        assert_eq!(ast.items.len(), 2);
1777    }
1778
1779    #[test]
1780    fn test_device_function() {
1781        let src = r#"
1782__device__ float clamp(float x, float lo, float hi) {
1783    if (x < lo) return lo;
1784    if (x > hi) return hi;
1785    return x;
1786}
1787"#;
1788        let parser = CudaParser::new();
1789        let ast = parser.parse(src).unwrap();
1790        assert_eq!(ast.items.len(), 1);
1791        assert!(matches!(&ast.items[0], Item::DeviceFunction(_)));
1792    }
1793
1794    #[test]
1795    fn test_expressions() {
1796        // Test various expression types in isolation
1797        assert!(parse_expr("a + b * c").is_ok());
1798        assert!(parse_expr("a[i]").is_ok());
1799        assert!(parse_expr("threadIdx.x").is_ok());
1800        assert!(parse_expr("blockIdx.x * blockDim.x + threadIdx.x").is_ok());
1801        assert!(parse_expr("(float)x").is_ok());
1802        assert!(parse_expr("a < b && c > d").is_ok());
1803        assert!(parse_expr("i++").is_ok());
1804        assert!(parse_expr("++i").is_ok());
1805        assert!(parse_expr("atomicAdd(&x, 1)").is_ok());
1806    }
1807
1808    #[test]
1809    fn test_type_parsing() {
1810        assert!(parse_type("float*").is_ok());
1811        assert!(parse_type("const float*").is_ok());
1812        assert!(parse_type("unsigned int").is_ok());
1813        assert!(parse_type("int").is_ok());
1814        assert!(parse_type("float4").is_ok());
1815        assert!(parse_type("double").is_ok());
1816    }
1817}