burn_central_core/artifacts/
mod.rs1use burn_central_client::response::{ArtifactResponse, MultipartUploadResponse};
2use burn_central_client::{Client, ClientError};
3use sha2::Digest;
4use std::collections::BTreeMap;
5
6use crate::bundle::{BundleDecode, BundleEncode, InMemoryBundleReader, InMemoryBundleSources};
7use crate::schemas::ExperimentPath;
8use burn_central_client::request::{ArtifactFileSpecRequest, CreateArtifactRequest};
9
10#[derive(Debug, Clone, strum::Display, strum::EnumString)]
11#[strum(serialize_all = "snake_case")]
12pub enum ArtifactKind {
13 Model,
14 Log,
15 Other,
16}
17
18#[derive(Clone)]
20pub struct ExperimentArtifactClient {
21 client: Client,
22 exp_path: ExperimentPath,
23}
24
25impl ExperimentArtifactClient {
26 pub(crate) fn new(client: Client, exp_path: ExperimentPath) -> Self {
27 Self { client, exp_path }
28 }
29
30 pub fn upload<E: BundleEncode>(
32 &self,
33 name: impl Into<String>,
34 kind: ArtifactKind,
35 artifact: E,
36 settings: &E::Settings,
37 ) -> Result<String, ArtifactError> {
38 let name = name.into();
39 let mut sources = InMemoryBundleSources::new();
40 artifact.encode(&mut sources, settings).map_err(|e| {
41 ArtifactError::Encoding(format!("Failed to encode artifact: {}", e.into()))
42 })?;
43
44 let mut specs = Vec::with_capacity(sources.files().len());
45 for f in sources.files() {
46 let (checksum, size) = sha256_and_size_from_bytes(f.source());
47 specs.push(ArtifactFileSpecRequest {
48 rel_path: f.dest_path().to_string(),
49 size_bytes: size,
50 checksum,
51 });
52 }
53
54 let res = self.client.create_artifact(
55 self.exp_path.owner_name(),
56 self.exp_path.project_name(),
57 self.exp_path.experiment_num(),
58 CreateArtifactRequest {
59 name: name.clone(),
60 kind: kind.to_string(),
61 files: specs,
62 },
63 )?;
64
65 let mut multipart_map: BTreeMap<String, &MultipartUploadResponse> = BTreeMap::new();
66 for f in &res.files {
67 multipart_map.insert(f.rel_path.clone(), &f.urls);
68 }
69
70 for f in sources.into_files() {
71 let multipart_info = multipart_map.get(f.dest_path()).ok_or_else(|| {
72 ArtifactError::Internal(format!(
73 "Missing multipart upload info for file {}",
74 f.dest_path()
75 ))
76 })?;
77
78 self.upload_file_multipart(f.source(), multipart_info)?;
79 }
80
81 self.client.complete_artifact_upload(
82 self.exp_path.owner_name(),
83 self.exp_path.project_name(),
84 self.exp_path.experiment_num(),
85 &res.id,
86 None,
87 )?;
88
89 Ok(res.id)
90 }
91
92 pub fn download<D: BundleDecode>(
94 &self,
95 name: impl AsRef<str>,
96 settings: &D::Settings,
97 ) -> Result<D, ArtifactError> {
98 let reader = self.download_raw(name.as_ref())?;
99 D::decode(&reader, settings).map_err(|e| {
100 ArtifactError::Decoding(format!(
101 "Failed to decode artifact {}: {}",
102 name.as_ref(),
103 e.into()
104 ))
105 })
106 }
107
108 pub fn download_raw(
110 &self,
111 name: impl AsRef<str>,
112 ) -> Result<InMemoryBundleReader, ArtifactError> {
113 let name = name.as_ref();
114 let artifact = self.fetch(name)?;
115 let resp = self.client.presign_artifact_download(
116 self.exp_path.owner_name(),
117 self.exp_path.project_name(),
118 self.exp_path.experiment_num(),
119 &artifact.id.to_string(),
120 )?;
121
122 let mut data = BTreeMap::new();
123
124 for file in resp.files {
125 data.insert(
126 file.rel_path.clone(),
127 self.client.download_bytes_from_url(&file.url)?,
128 );
129 }
130
131 Ok(InMemoryBundleReader::new(data))
132 }
133
134 pub fn fetch(&self, name: impl AsRef<str>) -> Result<ArtifactResponse, ArtifactError> {
136 let name = name.as_ref();
137 self.client
138 .list_artifacts_by_name(
139 self.exp_path.owner_name(),
140 self.exp_path.project_name(),
141 self.exp_path.experiment_num(),
142 name,
143 )?
144 .items
145 .into_iter()
146 .next()
147 .ok_or_else(|| ArtifactError::NotFound(name.to_owned()))
148 }
149
150 fn upload_file_multipart(
151 &self,
152 file_data: &[u8],
153 multipart_info: &MultipartUploadResponse,
154 ) -> Result<(), ArtifactError> {
155 let mut part_indices: Vec<usize> = (0..multipart_info.parts.len()).collect();
156 part_indices.sort_by_key(|&i| multipart_info.parts[i].part);
157
158 for (i, &part_idx) in part_indices.iter().enumerate() {
159 let part = &multipart_info.parts[part_idx];
160 if part.part != (i as u32 + 1) {
161 return Err(ArtifactError::Internal(format!(
162 "Invalid part numbering: expected part {}, got part {}",
163 i + 1,
164 part.part
165 )));
166 }
167 }
168
169 let mut current_offset = 0usize;
170 let total_parts = multipart_info.parts.len();
171
172 for (part_index, &part_idx) in part_indices.iter().enumerate() {
173 let part_info = &multipart_info.parts[part_idx];
174 let end_offset = std::cmp::min(
175 current_offset + part_info.size_bytes as usize,
176 file_data.len(),
177 );
178
179 if current_offset >= file_data.len() {
180 break;
181 }
182
183 let part_data = &file_data[current_offset..end_offset];
184
185 self.client
186 .upload_bytes_to_url(&part_info.url, part_data.to_vec())
187 .map_err(|e| {
188 ArtifactError::Internal(format!(
189 "Failed to upload part {} of {}: {}",
190 part_index + 1,
191 total_parts,
192 e
193 ))
194 })?;
195
196 current_offset = end_offset;
197 }
198
199 Ok(())
200 }
201}
202
203fn sha256_and_size_from_bytes(bytes: &[u8]) -> (String, u64) {
204 let mut hasher = sha2::Sha256::new();
205 hasher.update(bytes);
206 let digest = hasher.finalize();
207 (format!("{:x}", digest), bytes.len() as u64)
208}
209
210#[derive(Debug, thiserror::Error)]
211pub enum ArtifactError {
212 #[error("Artifact not found: {0}")]
213 NotFound(String),
214 #[error(transparent)]
215 Client(#[from] ClientError),
216 #[error("Error while encoding artifact: {0}")]
217 Encoding(String),
218 #[error("Error while decoding artifact: {0}")]
219 Decoding(String),
220 #[error("Internal error: {0}")]
221 Internal(String),
222}