webnn-graph 0.3.0

Simple DSL for WebNN graphs
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
// Constant folding for ONNX models
// Eliminates nodes with all-constant inputs by evaluating them at conversion time

pub mod evaluators;

use crate::onnx::convert::OnnxError;
use crate::protos::onnx::{ModelProto, NodeProto, TensorProto, TensorProto_DataType};
use std::collections::{HashMap, HashSet};

/// Represents constant tensor data with various types
#[derive(Debug, Clone)]
pub enum TensorData {
    Int64(Vec<i64>),
    Int32(Vec<i32>),
    Float32(Vec<f32>),
    Float64(Vec<f64>),
    UInt8(Vec<u8>),
    Int8(Vec<i8>),
}

impl TensorData {
    /// Get the number of elements in this tensor
    pub fn len(&self) -> usize {
        match self {
            TensorData::Int64(v) => v.len(),
            TensorData::Int32(v) => v.len(),
            TensorData::Float32(v) => v.len(),
            TensorData::Float64(v) => v.len(),
            TensorData::UInt8(v) => v.len(),
            TensorData::Int8(v) => v.len(),
        }
    }

    /// Check if the tensor is empty
    pub fn is_empty(&self) -> bool {
        self.len() == 0
    }

    /// Get the data type
    pub fn data_type(&self) -> TensorProto_DataType {
        match self {
            TensorData::Int64(_) => TensorProto_DataType::Int64,
            TensorData::Int32(_) => TensorProto_DataType::Int32,
            TensorData::Float32(_) => TensorProto_DataType::Float,
            TensorData::Float64(_) => TensorProto_DataType::Double,
            TensorData::UInt8(_) => TensorProto_DataType::Uint8,
            TensorData::Int8(_) => TensorProto_DataType::Int8,
        }
    }

    /// Convert to bytes (little-endian)
    pub fn to_bytes(&self) -> Vec<u8> {
        match self {
            TensorData::Int64(v) => v.iter().flat_map(|&x| x.to_le_bytes()).collect(),
            TensorData::Int32(v) => v.iter().flat_map(|&x| x.to_le_bytes()).collect(),
            TensorData::Float32(v) => v.iter().flat_map(|&x| x.to_le_bytes()).collect(),
            TensorData::Float64(v) => v.iter().flat_map(|&x| x.to_le_bytes()).collect(),
            TensorData::UInt8(v) => v.clone(),
            TensorData::Int8(v) => v.iter().map(|&x| x as u8).collect(),
        }
    }

    /// Create from TensorProto
    pub fn from_tensor_proto(tensor: &TensorProto) -> Result<Self, OnnxError> {
        let raw_data = tensor.raw_data.as_slice();
        let data_type = tensor.data_type;

        if !raw_data.is_empty() {
            // Parse from raw bytes
            match data_type {
                x if x == TensorProto_DataType::Int64 as i32 => {
                    let values = raw_data
                        .chunks_exact(8)
                        .map(|c| {
                            i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])
                        })
                        .collect();
                    Ok(TensorData::Int64(values))
                }
                x if x == TensorProto_DataType::Int32 as i32 => {
                    let values = raw_data
                        .chunks_exact(4)
                        .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]))
                        .collect();
                    Ok(TensorData::Int32(values))
                }
                x if x == TensorProto_DataType::Float as i32 => {
                    let values = raw_data
                        .chunks_exact(4)
                        .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
                        .collect();
                    Ok(TensorData::Float32(values))
                }
                x if x == TensorProto_DataType::Double as i32 => {
                    let values = raw_data
                        .chunks_exact(8)
                        .map(|c| {
                            f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])
                        })
                        .collect();
                    Ok(TensorData::Float64(values))
                }
                x if x == TensorProto_DataType::Uint8 as i32 => {
                    Ok(TensorData::UInt8(raw_data.to_vec()))
                }
                x if x == TensorProto_DataType::Int8 as i32 => Ok(TensorData::Int8(
                    raw_data.iter().map(|&x| x as i8).collect(),
                )),
                _ => Err(OnnxError::TypeConversion(
                    webnn_onnx_utils::error::ConversionError::UnsupportedOnnxDataType(data_type),
                )),
            }
        } else {
            // Parse from typed data fields
            match data_type {
                x if x == TensorProto_DataType::Int64 as i32 => {
                    Ok(TensorData::Int64(tensor.int64_data.as_slice().to_vec()))
                }
                x if x == TensorProto_DataType::Int32 as i32 => {
                    Ok(TensorData::Int32(tensor.int32_data.as_slice().to_vec()))
                }
                x if x == TensorProto_DataType::Float as i32 => {
                    Ok(TensorData::Float32(tensor.float_data.as_slice().to_vec()))
                }
                x if x == TensorProto_DataType::Double as i32 => {
                    Ok(TensorData::Float64(tensor.double_data.as_slice().to_vec()))
                }
                _ => Err(OnnxError::TypeConversion(
                    webnn_onnx_utils::error::ConversionError::UnsupportedOnnxDataType(data_type),
                )),
            }
        }
    }
}

