1use runmat_builtins::{BuiltinSemanticKind, ConcatKind, ShapeTransformKind};
2use runmat_hir::{BindingId, HirDiagnostic, HirDiagnosticSeverity, OperatorKind, Span};
3use runmat_mir::analysis::{AnalysisStore, MirLocalKey};
4use std::collections::HashMap;
5
6pub fn lint_shapes(result: &runmat_hir::LoweringResult) -> Vec<HirDiagnostic> {
7 let mir = match runmat_mir::lowering::lower_assembly(&result.assembly) {
8 Ok(mir) => mir,
9 Err(err) => return vec![mir_lowering_diagnostic(err)],
10 };
11 let store = runmat_mir::analysis::analyze_assembly(&mir);
12 let mut ctx = ShapeLintContext::default();
13 ctx.seed_from_analysis(&mir, &store);
14 ctx.walk_mir_assembly(&mir);
15 ctx.diagnostics
16}
17
18pub fn infer_binding_shapes(
19 result: &runmat_hir::LoweringResult,
20) -> HashMap<BindingId, Vec<Option<usize>>> {
21 let Ok(mir) = runmat_mir::lowering::lower_assembly(&result.assembly) else {
22 return HashMap::new();
23 };
24 let store = runmat_mir::analysis::analyze_assembly(&mir);
25 let mut ctx = ShapeLintContext::default();
26 ctx.seed_from_analysis(&mir, &store);
27 ctx.walk_mir_assembly(&mir);
28 ctx.env
29 .into_iter()
30 .map(|(binding, shape)| (binding, shape.0))
31 .collect()
32}
33
34fn mir_lowering_diagnostic(err: runmat_hir::HirError) -> HirDiagnostic {
35 HirDiagnostic::new(
36 "lint.mir.lowering_failed",
37 HirDiagnosticSeverity::Error,
38 format!("MIR lowering failed: {}", err.message),
39 err.span.unwrap_or(runmat_hir::Span { start: 0, end: 0 }),
40 )
41 .with_category("mir-lowering")
42}
43
44#[derive(Debug, Clone, PartialEq)]
45struct Shape(Vec<Option<usize>>);
46
47#[derive(Default)]
48struct ShapeLintContext {
49 env: HashMap<BindingId, Shape>,
50 local_env: HashMap<MirLocalKey, Shape>,
51 number_env: HashMap<MirLocalKey, f64>,
52 int_vector_env: HashMap<MirLocalKey, Vec<usize>>,
53 diagnostics: Vec<HirDiagnostic>,
54}
55
56#[derive(Default)]
57struct MirShapeValue {
58 shape: Option<Shape>,
59 number: Option<f64>,
60 int_vector: Option<Vec<usize>>,
61}
62
63impl ShapeLintContext {
64 fn seed_from_analysis(&mut self, mir: &runmat_mir::MirAssembly, store: &AnalysisStore) {
65 for body in mir.bodies.values() {
66 for local in &body.locals {
67 let Some(binding) = local.binding else {
68 continue;
69 };
70 let Some(fact) = store.mir_locals.get(&MirLocalKey {
71 function: body.function,
72 local: local.id,
73 }) else {
74 continue;
75 };
76 if let Some(shape) = shape_from_fact(&fact.shape) {
77 self.env.insert(binding, shape);
78 }
79 }
80 }
81 }
82
83 fn walk_mir_assembly(&mut self, mir: &runmat_mir::MirAssembly) {
84 for body in mir.bodies.values() {
85 for block in &body.blocks {
86 for stmt in &block.statements {
87 match &stmt.kind {
88 runmat_mir::MirStmtKind::Assign { place, value } => {
89 let value = self.infer_mir_rvalue(body, value, stmt.span);
90 if let runmat_mir::MirPlace::Local(local) = place {
91 self.record_mir_value(body, *local, value);
92 }
93 }
94 runmat_mir::MirStmtKind::MultiAssign { value, .. }
95 | runmat_mir::MirStmtKind::Expr(value) => {
96 self.infer_mir_rvalue(body, value, stmt.span);
97 }
98 runmat_mir::MirStmtKind::PlaceMutation(_)
99 | runmat_mir::MirStmtKind::WorkspaceEffect { .. }
100 | runmat_mir::MirStmtKind::EnvironmentEffect(_) => {}
101 }
102 }
103 }
104 }
105 }
106
107 fn record_mir_value(
108 &mut self,
109 body: &runmat_mir::MirBody,
110 local: runmat_mir::MirLocalId,
111 value: MirShapeValue,
112 ) {
113 let key = MirLocalKey {
114 function: body.function,
115 local,
116 };
117 if let Some(shape) = value.shape {
118 if let Some(binding) = body.locals.get(local.0).and_then(|local| local.binding) {
119 self.env.insert(binding, shape.clone());
120 }
121 self.local_env.insert(key, shape);
122 }
123 if let Some(number) = value.number {
124 self.number_env.insert(key, number);
125 }
126 if let Some(vector) = value.int_vector {
127 self.int_vector_env.insert(key, vector);
128 }
129 }
130
131 fn infer_mir_rvalue(
132 &mut self,
133 body: &runmat_mir::MirBody,
134 value: &runmat_mir::MirRvalue,
135 span: Span,
136 ) -> MirShapeValue {
137 match value {
138 runmat_mir::MirRvalue::Use(operand) => self.infer_mir_operand(body, operand),
139 runmat_mir::MirRvalue::Unary(op, operand) => {
140 let inner = self.infer_mir_operand(body, operand);
141 let number = match (op, inner.number) {
142 (OperatorKind::UnaryMinus, Some(value)) => Some(-value),
143 _ => None,
144 };
145 MirShapeValue {
146 shape: inner.shape,
147 number,
148 int_vector: None,
149 }
150 }
151 runmat_mir::MirRvalue::Binary(left, op, right) => {
152 let lhs = self.infer_mir_operand(body, left);
153 let rhs = self.infer_mir_operand(body, right);
154 MirShapeValue {
155 shape: self.infer_mir_binary(span, lhs.shape.as_ref(), op, rhs.shape.as_ref()),
156 number: None,
157 int_vector: None,
158 }
159 }
160 runmat_mir::MirRvalue::ShortCircuit {
161 left,
162 right_temps,
163 right,
164 ..
165 } => {
166 self.infer_mir_operand(body, left);
167 for stmt in right_temps {
168 match &stmt.kind {
169 runmat_mir::MirStmtKind::Assign { place, value } => {
170 let inferred = self.infer_mir_rvalue(body, value, stmt.span);
171 if let runmat_mir::MirPlace::Local(local) = place {
172 self.record_mir_value(body, *local, inferred);
173 }
174 }
175 runmat_mir::MirStmtKind::MultiAssign { value, .. }
176 | runmat_mir::MirStmtKind::Expr(value) => {
177 self.infer_mir_rvalue(body, value, stmt.span);
178 }
179 runmat_mir::MirStmtKind::PlaceMutation(_)
180 | runmat_mir::MirStmtKind::WorkspaceEffect { .. }
181 | runmat_mir::MirStmtKind::EnvironmentEffect(_) => {}
182 }
183 }
184 self.infer_mir_operand(body, right);
185 MirShapeValue::default()
186 }
187 runmat_mir::MirRvalue::Range { start, step, end } => {
188 let start = self.infer_mir_operand(body, start).number;
189 let step = step
190 .as_ref()
191 .and_then(|step| self.infer_mir_operand(body, step).number)
192 .unwrap_or(1.0);
193 let end = self.infer_mir_operand(body, end).number;
194 let width = start
195 .zip(end)
196 .and_then(|(start, end)| range_width(start, step, end));
197 MirShapeValue {
198 shape: Some(Shape(vec![Some(1), width])),
199 number: None,
200 int_vector: None,
201 }
202 }
203 runmat_mir::MirRvalue::Call(call) => self.infer_mir_call(body, span, call),
204 runmat_mir::MirRvalue::Aggregate {
205 kind,
206 rows,
207 elements,
208 ..
209 } => self.infer_mir_aggregate(body, span, kind, *rows, elements),
210 runmat_mir::MirRvalue::StructLiteral { fields } => {
211 for (_, value) in fields {
212 self.infer_mir_operand(body, value);
213 }
214 MirShapeValue {
215 shape: Some(Shape(vec![Some(1), Some(1)])),
216 number: None,
217 int_vector: None,
218 }
219 }
220 runmat_mir::MirRvalue::ObjectLiteral { fields, .. } => {
221 for (_, value) in fields {
222 self.infer_mir_operand(body, value);
223 }
224 MirShapeValue {
225 shape: Some(Shape(vec![Some(1), Some(1)])),
226 number: None,
227 int_vector: None,
228 }
229 }
230 runmat_mir::MirRvalue::Index { base, indexing } => {
231 let base_shape = self.infer_mir_operand(body, base).shape;
232 for component in &indexing.components {
233 match component {
234 runmat_mir::MirIndexComponent::Expr(operand) => {
235 let idx_shape = self.infer_mir_operand(body, operand).shape;
236 if indexing.components.len() == 1 {
237 self.check_logical_index(
238 span,
239 base_shape.as_ref(),
240 idx_shape.as_ref(),
241 );
242 }
243 }
244 runmat_mir::MirIndexComponent::Colon
245 | runmat_mir::MirIndexComponent::End { .. } => {}
246 }
247 }
248 MirShapeValue::default()
249 }
250 runmat_mir::MirRvalue::Member { base, .. } => self.infer_mir_operand(body, base),
251 runmat_mir::MirRvalue::DynamicMember { base, member } => {
252 self.infer_mir_operand(body, member);
253 self.infer_mir_operand(body, base)
254 }
255 runmat_mir::MirRvalue::Future { args, .. } => {
256 for arg in args {
257 self.infer_mir_operand(body, arg.operand());
258 }
259 MirShapeValue::default()
260 }
261 runmat_mir::MirRvalue::Spawn(operand) => self.infer_mir_operand(body, operand),
262 runmat_mir::MirRvalue::MetaClass(_)
263 | runmat_mir::MirRvalue::Colon
264 | runmat_mir::MirRvalue::End => MirShapeValue::default(),
265 }
266 }
267
268 fn infer_mir_operand(
269 &self,
270 body: &runmat_mir::MirBody,
271 operand: &runmat_mir::MirOperand,
272 ) -> MirShapeValue {
273 match operand {
274 runmat_mir::MirOperand::Constant(runmat_mir::MirConstant::Number(value)) => {
275 MirShapeValue {
276 shape: Some(Shape(vec![Some(1), Some(1)])),
277 number: value.parse().ok(),
278 int_vector: None,
279 }
280 }
281 runmat_mir::MirOperand::Local(local) => {
282 let key = MirLocalKey {
283 function: body.function,
284 local: *local,
285 };
286 MirShapeValue {
287 shape: self.local_env.get(&key).cloned(),
288 number: self.number_env.get(&key).copied(),
289 int_vector: self.int_vector_env.get(&key).cloned(),
290 }
291 }
292 runmat_mir::MirOperand::Constant(_) | runmat_mir::MirOperand::FunctionHandle(_) => {
293 MirShapeValue::default()
294 }
295 }
296 }
297
298 fn infer_mir_binary(
299 &mut self,
300 span: Span,
301 lhs: Option<&Shape>,
302 op: &OperatorKind,
303 rhs: Option<&Shape>,
304 ) -> Option<Shape> {
305 match op {
306 OperatorKind::MatrixMultiply => {
307 if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
308 if matrix_dims(lhs)
309 .zip(matrix_dims(rhs))
310 .is_some_and(|((_, lc), (rr, _))| lc.is_some() && rr.is_some() && lc != rr)
311 {
312 self.warn(
313 "lint.shape.matmul",
314 "matrix multiply inner dimensions must match",
315 span,
316 );
317 }
318 }
319 match (lhs.and_then(matrix_dims), rhs.and_then(matrix_dims)) {
320 (Some((rows, _)), Some((_, cols))) => Some(Shape(vec![rows, cols])),
321 _ => None,
322 }
323 }
324 OperatorKind::Add
325 | OperatorKind::Subtract
326 | OperatorKind::ElementwiseMultiply
327 | OperatorKind::ElementwiseDivide
328 | OperatorKind::ElementwiseLeftDivide
329 | OperatorKind::ElementwisePower
330 | OperatorKind::Greater
331 | OperatorKind::GreaterEqual
332 | OperatorKind::Less
333 | OperatorKind::LessEqual
334 | OperatorKind::Equal
335 | OperatorKind::NotEqual => {
336 if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
337 if !broadcast_compatible(lhs, rhs) {
338 self.warn(
339 "lint.shape.broadcast",
340 "array dimensions are not broadcast compatible",
341 span,
342 );
343 }
344 }
345 lhs.cloned().or_else(|| rhs.cloned())
346 }
347 _ => lhs.cloned().or_else(|| rhs.cloned()),
348 }
349 }
350
351 fn infer_mir_aggregate(
352 &mut self,
353 body: &runmat_mir::MirBody,
354 span: Span,
355 kind: &runmat_mir::MirAggregateKind,
356 rows: usize,
357 elements: &[runmat_mir::MirOperand],
358 ) -> MirShapeValue {
359 let values: Vec<_> = elements
360 .iter()
361 .map(|element| self.infer_mir_operand(body, element))
362 .collect();
363 let int_vector = values
364 .iter()
365 .map(|value| value.number.and_then(number_to_int))
366 .collect::<Option<Vec<_>>>();
367 let shape = match kind {
368 runmat_mir::MirAggregateKind::Tensor => {
369 let row_count = rows.max(1);
370 let cols_per_row = if row_count == 0 {
371 0
372 } else {
373 elements.len() / row_count
374 };
375 let mut row_dims = Vec::new();
376 for row_idx in 0..row_count {
377 let start = row_idx * cols_per_row;
378 let end = start + cols_per_row;
379 let mut total_cols = 0usize;
380 let mut expected_rows = None;
381 for value in &values[start..end] {
382 if let Some((rows, cols)) = value.shape.as_ref().and_then(matrix_dims) {
383 if let (Some(expected), Some(rows)) = (expected_rows, rows) {
384 if expected != rows {
385 self.warn(
386 "lint.shape.horzcat",
387 "horizontal concatenation row dimensions do not agree",
388 span,
389 );
390 }
391 }
392 expected_rows = expected_rows.or(rows);
393 total_cols += cols.unwrap_or(1);
394 } else {
395 total_cols += 1;
396 }
397 }
398 row_dims.push((expected_rows.unwrap_or(1), total_cols));
399 }
400 if let Some((_, first_cols)) = row_dims.first().copied() {
401 for (_, cols) in &row_dims {
402 if *cols != first_cols {
403 self.warn(
404 "lint.shape.vertcat",
405 "vertical concatenation column dimensions do not agree",
406 span,
407 );
408 }
409 }
410 Some(Shape(vec![
411 Some(row_dims.iter().map(|(rows, _)| rows).sum()),
412 Some(first_cols),
413 ]))
414 } else {
415 Some(Shape(vec![Some(0), Some(0)]))
416 }
417 }
418 runmat_mir::MirAggregateKind::Cell => Some(Shape(vec![Some(1), Some(elements.len())])),
419 };
420 MirShapeValue {
421 shape,
422 number: None,
423 int_vector,
424 }
425 }
426
427 fn infer_mir_call(
428 &mut self,
429 body: &runmat_mir::MirBody,
430 span: Span,
431 call: &runmat_mir::MirCall,
432 ) -> MirShapeValue {
433 let arg_values: Vec<_> = call
434 .args
435 .iter()
436 .map(|arg| self.infer_mir_operand(body, arg.operand()))
437 .collect();
438 let shape = match call.semantic_kind {
439 BuiltinSemanticKind::Elementwise => {
440 arg_values.first().and_then(|value| value.shape.clone())
441 }
442 BuiltinSemanticKind::ArrayConstructor => sized_constructor_shape(&arg_values),
443 BuiltinSemanticKind::ParameterizedArrayConstructor => {
444 sized_constructor_shape(arg_values.get(1..).unwrap_or(&[]))
445 }
446 BuiltinSemanticKind::PermutationConstructor => Some(Shape(vec![
447 Some(1),
448 arg_values
449 .first()
450 .and_then(|value| value.number.and_then(number_to_int)),
451 ])),
452 BuiltinSemanticKind::RangeConstructor => Some(Shape(vec![Some(1), None])),
453 BuiltinSemanticKind::EmptyConstructor => Some(Shape(vec![Some(0), Some(0)])),
454 BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Dot) => {
455 let lhs = arg_values.first().and_then(|value| value.shape.as_ref());
456 let rhs = arg_values.get(1).and_then(|value| value.shape.as_ref());
457 if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
458 if vector_len(lhs)
459 .zip(vector_len(rhs))
460 .is_some_and(|(l, r)| l != r)
461 {
462 self.warn(
463 "lint.shape.dot",
464 "dot product vector lengths do not agree",
465 span,
466 );
467 }
468 }
469 Some(Shape(vec![Some(1), Some(1)]))
470 }
471 BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Reshape) => {
472 let input = arg_values.first().and_then(|value| value.shape.as_ref());
473 let dims = mir_parse_dims(&arg_values[1..]);
474 if dims.iter().filter(|dim| matches!(dim, Dim::Infer)).count() > 1
475 || incompatible_element_count(input, &dims)
476 {
477 self.warn(
478 "lint.shape.reshape",
479 "reshape dimensions are not compatible",
480 span,
481 );
482 }
483 Some(Shape(dims.iter().map(|dim| dim.as_shape_dim()).collect()))
484 }
485 BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Repmat) => {
486 for arg in &arg_values[1..] {
487 if !matches!(mir_parse_dim(arg), Dim::Known(_)) {
488 self.warn(
489 "lint.shape.repmat",
490 "repmat dimensions must be non-negative integers",
491 span,
492 );
493 }
494 }
495 arg_values.first().and_then(|value| value.shape.clone())
496 }
497 BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Permute) => {
498 let base = arg_values.first().and_then(|value| value.shape.clone());
499 let order = arg_values.get(1).and_then(|value| value.int_vector.clone());
500 if let Some(order) = &order {
501 let mut sorted = order.clone();
502 sorted.sort_unstable();
503 if sorted.windows(2).any(|pair| pair[0] == pair[1])
504 || base
505 .as_ref()
506 .is_some_and(|shape| order.len() != shape.0.len())
507 {
508 self.warn(
509 "lint.shape.permute",
510 "permute order is invalid for input rank",
511 span,
512 );
513 }
514 }
515 base
516 }
517 BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Transpose) => {
518 let base = arg_values.first().and_then(|value| value.shape.clone());
519 base.map(|shape| {
520 if shape.0.len() >= 2 {
521 Shape(vec![shape.0[1], shape.0[0]])
522 } else {
523 shape
524 }
525 })
526 }
527 BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Concatenate(kind)) => {
528 self.infer_mir_concat(span, kind, &arg_values)
529 }
530 BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::General) => {
531 arg_values.first().and_then(|value| value.shape.clone())
532 }
533 BuiltinSemanticKind::Reduction => {
534 let base = arg_values.first().and_then(|value| value.shape.clone());
535 if let (Some(base_shape), Some(dim)) = (
536 base.as_ref(),
537 arg_values
538 .get(1)
539 .and_then(|value| value.number.and_then(number_to_int)),
540 ) {
541 if dim == 0 || dim > base_shape.0.len() {
542 self.warn(
543 "lint.shape.reduction",
544 "reduction dimension is out of range",
545 span,
546 );
547 }
548 }
549 base
550 }
551 _ => None,
552 };
553 MirShapeValue {
554 shape,
555 number: None,
556 int_vector: None,
557 }
558 }
559
560 fn infer_mir_concat(
561 &mut self,
562 span: Span,
563 kind: ConcatKind,
564 arg_values: &[MirShapeValue],
565 ) -> Option<Shape> {
566 let (dim, values) = match kind {
567 ConcatKind::Dimension => {
568 let dim = arg_values
569 .first()
570 .and_then(|value| value.number.and_then(number_to_int))?;
571 (dim, &arg_values[1..])
572 }
573 ConcatKind::Horizontal => (2, arg_values),
574 ConcatKind::Vertical => (1, arg_values),
575 };
576 let shapes: Vec<_> = values
577 .iter()
578 .filter_map(|value| value.shape.as_ref())
579 .collect();
580 if shapes.is_empty() || dim == 0 {
581 return None;
582 }
583 let rank = shapes
584 .iter()
585 .map(|shape| shape.0.len())
586 .max()
587 .unwrap_or(dim);
588 let axis = dim - 1;
589 if axis >= rank {
590 return None;
591 }
592 let mut out = vec![Some(1); rank];
593 for (idx, out_dim) in out.iter_mut().enumerate().take(rank) {
594 if idx == axis {
595 *out_dim = shapes
596 .iter()
597 .map(|shape| shape.0.get(idx).copied().flatten())
598 .try_fold(0usize, |sum, dim| dim.map(|dim| sum + dim));
599 continue;
600 }
601 let mut expected = None;
602 for shape in &shapes {
603 let dim = shape.0.get(idx).copied().flatten().or(Some(1));
604 if let (Some(expected), Some(dim)) = (expected, dim) {
605 if expected != dim {
606 self.warn(
607 "lint.shape.concat",
608 "concatenation dimensions do not agree",
609 span,
610 );
611 }
612 }
613 expected = expected.or(dim);
614 }
615 *out_dim = expected;
616 }
617 Some(Shape(out))
618 }
619
620 fn check_logical_index(&mut self, span: Span, base: Option<&Shape>, idx: Option<&Shape>) {
621 if let (Some(base), Some(idx)) = (base, idx) {
622 if element_count(base)
623 .zip(element_count(idx))
624 .is_some_and(|(base, idx)| base != idx)
625 {
626 self.warn(
627 "lint.shape.logical_index",
628 "logical index shape does not match indexed value",
629 span,
630 );
631 }
632 }
633 }
634
635 fn warn(&mut self, code: &'static str, message: &'static str, span: Span) {
636 self.diagnostics.push(
637 HirDiagnostic::new(code, HirDiagnosticSeverity::Warning, message, span)
638 .with_category("shape"),
639 );
640 }
641}
642
643#[derive(Clone, Copy, PartialEq)]
644enum Dim {
645 Known(usize),
646 Infer,
647 Unknown,
648}
649
650impl Dim {
651 fn as_shape_dim(self) -> Option<usize> {
652 match self {
653 Dim::Known(value) => Some(value),
654 Dim::Infer | Dim::Unknown => None,
655 }
656 }
657}
658
659fn number_to_int(value: f64) -> Option<usize> {
660 if value.is_finite() && value >= 0.0 && (value.fract().abs() <= 1e-9) {
661 Some(value as usize)
662 } else {
663 None
664 }
665}
666
667fn mir_parse_dim(value: &MirShapeValue) -> Dim {
668 match value.number {
669 Some(-1.0) => Dim::Infer,
670 Some(value) => number_to_int(value).map(Dim::Known).unwrap_or(Dim::Unknown),
671 None => Dim::Unknown,
672 }
673}
674
675fn mir_parse_dims(args: &[MirShapeValue]) -> Vec<Dim> {
676 if args.len() == 1 {
677 if let Some(values) = &args[0].int_vector {
678 return values.iter().copied().map(Dim::Known).collect();
679 }
680 }
681 args.iter().map(mir_parse_dim).collect()
682}
683
684fn sized_constructor_shape(args: &[MirShapeValue]) -> Option<Shape> {
685 let dims: Vec<_> = args
686 .iter()
687 .filter_map(|value| value.number.and_then(number_to_int))
688 .map(Some)
689 .collect();
690 match dims.as_slice() {
691 [] => None,
692 [dim] => Some(Shape(vec![*dim, *dim])),
693 _ => Some(Shape(dims)),
694 }
695}
696
697fn shape_from_fact(shape: &runmat_hir::ShapeFact) -> Option<Shape> {
698 match shape {
699 runmat_hir::ShapeFact::Scalar => Some(Shape(vec![Some(1), Some(1)])),
700 runmat_hir::ShapeFact::Shaped { dims } => Some(Shape(
701 dims.iter()
702 .map(|dim| match dim {
703 runmat_hir::DimFact::Known(value) => Some(*value),
704 runmat_hir::DimFact::Symbolic(_) | runmat_hir::DimFact::Unknown => None,
705 })
706 .collect(),
707 )),
708 runmat_hir::ShapeFact::Ranked { .. }
709 | runmat_hir::ShapeFact::Unknown
710 | runmat_hir::ShapeFact::Unreachable => None,
711 }
712}
713
714fn range_width(start: f64, step: f64, end: f64) -> Option<usize> {
715 if step == 0.0 || !start.is_finite() || !step.is_finite() || !end.is_finite() {
716 return None;
717 }
718 let span = end - start;
719 if (span > 0.0 && step < 0.0) || (span < 0.0 && step > 0.0) {
720 return Some(0);
721 }
722 Some((span / step).floor().abs() as usize + 1)
723}
724
725fn matrix_dims(shape: &Shape) -> Option<(Option<usize>, Option<usize>)> {
726 Some((*shape.0.first()?, *shape.0.get(1)?))
727}
728
729fn element_count(shape: &Shape) -> Option<usize> {
730 shape
731 .0
732 .iter()
733 .try_fold(1usize, |acc, dim| dim.map(|dim| acc * dim))
734}
735
736fn vector_len(shape: &Shape) -> Option<usize> {
737 let count = element_count(shape)?;
738 if shape.0.len() == 1
739 || (shape.0.len() == 2 && (shape.0[0] == Some(1) || shape.0[1] == Some(1)))
740 {
741 Some(count)
742 } else {
743 None
744 }
745}
746
747fn broadcast_compatible(left: &Shape, right: &Shape) -> bool {
748 let len = left.0.len().max(right.0.len());
749 (0..len).all(|idx| {
750 let l = left.0.iter().rev().nth(idx).copied().flatten().unwrap_or(1);
751 let r = right
752 .0
753 .iter()
754 .rev()
755 .nth(idx)
756 .copied()
757 .flatten()
758 .unwrap_or(1);
759 l == r || l == 1 || r == 1
760 })
761}
762
763fn incompatible_element_count(input: Option<&Shape>, dims: &[Dim]) -> bool {
764 let Some(input_count) = input.and_then(element_count) else {
765 return false;
766 };
767 if dims.iter().any(|dim| matches!(dim, Dim::Unknown)) {
768 return true;
769 }
770 let known_product = dims.iter().fold(1usize, |acc, dim| match dim {
771 Dim::Known(value) => acc * value,
772 Dim::Infer | Dim::Unknown => acc,
773 });
774 if dims.iter().any(|dim| matches!(dim, Dim::Infer)) {
775 known_product == 0 || input_count % known_product != 0
776 } else {
777 known_product != input_count
778 }
779}