Skip to main content

ringkernel_ir/
validation.rs

1//! IR validation.
2//!
3//! Validates IR modules for correctness before lowering to backend code.
4
5use std::collections::HashSet;
6
7use crate::{Block, BlockId, IrModule, IrNode, IrType, Terminator, ValueId};
8
9/// Validation strictness level.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
11pub enum ValidationLevel {
12    /// No validation.
13    None,
14    /// Basic structural validation.
15    Basic,
16    /// Full type checking and SSA validation.
17    Full,
18    /// Strict validation with warnings as errors.
19    Strict,
20}
21
22/// Result of validation.
23#[derive(Debug, Clone)]
24pub struct ValidationResult {
25    /// Errors found.
26    pub errors: Vec<ValidationError>,
27    /// Warnings found.
28    pub warnings: Vec<ValidationWarning>,
29}
30
31impl ValidationResult {
32    /// Create a successful result.
33    pub fn success() -> Self {
34        Self {
35            errors: Vec::new(),
36            warnings: Vec::new(),
37        }
38    }
39
40    /// Check if validation passed.
41    pub fn is_ok(&self) -> bool {
42        self.errors.is_empty()
43    }
44
45    /// Check if validation passed with no warnings.
46    pub fn is_clean(&self) -> bool {
47        self.errors.is_empty() && self.warnings.is_empty()
48    }
49
50    /// Add an error.
51    pub fn add_error(&mut self, error: ValidationError) {
52        self.errors.push(error);
53    }
54
55    /// Add a warning.
56    pub fn add_warning(&mut self, warning: ValidationWarning) {
57        self.warnings.push(warning);
58    }
59}
60
61/// Validation error.
62#[derive(Debug, Clone)]
63pub struct ValidationError {
64    /// Error kind.
65    pub kind: ValidationErrorKind,
66    /// Location in IR.
67    pub location: Option<ValidationLocation>,
68    /// Error message.
69    pub message: String,
70}
71
72impl std::fmt::Display for ValidationError {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        if let Some(loc) = &self.location {
75            write!(f, "{}: {}: {}", loc, self.kind, self.message)
76        } else {
77            write!(f, "{}: {}", self.kind, self.message)
78        }
79    }
80}
81
82/// Error kinds.
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
84pub enum ValidationErrorKind {
85    /// Type mismatch.
86    TypeMismatch,
87    /// Undefined value.
88    UndefinedValue,
89    /// Undefined block.
90    UndefinedBlock,
91    /// Unterminated block.
92    UnterminatedBlock,
93    /// Invalid operation.
94    InvalidOperation,
95    /// SSA violation.
96    SsaViolation,
97    /// Control flow error.
98    ControlFlow,
99    /// Missing entry block.
100    MissingEntry,
101}
102
103impl std::fmt::Display for ValidationErrorKind {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        match self {
106            ValidationErrorKind::TypeMismatch => write!(f, "type mismatch"),
107            ValidationErrorKind::UndefinedValue => write!(f, "undefined value"),
108            ValidationErrorKind::UndefinedBlock => write!(f, "undefined block"),
109            ValidationErrorKind::UnterminatedBlock => write!(f, "unterminated block"),
110            ValidationErrorKind::InvalidOperation => write!(f, "invalid operation"),
111            ValidationErrorKind::SsaViolation => write!(f, "SSA violation"),
112            ValidationErrorKind::ControlFlow => write!(f, "control flow error"),
113            ValidationErrorKind::MissingEntry => write!(f, "missing entry block"),
114        }
115    }
116}
117
118/// Validation warning.
119#[derive(Debug, Clone)]
120pub struct ValidationWarning {
121    /// Warning kind.
122    pub kind: ValidationWarningKind,
123    /// Location in IR.
124    pub location: Option<ValidationLocation>,
125    /// Warning message.
126    pub message: String,
127}
128
129/// Warning kinds.
130#[derive(Debug, Clone, Copy, PartialEq, Eq)]
131pub enum ValidationWarningKind {
132    /// Unused value.
133    UnusedValue,
134    /// Unreachable code.
135    UnreachableCode,
136    /// Potential performance issue.
137    Performance,
138    /// Deprecated feature.
139    Deprecated,
140}
141
142/// Location in IR for error reporting.
143#[derive(Debug, Clone)]
144pub struct ValidationLocation {
145    /// Block ID.
146    pub block: Option<BlockId>,
147    /// Instruction index.
148    pub instruction: Option<usize>,
149    /// Value ID.
150    pub value: Option<ValueId>,
151}
152
153impl std::fmt::Display for ValidationLocation {
154    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155        let mut parts = Vec::new();
156        if let Some(block) = &self.block {
157            parts.push(format!("block {}", block));
158        }
159        if let Some(inst) = &self.instruction {
160            parts.push(format!("instruction {}", inst));
161        }
162        if let Some(value) = &self.value {
163            parts.push(format!("value {}", value));
164        }
165        write!(f, "{}", parts.join(", "))
166    }
167}
168
169/// IR validator.
170pub struct Validator {
171    level: ValidationLevel,
172    result: ValidationResult,
173    defined_values: HashSet<ValueId>,
174    defined_blocks: HashSet<BlockId>,
175}
176
177impl Validator {
178    /// Create a new validator.
179    pub fn new(level: ValidationLevel) -> Self {
180        Self {
181            level,
182            result: ValidationResult::success(),
183            defined_values: HashSet::new(),
184            defined_blocks: HashSet::new(),
185        }
186    }
187
188    /// Validate a module.
189    pub fn validate(mut self, module: &IrModule) -> ValidationResult {
190        if self.level == ValidationLevel::None {
191            return ValidationResult::success();
192        }
193
194        // Collect defined values and blocks
195        self.collect_definitions(module);
196
197        // Check entry block exists
198        if !self.defined_blocks.contains(&module.entry_block) {
199            self.result.add_error(ValidationError {
200                kind: ValidationErrorKind::MissingEntry,
201                location: None,
202                message: "Module has no entry block".to_string(),
203            });
204        }
205
206        // Validate each block
207        for (block_id, block) in &module.blocks {
208            self.validate_block(module, *block_id, block);
209        }
210
211        // Full validation includes type checking
212        if self.level >= ValidationLevel::Full {
213            self.validate_types(module);
214        }
215
216        self.result
217    }
218
219    fn collect_definitions(&mut self, module: &IrModule) {
220        // Parameters define values
221        for param in &module.parameters {
222            self.defined_values.insert(param.value_id);
223        }
224
225        // Collect all values
226        for value_id in module.values.keys() {
227            self.defined_values.insert(*value_id);
228        }
229
230        // Collect all blocks
231        for block_id in module.blocks.keys() {
232            self.defined_blocks.insert(*block_id);
233        }
234    }
235
236    fn validate_block(&mut self, module: &IrModule, block_id: BlockId, block: &Block) {
237        // Check block is terminated
238        if block.terminator.is_none() {
239            self.result.add_error(ValidationError {
240                kind: ValidationErrorKind::UnterminatedBlock,
241                location: Some(ValidationLocation {
242                    block: Some(block_id),
243                    instruction: None,
244                    value: None,
245                }),
246                message: format!("Block {} is not terminated", block.label),
247            });
248        }
249
250        // Validate instructions
251        for (idx, inst) in block.instructions.iter().enumerate() {
252            self.validate_instruction(module, block_id, idx, &inst.node);
253        }
254
255        // Validate terminator
256        if let Some(term) = &block.terminator {
257            self.validate_terminator(block_id, term);
258        }
259    }
260
261    fn validate_instruction(
262        &mut self,
263        _module: &IrModule,
264        block_id: BlockId,
265        idx: usize,
266        node: &IrNode,
267    ) {
268        let location = ValidationLocation {
269            block: Some(block_id),
270            instruction: Some(idx),
271            value: None,
272        };
273
274        // Check value references
275        match node {
276            IrNode::BinaryOp(_, lhs, rhs) => {
277                self.check_value_defined(*lhs, &location);
278                self.check_value_defined(*rhs, &location);
279            }
280            IrNode::UnaryOp(_, val) => {
281                self.check_value_defined(*val, &location);
282            }
283            IrNode::Compare(_, lhs, rhs) => {
284                self.check_value_defined(*lhs, &location);
285                self.check_value_defined(*rhs, &location);
286            }
287            IrNode::Load(ptr) => {
288                self.check_value_defined(*ptr, &location);
289            }
290            IrNode::Store(ptr, val) => {
291                self.check_value_defined(*ptr, &location);
292                self.check_value_defined(*val, &location);
293            }
294            IrNode::Select(cond, then_val, else_val) => {
295                self.check_value_defined(*cond, &location);
296                self.check_value_defined(*then_val, &location);
297                self.check_value_defined(*else_val, &location);
298            }
299            IrNode::Phi(entries) => {
300                for (pred_block, val) in entries {
301                    self.check_block_defined(*pred_block, &location);
302                    self.check_value_defined(*val, &location);
303                }
304            }
305            _ => {}
306        }
307    }
308
309    fn validate_terminator(&mut self, block_id: BlockId, term: &Terminator) {
310        let location = ValidationLocation {
311            block: Some(block_id),
312            instruction: None,
313            value: None,
314        };
315
316        match term {
317            Terminator::Branch(target) => {
318                self.check_block_defined(*target, &location);
319            }
320            Terminator::CondBranch(cond, then_block, else_block) => {
321                self.check_value_defined(*cond, &location);
322                self.check_block_defined(*then_block, &location);
323                self.check_block_defined(*else_block, &location);
324            }
325            Terminator::Switch(val, default, cases) => {
326                self.check_value_defined(*val, &location);
327                self.check_block_defined(*default, &location);
328                for (_, target) in cases {
329                    self.check_block_defined(*target, &location);
330                }
331            }
332            Terminator::Return(Some(val)) => {
333                self.check_value_defined(*val, &location);
334            }
335            Terminator::Return(None) | Terminator::Unreachable => {}
336        }
337    }
338
339    fn validate_types(&mut self, module: &IrModule) {
340        for block in module.blocks.values() {
341            for inst in &block.instructions {
342                if let Err(msg) =
343                    self.check_instruction_types(module, &inst.node, &inst.result_type)
344                {
345                    self.result.add_error(ValidationError {
346                        kind: ValidationErrorKind::TypeMismatch,
347                        location: Some(ValidationLocation {
348                            block: Some(block.id),
349                            instruction: None,
350                            value: Some(inst.result),
351                        }),
352                        message: msg,
353                    });
354                }
355            }
356        }
357    }
358
359    fn check_instruction_types(
360        &self,
361        module: &IrModule,
362        node: &IrNode,
363        _result_ty: &IrType,
364    ) -> Result<(), String> {
365        match node {
366            IrNode::BinaryOp(_, lhs, rhs) => {
367                let lhs_ty = self.get_value_type(module, *lhs);
368                let rhs_ty = self.get_value_type(module, *rhs);
369                if lhs_ty != rhs_ty {
370                    return Err(format!(
371                        "Binary operation operand types don't match: {} vs {}",
372                        lhs_ty, rhs_ty
373                    ));
374                }
375            }
376            IrNode::Compare(_, lhs, rhs) => {
377                let lhs_ty = self.get_value_type(module, *lhs);
378                let rhs_ty = self.get_value_type(module, *rhs);
379                if lhs_ty != rhs_ty {
380                    return Err(format!(
381                        "Comparison operand types don't match: {} vs {}",
382                        lhs_ty, rhs_ty
383                    ));
384                }
385            }
386            IrNode::Load(ptr) => {
387                let ptr_ty = self.get_value_type(module, *ptr);
388                if !ptr_ty.is_ptr() {
389                    return Err(format!("Load requires pointer type, got {}", ptr_ty));
390                }
391            }
392            IrNode::Store(ptr, _val) => {
393                let ptr_ty = self.get_value_type(module, *ptr);
394                if !ptr_ty.is_ptr() {
395                    return Err(format!("Store requires pointer type, got {}", ptr_ty));
396                }
397            }
398            _ => {}
399        }
400        Ok(())
401    }
402
403    fn get_value_type(&self, module: &IrModule, id: ValueId) -> IrType {
404        module
405            .values
406            .get(&id)
407            .map(|v| v.ty.clone())
408            .unwrap_or(IrType::Void)
409    }
410
411    fn check_value_defined(&mut self, id: ValueId, location: &ValidationLocation) {
412        if !self.defined_values.contains(&id) {
413            self.result.add_error(ValidationError {
414                kind: ValidationErrorKind::UndefinedValue,
415                location: Some(location.clone()),
416                message: format!("Value {} is not defined", id),
417            });
418        }
419    }
420
421    fn check_block_defined(&mut self, id: BlockId, location: &ValidationLocation) {
422        if !self.defined_blocks.contains(&id) {
423            self.result.add_error(ValidationError {
424                kind: ValidationErrorKind::UndefinedBlock,
425                location: Some(location.clone()),
426                message: format!("Block {} is not defined", id),
427            });
428        }
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435    use crate::IrBuilder;
436
437    #[test]
438    fn test_validation_success() {
439        let mut builder = IrBuilder::new("test");
440        builder.ret();
441        let module = builder.build();
442
443        let result = module.validate(ValidationLevel::Full);
444        assert!(result.is_ok());
445    }
446
447    #[test]
448    fn test_validation_unterminated_block() {
449        let module = IrModule::new("test");
450        // Entry block has no terminator
451
452        let result = Validator::new(ValidationLevel::Basic).validate(&module);
453        assert!(!result.is_ok());
454        assert!(result
455            .errors
456            .iter()
457            .any(|e| e.kind == ValidationErrorKind::UnterminatedBlock));
458    }
459
460    #[test]
461    fn test_validation_level_none() {
462        let module = IrModule::new("test");
463        // No terminator, but validation level is None
464
465        let result = Validator::new(ValidationLevel::None).validate(&module);
466        assert!(result.is_ok());
467    }
468
469    #[test]
470    fn test_validation_result_display() {
471        let error = ValidationError {
472            kind: ValidationErrorKind::TypeMismatch,
473            location: None,
474            message: "expected i32".to_string(),
475        };
476        let display = format!("{}", error);
477        assert!(display.contains("type mismatch"));
478        assert!(display.contains("expected i32"));
479    }
480}