/// Represents a constant tensor with its shape and type
#[derive(Debug, Clone)]
pub struct ConstantTensor {
    pub data: TensorData,
    pub shape: Vec<i64>,
    pub data_type: i32,
}

impl ConstantTensor {
    /// Create a ConstantTensor from a TensorProto
    pub fn from_tensor_proto(tensor: &TensorProto) -> Result<Self, OnnxError> {
        let data = TensorData::from_tensor_proto(tensor)?;
        let shape = tensor.dims.as_slice().to_vec();
        let data_type = tensor.data_type;

        Ok(ConstantTensor {
            data,
            shape,
            data_type,
        })
    }

    /// Convert to TensorProto
    pub fn to_tensor_proto(&self, name: &str) -> TensorProto {
        TensorProto {
            name: name.to_string(),
            data_type: self.data_type,
            dims: self.shape.clone(),
            raw_data: self.data.to_bytes(),
            ..Default::default()
        }
    }

    /// Get the total number of elements
    pub fn numel(&self) -> i64 {
        if self.shape.is_empty() {
            1
        } else {
            self.shape.iter().product()
        }
    }
}

/// Context for constant folding operations
#[derive(Debug)]
pub struct ConstantFoldingContext<'a> {
    /// Map from tensor name to constant value
    pub constants: HashMap<String, ConstantTensor>,
    /// Original ONNX initializers (for reference)
    pub initializers: &'a HashMap<String, &'a TensorProto>,
}

impl<'a> ConstantFoldingContext<'a> {
    /// Create a new context from initializers
    pub fn new(initializers: &'a HashMap<String, &'a TensorProto>) -> Result<Self, OnnxError> {
        let mut constants = HashMap::new();

        for (name, tensor) in initializers.iter() {
            // Only add tensors with data
            if !tensor.raw_data.as_slice().is_empty()
                || !tensor.int64_data.as_slice().is_empty()
                || !tensor.int32_data.as_slice().is_empty()
                || !tensor.float_data.as_slice().is_empty()
                || !tensor.double_data.as_slice().is_empty()
            {
                match ConstantTensor::from_tensor_proto(tensor) {
                    Ok(ct) => {
                        constants.insert((*name).clone(), ct);
                    }
                    Err(e) => {
                        crate::debug_println!(
                            "Warning: Failed to parse initializer '{}': {}",
                            name,
                            e
                        );
                    }
                }
            }
        }

        Ok(ConstantFoldingContext {
            constants,
            initializers,
        })
    }

    /// Check if a value is a constant
    pub fn is_constant(&self, name: &str) -> bool {
        self.constants.contains_key(name)
    }

    /// Get a constant by name
    pub fn get_constant(&self, name: &str) -> Option<&ConstantTensor> {
        self.constants.get(name)
    }

    /// Add a new constant
    pub fn add_constant(&mut self, name: String, tensor: ConstantTensor) {
        self.constants.insert(name, tensor);
    }
}

/// Result of a constant folding pass
#[derive(Debug, Default)]
pub struct FoldingResult {
    /// New initializers to add to the model
    pub new_initializers: Vec<TensorProto>,
    /// Node indices to remove from the graph
    pub nodes_to_remove: HashSet<usize>,
    /// Number of nodes folded in this pass
    pub nodes_folded: usize,
}

/// Trait for operations that support constant evaluation
pub trait ConstantEvaluator {
    /// Get the operation type this evaluator handles
    fn op_type(&self) -> &str;

    /// Check if this evaluator can handle the given node
    fn can_evaluate(&self, node: &NodeProto, ctx: &ConstantFoldingContext) -> bool;

    /// Evaluate the node with constant inputs, returning output tensors
    fn evaluate(
        &self,
        node: &NodeProto,
        ctx: &ConstantFoldingContext,
    ) -> Result<Vec<ConstantTensor>, OnnxError>;
}

/// Build the initial context from model initializers
fn build_context<'a>(
    _model: &'a ModelProto,
    initializers_map: &'a HashMap<String, &'a TensorProto>,
) -> Result<ConstantFoldingContext<'a>, OnnxError> {
    ConstantFoldingContext::new(initializers_map)
}

/// Identify nodes that have all constant inputs
fn identify_constant_nodes(
    model: &ModelProto,
    ctx: &ConstantFoldingContext,
    evaluators: &[Box<dyn ConstantEvaluator>],
) -> Result<Vec<usize>, OnnxError> {
    let graph = model.graph.as_ref().unwrap();
    let mut constant_nodes = Vec::new();

    for (idx, node) in graph.node.as_slice().iter().enumerate() {
        // Check if any evaluator can handle this node
        let can_evaluate = evaluators.iter().any(|e| e.can_evaluate(node, ctx));

        if can_evaluate {
            constant_nodes.push(idx);
        }
    }

    Ok(constant_nodes)
}

