1use nom::{
6 IResult,
7 branch::alt,
8 bytes::complete::{tag, take_while, take_while1, take_until},
9 character::complete::{char, multispace0, multispace1, digit1, alpha1, one_of},
10 combinator::{opt, map, value, recognize},
11 multi::{separated_list0, many1},
12 sequence::{pair, tuple, delimited, preceded},
13};
14
15use crate::{Result, parse_error};
16use super::ast::*;
17
18fn ws(input: &str) -> IResult<&str, ()> {
24 let mut rest = input;
25 loop {
26 let (r, _) = multispace0(rest)?;
27 rest = r;
28 if rest.starts_with("//") {
29 let end = rest.find('\n').unwrap_or(rest.len());
30 rest = &rest[end..];
31 } else if rest.starts_with("/*") {
32 if let Some(end) = rest.find("*/") {
33 rest = &rest[end + 2..];
34 } else {
35 return Err(nom::Err::Error(nom::error::Error::new(rest, nom::error::ErrorKind::Tag)));
36 }
37 } else {
38 break;
39 }
40 }
41 Ok((rest, ()))
42}
43
44fn ws_around<'a, F, O>(inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O>
46where
47 F: FnMut(&'a str) -> IResult<&'a str, O>,
48{
49 delimited(ws, inner, ws)
50}
51
52fn identifier(input: &str) -> IResult<&str, &str> {
54 recognize(pair(
55 alt((alpha1, tag("_"))),
56 take_while(|c: char| c.is_alphanumeric() || c == '_'),
57 ))(input)
58}
59
60fn ws_ident(input: &str) -> IResult<&str, &str> {
62 let (input, _) = ws(input)?;
63 identifier(input)
64}
65
66fn ws_tag<'a>(t: &'a str) -> impl FnMut(&'a str) -> IResult<&'a str, &'a str> {
68 delimited(ws, tag(t), ws)
69}
70
71fn t<'a>(s: &'a str) -> impl FnMut(&'a str) -> IResult<&'a str, &'a str> {
73 tag(s)
74}
75
76fn keyword<'a>(kw: &'a str) -> impl FnMut(&'a str) -> IResult<&'a str, &'a str> {
78 move |input: &'a str| {
79 let (rest, matched) = tag(kw)(input)?;
80 if let Some(c) = rest.chars().next() {
82 if c.is_alphanumeric() || c == '_' {
83 return Err(nom::Err::Error(nom::error::Error::new(input, nom::error::ErrorKind::Tag)));
84 }
85 }
86 Ok((rest, matched))
87 }
88}
89
90fn parse_float_literal(input: &str) -> IResult<&str, Expression> {
95 let (rest, text) = recognize(alt((
97 recognize(tuple((
99 digit1,
100 char('.'),
101 opt(digit1),
102 opt(recognize(tuple((one_of("eE"), opt(one_of("+-")), digit1)))),
103 opt(one_of("fF")),
104 ))),
105 recognize(tuple((
107 char('.'),
108 digit1,
109 opt(recognize(tuple((one_of("eE"), opt(one_of("+-")), digit1)))),
110 opt(one_of("fF")),
111 ))),
112 recognize(tuple((
114 digit1,
115 one_of("eE"),
116 opt(one_of("+-")),
117 digit1,
118 opt(one_of("fF")),
119 ))),
120 recognize(tuple((digit1, one_of("fF")))),
122 )))(input)?;
123
124 let clean = text.trim_end_matches(|c| c == 'f' || c == 'F');
125 let val: f64 = clean.parse().unwrap_or(0.0);
126 Ok((rest, Expression::Literal(Literal::Float(val))))
127}
128
129fn parse_hex_literal(input: &str) -> IResult<&str, Expression> {
130 let (rest, text) = recognize(tuple((
131 tag("0"),
132 one_of("xX"),
133 take_while1(|c: char| c.is_ascii_hexdigit()),
134 opt(one_of("uUlL")),
135 )))(input)?;
136 let clean = text.trim_start_matches("0x").trim_start_matches("0X");
137 let clean = clean.trim_end_matches(|c| c == 'u' || c == 'U' || c == 'l' || c == 'L');
138 let val = u64::from_str_radix(clean, 16).unwrap_or(0);
139 if text.contains('u') || text.contains('U') {
140 Ok((rest, Expression::Literal(Literal::UInt(val))))
141 } else {
142 Ok((rest, Expression::Literal(Literal::Int(val as i64))))
143 }
144}
145
146fn parse_int_literal(input: &str) -> IResult<&str, Expression> {
147 let (rest, text) = recognize(pair(
148 digit1,
149 opt(recognize(many1(one_of("uUlL")))),
150 ))(input)?;
151 let clean = text.trim_end_matches(|c| c == 'u' || c == 'U' || c == 'l' || c == 'L');
152 if text.contains('u') || text.contains('U') {
153 let val: u64 = clean.parse().unwrap_or(0);
154 Ok((rest, Expression::Literal(Literal::UInt(val))))
155 } else {
156 let val: i64 = clean.parse().unwrap_or(0);
157 Ok((rest, Expression::Literal(Literal::Int(val))))
158 }
159}
160
161fn parse_literal(input: &str) -> IResult<&str, Expression> {
162 alt((
163 parse_hex_literal,
164 parse_float_literal,
165 parse_int_literal,
166 ))(input)
167}
168
169fn parse_base_type(input: &str) -> IResult<&str, Type> {
174 let (input, _) = ws(input)?;
175 alt((
176 value(Type::Void, keyword("void")),
177 value(Type::Bool, keyword("bool")),
178 value(Type::Vector(VectorType { element: Box::new(Type::Float(FloatType::F32)), size: 2 }), keyword("float2")),
180 value(Type::Vector(VectorType { element: Box::new(Type::Float(FloatType::F32)), size: 3 }), keyword("float3")),
181 value(Type::Vector(VectorType { element: Box::new(Type::Float(FloatType::F32)), size: 4 }), keyword("float4")),
182 value(Type::Vector(VectorType { element: Box::new(Type::Int(IntType::I32)), size: 2 }), keyword("int2")),
183 value(Type::Vector(VectorType { element: Box::new(Type::Int(IntType::I32)), size: 3 }), keyword("int3")),
184 value(Type::Vector(VectorType { element: Box::new(Type::Int(IntType::I32)), size: 4 }), keyword("int4")),
185 value(Type::Vector(VectorType { element: Box::new(Type::Float(FloatType::F64)), size: 2 }), keyword("double2")),
186 value(Type::Vector(VectorType { element: Box::new(Type::Float(FloatType::F64)), size: 3 }), keyword("double3")),
187 value(Type::Vector(VectorType { element: Box::new(Type::Float(FloatType::F64)), size: 4 }), keyword("double4")),
188 value(Type::Named("dim3".to_string()), keyword("dim3")),
189 value(Type::Float(FloatType::F64), keyword("double")),
190 value(Type::Float(FloatType::F32), keyword("float")),
191 map(preceded(keyword("unsigned"), opt(preceded(multispace1, keyword("int")))), |_| Type::Int(IntType::U32)),
193 map(pair(keyword("long"), opt(preceded(multispace1, keyword("long")))), |(_, ll)| {
195 if ll.is_some() { Type::Int(IntType::I64) } else { Type::Int(IntType::I64) }
196 }),
197 value(Type::Int(IntType::I16), keyword("short")),
198 value(Type::Int(IntType::I8), keyword("char")),
199 value(Type::Int(IntType::I32), keyword("int")),
200 value(Type::Int(IntType::U64), keyword("size_t")),
202 map(identifier, |name: &str| Type::Named(name.to_string())),
204 ))(input)
205}
206
207fn parse_type(input: &str) -> IResult<&str, (Type, Vec<ParamQualifier>)> {
209 let (input, _) = ws(input)?;
210
211 let mut qualifiers = Vec::new();
213 let mut rest = input;
214 loop {
215 let (r, _) = ws(rest)?;
216 if let Ok((r2, _)) = keyword("const")(r) {
217 qualifiers.push(ParamQualifier::Const);
218 rest = r2;
219 } else if let Ok((r2, _)) = keyword("volatile")(r) {
220 qualifiers.push(ParamQualifier::Volatile);
221 rest = r2;
222 } else if let Ok((r2, _)) = t("__restrict__")(r) {
223 qualifiers.push(ParamQualifier::Restrict);
224 rest = r2;
225 } else {
226 break;
227 }
228 }
229
230 let (rest, mut ty) = parse_base_type(rest)?;
232
233 let mut rest = rest;
235 loop {
236 let (r, _) = ws(rest)?;
237 if let Ok((r2, _)) = char::<&str, nom::error::Error<&str>>('*')(r) {
238 ty = Type::Pointer(Box::new(ty));
239 rest = r2;
240 let (r3, _) = ws(rest)?;
242 if let Ok((r4, _)) = keyword("const")(r3) {
243 qualifiers.push(ParamQualifier::Const);
244 rest = r4;
245 } else if let Ok((r4, _)) = t("__restrict__")(r3) {
246 qualifiers.push(ParamQualifier::Restrict);
247 rest = r4;
248 } else if let Ok((r4, _)) = keyword("restrict")(r3) {
249 qualifiers.push(ParamQualifier::Restrict);
250 rest = r4;
251 } else {
252 rest = r3;
253 }
254 } else {
255 rest = r;
256 break;
257 }
258 }
259
260 Ok((rest, (ty, qualifiers)))
261}
262
263fn parse_primary(input: &str) -> IResult<&str, Expression> {
269 let (input, _) = ws(input)?;
270 alt((
271 parse_cuda_builtin,
272 parse_sizeof_expr,
273 parse_cast_or_paren,
274 parse_literal,
275 parse_ident_or_call,
276 ))(input)
277}
278
279fn parse_sizeof_expr(input: &str) -> IResult<&str, Expression> {
281 let (input, _) = keyword("sizeof")(input)?;
282 let (input, _) = ws(input)?;
283 let (input, _) = char('(')(input)?;
284 let (input, _) = ws(input)?;
285 if let Ok((rest, (ty, _))) = parse_type(input) {
287 let (rest, _) = ws(rest)?;
288 let (rest, _) = char(')')(rest)?;
289 Ok((rest, Expression::Call {
291 name: "sizeof".to_string(),
292 args: vec![Expression::Var(format!("{:?}", ty))],
293 }))
294 } else {
295 let (input, expr) = parse_expr(input)?;
296 let (input, _) = ws(input)?;
297 let (input, _) = char(')')(input)?;
298 Ok((input, Expression::Call {
299 name: "sizeof".to_string(),
300 args: vec![expr],
301 }))
302 }
303}
304
305fn parse_cuda_builtin(input: &str) -> IResult<&str, Expression> {
307 let (input, builtin) = alt((
308 tag("threadIdx"),
309 tag("blockIdx"),
310 tag("blockDim"),
311 tag("gridDim"),
312 ))(input)?;
313 if let Some(c) = input.chars().next() {
315 if c.is_alphanumeric() || c == '_' {
316 return Err(nom::Err::Error(nom::error::Error::new(input, nom::error::ErrorKind::Tag)));
317 }
318 }
319 let (input, _) = ws(input)?;
320 let (input, _) = char('.')(input)?;
321 let (input, _) = ws(input)?;
322 let (input, dim_str) = alt((tag("x"), tag("y"), tag("z")))(input)?;
323 let dim = match dim_str {
324 "x" => Dimension::X,
325 "y" => Dimension::Y,
326 "z" => Dimension::Z,
327 _ => unreachable!(),
328 };
329 let expr = match builtin {
330 "threadIdx" => Expression::ThreadIdx(dim),
331 "blockIdx" => Expression::BlockIdx(dim),
332 "blockDim" => Expression::BlockDim(dim),
333 "gridDim" => Expression::GridDim(dim),
334 _ => unreachable!(),
335 };
336 Ok((input, expr))
337}
338
339fn parse_cast_or_paren(input: &str) -> IResult<&str, Expression> {
341 let (input, _) = char('(')(input)?;
342 let (input, _) = ws(input)?;
343
344 let checkpoint = input;
348 if let Ok((after_ty, (ty, _))) = parse_type(checkpoint) {
349 let (after_ty, _) = ws(after_ty)?;
350 if let Ok((after_close, _)) = char::<&str, nom::error::Error<&str>>(')')(after_ty) {
351 let (peek_rest, _) = ws(after_close)?;
353 let looks_like_expr = peek_rest.starts_with('(')
354 || peek_rest.starts_with(|c: char| c.is_alphanumeric() || c == '_' || c == '-' || c == '!' || c == '~' || c == '.');
355 if looks_like_expr {
356 let is_real_type = matches!(ty,
358 Type::Void | Type::Bool | Type::Int(_) | Type::Float(_) | Type::Pointer(_)
359 | Type::Vector(_) | Type::Array(_, _));
360 if is_real_type {
361 let (rest, expr) = parse_unary(after_close)?;
362 return Ok((rest, Expression::Cast { ty, expr: Box::new(expr) }));
363 }
364 }
365 }
366 }
367
368 let (input, expr) = parse_expr(checkpoint)?;
370 let (input, _) = ws(input)?;
371 let (input, _) = char(')')(input)?;
372 Ok((input, expr))
373}
374
375fn parse_ident_or_call(input: &str) -> IResult<&str, Expression> {
377 if let Ok((rest, _)) = tag::<&str, &str, nom::error::Error<&str>>("__syncthreads")(input) {
379 let (rest, _) = ws(rest)?;
380 if let Ok((rest, _)) = char::<&str, nom::error::Error<&str>>('(')(rest) {
381 let (rest, _) = ws(rest)?;
382 let (rest, _) = char(')')(rest)?;
383 return Ok((rest, Expression::Call { name: "__syncthreads".to_string(), args: vec![] }));
385 }
386 }
387
388 let (input, name) = identifier(input)?;
389 let (input, _) = ws(input)?;
390
391 if let Ok((rest, _)) = char::<&str, nom::error::Error<&str>>('(')(input) {
393 let (rest, _) = ws(rest)?;
394 let (rest, args) = separated_list0(
395 delimited(ws, char(','), ws),
396 parse_expr,
397 )(rest)?;
398 let (rest, _) = ws(rest)?;
399 let (rest, _) = char(')')(rest)?;
400
401 let expr = match name {
403 "__shfl_sync" => Expression::WarpPrimitive { op: WarpOp::Shuffle, args },
404 "__shfl_xor_sync" => Expression::WarpPrimitive { op: WarpOp::ShuffleXor, args },
405 "__shfl_up_sync" => Expression::WarpPrimitive { op: WarpOp::ShuffleUp, args },
406 "__shfl_down_sync" => Expression::WarpPrimitive { op: WarpOp::ShuffleDown, args },
407 "__ballot_sync" => Expression::WarpPrimitive { op: WarpOp::Ballot, args },
408 "__activemask" => Expression::WarpPrimitive { op: WarpOp::ActiveMask, args },
409 _ => Expression::Call { name: name.to_string(), args },
410 };
411 return Ok((rest, expr));
412 }
413
414 Ok((input, Expression::Var(name.to_string())))
415}
416
417fn parse_postfix(input: &str) -> IResult<&str, Expression> {
419 let (mut rest, mut expr) = parse_primary(input)?;
420
421 loop {
422 let (r, _) = ws(rest)?;
423
424 if let Ok((r2, _)) = char::<&str, nom::error::Error<&str>>('[')(r) {
426 let (r2, _) = ws(r2)?;
427 let (r2, index) = parse_expr(r2)?;
428 let (r2, _) = ws(r2)?;
429 let (r2, _) = char(']')(r2)?;
430 expr = Expression::Index {
431 array: Box::new(expr),
432 index: Box::new(index),
433 };
434 rest = r2;
435 continue;
436 }
437
438 if let Ok((r2, _)) = char::<&str, nom::error::Error<&str>>('.')(r) {
440 if let Ok((r3, field)) = identifier(r2) {
441 expr = Expression::Member {
443 object: Box::new(expr),
444 field: field.to_string(),
445 };
446 rest = r3;
447 continue;
448 }
449 }
450
451 if let Ok((r2, _)) = tag::<&str, &str, nom::error::Error<&str>>("->")(r) {
453 let (r3, field) = identifier(r2)?;
454 expr = Expression::Member {
455 object: Box::new(expr),
456 field: field.to_string(),
457 };
458 rest = r3;
459 continue;
460 }
461
462 if let Ok((r2, _)) = tag::<&str, &str, nom::error::Error<&str>>("++")(r) {
464 expr = Expression::Unary {
465 op: UnaryOp::PostInc,
466 expr: Box::new(expr),
467 };
468 rest = r2;
469 continue;
470 }
471
472 if let Ok((r2, _)) = tag::<&str, &str, nom::error::Error<&str>>("--")(r) {
474 expr = Expression::Unary {
475 op: UnaryOp::PostDec,
476 expr: Box::new(expr),
477 };
478 rest = r2;
479 continue;
480 }
481
482 rest = r;
483 break;
484 }
485
486 Ok((rest, expr))
487}
488
489fn parse_unary(input: &str) -> IResult<&str, Expression> {
491 let (input, _) = ws(input)?;
492
493 if let Ok((rest, _)) = tag::<&str, &str, nom::error::Error<&str>>("++")(input) {
495 let (rest, expr) = parse_unary(rest)?;
496 return Ok((rest, Expression::Unary { op: UnaryOp::PreInc, expr: Box::new(expr) }));
497 }
498 if let Ok((rest, _)) = tag::<&str, &str, nom::error::Error<&str>>("--")(input) {
500 let (rest, expr) = parse_unary(rest)?;
501 return Ok((rest, Expression::Unary { op: UnaryOp::PreDec, expr: Box::new(expr) }));
502 }
503 if input.starts_with('-') && !input.starts_with("--") && !input.starts_with("->") {
505 let (rest, _) = char('-')(input)?;
506 let (rest, expr) = parse_unary(rest)?;
507 return Ok((rest, Expression::Unary { op: UnaryOp::Neg, expr: Box::new(expr) }));
508 }
509 if input.starts_with('!') && !input.starts_with("!=") {
511 let (rest, _) = char('!')(input)?;
512 let (rest, expr) = parse_unary(rest)?;
513 return Ok((rest, Expression::Unary { op: UnaryOp::Not, expr: Box::new(expr) }));
514 }
515 if input.starts_with('~') {
517 let (rest, _) = char('~')(input)?;
518 let (rest, expr) = parse_unary(rest)?;
519 return Ok((rest, Expression::Unary { op: UnaryOp::BitNot, expr: Box::new(expr) }));
520 }
521 if input.starts_with('*') && !input.starts_with("*=") {
523 let (rest, _) = char('*')(input)?;
524 let (rest, expr) = parse_unary(rest)?;
525 return Ok((rest, Expression::Unary { op: UnaryOp::Deref, expr: Box::new(expr) }));
526 }
527 if input.starts_with('&') && !input.starts_with("&&") && !input.starts_with("&=") {
529 let (rest, _) = char('&')(input)?;
530 let (rest, expr) = parse_unary(rest)?;
531 return Ok((rest, Expression::Unary { op: UnaryOp::AddrOf, expr: Box::new(expr) }));
532 }
533
534 parse_postfix(input)
535}
536
537fn parse_expr(input: &str) -> IResult<&str, Expression> {
539 parse_assignment(input)
540}
541
542fn parse_assignment(input: &str) -> IResult<&str, Expression> {
543 let (mut rest, mut left) = parse_ternary(input)?;
544
545 loop {
546 let (r, _) = ws(rest)?;
547
548 let compound_op: Option<(usize, BinaryOp)> = if r.starts_with("<<=") {
550 Some((3, BinaryOp::Shl))
551 } else if r.starts_with(">>=") {
552 Some((3, BinaryOp::Shr))
553 } else if r.starts_with("+=") {
554 Some((2, BinaryOp::Add))
555 } else if r.starts_with("-=") {
556 Some((2, BinaryOp::Sub))
557 } else if r.starts_with("*=") {
558 Some((2, BinaryOp::Mul))
559 } else if r.starts_with("/=") {
560 Some((2, BinaryOp::Div))
561 } else if r.starts_with("%=") {
562 Some((2, BinaryOp::Mod))
563 } else if r.starts_with("&=") {
564 Some((2, BinaryOp::And))
565 } else if r.starts_with("|=") {
566 Some((2, BinaryOp::Or))
567 } else if r.starts_with("^=") {
568 Some((2, BinaryOp::Xor))
569 } else {
570 None
571 };
572
573 if let Some((len, op)) = compound_op {
574 let r2 = &r[len..];
575 let (r2, right) = parse_assignment(r2)?;
576 left = Expression::Binary {
578 op: BinaryOp::Assign,
579 left: Box::new(left.clone()),
580 right: Box::new(Expression::Binary {
581 op,
582 left: Box::new(left),
583 right: Box::new(right),
584 }),
585 };
586 rest = r2;
587 continue;
588 }
589
590 if r.starts_with('=') && !r.starts_with("==") {
592 let r2 = &r[1..];
593 let (r2, right) = parse_assignment(r2)?;
594 left = Expression::Binary { op: BinaryOp::Assign, left: Box::new(left), right: Box::new(right) };
595 rest = r2;
596 continue;
597 }
598
599 break;
600 }
601
602 Ok((rest, left))
603}
604
605fn parse_ternary(input: &str) -> IResult<&str, Expression> {
606 let (rest, cond) = parse_logical_or(input)?;
607 let (r, _) = ws(rest)?;
608 if let Ok((r2, _)) = char::<&str, nom::error::Error<&str>>('?')(r) {
609 let (r2, then_expr) = parse_expr(r2)?;
610 let (r2, _) = ws(r2)?;
611 let (r2, _) = char(':')(r2)?;
612 let (r2, else_expr) = parse_ternary(r2)?;
613 Ok((r2, Expression::Call {
617 name: "__ternary__".to_string(),
618 args: vec![cond, then_expr, else_expr],
619 }))
620 } else {
621 Ok((rest, cond))
622 }
623}
624
625fn parse_logical_or(input: &str) -> IResult<&str, Expression> {
628 let (mut rest, mut left) = parse_logical_and(input)?;
629 loop {
630 let (r, _) = ws(rest)?;
631 if let Ok((r2, _)) = tag::<&str, &str, nom::error::Error<&str>>("||")(r) {
632 let (r2, right) = parse_logical_and(r2)?;
633 left = Expression::Binary { op: BinaryOp::LogicalOr, left: Box::new(left), right: Box::new(right) };
634 rest = r2;
635 } else {
636 rest = r;
637 break;
638 }
639 }
640 Ok((rest, left))
641}
642
643fn parse_logical_and(input: &str) -> IResult<&str, Expression> {
644 let (mut rest, mut left) = parse_bitwise_or(input)?;
645 loop {
646 let (r, _) = ws(rest)?;
647 if let Ok((r2, _)) = tag::<&str, &str, nom::error::Error<&str>>("&&")(r) {
648 let (r2, right) = parse_bitwise_or(r2)?;
649 left = Expression::Binary { op: BinaryOp::LogicalAnd, left: Box::new(left), right: Box::new(right) };
650 rest = r2;
651 } else {
652 rest = r;
653 break;
654 }
655 }
656 Ok((rest, left))
657}
658
659fn parse_bitwise_or(input: &str) -> IResult<&str, Expression> {
660 let (mut rest, mut left) = parse_bitwise_xor(input)?;
661 loop {
662 let (r, _) = ws(rest)?;
663 if r.starts_with('|') && !r.starts_with("||") && !r.starts_with("|=") {
665 let (r2, _) = char('|')(r)?;
666 let (r2, right) = parse_bitwise_xor(r2)?;
667 left = Expression::Binary { op: BinaryOp::Or, left: Box::new(left), right: Box::new(right) };
668 rest = r2;
669 } else {
670 rest = r;
671 break;
672 }
673 }
674 Ok((rest, left))
675}
676
677fn parse_bitwise_xor(input: &str) -> IResult<&str, Expression> {
678 let (mut rest, mut left) = parse_bitwise_and(input)?;
679 loop {
680 let (r, _) = ws(rest)?;
681 if r.starts_with('^') && !r.starts_with("^=") {
682 let (r2, _) = char('^')(r)?;
683 let (r2, right) = parse_bitwise_and(r2)?;
684 left = Expression::Binary { op: BinaryOp::Xor, left: Box::new(left), right: Box::new(right) };
685 rest = r2;
686 } else {
687 rest = r;
688 break;
689 }
690 }
691 Ok((rest, left))
692}
693
694fn parse_bitwise_and(input: &str) -> IResult<&str, Expression> {
695 let (mut rest, mut left) = parse_equality(input)?;
696 loop {
697 let (r, _) = ws(rest)?;
698 if r.starts_with('&') && !r.starts_with("&&") && !r.starts_with("&=") {
700 let (r2, _) = char('&')(r)?;
701 let (r2, right) = parse_equality(r2)?;
702 left = Expression::Binary { op: BinaryOp::And, left: Box::new(left), right: Box::new(right) };
703 rest = r2;
704 } else {
705 rest = r;
706 break;
707 }
708 }
709 Ok((rest, left))
710}
711
712fn parse_equality(input: &str) -> IResult<&str, Expression> {
713 let (mut rest, mut left) = parse_relational(input)?;
714 loop {
715 let (r, _) = ws(rest)?;
716 if let Ok((r2, _)) = tag::<&str, &str, nom::error::Error<&str>>("==")(r) {
717 let (r2, right) = parse_relational(r2)?;
718 left = Expression::Binary { op: BinaryOp::Eq, left: Box::new(left), right: Box::new(right) };
719 rest = r2;
720 } else if let Ok((r2, _)) = tag::<&str, &str, nom::error::Error<&str>>("!=")(r) {
721 let (r2, right) = parse_relational(r2)?;
722 left = Expression::Binary { op: BinaryOp::Ne, left: Box::new(left), right: Box::new(right) };
723 rest = r2;
724 } else {
725 rest = r;
726 break;
727 }
728 }
729 Ok((rest, left))
730}
731
732fn parse_relational(input: &str) -> IResult<&str, Expression> {
733 let (mut rest, mut left) = parse_shift(input)?;
734 loop {
735 let (r, _) = ws(rest)?;
736 if r.starts_with("<=") {
737 let r2 = &r[2..];
738 let (r2, right) = parse_shift(r2)?;
739 left = Expression::Binary { op: BinaryOp::Le, left: Box::new(left), right: Box::new(right) };
740 rest = r2;
741 } else if r.starts_with(">=") {
742 let r2 = &r[2..];
743 let (r2, right) = parse_shift(r2)?;
744 left = Expression::Binary { op: BinaryOp::Ge, left: Box::new(left), right: Box::new(right) };
745 rest = r2;
746 } else if r.starts_with('<') && !r.starts_with("<<") {
747 let r2 = &r[1..];
748 let (r2, right) = parse_shift(r2)?;
749 left = Expression::Binary { op: BinaryOp::Lt, left: Box::new(left), right: Box::new(right) };
750 rest = r2;
751 } else if r.starts_with('>') && !r.starts_with(">>") {
752 let r2 = &r[1..];
753 let (r2, right) = parse_shift(r2)?;
754 left = Expression::Binary { op: BinaryOp::Gt, left: Box::new(left), right: Box::new(right) };
755 rest = r2;
756 } else {
757 rest = r;
758 break;
759 }
760 }
761 Ok((rest, left))
762}
763
764fn parse_shift(input: &str) -> IResult<&str, Expression> {
765 let (mut rest, mut left) = parse_additive(input)?;
766 loop {
767 let (r, _) = ws(rest)?;
768 if r.starts_with("<<=") {
769 rest = r;
770 break;
771 } else if r.starts_with(">>=") {
772 rest = r;
773 break;
774 } else if r.starts_with("<<") {
775 let r2 = &r[2..];
776 let (r2, right) = parse_additive(r2)?;
777 left = Expression::Binary { op: BinaryOp::Shl, left: Box::new(left), right: Box::new(right) };
778 rest = r2;
779 } else if r.starts_with(">>") {
780 let r2 = &r[2..];
781 let (r2, right) = parse_additive(r2)?;
782 left = Expression::Binary { op: BinaryOp::Shr, left: Box::new(left), right: Box::new(right) };
783 rest = r2;
784 } else {
785 rest = r;
786 break;
787 }
788 }
789 Ok((rest, left))
790}
791
792fn parse_additive(input: &str) -> IResult<&str, Expression> {
793 let (mut rest, mut left) = parse_multiplicative(input)?;
794 loop {
795 let (r, _) = ws(rest)?;
796 if r.starts_with('+') && !r.starts_with("++") && !r.starts_with("+=") {
797 let (r2, _) = char('+')(r)?;
798 let (r2, right) = parse_multiplicative(r2)?;
799 left = Expression::Binary { op: BinaryOp::Add, left: Box::new(left), right: Box::new(right) };
800 rest = r2;
801 } else if r.starts_with('-') && !r.starts_with("--") && !r.starts_with("-=") && !r.starts_with("->") {
802 let (r2, _) = char('-')(r)?;
803 let (r2, right) = parse_multiplicative(r2)?;
804 left = Expression::Binary { op: BinaryOp::Sub, left: Box::new(left), right: Box::new(right) };
805 rest = r2;
806 } else {
807 rest = r;
808 break;
809 }
810 }
811 Ok((rest, left))
812}
813
814fn parse_multiplicative(input: &str) -> IResult<&str, Expression> {
815 let (mut rest, mut left) = parse_unary(input)?;
816 loop {
817 let (r, _) = ws(rest)?;
818 if r.starts_with('*') && !r.starts_with("*=") {
819 let (r2, _) = char('*')(r)?;
820 let (r2, right) = parse_unary(r2)?;
821 left = Expression::Binary { op: BinaryOp::Mul, left: Box::new(left), right: Box::new(right) };
822 rest = r2;
823 } else if r.starts_with('/') && !r.starts_with("/=") && !r.starts_with("//") && !r.starts_with("/*") {
824 let (r2, _) = char('/')(r)?;
825 let (r2, right) = parse_unary(r2)?;
826 left = Expression::Binary { op: BinaryOp::Div, left: Box::new(left), right: Box::new(right) };
827 rest = r2;
828 } else if r.starts_with('%') && !r.starts_with("%=") {
829 let (r2, _) = char('%')(r)?;
830 let (r2, right) = parse_unary(r2)?;
831 left = Expression::Binary { op: BinaryOp::Mod, left: Box::new(left), right: Box::new(right) };
832 rest = r2;
833 } else {
834 rest = r;
835 break;
836 }
837 }
838 Ok((rest, left))
839}
840
841fn parse_block(input: &str) -> IResult<&str, Block> {
846 let (input, _) = ws(input)?;
847 let (input, _) = char('{')(input)?;
848 let (input, stmts) = parse_statement_list(input)?;
849 let (input, _) = ws(input)?;
850 let (input, _) = char('}')(input)?;
851 Ok((input, Block { statements: stmts }))
852}
853
854fn parse_statement_list(input: &str) -> IResult<&str, Vec<Statement>> {
855 let mut stmts = Vec::new();
856 let mut rest = input;
857 loop {
858 let (r, _) = ws(rest)?;
859 if r.starts_with('}') || r.is_empty() {
860 rest = r;
861 break;
862 }
863 match parse_statement(r) {
864 Ok((r2, stmt)) => {
865 stmts.push(stmt);
866 rest = r2;
867 }
868 Err(_) => {
869 if let Some(pos) = r.find(|c: char| c == ';' || c == '}') {
871 if r.as_bytes()[pos] == b';' {
872 rest = &r[pos + 1..];
873 } else {
874 rest = &r[pos..];
875 }
876 } else {
877 break;
878 }
879 }
880 }
881 }
882 Ok((rest, stmts))
883}
884
885fn parse_statement(input: &str) -> IResult<&str, Statement> {
886 let (input, _) = ws(input)?;
887 alt((
888 parse_syncthreads_stmt,
889 parse_return_stmt,
890 parse_break_stmt,
891 parse_continue_stmt,
892 parse_if_stmt,
893 parse_for_stmt,
894 parse_while_stmt,
895 parse_do_while_stmt,
896 parse_block_stmt,
897 parse_var_decl_stmt,
898 parse_expr_stmt,
899 ))(input)
900}
901
902fn parse_syncthreads_stmt(input: &str) -> IResult<&str, Statement> {
903 let (input, _) = tag("__syncthreads")(input)?;
904 let (input, _) = ws(input)?;
905 let (input, _) = char('(')(input)?;
906 let (input, _) = ws(input)?;
907 let (input, _) = char(')')(input)?;
908 let (input, _) = ws(input)?;
909 let (input, _) = char(';')(input)?;
910 Ok((input, Statement::SyncThreads))
911}
912
913fn parse_return_stmt(input: &str) -> IResult<&str, Statement> {
914 let (input, _) = keyword("return")(input)?;
915 let (input, _) = ws(input)?;
916 if let Ok((rest, _)) = char::<&str, nom::error::Error<&str>>(';')(input) {
917 return Ok((rest, Statement::Return(None)));
918 }
919 let (input, expr) = parse_expr(input)?;
920 let (input, _) = ws(input)?;
921 let (input, _) = char(';')(input)?;
922 Ok((input, Statement::Return(Some(expr))))
923}
924
925fn parse_break_stmt(input: &str) -> IResult<&str, Statement> {
926 let (input, _) = keyword("break")(input)?;
927 let (input, _) = ws(input)?;
928 let (input, _) = char(';')(input)?;
929 Ok((input, Statement::Break))
930}
931
932fn parse_continue_stmt(input: &str) -> IResult<&str, Statement> {
933 let (input, _) = keyword("continue")(input)?;
934 let (input, _) = ws(input)?;
935 let (input, _) = char(';')(input)?;
936 Ok((input, Statement::Continue))
937}
938
939fn parse_if_stmt(input: &str) -> IResult<&str, Statement> {
940 let (input, _) = keyword("if")(input)?;
941 let (input, _) = ws(input)?;
942 let (input, _) = char('(')(input)?;
943 let (input, condition) = parse_expr(input)?;
944 let (input, _) = ws(input)?;
945 let (input, _) = char(')')(input)?;
946 let (input, then_branch) = parse_statement(input)?;
947 let (input, _) = ws(input)?;
948 let (input, else_branch) = opt(preceded(
949 pair(keyword("else"), ws),
950 parse_statement,
951 ))(input)?;
952
953 Ok((input, Statement::If {
954 condition,
955 then_branch: Box::new(then_branch),
956 else_branch: else_branch.map(Box::new),
957 }))
958}
959
960fn parse_for_stmt(input: &str) -> IResult<&str, Statement> {
961 let (input, _) = keyword("for")(input)?;
962 let (input, _) = ws(input)?;
963 let (input, _) = char('(')(input)?;
964
965 let (input, _) = ws(input)?;
967 let (input, init) = if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>(';')(input) {
968 (r, None)
969 } else if let Ok((r, stmt)) = parse_var_decl_stmt(input) {
970 (r, Some(Box::new(stmt)))
972 } else {
973 let (r, expr) = parse_expr(input)?;
974 let (r, _) = ws(r)?;
975 let (r, _) = char(';')(r)?;
976 (r, Some(Box::new(Statement::Expr(expr))))
977 };
978
979 let (input, _) = ws(input)?;
981 let (input, condition) = if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>(';')(input) {
982 (r, None)
983 } else {
984 let (r, expr) = parse_expr(input)?;
985 let (r, _) = ws(r)?;
986 let (r, _) = char(';')(r)?;
987 (r, Some(expr))
988 };
989
990 let (input, _) = ws(input)?;
992 let (input, update) = if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>(')')(input) {
993 (r, None)
994 } else {
995 let (r, expr) = parse_expr(input)?;
996 let (r, _) = ws(r)?;
997 let (r, _) = char(')')(r)?;
998 (r, Some(expr))
999 };
1000
1001 let (input, body) = parse_statement(input)?;
1002
1003 Ok((input, Statement::For {
1004 init,
1005 condition,
1006 update,
1007 body: Box::new(body),
1008 }))
1009}
1010
1011fn parse_while_stmt(input: &str) -> IResult<&str, Statement> {
1012 let (input, _) = keyword("while")(input)?;
1013 let (input, _) = ws(input)?;
1014 let (input, _) = char('(')(input)?;
1015 let (input, condition) = parse_expr(input)?;
1016 let (input, _) = ws(input)?;
1017 let (input, _) = char(')')(input)?;
1018 let (input, body) = parse_statement(input)?;
1019
1020 Ok((input, Statement::While {
1021 condition,
1022 body: Box::new(body),
1023 }))
1024}
1025
1026fn parse_do_while_stmt(input: &str) -> IResult<&str, Statement> {
1027 let (input, _) = keyword("do")(input)?;
1028 let (input, body) = parse_statement(input)?;
1029 let (input, _) = ws(input)?;
1030 let (input, _) = keyword("while")(input)?;
1031 let (input, _) = ws(input)?;
1032 let (input, _) = char('(')(input)?;
1033 let (input, condition) = parse_expr(input)?;
1034 let (input, _) = ws(input)?;
1035 let (input, _) = char(')')(input)?;
1036 let (input, _) = ws(input)?;
1037 let (input, _) = char(';')(input)?;
1038
1039 Ok((input, Statement::While {
1040 condition,
1041 body: Box::new(body),
1042 }))
1043}
1044
1045fn parse_block_stmt(input: &str) -> IResult<&str, Statement> {
1046 let (input, block) = parse_block(input)?;
1047 Ok((input, Statement::Block(block)))
1048}
1049
1050fn parse_var_decl_stmt(input: &str) -> IResult<&str, Statement> {
1057 let (input, _) = ws(input)?;
1058
1059 let mut storage = StorageClass::Auto;
1061 let mut rest = input;
1062 let mut has_extern = false;
1063
1064 if let Ok((r, _)) = keyword("extern")(rest) {
1066 has_extern = true;
1067 rest = r;
1068 let (r, _) = ws(rest)?;
1069 rest = r;
1070 }
1071
1072 if let Ok((r, _)) = tag::<&str, &str, nom::error::Error<&str>>("__shared__")(rest) {
1074 storage = StorageClass::Shared;
1075 rest = r;
1076 } else if let Ok((r, _)) = tag::<&str, &str, nom::error::Error<&str>>("__constant__")(rest) {
1077 storage = StorageClass::Constant;
1078 rest = r;
1079 } else if let Ok((r, _)) = keyword("register")(rest) {
1080 storage = StorageClass::Register;
1081 rest = r;
1082 } else if let Ok((r, _)) = keyword("static")(rest) {
1083 rest = r;
1085 } else if has_extern {
1086 return Err(nom::Err::Error(nom::error::Error::new(input, nom::error::ErrorKind::Tag)));
1088 }
1089
1090 let (rest, (mut ty, qualifiers)) = parse_type(rest)?;
1092 let (rest, _) = ws(rest)?;
1093
1094 let (rest, name) = identifier(rest)?;
1096
1097 let kw_set = ["if", "else", "for", "while", "do", "return", "break", "continue",
1099 "switch", "case", "default", "goto", "__syncthreads"];
1100 if kw_set.contains(&name) {
1101 return Err(nom::Err::Error(nom::error::Error::new(input, nom::error::ErrorKind::Tag)));
1102 }
1103
1104 let (rest, _) = ws(rest)?;
1105
1106 if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>('[')(rest) {
1108 let (r, _) = ws(r)?;
1109 if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>(']')(r) {
1110 ty = Type::Array(Box::new(ty), None);
1111 let (r, _) = ws(r)?;
1112 let mut r = r;
1114 while let Ok((r2, _)) = char::<&str, nom::error::Error<&str>>('[')(r) {
1115 let (r2, _) = ws(r2)?;
1116 let (r2, size_expr) = parse_expr(r2)?;
1117 let (r2, _) = ws(r2)?;
1118 let (r2, _) = char(']')(r2)?;
1119 let (r2, _) = ws(r2)?;
1120 r = r2;
1121 }
1124 let (r, _) = ws(r)?;
1125 let (r, _) = char(';')(r)?;
1126 return Ok((r, Statement::VarDecl {
1127 name: name.to_string(),
1128 ty,
1129 init: None,
1130 storage,
1131 }));
1132 } else {
1133 let (r, size_expr) = parse_expr(r)?;
1135 let (r, _) = ws(r)?;
1136 let (r, _) = char(']')(r)?;
1137 let size = if let Expression::Literal(Literal::Int(n)) = &size_expr {
1138 Some(*n as usize)
1139 } else {
1140 None
1141 };
1142 ty = Type::Array(Box::new(ty), size);
1143 let mut r = r;
1144 let (r2, _) = ws(r)?;
1145 while let Ok((r3, _)) = char::<&str, nom::error::Error<&str>>('[')(r2) {
1147 let (r3, _) = ws(r3)?;
1148 let (r3, _size2) = parse_expr(r3)?;
1149 let (r3, _) = ws(r3)?;
1150 let (r3, _) = char(']')(r3)?;
1151 r = r3;
1152 let (r4, _) = ws(r)?;
1153 if r4.starts_with('[') {
1155 continue;
1156 }
1157 break;
1158 }
1159 let (r, _) = ws(r)?;
1160 let (r, init) = if let Ok((r2, _)) = char::<&str, nom::error::Error<&str>>('=')(r) {
1162 let (r2, expr) = parse_expr(r2)?;
1163 (r2, Some(expr))
1164 } else {
1165 (r, None)
1166 };
1167 let (r, _) = ws(r)?;
1168 let (r, _) = char(';')(r)?;
1169 return Ok((r, Statement::VarDecl {
1170 name: name.to_string(),
1171 ty,
1172 init,
1173 storage,
1174 }));
1175 }
1176 }
1177
1178 let (rest, init) = if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>('=')(rest) {
1180 let (r, expr) = parse_expr(r)?;
1181 (r, Some(expr))
1182 } else {
1183 (rest, None)
1184 };
1185
1186 let (rest, _) = ws(rest)?;
1187 let (rest, _) = char(';')(rest)?;
1188
1189 Ok((rest, Statement::VarDecl {
1190 name: name.to_string(),
1191 ty,
1192 init,
1193 storage,
1194 }))
1195}
1196
1197fn parse_expr_stmt(input: &str) -> IResult<&str, Statement> {
1198 let (input, expr) = parse_expr(input)?;
1199 let (input, _) = ws(input)?;
1200 let (input, _) = char(';')(input)?;
1201
1202 if let Expression::Call { ref name, ref args } = expr {
1204 if name == "__syncthreads" && args.is_empty() {
1205 return Ok((input, Statement::SyncThreads));
1206 }
1207 }
1208
1209 Ok((input, Statement::Expr(expr)))
1210}
1211
1212fn parse_parameter(input: &str) -> IResult<&str, Parameter> {
1217 let (input, _) = ws(input)?;
1218 let (input, (ty, qualifiers)) = parse_type(input)?;
1219 let (input, _) = ws(input)?;
1220 let (input, name) = identifier(input)?;
1221 let (input, _) = ws(input)?;
1223 let (input, _ty) = if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>('[')(input) {
1224 let (r, _) = ws(r)?;
1225 if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>(']')(r) {
1226 (r, Type::Pointer(Box::new(ty.clone())))
1227 } else {
1228 let (r, _) = parse_expr(r)?;
1229 let (r, _) = ws(r)?;
1230 let (r, _) = char(']')(r)?;
1231 (r, Type::Pointer(Box::new(ty.clone())))
1232 }
1233 } else {
1234 (input, ty.clone())
1235 };
1236
1237 Ok((input, Parameter {
1238 name: name.to_string(),
1239 ty: _ty,
1240 qualifiers,
1241 }))
1242}
1243
1244fn parse_param_list(input: &str) -> IResult<&str, Vec<Parameter>> {
1245 let (input, _) = ws(input)?;
1246 let (input, _) = char('(')(input)?;
1247 let (input, _) = ws(input)?;
1248 if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>(')')(input) {
1250 return Ok((r, vec![]));
1251 }
1252 if let Ok((r, _)) = keyword("void")(input) {
1253 let (r, _) = ws(r)?;
1254 if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>(')')(r) {
1255 return Ok((r, vec![]));
1256 }
1257 }
1258 let (input, params) = separated_list0(
1259 delimited(ws, char(','), ws),
1260 parse_parameter,
1261 )(input)?;
1262 let (input, _) = ws(input)?;
1263 let (input, _) = char(')')(input)?;
1264 Ok((input, params))
1265}
1266
1267fn parse_kernel_def(input: &str) -> IResult<&str, Item> {
1269 let (input, _) = ws(input)?;
1270 let input = skip_template(input);
1272 let (input, _) = ws(input)?;
1273 let (input, _) = tag("__global__")(input)?;
1274 let (input, _) = ws(input)?;
1275 let (input, _) = keyword("void")(input)?;
1276 let (input, _) = ws(input)?;
1277 let (input, name) = identifier(input)?;
1278 let (input, params) = parse_param_list(input)?;
1279 let (input, body) = parse_block(input)?;
1280
1281 Ok((input, Item::Kernel(KernelDef {
1282 name: name.to_string(),
1283 params,
1284 body,
1285 attributes: vec![],
1286 })))
1287}
1288
1289fn parse_device_function(input: &str) -> IResult<&str, Item> {
1291 let (input, _) = ws(input)?;
1292 let input = skip_template(input);
1293 let (input, _) = ws(input)?;
1294 let (input, _) = tag("__device__")(input)?;
1295 let (input, _) = ws(input)?;
1296 let (input, also_host) = opt(preceded(tag("__host__"), ws))(input)?;
1298 let (input, _) = opt(preceded(tag("__forceinline__"), ws))(input)?;
1300 let (input, _) = opt(preceded(keyword("inline"), ws))(input)?;
1301 let (input, (ret_ty, _)) = parse_type(input)?;
1302 let (input, _) = ws(input)?;
1303 let (input, name) = identifier(input)?;
1304 let (input, params) = parse_param_list(input)?;
1305 let (input, body) = parse_block(input)?;
1306
1307 let mut qualifiers = vec![FunctionQualifier::Device];
1308 if also_host.is_some() {
1309 qualifiers.push(FunctionQualifier::Host);
1310 }
1311
1312 Ok((input, Item::DeviceFunction(FunctionDef {
1313 name: name.to_string(),
1314 return_type: ret_ty,
1315 params,
1316 body,
1317 qualifiers,
1318 })))
1319}
1320
1321fn parse_host_function(input: &str) -> IResult<&str, Item> {
1323 let (input, _) = ws(input)?;
1324 let input = skip_template(input);
1325 let (input, _) = ws(input)?;
1326 let (input, _) = tag("__host__")(input)?;
1327 let (input, _) = ws(input)?;
1328 let (input, also_device) = opt(preceded(tag("__device__"), ws))(input)?;
1330 let (input, _) = opt(preceded(keyword("inline"), ws))(input)?;
1331 let (input, (ret_ty, _)) = parse_type(input)?;
1332 let (input, _) = ws(input)?;
1333 let (input, name) = identifier(input)?;
1334 let (input, params) = parse_param_list(input)?;
1335 let (input, body) = parse_block(input)?;
1336
1337 let mut qualifiers = vec![FunctionQualifier::Host];
1338 if also_device.is_some() {
1339 qualifiers.push(FunctionQualifier::Device);
1340 }
1341
1342 Ok((input, Item::HostFunction(FunctionDef {
1343 name: name.to_string(),
1344 return_type: ret_ty,
1345 params,
1346 body,
1347 qualifiers,
1348 })))
1349}
1350
1351fn parse_include(input: &str) -> IResult<&str, Item> {
1353 let (input, _) = ws(input)?;
1354 let (input, _) = char('#')(input)?;
1355 let (input, _) = ws(input)?;
1356 let (input, _) = tag("include")(input)?;
1357 let (input, _) = take_while(|c: char| c == ' ' || c == '\t')(input)?;
1358 let (input, path) = alt((
1360 delimited(char('<'), take_until(">"), char('>')),
1361 delimited(char('"'), take_until("\""), char('"')),
1362 ))(input)?;
1363 let (input, _) = take_while(|c: char| c != '\n')(input)?;
1365 Ok((input, Item::Include(path.to_string())))
1366}
1367
1368fn parse_typedef(input: &str) -> IResult<&str, Item> {
1370 let (input, _) = ws(input)?;
1371 let (input, _) = keyword("typedef")(input)?;
1372 let (input, _) = ws(input)?;
1373
1374 if let Ok((r, _)) = keyword("struct")(input) {
1376 let (r, _) = ws(r)?;
1377 let (r, _struct_name) = opt(identifier)(r)?;
1379 let (r, _) = ws(r)?;
1380 if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>('{')(r) {
1382 let r = skip_balanced_braces(r);
1383 let (r, _) = ws(r)?;
1384 let (r, name) = identifier(r)?;
1385 let (r, _) = ws(r)?;
1386 let (r, _) = char(';')(r)?;
1387 return Ok((r, Item::TypeDef(TypeDef {
1388 name: name.to_string(),
1389 ty: Type::Named(name.to_string()),
1390 })));
1391 }
1392 }
1393
1394 let (input, (ty, _)) = parse_type(input)?;
1395 let (input, _) = ws(input)?;
1396 let (input, name) = identifier(input)?;
1397 let (input, _) = ws(input)?;
1398 let (input, _) = char(';')(input)?;
1399 Ok((input, Item::TypeDef(TypeDef {
1400 name: name.to_string(),
1401 ty,
1402 })))
1403}
1404
1405fn parse_struct(input: &str) -> IResult<&str, Item> {
1407 let (input, _) = ws(input)?;
1408 let (input, _) = keyword("struct")(input)?;
1409 let (input, _) = ws(input)?;
1410 let (input, name) = identifier(input)?;
1411 let (input, _) = ws(input)?;
1412 let (input, _) = char('{')(input)?;
1413 let rest = skip_balanced_braces(input);
1414 let (rest, _) = ws(rest)?;
1415 let (rest, _) = char(';')(rest)?;
1416 Ok((rest, Item::TypeDef(TypeDef {
1417 name: name.to_string(),
1418 ty: Type::Named(name.to_string()),
1419 })))
1420}
1421
1422fn skip_template(input: &str) -> &str {
1424 let trimmed = input.trim_start();
1425 if !trimmed.starts_with("template") {
1426 return input;
1427 }
1428 let rest = &trimmed[8..];
1429 let rest = rest.trim_start();
1430 if !rest.starts_with('<') {
1431 return input;
1432 }
1433 let mut depth = 0;
1434 for (i, c) in rest.char_indices() {
1435 match c {
1436 '<' => depth += 1,
1437 '>' => {
1438 depth -= 1;
1439 if depth == 0 {
1440 return &rest[i + 1..];
1441 }
1442 }
1443 _ => {}
1444 }
1445 }
1446 input
1447}
1448
1449fn skip_balanced_braces(input: &str) -> &str {
1451 let mut depth = 1;
1452 for (i, c) in input.char_indices() {
1453 match c {
1454 '{' => depth += 1,
1455 '}' => {
1456 depth -= 1;
1457 if depth == 0 {
1458 return &input[i + 1..];
1459 }
1460 }
1461 _ => {}
1462 }
1463 }
1464 input
1465}
1466
1467fn parse_preprocessor(input: &str) -> IResult<&str, ()> {
1469 let (input, _) = ws(input)?;
1470 let (input, _) = char('#')(input)?;
1471 let (input, _) = take_while(|c: char| c != '\n')(input)?;
1472 Ok((input, ()))
1473}
1474
1475fn parse_global_var_decl(input: &str) -> IResult<&str, Item> {
1481 let (input, _) = ws(input)?;
1482 let rest = input;
1484 let storage;
1485 let rest = if let Ok((r, _)) = tag::<&str, &str, nom::error::Error<&str>>("__constant__")(rest) {
1486 storage = StorageClass::Constant;
1487 r
1488 } else if let Ok((r, _)) = tag::<&str, &str, nom::error::Error<&str>>("__shared__")(rest) {
1489 storage = StorageClass::Shared;
1490 r
1491 } else {
1492 return Err(nom::Err::Error(nom::error::Error::new(input, nom::error::ErrorKind::Tag)));
1493 };
1494 let (rest, _) = ws(rest)?;
1495 let (rest, (ty, _qualifiers)) = parse_type(rest)?;
1496 let (rest, _) = ws(rest)?;
1497 let (rest, name) = identifier(rest)?;
1498 let (rest, _) = ws(rest)?;
1499 let (rest, ty) = if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>('[')(rest) {
1501 let (r, _) = ws(r)?;
1502 let (r, size_expr) = parse_expr(r)?;
1503 let (r, _) = ws(r)?;
1504 let (r, _) = char(']')(r)?;
1505 let size = if let Expression::Literal(Literal::Int(n)) = &size_expr {
1506 Some(*n as usize)
1507 } else {
1508 None
1509 };
1510 (r, Type::Array(Box::new(ty), size))
1511 } else {
1512 (rest, ty)
1513 };
1514 let (rest, _) = ws(rest)?;
1515 let (rest, init) = if let Ok((r, _)) = char::<&str, nom::error::Error<&str>>('=')(rest) {
1517 let (r, _) = ws(r)?;
1518 if r.starts_with('{') {
1520 let end = r.find('}').unwrap_or(r.len() - 1);
1522 let r = &r[end + 1..];
1523 (r, None) } else {
1525 let (r, expr) = parse_expr(r)?;
1526 (r, Some(expr))
1527 }
1528 } else {
1529 (rest, None)
1530 };
1531 let (rest, _) = ws(rest)?;
1532 let (rest, _) = char(';')(rest)?;
1533 Ok((rest, Item::GlobalVar(GlobalVar {
1534 name: name.to_string(),
1535 ty,
1536 storage,
1537 init,
1538 })))
1539}
1540
1541fn parse_top_level_item(input: &str) -> IResult<&str, Option<Item>> {
1542 let (input, _) = ws(input)?;
1543 if input.is_empty() {
1544 return Err(nom::Err::Error(nom::error::Error::new(input, nom::error::ErrorKind::Eof)));
1545 }
1546
1547 if let Ok((r, item)) = parse_include(input) {
1549 return Ok((r, Some(item)));
1550 }
1551 if let Ok((r, item)) = parse_global_var_decl(input) {
1553 return Ok((r, Some(item)));
1554 }
1555 if let Ok((r, item)) = parse_kernel_def(input) {
1556 return Ok((r, Some(item)));
1557 }
1558 if let Ok((r, item)) = parse_device_function(input) {
1559 return Ok((r, Some(item)));
1560 }
1561 if let Ok((r, item)) = parse_host_function(input) {
1562 return Ok((r, Some(item)));
1563 }
1564 if let Ok((r, item)) = parse_typedef(input) {
1565 return Ok((r, Some(item)));
1566 }
1567 if let Ok((r, item)) = parse_struct(input) {
1568 return Ok((r, Some(item)));
1569 }
1570
1571 if input.starts_with('#') {
1573 let (r, _) = parse_preprocessor(input)?;
1574 return Ok((r, None));
1575 }
1576
1577 if let Some(pos) = input.find(|c: char| c == '{' || c == ';') {
1580 if input.as_bytes()[pos] == b'{' {
1581 let rest = skip_balanced_braces(&input[pos + 1..]);
1582 return Ok((rest, None));
1583 } else {
1584 return Ok((&input[pos + 1..], None));
1585 }
1586 }
1587
1588 Err(nom::Err::Error(nom::error::Error::new(input, nom::error::ErrorKind::Eof)))
1589}
1590
1591pub struct CudaParser {
1597 }
1599
1600impl CudaParser {
1601 pub fn new() -> Self {
1603 Self {}
1604 }
1605
1606 pub fn parse(&self, source: &str) -> Result<Ast> {
1608 let mut items = Vec::new();
1609 let mut rest = source;
1610
1611 loop {
1612 match ws(rest) {
1614 Ok((r, _)) => rest = r,
1615 Err(_) => break,
1616 }
1617 if rest.is_empty() {
1618 break;
1619 }
1620
1621 match parse_top_level_item(rest) {
1622 Ok((r, Some(item))) => {
1623 items.push(item);
1624 rest = r;
1625 }
1626 Ok((r, None)) => {
1627 rest = r;
1629 }
1630 Err(_) => {
1631 if rest.is_empty() {
1633 break;
1634 }
1635 if let Some(pos) = rest[1..].find(|c: char| {
1637 c == '#' || c == '_' || c.is_alphabetic()
1638 }) {
1639 rest = &rest[pos + 1..];
1640 } else {
1641 break;
1642 }
1643 }
1644 }
1645 }
1646
1647 Ok(Ast { items })
1648 }
1649}
1650
1651impl Default for CudaParser {
1652 fn default() -> Self {
1653 Self::new()
1654 }
1655}
1656
1657#[cfg(test)]
1662mod tests {
1663 use super::*;
1664
1665 #[test]
1666 fn test_vector_add() {
1667 let src = r#"
1668__global__ void vectorAdd(const float* a, const float* b, float* c, int n) {
1669 int i = blockIdx.x * blockDim.x + threadIdx.x;
1670 if (i < n) {
1671 c[i] = a[i] + b[i];
1672 }
1673}
1674"#;
1675 let parser = CudaParser::new();
1676 let ast = parser.parse(src).unwrap();
1677 assert_eq!(ast.items.len(), 1);
1678 if let Item::Kernel(ref k) = ast.items[0] {
1679 assert_eq!(k.name, "vectorAdd");
1680 assert_eq!(k.params.len(), 4);
1681 assert_eq!(k.params[0].name, "a");
1682 assert_eq!(k.params[3].name, "n");
1683 assert_eq!(k.body.statements.len(), 2);
1685 } else {
1686 panic!("Expected kernel");
1687 }
1688 }
1689
1690 #[test]
1691 fn test_mat_mul() {
1692 let src = r#"
1693__global__ void matMul(float* A, float* B, float* C, int M, int N, int K) {
1694 __shared__ float sA[16][16];
1695 __shared__ float sB[16][16];
1696 int row = blockIdx.y * blockDim.y + threadIdx.y;
1697 int col = blockIdx.x * blockDim.x + threadIdx.x;
1698 float sum = 0.0f;
1699 for (int t = 0; t < (K + 15) / 16; t++) {
1700 sA[threadIdx.y][threadIdx.x] = A[row * K + t * 16 + threadIdx.x];
1701 sB[threadIdx.y][threadIdx.x] = B[(t * 16 + threadIdx.y) * N + col];
1702 __syncthreads();
1703 for (int k = 0; k < 16; k++) {
1704 sum += sA[threadIdx.y][k] * sB[k][threadIdx.x];
1705 }
1706 __syncthreads();
1707 }
1708 C[row * N + col] = sum;
1709}
1710"#;
1711 let parser = CudaParser::new();
1712 let ast = parser.parse(src).unwrap();
1713 assert_eq!(ast.items.len(), 1);
1714 if let Item::Kernel(ref k) = ast.items[0] {
1715 assert_eq!(k.name, "matMul");
1716 assert_eq!(k.params.len(), 6);
1717 } else {
1718 panic!("Expected kernel");
1719 }
1720 }
1721
1722 #[test]
1723 fn test_reduce() {
1724 let src = r#"
1725__global__ void reduce(float* input, float* output, int n) {
1726 extern __shared__ float sdata[];
1727 unsigned int tid = threadIdx.x;
1728 unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
1729 sdata[tid] = (i < n) ? input[i] : 0.0f;
1730 __syncthreads();
1731 for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
1732 if (tid < s) {
1733 sdata[tid] += sdata[tid + s];
1734 }
1735 __syncthreads();
1736 }
1737 if (tid == 0) output[blockIdx.x] = sdata[0];
1738}
1739"#;
1740 let parser = CudaParser::new();
1741 let ast = parser.parse(src).unwrap();
1742 assert_eq!(ast.items.len(), 1);
1743 if let Item::Kernel(ref k) = ast.items[0] {
1744 assert_eq!(k.name, "reduce");
1745 assert_eq!(k.params.len(), 3);
1746 } else {
1747 panic!("Expected kernel");
1748 }
1749 }
1750
1751 #[test]
1752 fn test_include_directive() {
1753 let src = r#"#include <cuda_runtime.h>
1754#include "myheader.h"
1755"#;
1756 let parser = CudaParser::new();
1757 let ast = parser.parse(src).unwrap();
1758 assert!(ast.items.len() >= 2);
1759 assert!(matches!(&ast.items[0], Item::Include(p) if p == "cuda_runtime.h"));
1760 assert!(matches!(&ast.items[1], Item::Include(p) if p == "myheader.h"));
1761 }
1762
1763 #[test]
1764 fn test_multiple_kernels() {
1765 let src = r#"
1766__global__ void kernel1(int* a) {
1767 a[threadIdx.x] = 0;
1768}
1769__global__ void kernel2(float* b, int n) {
1770 int i = threadIdx.x;
1771 if (i < n) b[i] = 1.0f;
1772}
1773"#;
1774 let parser = CudaParser::new();
1775 let ast = parser.parse(src).unwrap();
1776 assert_eq!(ast.items.len(), 2);
1777 }
1778
1779 #[test]
1780 fn test_device_function() {
1781 let src = r#"
1782__device__ float clamp(float x, float lo, float hi) {
1783 if (x < lo) return lo;
1784 if (x > hi) return hi;
1785 return x;
1786}
1787"#;
1788 let parser = CudaParser::new();
1789 let ast = parser.parse(src).unwrap();
1790 assert_eq!(ast.items.len(), 1);
1791 assert!(matches!(&ast.items[0], Item::DeviceFunction(_)));
1792 }
1793
1794 #[test]
1795 fn test_expressions() {
1796 assert!(parse_expr("a + b * c").is_ok());
1798 assert!(parse_expr("a[i]").is_ok());
1799 assert!(parse_expr("threadIdx.x").is_ok());
1800 assert!(parse_expr("blockIdx.x * blockDim.x + threadIdx.x").is_ok());
1801 assert!(parse_expr("(float)x").is_ok());
1802 assert!(parse_expr("a < b && c > d").is_ok());
1803 assert!(parse_expr("i++").is_ok());
1804 assert!(parse_expr("++i").is_ok());
1805 assert!(parse_expr("atomicAdd(&x, 1)").is_ok());
1806 }
1807
1808 #[test]
1809 fn test_type_parsing() {
1810 assert!(parse_type("float*").is_ok());
1811 assert!(parse_type("const float*").is_ok());
1812 assert!(parse_type("unsigned int").is_ok());
1813 assert!(parse_type("int").is_ok());
1814 assert!(parse_type("float4").is_ok());
1815 assert!(parse_type("double").is_ok());
1816 }
1817}