1use std::collections::HashMap;
2
3use anyhow::{anyhow, Result};
4use once_cell::sync::OnceCell;
5use serde::Deserialize;
6
7use crate::graph::{AttrValue, OpAttrs, OpKind};
8use crate::tensor::DType;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12#[allow(dead_code)]
13pub enum OpAttrType {
14 Scalar,
15 DType,
16 Tensor,
17 String,
18 IntList,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23#[allow(dead_code)]
24pub enum ScalarAttrKind {
25 Float,
26 Int,
27 UInt,
28 Bool,
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub struct OpAttrDef {
34 pub name: &'static str,
35 pub kind: OpAttrType,
36 pub scalar_kinds: &'static [ScalarAttrKind],
37}
38
39impl OpAttrDef {
40 pub const fn new(name: &'static str, kind: OpAttrType) -> Self {
42 Self {
43 name,
44 kind,
45 scalar_kinds: &[],
46 }
47 }
48
49 pub const fn scalar(name: &'static str, scalar_kinds: &'static [ScalarAttrKind]) -> Self {
51 Self {
52 name,
53 kind: OpAttrType::Scalar,
54 scalar_kinds,
55 }
56 }
57}
58
59#[derive(Debug, Clone, Copy)]
61#[allow(dead_code)]
62pub struct OpDTypeSupport {
63 pub normal: &'static [DType],
64 pub accumulate: &'static [(DType, DType)],
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
69#[allow(dead_code)]
70pub enum BroadcastSupport {
71 Deny,
72 Allow,
73}
74
75impl BroadcastSupport {
76 pub fn allow(self) -> bool {
78 matches!(self, BroadcastSupport::Allow)
79 }
80}
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq)]
84#[allow(dead_code)]
85pub enum InplaceSupport {
86 Deny,
87 Allow,
88}
89
90impl InplaceSupport {
91 pub fn allow(self) -> bool {
93 matches!(self, InplaceSupport::Allow)
94 }
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99#[allow(dead_code)]
100pub enum AccumulateSupport {
101 Deny,
102 Allow,
103}
104
105impl AccumulateSupport {
106 pub fn allow(self) -> bool {
108 matches!(self, AccumulateSupport::Allow)
109 }
110}
111
112#[derive(Debug, Clone, Copy)]
114#[allow(dead_code)]
115pub struct OpSchema {
116 pub kind: OpKind,
117 pub inputs: InputArity,
118 pub outputs: OutputArity,
119 pub attrs: &'static [OpAttrDef],
120 pub broadcast: BroadcastSupport,
121 pub inplace: InplaceSupport,
122 pub accumulate: AccumulateSupport,
123 pub type_rule: TypeRule,
124 pub dtype_support: Option<&'static OpDTypeSupport>,
125 pub output_dtypes: Option<&'static [DType]>,
126}
127
128#[derive(Debug, Clone, Copy)]
130#[allow(dead_code)]
131pub enum TypeRule {
132 SameAsInput(usize),
133 Fixed(DType),
134 AccFromAttr { attr: &'static str },
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139#[allow(dead_code)]
140pub enum InputArity {
141 Fixed(usize),
142 AtLeast(usize),
143 Any,
144}
145
146impl InputArity {
147 pub fn allows(self, count: usize) -> bool {
149 match self {
150 InputArity::Fixed(expected) => count == expected,
151 InputArity::AtLeast(min) => count >= min,
152 InputArity::Any => true,
153 }
154 }
155
156 pub fn fixed(self) -> Option<usize> {
158 match self {
159 InputArity::Fixed(count) => Some(count),
160 _ => None,
161 }
162 }
163}
164
165#[derive(Debug, Clone, Copy, PartialEq, Eq)]
167#[allow(dead_code)]
168pub enum OutputArity {
169 Fixed(usize),
170 AtLeast(usize),
171 Any,
172}
173
174#[allow(dead_code)]
175impl OutputArity {
176 pub fn allows(self, count: usize) -> bool {
178 match self {
179 OutputArity::Fixed(expected) => count == expected,
180 OutputArity::AtLeast(min) => count >= min,
181 OutputArity::Any => true,
182 }
183 }
184
185 #[allow(dead_code)]
186 pub fn fixed(self) -> Option<usize> {
188 match self {
189 OutputArity::Fixed(count) => Some(count),
190 _ => None,
191 }
192 }
193}
194
195impl TypeRule {
196 pub fn output_dtype(self, inputs: &[DType], attrs: &OpAttrs) -> Result<DType> {
198 match self {
199 TypeRule::SameAsInput(index) => inputs
200 .get(index)
201 .copied()
202 .ok_or_else(|| anyhow!("missing input dtype at {}", index)),
203 TypeRule::Fixed(dtype) => Ok(dtype),
204 TypeRule::AccFromAttr { attr } => attrs
205 .items
206 .iter()
207 .find(|item| item.name == attr)
208 .ok_or_else(|| anyhow!("missing {} attribute", attr))
209 .and_then(|item| match &item.value {
210 AttrValue::DType(dtype) => Ok(*dtype),
211 _ => Err(anyhow!("{} attribute must be a dtype", attr)),
212 }),
213 }
214 }
215}
216
217#[derive(Debug)]
218#[allow(dead_code)]
219struct OpRegistry {
220 schemas: Vec<OpSchema>,
221 dtype_supports: HashMap<String, &'static OpDTypeSupport>,
222 output_dtype_sets: HashMap<String, &'static [DType]>,
223}
224
225static REGISTRY: OnceCell<OpRegistry> = OnceCell::new();
226
227#[derive(Debug, Deserialize)]
228struct OpsFile {
229 version: u32,
230 attr_defs: HashMap<String, AttrDefJson>,
231 dtype_sets: HashMap<String, DTypeSupportJson>,
232 output_dtype_sets: Option<HashMap<String, Vec<String>>>,
233 ops: Vec<OpSchemaJson>,
234}
235
236#[derive(Debug, Deserialize)]
237struct AttrDefJson {
238 kind: String,
239 #[serde(default)]
240 scalar_kinds: Vec<String>,
241}
242
243#[derive(Debug, Deserialize)]
244#[allow(dead_code)]
245struct OpSchemaJson {
246 name: String,
247 kind: OpKind,
248 inputs: ArityJson,
249 outputs: ArityJson,
250 #[serde(default)]
251 attrs: Vec<String>,
252 broadcast: String,
253 inplace: String,
254 accumulate: String,
255 type_rule: TypeRuleJson,
256 dtype_support_ref: Option<String>,
257 output_dtypes_ref: Option<String>,
258 #[serde(default)]
259 devices: Option<serde_json::Value>,
260}
261
262#[derive(Debug, Deserialize)]
263struct ArityJson {
264 arity: String,
265 count: Option<usize>,
266}
267
268#[derive(Debug, Deserialize)]
269struct TypeRuleJson {
270 kind: String,
271 index: Option<usize>,
272 dtype: Option<String>,
273 attr: Option<String>,
274}
275
276#[derive(Debug, Deserialize)]
277struct DTypeSupportJson {
278 normal: Vec<String>,
279 #[serde(default)]
280 accumulate: Vec<AccumulatePairJson>,
281}
282
283#[derive(Debug, Deserialize)]
284struct AccumulatePairJson {
285 input: String,
286 acc: String,
287}
288
289fn registry() -> &'static OpRegistry {
290 REGISTRY.get_or_init(|| {
291 load_registry().unwrap_or_else(|err| panic!("ops registry init failed: {err}"))
292 })
293}
294
295fn load_registry() -> Result<OpRegistry> {
296 let json = include_str!("../ops.json");
297 let file: OpsFile = serde_json::from_str(json)?;
298 if file.version != 1 {
299 return Err(anyhow!("unsupported ops.json version {}", file.version));
300 }
301
302 let attr_defs = build_attr_defs(&file.attr_defs)?;
303 let dtype_supports = build_dtype_supports(&file.dtype_sets)?;
304 let output_dtype_sets = build_output_dtype_sets(file.output_dtype_sets.as_ref())?;
305 let mut schemas = Vec::with_capacity(file.ops.len());
306 for op in file.ops {
307 let attrs = build_attr_list(&attr_defs, &op.attrs)?;
308 let inputs = parse_input_arity(&op.inputs)?;
309 let outputs = parse_output_arity(&op.outputs)?;
310 let broadcast = parse_broadcast(&op.broadcast)?;
311 let inplace = parse_inplace(&op.inplace)?;
312 let accumulate = parse_accumulate(&op.accumulate)?;
313 let type_rule = parse_type_rule(op.type_rule)?;
314 let dtype_support = op
315 .dtype_support_ref
316 .as_deref()
317 .and_then(|name| dtype_supports.get(name).copied())
318 .ok_or_else(|| anyhow!("unknown dtype_support_ref for {}", op.name))?;
319 let output_dtypes = match op.output_dtypes_ref.as_deref() {
320 Some(name) => Some(
321 output_dtype_sets
322 .get(name)
323 .copied()
324 .ok_or_else(|| anyhow!("unknown output_dtypes_ref for {}", op.name))?,
325 ),
326 None => None,
327 };
328 schemas.push(OpSchema {
329 kind: op.kind,
330 inputs,
331 outputs,
332 attrs,
333 broadcast,
334 inplace,
335 accumulate,
336 type_rule,
337 dtype_support: Some(dtype_support),
338 output_dtypes,
339 });
340 }
341 Ok(OpRegistry {
342 schemas,
343 dtype_supports,
344 output_dtype_sets,
345 })
346}
347
348fn build_attr_defs(defs: &HashMap<String, AttrDefJson>) -> Result<HashMap<String, OpAttrDef>> {
349 let mut out = HashMap::new();
350 for (name, def) in defs {
351 let name_static: &'static str = Box::leak(name.clone().into_boxed_str());
352 let kind = match def.kind.as_str() {
353 "scalar" => OpAttrType::Scalar,
354 "dtype" => OpAttrType::DType,
355 "tensor" => OpAttrType::Tensor,
356 "string" => OpAttrType::String,
357 "int_list" => OpAttrType::IntList,
358 other => return Err(anyhow!("unknown attr kind {other} for {name}")),
359 };
360 let scalar_kinds = if matches!(kind, OpAttrType::Scalar) {
361 let kinds = def
362 .scalar_kinds
363 .iter()
364 .map(|kind| match kind.as_str() {
365 "float" => Ok(ScalarAttrKind::Float),
366 "int" => Ok(ScalarAttrKind::Int),
367 "uint" => Ok(ScalarAttrKind::UInt),
368 "bool" => Ok(ScalarAttrKind::Bool),
369 other => Err(anyhow!("unknown scalar kind {other} for {name}")),
370 })
371 .collect::<Result<Vec<_>>>()?;
372 Box::leak(kinds.into_boxed_slice()) as &'static [ScalarAttrKind]
373 } else {
374 &[]
375 };
376 out.insert(
377 name.clone(),
378 OpAttrDef {
379 name: name_static,
380 kind,
381 scalar_kinds,
382 },
383 );
384 }
385 Ok(out)
386}
387
388fn build_attr_list(
389 defs: &HashMap<String, OpAttrDef>,
390 attrs: &[String],
391) -> Result<&'static [OpAttrDef]> {
392 let mut out = Vec::with_capacity(attrs.len());
393 for attr in attrs {
394 let def = defs
395 .get(attr)
396 .copied()
397 .ok_or_else(|| anyhow!("unknown attr {attr} in ops.json"))?;
398 out.push(def);
399 }
400 Ok(Box::leak(out.into_boxed_slice()))
401}
402
403fn parse_input_arity(arity: &ArityJson) -> Result<InputArity> {
404 match arity.arity.as_str() {
405 "fixed" => Ok(InputArity::Fixed(required_count(arity, "fixed")?)),
406 "at_least" => Ok(InputArity::AtLeast(required_count(arity, "at_least")?)),
407 "any" => Ok(InputArity::Any),
408 other => Err(anyhow!("unknown input arity {other}")),
409 }
410}
411
412fn parse_output_arity(arity: &ArityJson) -> Result<OutputArity> {
413 match arity.arity.as_str() {
414 "fixed" => Ok(OutputArity::Fixed(required_count(arity, "fixed")?)),
415 "at_least" => Ok(OutputArity::AtLeast(required_count(arity, "at_least")?)),
416 "any" => Ok(OutputArity::Any),
417 other => Err(anyhow!("unknown output arity {other}")),
418 }
419}
420
421fn required_count(arity: &ArityJson, label: &str) -> Result<usize> {
422 arity
423 .count
424 .ok_or_else(|| anyhow!("missing count for {label} arity"))
425}
426
427fn parse_broadcast(value: &str) -> Result<BroadcastSupport> {
428 match value {
429 "allow" => Ok(BroadcastSupport::Allow),
430 "deny" => Ok(BroadcastSupport::Deny),
431 other => Err(anyhow!("unknown broadcast support {other}")),
432 }
433}
434
435fn parse_inplace(value: &str) -> Result<InplaceSupport> {
436 match value {
437 "allow" => Ok(InplaceSupport::Allow),
438 "deny" => Ok(InplaceSupport::Deny),
439 other => Err(anyhow!("unknown inplace support {other}")),
440 }
441}
442
443fn parse_accumulate(value: &str) -> Result<AccumulateSupport> {
444 match value {
445 "allow" => Ok(AccumulateSupport::Allow),
446 "deny" => Ok(AccumulateSupport::Deny),
447 other => Err(anyhow!("unknown accumulate support {other}")),
448 }
449}
450
451fn parse_type_rule(rule: TypeRuleJson) -> Result<TypeRule> {
452 match rule.kind.as_str() {
453 "same_as_input" => Ok(TypeRule::SameAsInput(
454 rule.index.ok_or_else(|| anyhow!("missing index for same_as_input"))?,
455 )),
456 "fixed" => {
457 let dtype = rule
458 .dtype
459 .ok_or_else(|| anyhow!("missing dtype for fixed type_rule"))?;
460 Ok(TypeRule::Fixed(DType::from_ident(&dtype)?))
461 }
462 "acc_from_attr" => {
463 let attr = rule
464 .attr
465 .ok_or_else(|| anyhow!("missing attr for acc_from_attr"))?;
466 let attr_static: &'static str = Box::leak(attr.into_boxed_str());
467 Ok(TypeRule::AccFromAttr { attr: attr_static })
468 }
469 other => Err(anyhow!("unknown type_rule {other}")),
470 }
471}
472
473fn build_dtype_supports(
474 dtype_sets: &HashMap<String, DTypeSupportJson>,
475) -> Result<HashMap<String, &'static OpDTypeSupport>> {
476 let mut out = HashMap::new();
477 for (name, support) in dtype_sets {
478 let normal = support
479 .normal
480 .iter()
481 .map(|ident| DType::from_ident(ident))
482 .collect::<Result<Vec<_>>>()?;
483 let accumulate = support
484 .accumulate
485 .iter()
486 .map(|pair| {
487 Ok((
488 DType::from_ident(&pair.input)?,
489 DType::from_ident(&pair.acc)?,
490 ))
491 })
492 .collect::<Result<Vec<_>>>()?;
493 let normal_static = Box::leak(normal.into_boxed_slice());
494 let acc_static = Box::leak(accumulate.into_boxed_slice());
495 let support_static: &'static OpDTypeSupport = Box::leak(Box::new(OpDTypeSupport {
496 normal: normal_static,
497 accumulate: acc_static,
498 }));
499 out.insert(name.clone(), support_static);
500 }
501 Ok(out)
502}
503
504fn build_output_dtype_sets(
505 output_sets: Option<&HashMap<String, Vec<String>>>,
506) -> Result<HashMap<String, &'static [DType]>> {
507 let mut out = HashMap::new();
508 if let Some(output_sets) = output_sets {
509 for (name, dtypes) in output_sets {
510 let converted = dtypes
511 .iter()
512 .map(|ident| DType::from_ident(ident))
513 .collect::<Result<Vec<_>>>()?;
514 let leaked: &'static [DType] = Box::leak(converted.into_boxed_slice());
515 out.insert(name.clone(), leaked);
516 }
517 }
518 Ok(out)
519}
520
521#[allow(unused)]
523pub fn acc_dtype(attrs: &OpAttrs) -> Result<DType> {
524 attrs
525 .items
526 .iter()
527 .find(|attr| attr.name == "acc")
528 .ok_or_else(|| anyhow!("missing acc attribute"))
529 .and_then(|attr| match &attr.value {
530 AttrValue::DType(dtype) => Ok(*dtype),
531 _ => Err(anyhow!("acc attribute must be a dtype")),
532 })
533}
534
535pub fn op_schema(kind: OpKind) -> Option<&'static OpSchema> {
537 registry().schemas.iter().find(|op| op.kind == kind)
538}
539
540pub fn init_ops_registry() {
542 let _ = registry();
543}