tract_libcli/
export.rs

1use std::collections::HashMap;
2use serde::Serialize;
3use tract_core::internal::*;
4use crate::annotations::Annotations;
5use crate::model::Model;
6
7#[derive(Clone, Debug, Default, Serialize)]
8pub struct GraphPerfInfo {
9    nodes: Vec<Node>,
10    profiling_info: Option<ProfilingInfo>,
11}
12
13#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize)]
14pub struct NodeQIdSer(pub Vec<(usize, String)>, pub usize);
15
16#[derive(Clone, Debug, Serialize)]
17pub struct Node {
18    qualified_id: NodeQIdSer,
19    op_name: String,
20    node_name: String,
21
22    #[serde(skip_serializing_if = "HashMap::is_empty")]
23    cost: HashMap<String, String>,
24
25    #[serde(skip_serializing_if = "Option::is_none")]
26    secs_per_iter: Option<f64>,
27}
28
29#[derive(Clone, Debug, Serialize)]
30pub struct ProfilingInfo {
31    iterations: usize,
32    secs_per_iter: f64,
33}
34
35impl GraphPerfInfo {
36    pub fn from(model: &dyn Model, annotations: &Annotations) -> GraphPerfInfo {
37        let nodes = annotations
38            .tags
39            .iter()
40            .map(|(id, node)| Node {
41                qualified_id: NodeQIdSer(id.0.iter().cloned().collect(), id.1),
42                cost: node
43                    .cost
44                    .iter()
45                    .map(|(k, v)| (format!("{k:?}"), format!("{v}")))
46                    .collect(),
47                node_name: id.model(model).unwrap().node_name(id.1).to_string(),
48                op_name: id.model(model).unwrap().node_op_name(id.1).to_string(),
49                secs_per_iter: node.profile.map(|s| s.as_secs_f64()),
50            })
51            .collect();
52        let profiling_info = annotations.profile_summary.as_ref().map(|summary| ProfilingInfo {
53            secs_per_iter: summary.entire.as_secs_f64(),
54            iterations: summary.iters,
55        });
56        GraphPerfInfo { nodes, profiling_info }
57    }
58}