use crate::error::CliError;
use aprender::format::rosetta::RosettaStone;
use colored::Colorize;
use std::collections::BTreeMap;
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum TreeFormat {
Ascii,
Dot,
Mermaid,
Json,
}
impl std::str::FromStr for TreeFormat {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"ascii" | "text" => Ok(Self::Ascii),
"dot" | "graphviz" => Ok(Self::Dot),
"mermaid" | "md" => Ok(Self::Mermaid),
"json" => Ok(Self::Json),
_ => Err(format!("Unknown format: {s}")),
}
}
}
#[derive(Debug, Clone)]
struct TreeNode {
name: String,
full_path: String,
shape: Option<Vec<usize>>,
size_bytes: usize,
children: BTreeMap<String, TreeNode>,
is_leaf: bool,
}
impl TreeNode {
fn new(name: &str, full_path: &str) -> Self {
Self {
name: name.to_string(),
full_path: full_path.to_string(),
shape: None,
size_bytes: 0,
children: BTreeMap::new(),
is_leaf: false,
}
}
fn total_size(&self) -> usize {
if self.is_leaf {
self.size_bytes
} else {
self.children.values().map(TreeNode::total_size).sum()
}
}
fn tensor_count(&self) -> usize {
if self.is_leaf {
1
} else {
self.children.values().map(TreeNode::tensor_count).sum()
}
}
}
fn build_tree_from_rosetta(
report: &aprender::format::rosetta::InspectionReport,
filter: Option<&str>,
) -> TreeNode {
let mut root = TreeNode::new("model", "");
for tensor in &report.tensors {
if let Some(f) = filter {
if !tensor.name.contains(f) {
continue;
}
}
insert_tensor(&mut root, &tensor.name, &tensor.shape, tensor.size_bytes);
}
root
}
fn insert_tensor(root: &mut TreeNode, name: &str, shape: &[usize], size: usize) {
let parts: Vec<&str> = name.split('.').collect();
let mut current = root;
for (i, part) in parts.iter().enumerate() {
let path = parts[..=i].join(".");
if !current.children.contains_key(*part) {
current
.children
.insert((*part).to_string(), TreeNode::new(part, &path));
}
current = current.children.get_mut(*part).expect("just inserted");
if i == parts.len() - 1 {
current.is_leaf = true;
current.shape = Some(shape.to_vec());
current.size_bytes = size;
}
}
}
#[provable_contracts_macros::contract(
"apr-cli-operations-v1",
equation = "side_effect_classification"
)]
pub(crate) fn run(
apr_path: &Path,
filter: Option<&str>,
format: TreeFormat,
show_sizes: bool,
max_depth: Option<usize>,
) -> Result<(), CliError> {
if !apr_path.exists() {
return Err(CliError::FileNotFound(apr_path.to_path_buf()));
}
let rosetta = RosettaStone::new();
let report = rosetta
.inspect(apr_path)
.map_err(|e| CliError::InvalidFormat(format!("Failed to inspect: {e}")))?;
let root = build_tree_from_rosetta(&report, filter);
match format {
TreeFormat::Ascii => print_ascii_tree(&root, apr_path, show_sizes, max_depth),
TreeFormat::Dot => print_dot_graph(&root),
TreeFormat::Mermaid => print_mermaid_graph(&root),
TreeFormat::Json => print_json_tree(&root),
}
Ok(())
}
fn print_ascii_tree(root: &TreeNode, path: &Path, show_sizes: bool, max_depth: Option<usize>) {
let total_size = root.total_size();
let tensor_count = root.tensor_count();
println!(
"{} ({} tensors, {})",
path.file_name()
.and_then(|s| s.to_str())
.unwrap_or("model")
.cyan()
.bold(),
tensor_count.to_string().green(),
format_size(total_size).yellow()
);
let children: Vec<_> = root.children.values().collect();
for (i, child) in children.iter().enumerate() {
let is_last = i == children.len() - 1;
print_tree_node(child, "", is_last, show_sizes, 0, max_depth);
}
}
#[allow(clippy::disallowed_methods)] fn print_tree_node(
node: &TreeNode,
prefix: &str,
is_last: bool,
show_sizes: bool,
depth: usize,
max_depth: Option<usize>,
) {
if let Some(max) = max_depth {
if depth >= max {
return;
}
}
let connector = if is_last { "└── " } else { "├── " };
let child_prefix = if is_last { " " } else { "│ " };
let name_str = if node.is_leaf {
format!("{}", node.name.green())
} else {
format!("{}", node.name.blue().bold())
};
let info_str = if node.is_leaf {
let shape_str = node
.shape
.as_ref()
.map(|s| format!("{s:?}"))
.unwrap_or_default();
let size_str = if show_sizes {
format!(" ━━━ {}", format_size(node.size_bytes).dimmed())
} else {
String::new()
};
format!(" {}{}", shape_str.dimmed(), size_str)
} else if show_sizes {
let total = node.total_size();
let count = node.tensor_count();
format!(
" [{} tensors, {}]",
count.to_string().dimmed(),
format_size(total).dimmed()
)
} else {
String::new()
};
println!("{prefix}{connector}{name_str}{info_str}");
let children: Vec<_> = node.children.values().collect();
for (i, child) in children.iter().enumerate() {
let child_is_last = i == children.len() - 1;
let new_prefix = format!("{prefix}{child_prefix}");
print_tree_node(
child,
&new_prefix,
child_is_last,
show_sizes,
depth + 1,
max_depth,
);
}
}
fn print_dot_graph(root: &TreeNode) {
println!("digraph model {{");
println!(" rankdir=TB;");
println!(" node [shape=box, fontname=\"Helvetica\"];");
println!(" edge [fontname=\"Helvetica\", fontsize=10];");
println!();
print_dot_nodes(root, "root");
println!("}}");
}
#[allow(clippy::disallowed_methods)] fn print_dot_nodes(node: &TreeNode, parent_id: &str) {
for (name, child) in &node.children {
let node_id = format!("{}_{}", parent_id, name.replace('.', "_"));
let (shape, color) = if child.is_leaf {
("box", "lightblue")
} else if child.children.len() > 3 {
("folder", "lightyellow")
} else {
("box", "white")
};
let label = if child.is_leaf {
let shape_str = child
.shape
.as_ref()
.map(|s| format!("{s:?}"))
.unwrap_or_default();
format!("{name}\\n{shape_str}")
} else {
format!("{}\\n({} tensors)", name, child.tensor_count())
};
println!(
" {node_id} [label=\"{label}\", shape={shape}, fillcolor=\"{color}\", style=filled];"
);
println!(" {parent_id} -> {node_id};");
print_dot_nodes(child, &node_id);
}
}
fn print_mermaid_graph(root: &TreeNode) {
println!("```mermaid");
println!("graph TD");
print_mermaid_nodes(root, "root", "Model");
println!("```");
}
#[allow(clippy::disallowed_methods)] fn print_mermaid_nodes(node: &TreeNode, parent_id: &str, parent_label: &str) {
if parent_id == "root" {
println!(" {parent_id}[{parent_label}]");
}
for (i, (name, child)) in node.children.iter().enumerate() {
let node_id = format!("{parent_id}_{i}");
let label = if child.is_leaf {
let shape_str = child
.shape
.as_ref()
.map(|s| format!("{s:?}"))
.unwrap_or_default();
format!("{name}<br/>{shape_str}")
} else {
name.clone()
};
let node_def = if child.is_leaf {
format!("{node_id}[{label}]")
} else {
format!("{node_id}{{{{ {label} }}}}")
};
println!(" {parent_id} --> {node_def}");
print_mermaid_nodes(child, &node_id, name);
}
}
#[allow(clippy::disallowed_methods)] fn print_json_tree(root: &TreeNode) {
fn to_json(node: &TreeNode) -> serde_json::Value {
if node.is_leaf {
serde_json::json!({
"name": node.name,
"path": node.full_path,
"shape": node.shape,
"size_bytes": node.size_bytes,
"type": "tensor"
})
} else {
let children: Vec<_> = node.children.values().map(to_json).collect();
serde_json::json!({
"name": node.name,
"path": node.full_path,
"total_size": node.total_size(),
"tensor_count": node.tensor_count(),
"type": "group",
"children": children
})
}
}
if let Ok(json) = serde_json::to_string_pretty(&to_json(root)) {
println!("{json}");
}
}
fn format_size(bytes: usize) -> String {
batuta_common::fmt::format_bytes(bytes as u64)
}
#[cfg(test)]
#[path = "tree_tests.rs"]
mod tests;