burn_central_core/artifacts/
client.rs1use burn_central_client::response::MultipartUploadResponse;
2use burn_central_client::{Client, ClientError};
3use sha2::Digest;
4use std::collections::BTreeMap;
5
6use crate::artifacts::ArtifactInfo;
7use crate::bundle::{BundleDecode, BundleEncode, InMemoryBundleReader, InMemoryBundleSources};
8use crate::schemas::ExperimentPath;
9use burn_central_client::request::{ArtifactFileSpecRequest, CreateArtifactRequest};
10
11#[derive(Debug, Clone, strum::Display, strum::EnumString)]
12#[strum(serialize_all = "snake_case")]
13pub enum ArtifactKind {
14 Model,
15 Log,
16 Other,
17}
18
19#[derive(Clone)]
21pub struct ExperimentArtifactClient {
22 client: Client,
23 exp_path: ExperimentPath,
24}
25
26impl ExperimentArtifactClient {
27 pub(crate) fn new(client: Client, exp_path: ExperimentPath) -> Self {
28 Self { client, exp_path }
29 }
30
31 pub fn upload<E: BundleEncode>(
33 &self,
34 name: impl Into<String>,
35 kind: ArtifactKind,
36 artifact: E,
37 settings: &E::Settings,
38 ) -> Result<String, ArtifactError> {
39 let name = name.into();
40 let mut sources = InMemoryBundleSources::new();
41 artifact.encode(&mut sources, settings).map_err(|e| {
42 ArtifactError::Encoding(format!("Failed to encode artifact: {}", e.into()))
43 })?;
44
45 let mut specs = Vec::with_capacity(sources.files().len());
46 for f in sources.files() {
47 let (checksum, size) = sha256_and_size_from_bytes(f.source());
48 specs.push(ArtifactFileSpecRequest {
49 rel_path: f.dest_path().to_string(),
50 size_bytes: size,
51 checksum,
52 });
53 }
54
55 let res = self.client.create_artifact(
56 self.exp_path.owner_name(),
57 self.exp_path.project_name(),
58 self.exp_path.experiment_num(),
59 CreateArtifactRequest {
60 name: name.clone(),
61 kind: kind.to_string(),
62 files: specs,
63 },
64 )?;
65
66 let mut multipart_map: BTreeMap<String, &MultipartUploadResponse> = BTreeMap::new();
67 for f in &res.files {
68 multipart_map.insert(f.rel_path.clone(), &f.urls);
69 }
70
71 for f in sources.into_files() {
72 let multipart_info = multipart_map.get(f.dest_path()).ok_or_else(|| {
73 ArtifactError::Internal(format!(
74 "Missing multipart upload info for file {}",
75 f.dest_path()
76 ))
77 })?;
78
79 self.upload_file_multipart(f.source(), multipart_info)?;
80 }
81
82 self.client.complete_artifact_upload(
83 self.exp_path.owner_name(),
84 self.exp_path.project_name(),
85 self.exp_path.experiment_num(),
86 &res.id,
87 None,
88 )?;
89
90 Ok(res.id)
91 }
92
93 pub fn download<D: BundleDecode>(
95 &self,
96 name: impl AsRef<str>,
97 settings: &D::Settings,
98 ) -> Result<D, ArtifactError> {
99 let reader = self.download_raw(name.as_ref())?;
100 D::decode(&reader, settings).map_err(|e| {
101 ArtifactError::Decoding(format!(
102 "Failed to decode artifact {}: {}",
103 name.as_ref(),
104 e.into()
105 ))
106 })
107 }
108
109 pub fn download_raw(
111 &self,
112 name: impl AsRef<str>,
113 ) -> Result<InMemoryBundleReader, ArtifactError> {
114 let name = name.as_ref();
115 let artifact = self.fetch(name)?;
116 let resp = self.client.presign_artifact_download(
117 self.exp_path.owner_name(),
118 self.exp_path.project_name(),
119 self.exp_path.experiment_num(),
120 &artifact.id.to_string(),
121 )?;
122
123 let mut data = BTreeMap::new();
124
125 for file in resp.files {
126 data.insert(
127 file.rel_path.clone(),
128 self.client.download_bytes_from_url(&file.url)?,
129 );
130 }
131
132 Ok(InMemoryBundleReader::new(data))
133 }
134
135 pub fn fetch(&self, name: impl AsRef<str>) -> Result<ArtifactInfo, ArtifactError> {
137 let name = name.as_ref();
138 let artifact_resp = self
139 .client
140 .list_artifacts_by_name(
141 self.exp_path.owner_name(),
142 self.exp_path.project_name(),
143 self.exp_path.experiment_num(),
144 name,
145 )?
146 .items
147 .into_iter()
148 .next()
149 .ok_or_else(|| ArtifactError::NotFound(name.to_owned()))?;
150
151 Ok(artifact_resp.into())
152 }
153
154 fn upload_file_multipart(
155 &self,
156 file_data: &[u8],
157 multipart_info: &MultipartUploadResponse,
158 ) -> Result<(), ArtifactError> {
159 let mut part_indices: Vec<usize> = (0..multipart_info.parts.len()).collect();
160 part_indices.sort_by_key(|&i| multipart_info.parts[i].part);
161
162 for (i, &part_idx) in part_indices.iter().enumerate() {
163 let part = &multipart_info.parts[part_idx];
164 if part.part != (i as u32 + 1) {
165 return Err(ArtifactError::Internal(format!(
166 "Invalid part numbering: expected part {}, got part {}",
167 i + 1,
168 part.part
169 )));
170 }
171 }
172
173 let mut current_offset = 0usize;
174 let total_parts = multipart_info.parts.len();
175
176 for (part_index, &part_idx) in part_indices.iter().enumerate() {
177 let part_info = &multipart_info.parts[part_idx];
178 let end_offset = std::cmp::min(
179 current_offset + part_info.size_bytes as usize,
180 file_data.len(),
181 );
182
183 if current_offset >= file_data.len() {
184 break;
185 }
186
187 let part_data = &file_data[current_offset..end_offset];
188
189 self.client
190 .upload_bytes_to_url(&part_info.url, part_data.to_vec())
191 .map_err(|e| {
192 ArtifactError::Internal(format!(
193 "Failed to upload part {} of {}: {}",
194 part_index + 1,
195 total_parts,
196 e
197 ))
198 })?;
199
200 current_offset = end_offset;
201 }
202
203 Ok(())
204 }
205}
206
207fn sha256_and_size_from_bytes(bytes: &[u8]) -> (String, u64) {
208 let mut hasher = sha2::Sha256::new();
209 hasher.update(bytes);
210 let digest = hasher.finalize();
211 (format!("{:x}", digest), bytes.len() as u64)
212}
213
214#[derive(Debug, thiserror::Error)]
215pub enum ArtifactError {
216 #[error("Artifact not found: {0}")]
217 NotFound(String),
218 #[error(transparent)]
219 Client(#[from] ClientError),
220 #[error("Error while encoding artifact: {0}")]
221 Encoding(String),
222 #[error("Error while decoding artifact: {0}")]
223 Decoding(String),
224 #[error("Internal error: {0}")]
225 Internal(String),
226}