Skip to main content

openinfer_simulator/
op_defs.rs

1use std::collections::HashMap;
2
3use anyhow::{anyhow, Result};
4use once_cell::sync::OnceCell;
5use serde::Deserialize;
6
7use crate::graph::{AttrValue, OpAttrs, OpKind};
8use crate::tensor::DType;
9
10/// Attribute type used in op schemas.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12#[allow(dead_code)]
13pub enum OpAttrType {
14    Scalar,
15    DType,
16    Tensor,
17    String,
18    IntList,
19}
20
21/// Allowed scalar kinds for scalar attributes.
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23#[allow(dead_code)]
24pub enum ScalarAttrKind {
25    Float,
26    Int,
27    UInt,
28    Bool,
29}
30
31/// Definition of a single op attribute.
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub struct OpAttrDef {
34    pub name: &'static str,
35    pub kind: OpAttrType,
36    pub scalar_kinds: &'static [ScalarAttrKind],
37}
38
39impl OpAttrDef {
40    /// Create a non-scalar attribute definition.
41    pub const fn new(name: &'static str, kind: OpAttrType) -> Self {
42        Self {
43            name,
44            kind,
45            scalar_kinds: &[],
46        }
47    }
48
49    /// Create a scalar attribute definition.
50    pub const fn scalar(name: &'static str, scalar_kinds: &'static [ScalarAttrKind]) -> Self {
51        Self {
52            name,
53            kind: OpAttrType::Scalar,
54            scalar_kinds,
55        }
56    }
57}
58
59/// Supported dtypes for an op (normal/accumulate modes).
60#[derive(Debug, Clone, Copy)]
61#[allow(dead_code)]
62pub struct OpDTypeSupport {
63    pub normal: &'static [DType],
64    pub accumulate: &'static [(DType, DType)],
65}
66
67/// Broadcast support for an op.
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
69#[allow(dead_code)]
70pub enum BroadcastSupport {
71    Deny,
72    Allow,
73}
74
75impl BroadcastSupport {
76    /// True if broadcasting is allowed.
77    pub fn allow(self) -> bool {
78        matches!(self, BroadcastSupport::Allow)
79    }
80}
81
82/// In-place support for an op.
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
84#[allow(dead_code)]
85pub enum InplaceSupport {
86    Deny,
87    Allow,
88}
89
90impl InplaceSupport {
91    /// True if in-place execution is allowed.
92    pub fn allow(self) -> bool {
93        matches!(self, InplaceSupport::Allow)
94    }
95}
96
97/// Accumulate support for an op.
98#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99#[allow(dead_code)]
100pub enum AccumulateSupport {
101    Deny,
102    Allow,
103}
104
105impl AccumulateSupport {
106    /// True if accumulate mode is allowed.
107    pub fn allow(self) -> bool {
108        matches!(self, AccumulateSupport::Allow)
109    }
110}
111
112/// Static schema definition for an op.
113#[derive(Debug, Clone, Copy)]
114#[allow(dead_code)]
115pub struct OpSchema {
116    pub kind: OpKind,
117    pub inputs: InputArity,
118    pub outputs: OutputArity,
119    pub attrs: &'static [OpAttrDef],
120    pub broadcast: BroadcastSupport,
121    pub inplace: InplaceSupport,
122    pub accumulate: AccumulateSupport,
123    pub type_rule: TypeRule,
124    pub dtype_support: Option<&'static OpDTypeSupport>,
125    pub output_dtypes: Option<&'static [DType]>,
126}
127
128/// Type inference rule for op outputs.
129#[derive(Debug, Clone, Copy)]
130#[allow(dead_code)]
131pub enum TypeRule {
132    SameAsInput(usize),
133    Fixed(DType),
134    AccFromAttr { attr: &'static str },
135}
136
137/// Input arity constraints for an op.
138#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139#[allow(dead_code)]
140pub enum InputArity {
141    Fixed(usize),
142    AtLeast(usize),
143    Any,
144}
145
146impl InputArity {
147    /// True if the provided count satisfies the arity.
148    pub fn allows(self, count: usize) -> bool {
149        match self {
150            InputArity::Fixed(expected) => count == expected,
151            InputArity::AtLeast(min) => count >= min,
152            InputArity::Any => true,
153        }
154    }
155
156    /// Return the fixed input count if applicable.
157    pub fn fixed(self) -> Option<usize> {
158        match self {
159            InputArity::Fixed(count) => Some(count),
160            _ => None,
161        }
162    }
163}
164
165/// Output arity constraints for an op.
166#[derive(Debug, Clone, Copy, PartialEq, Eq)]
167#[allow(dead_code)]
168pub enum OutputArity {
169    Fixed(usize),
170    AtLeast(usize),
171    Any,
172}
173
174#[allow(dead_code)]
175impl OutputArity {
176    /// True if the provided count satisfies the arity.
177    pub fn allows(self, count: usize) -> bool {
178        match self {
179            OutputArity::Fixed(expected) => count == expected,
180            OutputArity::AtLeast(min) => count >= min,
181            OutputArity::Any => true,
182        }
183    }
184
185    #[allow(dead_code)]
186    /// Return the fixed output count if applicable.
187    pub fn fixed(self) -> Option<usize> {
188        match self {
189            OutputArity::Fixed(count) => Some(count),
190            _ => None,
191        }
192    }
193}
194
195impl TypeRule {
196    /// Infer an output dtype from inputs and attributes.
197    pub fn output_dtype(self, inputs: &[DType], attrs: &OpAttrs) -> Result<DType> {
198        match self {
199            TypeRule::SameAsInput(index) => inputs
200                .get(index)
201                .copied()
202                .ok_or_else(|| anyhow!("missing input dtype at {}", index)),
203            TypeRule::Fixed(dtype) => Ok(dtype),
204            TypeRule::AccFromAttr { attr } => attrs
205                .items
206                .iter()
207                .find(|item| item.name == attr)
208                .ok_or_else(|| anyhow!("missing {} attribute", attr))
209                .and_then(|item| match &item.value {
210                    AttrValue::DType(dtype) => Ok(*dtype),
211                    _ => Err(anyhow!("{} attribute must be a dtype", attr)),
212                }),
213        }
214    }
215}
216
217#[derive(Debug)]
218#[allow(dead_code)]
219struct OpRegistry {
220    schemas: Vec<OpSchema>,
221    dtype_supports: HashMap<String, &'static OpDTypeSupport>,
222    output_dtype_sets: HashMap<String, &'static [DType]>,
223}
224
225static REGISTRY: OnceCell<OpRegistry> = OnceCell::new();
226
227#[derive(Debug, Deserialize)]
228struct OpsFile {
229    version: u32,
230    attr_defs: HashMap<String, AttrDefJson>,
231    dtype_sets: HashMap<String, DTypeSupportJson>,
232    output_dtype_sets: Option<HashMap<String, Vec<String>>>,
233    ops: Vec<OpSchemaJson>,
234}
235
236#[derive(Debug, Deserialize)]
237struct AttrDefJson {
238    kind: String,
239    #[serde(default)]
240    scalar_kinds: Vec<String>,
241}
242
243#[derive(Debug, Deserialize)]
244#[allow(dead_code)]
245struct OpSchemaJson {
246    name: String,
247    kind: OpKind,
248    inputs: ArityJson,
249    outputs: ArityJson,
250    #[serde(default)]
251    attrs: Vec<String>,
252    broadcast: String,
253    inplace: String,
254    accumulate: String,
255    type_rule: TypeRuleJson,
256    dtype_support_ref: Option<String>,
257    output_dtypes_ref: Option<String>,
258    #[serde(default)]
259    devices: Option<serde_json::Value>,
260}
261
262#[derive(Debug, Deserialize)]
263struct ArityJson {
264    arity: String,
265    count: Option<usize>,
266}
267
268#[derive(Debug, Deserialize)]
269struct TypeRuleJson {
270    kind: String,
271    index: Option<usize>,
272    dtype: Option<String>,
273    attr: Option<String>,
274}
275
276#[derive(Debug, Deserialize)]
277struct DTypeSupportJson {
278    normal: Vec<String>,
279    #[serde(default)]
280    accumulate: Vec<AccumulatePairJson>,
281}
282
283#[derive(Debug, Deserialize)]
284struct AccumulatePairJson {
285    input: String,
286    acc: String,
287}
288
289fn registry() -> &'static OpRegistry {
290    REGISTRY.get_or_init(|| {
291        load_registry().unwrap_or_else(|err| panic!("ops registry init failed: {err}"))
292    })
293}
294
295fn load_registry() -> Result<OpRegistry> {
296    let json = include_str!("../ops.json");
297    let file: OpsFile = serde_json::from_str(json)?;
298    if file.version != 1 {
299        return Err(anyhow!("unsupported ops.json version {}", file.version));
300    }
301
302    let attr_defs = build_attr_defs(&file.attr_defs)?;
303    let dtype_supports = build_dtype_supports(&file.dtype_sets)?;
304    let output_dtype_sets = build_output_dtype_sets(file.output_dtype_sets.as_ref())?;
305    let mut schemas = Vec::with_capacity(file.ops.len());
306    for op in file.ops {
307        let attrs = build_attr_list(&attr_defs, &op.attrs)?;
308        let inputs = parse_input_arity(&op.inputs)?;
309        let outputs = parse_output_arity(&op.outputs)?;
310        let broadcast = parse_broadcast(&op.broadcast)?;
311        let inplace = parse_inplace(&op.inplace)?;
312        let accumulate = parse_accumulate(&op.accumulate)?;
313        let type_rule = parse_type_rule(op.type_rule)?;
314        let dtype_support = op
315            .dtype_support_ref
316            .as_deref()
317            .and_then(|name| dtype_supports.get(name).copied())
318            .ok_or_else(|| anyhow!("unknown dtype_support_ref for {}", op.name))?;
319        let output_dtypes = match op.output_dtypes_ref.as_deref() {
320            Some(name) => Some(
321                output_dtype_sets
322                    .get(name)
323                    .copied()
324                    .ok_or_else(|| anyhow!("unknown output_dtypes_ref for {}", op.name))?,
325            ),
326            None => None,
327        };
328        schemas.push(OpSchema {
329            kind: op.kind,
330            inputs,
331            outputs,
332            attrs,
333            broadcast,
334            inplace,
335            accumulate,
336            type_rule,
337            dtype_support: Some(dtype_support),
338            output_dtypes,
339        });
340    }
341    Ok(OpRegistry {
342        schemas,
343        dtype_supports,
344        output_dtype_sets,
345    })
346}
347
348fn build_attr_defs(defs: &HashMap<String, AttrDefJson>) -> Result<HashMap<String, OpAttrDef>> {
349    let mut out = HashMap::new();
350    for (name, def) in defs {
351        let name_static: &'static str = Box::leak(name.clone().into_boxed_str());
352        let kind = match def.kind.as_str() {
353            "scalar" => OpAttrType::Scalar,
354            "dtype" => OpAttrType::DType,
355            "tensor" => OpAttrType::Tensor,
356            "string" => OpAttrType::String,
357            "int_list" => OpAttrType::IntList,
358            other => return Err(anyhow!("unknown attr kind {other} for {name}")),
359        };
360        let scalar_kinds = if matches!(kind, OpAttrType::Scalar) {
361            let kinds = def
362                .scalar_kinds
363                .iter()
364                .map(|kind| match kind.as_str() {
365                    "float" => Ok(ScalarAttrKind::Float),
366                    "int" => Ok(ScalarAttrKind::Int),
367                    "uint" => Ok(ScalarAttrKind::UInt),
368                    "bool" => Ok(ScalarAttrKind::Bool),
369                    other => Err(anyhow!("unknown scalar kind {other} for {name}")),
370                })
371                .collect::<Result<Vec<_>>>()?;
372            Box::leak(kinds.into_boxed_slice()) as &'static [ScalarAttrKind]
373        } else {
374            &[]
375        };
376        out.insert(
377            name.clone(),
378            OpAttrDef {
379                name: name_static,
380                kind,
381                scalar_kinds,
382            },
383        );
384    }
385    Ok(out)
386}
387
388fn build_attr_list(
389    defs: &HashMap<String, OpAttrDef>,
390    attrs: &[String],
391) -> Result<&'static [OpAttrDef]> {
392    let mut out = Vec::with_capacity(attrs.len());
393    for attr in attrs {
394        let def = defs
395            .get(attr)
396            .copied()
397            .ok_or_else(|| anyhow!("unknown attr {attr} in ops.json"))?;
398        out.push(def);
399    }
400    Ok(Box::leak(out.into_boxed_slice()))
401}
402
403fn parse_input_arity(arity: &ArityJson) -> Result<InputArity> {
404    match arity.arity.as_str() {
405        "fixed" => Ok(InputArity::Fixed(required_count(arity, "fixed")?)),
406        "at_least" => Ok(InputArity::AtLeast(required_count(arity, "at_least")?)),
407        "any" => Ok(InputArity::Any),
408        other => Err(anyhow!("unknown input arity {other}")),
409    }
410}
411
412fn parse_output_arity(arity: &ArityJson) -> Result<OutputArity> {
413    match arity.arity.as_str() {
414        "fixed" => Ok(OutputArity::Fixed(required_count(arity, "fixed")?)),
415        "at_least" => Ok(OutputArity::AtLeast(required_count(arity, "at_least")?)),
416        "any" => Ok(OutputArity::Any),
417        other => Err(anyhow!("unknown output arity {other}")),
418    }
419}
420
421fn required_count(arity: &ArityJson, label: &str) -> Result<usize> {
422    arity
423        .count
424        .ok_or_else(|| anyhow!("missing count for {label} arity"))
425}
426
427fn parse_broadcast(value: &str) -> Result<BroadcastSupport> {
428    match value {
429        "allow" => Ok(BroadcastSupport::Allow),
430        "deny" => Ok(BroadcastSupport::Deny),
431        other => Err(anyhow!("unknown broadcast support {other}")),
432    }
433}
434
435fn parse_inplace(value: &str) -> Result<InplaceSupport> {
436    match value {
437        "allow" => Ok(InplaceSupport::Allow),
438        "deny" => Ok(InplaceSupport::Deny),
439        other => Err(anyhow!("unknown inplace support {other}")),
440    }
441}
442
443fn parse_accumulate(value: &str) -> Result<AccumulateSupport> {
444    match value {
445        "allow" => Ok(AccumulateSupport::Allow),
446        "deny" => Ok(AccumulateSupport::Deny),
447        other => Err(anyhow!("unknown accumulate support {other}")),
448    }
449}
450
451fn parse_type_rule(rule: TypeRuleJson) -> Result<TypeRule> {
452    match rule.kind.as_str() {
453        "same_as_input" => Ok(TypeRule::SameAsInput(
454            rule.index.ok_or_else(|| anyhow!("missing index for same_as_input"))?,
455        )),
456        "fixed" => {
457            let dtype = rule
458                .dtype
459                .ok_or_else(|| anyhow!("missing dtype for fixed type_rule"))?;
460            Ok(TypeRule::Fixed(DType::from_ident(&dtype)?))
461        }
462        "acc_from_attr" => {
463            let attr = rule
464                .attr
465                .ok_or_else(|| anyhow!("missing attr for acc_from_attr"))?;
466            let attr_static: &'static str = Box::leak(attr.into_boxed_str());
467            Ok(TypeRule::AccFromAttr { attr: attr_static })
468        }
469        other => Err(anyhow!("unknown type_rule {other}")),
470    }
471}
472
473fn build_dtype_supports(
474    dtype_sets: &HashMap<String, DTypeSupportJson>,
475) -> Result<HashMap<String, &'static OpDTypeSupport>> {
476    let mut out = HashMap::new();
477    for (name, support) in dtype_sets {
478        let normal = support
479            .normal
480            .iter()
481            .map(|ident| DType::from_ident(ident))
482            .collect::<Result<Vec<_>>>()?;
483        let accumulate = support
484            .accumulate
485            .iter()
486            .map(|pair| {
487                Ok((
488                    DType::from_ident(&pair.input)?,
489                    DType::from_ident(&pair.acc)?,
490                ))
491            })
492            .collect::<Result<Vec<_>>>()?;
493        let normal_static = Box::leak(normal.into_boxed_slice());
494        let acc_static = Box::leak(accumulate.into_boxed_slice());
495        let support_static: &'static OpDTypeSupport = Box::leak(Box::new(OpDTypeSupport {
496            normal: normal_static,
497            accumulate: acc_static,
498        }));
499        out.insert(name.clone(), support_static);
500    }
501    Ok(out)
502}
503
504fn build_output_dtype_sets(
505    output_sets: Option<&HashMap<String, Vec<String>>>,
506) -> Result<HashMap<String, &'static [DType]>> {
507    let mut out = HashMap::new();
508    if let Some(output_sets) = output_sets {
509        for (name, dtypes) in output_sets {
510            let converted = dtypes
511                .iter()
512                .map(|ident| DType::from_ident(ident))
513                .collect::<Result<Vec<_>>>()?;
514            let leaked: &'static [DType] = Box::leak(converted.into_boxed_slice());
515            out.insert(name.clone(), leaked);
516        }
517    }
518    Ok(out)
519}
520
521/// Convenience accessor for the `acc` dtype attribute.
522#[allow(unused)]
523pub fn acc_dtype(attrs: &OpAttrs) -> Result<DType> {
524    attrs
525        .items
526        .iter()
527        .find(|attr| attr.name == "acc")
528        .ok_or_else(|| anyhow!("missing acc attribute"))
529        .and_then(|attr| match &attr.value {
530            AttrValue::DType(dtype) => Ok(*dtype),
531            _ => Err(anyhow!("acc attribute must be a dtype")),
532        })
533}
534
535/// Lookup the schema for a specific op kind.
536pub fn op_schema(kind: OpKind) -> Option<&'static OpSchema> {
537    registry().schemas.iter().find(|op| op.kind == kind)
538}
539
540/// Initialize the global op registry (idempotent).
541pub fn init_ops_registry() {
542    let _ = registry();
543}