1use crate::ast::{ConstInit, Dimension, GraphJson, Node};
2use thiserror::Error;
3
4#[derive(Debug, Error)]
5pub enum SerializeError {
6 #[error("invalid format: {0}")]
7 InvalidFormat(String),
8 #[error("unsupported version: {0}")]
9 UnsupportedVersion(u32),
10}
11
12#[derive(Debug, Clone, Copy, Default)]
14pub struct SerializeOptions {
15 pub quantized: bool,
19}
20
21pub fn serialize_graph_to_wg_text(
22 graph: &GraphJson,
23 opts: SerializeOptions,
24) -> Result<String, SerializeError> {
25 let mut output = String::new();
26
27 if graph.format != "webnn-graph-json" {
29 return Err(SerializeError::InvalidFormat(graph.format.clone()));
30 }
31 if graph.version != 1 && graph.version != 2 {
32 return Err(SerializeError::UnsupportedVersion(graph.version));
33 }
34
35 let name = graph.name.as_deref().unwrap_or("graph");
37 let quantized_flag = if opts.quantized || graph.quantized {
38 " @quantized"
39 } else {
40 ""
41 };
42 output.push_str(&format!(
43 "webnn_graph \"{}\" v{}{} {{\n",
44 escape_string(name),
45 graph.version,
46 quantized_flag
47 ));
48
49 if !graph.inputs.is_empty() {
51 output.push_str(" inputs {\n");
52 for (name, desc) in &graph.inputs {
53 let dtype = desc.data_type.to_wg_text();
54 let shape = serialize_shape(&desc.shape)?;
55 output.push_str(&format!(" {}: {}{};\n", name, dtype, shape));
56 }
57 output.push_str(" }\n\n");
58 }
59
60 if !graph.consts.is_empty() {
62 output.push_str(" consts {\n");
63 for (name, const_decl) in &graph.consts {
64 let dtype = const_decl.data_type.to_wg_text();
65 let shape = serialize_shape_u32(&const_decl.shape);
66 let annotation = serialize_const_init(&const_decl.init);
67 output.push_str(&format!(" {}: {}{}{}", name, dtype, shape, annotation));
68 output.push_str(";\n");
69 }
70 output.push_str(" }\n\n");
71 }
72
73 if !graph.nodes.is_empty() {
75 output.push_str(" nodes {\n");
76 for node in &graph.nodes {
77 output.push_str(&format!(" {}\n", serialize_node(node)));
78 }
79 output.push_str(" }\n\n");
80 }
81
82 output.push_str(" outputs {");
84 if !graph.outputs.is_empty() {
85 let outputs: Vec<String> = graph.outputs.keys().map(|k| format!(" {};", k)).collect();
86 output.push_str(&outputs.join(""));
87 output.push(' ');
88 }
89 output.push_str("}\n");
90
91 output.push_str("}\n");
92 Ok(output)
93}
94
95fn serialize_shape(shape: &[Dimension]) -> Result<String, SerializeError> {
96 let mut dims: Vec<String> = Vec::with_capacity(shape.len());
97 for dim in shape {
98 match dim {
99 Dimension::Static(v) => dims.push(v.to_string()),
100 Dimension::Dynamic(d) => {
101 dims.push(format!(
102 "dyn(\"{}\", {})",
103 escape_string(&d.name),
104 d.max_size
105 ));
106 }
107 }
108 }
109 Ok(format!("[{}]", dims.join(", ")))
110}
111
112fn serialize_shape_u32(shape: &[u32]) -> String {
113 let dims: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
114 format!("[{}]", dims.join(", "))
115}
116
117fn serialize_const_init(init: &ConstInit) -> String {
118 match init {
119 ConstInit::Weights { r#ref } => {
120 format!(" @weights(\"{}\")", escape_string(r#ref))
121 }
122 ConstInit::Scalar { value } => {
123 format!(" @scalar({})", serialize_json_value(value))
124 }
125 ConstInit::InlineBytes { bytes } => {
126 let nums: Vec<String> = bytes.iter().map(|b| b.to_string()).collect();
127 format!(" @bytes([{}])", nums.join(", "))
128 }
129 }
130}
131
132fn serialize_node(node: &Node) -> String {
133 let call = serialize_call(&node.op, &node.inputs, &node.options);
134
135 if let Some(outputs) = &node.outputs {
136 let out_list = outputs.join(", ");
138 format!("[{}] = {};", out_list, call)
139 } else {
140 format!("{} = {};", node.id, call)
142 }
143}
144
145fn serialize_call(
146 op: &str,
147 inputs: &[String],
148 options: &serde_json::Map<String, serde_json::Value>,
149) -> String {
150 let mut args = Vec::new();
151
152 for input in inputs {
154 args.push(input.clone());
155 }
156
157 for (key, value) in options {
159 args.push(format!("{}={}", key, serialize_json_value(value)));
160 }
161
162 format!("{}({})", op, args.join(", "))
163}
164
165fn serialize_json_value(value: &serde_json::Value) -> String {
166 match value {
167 serde_json::Value::Null => "null".to_string(),
168 serde_json::Value::Bool(b) => b.to_string(),
169 serde_json::Value::Number(n) => n.to_string(),
170 serde_json::Value::String(s) => format!("\"{}\"", escape_string(s)),
171 serde_json::Value::Array(arr) => {
172 let items: Vec<String> = arr.iter().map(serialize_json_value).collect();
173 format!("[{}]", items.join(", "))
174 }
175 serde_json::Value::Object(_) => {
176 value.to_string()
178 }
179 }
180}
181
182fn escape_string(s: &str) -> String {
183 s.replace('\\', "\\\\").replace('"', "\\\"")
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use crate::ast::{
190 new_graph_json, to_dimension_vector, ConstDecl, ConstInit, DataType, Node, OperandDesc,
191 };
192 use crate::parser::parse_wg_text;
193
194 #[test]
195 fn test_serialize_simple_graph() {
196 let mut g = new_graph_json();
197 g.name = Some("test".to_string());
198 g.inputs.insert(
199 "x".to_string(),
200 OperandDesc {
201 data_type: DataType::Float32,
202 shape: to_dimension_vector(&[1, 10]),
203 },
204 );
205 g.nodes.push(Node {
206 id: "result".to_string(),
207 op: "relu".to_string(),
208 inputs: vec!["x".to_string()],
209 options: serde_json::Map::new(),
210 outputs: None,
211 });
212 g.outputs.insert("result".to_string(), "result".to_string());
213
214 let text = serialize_graph_to_wg_text(&g, SerializeOptions::default()).unwrap();
215 assert!(text.contains(&format!("webnn_graph \"test\" v{}", g.version)));
216 assert!(text.contains("inputs {"));
217 assert!(text.contains("x: f32[1, 10];"));
218 assert!(text.contains("nodes {"));
219 assert!(text.contains("result = relu(x);"));
220 assert!(text.contains("outputs { result; }"));
221 }
222
223 #[test]
224 fn test_serialize_dynamic_input_shape() {
225 let mut g = new_graph_json();
226 g.name = Some("dyn".to_string());
227 g.inputs.insert(
228 "x".to_string(),
229 OperandDesc {
230 data_type: DataType::Float32,
231 shape: vec![
232 Dimension::Dynamic(crate::ast::DynamicDimension {
233 name: "batch_size".to_string(),
234 max_size: 8,
235 }),
236 Dimension::Static(128),
237 ],
238 },
239 );
240 g.outputs.insert("x".to_string(), "x".to_string());
241 let text = serialize_graph_to_wg_text(&g, SerializeOptions::default()).unwrap();
242 assert!(text.contains("x: f32[dyn(\"batch_size\", 8), 128];"));
243 }
244
245 #[test]
246 fn test_serialize_weights_annotation() {
247 let mut g = new_graph_json();
248 g.name = Some("test".to_string());
249 g.consts.insert(
250 "W".to_string(),
251 ConstDecl {
252 data_type: DataType::Float32,
253 shape: vec![10, 5],
254 init: ConstInit::Weights {
255 r#ref: "W".to_string(),
256 },
257 },
258 );
259 g.outputs.insert("W".to_string(), "W".to_string());
260
261 let text = serialize_graph_to_wg_text(&g, SerializeOptions::default()).unwrap();
262 assert!(text.contains("W: f32[10, 5] @weights(\"W\");"));
263 }
264
265 #[test]
266 fn test_serialize_scalar_annotation() {
267 let mut g = new_graph_json();
268 g.name = Some("test".to_string());
269 g.consts.insert(
270 "scale".to_string(),
271 ConstDecl {
272 data_type: DataType::Float32,
273 shape: vec![1],
274 init: ConstInit::Scalar {
275 value: serde_json::json!(3.5),
276 },
277 },
278 );
279 g.outputs.insert("scale".to_string(), "scale".to_string());
280
281 let text = serialize_graph_to_wg_text(&g, SerializeOptions::default()).unwrap();
282 assert!(text.contains("scale: f32[1] @scalar(3.5);"));
283 }
284
285 #[test]
286 fn test_serialize_multi_output_node() {
287 let mut g = new_graph_json();
288 g.name = Some("test".to_string());
289 g.inputs.insert(
290 "x".to_string(),
291 OperandDesc {
292 data_type: DataType::Float32,
293 shape: to_dimension_vector(&[10]),
294 },
295 );
296 g.nodes.push(Node {
297 id: "a".to_string(),
298 op: "split".to_string(),
299 inputs: vec!["x".to_string()],
300 options: serde_json::Map::new(),
301 outputs: Some(vec!["a".to_string(), "b".to_string()]),
302 });
303 g.outputs.insert("a".to_string(), "a".to_string());
304
305 let text = serialize_graph_to_wg_text(&g, SerializeOptions::default()).unwrap();
306 assert!(text.contains("[a, b] = split(x);"));
307 }
308
309 #[test]
310 fn test_serialize_node_options() {
311 let mut g = new_graph_json();
312 g.name = Some("test".to_string());
313 g.inputs.insert(
314 "x".to_string(),
315 OperandDesc {
316 data_type: DataType::Float32,
317 shape: to_dimension_vector(&[1, 10]),
318 },
319 );
320
321 let mut options = serde_json::Map::new();
322 options.insert("axis".to_string(), serde_json::json!(1));
323 options.insert("keepdims".to_string(), serde_json::json!(true));
324
325 g.nodes.push(Node {
326 id: "result".to_string(),
327 op: "softmax".to_string(),
328 inputs: vec!["x".to_string()],
329 options,
330 outputs: None,
331 });
332 g.outputs.insert("result".to_string(), "result".to_string());
333
334 let text = serialize_graph_to_wg_text(&g, SerializeOptions::default()).unwrap();
335 assert!(text.contains("softmax(x,"));
336 assert!(text.contains("axis=1"));
337 assert!(text.contains("keepdims=true"));
338 }
339
340 #[test]
341 fn test_serialize_various_dtypes() {
342 let mut g = new_graph_json();
343 g.name = Some("test".to_string());
344
345 let dtypes = vec![
346 ("f32_input", DataType::Float32),
347 ("f16_input", DataType::Float16),
348 ("i32_input", DataType::Int32),
349 ("u32_input", DataType::Uint32),
350 ("i64_input", DataType::Int64),
351 ("u64_input", DataType::Uint64),
352 ("i8_input", DataType::Int8),
353 ("u8_input", DataType::Uint8),
354 ];
355
356 for (name, dtype) in dtypes {
357 g.inputs.insert(
358 name.to_string(),
359 OperandDesc {
360 data_type: dtype,
361 shape: to_dimension_vector(&[1]),
362 },
363 );
364 }
365 g.outputs
366 .insert("f32_input".to_string(), "f32_input".to_string());
367
368 let text = serialize_graph_to_wg_text(&g, SerializeOptions::default()).unwrap();
369 assert!(text.contains("f32_input: f32[1];"));
370 assert!(text.contains("f16_input: f16[1];"));
371 assert!(text.contains("i32_input: i32[1];"));
372 assert!(text.contains("u32_input: u32[1];"));
373 assert!(text.contains("i64_input: i64[1];"));
374 assert!(text.contains("u64_input: u64[1];"));
375 assert!(text.contains("i8_input: i8[1];"));
376 assert!(text.contains("u8_input: u8[1];"));
377 }
378
379 #[test]
380 fn test_roundtrip() {
381 let input = r#"
382webnn_graph "resnet_head" v1 {
383 inputs {
384 x: f32[1, 2048];
385 }
386 consts {
387 W: f32[2048, 1000] @weights("W");
388 b: f32[1000] @weights("b");
389 }
390 nodes {
391 logits0 = matmul(x, W);
392 logits = add(logits0, b);
393 probs = softmax(logits, axis=1);
394 }
395 outputs { probs; }
396}
397"#;
398 let graph = parse_wg_text(input).unwrap();
400
401 let serialized = serialize_graph_to_wg_text(&graph, SerializeOptions::default()).unwrap();
403
404 let graph2 = parse_wg_text(&serialized).unwrap();
406
407 assert_eq!(graph.name, graph2.name);
409 assert_eq!(graph.inputs.len(), graph2.inputs.len());
410 assert_eq!(graph.consts.len(), graph2.consts.len());
411 assert_eq!(graph.nodes.len(), graph2.nodes.len());
412 assert_eq!(graph.outputs.len(), graph2.outputs.len());
413 }
414
415 #[test]
416 fn test_default_graph_name() {
417 let mut g = new_graph_json();
418 g.outputs.insert("x".to_string(), "x".to_string());
420
421 let text = serialize_graph_to_wg_text(&g, SerializeOptions::default()).unwrap();
422 assert!(text.contains(&format!("webnn_graph \"graph\" v{}", g.version)));
423 }
424
425 #[test]
426 fn test_string_escaping() {
427 let mut g = new_graph_json();
428 g.name = Some("test\"with\\quotes".to_string());
429 g.outputs.insert("x".to_string(), "x".to_string());
430
431 let text = serialize_graph_to_wg_text(&g, SerializeOptions::default()).unwrap();
432 assert!(text.contains(&format!(
433 "webnn_graph \"test\\\"with\\\\quotes\" v{}",
434 g.version
435 )));
436 }
437
438 #[test]
439 fn test_value_types() {
440 let mut g = new_graph_json();
441 g.name = Some("test".to_string());
442 g.inputs.insert(
443 "x".to_string(),
444 OperandDesc {
445 data_type: DataType::Float32,
446 shape: to_dimension_vector(&[1]),
447 },
448 );
449
450 let mut options = serde_json::Map::new();
451 options.insert("int_val".to_string(), serde_json::json!(42));
452 options.insert("float_val".to_string(), serde_json::json!(3.5));
453 options.insert("bool_val".to_string(), serde_json::json!(true));
454 options.insert("null_val".to_string(), serde_json::json!(null));
455 options.insert("array_val".to_string(), serde_json::json!([1, 2, 3]));
456
457 g.nodes.push(Node {
458 id: "result".to_string(),
459 op: "test_op".to_string(),
460 inputs: vec!["x".to_string()],
461 options,
462 outputs: None,
463 });
464 g.outputs.insert("result".to_string(), "result".to_string());
465
466 let text = serialize_graph_to_wg_text(&g, SerializeOptions::default()).unwrap();
467 assert!(text.contains("int_val=42"));
468 assert!(text.contains("float_val=3.5"));
469 assert!(text.contains("bool_val=true"));
470 assert!(text.contains("null_val=null"));
471 assert!(text.contains("array_val=[1, 2, 3]"));
472 }
473}