1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
//! DSL syntax validation for GPU kernel entry points.
//!
//! This module validates that kernel functions follow the restrictions and requirements
//! of the cuTile Rust DSL. It ensures type safety and prevents unsupported patterns that
//! would fail during MLIR compilation or GPU execution.
//!
//! ## Validation Rules
//!
//! ### Parameter Types
//!
//! Kernel entry points may only use the following parameter types:
//!
//! - **Scalars** - Primitive types like `i32`, `f32`, etc.
//! - **`&Tensor<T, S>`** - Immutable tensor references (read-only access)
//! - **`&mut Tensor<T, S>`** - Mutable tensor references (partitioned tensors)
//! - **`*mut T`** - Raw pointers (for unsafe kernels only)
//!
//! ### Disallowed Patterns
//!
//! - **Owned tensors** - Cannot move tensors into kernels
//! - **Arbitrary references** - Only `&Tensor` and `&mut Tensor` are supported
//! - **Complex types** - No user-defined structs (except DSL types)
//! - **Closures** - No closure parameters
//!
//! ## Examples
//!
//! ### Valid Entry Points
//!
//! ```rust,ignore
//! #[cutile::entry]
//! fn valid_kernel<const N: i32>(
//! output: &mut Tensor<f32, {[N]}>, // ✓ Mutable tensor (partitioned)
//! input: &Tensor<f32, {[-1]}>, // ✓ Immutable tensor
//! scalar: f32, // ✓ Scalar parameter
//! ) { }
//! ```
//!
//! ### Invalid Entry Points
//!
//! ```rust,ignore
//! #[cutile::entry]
//! fn invalid_kernel(
//! owned: Tensor<f32, {[128]}>, // ✗ Cannot move tensors
//! vec_ref: &Vec<f32>, // ✗ Only &Tensor references allowed
//! ) { }
//! ```
//!
//! ## Error Messages
//!
//! The validator provides helpful error messages when validation fails:
//!
//! - Explains why a parameter type is not supported
//! - Suggests correct usage (e.g., using `&mut Tensor` for partitioned tensors)
//! - Points to the specific parameter that caused the error
use ;
use get_ptr_type;
use ToTokens;
use ;
use crate;
/// Validates that kernel entry point parameters follow DSL restrictions.
///
/// This function checks each parameter in a kernel function signature to ensure it uses
/// only supported types. It enforces the safety guarantees of the cuTile Rust DSL.
///
/// ## Supported Parameter Types
///
/// - **Scalars**: `i32`, `f32`, `bool`, etc.
/// - **Tensor references**: `&Tensor<T, S>` (immutable) or `&mut Tensor<T, S>` (mutable/partitioned)
/// - **Raw pointers**: `*mut T` (unsafe kernels only)
///
/// ## Validation Logic
///
/// For each parameter:
/// 1. **References** - Must be `&Tensor` or `&mut Tensor`
/// 2. **Path types** - Disallows owned `Tensor` (suggests using references)
/// 3. **Pointers** - Validates pointer type is supported
/// 4. **Other types** - Assumed to be scalars (validated elsewhere)
///
/// ## Errors
///
/// Returns an `Error` if an unsupported parameter type is encountered.
/// The error message includes the problematic type and suggestions for fixing it.
///
/// ## Examples
///
/// ```rust,ignore
/// // This would pass validation
/// fn valid_kernel(x: &mut Tensor<f32, {[128]}>, y: &Tensor<f32, {[-1]}>) { }
///
/// // This would panic with helpful error message
/// fn invalid_kernel(x: Tensor<f32, {[128]}>) { }
/// // Error: "Tensors cannot be moved into kernel functions. Use &mut Tensor for
/// // partitioned tensors or &Tensor for tensor references."
/// ```
// Ensure only valid parameters have been specified in function signatures.
// Currently only supporting scalars, &Tensor, and &mut Tensor for safe kernels.
// * mut T for unsafe kernels.
// TODO (hme): Implement a comprehensive validation pass on entire module.