1use crate::serialize::Artifact;
3use std::collections::{BTreeMap, BTreeSet};
4use std::time::Duration;
5
6#[derive(Debug, Clone, structopt::StructOpt, serde::Serialize, serde::Deserialize)]
8#[structopt(rename_all = "kebab-case")]
9#[serde(rename_all = "kebab-case")]
10pub struct CommonArtifactsOpt {
11 #[structopt(long, env = "MLMD_DB", hide_env_values = true)]
13 #[serde(skip)]
14 pub db: String,
15
16 #[structopt(long = "id")]
18 #[serde(default, skip_serializing_if = "Vec::is_empty")]
19 pub ids: Vec<i32>,
20
21 #[structopt(long, requires("type-name"))]
23 #[serde(default, skip_serializing_if = "Option::is_none")]
24 pub name: Option<String>,
25
26 #[structopt(long, requires("type-name"), conflicts_with("name"))]
28 #[serde(default, skip_serializing_if = "Option::is_none")]
29 pub name_pattern: Option<String>,
30
31 #[structopt(long = "type")]
33 #[serde(rename = "type")]
34 #[serde(default, skip_serializing_if = "Option::is_none")]
35 pub type_name: Option<String>,
36
37 #[structopt(long)]
39 #[serde(default, skip_serializing_if = "Option::is_none")]
40 pub uri: Option<String>,
41
42 #[structopt(long)]
44 #[serde(default, skip_serializing_if = "Option::is_none")]
45 pub context: Option<i32>,
46
47 #[structopt(long)]
49 #[serde(default, skip_serializing_if = "Option::is_none")]
50 pub ctime_start: Option<f64>,
51
52 #[structopt(long)]
54 #[serde(default, skip_serializing_if = "Option::is_none")]
55 pub ctime_end: Option<f64>,
56
57 #[structopt(long)]
59 #[serde(default, skip_serializing_if = "Option::is_none")]
60 pub mtime_start: Option<f64>,
61
62 #[structopt(long)]
64 #[serde(default, skip_serializing_if = "Option::is_none")]
65 pub mtime_end: Option<f64>,
66}
67
68impl CommonArtifactsOpt {
69 fn request<'a>(
70 &self,
71 store: &'a mut mlmd::MetadataStore,
72 ) -> mlmd::requests::GetArtifactsRequest<'a> {
73 let mut request = store.get_artifacts();
74
75 if !self.ids.is_empty() {
76 request = request.ids(
77 self.ids
78 .iter()
79 .copied()
80 .map(mlmd::metadata::ArtifactId::new),
81 );
82 }
83 match (&self.name, &self.name_pattern, &self.type_name) {
84 (Some(name), None, Some(type_name)) => {
85 request = request.type_and_name(type_name, name);
86 }
87 (None, Some(name_pattern), Some(type_name)) => {
88 request = request.type_and_name_pattern(type_name, name_pattern);
89 }
90 (None, None, Some(type_name)) => {
91 request = request.ty(type_name);
92 }
93 _ => {}
94 }
95 if let Some(x) = &self.uri {
96 request = request.uri(x);
97 }
98 if let Some(x) = self.context {
99 request = request.context(mlmd::metadata::ContextId::new(x));
100 }
101 request = match (self.ctime_start, self.ctime_end) {
102 (None, None) => request,
103 (Some(s), None) => request.create_time(Duration::from_secs_f64(s)..),
104 (None, Some(e)) => request.create_time(..Duration::from_secs_f64(e)),
105 (Some(s), Some(e)) => {
106 request.create_time(Duration::from_secs_f64(s)..Duration::from_secs_f64(e))
107 }
108 };
109 request = match (self.mtime_start, self.mtime_end) {
110 (None, None) => request,
111 (Some(s), None) => request.update_time(Duration::from_secs_f64(s)..),
112 (None, Some(e)) => request.update_time(..Duration::from_secs_f64(e)),
113 (Some(s), Some(e)) => {
114 request.update_time(Duration::from_secs_f64(s)..Duration::from_secs_f64(e))
115 }
116 };
117
118 request
119 }
120}
121
122#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
124#[serde(rename_all = "snake_case")]
125#[allow(missing_docs)]
126pub enum ArtifactOrderByField {
127 Id,
128 Name,
129 #[serde(rename = "ctime")]
130 CreateTime,
131 #[serde(rename = "mtime")]
132 UpdateTime,
133}
134
135impl ArtifactOrderByField {
136 const POSSIBLE_VALUES: &'static [&'static str] = &["id", "name", "ctime", "mtime"];
137}
138
139impl Default for ArtifactOrderByField {
140 fn default() -> Self {
141 Self::Id
142 }
143}
144
145impl std::str::FromStr for ArtifactOrderByField {
146 type Err = anyhow::Error;
147
148 fn from_str(s: &str) -> anyhow::Result<Self> {
149 match s {
150 "id" => Ok(Self::Id),
151 "name" => Ok(Self::Name),
152 "ctime" => Ok(Self::CreateTime),
153 "mtime" => Ok(Self::UpdateTime),
154 _ => anyhow::bail!("invalid value: {:?}", s),
155 }
156 }
157}
158
159impl From<ArtifactOrderByField> for mlmd::requests::ArtifactOrderByField {
160 fn from(x: ArtifactOrderByField) -> Self {
161 match x {
162 ArtifactOrderByField::Id => Self::Id,
163 ArtifactOrderByField::Name => Self::Name,
164 ArtifactOrderByField::CreateTime => Self::CreateTime,
165 ArtifactOrderByField::UpdateTime => Self::UpdateTime,
166 }
167 }
168}
169
170#[derive(Debug, Clone, structopt::StructOpt, serde::Serialize, serde::Deserialize)]
172pub struct CountArtifactsOpt {
173 #[structopt(flatten)]
175 #[serde(flatten)]
176 pub common: CommonArtifactsOpt,
177}
178
179impl CountArtifactsOpt {
180 pub async fn count(&self, store: &mut mlmd::MetadataStore) -> anyhow::Result<usize> {
182 let n = self.common.request(store).count().await?;
183 Ok(n)
184 }
185}
186
187#[derive(Debug, Clone, structopt::StructOpt, serde::Serialize, serde::Deserialize)]
189#[serde(rename_all = "kebab-case")]
190pub struct GetArtifactsOpt {
191 #[structopt(flatten)]
193 pub common: CommonArtifactsOpt,
194
195 #[structopt(long, default_value="id", possible_values = ArtifactOrderByField::POSSIBLE_VALUES)]
197 #[serde(default)]
198 pub order_by: ArtifactOrderByField,
199
200 #[structopt(long)]
202 #[serde(default)]
203 pub asc: bool,
204
205 #[structopt(long, default_value = "100")]
207 #[serde(default = "GetArtifactsOpt::limit_default")]
208 pub limit: usize,
209
210 #[structopt(long, default_value = "0")]
212 #[serde(default)]
213 pub offset: usize,
214}
215
216impl GetArtifactsOpt {
217 fn limit_default() -> usize {
218 100
219 }
220
221 pub async fn get(&self, store: &mut mlmd::MetadataStore) -> anyhow::Result<Vec<Artifact>> {
223 let artifacts = self
224 .common
225 .request(store)
226 .limit(self.limit)
227 .offset(self.offset)
228 .order_by(self.order_by.into(), self.asc)
229 .execute()
230 .await?;
231
232 let artifact_types = self.get_artifact_types(store, &artifacts).await?;
233 Ok(artifacts
234 .into_iter()
235 .map(|x| Artifact::new(artifact_types[&x.type_id].clone(), x))
236 .collect())
237 }
238
239 async fn get_artifact_types(
240 &self,
241 store: &mut mlmd::MetadataStore,
242 artifacts: &[mlmd::metadata::Artifact],
243 ) -> anyhow::Result<BTreeMap<mlmd::metadata::TypeId, String>> {
244 Ok(store
245 .get_artifact_types()
246 .ids(
247 artifacts
248 .iter()
249 .map(|x| x.type_id)
250 .collect::<BTreeSet<_>>()
251 .into_iter(),
252 )
253 .execute()
254 .await?
255 .into_iter()
256 .map(|x| (x.id, x.name))
257 .collect::<BTreeMap<_, _>>())
258 }
259}