mlmdquery 0.2.0

A command-line tool to query the contents of an ml-metadata DB.
Documentation
use mlmd::metadata::{
    Artifact, ArtifactId, ArtifactType, Event, EventType, Execution, ExecutionId, ExecutionType,
    TypeId,
};
use mlmd::MetadataStore;
use palette::{Gradient, Srgb};
use std::collections::{BTreeMap, HashMap, HashSet};
use std::io::Write;
use tinytemplate::TinyTemplate;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum NodeId {
    Artifact(ArtifactId),
    Execution(ExecutionId),
}

impl std::fmt::Display for NodeId {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
        match self {
            Self::Artifact(x) => write!(f, "{}@artifact", x.get()),
            Self::Execution(x) => write!(f, "{}@execution", x.get()),
        }
    }
}

#[derive(Debug)]
pub enum Node {
    Artifact(Artifact),
    Execution(Execution),
}

impl Node {
    pub fn id(&self) -> NodeId {
        match self {
            Self::Artifact(x) => NodeId::Artifact(x.id),
            Self::Execution(x) => NodeId::Execution(x.id),
        }
    }

    pub fn label(&self) -> String {
        match self {
            Self::Artifact(x) => x.id.get().to_string(),
            Self::Execution(x) => x.id.get().to_string(),
        }
    }

    pub fn color(&self, colors: &HashMap<TypeId, Srgb<u8>>) -> String {
        let type_id = match self {
            Self::Artifact(x) => x.type_id,
            Self::Execution(x) => x.type_id,
        };
        let color = colors[&type_id];
        format!("#{:02x}{:02x}{:02x}", color.red, color.green, color.blue)
    }

    pub fn url(&self, template: Option<&TinyTemplate>) -> anyhow::Result<String> {
        if let Some(tt) = template {
            let context = match self {
                Self::Artifact(x) => UrlTemplateContext {
                    node_type: "artifact",
                    id: x.id.get(),
                },
                Self::Execution(x) => UrlTemplateContext {
                    node_type: "execution",
                    id: x.id.get(),
                },
            };
            Ok(tt.render("url", &context)?)
        } else {
            Ok("".to_owned())
        }
    }

    pub fn shape(&self) -> &str {
        match self {
            Self::Artifact(_) => "ellipse",
            Self::Execution(_) => "box",
        }
    }

    pub fn style(&self, origin: NodeId) -> &str {
        if self.id() == origin {
            "bold,dashed,filled"
        } else {
            "solid,filled"
        }
    }

    pub fn tooltip(&self, types: &BTreeMap<TypeId, Type>) -> anyhow::Result<String> {
        match self {
            Self::Artifact(x) => {
                let artifact = crate::serialize::ArtifactNode::new(
                    types[&x.type_id].name().to_owned(),
                    x.clone(),
                );
                Ok(serde_json::to_string_pretty(&artifact)?)
            }
            Self::Execution(x) => {
                let execution = crate::serialize::ExecutionNode::new(
                    types[&x.type_id].name().to_owned(),
                    x.clone(),
                );
                Ok(serde_json::to_string_pretty(&execution)?)
            }
        }
    }
}

#[derive(Debug, PartialEq, Eq, Hash)]
pub struct Edge {
    event: Event,
}

impl Edge {
    pub fn new(event: Event) -> Self {
        Self { event }
    }

    pub fn label(&self) -> anyhow::Result<String> {
        let path = self
            .event
            .path
            .iter()
            .cloned()
            .map(crate::serialize::EventStep::from)
            .collect::<Vec<_>>();
        if path.is_empty() {
            return Ok("".to_owned());
        }
        Ok(serde_json::to_string(&path)?)
    }

    #[allow(clippy::wrong_self_convention)]
    pub fn from_node(&self) -> NodeId {
        let is_input = matches!(
            self.event.ty,
            EventType::Input | EventType::DeclaredInput | EventType::InternalInput
        );
        if is_input {
            NodeId::Artifact(self.event.artifact_id)
        } else {
            NodeId::Execution(self.event.execution_id)
        }
    }

    pub fn to_node(&self) -> NodeId {
        let is_input = matches!(
            self.event.ty,
            EventType::Input | EventType::DeclaredInput | EventType::InternalInput
        );
        if is_input {
            NodeId::Execution(self.event.execution_id)
        } else {
            NodeId::Artifact(self.event.artifact_id)
        }
    }
}

#[derive(Debug, serde::Serialize)]
struct UrlTemplateContext {
    node_type: &'static str,
    id: i32,
}

#[derive(Debug)]
pub enum Type {
    Artifact(ArtifactType),
    Execution(ExecutionType),
}

impl Type {
    pub fn id(&self) -> TypeId {
        match self {
            Self::Artifact(x) => x.id,
            Self::Execution(x) => x.id,
        }
    }

    pub fn name(&self) -> &str {
        match self {
            Self::Artifact(x) => &x.name,
            Self::Execution(x) => &x.name,
        }
    }

    pub fn shape(&self) -> &str {
        match self {
            Self::Artifact(_) => "ellipse",
            Self::Execution(_) => "box",
        }
    }
}

#[derive(Debug)]
pub struct Graph {
    origin: NodeId,
    nodes: HashMap<NodeId, Node>,
    edges: HashSet<Edge>,
    types: BTreeMap<TypeId, Type>,
    colors: HashMap<TypeId, Srgb<u8>>,
    url_template: Option<String>,
}

