use structopt::StructOpt;
#[derive(Debug, StructOpt)]
#[structopt(rename_all = "kebab-case")]
#[allow(clippy::large_enum_variant)]
enum Opt {
#[allow(missing_docs)]
#[structopt(flatten)]
Batchable(BatchableOpt),
Graph(GraphOpt),
Batch(BatchOpt),
}
#[derive(Debug, StructOpt)]
#[structopt(rename_all = "kebab-case")]
struct BatchOpt {
#[structopt(long, env = "MLMD_DB", hide_env_values = true)]
db: String,
#[structopt(long, default_value = "10")]
workers: std::num::NonZeroUsize,
commands: Vec<BatchableOpt>,
}
impl BatchOpt {
async fn execute(&self) -> anyhow::Result<()> {
let handlers = (0..self.workers.get())
.map(|i| {
let db = self.db.clone();
let commands = self
.commands
.iter()
.enumerate()
.filter_map(|(j, c)| {
if j % self.workers.get() == i {
Some((j, c.clone()))
} else {
None
}
})
.collect::<Vec<_>>();
tokio::spawn(async move {
let mut store = mlmd::MetadataStore::connect(&db).await?;
let mut results = Vec::new();
for (i, command) in commands {
let result = command.execute_with_store(&mut store).await?;
results.push((i, result));
}
Ok(results)
})
})
.collect::<Vec<tokio::task::JoinHandle<anyhow::Result<_>>>>();
let mut results = Vec::new();
for handler in handlers {
results.extend(handler.await??);
}
results.sort_by_key(|x| x.0);
serde_json::to_writer_pretty(
std::io::stdout().lock(),
&results.into_iter().map(|x| x.1).collect::<Vec<_>>(),
)?;
println!();
Ok(())
}
}
#[derive(Debug, Clone, StructOpt, serde::Serialize, serde::Deserialize)]
#[structopt(rename_all = "kebab-case")]
#[serde(rename_all = "kebab-case")]
enum BatchableOpt {
Count(CountOpt),
Get(GetOpt),
}
impl std::str::FromStr for BatchableOpt {
type Err = anyhow::Error;
fn from_str(s: &str) -> anyhow::Result<Self> {
let v = serde_json::from_str(s)?;
Ok(v)
}
}
impl BatchableOpt {
fn db_uri(&self) -> &str {
match self {
Self::Count(CountOpt::Artifacts(opt)) => &opt.common.db,
Self::Get(GetOpt::Artifacts(opt)) => &opt.common.db,
Self::Count(CountOpt::ArtifactTypes(opt)) => &opt.db,
Self::Get(GetOpt::ArtifactTypes(opt)) => &opt.db,
Self::Count(CountOpt::Executions(opt)) => &opt.common.db,
Self::Get(GetOpt::Executions(opt)) => &opt.common.db,
Self::Count(CountOpt::ExecutionTypes(opt)) => &opt.db,
Self::Get(GetOpt::ExecutionTypes(opt)) => &opt.db,
Self::Count(CountOpt::Contexts(opt)) => &opt.common.db,
Self::Get(GetOpt::Contexts(opt)) => &opt.common.db,
Self::Count(CountOpt::ContextTypes(opt)) => &opt.db,
Self::Get(GetOpt::ContextTypes(opt)) => &opt.db,
Self::Count(CountOpt::Events(opt)) => &opt.common.db,
Self::Get(GetOpt::Events(opt)) => &opt.common.db,
}
}
async fn execute(&self) -> anyhow::Result<serde_json::Value> {
let mut store = mlmd::MetadataStore::connect(self.db_uri()).await?;
self.execute_with_store(&mut store).await
}
async fn execute_with_store(
&self,
store: &mut mlmd::MetadataStore,
) -> anyhow::Result<serde_json::Value> {
match self {
Self::Count(CountOpt::Artifacts(opt)) => to_json(opt.count(store).await?),
Self::Get(GetOpt::Artifacts(opt)) => to_json(opt.get(store).await?),
Self::Count(CountOpt::ArtifactTypes(opt)) => to_json(opt.count(store).await?),
Self::Get(GetOpt::ArtifactTypes(opt)) => to_json(opt.get(store).await?),
Self::Count(CountOpt::Executions(opt)) => to_json(opt.count(store).await?),
Self::Get(GetOpt::Executions(opt)) => to_json(opt.get(store).await?),
Self::Count(CountOpt::ExecutionTypes(opt)) => to_json(opt.count(store).await?),
Self::Get(GetOpt::ExecutionTypes(opt)) => to_json(opt.get(store).await?),
Self::Count(CountOpt::Contexts(opt)) => to_json(opt.count(store).await?),
Self::Get(GetOpt::Contexts(opt)) => to_json(opt.get(store).await?),
Self::Count(CountOpt::ContextTypes(opt)) => to_json(opt.count(store).await?),
Self::Get(GetOpt::ContextTypes(opt)) => to_json(opt.get(store).await?),
Self::Count(CountOpt::Events(opt)) => to_json(opt.count(store).await?),
Self::Get(GetOpt::Events(opt)) => to_json(opt.get(store).await?),
}
}
}
#[derive(Debug, Clone, StructOpt, serde::Serialize, serde::Deserialize)]
#[structopt(rename_all = "kebab-case")]
#[serde(rename_all = "kebab-case")]
enum CountOpt {
Artifacts(mlmdquery::artifacts::CountArtifactsOpt),
ArtifactTypes(mlmdquery::artifact_types::ArtifactTypesOpt),
Executions(mlmdquery::executions::CountExecutionsOpt),
ExecutionTypes(mlmdquery::execution_types::ExecutionTypesOpt),
Contexts(mlmdquery::contexts::CountContextsOpt),
ContextTypes(mlmdquery::context_types::ContextTypesOpt),
Events(mlmdquery::events::CountEventsOpt),
}
#[derive(Debug, Clone, StructOpt, serde::Serialize, serde::Deserialize)]
#[structopt(rename_all = "kebab-case")]
#[serde(rename_all = "kebab-case")]
enum GetOpt {
Artifacts(mlmdquery::artifacts::GetArtifactsOpt),
ArtifactTypes(mlmdquery::artifact_types::ArtifactTypesOpt),
Executions(mlmdquery::executions::GetExecutionsOpt),
ExecutionTypes(mlmdquery::execution_types::ExecutionTypesOpt),
Contexts(mlmdquery::contexts::GetContextsOpt),
ContextTypes(mlmdquery::context_types::ContextTypesOpt),
Events(mlmdquery::events::GetEventsOpt),
}
#[derive(Debug, StructOpt)]
#[structopt(rename_all = "kebab-case")]
enum GraphOpt {
Lineage(mlmdquery::lineage::GraphLineageOpt),
Derived(mlmdquery::derived::GraphDerivedOpt),
Io(mlmdquery::io::GraphIoOpt),
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let opt = Opt::from_args();
match opt {
Opt::Batchable(opt) => {
serde_json::to_writer_pretty(std::io::stdout().lock(), &opt.execute().await?)?;
println!();
}
Opt::Batch(opt) => opt.execute().await?,
Opt::Graph(GraphOpt::Lineage(opt)) => opt.graph(&mut std::io::stdout().lock()).await?,
Opt::Graph(GraphOpt::Derived(opt)) => opt.graph(&mut std::io::stdout().lock()).await?,
Opt::Graph(GraphOpt::Io(opt)) => opt.graph(&mut std::io::stdout().lock()).await?,
}
Ok(())
}
fn to_json(item: impl serde::Serialize) -> anyhow::Result<serde_json::Value> {
let v = serde_json::to_value(&item)?;
Ok(v)
}