1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
use std::collections::HashMap;
use serde::Serialize;
use tract_core::internal::*;
use crate::annotations::Annotations;
use crate::model::Model;
#[derive(Clone, Debug, Default, Serialize)]
pub struct GraphPerfInfo {
nodes: Vec<Node>,
profiling_info: Option<ProfilingInfo>,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize)]
pub struct NodeQIdSer(pub Vec<(usize, String)>, pub usize);
#[derive(Clone, Debug, Serialize)]
pub struct Node {
qualified_id: NodeQIdSer,
op_name: String,
node_name: String,
#[serde(skip_serializing_if = "HashMap::is_empty")]
cost: HashMap<String, String>,
#[serde(skip_serializing_if = "Option::is_none")]
secs_per_iter: Option<f64>,
}
#[derive(Clone, Debug, Serialize)]
pub struct ProfilingInfo {
iterations: usize,
secs_per_iter: f64,
}
impl GraphPerfInfo {
pub fn from(model: &dyn Model, annotations: &Annotations) -> GraphPerfInfo {
let nodes = annotations
.tags
.iter()
.map(|(id, node)| Node {
qualified_id: NodeQIdSer(id.0.iter().cloned().collect(), id.1),
cost: node
.cost
.iter()
.map(|(k, v)| (format!("{k:?}"), format!("{v}")))
.collect(),
node_name: id.model(model).unwrap().node_name(id.1).to_string(),
op_name: id.model(model).unwrap().node_op_name(id.1).to_string(),
secs_per_iter: node.profile.map(|s| s.as_secs_f64()),
})
.collect();
let profiling_info = annotations.profile_summary.as_ref().map(|summary| ProfilingInfo {
secs_per_iter: summary.entire.as_secs_f64(),
iterations: summary.iters,
});
GraphPerfInfo { nodes, profiling_info }
}
}