impl Graph {
    pub async fn new(
        store: &mut MetadataStore,
        origin: NodeId,
        nodes: HashMap<NodeId, Node>,
        edges: HashSet<Edge>,
        url_template: Option<String>,
    ) -> anyhow::Result<Self> {
        let mut types = BTreeMap::new();
        types.extend(
            store
                .get_artifact_types()
                .ids(
                    nodes
                        .values()
                        .filter_map(|x| {
                            if let Node::Artifact(x) = x {
                                Some(x)
                            } else {
                                None
                            }
                        })
                        .map(|x| x.type_id),
                )
                .execute()
                .await?
                .into_iter()
                .map(|x| (x.id, Type::Artifact(x))),
        );
        let artifact_type_count = types.len();

        types.extend(
            store
                .get_execution_types()
                .ids(
                    nodes
                        .values()
                        .filter_map(|x| {
                            if let Node::Execution(x) = x {
                                Some(x)
                            } else {
                                None
                            }
                        })
                        .map(|x| x.type_id),
                )
                .execute()
                .await?
                .into_iter()
                .map(|x| (x.id, Type::Execution(x))),
        );
        let execution_type_count = types.len() - artifact_type_count;

        let gradient = Gradient::new(vec![
            Srgb::new(1.0, 1.0, 1.0).into_linear(),
            Srgb::new(0.5, 0.5, 0.5).into_linear(),
        ]);
        let colors = types
            .iter()
            .filter_map(|(id, ty)| {
                if matches!(ty, Type::Artifact(_)) {
                    Some(*id)
                } else {
                    None
                }
            })
            .zip(gradient.take(artifact_type_count))
            .map(|(id, color)| (id, Srgb::<u8>::from(color)))
            .chain(
                types
                    .iter()
                    .filter_map(|(id, ty)| {
                        if matches!(ty, Type::Execution(_)) {
                            Some(*id)
                        } else {
                            None
                        }
                    })
                    .zip(gradient.take(execution_type_count))
                    .map(|(id, color)| (id, Srgb::<u8>::from(color))),
            )
            .collect();

        Ok(Self {
            origin,
            nodes,
            edges,
            types,
            colors,
            url_template,
        })
    }

    pub fn generate<W: Write>(&self, writer: &mut W) -> anyhow::Result<()> {
        let url_template = if let Some(x) = &self.url_template {
            let mut tt = TinyTemplate::new();
            tt.add_template("url", x)?;
            Some(tt)
        } else {
            None
        };

        writeln!(writer, "digraph artifact_lineage_graph {{")?;
        writeln!(writer, "  concentrate=true;")?;
        for node in self.nodes.values() {
            writeln!(
                writer,
                "  {:?} [label={:?},shape={:?},style={:?},tooltip={:?},fillcolor={:?},URL={:?}];",
                node.id().to_string(),
                node.label(),
                node.shape(),
                node.style(self.origin),
                node.tooltip(&self.types)?,
                node.color(&self.colors),
                node.url(url_template.as_ref())?
            )?;
        }

        let mut visited_edges = HashSet::new();
        for edge in &self.edges {
            let key = (edge.from_node(), edge.to_node(), edge.label()?);
            if visited_edges.contains(&key) {
                continue;
            }
            visited_edges.insert(key);

            writeln!(
                writer,
                "  {:?} -> {:?} [label={:?}];",
                self.nodes[&edge.from_node()].id().to_string(),
                self.nodes[&edge.to_node()].id().to_string(),
                edge.label()?
            )?;
        }

        writeln!(writer, "  subgraph cluster_artifact_legend {{")?;
        writeln!(writer, "    label = \"Artifact Legend\";")?;
        let mut prev = None;
        for ty in self.types.values() {
            if matches!(ty, Type::Artifact(_)) {
                writeln!(
                    writer,
                    "    {:?}[shape={:?},style=filled,fillcolor=\"#{:02x}{:02x}{:02x}\"];",
                    ty.name(),
                    ty.shape(),
                    self.colors[&ty.id()].red,
                    self.colors[&ty.id()].green,
                    self.colors[&ty.id()].blue
                )?;
                if let Some(prev) = prev {
                    writeln!(
                        writer,
                        "{:?} -> {:?}[penwidth=0,arrowhead=none];",
                        prev,
                        ty.name()
                    )?;
                }
                prev = Some(ty.name());
            }
        }
        writeln!(writer, "  }}")?;

        writeln!(writer, "  subgraph cluster_execution_legend {{")?;
        writeln!(writer, "    label = \"Execution Legend\";")?;
        let mut prev = None;
        for ty in self.types.values() {
            if matches!(ty, Type::Execution(_)) {
                writeln!(
                    writer,
                    "    {:?}[shape={:?},style=filled,fillcolor=\"#{:02x}{:02x}{:02x}\"];",
                    ty.name(),
                    ty.shape(),
                    self.colors[&ty.id()].red,
                    self.colors[&ty.id()].green,
                    self.colors[&ty.id()].blue
                )?;
                if let Some(prev) = prev {
                    writeln!(
                        writer,
                        "{:?} -> {:?}[penwidth=0,arrowhead=none];",
                        prev,
                        ty.name()
                    )?;
                }
                prev = Some(ty.name());
            }
        }
        writeln!(writer, "  }}")?;

        writeln!(writer, "}}")?;
        Ok(())
    }
}