pub mod artifact;
pub mod config;
mod error;
pub mod experiment;
mod http;
pub mod modelversion;
pub mod registered;
pub mod run;
use crate::config::{BasicAuth, ClientConfig, DownloadConfig, MlflowConfig};
use crate::error::is_retryable;
use crate::http::Response;
use crate::{artifact::*, experiment::*, modelversion::*, registered::*, run::*};
use anyhow::{anyhow, Context, Result};
use backon::{ExponentialBuilder, Retryable};
use reqwest::{Method, RequestBuilder};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
use std::path::Path;
use tokio::{
fs,
fs::{create_dir_all, File},
io::AsyncWriteExt,
task::JoinSet,
};
use tracing::{debug, info, warn};
#[derive(Clone, Debug)]
pub struct Client {
pub client: reqwest::Client,
pub mlflow: MlflowConfig,
pub download: DownloadConfig,
}
pub(crate) mod bs {
use serde::Deserialize;
#[derive(Deserialize)]
pub struct Dummy {}
#[derive(Deserialize)]
pub struct Run {
pub run: super::Run,
}
#[derive(Deserialize)]
pub struct Runs {
#[serde(default)]
pub runs: Vec<super::Run>,
}
#[derive(Deserialize)]
pub struct RunInfo {
pub run_info: super::RunInfo,
}
#[derive(Deserialize)]
pub struct ModelVersion {
pub model_version: super::ModelVersion,
}
#[derive(Deserialize)]
pub struct RegisteredModel {
pub registered_model: super::RegisteredModel,
}
#[derive(Deserialize)]
pub struct RegisteredModels {
#[serde(default)]
pub registered_models: Vec<super::RegisteredModel>,
}
}
impl Client {
pub fn new(urlbase: impl AsRef<str>) -> Self {
Self {
client: reqwest::Client::new(),
mlflow: MlflowConfig {
urlbase: urlbase.as_ref().to_owned(),
auth: None,
no_retry_status_codes: vec![400, 401, 403, 404, 409, 422],
},
download: DownloadConfig {
retry: false,
tasks: 16,
blacklist: Default::default(),
cache_local_artifacts: false,
file_size_check: true,
},
}
}
pub fn new_from_config(config: ClientConfig) -> Self {
Self {
client: reqwest::Client::new(),
mlflow: config.mlflow,
download: config.download,
}
}
pub fn with_auth(mut self, auth: BasicAuth) -> Self {
self.mlflow.auth = Some(auth);
self
}
fn make_url(&self, path: impl AsRef<str>) -> String {
format!("{}/{}", self.mlflow.urlbase, path.as_ref())
}
async fn send_json_request_no_retry<P: Serialize>(
&self,
method: Method,
path: impl AsRef<str>,
payload: &P,
) -> Result<Response> {
let mut request = self
.client
.request(method, self.make_url(path.as_ref()))
.json(&payload);
request = self.assign_basic_auth(request);
Response(request.send().await?).error_for_status().await
}
async fn send_request<P: Serialize>(
&self,
method: Method,
path: impl AsRef<str>,
payload: P,
) -> Result<Response> {
let future = || async {
self.send_json_request_no_retry(method.clone(), path.as_ref(), &payload)
.await
};
future
.retry(ExponentialBuilder::default())
.when(|e| is_retryable(&self.mlflow.no_retry_status_codes, e))
.notify(|error, duration| {
warn!("Retrying in {:?}, got an error: {:?}", error, duration);
})
.await
.with_context(|| format!("Cannot send request to: {:?}", path.as_ref()))
}
fn assign_basic_auth(&self, request: RequestBuilder) -> RequestBuilder {
if let Some(auth) = self.mlflow.auth.as_ref() {
request.basic_auth(
auth.user.expose_secret().to_owned(),
Some(auth.password.expose_secret().to_owned()),
)
} else {
request
}
}
pub async fn create_experiment(&self, name: &str, tags: Vec<KeyValue>) -> Result<String> {
let path = "2.0/mlflow/experiments/create";
let request = CreateExperiment {
name: name.to_owned(),
tags,
};
self.send_request(Method::POST, path, request)
.await?
.json::<ExperimentId>()
.await
.map(|response| response.experiment_id)
}
pub async fn delete_experiment(&self, experiment_id: &str) -> Result<()> {
let path = "2.0/mlflow/experiments/delete";
let request = ExperimentId {
experiment_id: experiment_id.to_owned(),
};
self.send_request(Method::POST, path, request)
.await
.map(|_| ())
}
pub async fn get_experiment(&self, experiment_id: &str) -> Result<Experiment> {
let path = "2.0/mlflow/experiments/get";
let request = ExperimentId {
experiment_id: experiment_id.to_owned(),
};
#[derive(Deserialize)]
struct Response {
experiment: Experiment,
}
self.send_request(Method::GET, path, request)
.await?
.json::<Response>()
.await
.map(|response| response.experiment)
}
pub async fn get_experiment_by_name(&self, name: &str) -> Result<Experiment> {
let path = "2.0/mlflow/experiments/get-by-name";
let request = json!({
"experiment_name": name
});
#[derive(Deserialize)]
struct Response {
experiment: Experiment,
}
self.send_request(Method::GET, path, request)
.await?
.json::<Response>()
.await
.map(|response| response.experiment)
}
pub async fn get_or_create_experiment(
&self,
name: &str,
tags: Vec<KeyValue>,
) -> Result<Experiment> {
let path = "2.0/mlflow/experiments/get-by-name";
let request = json!({
"experiment_name": name
});
#[derive(Deserialize)]
struct Response {
experiment: Experiment,
}
let result = self
.send_json_request_no_retry(Method::GET, path, &request)
.await;
match result {
Err(_) => {
let experiment_id = self.create_experiment(name, tags).await?;
self.get_experiment(&experiment_id).await
}
Ok(response) => response
.json::<Response>()
.await
.map(|response| response.experiment),
}
}
pub async fn create_run(&self, create: CreateRun) -> Result<Run> {
let path = "2.0/mlflow/runs/create";
self.send_request(Method::POST, path, create)
.await?
.json::<bs::Run>()
.await
.map(|value| value.run)
}
pub async fn get_run(&self, run_id: &str) -> Result<Run> {
let path = "2.0/mlflow/runs/get";
let request = Request {
run_id: run_id.to_string(),
};
#[derive(Serialize)]
struct Request {
run_id: String,
}
self.send_request(Method::GET, path, request)
.await?
.json::<bs::Run>()
.await
.map(|value| value.run)
}
pub async fn search_runs(&self, search: SearchRuns) -> Result<Vec<Run>> {
let path = "2.0/mlflow/runs/search";
self.send_request(Method::POST, path, search)
.await?
.json::<bs::Runs>()
.await
.map(|response| response.runs)
}
pub async fn update_run(&self, update: UpdateRun) -> Result<RunInfo> {
let path = "2.0/mlflow/runs/update";
self.send_request(Method::POST, path, update)
.await?
.json::<bs::RunInfo>()
.await
.map(|value| value.run_info)
}
pub async fn add_run_meta(&self, run_id: &str, meta: RunData) -> Result<()> {
let path = "2.0/mlflow/runs/log-batch";
let request = json!({
"run_id": run_id.to_owned(),
"metrics": meta.metrics,
"params": meta.params,
"tags": meta.tags
});
self.send_request(Method::POST, path, request)
.await
.map(|_| ())
}
pub async fn delete_run(&self, run_id: &str) -> Result<()> {
let path = "2.0/mlflow/runs/delete";
self.send_request(Method::POST, path, json!({"run_id": run_id}))
.await
.map(|_| ())
}
pub async fn add_run_inputs(&self, run_id: &str, inputs: Vec<DataSetInput>) -> Result<()> {
let path = "2.0/mlflow/runs/log-inputs";
let request = json!({
"run_id": run_id.to_owned(),
"datasets": inputs,
});
self.send_request(Method::POST, path, request)
.await
.map(|_| ())
}
pub async fn delete_registered_model(&self, name: &str) -> Result<()> {
let path = "2.0/mlflow/registered-models/delete";
let request = json!({"name": name.to_owned()});
self.send_request(Method::DELETE, path, request)
.await
.map(|_| ())
}
pub async fn register_model(&self, request: RegisterModel) -> Result<RegisteredModel> {
let path = "2.0/mlflow/registered-models/create";
self.send_request(Method::POST, path, request)
.await?
.json::<bs::RegisteredModel>()
.await
.map(|response| response.registered_model)
}
pub async fn get_registered_model(&self, name: &str) -> Result<RegisteredModel> {
let path = "2.0/mlflow/registered-models/get";
let request = json!({"name": name.to_owned()});
self.send_request(Method::GET, path, request)
.await?
.json::<bs::RegisteredModel>()
.await
.map(|response| response.registered_model)
}
pub async fn search_registered_models(&self, filter: &str) -> Result<Vec<RegisteredModel>> {
let path = "2.0/mlflow/registered-models/search";
let request = json!({"filter": filter.to_owned()});
self.send_request(Method::GET, path, request)
.await?
.json::<bs::RegisteredModels>()
.await
.map(|response| response.registered_models)
}
pub async fn create_model_version(&self, request: CreateModelVersion) -> Result<ModelVersion> {
let path = "2.0/mlflow/model-versions/create";
self.send_request(Method::POST, path, request)
.await?
.json::<bs::ModelVersion>()
.await
.map(|response| response.model_version)
}
pub async fn transition_model_version_stage(
&self,
request: TransitionModelVersionStage,
) -> Result<ModelVersion> {
let path = "2.0/mlflow/model-versions/transition-stage";
self.send_request(Method::POST, path, request)
.await?
.json::<bs::ModelVersion>()
.await
.map(|response| response.model_version)
}
pub async fn list_run_artifacts(&self, run_id: &str) -> Result<ListedArtifacts> {
let path = "2.0/mlflow/artifacts/list";
let request = json!({
"run_id": run_id
});
let mut response = self
.send_request(Method::GET, path, request)
.await?
.json::<ListedArtifacts>()
.await?;
let mut tocheck = response.files.drain(..).collect::<Vec<_>>();
while let Some(info) = tocheck.pop() {
if !info.is_dir {
response.files.push(info);
continue;
}
debug!("Listing directory: {:?}", info.path);
let request = json!({
"run_id": run_id,
"path": info.path,
});
let mut files = self
.send_request(Method::GET, path, request)
.await?
.json::<ListedArtifacts>()
.await?
.files;
tocheck.append(&mut files);
}
Ok(response)
}
pub async fn download_artifacts(
&self,
downloads: Vec<DownloadRunArtifacts>,
) -> Result<Vec<RunArtifacts>> {
let mut downloads = downloads
.into_iter()
.flat_map(|downloads| downloads.as_single_downloads())
.collect::<Vec<_>>();
downloads.retain(|download| {
if self.download.is_blacklisted(&download.file) {
debug!("Filtering out blacklisted file: {}", download.file);
false
} else {
true
}
});
let chunk_size = if downloads.len() <= self.download.tasks {
1
} else {
downloads.len() / self.download.tasks
};
let batches = downloads.chunks(chunk_size).map(|chunks| chunks.to_vec());
let mut tasks = JoinSet::<Result<Vec<_>>>::new();
for (idx, batch) in batches.enumerate() {
let client = self.clone();
tasks.spawn(async move {
debug!("Worker {} has got #{} files", idx, batch.len());
let mut completed = Vec::with_capacity(batch.len());
for download in batch {
debug!("Starting {:?}", download);
let size = client.download_artifact(&download).await?;
debug!(
"Downloaded artifact {} [size={}, expected={}]",
download.file, size, download.expected_size
);
completed.push(download);
}
Ok(completed)
});
}
let mut downloaded = HashMap::new();
while let Some(result) = tasks.join_next().await {
let completed = result
.context("We cannot join download task")?
.context("Download task has returned an error")?;
for download in completed {
downloaded
.entry(download.run_id.clone())
.and_modify(|artifacts: &mut RunArtifacts| {
artifacts.paths.push(download.path());
})
.or_insert_with(|| RunArtifacts {
paths: vec![download.path()],
experiment_id: download.experiment_id,
run_id: download.run_id,
root: download.destination,
});
}
}
Ok(downloaded.into_values().collect())
}
pub async fn prepare_run_download(
&self,
run_id: &str,
directory: impl AsRef<Path>,
) -> Result<DownloadRunArtifacts> {
let list = self
.list_run_artifacts(run_id)
.await
.context("Cannot list run artifacts")?;
let run = self
.get_run(run_id)
.await
.context("Cannot get run from mlflow")?;
Ok(DownloadRunArtifacts::new_from_run(directory, run, list))
}
pub async fn download_run_artifacts(
&self,
download: DownloadRunArtifacts,
) -> Result<RunArtifacts> {
debug!("Starting: {:#?}", download);
self.download_artifacts(vec![download.clone()])
.await
.with_context(|| format!("Cannot download artifacts for {:#?}", download))?
.pop()
.context("BUG: We've not received any RunArtifacts")
}
pub async fn upload_json_artifact_no_retry(
&self,
data: &impl Serialize,
artifact: &Artifact,
) -> Result<()> {
let path = format!(
"2.0/mlflow-artifacts/artifacts/{}/{}/artifacts/{}",
artifact.experiment_id,
artifact.run_id,
artifact.path.to_string_lossy()
);
let mut request = self.client.put(self.make_url(path)).json(data);
request = self.assign_basic_auth(request);
Response(request.send().await?)
.error_for_status()
.await?
.json::<bs::Dummy>()
.await
.map(|_| ())
}
pub async fn upload_json_artifact(
&self,
data: &impl Serialize,
artifact: &Artifact,
) -> Result<()> {
let future = || async { self.upload_json_artifact_no_retry(data, artifact).await };
future
.retry(ExponentialBuilder::default())
.when(|e| is_retryable(&self.mlflow.no_retry_status_codes, e))
.notify(|error, duration| {
warn!(
"Retrying upload in {:?}, got an error: {:?}",
error, duration
);
})
.await
.with_context(|| format!("Cannot upload: {:?}", artifact))
}
pub async fn download_json_artifact_no_retry<T: DeserializeOwned>(
&self,
artifact: &Artifact,
) -> Result<T> {
let path = format!(
"2.0/mlflow-artifacts/artifacts/{}/{}/artifacts/{}",
artifact.experiment_id,
artifact.run_id,
artifact.path.to_string_lossy()
);
let mut request = self.client.get(self.make_url(path));
request = self.assign_basic_auth(request);
Response(request.send().await?)
.error_for_status()
.await?
.json::<T>()
.await
}
pub async fn download_json_artifact<T: DeserializeOwned>(
&self,
artifact: &Artifact,
) -> Result<T> {
let future = || async { self.download_json_artifact_no_retry(artifact).await };
future
.retry(ExponentialBuilder::default())
.when(|e| is_retryable(&self.mlflow.no_retry_status_codes, e))
.notify(|error, duration| {
warn!(
"Retrying download in {:?}, got an error: {:?}",
error, duration
);
})
.await
.with_context(|| format!("Cannot download: {:?}", artifact))
}
pub async fn upload_artifact_no_retry(
&self,
source: impl AsRef<Path>,
artifact: &Artifact,
) -> Result<()> {
let path = format!(
"2.0/mlflow-artifacts/artifacts/{}/{}/artifacts/{}",
artifact.experiment_id,
artifact.run_id,
artifact.path.to_string_lossy()
);
let file = File::open(source.as_ref())
.await
.with_context(|| format!("Cannot open artifact file: {:?}", source.as_ref()))?;
let mut request = self.client.put(self.make_url(path)).body(file);
request = self.assign_basic_auth(request);
Response(request.send().await?)
.error_for_status()
.await?
.json::<bs::Dummy>()
.await
.map(|_| ())
}
pub async fn upload_artifacts(&self, uploads: Vec<UploadArtifact>) -> Result<()> {
let chunk_size = if uploads.len() <= self.download.tasks {
1
} else {
uploads.len() / self.download.tasks
};
let batches = uploads.chunks(chunk_size).map(|chunks| chunks.to_vec());
let mut tasks = JoinSet::<Result<()>>::new();
for (idx, batch) in batches.enumerate() {
let client = self.clone();
tasks.spawn(async move {
debug!("Worker {} has got #{} files", idx, batch.len());
for upload in batch {
debug!("Starting {:?}", upload);
client
.upload_artifact(&upload.local, &upload.remote)
.await?;
}
Ok(())
});
}
while let Some(result) = tasks.join_next().await {
result
.context("We cannot join download task")?
.context("Upload task has returned an error")?;
}
Ok(())
}
pub async fn upload_artifact(
&self,
source: impl AsRef<Path>,
artifact: &Artifact,
) -> Result<()> {
let future = || async {
self.upload_artifact_no_retry(source.as_ref(), artifact)
.await
};
future
.retry(ExponentialBuilder::default())
.when(|e| is_retryable(&self.mlflow.no_retry_status_codes, e))
.notify(|error, duration| {
warn!(
"Retrying upload in {:?}, got an error: {:?}",
error, duration
);
})
.await
.with_context(|| format!("Cannot upload: {:?}", artifact))
}
pub async fn download_artifact_no_retry(&self, download: &DownloadRunArtifact) -> Result<u64> {
let path = format!(
"2.0/mlflow-artifacts/artifacts/{}/{}/artifacts/{}",
download.experiment_id, download.run_id, download.file,
);
if let Some(directory) = download.path().parent() {
create_dir_all(directory).await.with_context(|| {
format!(
"Unable to create a directory {:?} for the artifact",
directory
)
})?;
}
if self.download.cache_local_artifacts && download.path().exists() {
let path = download.path();
info!("Path {path:?} does exists and caching is enabled, skipping download");
let file = File::open(&path)
.await
.with_context(|| format!("Unable to open cached artifact: {path:?} for reading"))?;
return file
.metadata()
.await
.with_context(|| format!("Unable to get metadata for: {path:?}"))
.map(|metadata| metadata.len());
}
let mut file = File::create(&download.path())
.await
.with_context(|| format!("Cannot create artifact file: {:?}", download.path()))?;
let mut request = self.client.get(self.make_url(path));
request = self.assign_basic_auth(request);
let mut response = Response(request.send().await?).error_for_status().await?.0;
debug!(
"Content-Length of {:?} is {:?}",
download.file,
response.content_length()
);
let mut size = 0;
while let Some(chunk) = response
.chunk()
.await
.context("Cannot read artifact data")?
{
file.write_all(&chunk)
.await
.context("Cannot write data to artifact file")?;
size += chunk.len() as u64;
}
file.flush()
.await
.context("Unable to flush artifact file to disk")?;
file.sync_all()
.await
.context("Unable to write file metadata to disk")?;
Ok(size)
}
pub async fn download_artifact(&self, download: &DownloadRunArtifact) -> Result<u64> {
let future = || async {
let size = self.download_artifact_no_retry(download).await?;
if self.download.file_size_check && size != download.expected_size {
fs::remove_file(download.path()).await?; Err(anyhow!(
"Expected size {} and downloaded {} do not match for {}",
size,
download.expected_size,
download.file,
))
} else {
Ok(size)
}
};
future
.retry(ExponentialBuilder::default())
.when(|e| is_retryable(&self.mlflow.no_retry_status_codes, e))
.notify(|error, duration| {
warn!(
"Retrying download in {:?}, got an error: {:?}",
error, duration
);
})
.await
.with_context(|| format!("Cannot download: {:?}", download))
}
}
#[cfg(test)]
mod test {
use super::*;
use redact::Secret;
use rstest::{fixture, rstest};
use std::path::PathBuf;
use tracing_test::traced_test;
#[fixture]
fn run_name() -> String {
format!("run-{}", rand::random::<u64>())
}
#[fixture]
fn model_name() -> String {
format!("registered-model-{}", rand::random::<u64>())
}
#[fixture]
fn experiment_name() -> String {
format!("experiment-{}", rand::random::<u64>())
}
#[fixture]
fn client() -> Client {
let host = match std::env::var("CI") {
Ok(_) => "mlflow",
Err(_) => "localhost",
};
Client::new(format!("http://{}:5000/api", host)).with_auth(BasicAuth {
user: Secret::from("kokot"),
password: Secret::from("bezpeci2021"),
})
}
#[rstest]
fn test_client_debug(client: Client) {
let formatted = format!("{client:?}");
assert!(!formatted.contains("kokot"));
assert!(!formatted.contains("bezpeci2021"));
}
#[rstest]
#[tokio::test]
async fn test_get_or_create_experiment(experiment_name: String, client: Client) {
client
.get_or_create_experiment(&experiment_name, vec![])
.await
.expect("BUG: Cannot create experiment");
}
#[rstest]
#[tokio::test]
async fn test_create_get_delete_experiment(experiment_name: String, client: Client) {
let id = client
.create_experiment(&experiment_name, vec![])
.await
.expect("BUG: Cannot create experiment");
let experiment = client
.get_experiment(&id)
.await
.expect("BUG: Cannot get experiment");
assert_eq!(experiment.name, experiment_name);
client
.delete_experiment(&id)
.await
.expect("BUG: Cannot update experiment");
client
.create_experiment(&experiment_name, vec![])
.await
.expect_err("BUG: Cannot create deleted experiment");
}
#[rstest]
#[tokio::test]
async fn test_get_experiment_by_name(experiment_name: String, client: Client) {
let experiment_id = client
.create_experiment(&experiment_name, vec![])
.await
.expect("BUG: Cannot create experiment");
let experiment = client
.get_experiment_by_name(&experiment_name)
.await
.expect("BUG: Cannot get crated experiment");
assert_eq!(experiment.experiment_id, experiment_id);
}
#[rstest]
#[tokio::test]
#[awt]
async fn test_run_search(client: Client, #[future] run: Run) {
let search = SearchRuns::new()
.experiment_ids(vec![run.info.experiment_id.clone()])
.max_results(Some(16))
.view(ViewType::All)
.build();
let runs = client
.search_runs(search)
.await
.expect("BUG: Cannot search runs");
assert!(!runs.is_empty());
let search = SearchRuns::new()
.experiment_ids(vec![run.info.experiment_id.clone()])
.view(ViewType::All)
.build();
let runs = client
.search_runs(search)
.await
.expect("BUG: Cannot search runs");
assert!(!runs.is_empty());
}
#[rstest]
#[tokio::test]
async fn test_run_create(experiment_name: String, run_name: String, client: Client) {
let experiment_id = client
.create_experiment(&experiment_name, vec![])
.await
.expect("BUG: Cannot create experiment");
let create = CreateRun::new()
.run_name(&run_name)
.experiment_id(&experiment_id)
.build();
let run = client
.create_run(create)
.await
.expect("BUG: Cannot create run");
let run1 = client
.get_run(&run.info.run_id)
.await
.expect("BUG: Cannot get run");
assert_eq!(run, run1);
}
#[rstest]
#[tokio::test]
async fn test_run_delete(experiment_name: String, run_name: String, client: Client) {
let experiment_id = client
.create_experiment(&experiment_name, vec![])
.await
.expect("BUG: Cannot create experiment");
let create = CreateRun::new()
.run_name(&run_name)
.experiment_id(&experiment_id)
.build();
let run = client
.create_run(create)
.await
.expect("BUG: Cannot create run");
client
.delete_run(&run.info.run_id)
.await
.expect("BUG: Cannot delete run");
client
.delete_run(&run.info.run_id)
.await
.expect("BUG: Can delete non-existing run");
}
#[rstest]
#[tokio::test]
async fn test_run_update_data(experiment_name: String, run_name: String, client: Client) {
let experiment_id = client
.create_experiment(&experiment_name, vec![])
.await
.expect("BUG: Cannot create experiment");
let create = CreateRun::new()
.run_name(&run_name)
.experiment_id(&experiment_id)
.build();
let run = client
.create_run(create)
.await
.expect("BUG: Cannot create run");
let mut update = RunData::new()
.metrics(vec![Metric::new().key("m").value(37.).step(0).build()])
.params(vec![KeyValue::new().key("p").value("42").build()])
.tags(vec![KeyValue::new().key("t").value("73").build()])
.build();
client
.add_run_meta(&run.info.run_id, update.clone())
.await
.expect("BUG: Cannot update run data");
update.tags.insert(
0,
KeyValue::new()
.key("mlflow.runName")
.value(run_name)
.build(),
);
let run1 = client
.get_run(&run.info.run_id)
.await
.expect("BUG: Cannot get run");
assert_eq!(run1.data, update);
}
#[rstest]
#[tokio::test]
async fn test_run_add_inputs(experiment_name: String, run_name: String, client: Client) {
let experiment_id = client
.create_experiment(&experiment_name, vec![])
.await
.expect("BUG: Cannot create experiment");
let create = CreateRun::new()
.run_name(&run_name)
.experiment_id(&experiment_id)
.build();
let run = client
.create_run(create)
.await
.expect("BUG: Cannot create run");
let input = DataSetInput::new()
.tags(vec![KeyValue::new().key("a").value("x").build()])
.dataset(
DataSet::new()
.name("kokot")
.digest("123")
.source_type("kokot1")
.source("s3")
.schema("{}")
.profile("{\"rows\": 22}")
.build(),
)
.build();
client
.add_run_inputs(&run.info.run_id, vec![input.clone()])
.await
.expect("BUG: Unable to add inputs to run");
let run1 = client
.get_run(&run.info.run_id)
.await
.expect("BUG: Cannot get run");
assert_eq!(
run1.inputs,
RunInputs {
inputs: vec![input]
}
);
}
#[fixture]
async fn run(experiment_name: String, run_name: String, client: Client) -> Run {
let experiment_id = client
.create_experiment(&experiment_name, vec![])
.await
.expect("BUG: Cannot create experiment");
let create = CreateRun::new()
.run_name(&run_name)
.experiment_id(&experiment_id)
.build();
client
.create_run(create)
.await
.expect("BUG: Cannot create run")
}
#[rstest]
#[tokio::test]
#[awt]
async fn test_run_update(client: Client, #[future] run: Run) {
let end_time_ms: u64 = 1733565663000;
let update = UpdateRun::new()
.run_id(&run.info.run_id)
.status(RunStatus::Killed)
.end_time(end_time_ms as i64)
.experiment_id(&run.info.experiment_id)
.build();
client
.update_run(update)
.await
.expect("BUG: Unable to update run");
let run1 = client
.get_run(&run.info.run_id)
.await
.expect("BUG: Unable to fetch run");
assert_eq!(run1.info.end_time, Some(end_time_ms));
assert_eq!(run1.info.status, RunStatus::Killed);
}
#[rstest]
#[tokio::test]
#[awt]
#[traced_test]
async fn test_artifacts(client: Client, #[future] run: Run) {
let artifact = Artifact {
experiment_id: run.info.experiment_id.clone(),
run_id: run.info.run_id.clone(),
path: PathBuf::from("abc/lock"),
};
let mut artifact1 = artifact.clone();
artifact1.path = PathBuf::from("abc/lock1");
let artifacts = vec![
UploadArtifact {
local: "Cargo.lock".into(),
remote: artifact.clone(),
},
UploadArtifact {
local: "Cargo.lock".into(),
remote: artifact1.clone(),
},
];
client
.upload_artifacts(artifacts)
.await
.expect("BUG: Unable to upload artifacts");
let list = client
.list_run_artifacts(&run.info.run_id)
.await
.expect("BUG: Unable to list run artifacts");
assert_eq!(list.files.len(), 2);
let download = client
.prepare_run_download(&run.info.run_id, "/tmp")
.await
.expect("BUG: Unable to prepare run download");
let mut artifacts = client
.download_run_artifacts(download)
.await
.expect("BUG: Cannot download artifact");
artifacts.paths.sort();
assert_eq!(
artifacts.paths,
vec![
PathBuf::from("/tmp/abc/lock"),
PathBuf::from("/tmp/abc/lock1"),
]
);
}
#[rstest]
#[tokio::test]
#[awt]
async fn test_artifacts_cache(mut client: Client, #[future] run: Run) {
let artifact = Artifact {
experiment_id: run.info.experiment_id.clone(),
run_id: run.info.run_id.clone(),
path: PathBuf::from("run.json"),
};
client
.upload_json_artifact(&run, &artifact)
.await
.expect("BUG: Cannot upload artifact");
client.download.cache_local_artifacts = true;
let filedir = PathBuf::from(format!("/tmp/{}/", run.info.run_id));
let filepath = filedir.join("run.json");
create_dir_all(&filedir)
.await
.expect("BUG: Unable to create destdir");
let file = File::create(&filepath)
.await
.expect("BUG: Unable to create dummy file");
let download = client
.prepare_run_download(&run.info.run_id, filedir)
.await
.expect("BUG: Unable to prepare run download");
let artifacts = client
.download_run_artifacts(download)
.await
.expect("BUG: Cannot download artifact");
assert_eq!(artifacts.paths, vec![filepath]);
assert_eq!(
file.metadata()
.await
.expect("BUG: Unable to get file metadata")
.len(),
0
);
}
#[rstest]
#[tokio::test]
#[awt]
async fn test_json_artifact(client: Client, #[future] run: Run) {
let artifact = Artifact {
experiment_id: run.info.experiment_id.clone(),
run_id: run.info.run_id.clone(),
path: PathBuf::from("run.json"),
};
client
.upload_json_artifact(&run, &artifact)
.await
.expect("BUG: Cannot upload artifact");
let run1 = client
.download_json_artifact::<Run>(&artifact)
.await
.expect("BUG: Cannot download artifact");
assert_eq!(run, run1);
}
#[rstest]
#[tokio::test]
#[awt]
async fn test_registered_model(client: Client, model_name: String) {
let registered = client
.register_model(
RegisterModel::new()
.name(&model_name)
.description("")
.build(),
)
.await
.expect("BUG: Unable to register model");
let registered1 = client
.get_registered_model(&model_name)
.await
.expect("BUG: Cannot get registered model");
assert_eq!(registered.name, registered1.name);
assert_eq!(
registered.creation_timestamp,
registered1.creation_timestamp
);
let _ = client
.delete_registered_model(&model_name)
.await
.expect("BUG: Unable to delete registered model");
}
#[rstest]
#[tokio::test]
#[awt]
async fn test_search_registered_models(client: Client, #[future] run: Run, model_name: String) {
let _registered = client
.register_model(
RegisterModel::new()
.name(&model_name)
.description("yep")
.tags(vec![KeyValue::new()
.key("kokot")
.value(&model_name)
.build()])
.build(),
)
.await
.expect("BUG: Unable to register model");
let create = CreateModelVersion::new()
.registered_model_name(&model_name)
.artifacts_url("s3:///kokot")
.run_id(&run.info.run_id)
.description("xxx")
.build();
client
.create_model_version(create)
.await
.expect("BUG: Cannot create model version");
let mut models = client
.search_registered_models(&format!("tags.kokot like '({}|abc)'", model_name))
.await
.expect("BUG: Cannot search registered models");
let mut model = models
.pop()
.expect("BUG: We must get at least one registred model");
let latest = model
.latest_versions
.pop()
.expect("BUG: Model must have at least one latest version");
assert_eq!(latest.run_id, run.info.run_id);
}
#[rstest]
#[tokio::test]
#[awt]
async fn test_add_model_version(
client: Client,
#[future] run: Run,
#[future]
#[from(run)]
run1: Run,
model_name: String,
) {
let _registered = client
.register_model(
RegisterModel::new()
.name(&model_name)
.description("yep")
.build(),
)
.await
.expect("BUG: Unable to register model");
let create = CreateModelVersion::new()
.registered_model_name(&model_name)
.artifacts_url("s3:///kokot")
.run_id(&run.info.run_id)
.description("xxx")
.build();
let mut create1 = create.clone();
create1.run_id = run1.info.run_id;
client
.create_model_version(create)
.await
.expect("BUG: Cannot create model version");
let version = client
.create_model_version(create1)
.await
.expect("BUG: Cannot create model version");
assert_eq!(version.current_stage, "None");
let transition = TransitionModelVersionStage::new()
.name(&model_name)
.version(version.version)
.stage(ModelVersionStage::Production)
.archive_existing_versions(false)
.build();
let version1 = client
.transition_model_version_stage(transition)
.await
.expect("BUG: Cannot transition model stage");
assert_eq!(version1.current_stage, "Production");
}
}