use crate::{
Annotation, Error, Sample, Task,
api::{
AnnotationSetID, Artifact, DatasetID, Experiment, ExperimentID, LoginResult, Organization,
Project, ProjectID, SampleID, SamplesCountResult, SamplesListParams, SamplesListResult,
Snapshot, SnapshotCreateFromDataset, SnapshotFromDatasetResult, SnapshotID,
SnapshotRestore, SnapshotRestoreResult, Stage, TaskID, TaskInfo, TaskStages, TaskStatus,
TasksListParams, TasksListResult, TrainingSession, TrainingSessionID, ValidationSession,
ValidationSessionID,
},
dataset::{
AnnotationSet, AnnotationType, Dataset, FileType, Group, Label, NewLabel, NewLabelObject,
},
retry::{create_retry_policy, log_retry_configuration},
storage::{FileTokenStorage, MemoryTokenStorage, TokenStorage},
};
use base64::Engine as _;
use chrono::{DateTime, Utc};
use directories::ProjectDirs;
use futures::{StreamExt as _, future::join_all};
use log::{Level, debug, error, log_enabled, trace, warn};
use reqwest::{Body, header::CONTENT_LENGTH, multipart::Form};
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::{
collections::HashMap,
ffi::OsStr,
fs::create_dir_all,
io::{SeekFrom, Write as _},
path::{Path, PathBuf},
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
time::Duration,
vec,
};
use tokio::{
fs::{self, File},
io::{AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _},
sync::{RwLock, Semaphore, mpsc::Sender},
};
use tokio_util::codec::{BytesCodec, FramedRead};
use walkdir::WalkDir;
#[cfg(feature = "polars")]
use polars::prelude::*;
static PART_SIZE: usize = 100 * 1024 * 1024;
#[derive(Clone)]
enum FileSource {
Path(PathBuf),
Bytes(Vec<u8>),
}
fn max_tasks() -> usize {
std::env::var("MAX_TASKS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or_else(|| {
let cpus = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4);
(cpus / 2).clamp(2, 8)
})
}
fn max_upload_tasks() -> usize {
std::env::var("MAX_UPLOAD_TASKS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(8) }
fn filter_and_sort_by_name<T, F>(items: Vec<T>, filter: &str, get_name: F) -> Vec<T>
where
F: Fn(&T) -> &str,
{
let filter_lower = filter.to_lowercase();
let mut filtered: Vec<T> = items
.into_iter()
.filter(|item| get_name(item).to_lowercase().contains(&filter_lower))
.collect();
filtered.sort_by(|a, b| {
let name_a = get_name(a);
let name_b = get_name(b);
let exact_a = name_a == filter;
let exact_b = name_b == filter;
if exact_a != exact_b {
return exact_b.cmp(&exact_a); }
let exact_ci_a = name_a.to_lowercase() == filter_lower;
let exact_ci_b = name_b.to_lowercase() == filter_lower;
if exact_ci_a != exact_ci_b {
return exact_ci_b.cmp(&exact_ci_a);
}
let len_cmp = name_a.len().cmp(&name_b.len());
if len_cmp != std::cmp::Ordering::Equal {
return len_cmp;
}
name_a.cmp(name_b)
});
filtered
}
fn sanitize_path_component(name: &str) -> String {
let trimmed = name.trim();
if trimmed.is_empty() {
return "unnamed".to_string();
}
let component = Path::new(trimmed)
.file_name()
.unwrap_or_else(|| OsStr::new(trimmed));
let sanitized: String = component
.to_string_lossy()
.chars()
.map(|c| match c {
'/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
_ => c,
})
.collect();
if sanitized.is_empty() {
"unnamed".to_string()
} else {
sanitized
}
}
#[derive(Debug, Clone)]
pub struct Progress {
pub current: usize,
pub total: usize,
pub status: Option<String>,
}
#[derive(Serialize)]
struct RpcRequest<Params> {
id: u64,
jsonrpc: String,
method: String,
params: Option<Params>,
}
impl<T> Default for RpcRequest<T> {
fn default() -> Self {
RpcRequest {
id: 0,
jsonrpc: "2.0".to_string(),
method: "".to_string(),
params: None,
}
}
}
#[derive(Deserialize)]
struct RpcError {
code: i32,
message: String,
}
#[derive(Deserialize)]
struct RpcResponse<RpcResult> {
#[allow(dead_code)]
id: String,
#[allow(dead_code)]
jsonrpc: String,
error: Option<RpcError>,
result: Option<RpcResult>,
}
#[derive(Deserialize)]
#[allow(dead_code)]
struct EmptyResult {}
#[derive(Debug, Serialize)]
#[allow(dead_code)]
struct SnapshotCreateParams {
snapshot_name: String,
keys: Vec<String>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct SnapshotCreateResult {
snapshot_id: SnapshotID,
urls: Vec<String>,
}
#[derive(Debug, Serialize)]
struct SnapshotCreateMultipartParams {
snapshot_name: String,
keys: Vec<String>,
file_sizes: Vec<usize>,
#[serde(skip_serializing_if = "Option::is_none", rename = "type")]
snapshot_type: Option<String>,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum SnapshotCreateMultipartResultField {
Id(u64),
Part(SnapshotPart),
}
#[derive(Debug, Serialize)]
struct SnapshotCompleteMultipartParams {
key: String,
upload_id: String,
etag_list: Vec<EtagPart>,
}
#[derive(Debug, Clone, Serialize)]
struct EtagPart {
#[serde(rename = "ETag")]
etag: String,
#[serde(rename = "PartNumber")]
part_number: usize,
}
#[derive(Debug, Clone, Deserialize)]
struct SnapshotPart {
key: Option<String>,
upload_id: String,
urls: Vec<String>,
}
#[derive(Debug, Serialize)]
struct SnapshotStatusParams {
snapshot_id: SnapshotID,
status: String,
}
#[derive(Deserialize, Debug)]
struct SnapshotStatusResult {
#[allow(dead_code)]
pub id: SnapshotID,
#[allow(dead_code)]
pub uid: String,
#[allow(dead_code)]
pub description: String,
#[allow(dead_code)]
pub date: String,
#[allow(dead_code)]
pub status: String,
}
#[derive(Serialize)]
#[allow(dead_code)]
struct ImageListParams {
images_filter: ImagesFilter,
image_files_filter: HashMap<String, String>,
only_ids: bool,
}
#[derive(Serialize)]
#[allow(dead_code)]
struct ImagesFilter {
dataset_id: DatasetID,
}
#[derive(Clone)]
pub struct Client {
http: reqwest::Client,
bulk_http: reqwest::Client,
url: String,
token: Arc<RwLock<String>>,
storage: Option<Arc<dyn TokenStorage>>,
token_path: Option<PathBuf>,
}
impl std::fmt::Debug for Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Client")
.field("url", &self.url)
.field("has_storage", &self.storage.is_some())
.field("token_path", &self.token_path)
.finish()
}
}
struct FetchContext<'a> {
dataset_id: DatasetID,
annotation_set_id: Option<AnnotationSetID>,
groups: &'a [String],
types: Vec<String>,
labels: &'a HashMap<String, u64>,
}
impl Client {
pub fn new() -> Result<Self, Error> {
log_retry_configuration();
let timeout_secs = std::env::var("EDGEFIRST_TIMEOUT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(30);
let read_timeout_secs = std::env::var("EDGEFIRST_READ_TIMEOUT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(120);
let http = reqwest::Client::builder()
.connect_timeout(Duration::from_secs(10))
.timeout(Duration::from_secs(timeout_secs))
.pool_idle_timeout(Duration::from_secs(90))
.pool_max_idle_per_host(10)
.retry(create_retry_policy())
.build()?;
let bulk_http = reqwest::Client::builder()
.connect_timeout(Duration::from_secs(30))
.read_timeout(Duration::from_secs(read_timeout_secs))
.pool_idle_timeout(Duration::from_secs(90))
.pool_max_idle_per_host(10)
.retry(create_retry_policy())
.build()?;
let storage: Arc<dyn TokenStorage> = match FileTokenStorage::new() {
Ok(file_storage) => Arc::new(file_storage),
Err(e) => {
warn!(
"Could not initialize file token storage: {}. Using memory storage.",
e
);
Arc::new(MemoryTokenStorage::new())
}
};
let token = match storage.load() {
Ok(Some(t)) => t,
Ok(None) => String::new(),
Err(e) => {
warn!(
"Failed to load token from storage: {}. Starting with empty token.",
e
);
String::new()
}
};
let url = if !token.is_empty() {
match Self::extract_server_from_token(&token) {
Ok(server) => format!("https://{}.edgefirst.studio", server),
Err(e) => {
warn!(
"Failed to extract server from token: {}. Using default server.",
e
);
"https://edgefirst.studio".to_string()
}
}
} else {
"https://edgefirst.studio".to_string()
};
Ok(Client {
http,
bulk_http,
url,
token: Arc::new(tokio::sync::RwLock::new(token)),
storage: Some(storage),
token_path: None,
})
}
pub fn with_server(&self, server: &str) -> Result<Self, Error> {
let url = match server {
"" | "saas" => "https://edgefirst.studio".to_string(),
name => format!("https://{}.edgefirst.studio", name),
};
if let Some(ref storage) = self.storage
&& let Err(e) = storage.clear()
{
warn!(
"Failed to clear token from storage when changing servers: {}",
e
);
}
Ok(Client {
url,
token: Arc::new(tokio::sync::RwLock::new(String::new())),
..self.clone()
})
}
pub fn with_storage(self, storage: Arc<dyn TokenStorage>) -> Self {
let token = match storage.load() {
Ok(Some(t)) => t,
Ok(None) => String::new(),
Err(e) => {
warn!(
"Failed to load token from storage: {}. Starting with empty token.",
e
);
String::new()
}
};
Client {
token: Arc::new(tokio::sync::RwLock::new(token)),
storage: Some(storage),
token_path: None,
..self
}
}
pub fn with_memory_storage(self) -> Self {
Client {
token: Arc::new(tokio::sync::RwLock::new(String::new())),
storage: Some(Arc::new(MemoryTokenStorage::new())),
token_path: None,
..self
}
}
pub fn with_no_storage(self) -> Self {
Client {
storage: None,
token_path: None,
..self
}
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self, password)))]
pub async fn with_login(&self, username: &str, password: &str) -> Result<Self, Error> {
let params = HashMap::from([("username", username), ("password", password)]);
let login: LoginResult = self
.rpc_without_auth("auth.login".to_owned(), Some(params))
.await?;
if login.token.is_empty() {
return Err(Error::EmptyToken);
}
if let Some(ref storage) = self.storage
&& let Err(e) = storage.store(&login.token)
{
warn!("Failed to persist token to storage: {}", e);
}
Ok(Client {
token: Arc::new(tokio::sync::RwLock::new(login.token)),
..self.clone()
})
}
pub fn with_token_path(&self, token_path: Option<&Path>) -> Result<Self, Error> {
let token_path = match token_path {
Some(path) => path.to_path_buf(),
None => ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
.ok_or_else(|| {
Error::IoError(std::io::Error::new(
std::io::ErrorKind::NotFound,
"Could not determine user config directory",
))
})?
.config_dir()
.join("token"),
};
debug!("Using token path (legacy): {:?}", token_path);
let token = match token_path.exists() {
true => std::fs::read_to_string(&token_path)?,
false => "".to_string(),
};
if !token.is_empty() {
match self.with_token(&token) {
Ok(client) => Ok(Client {
token_path: Some(token_path),
storage: None, ..client
}),
Err(e) => {
warn!(
"Invalid or corrupted token file at {:?}: {:?}. Removing token file.",
token_path, e
);
if let Err(remove_err) = std::fs::remove_file(&token_path) {
warn!("Failed to remove corrupted token file: {:?}", remove_err);
}
Ok(Client {
token_path: Some(token_path),
storage: None,
token: Arc::new(RwLock::new("".to_string())),
..self.clone()
})
}
}
} else {
Ok(Client {
token_path: Some(token_path),
storage: None,
token: Arc::new(RwLock::new("".to_string())),
..self.clone()
})
}
}
fn extract_server_from_token(token: &str) -> Result<String, Error> {
let token_parts: Vec<&str> = token.split('.').collect();
if token_parts.len() != 3 {
return Err(Error::InvalidToken);
}
let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
.decode(token_parts[1])
.map_err(|_| Error::InvalidToken)?;
let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
let server = match payload.get("server") {
Some(value) => value.as_str().ok_or(Error::InvalidToken)?.to_string(),
None => return Err(Error::InvalidToken),
};
Ok(server)
}
pub fn with_token(&self, token: &str) -> Result<Self, Error> {
if token.is_empty() {
return Ok(self.clone());
}
let server = Self::extract_server_from_token(token)?;
if let Some(ref storage) = self.storage
&& let Err(e) = storage.store(token)
{
warn!("Failed to persist token to storage: {}", e);
}
Ok(Client {
url: format!("https://{}.edgefirst.studio", server),
token: Arc::new(tokio::sync::RwLock::new(token.to_string())),
..self.clone()
})
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn save_token(&self) -> Result<(), Error> {
let token = self.token.read().await;
if let Some(ref storage) = self.storage {
storage.store(&token)?;
debug!("Token saved to storage");
return Ok(());
}
let path = self.token_path.clone().unwrap_or_else(|| {
ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
.map(|dirs| dirs.config_dir().join("token"))
.unwrap_or_else(|| PathBuf::from(".token"))
});
create_dir_all(path.parent().ok_or_else(|| {
Error::IoError(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Token path has no parent directory",
))
})?)?;
let mut file = std::fs::File::create(&path)?;
file.write_all(token.as_bytes())?;
debug!("Saved token to {:?}", path);
Ok(())
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn version(&self) -> Result<String, Error> {
let version: HashMap<String, String> = self
.rpc_without_auth::<(), HashMap<String, String>>("version".to_owned(), None)
.await?;
let version = version.get("version").ok_or(Error::InvalidResponse)?;
Ok(version.to_owned())
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn logout(&self) -> Result<(), Error> {
{
let mut token = self.token.write().await;
*token = "".to_string();
}
if let Some(ref storage) = self.storage
&& let Err(e) = storage.clear()
{
warn!("Failed to clear token from storage: {}", e);
}
if let Some(path) = &self.token_path
&& path.exists()
{
fs::remove_file(path).await?;
}
Ok(())
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn token(&self) -> String {
self.token.read().await.clone()
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn verify_token(&self) -> Result<(), Error> {
self.rpc::<(), LoginResult>("auth.verify_token".to_owned(), None)
.await?;
Ok::<(), Error>(())
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn renew_token(&self) -> Result<(), Error> {
let params = HashMap::from([("username".to_string(), self.username().await?)]);
let result: LoginResult = self
.rpc_without_auth("auth.refresh".to_owned(), Some(params))
.await?;
{
let mut token = self.token.write().await;
*token = result.token.clone();
}
if let Some(ref storage) = self.storage
&& let Err(e) = storage.store(&result.token)
{
warn!("Failed to persist renewed token to storage: {}", e);
}
if self.token_path.is_some() {
self.save_token().await?;
}
Ok(())
}
async fn token_field(&self, field: &str) -> Result<serde_json::Value, Error> {
let token = self.token.read().await;
if token.is_empty() {
return Err(Error::EmptyToken);
}
let token_parts: Vec<&str> = token.split('.').collect();
if token_parts.len() != 3 {
return Err(Error::InvalidToken);
}
let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
.decode(token_parts[1])
.map_err(|_| Error::InvalidToken)?;
let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
match payload.get(field) {
Some(value) => Ok(value.to_owned()),
None => Err(Error::InvalidToken),
}
}
pub fn url(&self) -> &str {
&self.url
}
pub fn server(&self) -> &str {
if self.url == "https://edgefirst.studio" {
"saas"
} else if let Some(name) = self.url.strip_prefix("https://") {
name.strip_suffix(".edgefirst.studio").unwrap_or("saas")
} else {
"saas"
}
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn username(&self) -> Result<String, Error> {
match self.token_field("username").await? {
serde_json::Value::String(username) => Ok(username),
_ => Err(Error::InvalidToken),
}
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn token_expiration(&self) -> Result<DateTime<Utc>, Error> {
let ts = match self.token_field("exp").await? {
serde_json::Value::Number(exp) => exp.as_i64().ok_or(Error::InvalidToken)?,
_ => return Err(Error::InvalidToken),
};
match DateTime::<Utc>::from_timestamp(ts, 0) {
Some(dt) => Ok(dt),
None => Err(Error::InvalidToken),
}
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn organization(&self) -> Result<Organization, Error> {
self.rpc::<(), Organization>("org.get".to_owned(), None)
.await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn projects(&self, name: Option<&str>) -> Result<Vec<Project>, Error> {
let projects = self
.rpc::<(), Vec<Project>>("project.list".to_owned(), None)
.await?;
if let Some(name) = name {
Ok(filter_and_sort_by_name(projects, name, |p| p.name()))
} else {
Ok(projects)
}
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(project_id = %project_id)))]
pub async fn project(&self, project_id: ProjectID) -> Result<Project, Error> {
let params = HashMap::from([("project_id", project_id)]);
self.rpc("project.get".to_owned(), Some(params)).await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn datasets(
&self,
project_id: ProjectID,
name: Option<&str>,
) -> Result<Vec<Dataset>, Error> {
let params = HashMap::from([("project_id", project_id)]);
let datasets: Vec<Dataset> = self.rpc("dataset.list".to_owned(), Some(params)).await?;
if let Some(name) = name {
Ok(filter_and_sort_by_name(datasets, name, |d| d.name()))
} else {
Ok(datasets)
}
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
pub async fn dataset(&self, dataset_id: DatasetID) -> Result<Dataset, Error> {
let params = HashMap::from([("dataset_id", dataset_id)]);
self.rpc("dataset.get".to_owned(), Some(params)).await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
pub async fn labels(&self, dataset_id: DatasetID) -> Result<Vec<Label>, Error> {
let params = HashMap::from([("dataset_id", dataset_id)]);
self.rpc("label.list".to_owned(), Some(params)).await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
pub async fn add_label(&self, dataset_id: DatasetID, name: &str) -> Result<(), Error> {
let new_label = NewLabel {
dataset_id,
labels: vec![NewLabelObject {
name: name.to_owned(),
}],
};
let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
Ok(())
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn remove_label(&self, label_id: u64) -> Result<(), Error> {
let params = HashMap::from([("label_id", label_id)]);
let _: String = self.rpc("label.del".to_owned(), Some(params)).await?;
Ok(())
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn create_dataset(
&self,
project_id: &str,
name: &str,
description: Option<&str>,
) -> Result<DatasetID, Error> {
let mut params = HashMap::new();
params.insert("project_id", project_id);
params.insert("name", name);
if let Some(desc) = description {
params.insert("description", desc);
}
#[derive(Deserialize)]
struct CreateDatasetResult {
id: DatasetID,
}
let result: CreateDatasetResult =
self.rpc("dataset.create".to_owned(), Some(params)).await?;
Ok(result.id)
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
pub async fn delete_dataset(&self, dataset_id: DatasetID) -> Result<(), Error> {
let params = HashMap::from([("id", dataset_id)]);
let _: serde_json::Value = self.rpc("dataset.delete".to_owned(), Some(params)).await?;
Ok(())
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self, label)))]
pub async fn update_label(&self, label: &Label) -> Result<(), Error> {
#[derive(Serialize)]
struct Params {
dataset_id: DatasetID,
label_id: u64,
label_name: String,
label_index: u64,
}
let _: String = self
.rpc(
"label.update".to_owned(),
Some(Params {
dataset_id: label.dataset_id(),
label_id: label.id(),
label_name: label.name().to_owned(),
label_index: label.index(),
}),
)
.await?;
Ok(())
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
pub async fn groups(&self, dataset_id: DatasetID) -> Result<Vec<Group>, Error> {
let params = HashMap::from([("dataset_id", dataset_id)]);
self.rpc("groups.list".to_owned(), Some(params)).await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
pub async fn get_or_create_group(
&self,
dataset_id: DatasetID,
name: &str,
) -> Result<u64, Error> {
let groups = self.groups(dataset_id).await?;
if let Some(group) = groups.iter().find(|g| g.name == name) {
return Ok(group.id);
}
#[derive(Serialize)]
struct CreateGroupParams {
dataset_id: DatasetID,
group_names: Vec<String>,
group_splits: Vec<i64>,
}
let params = CreateGroupParams {
dataset_id,
group_names: vec![name.to_string()],
group_splits: vec![0], };
let created_groups: Vec<Group> = self.rpc("groups.create".to_owned(), Some(params)).await?;
if let Some(group) = created_groups.into_iter().find(|g| g.name == name) {
Ok(group.id)
} else {
let groups = self.groups(dataset_id).await?;
groups
.iter()
.find(|g| g.name == name)
.map(|g| g.id)
.ok_or_else(|| {
Error::RpcError(0, format!("Failed to create or find group '{}'", name))
})
}
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn set_sample_group_id(
&self,
sample_id: SampleID,
group_id: u64,
) -> Result<(), Error> {
#[derive(Serialize)]
struct SetGroupParams {
image_id: SampleID,
group_id: u64,
}
let params = SetGroupParams {
image_id: sample_id,
group_id,
};
let _: String = self
.rpc("image.set_group_id".to_owned(), Some(params))
.await?;
Ok(())
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self, groups, file_types, progress), fields(dataset_id = %dataset_id, output = %output.display())))]
pub async fn download_dataset(
&self,
dataset_id: DatasetID,
groups: &[String],
file_types: &[FileType],
output: PathBuf,
flatten: bool,
progress: Option<Sender<Progress>>,
) -> Result<(), Error> {
let samples = self
.samples(dataset_id, None, &[], groups, file_types, progress.clone())
.await?;
fs::create_dir_all(&output).await?;
let total = samples.len();
let current = Arc::new(AtomicUsize::new(0));
let sem = Arc::new(Semaphore::new(max_tasks()));
if let Some(ref progress) = progress {
let _ = progress
.send(Progress {
current: 0,
total,
status: Some("Downloading".to_string()),
})
.await;
}
let tasks = samples
.into_iter()
.map(|sample| {
let client = self.clone();
let file_types = file_types.to_vec();
let output = output.clone();
let progress = progress.clone();
let current = current.clone();
let sem = sem.clone();
tokio::spawn(async move {
let _permit = sem.acquire().await.map_err(|_| {
Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
})?;
for file_type in &file_types {
if let Some(data) = sample.download(&client, file_type.clone()).await? {
let (file_ext, is_image) = match file_type {
FileType::Image => (
infer::get(&data)
.expect("Failed to identify image file format for sample")
.extension()
.to_string(),
true,
),
other => (other.file_extension().to_string(), false),
};
let sequence_dir = sample
.sequence_name()
.map(|name| sanitize_path_component(name));
let target_dir = if flatten {
output.clone()
} else {
sequence_dir
.as_ref()
.map(|seq| output.join(seq))
.unwrap_or_else(|| output.clone())
};
fs::create_dir_all(&target_dir).await?;
let sanitized_sample_name = sample
.name()
.map(|name| sanitize_path_component(&name))
.unwrap_or_else(|| "unknown".to_string());
let image_name = sample.image_name().map(sanitize_path_component);
let file_name = if is_image {
if let Some(img_name) = image_name {
Client::build_filename(
&img_name,
flatten,
sequence_dir.as_ref(),
sample.frame_number(),
)
} else {
format!("{}.{}", sanitized_sample_name, file_ext)
}
} else {
let base_name = format!("{}.{}", sanitized_sample_name, file_ext);
Client::build_filename(
&base_name,
flatten,
sequence_dir.as_ref(),
sample.frame_number(),
)
};
let file_path = target_dir.join(&file_name);
let mut file = File::create(&file_path).await?;
file.write_all(&data).await?;
}
}
if let Some(progress) = &progress {
let completed = current.fetch_add(1, Ordering::SeqCst) + 1;
let _ = progress
.send(Progress {
current: completed,
total,
status: Some("Downloading".to_string()),
})
.await;
}
Ok::<(), Error>(())
})
})
.collect::<Vec<_>>();
join_all(tasks)
.await
.into_iter()
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.collect::<Result<Vec<_>, _>>()?;
Ok(())
}
fn build_filename(
base_name: &str,
flatten: bool,
sequence_name: Option<&String>,
frame_number: Option<u32>,
) -> String {
if !flatten || sequence_name.is_none() {
return base_name.to_string();
}
let seq_name = sequence_name.unwrap();
let prefix = format!("{}_", seq_name);
if base_name.starts_with(&prefix) {
base_name.to_string()
} else {
match frame_number {
Some(frame) => format!("{}{}_{}", prefix, frame, base_name),
None => format!("{}{}", prefix, base_name),
}
}
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
pub async fn annotation_sets(
&self,
dataset_id: DatasetID,
) -> Result<Vec<AnnotationSet>, Error> {
let params = HashMap::from([("dataset_id", dataset_id)]);
self.rpc("annset.list".to_owned(), Some(params)).await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn create_annotation_set(
&self,
dataset_id: DatasetID,
name: &str,
description: Option<&str>,
) -> Result<AnnotationSetID, Error> {
#[derive(Serialize)]
struct Params<'a> {
dataset_id: DatasetID,
name: &'a str,
operator: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
description: Option<&'a str>,
}
#[derive(Deserialize)]
struct CreateAnnotationSetResult {
id: AnnotationSetID,
}
let username = self.username().await?;
let result: CreateAnnotationSetResult = self
.rpc(
"annset.add".to_owned(),
Some(Params {
dataset_id,
name,
operator: &username,
description,
}),
)
.await?;
Ok(result.id)
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
pub async fn delete_annotation_set(
&self,
annotation_set_id: AnnotationSetID,
) -> Result<(), Error> {
let params = HashMap::from([("id", annotation_set_id)]);
let _: serde_json::Value = self.rpc("annset.delete".to_owned(), Some(params)).await?;
Ok(())
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
pub async fn annotation_set(
&self,
annotation_set_id: AnnotationSetID,
) -> Result<AnnotationSet, Error> {
let params = HashMap::from([("annotation_set_id", annotation_set_id)]);
self.rpc("annset.get".to_owned(), Some(params)).await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
pub async fn annotations(
&self,
annotation_set_id: AnnotationSetID,
groups: &[String],
annotation_types: &[AnnotationType],
progress: Option<Sender<Progress>>,
) -> Result<Vec<Annotation>, Error> {
let dataset_id = self.annotation_set(annotation_set_id).await?.dataset_id();
let labels = self
.labels(dataset_id)
.await?
.into_iter()
.map(|label| (label.name().to_string(), label.index()))
.collect::<HashMap<_, _>>();
let total = self
.samples_count(
dataset_id,
Some(annotation_set_id),
annotation_types,
groups,
&[],
)
.await?
.total as usize;
if total == 0 {
return Ok(vec![]);
}
let context = FetchContext {
dataset_id,
annotation_set_id: Some(annotation_set_id),
groups,
types: annotation_types.iter().map(|t| t.to_string()).collect(),
labels: &labels,
};
self.fetch_annotations_paginated(context, total, progress)
.await
}
async fn fetch_annotations_paginated(
&self,
context: FetchContext<'_>,
total: usize,
progress: Option<Sender<Progress>>,
) -> Result<Vec<Annotation>, Error> {
let mut annotations = vec![];
let mut continue_token: Option<String> = None;
let mut current = 0;
loop {
let params = SamplesListParams {
dataset_id: context.dataset_id,
annotation_set_id: context.annotation_set_id,
types: context.types.clone(),
group_names: context.groups.to_vec(),
continue_token,
};
let result: SamplesListResult =
self.rpc("samples.list".to_owned(), Some(params)).await?;
current += result.samples.len();
continue_token = result.continue_token;
if result.samples.is_empty() {
break;
}
self.process_sample_annotations(&result.samples, context.labels, &mut annotations);
if let Some(progress) = &progress {
let _ = progress
.send(Progress {
current,
total,
status: None,
})
.await;
}
match &continue_token {
Some(token) if !token.is_empty() => continue,
_ => break,
}
}
drop(progress);
Ok(annotations)
}
fn process_sample_annotations(
&self,
samples: &[Sample],
labels: &HashMap<String, u64>,
annotations: &mut Vec<Annotation>,
) {
for sample in samples {
if sample.annotations().is_empty() {
let mut annotation = Annotation::new();
annotation.set_sample_id(sample.id());
annotation.set_name(sample.name());
annotation.set_sequence_name(sample.sequence_name().cloned());
annotation.set_frame_number(sample.frame_number());
annotation.set_group(sample.group().cloned());
annotations.push(annotation);
continue;
}
for annotation in sample.annotations() {
let mut annotation = annotation.clone();
annotation.set_sample_id(sample.id());
annotation.set_name(sample.name());
annotation.set_sequence_name(sample.sequence_name().cloned());
annotation.set_frame_number(sample.frame_number());
annotation.set_group(sample.group().cloned());
Self::set_label_index_from_map(&mut annotation, labels);
annotations.push(annotation);
}
}
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, sample_ids), fields(annotation_set_id = %annotation_set_id)))]
pub async fn delete_annotations_bulk(
&self,
annotation_set_id: AnnotationSetID,
annotation_types: &[String],
sample_ids: &[SampleID],
) -> Result<(), Error> {
use crate::api::AnnotationBulkDeleteParams;
let params = AnnotationBulkDeleteParams {
annotation_set_id: annotation_set_id.into(),
annotation_types: annotation_types.to_vec(),
image_ids: sample_ids.iter().map(|id| (*id).into()).collect(),
delete_all: None,
};
let _: String = self
.rpc("annotation.bulk.del".to_owned(), Some(params))
.await?;
Ok(())
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotations), fields(annotation_count = annotations.len())))]
pub async fn add_annotations_bulk(
&self,
annotation_set_id: AnnotationSetID,
annotations: Vec<crate::api::ServerAnnotation>,
) -> Result<Vec<serde_json::Value>, Error> {
use crate::api::AnnotationAddBulkParams;
let params = AnnotationAddBulkParams {
annotation_set_id: annotation_set_id.into(),
annotations,
};
self.rpc("annotation.add_bulk".to_owned(), Some(params))
.await
}
fn parse_frame_from_image_name(
image_name: Option<&String>,
sequence_name: Option<&String>,
) -> Option<u32> {
use std::path::Path;
let sequence = sequence_name?;
let name = image_name?;
let stem = Path::new(name).file_stem().and_then(|s| s.to_str())?;
stem.strip_prefix(sequence)
.and_then(|suffix| suffix.strip_prefix('_'))
.and_then(|frame_str| frame_str.parse::<u32>().ok())
}
fn set_label_index_from_map(annotation: &mut Annotation, labels: &HashMap<String, u64>) {
if let Some(label) = annotation.label() {
annotation.set_label_index(Some(labels[label.as_str()]));
}
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, groups, types), fields(dataset_id = %dataset_id, annotation_set_id = ?annotation_set_id)))]
pub async fn samples_count(
&self,
dataset_id: DatasetID,
annotation_set_id: Option<AnnotationSetID>,
annotation_types: &[AnnotationType],
groups: &[String],
types: &[FileType],
) -> Result<SamplesCountResult, Error> {
let types = annotation_types
.iter()
.map(|t| t.as_server_type().to_string())
.chain(types.iter().map(|t| t.to_string()))
.collect::<Vec<_>>();
let params = SamplesListParams {
dataset_id,
annotation_set_id,
group_names: groups.to_vec(),
types,
continue_token: None,
};
self.rpc("samples.count".to_owned(), Some(params)).await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, groups, types, progress), fields(dataset_id = %dataset_id, annotation_set_id = ?annotation_set_id)))]
pub async fn samples(
&self,
dataset_id: DatasetID,
annotation_set_id: Option<AnnotationSetID>,
annotation_types: &[AnnotationType],
groups: &[String],
types: &[FileType],
progress: Option<Sender<Progress>>,
) -> Result<Vec<Sample>, Error> {
let types_vec = annotation_types
.iter()
.map(|t| t.as_server_type().to_string())
.chain(types.iter().map(|t| t.to_string()))
.collect::<Vec<_>>();
let labels = self
.labels(dataset_id)
.await?
.into_iter()
.map(|label| (label.name().to_string(), label.index()))
.collect::<HashMap<_, _>>();
let total = self
.samples_count(dataset_id, annotation_set_id, annotation_types, groups, &[])
.await?
.total as usize;
if total == 0 {
return Ok(vec![]);
}
let context = FetchContext {
dataset_id,
annotation_set_id,
groups,
types: types_vec,
labels: &labels,
};
self.fetch_samples_paginated(context, total, progress).await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
pub async fn sample_names(
&self,
dataset_id: DatasetID,
groups: &[String],
progress: Option<Sender<Progress>>,
) -> Result<std::collections::HashSet<String>, Error> {
use std::collections::HashSet;
let total = self
.samples_count(dataset_id, None, &[], groups, &[])
.await?
.total as usize;
if total == 0 {
return Ok(HashSet::new());
}
let mut names = HashSet::with_capacity(total);
let mut continue_token: Option<String> = None;
let mut current = 0;
loop {
let params = SamplesListParams {
dataset_id,
annotation_set_id: None,
types: vec![], group_names: groups.to_vec(),
continue_token: continue_token.clone(),
};
let result: SamplesListResult =
self.rpc("samples.list".to_owned(), Some(params)).await?;
current += result.samples.len();
continue_token = result.continue_token;
if result.samples.is_empty() {
break;
}
for sample in result.samples {
if let Some(name) = sample.name() {
names.insert(name);
}
}
if let Some(ref p) = progress {
let _ = p
.send(Progress {
current,
total,
status: None,
})
.await;
}
match &continue_token {
Some(token) if !token.is_empty() => continue,
_ => break,
}
}
Ok(names)
}
async fn fetch_samples_paginated(
&self,
context: FetchContext<'_>,
total: usize,
progress: Option<Sender<Progress>>,
) -> Result<Vec<Sample>, Error> {
let mut samples = vec![];
let mut continue_token: Option<String> = None;
let mut current = 0;
loop {
let params = SamplesListParams {
dataset_id: context.dataset_id,
annotation_set_id: context.annotation_set_id,
types: context.types.clone(),
group_names: context.groups.to_vec(),
continue_token: continue_token.clone(),
};
let result: SamplesListResult =
self.rpc("samples.list".to_owned(), Some(params)).await?;
current += result.samples.len();
continue_token = result.continue_token;
if result.samples.is_empty() {
break;
}
samples.append(
&mut result
.samples
.into_iter()
.map(|s| {
let frame_number = s.frame_number.or_else(|| {
Self::parse_frame_from_image_name(
s.image_name.as_ref(),
s.sequence_name.as_ref(),
)
});
let mut anns = s.annotations().to_vec();
for ann in &mut anns {
ann.set_name(s.name());
ann.set_group(s.group().cloned());
ann.set_sequence_name(s.sequence_name().cloned());
ann.set_frame_number(frame_number);
Self::set_label_index_from_map(ann, context.labels);
}
s.with_annotations(anns).with_frame_number(frame_number)
})
.collect::<Vec<_>>(),
);
if let Some(progress) = &progress {
let _ = progress
.send(Progress {
current,
total,
status: None,
})
.await;
}
match &continue_token {
Some(token) if !token.is_empty() => continue,
_ => break,
}
}
drop(progress);
Ok(samples)
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self, samples, progress), fields(sample_count = samples.len())))]
pub async fn populate_samples(
&self,
dataset_id: DatasetID,
annotation_set_id: Option<AnnotationSetID>,
samples: Vec<Sample>,
progress: Option<Sender<Progress>>,
) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
self.populate_samples_with_concurrency(
dataset_id,
annotation_set_id,
samples,
progress,
None,
)
.await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self, samples, progress), fields(sample_count = samples.len())))]
pub async fn populate_samples_with_concurrency(
&self,
dataset_id: DatasetID,
annotation_set_id: Option<AnnotationSetID>,
samples: Vec<Sample>,
progress: Option<Sender<Progress>>,
concurrency: Option<usize>,
) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
use crate::api::SamplesPopulateParams;
let mut files_to_upload: Vec<(String, String, FileSource, String)> = Vec::new();
let samples = self.prepare_samples_for_upload(samples, &mut files_to_upload)?;
let has_files_to_upload = !files_to_upload.is_empty();
let params = SamplesPopulateParams {
dataset_id,
annotation_set_id,
presigned_urls: Some(has_files_to_upload),
samples,
};
let results: Vec<crate::SamplesPopulateResult> = self
.rpc("samples.populate2".to_owned(), Some(params))
.await?;
if has_files_to_upload {
self.upload_sample_files(&results, files_to_upload, progress, concurrency)
.await?;
}
Ok(results)
}
fn prepare_samples_for_upload(
&self,
samples: Vec<Sample>,
files_to_upload: &mut Vec<(String, String, FileSource, String)>,
) -> Result<Vec<Sample>, Error> {
Ok(samples
.into_iter()
.map(|mut sample| {
if sample.uuid.is_none() {
sample.uuid = Some(uuid::Uuid::new_v4().to_string());
}
let sample_uuid = sample.uuid.clone().expect("UUID just set above");
let files_copy = sample.files.clone();
let updated_files: Vec<crate::SampleFile> = files_copy
.iter()
.map(|file| {
self.process_sample_file(file, &sample_uuid, &mut sample, files_to_upload)
})
.collect();
sample.files = updated_files;
sample
})
.collect())
}
fn process_sample_file(
&self,
file: &crate::SampleFile,
sample_uuid: &str,
sample: &mut Sample,
files_to_upload: &mut Vec<(String, String, FileSource, String)>,
) -> crate::SampleFile {
use std::path::Path;
if let Some(bytes) = file.bytes()
&& let Some(filename) = file.filename()
{
if file.file_type() == "image"
&& (sample.width.is_none() || sample.height.is_none())
&& let Ok(size) = imagesize::blob_size(bytes)
{
sample.width = Some(size.width as u32);
sample.height = Some(size.height as u32);
}
files_to_upload.push((
sample_uuid.to_string(),
file.file_type().to_string(),
FileSource::Bytes(bytes.to_vec()),
filename.to_string(),
));
return crate::SampleFile::with_filename(
file.file_type().to_string(),
filename.to_string(),
);
}
if let Some(filename) = file.filename() {
let path = Path::new(filename);
if path.exists()
&& path.is_file()
&& let Some(basename) = path.file_name().and_then(|s| s.to_str())
{
if file.file_type() == "image"
&& (sample.width.is_none() || sample.height.is_none())
&& let Ok(size) = imagesize::size(path)
{
sample.width = Some(size.width as u32);
sample.height = Some(size.height as u32);
}
files_to_upload.push((
sample_uuid.to_string(),
file.file_type().to_string(),
FileSource::Path(path.to_path_buf()),
basename.to_string(),
));
return crate::SampleFile::with_filename(
file.file_type().to_string(),
basename.to_string(),
);
}
}
file.clone()
}
async fn upload_sample_files(
&self,
results: &[crate::SamplesPopulateResult],
files_to_upload: Vec<(String, String, FileSource, String)>,
progress: Option<Sender<Progress>>,
concurrency: Option<usize>,
) -> Result<(), Error> {
let mut upload_map: HashMap<(String, String), FileSource> = HashMap::new();
for (uuid, _file_type, source, basename) in files_to_upload {
upload_map.insert((uuid, basename), source);
}
let http = self.bulk_http.clone();
let upload_tasks: Vec<_> = results
.iter()
.map(|result| (result.uuid.clone(), result.urls.clone()))
.collect();
parallel_foreach_items(
upload_tasks,
progress.clone(),
concurrency,
move |(uuid, urls)| {
let http = http.clone();
let upload_map = upload_map.clone();
async move {
for url_info in &urls {
if let Some(source) =
upload_map.get(&(uuid.clone(), url_info.filename.clone()))
{
match source {
FileSource::Path(path) => {
upload_file_to_presigned_url(
http.clone(),
&url_info.url,
path.clone(),
)
.await?;
}
FileSource::Bytes(bytes) => {
upload_bytes_to_presigned_url(
http.clone(),
&url_info.url,
bytes.clone(),
&url_info.filename,
)
.await?;
}
}
}
}
Ok(())
}
},
)
.await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn download(&self, url: &str) -> Result<Vec<u8>, Error> {
if !url.starts_with("http://") && !url.starts_with("https://") {
return Err(Error::InvalidParameters(format!(
"Invalid URL (must be absolute): {}",
url
)));
}
let resp = self.bulk_http.get(url).send().await?;
if !resp.status().is_success() {
return Err(Error::HttpError(resp.error_for_status().unwrap_err()));
}
let bytes = resp.bytes().await?;
Ok(bytes.to_vec())
}
#[cfg(feature = "polars")]
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
pub async fn samples_dataframe(
&self,
dataset_id: DatasetID,
annotation_set_id: Option<AnnotationSetID>,
groups: &[String],
types: &[AnnotationType],
progress: Option<Sender<Progress>>,
) -> Result<DataFrame, Error> {
use crate::dataset::samples_dataframe;
let samples = self
.samples(dataset_id, annotation_set_id, types, groups, &[], progress)
.await?;
samples_dataframe(&samples)
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn snapshots(&self, name: Option<&str>) -> Result<Vec<Snapshot>, Error> {
let snapshots: Vec<Snapshot> = self
.rpc::<(), Vec<Snapshot>>("snapshots.list".to_owned(), None)
.await?;
if let Some(name) = name {
Ok(filter_and_sort_by_name(snapshots, name, |s| {
s.description()
}))
} else {
Ok(snapshots)
}
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(snapshot_id = %snapshot_id)))]
pub async fn snapshot(&self, snapshot_id: SnapshotID) -> Result<Snapshot, Error> {
let params = HashMap::from([("snapshot_id", snapshot_id)]);
self.rpc("snapshots.get".to_owned(), Some(params)).await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress)))]
pub async fn create_snapshot(
&self,
path: &str,
progress: Option<Sender<Progress>>,
) -> Result<Snapshot, Error> {
let path = Path::new(path);
if path.is_dir() {
let path_str = path.to_str().ok_or_else(|| {
Error::IoError(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Path contains invalid UTF-8",
))
})?;
return self.create_snapshot_folder(path_str, progress).await;
}
let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
Error::IoError(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Invalid filename",
))
})?;
let total = path.metadata()?.len() as usize;
let current = Arc::new(AtomicUsize::new(0));
if let Some(progress) = &progress {
let _ = progress
.send(Progress {
current: 0,
total,
status: None,
})
.await;
}
let params = SnapshotCreateMultipartParams {
snapshot_name: name.to_owned(),
keys: vec![name.to_owned()],
file_sizes: vec![total],
snapshot_type: None,
};
let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
.rpc(
"snapshots.create_upload_url_multipart".to_owned(),
Some(params),
)
.await?;
let snapshot_id = match multipart.get("snapshot_id") {
Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
_ => return Err(Error::InvalidResponse),
};
let snapshot = self.snapshot(snapshot_id).await?;
let part_prefix = snapshot
.path()
.split("::/")
.last()
.ok_or(Error::InvalidResponse)?
.to_owned();
let part_key = format!("{}/{}", part_prefix, name);
let mut part = match multipart.get(&part_key) {
Some(SnapshotCreateMultipartResultField::Part(part)) => part,
_ => return Err(Error::InvalidResponse),
}
.clone();
part.key = Some(part_key);
let params = upload_multipart(
self.bulk_http.clone(),
part.clone(),
path.to_path_buf(),
total,
current,
progress.clone(),
)
.await?;
let complete: String = self
.rpc(
"snapshots.complete_multipart_upload".to_owned(),
Some(params),
)
.await?;
debug!("Snapshot Multipart Complete: {:?}", complete);
let params: SnapshotStatusParams = SnapshotStatusParams {
snapshot_id,
status: "available".to_owned(),
};
let _: SnapshotStatusResult = self
.rpc("snapshots.update".to_owned(), Some(params))
.await?;
if let Some(progress) = progress {
drop(progress);
}
self.snapshot(snapshot_id).await
}
async fn create_snapshot_folder(
&self,
path: &str,
progress: Option<Sender<Progress>>,
) -> Result<Snapshot, Error> {
let path = Path::new(path);
let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
Error::IoError(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Invalid directory name",
))
})?;
let files = WalkDir::new(path)
.into_iter()
.filter_map(|entry| entry.ok())
.filter(|entry| entry.file_type().is_file())
.filter_map(|entry| entry.path().strip_prefix(path).ok().map(|p| p.to_owned()))
.collect::<Vec<_>>();
let total: usize = files
.iter()
.filter_map(|file| path.join(file).metadata().ok())
.map(|metadata| metadata.len() as usize)
.sum();
let current = Arc::new(AtomicUsize::new(0));
if let Some(progress) = &progress {
let _ = progress
.send(Progress {
current: 0,
total,
status: None,
})
.await;
}
let keys = files
.iter()
.filter_map(|key| key.to_str().map(|s| s.to_owned()))
.collect::<Vec<_>>();
let file_sizes = files
.iter()
.filter_map(|key| path.join(key).metadata().ok())
.map(|metadata| metadata.len() as usize)
.collect::<Vec<_>>();
let params = SnapshotCreateMultipartParams {
snapshot_name: name.to_owned(),
keys,
file_sizes,
snapshot_type: None,
};
let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
.rpc(
"snapshots.create_upload_url_multipart".to_owned(),
Some(params),
)
.await?;
let snapshot_id = match multipart.get("snapshot_id") {
Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
_ => return Err(Error::InvalidResponse),
};
let snapshot = self.snapshot(snapshot_id).await?;
let part_prefix = snapshot
.path()
.split("::/")
.last()
.ok_or(Error::InvalidResponse)?
.to_owned();
for file in files {
let file_str = file.to_str().ok_or_else(|| {
Error::IoError(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"File path contains invalid UTF-8",
))
})?;
let part_key = format!("{}/{}", part_prefix, file_str);
let mut part = match multipart.get(&part_key) {
Some(SnapshotCreateMultipartResultField::Part(part)) => part,
_ => return Err(Error::InvalidResponse),
}
.clone();
part.key = Some(part_key);
let params = upload_multipart(
self.bulk_http.clone(),
part.clone(),
path.join(file),
total,
current.clone(),
progress.clone(),
)
.await?;
let complete: String = self
.rpc(
"snapshots.complete_multipart_upload".to_owned(),
Some(params),
)
.await?;
debug!("Snapshot Part Complete: {:?}", complete);
}
let params = SnapshotStatusParams {
snapshot_id,
status: "available".to_owned(),
};
let _: SnapshotStatusResult = self
.rpc("snapshots.update".to_owned(), Some(params))
.await?;
if let Some(progress) = progress {
drop(progress);
}
self.snapshot(snapshot_id).await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress)))]
pub async fn create_snapshot_edgefirst_format(
&self,
arrow_path: &str,
zip_path: &str,
description: Option<&str>,
progress: Option<Sender<Progress>>,
) -> Result<Snapshot, Error> {
let arrow_path = Path::new(arrow_path);
let zip_path = Path::new(zip_path);
if !arrow_path.exists() {
return Err(Error::IoError(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("Arrow file not found: {}", arrow_path.display()),
)));
}
if !zip_path.exists() {
return Err(Error::IoError(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("ZIP file not found: {}", zip_path.display()),
)));
}
let arrow_name = arrow_path
.file_name()
.and_then(|n| n.to_str())
.ok_or_else(|| {
Error::IoError(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Invalid Arrow filename",
))
})?;
let zip_name = zip_path
.file_name()
.and_then(|n| n.to_str())
.ok_or_else(|| {
Error::IoError(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Invalid ZIP filename",
))
})?;
let snapshot_name = description
.map(|s| s.to_string())
.or_else(|| {
arrow_path
.file_stem()
.and_then(|s| s.to_str())
.map(|s| s.to_string())
})
.unwrap_or_else(|| "edgefirst_dataset".to_string());
let arrow_size = arrow_path.metadata()?.len() as usize;
let zip_size = zip_path.metadata()?.len() as usize;
let total = arrow_size + zip_size;
let current = Arc::new(AtomicUsize::new(0));
if let Some(progress) = &progress {
let _ = progress
.send(Progress {
current: 0,
total,
status: None,
})
.await;
}
let params = SnapshotCreateMultipartParams {
snapshot_name,
keys: vec![arrow_name.to_owned(), zip_name.to_owned()],
file_sizes: vec![arrow_size, zip_size],
snapshot_type: Some("ziparrow".to_string()),
};
let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
.rpc(
"snapshots.create_upload_url_multipart".to_owned(),
Some(params),
)
.await?;
let snapshot_id = match multipart.get("snapshot_id") {
Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
_ => return Err(Error::InvalidResponse),
};
let snapshot = self.snapshot(snapshot_id).await?;
let part_prefix = snapshot
.path()
.split("::/")
.last()
.ok_or(Error::InvalidResponse)?
.to_owned();
let arrow_key = format!("{}/{}", part_prefix, arrow_name);
let mut arrow_part = match multipart.get(&arrow_key) {
Some(SnapshotCreateMultipartResultField::Part(part)) => part.clone(),
_ => return Err(Error::InvalidResponse),
};
arrow_part.key = Some(arrow_key);
let params = upload_multipart(
self.bulk_http.clone(),
arrow_part,
arrow_path.to_path_buf(),
total,
current.clone(),
progress.clone(),
)
.await?;
let _: String = self
.rpc(
"snapshots.complete_multipart_upload".to_owned(),
Some(params),
)
.await?;
debug!("Arrow file upload complete");
let zip_key = format!("{}/{}", part_prefix, zip_name);
let mut zip_part = match multipart.get(&zip_key) {
Some(SnapshotCreateMultipartResultField::Part(part)) => part.clone(),
_ => return Err(Error::InvalidResponse),
};
zip_part.key = Some(zip_key);
let params = upload_multipart(
self.bulk_http.clone(),
zip_part,
zip_path.to_path_buf(),
total,
current.clone(),
progress.clone(),
)
.await?;
let _: String = self
.rpc(
"snapshots.complete_multipart_upload".to_owned(),
Some(params),
)
.await?;
debug!("ZIP file upload complete");
let params = SnapshotStatusParams {
snapshot_id,
status: "available".to_owned(),
};
let _: SnapshotStatusResult = self
.rpc("snapshots.update".to_owned(), Some(params))
.await?;
if let Some(progress) = progress {
drop(progress);
}
self.snapshot(snapshot_id).await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(snapshot_id = %snapshot_id)))]
pub async fn delete_snapshot(&self, snapshot_id: SnapshotID) -> Result<(), Error> {
let params = HashMap::from([("snapshot_id", snapshot_id)]);
let _: serde_json::Value = self
.rpc("snapshots.delete".to_owned(), Some(params))
.await?;
Ok(())
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
pub async fn create_snapshot_from_dataset(
&self,
dataset_id: DatasetID,
description: &str,
annotation_set_id: Option<AnnotationSetID>,
) -> Result<SnapshotFromDatasetResult, Error> {
let annotation_set_id = match annotation_set_id {
Some(id) => id,
None => {
let sets = self.annotation_sets(dataset_id).await?;
if sets.is_empty() {
return Err(Error::InvalidParameters(
"No annotation sets available for dataset".to_owned(),
));
}
sets.iter()
.find(|s| s.name() == "annotations")
.unwrap_or(&sets[0])
.id()
}
};
let params = SnapshotCreateFromDataset {
description: description.to_owned(),
dataset_id,
annotation_set_id,
};
self.rpc("snapshots.create".to_owned(), Some(params)).await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(snapshot_id = %snapshot_id, output = %output.display())))]
pub async fn download_snapshot(
&self,
snapshot_id: SnapshotID,
output: PathBuf,
progress: Option<Sender<Progress>>,
) -> Result<(), Error> {
fs::create_dir_all(&output).await?;
let params = HashMap::from([("snapshot_id", snapshot_id)]);
let items: HashMap<String, String> = self
.rpc("snapshots.create_download_url".to_owned(), Some(params))
.await?;
let http = self.bulk_http.clone();
let current = Arc::new(AtomicUsize::new(0));
let total = Arc::new(AtomicUsize::new(0));
let sem = Arc::new(Semaphore::new(max_tasks()));
let tasks = items
.into_iter()
.map(|(key, url)| {
let http = http.clone();
let output = output.clone();
let progress = progress.clone();
let current = current.clone();
let total = total.clone();
let sem = sem.clone();
tokio::spawn(async move {
let _permit = sem.acquire().await.map_err(|_| {
Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
})?;
let res = http.get(url).send().await?;
let res = res.error_for_status()?;
if let Some(len) = res.content_length() {
total.fetch_add(len as usize, Ordering::SeqCst);
}
let mut file = File::create(output.join(key)).await?;
let mut stream = res.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
file.write_all(&chunk).await?;
let len = chunk.len();
if let Some(progress) = &progress {
let cur = current.fetch_add(len, Ordering::SeqCst) + len;
let tot = total.load(Ordering::SeqCst);
let _ = progress
.send(Progress {
current: cur,
total: tot,
status: None,
})
.await;
}
}
Ok::<(), Error>(())
})
})
.collect::<Vec<_>>();
join_all(tasks)
.await
.into_iter()
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.collect::<Result<Vec<_>, _>>()?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn restore_snapshot(
&self,
project_id: ProjectID,
snapshot_id: SnapshotID,
topics: &[String],
autolabel: &[String],
autodepth: bool,
dataset_name: Option<&str>,
dataset_description: Option<&str>,
) -> Result<SnapshotRestoreResult, Error> {
let params = SnapshotRestore {
project_id,
snapshot_id,
fps: 1,
autodepth,
agtg_pipeline: !autolabel.is_empty(),
autolabel: autolabel.to_vec(),
topics: topics.to_vec(),
dataset_name: dataset_name.map(|s| s.to_owned()),
dataset_description: dataset_description.map(|s| s.to_owned()),
};
self.rpc("snapshots.restore".to_owned(), Some(params)).await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn experiments(
&self,
project_id: ProjectID,
name: Option<&str>,
) -> Result<Vec<Experiment>, Error> {
let params = HashMap::from([("project_id", project_id)]);
let experiments: Vec<Experiment> =
self.rpc("trainer.list2".to_owned(), Some(params)).await?;
if let Some(name) = name {
Ok(filter_and_sort_by_name(experiments, name, |e| e.name()))
} else {
Ok(experiments)
}
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn experiment(&self, experiment_id: ExperimentID) -> Result<Experiment, Error> {
let params = HashMap::from([("trainer_id", experiment_id)]);
self.rpc("trainer.get".to_owned(), Some(params)).await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn training_sessions(
&self,
experiment_id: ExperimentID,
name: Option<&str>,
) -> Result<Vec<TrainingSession>, Error> {
let params = HashMap::from([("trainer_id", experiment_id)]);
let sessions: Vec<TrainingSession> = self
.rpc("trainer.session.list".to_owned(), Some(params))
.await?;
if let Some(name) = name {
Ok(filter_and_sort_by_name(sessions, name, |s| s.name()))
} else {
Ok(sessions)
}
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn training_session(
&self,
session_id: TrainingSessionID,
) -> Result<TrainingSession, Error> {
let params = HashMap::from([("trainer_session_id", session_id)]);
self.rpc("trainer.session.get".to_owned(), Some(params))
.await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn validation_sessions(
&self,
project_id: ProjectID,
) -> Result<Vec<ValidationSession>, Error> {
let params = HashMap::from([("project_id", project_id)]);
self.rpc("validate.session.list".to_owned(), Some(params))
.await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn validation_session(
&self,
session_id: ValidationSessionID,
) -> Result<ValidationSession, Error> {
let params = HashMap::from([("validate_session_id", session_id)]);
self.rpc("validate.session.get".to_owned(), Some(params))
.await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn artifacts(
&self,
training_session_id: TrainingSessionID,
) -> Result<Vec<Artifact>, Error> {
let params = HashMap::from([("training_session_id", training_session_id)]);
self.rpc("trainer.get_artifacts".to_owned(), Some(params))
.await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(training_session_id = %training_session_id)))]
pub async fn download_artifact(
&self,
training_session_id: TrainingSessionID,
modelname: &str,
filename: Option<PathBuf>,
progress: Option<Sender<Progress>>,
) -> Result<(), Error> {
let filename = filename.unwrap_or_else(|| PathBuf::from(modelname));
let resp = self
.bulk_http
.get(format!(
"{}/download_model?training_session_id={}&file={}",
self.url,
training_session_id.value(),
modelname
))
.header("Authorization", format!("Bearer {}", self.token().await))
.send()
.await?;
if !resp.status().is_success() {
let err = resp.error_for_status_ref().unwrap_err();
return Err(Error::HttpError(err));
}
if let Some(parent) = filename.parent() {
fs::create_dir_all(parent).await?;
}
let total = resp.content_length().unwrap_or(0) as usize;
if let Some(ref progress) = progress {
let _ = progress
.send(Progress {
current: 0,
total,
status: None,
})
.await;
}
let mut file = File::create(filename).await?;
let mut current = 0;
let mut stream = resp.bytes_stream();
while let Some(item) = stream.next().await {
let chunk = item?;
file.write_all(&chunk).await?;
current += chunk.len();
if let Some(ref progress) = progress {
let _ = progress
.send(Progress {
current,
total,
status: None,
})
.await;
}
}
file.flush().await?;
Ok(())
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(training_session_id = %training_session_id)))]
pub async fn download_checkpoint(
&self,
training_session_id: TrainingSessionID,
checkpoint: &str,
filename: Option<PathBuf>,
progress: Option<Sender<Progress>>,
) -> Result<(), Error> {
let filename = filename.unwrap_or_else(|| PathBuf::from(checkpoint));
let resp = self
.bulk_http
.get(format!(
"{}/download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
self.url,
training_session_id.value(),
checkpoint
))
.header("Authorization", format!("Bearer {}", self.token().await))
.send()
.await?;
if !resp.status().is_success() {
let err = resp.error_for_status_ref().unwrap_err();
return Err(Error::HttpError(err));
}
if let Some(parent) = filename.parent() {
fs::create_dir_all(parent).await?;
}
let total = resp.content_length().unwrap_or(0) as usize;
if let Some(ref progress) = progress {
let _ = progress
.send(Progress {
current: 0,
total,
status: None,
})
.await;
}
let mut file = File::create(filename).await?;
let mut current = 0;
let mut stream = resp.bytes_stream();
while let Some(item) = stream.next().await {
let chunk = item?;
file.write_all(&chunk).await?;
current += chunk.len();
if let Some(ref progress) = progress {
let _ = progress
.send(Progress {
current,
total,
status: None,
})
.await;
}
}
file.flush().await?;
Ok(())
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn tasks(
&self,
name: Option<&str>,
workflow: Option<&str>,
status: Option<&str>,
manager: Option<&str>,
) -> Result<Vec<Task>, Error> {
let mut params = TasksListParams {
continue_token: None,
types: workflow.map(|w| vec![w.to_owned()]),
status: status.map(|s| vec![s.to_owned()]),
manager: manager.map(|m| vec![m.to_owned()]),
};
let mut tasks = Vec::new();
loop {
let result = self
.rpc::<_, TasksListResult>("task.list".to_owned(), Some(¶ms))
.await?;
tasks.extend(result.tasks);
if result.continue_token.is_none() || result.continue_token == Some("".into()) {
params.continue_token = None;
} else {
params.continue_token = result.continue_token;
}
if params.continue_token.is_none() {
break;
}
}
if let Some(name) = name {
tasks = filter_and_sort_by_name(tasks, name, |t| t.name());
}
Ok(tasks)
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(task_id = %task_id)))]
pub async fn task_info(&self, task_id: TaskID) -> Result<TaskInfo, Error> {
self.rpc(
"task.get".to_owned(),
Some(HashMap::from([("id", task_id)])),
)
.await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn task_status(&self, task_id: TaskID, status: &str) -> Result<Task, Error> {
let status = TaskStatus {
task_id,
status: status.to_owned(),
};
self.rpc("docker.update.status".to_owned(), Some(status))
.await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self, stages)))]
pub async fn set_stages(&self, task_id: TaskID, stages: &[(&str, &str)]) -> Result<(), Error> {
let stages: Vec<HashMap<String, String>> = stages
.iter()
.map(|(key, value)| {
let mut stage_map = HashMap::new();
stage_map.insert(key.to_string(), value.to_string());
stage_map
})
.collect();
let params = TaskStages { task_id, stages };
let _: Task = self.rpc("status.stages".to_owned(), Some(params)).await?;
Ok(())
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn update_stage(
&self,
task_id: TaskID,
stage: &str,
status: &str,
message: &str,
percentage: u8,
) -> Result<(), Error> {
let stage = Stage::new(
Some(task_id),
stage.to_owned(),
Some(status.to_owned()),
Some(message.to_owned()),
percentage,
);
let _: Task = self.rpc("status.update".to_owned(), Some(stage)).await?;
Ok(())
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
pub async fn fetch(&self, query: &str) -> Result<Vec<u8>, Error> {
let req = self
.bulk_http
.get(format!("{}/{}", self.url, query))
.header("User-Agent", "EdgeFirst Client")
.header("Authorization", format!("Bearer {}", self.token().await));
let resp = req.send().await?;
if resp.status().is_success() {
let body = resp.bytes().await?;
if log_enabled!(Level::Trace) {
trace!("Fetch Response: {}", String::from_utf8_lossy(&body));
}
Ok(body.to_vec())
} else {
let err = resp.error_for_status_ref().unwrap_err();
Err(Error::HttpError(err))
}
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self, form)))]
pub async fn post_multipart(&self, method: &str, form: Form) -> Result<String, Error> {
let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(600u64);
let req = self
.http
.post(format!("{}/api?method={}", self.url, method))
.header("Accept", "application/json")
.header("User-Agent", "EdgeFirst Client")
.header("Authorization", format!("Bearer {}", self.token().await))
.timeout(Duration::from_secs(upload_timeout_secs))
.multipart(form);
let resp = req.send().await?;
if resp.status().is_success() {
let body = resp.bytes().await?;
if log_enabled!(Level::Trace) {
trace!(
"POST Multipart Response: {}",
String::from_utf8_lossy(&body)
);
}
let response: RpcResponse<String> = match serde_json::from_slice(&body) {
Ok(response) => response,
Err(err) => {
error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
return Err(err.into());
}
};
if let Some(error) = response.error {
Err(Error::RpcError(error.code, error.message))
} else if let Some(result) = response.result {
Ok(result)
} else {
Err(Error::InvalidResponse)
}
} else {
let err = resp.error_for_status_ref().unwrap_err();
Err(Error::HttpError(err))
}
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self, params), fields(method = %method)))]
pub async fn rpc<Params, RpcResult>(
&self,
method: String,
params: Option<Params>,
) -> Result<RpcResult, Error>
where
Params: Serialize,
RpcResult: DeserializeOwned,
{
let auth_expires = self.token_expiration().await?;
if auth_expires <= Utc::now() + Duration::from_secs(3600) {
self.renew_token().await?;
}
self.rpc_without_auth(method, params).await
}
#[cfg_attr(feature = "profiling", tracing::instrument(skip(self, params), fields(method = %method, request = tracing::field::Empty, response = tracing::field::Empty)))]
async fn rpc_without_auth<Params, RpcResult>(
&self,
method: String,
params: Option<Params>,
) -> Result<RpcResult, Error>
where
Params: Serialize,
RpcResult: DeserializeOwned,
{
let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(5usize);
let url = format!("{}/api", self.url);
let request = RpcRequest {
method: method.clone(),
params,
..Default::default()
};
let request_json = if method == "auth.login" {
serde_json::json!({
"jsonrpc": "2.0",
"method": &method,
"params": "[REDACTED - contains credentials]",
"id": request.id
})
.to_string()
} else {
serde_json::to_string(&request)?
};
if log_enabled!(Level::Trace) {
trace!("RPC Request: {}", request_json);
}
#[cfg(feature = "profiling")]
tracing::Span::current().record("request", &request_json);
let request_body = serde_json::to_vec(&request)?;
let mut last_error: Option<Error> = None;
for attempt in 0..=max_retries {
if attempt > 0 {
let base_delay_secs = (1u64 << (attempt - 1).min(5)).min(30);
let jitter_factor = 1.0 + (rand::random::<f64>() * 0.5); let delay_ms = (base_delay_secs as f64 * 1000.0 * jitter_factor) as u64;
let delay = Duration::from_millis(delay_ms);
warn!(
"Retry {}/{} for RPC '{}' after {:?}",
attempt, max_retries, method, delay
);
tokio::time::sleep(delay).await;
}
let result = self
.http
.post(&url)
.header("Accept", "application/json")
.header("Content-Type", "application/json")
.header("User-Agent", "EdgeFirst Client")
.header("Authorization", format!("Bearer {}", self.token().await))
.body(request_body.clone())
.send()
.await;
match result {
Ok(res) => {
let status = res.status();
let status_code = status.as_u16();
if matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504)
&& attempt < max_retries
{
warn!(
"RPC '{}' failed with HTTP {} (retrying)",
method, status_code
);
last_error = Some(Error::HttpError(res.error_for_status().unwrap_err()));
continue;
}
match self.process_rpc_response(res).await {
Ok(result) => {
if attempt > 0 {
debug!("RPC '{}' succeeded on retry {}", method, attempt);
}
return Ok(result);
}
Err(e) => {
if attempt > 0 {
error!("RPC '{}' failed after {} retries: {}", method, attempt, e);
}
return Err(e);
}
}
}
Err(e) => {
let is_timeout = e.is_timeout();
let is_connect = e.is_connect();
if (is_timeout || is_connect) && attempt < max_retries {
warn!(
"RPC '{}' transport error (retrying): {}",
method,
if is_timeout {
"timeout"
} else {
"connection failed"
}
);
last_error = Some(Error::HttpError(e));
continue;
}
if attempt > 0 {
error!("RPC '{}' failed after {} retries: {}", method, attempt, e);
}
return Err(Error::HttpError(e));
}
}
}
Err(last_error.unwrap_or_else(|| {
Error::InvalidParameters(format!(
"RPC '{}' failed after {} retries",
method, max_retries
))
}))
}
async fn process_rpc_response<RpcResult>(
&self,
res: reqwest::Response,
) -> Result<RpcResult, Error>
where
RpcResult: DeserializeOwned,
{
let body = res.bytes().await?;
let response_str = String::from_utf8_lossy(&body);
if log_enabled!(Level::Trace) {
trace!("RPC Response: {}", response_str);
}
#[cfg(feature = "profiling")]
{
const MAX_RESPONSE_LEN: usize = 4096;
let truncated = if response_str.len() > MAX_RESPONSE_LEN {
let safe_end = response_str.floor_char_boundary(MAX_RESPONSE_LEN);
format!(
"{}...[truncated {} bytes]",
&response_str[..safe_end],
response_str.len() - safe_end
)
} else {
response_str.to_string()
};
tracing::Span::current().record("response", &truncated);
}
let response: RpcResponse<RpcResult> = match serde_json::from_slice(&body) {
Ok(response) => response,
Err(err) => {
error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
return Err(err.into());
}
};
if let Some(error) = response.error {
Err(Error::RpcError(error.code, error.message))
} else if let Some(result) = response.result {
Ok(result)
} else {
Err(Error::InvalidResponse)
}
}
}
async fn parallel_foreach_items<T, F, Fut>(
items: Vec<T>,
progress: Option<Sender<Progress>>,
concurrency: Option<usize>,
work_fn: F,
) -> Result<(), Error>
where
T: Send + 'static,
F: Fn(T) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), Error>> + Send + 'static,
{
let total = items.len();
let current = Arc::new(AtomicUsize::new(0));
let sem = Arc::new(Semaphore::new(concurrency.unwrap_or_else(max_tasks)));
let work_fn = Arc::new(work_fn);
let tasks = items
.into_iter()
.map(|item| {
let sem = sem.clone();
let current = current.clone();
let progress = progress.clone();
let work_fn = work_fn.clone();
tokio::spawn(async move {
let _permit = sem.acquire().await.map_err(|_| {
Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
})?;
work_fn(item).await?;
if let Some(progress) = &progress {
let current = current.fetch_add(1, Ordering::SeqCst);
let _ = progress
.send(Progress {
current: current + 1,
total,
status: None,
})
.await;
}
Ok::<(), Error>(())
})
})
.collect::<Vec<_>>();
join_all(tasks)
.await
.into_iter()
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.collect::<Result<Vec<_>, _>>()?;
if let Some(progress) = progress {
drop(progress);
}
Ok(())
}
async fn upload_multipart(
http: reqwest::Client,
part: SnapshotPart,
path: PathBuf,
total: usize,
confirmed_bytes: Arc<AtomicUsize>,
progress: Option<Sender<Progress>>,
) -> Result<SnapshotCompleteMultipartParams, Error> {
let filesize = path.metadata()?.len() as usize;
let n_parts = filesize.div_ceil(PART_SIZE);
let sem = Arc::new(Semaphore::new(max_upload_tasks()));
let key = part.key.ok_or(Error::InvalidResponse)?;
let upload_id = part.upload_id;
let urls = part.urls.clone();
let etags = Arc::new(tokio::sync::Mutex::new(vec![
EtagPart {
etag: "".to_owned(),
part_number: 0,
};
n_parts
]));
let part_bytes: Arc<Vec<AtomicUsize>> = Arc::new(
(0..n_parts)
.map(|_| AtomicUsize::new(0))
.collect::<Vec<_>>(),
);
let tasks = (0..n_parts)
.map(|part_idx| {
let http = http.clone();
let url = urls[part_idx].clone();
let etags = etags.clone();
let path = path.to_owned();
let sem = sem.clone();
let progress = progress.clone();
let confirmed_bytes = confirmed_bytes.clone();
let part_bytes = part_bytes.clone();
let part_size = if part_idx + 1 == n_parts && !filesize.is_multiple_of(PART_SIZE) {
filesize % PART_SIZE
} else {
PART_SIZE
};
tokio::spawn(async move {
let _permit = sem.acquire().await.map_err(|_| {
Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
})?;
let etag = upload_part_with_progress(
http,
url,
path,
part_idx,
n_parts,
part_size,
total,
confirmed_bytes.clone(),
part_bytes.clone(),
progress.clone(),
)
.await?;
let mut etags_guard = etags.lock().await;
etags_guard[part_idx] = EtagPart {
etag,
part_number: part_idx + 1,
};
confirmed_bytes.fetch_add(part_size, Ordering::SeqCst);
part_bytes[part_idx].store(0, Ordering::SeqCst);
if let Some(progress) = &progress {
let current = confirmed_bytes.load(Ordering::SeqCst)
+ part_bytes
.iter()
.map(|p| p.load(Ordering::SeqCst))
.sum::<usize>();
let _ = progress
.send(Progress {
current,
total,
status: None,
})
.await;
}
Ok::<(), Error>(())
})
})
.collect::<Vec<_>>();
join_all(tasks)
.await
.into_iter()
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.collect::<Result<Vec<_>, _>>()?;
Ok(SnapshotCompleteMultipartParams {
key,
upload_id,
etag_list: etags.lock().await.clone(),
})
}
#[allow(clippy::too_many_arguments)]
async fn upload_part_with_progress(
http: reqwest::Client,
url: String,
path: PathBuf,
part_idx: usize,
n_parts: usize,
part_size: usize,
total: usize,
confirmed_bytes: Arc<AtomicUsize>,
part_bytes: Arc<Vec<AtomicUsize>>,
progress: Option<Sender<Progress>>,
) -> Result<String, Error> {
let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(5usize);
let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(600u64);
let mut last_error: Option<Error> = None;
for attempt in 0..=max_retries {
if attempt > 0 {
part_bytes[part_idx].store(0, Ordering::SeqCst);
let delay = Duration::from_secs(1 << (attempt - 1).min(4));
warn!(
"Retry {}/{} for part {} after {:?}",
attempt, max_retries, part_idx, delay
);
tokio::time::sleep(delay).await;
}
match upload_part_streaming(
http.clone(),
url.clone(),
path.clone(),
part_idx,
n_parts,
part_size,
total,
upload_timeout_secs,
confirmed_bytes.clone(),
part_bytes.clone(),
progress.clone(),
)
.await
{
Ok(etag) => return Ok(etag),
Err(e) => {
let is_retryable = matches!(
&e,
Error::HttpError(re) if re.is_timeout() || re.is_connect() ||
re.status().map(|s: reqwest::StatusCode| s.as_u16()).unwrap_or(0) >= 500
);
if is_retryable && attempt < max_retries {
last_error = Some(e);
continue;
}
return Err(e);
}
}
}
Err(last_error
.unwrap_or_else(|| Error::IoError(std::io::Error::other("Upload failed after retries"))))
}
#[allow(clippy::too_many_arguments)]
async fn upload_part_streaming(
http: reqwest::Client,
url: String,
path: PathBuf,
part_idx: usize,
n_parts: usize,
_part_size: usize,
total: usize,
upload_timeout_secs: u64,
confirmed_bytes: Arc<AtomicUsize>,
part_bytes: Arc<Vec<AtomicUsize>>,
progress: Option<Sender<Progress>>,
) -> Result<String, Error> {
let filesize = path.metadata()?.len() as usize;
let mut file = File::open(&path).await?;
file.seek(SeekFrom::Start((part_idx * PART_SIZE) as u64))
.await?;
let file = file.take(PART_SIZE as u64);
let body_length = if part_idx + 1 == n_parts && !filesize.is_multiple_of(PART_SIZE) {
filesize % PART_SIZE
} else {
PART_SIZE
};
let stream = FramedRead::new(file, BytesCodec::new());
let progress_stream = stream.map(move |result| {
if let Ok(ref bytes) = result {
let bytes_len = bytes.len();
part_bytes[part_idx].fetch_add(bytes_len, Ordering::SeqCst);
if let Some(ref progress) = progress {
let current = confirmed_bytes.load(Ordering::SeqCst)
+ part_bytes
.iter()
.map(|p| p.load(Ordering::SeqCst))
.sum::<usize>();
let _ = progress.try_send(Progress {
current,
total,
status: None,
});
}
}
result.map(|b| b.freeze())
});
let body = Body::wrap_stream(progress_stream);
let resp = http
.put(url)
.header(CONTENT_LENGTH, body_length)
.timeout(Duration::from_secs(upload_timeout_secs))
.body(body)
.send()
.await?
.error_for_status()?;
let etag = resp
.headers()
.get("etag")
.ok_or_else(|| Error::InvalidEtag("Missing ETag header".to_string()))?
.to_str()
.map_err(|_| Error::InvalidEtag("Invalid ETag encoding".to_string()))?
.to_owned();
let etag = etag
.strip_prefix("\"")
.ok_or_else(|| Error::InvalidEtag("Missing opening quote".to_string()))?;
let etag = etag
.strip_suffix("\"")
.ok_or_else(|| Error::InvalidEtag("Missing closing quote".to_string()))?;
Ok(etag.to_owned())
}
async fn upload_file_to_presigned_url(
http: reqwest::Client,
url: &str,
path: PathBuf,
) -> Result<(), Error> {
let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(5usize);
let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(600u64);
let file_data = fs::read(&path).await?;
let file_size = file_data.len();
let filename = path.file_name().unwrap_or_default().to_string_lossy();
let mut last_error: Option<Error> = None;
for attempt in 0..=max_retries {
if attempt > 0 {
let delay = Duration::from_secs(1 << (attempt - 1).min(4));
warn!(
"Retry {}/{} for upload '{}' after {:?}",
attempt, max_retries, filename, delay
);
tokio::time::sleep(delay).await;
}
let result = http
.put(url)
.header(CONTENT_LENGTH, file_size)
.timeout(Duration::from_secs(upload_timeout_secs))
.body(file_data.clone())
.send()
.await;
match result {
Ok(resp) => {
if resp.status().is_success() {
if attempt > 0 {
debug!(
"Upload '{}' succeeded on retry {} ({} bytes)",
filename, attempt, file_size
);
} else {
debug!(
"Successfully uploaded file: {} ({} bytes)",
filename, file_size
);
}
return Ok(());
}
let status = resp.status();
let status_code = status.as_u16();
let is_retryable =
matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504 | 409 | 423);
if is_retryable && attempt < max_retries {
let error_text = resp.text().await.unwrap_or_default();
warn!(
"Upload '{}' failed with HTTP {} (retryable): {}",
filename, status_code, error_text
);
last_error = Some(Error::InvalidParameters(format!(
"Upload failed: HTTP {} - {}",
status, error_text
)));
continue;
}
let error_text = resp.text().await.unwrap_or_default();
if attempt > 0 {
error!(
"Upload '{}' failed after {} retries: HTTP {} - {}",
filename, attempt, status, error_text
);
}
return Err(Error::InvalidParameters(format!(
"Upload failed: HTTP {} - {}",
status, error_text
)));
}
Err(e) => {
let is_timeout = e.is_timeout();
let is_connect = e.is_connect();
if (is_timeout || is_connect) && attempt < max_retries {
warn!(
"Upload '{}' transport error (retrying): {}",
filename,
if is_timeout {
"timeout"
} else {
"connection failed"
}
);
last_error = Some(Error::HttpError(e));
continue;
}
if attempt > 0 {
error!(
"Upload '{}' failed after {} retries: {}",
filename, attempt, e
);
}
return Err(Error::HttpError(e));
}
}
}
Err(last_error.unwrap_or_else(|| {
Error::InvalidParameters(format!("Upload failed after {} retries", max_retries))
}))
}
async fn upload_bytes_to_presigned_url(
http: reqwest::Client,
url: &str,
file_data: Vec<u8>,
filename: &str,
) -> Result<(), Error> {
let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(5usize);
let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(600u64);
let file_size = file_data.len();
let mut last_error: Option<Error> = None;
for attempt in 0..=max_retries {
if attempt > 0 {
let delay = Duration::from_secs(1 << (attempt - 1).min(4));
warn!(
"Retry {}/{} for upload '{}' after {:?}",
attempt, max_retries, filename, delay
);
tokio::time::sleep(delay).await;
}
let result = http
.put(url)
.header(CONTENT_LENGTH, file_size)
.timeout(Duration::from_secs(upload_timeout_secs))
.body(file_data.clone())
.send()
.await;
match result {
Ok(resp) => {
if resp.status().is_success() {
if attempt > 0 {
debug!(
"Upload '{}' succeeded on retry {} ({} bytes)",
filename, attempt, file_size
);
} else {
debug!(
"Successfully uploaded file: {} ({} bytes)",
filename, file_size
);
}
return Ok(());
}
let status = resp.status();
let status_code = status.as_u16();
let is_retryable =
matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504 | 409 | 423);
if is_retryable && attempt < max_retries {
let error_text = resp.text().await.unwrap_or_default();
warn!(
"Upload '{}' failed with HTTP {} (retryable): {}",
filename, status_code, error_text
);
last_error = Some(Error::InvalidParameters(format!(
"Upload failed: HTTP {} - {}",
status, error_text
)));
continue;
}
let error_text = resp.text().await.unwrap_or_default();
if attempt > 0 {
error!(
"Upload '{}' failed after {} retries: HTTP {} - {}",
filename, attempt, status, error_text
);
}
return Err(Error::InvalidParameters(format!(
"Upload failed: HTTP {} - {}",
status, error_text
)));
}
Err(e) => {
let is_timeout = e.is_timeout();
let is_connect = e.is_connect();
if (is_timeout || is_connect) && attempt < max_retries {
warn!(
"Upload '{}' transport error (retrying): {}",
filename,
if is_timeout {
"timeout"
} else {
"connection failed"
}
);
last_error = Some(Error::HttpError(e));
continue;
}
if attempt > 0 {
error!(
"Upload '{}' failed after {} retries: {}",
filename, attempt, e
);
}
return Err(Error::HttpError(e));
}
}
}
Err(last_error.unwrap_or_else(|| {
Error::InvalidParameters(format!("Upload failed after {} retries", max_retries))
}))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_filter_and_sort_by_name_exact_match_first() {
let items = vec![
"Deer Roundtrip 123".to_string(),
"Deer".to_string(),
"Reindeer".to_string(),
"DEER".to_string(),
];
let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
assert_eq!(result[0], "Deer"); assert_eq!(result[1], "DEER"); }
#[test]
fn test_filter_and_sort_by_name_shorter_names_preferred() {
let items = vec![
"Test Dataset ABC".to_string(),
"Test".to_string(),
"Test Dataset".to_string(),
];
let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
assert_eq!(result[0], "Test"); assert_eq!(result[1], "Test Dataset"); assert_eq!(result[2], "Test Dataset ABC"); }
#[test]
fn test_filter_and_sort_by_name_case_insensitive_filter() {
let items = vec![
"UPPERCASE".to_string(),
"lowercase".to_string(),
"MixedCase".to_string(),
];
let result = filter_and_sort_by_name(items, "case", |s| s.as_str());
assert_eq!(result.len(), 3); }
#[test]
fn test_filter_and_sort_by_name_no_matches() {
let items = vec!["Apple".to_string(), "Banana".to_string()];
let result = filter_and_sort_by_name(items, "Cherry", |s| s.as_str());
assert!(result.is_empty());
}
#[test]
fn test_filter_and_sort_by_name_alphabetical_tiebreaker() {
let items = vec![
"TestC".to_string(),
"TestA".to_string(),
"TestB".to_string(),
];
let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
assert_eq!(result, vec!["TestA", "TestB", "TestC"]);
}
#[test]
fn test_build_filename_no_flatten() {
let result = Client::build_filename("image.jpg", false, Some(&"seq".to_string()), Some(42));
assert_eq!(result, "image.jpg");
let result = Client::build_filename("test.png", false, None, None);
assert_eq!(result, "test.png");
}
#[test]
fn test_build_filename_flatten_no_sequence() {
let result = Client::build_filename("standalone.jpg", true, None, None);
assert_eq!(result, "standalone.jpg");
}
#[test]
fn test_build_filename_flatten_with_sequence_not_prefixed() {
let result = Client::build_filename(
"image.camera.jpeg",
true,
Some(&"deer_sequence".to_string()),
Some(42),
);
assert_eq!(result, "deer_sequence_42_image.camera.jpeg");
}
#[test]
fn test_build_filename_flatten_with_sequence_no_frame() {
let result =
Client::build_filename("image.jpg", true, Some(&"sequence_A".to_string()), None);
assert_eq!(result, "sequence_A_image.jpg");
}
#[test]
fn test_build_filename_flatten_already_prefixed() {
let result = Client::build_filename(
"deer_sequence_042.camera.jpeg",
true,
Some(&"deer_sequence".to_string()),
Some(42),
);
assert_eq!(result, "deer_sequence_042.camera.jpeg");
}
#[test]
fn test_build_filename_flatten_already_prefixed_different_frame() {
let result = Client::build_filename(
"sequence_A_001.jpg",
true,
Some(&"sequence_A".to_string()),
Some(2),
);
assert_eq!(result, "sequence_A_001.jpg");
}
#[test]
fn test_build_filename_flatten_partial_match() {
let result = Client::build_filename(
"test_sequence_A_image.jpg",
true,
Some(&"sequence_A".to_string()),
Some(5),
);
assert_eq!(result, "sequence_A_5_test_sequence_A_image.jpg");
}
#[test]
fn test_build_filename_flatten_preserves_extension() {
let extensions = vec![
"jpeg",
"jpg",
"png",
"camera.jpeg",
"lidar.pcd",
"depth.png",
];
for ext in extensions {
let filename = format!("image.{}", ext);
let result = Client::build_filename(&filename, true, Some(&"seq".to_string()), Some(1));
assert!(
result.ends_with(&format!(".{}", ext)),
"Extension .{} not preserved in {}",
ext,
result
);
}
}
#[test]
fn test_build_filename_flatten_sanitization_compatibility() {
let result = Client::build_filename(
"sample_001.jpg",
true,
Some(&"seq_name_with_underscores".to_string()),
Some(10),
);
assert_eq!(result, "seq_name_with_underscores_10_sample_001.jpg");
}
#[test]
fn test_filter_and_sort_by_name_exact_match_is_deterministic() {
let items = vec![
"Deer Roundtrip 20251129".to_string(),
"White-Tailed Deer".to_string(),
"Deer".to_string(),
"Deer Snapshot Test".to_string(),
"Reindeer Dataset".to_string(),
];
let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
assert_eq!(
result.first().map(|s| s.as_str()),
Some("Deer"),
"Expected exact match 'Deer' first, got: {:?}",
result.first()
);
assert_eq!(result.len(), 5);
}
#[test]
fn test_filter_and_sort_by_name_exact_match_with_different_cases() {
let items = vec![
"DEER".to_string(),
"deer".to_string(),
"Deer".to_string(),
"Deer Test".to_string(),
];
let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
assert_eq!(result[0], "Deer");
assert!(result[1] == "DEER" || result[1] == "deer");
assert!(result[2] == "DEER" || result[2] == "deer");
}
#[test]
fn test_filter_and_sort_by_name_snapshot_realistic_scenario() {
let items = vec![
"Unit Testing - Deer Dataset Backup".to_string(),
"Deer".to_string(),
"Deer Snapshot 2025-01-15".to_string(),
"Original Deer".to_string(),
];
let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
assert_eq!(
result[0], "Deer",
"Searching for 'Deer' should return exact 'Deer' first"
);
}
#[test]
fn test_filter_and_sort_by_name_dataset_realistic_scenario() {
let items = vec![
"Deer Roundtrip".to_string(),
"Deer".to_string(),
"deer".to_string(),
"White-Tailed Deer".to_string(),
"Deer-V2".to_string(),
];
let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
assert_eq!(result[0], "Deer");
assert_eq!(result[1], "deer");
assert!(
result.iter().position(|s| s == "Deer-V2").unwrap()
< result.iter().position(|s| s == "Deer Roundtrip").unwrap()
);
}
#[test]
fn test_filter_and_sort_by_name_first_result_is_always_best_match() {
let scenarios = vec![
(vec!["Deer Dataset", "Deer", "deer"], "Deer", "Deer"),
(vec!["test", "TEST", "Test Data"], "test", "test"),
(vec!["ABC", "ABCD", "abc"], "ABC", "ABC"),
];
for (items, filter, expected_first) in scenarios {
let items: Vec<String> = items.iter().map(|s| s.to_string()).collect();
let result = filter_and_sort_by_name(items, filter, |s| s.as_str());
assert_eq!(
result.first().map(|s| s.as_str()),
Some(expected_first),
"For filter '{}', expected first result '{}', got: {:?}",
filter,
expected_first,
result.first()
);
}
}
#[test]
fn test_with_server_clears_storage() {
use crate::storage::MemoryTokenStorage;
let storage = Arc::new(MemoryTokenStorage::new());
storage.store("test-token").unwrap();
let client = Client::new().unwrap().with_storage(storage.clone());
assert_eq!(storage.load().unwrap(), Some("test-token".to_string()));
let _new_client = client.with_server("test").unwrap();
assert_eq!(storage.load().unwrap(), None);
}
}