/// Evaluate constant nodes and return the folding result
fn evaluate_constant_nodes(
    model: &ModelProto,
    constant_node_indices: &[usize],
    ctx: &mut ConstantFoldingContext,
    evaluators: &[Box<dyn ConstantEvaluator>],
) -> Result<FoldingResult, OnnxError> {
    let graph = model.graph.as_ref().unwrap();
    let mut result = FoldingResult::default();

    for &idx in constant_node_indices {
        let node = &graph.node.as_slice()[idx];

        // Find an evaluator that can handle this node
        let evaluator = evaluators.iter().find(|e| e.can_evaluate(node, ctx));

        if let Some(evaluator) = evaluator {
            match evaluator.evaluate(node, ctx) {
                Ok(output_tensors) => {
                    // Add outputs as new initializers
                    for (i, tensor) in output_tensors.iter().enumerate() {
                        if i < node.output.as_slice().len() {
                            let output_name = &node.output.as_slice()[i];
                            let proto = tensor.to_tensor_proto(output_name);
                            result.new_initializers.push(proto.clone());

                            // Add to context for subsequent evaluations
                            ctx.add_constant(output_name.to_string(), tensor.clone());
                        }
                    }

                    result.nodes_to_remove.insert(idx);
                    result.nodes_folded += 1;
                }
                Err(e) => {
                    crate::debug_println!(
                        "Warning: Failed to evaluate constant node '{}' ({}): {}",
                        node.name.as_str(),
                        node.op_type.as_str(),
                        e
                    );
                }
            }
        }
    }

    Ok(result)
}

/// Main entry point: fold constants in an ONNX model
pub fn fold_constants_in_model(
    model: &mut ModelProto,
    evaluators: &[Box<dyn ConstantEvaluator>],
) -> Result<usize, OnnxError> {
    let mut total_folded = 0;
    let max_iterations = 10;

    // Build initializers map
    let graph = model.graph.as_ref().unwrap();
    let mut initializers_map: HashMap<String, &TensorProto> = HashMap::new();
    for init in graph.initializer.as_slice() {
        initializers_map.insert(init.name.as_str().to_string(), init);
    }

    for iteration in 0..max_iterations {
        // 1. Build context from current initializers
        let initializers_map_ref: HashMap<String, &TensorProto> = model
            .graph
            .as_ref()
            .unwrap()
            .initializer
            .as_slice()
            .iter()
            .map(|init| (init.name.as_str().to_string(), init))
            .collect();

        let mut ctx = build_context(model, &initializers_map_ref)?;

        // 2. Identify constant nodes
        let constant_nodes = identify_constant_nodes(model, &ctx, evaluators)?;

        if constant_nodes.is_empty() {
            break;
        }

        // 3. Evaluate constant nodes
        let result = evaluate_constant_nodes(model, &constant_nodes, &mut ctx, evaluators)?;

        if result.nodes_folded == 0 {
            break;
        }

        // 4. Add new initializers to the model
        let graph_mut = model.graph.as_mut().unwrap();
        for init in result.new_initializers {
            graph_mut.initializer.push(init);
        }

        // 5. Remove evaluated nodes
        let nodes = graph_mut.node.as_slice().to_vec();
        graph_mut.node.clear();
        for (idx, node) in nodes.into_iter().enumerate() {
            if !result.nodes_to_remove.contains(&idx) {
                graph_mut.node.push(node);
            }
        }

        total_folded += result.nodes_folded;

        crate::debug_println!(
            "Constant folding iteration {}: {} nodes folded",
            iteration + 1,
            result.nodes_folded
        );
    }

    Ok(total_folded)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_tensor_data_len() {
        let data = TensorData::Int64(vec![1, 2, 3]);
        assert_eq!(data.len(), 3);

        let data = TensorData::Float32(vec![1.0, 2.0]);
        assert_eq!(data.len(), 2);
    }

    #[test]
    fn test_tensor_data_to_bytes() {
        let data = TensorData::Int32(vec![1, 2, 3]);
        let bytes = data.to_bytes();
        assert_eq!(bytes.len(), 12); // 3 * 4 bytes

        let data = TensorData::Int64(vec![1, 2]);
        let bytes = data.to_bytes();
        assert_eq!(bytes.len(), 16); // 2 * 8 bytes
    }

    #[test]
    fn test_constant_tensor_numel() {
        let ct = ConstantTensor {
            data: TensorData::Int64(vec![1, 2, 3, 4, 5, 6]),
            shape: vec![2, 3],
            data_type: TensorProto_DataType::Int64 as i32,
        };
        assert_eq!(ct.numel(), 6);

        let ct = ConstantTensor {
            data: TensorData::Int64(vec![42]),
            shape: vec![],
            data_type: TensorProto_DataType::Int64 as i32,
        };
        assert_eq!(ct.numel(), 1);
    }
}