1use serde::{Deserialize, Serialize};
2use std::collections::BTreeMap;
3
4#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
5#[serde(rename_all = "camelCase")]
6pub struct DynamicDimension {
7 pub name: String,
8 pub max_size: u32,
9}
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
12#[serde(untagged)]
13pub enum Dimension {
14 Static(u32),
15 Dynamic(DynamicDimension),
16}
17
18pub fn to_dimension_vector(shape: &[u32]) -> Vec<Dimension> {
19 shape.iter().copied().map(Dimension::Static).collect()
20}
21
22pub fn get_static_or_max_size(dim: &Dimension) -> u32 {
23 match dim {
24 Dimension::Static(v) => *v,
25 Dimension::Dynamic(d) => d.max_size,
26 }
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct GraphJson {
31 pub format: String, pub version: u32, #[serde(skip_serializing_if = "Option::is_none")]
34 pub name: Option<String>,
35 #[serde(default)]
36 pub quantized: bool,
37 pub inputs: BTreeMap<String, OperandDesc>,
38 #[serde(default)]
39 pub consts: BTreeMap<String, ConstDecl>,
40 pub nodes: Vec<Node>,
41 pub outputs: BTreeMap<String, String>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
46pub struct OperandDesc {
47 #[serde(rename = "dataType")]
48 pub data_type: DataType,
49 pub shape: Vec<Dimension>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
53pub enum DataType {
54 #[serde(rename = "float32")]
55 Float32,
56 #[serde(rename = "float16")]
57 Float16,
58 #[serde(rename = "int4")]
59 Int4,
60 #[serde(rename = "uint4")]
61 Uint4,
62 #[serde(rename = "int32")]
63 Int32,
64 #[serde(rename = "uint32")]
65 Uint32,
66 #[serde(rename = "int64")]
67 Int64,
68 #[serde(rename = "uint64")]
69 Uint64,
70 #[serde(rename = "int8")]
71 Int8,
72 #[serde(rename = "uint8")]
73 Uint8,
74}
75
76impl DataType {
77 pub fn from_wg(s: &str) -> Option<Self> {
78 match s {
79 "f32" => Some(Self::Float32),
80 "f16" => Some(Self::Float16),
81 "i4" => Some(Self::Int4),
82 "u4" => Some(Self::Uint4),
83 "i32" => Some(Self::Int32),
84 "u32" => Some(Self::Uint32),
85 "i64" => Some(Self::Int64),
86 "u64" => Some(Self::Uint64),
87 "i8" => Some(Self::Int8),
88 "u8" => Some(Self::Uint8),
89 _ => None,
90 }
91 }
92
93 pub fn to_wg_text(&self) -> &'static str {
94 match self {
95 Self::Float32 => "f32",
96 Self::Float16 => "f16",
97 Self::Int4 => "i4",
98 Self::Uint4 => "u4",
99 Self::Int32 => "i32",
100 Self::Uint32 => "u32",
101 Self::Int64 => "i64",
102 Self::Uint64 => "u64",
103 Self::Int8 => "i8",
104 Self::Uint8 => "u8",
105 }
106 }
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
110pub struct ConstDecl {
111 #[serde(rename = "dataType")]
112 pub data_type: DataType,
113 pub shape: Vec<u32>,
114 pub init: ConstInit,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
118#[serde(tag = "kind", rename_all = "camelCase")]
119pub enum ConstInit {
120 Weights { r#ref: String },
121 Scalar { value: serde_json::Value },
122 InlineBytes { bytes: Vec<u8> },
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct Node {
127 pub id: String,
128 pub op: String,
129 pub inputs: Vec<String>,
130 #[serde(default)]
131 pub options: serde_json::Map<String, serde_json::Value>,
132 #[serde(default)]
133 pub outputs: Option<Vec<String>>,
134}
135
136pub fn new_graph_json() -> GraphJson {
137 GraphJson {
138 format: "webnn-graph-json".to_string(),
139 version: 2,
140 name: None,
141 quantized: false,
142 inputs: BTreeMap::new(),
143 consts: BTreeMap::new(),
144 nodes: Vec::new(),
145 outputs: BTreeMap::new(),
146 }
147}
148
149impl OperandDesc {
150 pub fn static_shape(&self) -> Option<Vec<u32>> {
151 let mut shape = Vec::with_capacity(self.shape.len());
152 for dim in &self.shape {
153 match dim {
154 Dimension::Static(v) => shape.push(*v),
155 Dimension::Dynamic(_) => return None,
156 }
157 }
158 Some(shape)
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 #[test]
167 fn test_datatype_from_wg() {
168 assert_eq!(DataType::from_wg("f32"), Some(DataType::Float32));
169 assert_eq!(DataType::from_wg("f16"), Some(DataType::Float16));
170 assert_eq!(DataType::from_wg("i32"), Some(DataType::Int32));
171 assert_eq!(DataType::from_wg("u32"), Some(DataType::Uint32));
172 assert_eq!(DataType::from_wg("i64"), Some(DataType::Int64));
173 assert_eq!(DataType::from_wg("u64"), Some(DataType::Uint64));
174 assert_eq!(DataType::from_wg("i8"), Some(DataType::Int8));
175 assert_eq!(DataType::from_wg("u8"), Some(DataType::Uint8));
176 assert_eq!(DataType::from_wg("invalid"), None);
177 assert_eq!(DataType::from_wg("float32"), None);
178 }
179
180 #[test]
181 fn test_new_graph_json() {
182 let graph = new_graph_json();
183 assert_eq!(graph.format, "webnn-graph-json");
184 assert_eq!(graph.version, 2);
185 assert!(graph.inputs.is_empty());
186 assert!(graph.consts.is_empty());
187 assert!(graph.nodes.is_empty());
188 assert!(graph.outputs.is_empty());
189 }
190
191 #[test]
192 fn test_operand_desc_equality() {
193 let desc1 = OperandDesc {
194 data_type: DataType::Float32,
195 shape: to_dimension_vector(&[1, 2, 3]),
196 };
197 let desc2 = OperandDesc {
198 data_type: DataType::Float32,
199 shape: to_dimension_vector(&[1, 2, 3]),
200 };
201 let desc3 = OperandDesc {
202 data_type: DataType::Float16,
203 shape: to_dimension_vector(&[1, 2, 3]),
204 };
205 assert_eq!(desc1, desc2);
206 assert_ne!(desc1, desc3);
207 }
208
209 #[test]
210 fn test_const_init_variants() {
211 let weights_init = ConstInit::Weights {
212 r#ref: "W".to_string(),
213 };
214 let scalar_init = ConstInit::Scalar {
215 value: serde_json::json!(1.0),
216 };
217 let bytes_init = ConstInit::InlineBytes {
218 bytes: vec![1, 2, 3, 4],
219 };
220
221 assert!(matches!(weights_init, ConstInit::Weights { .. }));
223 assert!(matches!(scalar_init, ConstInit::Scalar { .. }));
224 assert!(matches!(bytes_init, ConstInit::InlineBytes { .. }));
225 }
226}