1use serde::{Deserialize, Serialize};
2use std::collections::BTreeMap;
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct GraphJson {
6 pub format: String, pub version: u32, #[serde(skip_serializing_if = "Option::is_none")]
9 pub name: Option<String>,
10 pub inputs: BTreeMap<String, OperandDesc>,
11 #[serde(default)]
12 pub consts: BTreeMap<String, ConstDecl>,
13 pub nodes: Vec<Node>,
14 pub outputs: BTreeMap<String, String>,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
19pub struct OperandDesc {
20 #[serde(rename = "dataType")]
21 pub data_type: DataType,
22 pub shape: Vec<u32>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
26pub enum DataType {
27 #[serde(rename = "float32")]
28 Float32,
29 #[serde(rename = "float16")]
30 Float16,
31 #[serde(rename = "int32")]
32 Int32,
33 #[serde(rename = "uint32")]
34 Uint32,
35 #[serde(rename = "int64")]
36 Int64,
37 #[serde(rename = "uint64")]
38 Uint64,
39 #[serde(rename = "int8")]
40 Int8,
41 #[serde(rename = "uint8")]
42 Uint8,
43}
44
45impl DataType {
46 pub fn from_wg(s: &str) -> Option<Self> {
47 match s {
48 "f32" => Some(Self::Float32),
49 "f16" => Some(Self::Float16),
50 "i32" => Some(Self::Int32),
51 "u32" => Some(Self::Uint32),
52 "i64" => Some(Self::Int64),
53 "u64" => Some(Self::Uint64),
54 "i8" => Some(Self::Int8),
55 "u8" => Some(Self::Uint8),
56 _ => None,
57 }
58 }
59
60 pub fn to_wg_text(&self) -> &'static str {
61 match self {
62 Self::Float32 => "f32",
63 Self::Float16 => "f16",
64 Self::Int32 => "i32",
65 Self::Uint32 => "u32",
66 Self::Int64 => "i64",
67 Self::Uint64 => "u64",
68 Self::Int8 => "i8",
69 Self::Uint8 => "u8",
70 }
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
75pub struct ConstDecl {
76 #[serde(rename = "dataType")]
77 pub data_type: DataType,
78 pub shape: Vec<u32>,
79 pub init: ConstInit,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
83#[serde(tag = "kind", rename_all = "camelCase")]
84pub enum ConstInit {
85 Weights { r#ref: String },
86 Scalar { value: serde_json::Value },
87 InlineBytes { bytes: Vec<u8> },
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct Node {
92 pub id: String,
93 pub op: String,
94 pub inputs: Vec<String>,
95 #[serde(default)]
96 pub options: serde_json::Map<String, serde_json::Value>,
97 #[serde(default)]
98 pub outputs: Option<Vec<String>>,
99}
100
101pub fn new_graph_json() -> GraphJson {
102 GraphJson {
103 format: "webnn-graph-json".to_string(),
104 version: 1,
105 name: None,
106 inputs: BTreeMap::new(),
107 consts: BTreeMap::new(),
108 nodes: Vec::new(),
109 outputs: BTreeMap::new(),
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 #[test]
118 fn test_datatype_from_wg() {
119 assert_eq!(DataType::from_wg("f32"), Some(DataType::Float32));
120 assert_eq!(DataType::from_wg("f16"), Some(DataType::Float16));
121 assert_eq!(DataType::from_wg("i32"), Some(DataType::Int32));
122 assert_eq!(DataType::from_wg("u32"), Some(DataType::Uint32));
123 assert_eq!(DataType::from_wg("i64"), Some(DataType::Int64));
124 assert_eq!(DataType::from_wg("u64"), Some(DataType::Uint64));
125 assert_eq!(DataType::from_wg("i8"), Some(DataType::Int8));
126 assert_eq!(DataType::from_wg("u8"), Some(DataType::Uint8));
127 assert_eq!(DataType::from_wg("invalid"), None);
128 assert_eq!(DataType::from_wg("float32"), None);
129 }
130
131 #[test]
132 fn test_new_graph_json() {
133 let graph = new_graph_json();
134 assert_eq!(graph.format, "webnn-graph-json");
135 assert_eq!(graph.version, 1);
136 assert!(graph.inputs.is_empty());
137 assert!(graph.consts.is_empty());
138 assert!(graph.nodes.is_empty());
139 assert!(graph.outputs.is_empty());
140 }
141
142 #[test]
143 fn test_operand_desc_equality() {
144 let desc1 = OperandDesc {
145 data_type: DataType::Float32,
146 shape: vec![1, 2, 3],
147 };
148 let desc2 = OperandDesc {
149 data_type: DataType::Float32,
150 shape: vec![1, 2, 3],
151 };
152 let desc3 = OperandDesc {
153 data_type: DataType::Float16,
154 shape: vec![1, 2, 3],
155 };
156 assert_eq!(desc1, desc2);
157 assert_ne!(desc1, desc3);
158 }
159
160 #[test]
161 fn test_const_init_variants() {
162 let weights_init = ConstInit::Weights {
163 r#ref: "W".to_string(),
164 };
165 let scalar_init = ConstInit::Scalar {
166 value: serde_json::json!(1.0),
167 };
168 let bytes_init = ConstInit::InlineBytes {
169 bytes: vec![1, 2, 3, 4],
170 };
171
172 assert!(matches!(weights_init, ConstInit::Weights { .. }));
174 assert!(matches!(scalar_init, ConstInit::Scalar { .. }));
175 assert!(matches!(bytes_init, ConstInit::InlineBytes { .. }));
176 }
177}