1use std::collections::HashSet;
6
7use crate::{Block, BlockId, IrModule, IrNode, IrType, Terminator, ValueId};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
11pub enum ValidationLevel {
12 None,
14 Basic,
16 Full,
18 Strict,
20}
21
22#[derive(Debug, Clone)]
24pub struct ValidationResult {
25 pub errors: Vec<ValidationError>,
27 pub warnings: Vec<ValidationWarning>,
29}
30
31impl ValidationResult {
32 pub fn success() -> Self {
34 Self {
35 errors: Vec::new(),
36 warnings: Vec::new(),
37 }
38 }
39
40 pub fn is_ok(&self) -> bool {
42 self.errors.is_empty()
43 }
44
45 pub fn is_clean(&self) -> bool {
47 self.errors.is_empty() && self.warnings.is_empty()
48 }
49
50 pub fn add_error(&mut self, error: ValidationError) {
52 self.errors.push(error);
53 }
54
55 pub fn add_warning(&mut self, warning: ValidationWarning) {
57 self.warnings.push(warning);
58 }
59}
60
61#[derive(Debug, Clone)]
63pub struct ValidationError {
64 pub kind: ValidationErrorKind,
66 pub location: Option<ValidationLocation>,
68 pub message: String,
70}
71
72impl std::fmt::Display for ValidationError {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 if let Some(loc) = &self.location {
75 write!(f, "{}: {}: {}", loc, self.kind, self.message)
76 } else {
77 write!(f, "{}: {}", self.kind, self.message)
78 }
79 }
80}
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq)]
84pub enum ValidationErrorKind {
85 TypeMismatch,
87 UndefinedValue,
89 UndefinedBlock,
91 UnterminatedBlock,
93 InvalidOperation,
95 SsaViolation,
97 ControlFlow,
99 MissingEntry,
101}
102
103impl std::fmt::Display for ValidationErrorKind {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 match self {
106 ValidationErrorKind::TypeMismatch => write!(f, "type mismatch"),
107 ValidationErrorKind::UndefinedValue => write!(f, "undefined value"),
108 ValidationErrorKind::UndefinedBlock => write!(f, "undefined block"),
109 ValidationErrorKind::UnterminatedBlock => write!(f, "unterminated block"),
110 ValidationErrorKind::InvalidOperation => write!(f, "invalid operation"),
111 ValidationErrorKind::SsaViolation => write!(f, "SSA violation"),
112 ValidationErrorKind::ControlFlow => write!(f, "control flow error"),
113 ValidationErrorKind::MissingEntry => write!(f, "missing entry block"),
114 }
115 }
116}
117
118#[derive(Debug, Clone)]
120pub struct ValidationWarning {
121 pub kind: ValidationWarningKind,
123 pub location: Option<ValidationLocation>,
125 pub message: String,
127}
128
129#[derive(Debug, Clone, Copy, PartialEq, Eq)]
131pub enum ValidationWarningKind {
132 UnusedValue,
134 UnreachableCode,
136 Performance,
138 Deprecated,
140}
141
142#[derive(Debug, Clone)]
144pub struct ValidationLocation {
145 pub block: Option<BlockId>,
147 pub instruction: Option<usize>,
149 pub value: Option<ValueId>,
151}
152
153impl std::fmt::Display for ValidationLocation {
154 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155 let mut parts = Vec::new();
156 if let Some(block) = &self.block {
157 parts.push(format!("block {}", block));
158 }
159 if let Some(inst) = &self.instruction {
160 parts.push(format!("instruction {}", inst));
161 }
162 if let Some(value) = &self.value {
163 parts.push(format!("value {}", value));
164 }
165 write!(f, "{}", parts.join(", "))
166 }
167}
168
169pub struct Validator {
171 level: ValidationLevel,
172 result: ValidationResult,
173 defined_values: HashSet<ValueId>,
174 defined_blocks: HashSet<BlockId>,
175}
176
177impl Validator {
178 pub fn new(level: ValidationLevel) -> Self {
180 Self {
181 level,
182 result: ValidationResult::success(),
183 defined_values: HashSet::new(),
184 defined_blocks: HashSet::new(),
185 }
186 }
187
188 pub fn validate(mut self, module: &IrModule) -> ValidationResult {
190 if self.level == ValidationLevel::None {
191 return ValidationResult::success();
192 }
193
194 self.collect_definitions(module);
196
197 if !self.defined_blocks.contains(&module.entry_block) {
199 self.result.add_error(ValidationError {
200 kind: ValidationErrorKind::MissingEntry,
201 location: None,
202 message: "Module has no entry block".to_string(),
203 });
204 }
205
206 for (block_id, block) in &module.blocks {
208 self.validate_block(module, *block_id, block);
209 }
210
211 if self.level >= ValidationLevel::Full {
213 self.validate_types(module);
214 }
215
216 self.result
217 }
218
219 fn collect_definitions(&mut self, module: &IrModule) {
220 for param in &module.parameters {
222 self.defined_values.insert(param.value_id);
223 }
224
225 for value_id in module.values.keys() {
227 self.defined_values.insert(*value_id);
228 }
229
230 for block_id in module.blocks.keys() {
232 self.defined_blocks.insert(*block_id);
233 }
234 }
235
236 fn validate_block(&mut self, module: &IrModule, block_id: BlockId, block: &Block) {
237 if block.terminator.is_none() {
239 self.result.add_error(ValidationError {
240 kind: ValidationErrorKind::UnterminatedBlock,
241 location: Some(ValidationLocation {
242 block: Some(block_id),
243 instruction: None,
244 value: None,
245 }),
246 message: format!("Block {} is not terminated", block.label),
247 });
248 }
249
250 for (idx, inst) in block.instructions.iter().enumerate() {
252 self.validate_instruction(module, block_id, idx, &inst.node);
253 }
254
255 if let Some(term) = &block.terminator {
257 self.validate_terminator(block_id, term);
258 }
259 }
260
261 fn validate_instruction(
262 &mut self,
263 _module: &IrModule,
264 block_id: BlockId,
265 idx: usize,
266 node: &IrNode,
267 ) {
268 let location = ValidationLocation {
269 block: Some(block_id),
270 instruction: Some(idx),
271 value: None,
272 };
273
274 match node {
276 IrNode::BinaryOp(_, lhs, rhs) => {
277 self.check_value_defined(*lhs, &location);
278 self.check_value_defined(*rhs, &location);
279 }
280 IrNode::UnaryOp(_, val) => {
281 self.check_value_defined(*val, &location);
282 }
283 IrNode::Compare(_, lhs, rhs) => {
284 self.check_value_defined(*lhs, &location);
285 self.check_value_defined(*rhs, &location);
286 }
287 IrNode::Load(ptr) => {
288 self.check_value_defined(*ptr, &location);
289 }
290 IrNode::Store(ptr, val) => {
291 self.check_value_defined(*ptr, &location);
292 self.check_value_defined(*val, &location);
293 }
294 IrNode::Select(cond, then_val, else_val) => {
295 self.check_value_defined(*cond, &location);
296 self.check_value_defined(*then_val, &location);
297 self.check_value_defined(*else_val, &location);
298 }
299 IrNode::Phi(entries) => {
300 for (pred_block, val) in entries {
301 self.check_block_defined(*pred_block, &location);
302 self.check_value_defined(*val, &location);
303 }
304 }
305 _ => {}
306 }
307 }
308
309 fn validate_terminator(&mut self, block_id: BlockId, term: &Terminator) {
310 let location = ValidationLocation {
311 block: Some(block_id),
312 instruction: None,
313 value: None,
314 };
315
316 match term {
317 Terminator::Branch(target) => {
318 self.check_block_defined(*target, &location);
319 }
320 Terminator::CondBranch(cond, then_block, else_block) => {
321 self.check_value_defined(*cond, &location);
322 self.check_block_defined(*then_block, &location);
323 self.check_block_defined(*else_block, &location);
324 }
325 Terminator::Switch(val, default, cases) => {
326 self.check_value_defined(*val, &location);
327 self.check_block_defined(*default, &location);
328 for (_, target) in cases {
329 self.check_block_defined(*target, &location);
330 }
331 }
332 Terminator::Return(Some(val)) => {
333 self.check_value_defined(*val, &location);
334 }
335 Terminator::Return(None) | Terminator::Unreachable => {}
336 }
337 }
338
339 fn validate_types(&mut self, module: &IrModule) {
340 for block in module.blocks.values() {
341 for inst in &block.instructions {
342 if let Err(msg) =
343 self.check_instruction_types(module, &inst.node, &inst.result_type)
344 {
345 self.result.add_error(ValidationError {
346 kind: ValidationErrorKind::TypeMismatch,
347 location: Some(ValidationLocation {
348 block: Some(block.id),
349 instruction: None,
350 value: Some(inst.result),
351 }),
352 message: msg,
353 });
354 }
355 }
356 }
357 }
358
359 fn check_instruction_types(
360 &self,
361 module: &IrModule,
362 node: &IrNode,
363 _result_ty: &IrType,
364 ) -> Result<(), String> {
365 match node {
366 IrNode::BinaryOp(_, lhs, rhs) => {
367 let lhs_ty = self.get_value_type(module, *lhs);
368 let rhs_ty = self.get_value_type(module, *rhs);
369 if lhs_ty != rhs_ty {
370 return Err(format!(
371 "Binary operation operand types don't match: {} vs {}",
372 lhs_ty, rhs_ty
373 ));
374 }
375 }
376 IrNode::Compare(_, lhs, rhs) => {
377 let lhs_ty = self.get_value_type(module, *lhs);
378 let rhs_ty = self.get_value_type(module, *rhs);
379 if lhs_ty != rhs_ty {
380 return Err(format!(
381 "Comparison operand types don't match: {} vs {}",
382 lhs_ty, rhs_ty
383 ));
384 }
385 }
386 IrNode::Load(ptr) => {
387 let ptr_ty = self.get_value_type(module, *ptr);
388 if !ptr_ty.is_ptr() {
389 return Err(format!("Load requires pointer type, got {}", ptr_ty));
390 }
391 }
392 IrNode::Store(ptr, _val) => {
393 let ptr_ty = self.get_value_type(module, *ptr);
394 if !ptr_ty.is_ptr() {
395 return Err(format!("Store requires pointer type, got {}", ptr_ty));
396 }
397 }
398 _ => {}
399 }
400 Ok(())
401 }
402
403 fn get_value_type(&self, module: &IrModule, id: ValueId) -> IrType {
404 module
405 .values
406 .get(&id)
407 .map(|v| v.ty.clone())
408 .unwrap_or(IrType::Void)
409 }
410
411 fn check_value_defined(&mut self, id: ValueId, location: &ValidationLocation) {
412 if !self.defined_values.contains(&id) {
413 self.result.add_error(ValidationError {
414 kind: ValidationErrorKind::UndefinedValue,
415 location: Some(location.clone()),
416 message: format!("Value {} is not defined", id),
417 });
418 }
419 }
420
421 fn check_block_defined(&mut self, id: BlockId, location: &ValidationLocation) {
422 if !self.defined_blocks.contains(&id) {
423 self.result.add_error(ValidationError {
424 kind: ValidationErrorKind::UndefinedBlock,
425 location: Some(location.clone()),
426 message: format!("Block {} is not defined", id),
427 });
428 }
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435 use crate::IrBuilder;
436
437 #[test]
438 fn test_validation_success() {
439 let mut builder = IrBuilder::new("test");
440 builder.ret();
441 let module = builder.build();
442
443 let result = module.validate(ValidationLevel::Full);
444 assert!(result.is_ok());
445 }
446
447 #[test]
448 fn test_validation_unterminated_block() {
449 let module = IrModule::new("test");
450 let result = Validator::new(ValidationLevel::Basic).validate(&module);
453 assert!(!result.is_ok());
454 assert!(result
455 .errors
456 .iter()
457 .any(|e| e.kind == ValidationErrorKind::UnterminatedBlock));
458 }
459
460 #[test]
461 fn test_validation_level_none() {
462 let module = IrModule::new("test");
463 let result = Validator::new(ValidationLevel::None).validate(&module);
466 assert!(result.is_ok());
467 }
468
469 #[test]
470 fn test_validation_result_display() {
471 let error = ValidationError {
472 kind: ValidationErrorKind::TypeMismatch,
473 location: None,
474 message: "expected i32".to_string(),
475 };
476 let display = format!("{}", error);
477 assert!(display.contains("type mismatch"));
478 assert!(display.contains("expected i32"));
479 }
480}