mlmdquery/
artifacts.rs

1//! `$ mlmdquery {get,count} artifacts` implementation.
2use crate::serialize::Artifact;
3use std::collections::{BTreeMap, BTreeSet};
4use std::time::Duration;
5
6/// `$ mlmdquery {get,count} artifacts` common options.
7#[derive(Debug, Clone, structopt::StructOpt, serde::Serialize, serde::Deserialize)]
8#[structopt(rename_all = "kebab-case")]
9#[serde(rename_all = "kebab-case")]
10pub struct CommonArtifactsOpt {
11    /// Database URL.
12    #[structopt(long, env = "MLMD_DB", hide_env_values = true)]
13    #[serde(skip)]
14    pub db: String,
15
16    /// Target artifact IDs.
17    #[structopt(long = "id")]
18    #[serde(default, skip_serializing_if = "Vec::is_empty")]
19    pub ids: Vec<i32>,
20
21    /// Target artifact name.
22    #[structopt(long, requires("type-name"))]
23    #[serde(default, skip_serializing_if = "Option::is_none")]
24    pub name: Option<String>,
25
26    /// Target artifact name pattern (SQL LIKE statement value).
27    #[structopt(long, requires("type-name"), conflicts_with("name"))]
28    #[serde(default, skip_serializing_if = "Option::is_none")]
29    pub name_pattern: Option<String>,
30
31    /// Target artifact type.
32    #[structopt(long = "type")]
33    #[serde(rename = "type")]
34    #[serde(default, skip_serializing_if = "Option::is_none")]
35    pub type_name: Option<String>,
36
37    /// Target artifact URI.
38    #[structopt(long)]
39    #[serde(default, skip_serializing_if = "Option::is_none")]
40    pub uri: Option<String>,
41
42    /// Context ID to which target artifacts belong.
43    #[structopt(long)]
44    #[serde(default, skip_serializing_if = "Option::is_none")]
45    pub context: Option<i32>,
46
47    /// Start of creation time (UNIX timestamp seconds).
48    #[structopt(long)]
49    #[serde(default, skip_serializing_if = "Option::is_none")]
50    pub ctime_start: Option<f64>,
51
52    /// End of creation time (UNIX timestamp seconds).
53    #[structopt(long)]
54    #[serde(default, skip_serializing_if = "Option::is_none")]
55    pub ctime_end: Option<f64>,
56
57    /// Start of update time (UNIX timestamp seconds).
58    #[structopt(long)]
59    #[serde(default, skip_serializing_if = "Option::is_none")]
60    pub mtime_start: Option<f64>,
61
62    /// End of update time (UNIX timestamp seconds).
63    #[structopt(long)]
64    #[serde(default, skip_serializing_if = "Option::is_none")]
65    pub mtime_end: Option<f64>,
66}
67
68impl CommonArtifactsOpt {
69    fn request<'a>(
70        &self,
71        store: &'a mut mlmd::MetadataStore,
72    ) -> mlmd::requests::GetArtifactsRequest<'a> {
73        let mut request = store.get_artifacts();
74
75        if !self.ids.is_empty() {
76            request = request.ids(
77                self.ids
78                    .iter()
79                    .copied()
80                    .map(mlmd::metadata::ArtifactId::new),
81            );
82        }
83        match (&self.name, &self.name_pattern, &self.type_name) {
84            (Some(name), None, Some(type_name)) => {
85                request = request.type_and_name(type_name, name);
86            }
87            (None, Some(name_pattern), Some(type_name)) => {
88                request = request.type_and_name_pattern(type_name, name_pattern);
89            }
90            (None, None, Some(type_name)) => {
91                request = request.ty(type_name);
92            }
93            _ => {}
94        }
95        if let Some(x) = &self.uri {
96            request = request.uri(x);
97        }
98        if let Some(x) = self.context {
99            request = request.context(mlmd::metadata::ContextId::new(x));
100        }
101        request = match (self.ctime_start, self.ctime_end) {
102            (None, None) => request,
103            (Some(s), None) => request.create_time(Duration::from_secs_f64(s)..),
104            (None, Some(e)) => request.create_time(..Duration::from_secs_f64(e)),
105            (Some(s), Some(e)) => {
106                request.create_time(Duration::from_secs_f64(s)..Duration::from_secs_f64(e))
107            }
108        };
109        request = match (self.mtime_start, self.mtime_end) {
110            (None, None) => request,
111            (Some(s), None) => request.update_time(Duration::from_secs_f64(s)..),
112            (None, Some(e)) => request.update_time(..Duration::from_secs_f64(e)),
113            (Some(s), Some(e)) => {
114                request.update_time(Duration::from_secs_f64(s)..Duration::from_secs_f64(e))
115            }
116        };
117
118        request
119    }
120}
121
122/// Fields that can be used to sort a search result.
123#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
124#[serde(rename_all = "snake_case")]
125#[allow(missing_docs)]
126pub enum ArtifactOrderByField {
127    Id,
128    Name,
129    #[serde(rename = "ctime")]
130    CreateTime,
131    #[serde(rename = "mtime")]
132    UpdateTime,
133}
134
135impl ArtifactOrderByField {
136    const POSSIBLE_VALUES: &'static [&'static str] = &["id", "name", "ctime", "mtime"];
137}
138
139impl Default for ArtifactOrderByField {
140    fn default() -> Self {
141        Self::Id
142    }
143}
144
145impl std::str::FromStr for ArtifactOrderByField {
146    type Err = anyhow::Error;
147
148    fn from_str(s: &str) -> anyhow::Result<Self> {
149        match s {
150            "id" => Ok(Self::Id),
151            "name" => Ok(Self::Name),
152            "ctime" => Ok(Self::CreateTime),
153            "mtime" => Ok(Self::UpdateTime),
154            _ => anyhow::bail!("invalid value: {:?}", s),
155        }
156    }
157}
158
159impl From<ArtifactOrderByField> for mlmd::requests::ArtifactOrderByField {
160    fn from(x: ArtifactOrderByField) -> Self {
161        match x {
162            ArtifactOrderByField::Id => Self::Id,
163            ArtifactOrderByField::Name => Self::Name,
164            ArtifactOrderByField::CreateTime => Self::CreateTime,
165            ArtifactOrderByField::UpdateTime => Self::UpdateTime,
166        }
167    }
168}
169
170/// `$ mlmdquery count artifacts` options.
171#[derive(Debug, Clone, structopt::StructOpt, serde::Serialize, serde::Deserialize)]
172pub struct CountArtifactsOpt {
173    /// Common options.
174    #[structopt(flatten)]
175    #[serde(flatten)]
176    pub common: CommonArtifactsOpt,
177}
178
179impl CountArtifactsOpt {
180    /// `$ mlmdquery count artifacts` implementation.
181    pub async fn count(&self, store: &mut mlmd::MetadataStore) -> anyhow::Result<usize> {
182        let n = self.common.request(store).count().await?;
183        Ok(n)
184    }
185}
186
187/// `$ mlmdquery get artifacts` options.
188#[derive(Debug, Clone, structopt::StructOpt, serde::Serialize, serde::Deserialize)]
189#[serde(rename_all = "kebab-case")]
190pub struct GetArtifactsOpt {
191    /// Common options.
192    #[structopt(flatten)]
193    pub common: CommonArtifactsOpt,
194
195    /// Field to be used to sort a search result.
196    #[structopt(long, default_value="id", possible_values = ArtifactOrderByField::POSSIBLE_VALUES)]
197    #[serde(default)]
198    pub order_by: ArtifactOrderByField,
199
200    /// If specified, the search results will be sorted in ascending order.
201    #[structopt(long)]
202    #[serde(default)]
203    pub asc: bool,
204
205    /// Maximum number of artifacts in a search result.
206    #[structopt(long, default_value = "100")]
207    #[serde(default = "GetArtifactsOpt::limit_default")]
208    pub limit: usize,
209
210    /// Number of artifacts to be skipped from a search result.
211    #[structopt(long, default_value = "0")]
212    #[serde(default)]
213    pub offset: usize,
214}
215
216impl GetArtifactsOpt {
217    fn limit_default() -> usize {
218        100
219    }
220
221    /// `$ mlmdquery get artifacts` implementation.
222    pub async fn get(&self, store: &mut mlmd::MetadataStore) -> anyhow::Result<Vec<Artifact>> {
223        let artifacts = self
224            .common
225            .request(store)
226            .limit(self.limit)
227            .offset(self.offset)
228            .order_by(self.order_by.into(), self.asc)
229            .execute()
230            .await?;
231
232        let artifact_types = self.get_artifact_types(store, &artifacts).await?;
233        Ok(artifacts
234            .into_iter()
235            .map(|x| Artifact::new(artifact_types[&x.type_id].clone(), x))
236            .collect())
237    }
238
239    async fn get_artifact_types(
240        &self,
241        store: &mut mlmd::MetadataStore,
242        artifacts: &[mlmd::metadata::Artifact],
243    ) -> anyhow::Result<BTreeMap<mlmd::metadata::TypeId, String>> {
244        Ok(store
245            .get_artifact_types()
246            .ids(
247                artifacts
248                    .iter()
249                    .map(|x| x.type_id)
250                    .collect::<BTreeSet<_>>()
251                    .into_iter(),
252            )
253            .execute()
254            .await?
255            .into_iter()
256            .map(|x| (x.id, x.name))
257            .collect::<BTreeMap<_, _>>())
258    }
259}