Skip to main content

runmat_static_analysis/lints/
shape.rs

1use runmat_builtins::{BuiltinSemanticKind, ConcatKind, ShapeTransformKind};
2use runmat_hir::{BindingId, HirDiagnostic, HirDiagnosticSeverity, OperatorKind, Span};
3use runmat_mir::analysis::{AnalysisStore, MirLocalKey};
4use std::collections::HashMap;
5
6pub fn lint_shapes(result: &runmat_hir::LoweringResult) -> Vec<HirDiagnostic> {
7    let mir = match runmat_mir::lowering::lower_assembly(&result.assembly) {
8        Ok(mir) => mir,
9        Err(err) => return vec![mir_lowering_diagnostic(err)],
10    };
11    let store = runmat_mir::analysis::analyze_assembly(&mir);
12    let mut ctx = ShapeLintContext::default();
13    ctx.seed_from_analysis(&mir, &store);
14    ctx.walk_mir_assembly(&mir);
15    ctx.diagnostics
16}
17
18pub fn infer_binding_shapes(
19    result: &runmat_hir::LoweringResult,
20) -> HashMap<BindingId, Vec<Option<usize>>> {
21    let Ok(mir) = runmat_mir::lowering::lower_assembly(&result.assembly) else {
22        return HashMap::new();
23    };
24    let store = runmat_mir::analysis::analyze_assembly(&mir);
25    let mut ctx = ShapeLintContext::default();
26    ctx.seed_from_analysis(&mir, &store);
27    ctx.walk_mir_assembly(&mir);
28    ctx.env
29        .into_iter()
30        .map(|(binding, shape)| (binding, shape.0))
31        .collect()
32}
33
34fn mir_lowering_diagnostic(err: runmat_hir::HirError) -> HirDiagnostic {
35    HirDiagnostic::new(
36        "lint.mir.lowering_failed",
37        HirDiagnosticSeverity::Error,
38        format!("MIR lowering failed: {}", err.message),
39        err.span.unwrap_or(runmat_hir::Span { start: 0, end: 0 }),
40    )
41    .with_category("mir-lowering")
42}
43
44#[derive(Debug, Clone, PartialEq)]
45struct Shape(Vec<Option<usize>>);
46
47#[derive(Default)]
48struct ShapeLintContext {
49    env: HashMap<BindingId, Shape>,
50    local_env: HashMap<MirLocalKey, Shape>,
51    number_env: HashMap<MirLocalKey, f64>,
52    int_vector_env: HashMap<MirLocalKey, Vec<usize>>,
53    diagnostics: Vec<HirDiagnostic>,
54}
55
56#[derive(Default)]
57struct MirShapeValue {
58    shape: Option<Shape>,
59    number: Option<f64>,
60    int_vector: Option<Vec<usize>>,
61}
62
63impl ShapeLintContext {
64    fn seed_from_analysis(&mut self, mir: &runmat_mir::MirAssembly, store: &AnalysisStore) {
65        for body in mir.bodies.values() {
66            for local in &body.locals {
67                let Some(binding) = local.binding else {
68                    continue;
69                };
70                let Some(fact) = store.mir_locals.get(&MirLocalKey {
71                    function: body.function,
72                    local: local.id,
73                }) else {
74                    continue;
75                };
76                if let Some(shape) = shape_from_fact(&fact.shape) {
77                    self.env.insert(binding, shape);
78                }
79            }
80        }
81    }
82
83    fn walk_mir_assembly(&mut self, mir: &runmat_mir::MirAssembly) {
84        for body in mir.bodies.values() {
85            for block in &body.blocks {
86                for stmt in &block.statements {
87                    match &stmt.kind {
88                        runmat_mir::MirStmtKind::Assign { place, value } => {
89                            let value = self.infer_mir_rvalue(body, value, stmt.span);
90                            if let runmat_mir::MirPlace::Local(local) = place {
91                                self.record_mir_value(body, *local, value);
92                            }
93                        }
94                        runmat_mir::MirStmtKind::MultiAssign { value, .. }
95                        | runmat_mir::MirStmtKind::Expr(value) => {
96                            self.infer_mir_rvalue(body, value, stmt.span);
97                        }
98                        runmat_mir::MirStmtKind::PlaceMutation(_)
99                        | runmat_mir::MirStmtKind::WorkspaceEffect { .. }
100                        | runmat_mir::MirStmtKind::EnvironmentEffect(_) => {}
101                    }
102                }
103            }
104        }
105    }
106
107    fn record_mir_value(
108        &mut self,
109        body: &runmat_mir::MirBody,
110        local: runmat_mir::MirLocalId,
111        value: MirShapeValue,
112    ) {
113        let key = MirLocalKey {
114            function: body.function,
115            local,
116        };
117        if let Some(shape) = value.shape {
118            if let Some(binding) = body.locals.get(local.0).and_then(|local| local.binding) {
119                self.env.insert(binding, shape.clone());
120            }
121            self.local_env.insert(key, shape);
122        }
123        if let Some(number) = value.number {
124            self.number_env.insert(key, number);
125        }
126        if let Some(vector) = value.int_vector {
127            self.int_vector_env.insert(key, vector);
128        }
129    }
130
131    fn infer_mir_rvalue(
132        &mut self,
133        body: &runmat_mir::MirBody,
134        value: &runmat_mir::MirRvalue,
135        span: Span,
136    ) -> MirShapeValue {
137        match value {
138            runmat_mir::MirRvalue::Use(operand) => self.infer_mir_operand(body, operand),
139            runmat_mir::MirRvalue::Unary(op, operand) => {
140                let inner = self.infer_mir_operand(body, operand);
141                let number = match (op, inner.number) {
142                    (OperatorKind::UnaryMinus, Some(value)) => Some(-value),
143                    _ => None,
144                };
145                MirShapeValue {
146                    shape: inner.shape,
147                    number,
148                    int_vector: None,
149                }
150            }
151            runmat_mir::MirRvalue::Binary(left, op, right) => {
152                let lhs = self.infer_mir_operand(body, left);
153                let rhs = self.infer_mir_operand(body, right);
154                MirShapeValue {
155                    shape: self.infer_mir_binary(span, lhs.shape.as_ref(), op, rhs.shape.as_ref()),
156                    number: None,
157                    int_vector: None,
158                }
159            }
160            runmat_mir::MirRvalue::ShortCircuit {
161                left,
162                right_temps,
163                right,
164                ..
165            } => {
166                self.infer_mir_operand(body, left);
167                for stmt in right_temps {
168                    match &stmt.kind {
169                        runmat_mir::MirStmtKind::Assign { place, value } => {
170                            let inferred = self.infer_mir_rvalue(body, value, stmt.span);
171                            if let runmat_mir::MirPlace::Local(local) = place {
172                                self.record_mir_value(body, *local, inferred);
173                            }
174                        }
175                        runmat_mir::MirStmtKind::MultiAssign { value, .. }
176                        | runmat_mir::MirStmtKind::Expr(value) => {
177                            self.infer_mir_rvalue(body, value, stmt.span);
178                        }
179                        runmat_mir::MirStmtKind::PlaceMutation(_)
180                        | runmat_mir::MirStmtKind::WorkspaceEffect { .. }
181                        | runmat_mir::MirStmtKind::EnvironmentEffect(_) => {}
182                    }
183                }
184                self.infer_mir_operand(body, right);
185                MirShapeValue::default()
186            }
187            runmat_mir::MirRvalue::Range { start, step, end } => {
188                let start = self.infer_mir_operand(body, start).number;
189                let step = step
190                    .as_ref()
191                    .and_then(|step| self.infer_mir_operand(body, step).number)
192                    .unwrap_or(1.0);
193                let end = self.infer_mir_operand(body, end).number;
194                let width = start
195                    .zip(end)
196                    .and_then(|(start, end)| range_width(start, step, end));
197                MirShapeValue {
198                    shape: Some(Shape(vec![Some(1), width])),
199                    number: None,
200                    int_vector: None,
201                }
202            }
203            runmat_mir::MirRvalue::Call(call) => self.infer_mir_call(body, span, call),
204            runmat_mir::MirRvalue::Aggregate {
205                kind,
206                rows,
207                elements,
208                ..
209            } => self.infer_mir_aggregate(body, span, kind, *rows, elements),
210            runmat_mir::MirRvalue::StructLiteral { fields } => {
211                for (_, value) in fields {
212                    self.infer_mir_operand(body, value);
213                }
214                MirShapeValue {
215                    shape: Some(Shape(vec![Some(1), Some(1)])),
216                    number: None,
217                    int_vector: None,
218                }
219            }
220            runmat_mir::MirRvalue::ObjectLiteral { fields, .. } => {
221                for (_, value) in fields {
222                    self.infer_mir_operand(body, value);
223                }
224                MirShapeValue {
225                    shape: Some(Shape(vec![Some(1), Some(1)])),
226                    number: None,
227                    int_vector: None,
228                }
229            }
230            runmat_mir::MirRvalue::Index { base, indexing } => {
231                let base_shape = self.infer_mir_operand(body, base).shape;
232                for component in &indexing.components {
233                    match component {
234                        runmat_mir::MirIndexComponent::Expr(operand) => {
235                            let idx_shape = self.infer_mir_operand(body, operand).shape;
236                            if indexing.components.len() == 1 {
237                                self.check_logical_index(
238                                    span,
239                                    base_shape.as_ref(),
240                                    idx_shape.as_ref(),
241                                );
242                            }
243                        }
244                        runmat_mir::MirIndexComponent::Colon
245                        | runmat_mir::MirIndexComponent::End { .. } => {}
246                    }
247                }
248                MirShapeValue::default()
249            }
250            runmat_mir::MirRvalue::Member { base, .. } => self.infer_mir_operand(body, base),
251            runmat_mir::MirRvalue::DynamicMember { base, member } => {
252                self.infer_mir_operand(body, member);
253                self.infer_mir_operand(body, base)
254            }
255            runmat_mir::MirRvalue::Future { args, .. } => {
256                for arg in args {
257                    self.infer_mir_operand(body, arg.operand());
258                }
259                MirShapeValue::default()
260            }
261            runmat_mir::MirRvalue::Spawn(operand) => self.infer_mir_operand(body, operand),
262            runmat_mir::MirRvalue::MetaClass(_)
263            | runmat_mir::MirRvalue::Colon
264            | runmat_mir::MirRvalue::End => MirShapeValue::default(),
265        }
266    }
267
268    fn infer_mir_operand(
269        &self,
270        body: &runmat_mir::MirBody,
271        operand: &runmat_mir::MirOperand,
272    ) -> MirShapeValue {
273        match operand {
274            runmat_mir::MirOperand::Constant(runmat_mir::MirConstant::Number(value)) => {
275                MirShapeValue {
276                    shape: Some(Shape(vec![Some(1), Some(1)])),
277                    number: value.parse().ok(),
278                    int_vector: None,
279                }
280            }
281            runmat_mir::MirOperand::Local(local) => {
282                let key = MirLocalKey {
283                    function: body.function,
284                    local: *local,
285                };
286                MirShapeValue {
287                    shape: self.local_env.get(&key).cloned(),
288                    number: self.number_env.get(&key).copied(),
289                    int_vector: self.int_vector_env.get(&key).cloned(),
290                }
291            }
292            runmat_mir::MirOperand::Constant(_) | runmat_mir::MirOperand::FunctionHandle(_) => {
293                MirShapeValue::default()
294            }
295        }
296    }
297
298    fn infer_mir_binary(
299        &mut self,
300        span: Span,
301        lhs: Option<&Shape>,
302        op: &OperatorKind,
303        rhs: Option<&Shape>,
304    ) -> Option<Shape> {
305        match op {
306            OperatorKind::MatrixMultiply => {
307                if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
308                    if matrix_dims(lhs)
309                        .zip(matrix_dims(rhs))
310                        .is_some_and(|((_, lc), (rr, _))| lc.is_some() && rr.is_some() && lc != rr)
311                    {
312                        self.warn(
313                            "lint.shape.matmul",
314                            "matrix multiply inner dimensions must match",
315                            span,
316                        );
317                    }
318                }
319                match (lhs.and_then(matrix_dims), rhs.and_then(matrix_dims)) {
320                    (Some((rows, _)), Some((_, cols))) => Some(Shape(vec![rows, cols])),
321                    _ => None,
322                }
323            }
324            OperatorKind::Add
325            | OperatorKind::Subtract
326            | OperatorKind::ElementwiseMultiply
327            | OperatorKind::ElementwiseDivide
328            | OperatorKind::ElementwiseLeftDivide
329            | OperatorKind::ElementwisePower
330            | OperatorKind::Greater
331            | OperatorKind::GreaterEqual
332            | OperatorKind::Less
333            | OperatorKind::LessEqual
334            | OperatorKind::Equal
335            | OperatorKind::NotEqual => {
336                if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
337                    if !broadcast_compatible(lhs, rhs) {
338                        self.warn(
339                            "lint.shape.broadcast",
340                            "array dimensions are not broadcast compatible",
341                            span,
342                        );
343                    }
344                }
345                lhs.cloned().or_else(|| rhs.cloned())
346            }
347            _ => lhs.cloned().or_else(|| rhs.cloned()),
348        }
349    }
350
351    fn infer_mir_aggregate(
352        &mut self,
353        body: &runmat_mir::MirBody,
354        span: Span,
355        kind: &runmat_mir::MirAggregateKind,
356        rows: usize,
357        elements: &[runmat_mir::MirOperand],
358    ) -> MirShapeValue {
359        let values: Vec<_> = elements
360            .iter()
361            .map(|element| self.infer_mir_operand(body, element))
362            .collect();
363        let int_vector = values
364            .iter()
365            .map(|value| value.number.and_then(number_to_int))
366            .collect::<Option<Vec<_>>>();
367        let shape = match kind {
368            runmat_mir::MirAggregateKind::Tensor => {
369                let row_count = rows.max(1);
370                let cols_per_row = if row_count == 0 {
371                    0
372                } else {
373                    elements.len() / row_count
374                };
375                let mut row_dims = Vec::new();
376                for row_idx in 0..row_count {
377                    let start = row_idx * cols_per_row;
378                    let end = start + cols_per_row;
379                    let mut total_cols = 0usize;
380                    let mut expected_rows = None;
381                    for value in &values[start..end] {
382                        if let Some((rows, cols)) = value.shape.as_ref().and_then(matrix_dims) {
383                            if let (Some(expected), Some(rows)) = (expected_rows, rows) {
384                                if expected != rows {
385                                    self.warn(
386                                        "lint.shape.horzcat",
387                                        "horizontal concatenation row dimensions do not agree",
388                                        span,
389                                    );
390                                }
391                            }
392                            expected_rows = expected_rows.or(rows);
393                            total_cols += cols.unwrap_or(1);
394                        } else {
395                            total_cols += 1;
396                        }
397                    }
398                    row_dims.push((expected_rows.unwrap_or(1), total_cols));
399                }
400                if let Some((_, first_cols)) = row_dims.first().copied() {
401                    for (_, cols) in &row_dims {
402                        if *cols != first_cols {
403                            self.warn(
404                                "lint.shape.vertcat",
405                                "vertical concatenation column dimensions do not agree",
406                                span,
407                            );
408                        }
409                    }
410                    Some(Shape(vec![
411                        Some(row_dims.iter().map(|(rows, _)| rows).sum()),
412                        Some(first_cols),
413                    ]))
414                } else {
415                    Some(Shape(vec![Some(0), Some(0)]))
416                }
417            }
418            runmat_mir::MirAggregateKind::Cell => Some(Shape(vec![Some(1), Some(elements.len())])),
419        };
420        MirShapeValue {
421            shape,
422            number: None,
423            int_vector,
424        }
425    }
426
427    fn infer_mir_call(
428        &mut self,
429        body: &runmat_mir::MirBody,
430        span: Span,
431        call: &runmat_mir::MirCall,
432    ) -> MirShapeValue {
433        let arg_values: Vec<_> = call
434            .args
435            .iter()
436            .map(|arg| self.infer_mir_operand(body, arg.operand()))
437            .collect();
438        let shape = match call.semantic_kind {
439            BuiltinSemanticKind::Elementwise => {
440                arg_values.first().and_then(|value| value.shape.clone())
441            }
442            BuiltinSemanticKind::ArrayConstructor => sized_constructor_shape(&arg_values),
443            BuiltinSemanticKind::ParameterizedArrayConstructor => {
444                sized_constructor_shape(arg_values.get(1..).unwrap_or(&[]))
445            }
446            BuiltinSemanticKind::PermutationConstructor => Some(Shape(vec![
447                Some(1),
448                arg_values
449                    .first()
450                    .and_then(|value| value.number.and_then(number_to_int)),
451            ])),
452            BuiltinSemanticKind::RangeConstructor => Some(Shape(vec![Some(1), None])),
453            BuiltinSemanticKind::EmptyConstructor => Some(Shape(vec![Some(0), Some(0)])),
454            BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Dot) => {
455                let lhs = arg_values.first().and_then(|value| value.shape.as_ref());
456                let rhs = arg_values.get(1).and_then(|value| value.shape.as_ref());
457                if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
458                    if vector_len(lhs)
459                        .zip(vector_len(rhs))
460                        .is_some_and(|(l, r)| l != r)
461                    {
462                        self.warn(
463                            "lint.shape.dot",
464                            "dot product vector lengths do not agree",
465                            span,
466                        );
467                    }
468                }
469                Some(Shape(vec![Some(1), Some(1)]))
470            }
471            BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Reshape) => {
472                let input = arg_values.first().and_then(|value| value.shape.as_ref());
473                let dims = mir_parse_dims(&arg_values[1..]);
474                if dims.iter().filter(|dim| matches!(dim, Dim::Infer)).count() > 1
475                    || incompatible_element_count(input, &dims)
476                {
477                    self.warn(
478                        "lint.shape.reshape",
479                        "reshape dimensions are not compatible",
480                        span,
481                    );
482                }
483                Some(Shape(dims.iter().map(|dim| dim.as_shape_dim()).collect()))
484            }
485            BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Repmat) => {
486                for arg in &arg_values[1..] {
487                    if !matches!(mir_parse_dim(arg), Dim::Known(_)) {
488                        self.warn(
489                            "lint.shape.repmat",
490                            "repmat dimensions must be non-negative integers",
491                            span,
492                        );
493                    }
494                }
495                arg_values.first().and_then(|value| value.shape.clone())
496            }
497            BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Permute) => {
498                let base = arg_values.first().and_then(|value| value.shape.clone());
499                let order = arg_values.get(1).and_then(|value| value.int_vector.clone());
500                if let Some(order) = &order {
501                    let mut sorted = order.clone();
502                    sorted.sort_unstable();
503                    if sorted.windows(2).any(|pair| pair[0] == pair[1])
504                        || base
505                            .as_ref()
506                            .is_some_and(|shape| order.len() != shape.0.len())
507                    {
508                        self.warn(
509                            "lint.shape.permute",
510                            "permute order is invalid for input rank",
511                            span,
512                        );
513                    }
514                }
515                base
516            }
517            BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Transpose) => {
518                let base = arg_values.first().and_then(|value| value.shape.clone());
519                base.map(|shape| {
520                    if shape.0.len() >= 2 {
521                        Shape(vec![shape.0[1], shape.0[0]])
522                    } else {
523                        shape
524                    }
525                })
526            }
527            BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Concatenate(kind)) => {
528                self.infer_mir_concat(span, kind, &arg_values)
529            }
530            BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::General) => {
531                arg_values.first().and_then(|value| value.shape.clone())
532            }
533            BuiltinSemanticKind::Reduction => {
534                let base = arg_values.first().and_then(|value| value.shape.clone());
535                if let (Some(base_shape), Some(dim)) = (
536                    base.as_ref(),
537                    arg_values
538                        .get(1)
539                        .and_then(|value| value.number.and_then(number_to_int)),
540                ) {
541                    if dim == 0 || dim > base_shape.0.len() {
542                        self.warn(
543                            "lint.shape.reduction",
544                            "reduction dimension is out of range",
545                            span,
546                        );
547                    }
548                }
549                base
550            }
551            _ => None,
552        };
553        MirShapeValue {
554            shape,
555            number: None,
556            int_vector: None,
557        }
558    }
559
560    fn infer_mir_concat(
561        &mut self,
562        span: Span,
563        kind: ConcatKind,
564        arg_values: &[MirShapeValue],
565    ) -> Option<Shape> {
566        let (dim, values) = match kind {
567            ConcatKind::Dimension => {
568                let dim = arg_values
569                    .first()
570                    .and_then(|value| value.number.and_then(number_to_int))?;
571                (dim, &arg_values[1..])
572            }
573            ConcatKind::Horizontal => (2, arg_values),
574            ConcatKind::Vertical => (1, arg_values),
575        };
576        let shapes: Vec<_> = values
577            .iter()
578            .filter_map(|value| value.shape.as_ref())
579            .collect();
580        if shapes.is_empty() || dim == 0 {
581            return None;
582        }
583        let rank = shapes
584            .iter()
585            .map(|shape| shape.0.len())
586            .max()
587            .unwrap_or(dim);
588        let axis = dim - 1;
589        if axis >= rank {
590            return None;
591        }
592        let mut out = vec![Some(1); rank];
593        for (idx, out_dim) in out.iter_mut().enumerate().take(rank) {
594            if idx == axis {
595                *out_dim = shapes
596                    .iter()
597                    .map(|shape| shape.0.get(idx).copied().flatten())
598                    .try_fold(0usize, |sum, dim| dim.map(|dim| sum + dim));
599                continue;
600            }
601            let mut expected = None;
602            for shape in &shapes {
603                let dim = shape.0.get(idx).copied().flatten().or(Some(1));
604                if let (Some(expected), Some(dim)) = (expected, dim) {
605                    if expected != dim {
606                        self.warn(
607                            "lint.shape.concat",
608                            "concatenation dimensions do not agree",
609                            span,
610                        );
611                    }
612                }
613                expected = expected.or(dim);
614            }
615            *out_dim = expected;
616        }
617        Some(Shape(out))
618    }
619
620    fn check_logical_index(&mut self, span: Span, base: Option<&Shape>, idx: Option<&Shape>) {
621        if let (Some(base), Some(idx)) = (base, idx) {
622            if element_count(base)
623                .zip(element_count(idx))
624                .is_some_and(|(base, idx)| base != idx)
625            {
626                self.warn(
627                    "lint.shape.logical_index",
628                    "logical index shape does not match indexed value",
629                    span,
630                );
631            }
632        }
633    }
634
635    fn warn(&mut self, code: &'static str, message: &'static str, span: Span) {
636        self.diagnostics.push(
637            HirDiagnostic::new(code, HirDiagnosticSeverity::Warning, message, span)
638                .with_category("shape"),
639        );
640    }
641}
642
643#[derive(Clone, Copy, PartialEq)]
644enum Dim {
645    Known(usize),
646    Infer,
647    Unknown,
648}
649
650impl Dim {
651    fn as_shape_dim(self) -> Option<usize> {
652        match self {
653            Dim::Known(value) => Some(value),
654            Dim::Infer | Dim::Unknown => None,
655        }
656    }
657}
658
659fn number_to_int(value: f64) -> Option<usize> {
660    if value.is_finite() && value >= 0.0 && (value.fract().abs() <= 1e-9) {
661        Some(value as usize)
662    } else {
663        None
664    }
665}
666
667fn mir_parse_dim(value: &MirShapeValue) -> Dim {
668    match value.number {
669        Some(-1.0) => Dim::Infer,
670        Some(value) => number_to_int(value).map(Dim::Known).unwrap_or(Dim::Unknown),
671        None => Dim::Unknown,
672    }
673}
674
675fn mir_parse_dims(args: &[MirShapeValue]) -> Vec<Dim> {
676    if args.len() == 1 {
677        if let Some(values) = &args[0].int_vector {
678            return values.iter().copied().map(Dim::Known).collect();
679        }
680    }
681    args.iter().map(mir_parse_dim).collect()
682}
683
684fn sized_constructor_shape(args: &[MirShapeValue]) -> Option<Shape> {
685    let dims: Vec<_> = args
686        .iter()
687        .filter_map(|value| value.number.and_then(number_to_int))
688        .map(Some)
689        .collect();
690    match dims.as_slice() {
691        [] => None,
692        [dim] => Some(Shape(vec![*dim, *dim])),
693        _ => Some(Shape(dims)),
694    }
695}
696
697fn shape_from_fact(shape: &runmat_hir::ShapeFact) -> Option<Shape> {
698    match shape {
699        runmat_hir::ShapeFact::Scalar => Some(Shape(vec![Some(1), Some(1)])),
700        runmat_hir::ShapeFact::Shaped { dims } => Some(Shape(
701            dims.iter()
702                .map(|dim| match dim {
703                    runmat_hir::DimFact::Known(value) => Some(*value),
704                    runmat_hir::DimFact::Symbolic(_) | runmat_hir::DimFact::Unknown => None,
705                })
706                .collect(),
707        )),
708        runmat_hir::ShapeFact::Ranked { .. }
709        | runmat_hir::ShapeFact::Unknown
710        | runmat_hir::ShapeFact::Unreachable => None,
711    }
712}
713
714fn range_width(start: f64, step: f64, end: f64) -> Option<usize> {
715    if step == 0.0 || !start.is_finite() || !step.is_finite() || !end.is_finite() {
716        return None;
717    }
718    let span = end - start;
719    if (span > 0.0 && step < 0.0) || (span < 0.0 && step > 0.0) {
720        return Some(0);
721    }
722    Some((span / step).floor().abs() as usize + 1)
723}
724
725fn matrix_dims(shape: &Shape) -> Option<(Option<usize>, Option<usize>)> {
726    Some((*shape.0.first()?, *shape.0.get(1)?))
727}
728
729fn element_count(shape: &Shape) -> Option<usize> {
730    shape
731        .0
732        .iter()
733        .try_fold(1usize, |acc, dim| dim.map(|dim| acc * dim))
734}
735
736fn vector_len(shape: &Shape) -> Option<usize> {
737    let count = element_count(shape)?;
738    if shape.0.len() == 1
739        || (shape.0.len() == 2 && (shape.0[0] == Some(1) || shape.0[1] == Some(1)))
740    {
741        Some(count)
742    } else {
743        None
744    }
745}
746
747fn broadcast_compatible(left: &Shape, right: &Shape) -> bool {
748    let len = left.0.len().max(right.0.len());
749    (0..len).all(|idx| {
750        let l = left.0.iter().rev().nth(idx).copied().flatten().unwrap_or(1);
751        let r = right
752            .0
753            .iter()
754            .rev()
755            .nth(idx)
756            .copied()
757            .flatten()
758            .unwrap_or(1);
759        l == r || l == 1 || r == 1
760    })
761}
762
763fn incompatible_element_count(input: Option<&Shape>, dims: &[Dim]) -> bool {
764    let Some(input_count) = input.and_then(element_count) else {
765        return false;
766    };
767    if dims.iter().any(|dim| matches!(dim, Dim::Unknown)) {
768        return true;
769    }
770    let known_product = dims.iter().fold(1usize, |acc, dim| match dim {
771        Dim::Known(value) => acc * value,
772        Dim::Infer | Dim::Unknown => acc,
773    });
774    if dims.iter().any(|dim| matches!(dim, Dim::Infer)) {
775        known_product == 0 || input_count % known_product != 0
776    } else {
777        known_product != input_count
778    }
779}