Skip to main content

burn_central_core/artifacts/
mod.rs

1use burn_central_client::response::{ArtifactResponse, MultipartUploadResponse};
2use burn_central_client::{Client, ClientError};
3use sha2::Digest;
4use std::collections::BTreeMap;
5
6use crate::bundle::{BundleDecode, BundleEncode, InMemoryBundleReader, InMemoryBundleSources};
7use crate::schemas::ExperimentPath;
8use burn_central_client::request::{ArtifactFileSpecRequest, CreateArtifactRequest};
9
10#[derive(Debug, Clone, strum::Display, strum::EnumString)]
11#[strum(serialize_all = "snake_case")]
12pub enum ArtifactKind {
13    Model,
14    Log,
15    Other,
16}
17
18/// A scope for artifact operations within a specific experiment
19#[derive(Clone)]
20pub struct ExperimentArtifactClient {
21    client: Client,
22    exp_path: ExperimentPath,
23}
24
25impl ExperimentArtifactClient {
26    pub(crate) fn new(client: Client, exp_path: ExperimentPath) -> Self {
27        Self { client, exp_path }
28    }
29
30    /// Upload an artifact using the BundleEncode trait
31    pub fn upload<E: BundleEncode>(
32        &self,
33        name: impl Into<String>,
34        kind: ArtifactKind,
35        artifact: E,
36        settings: &E::Settings,
37    ) -> Result<String, ArtifactError> {
38        let name = name.into();
39        let mut sources = InMemoryBundleSources::new();
40        artifact.encode(&mut sources, settings).map_err(|e| {
41            ArtifactError::Encoding(format!("Failed to encode artifact: {}", e.into()))
42        })?;
43
44        let mut specs = Vec::with_capacity(sources.files().len());
45        for f in sources.files() {
46            let (checksum, size) = sha256_and_size_from_bytes(f.source());
47            specs.push(ArtifactFileSpecRequest {
48                rel_path: f.dest_path().to_string(),
49                size_bytes: size,
50                checksum,
51            });
52        }
53
54        let res = self.client.create_artifact(
55            self.exp_path.owner_name(),
56            self.exp_path.project_name(),
57            self.exp_path.experiment_num(),
58            CreateArtifactRequest {
59                name: name.clone(),
60                kind: kind.to_string(),
61                files: specs,
62            },
63        )?;
64
65        let mut multipart_map: BTreeMap<String, &MultipartUploadResponse> = BTreeMap::new();
66        for f in &res.files {
67            multipart_map.insert(f.rel_path.clone(), &f.urls);
68        }
69
70        for f in sources.into_files() {
71            let multipart_info = multipart_map.get(f.dest_path()).ok_or_else(|| {
72                ArtifactError::Internal(format!(
73                    "Missing multipart upload info for file {}",
74                    f.dest_path()
75                ))
76            })?;
77
78            self.upload_file_multipart(f.source(), multipart_info)?;
79        }
80
81        self.client.complete_artifact_upload(
82            self.exp_path.owner_name(),
83            self.exp_path.project_name(),
84            self.exp_path.experiment_num(),
85            &res.id,
86            None,
87        )?;
88
89        Ok(res.id)
90    }
91
92    /// Download an artifact and decode it using the BundleDecode trait
93    pub fn download<D: BundleDecode>(
94        &self,
95        name: impl AsRef<str>,
96        settings: &D::Settings,
97    ) -> Result<D, ArtifactError> {
98        let reader = self.download_raw(name.as_ref())?;
99        D::decode(&reader, settings).map_err(|e| {
100            ArtifactError::Decoding(format!(
101                "Failed to decode artifact {}: {}",
102                name.as_ref(),
103                e.into()
104            ))
105        })
106    }
107
108    /// Download an artifact as a raw memory bundle reader
109    pub fn download_raw(
110        &self,
111        name: impl AsRef<str>,
112    ) -> Result<InMemoryBundleReader, ArtifactError> {
113        let name = name.as_ref();
114        let artifact = self.fetch(name)?;
115        let resp = self.client.presign_artifact_download(
116            self.exp_path.owner_name(),
117            self.exp_path.project_name(),
118            self.exp_path.experiment_num(),
119            &artifact.id.to_string(),
120        )?;
121
122        let mut data = BTreeMap::new();
123
124        for file in resp.files {
125            data.insert(
126                file.rel_path.clone(),
127                self.client.download_bytes_from_url(&file.url)?,
128            );
129        }
130
131        Ok(InMemoryBundleReader::new(data))
132    }
133
134    /// Fetch information about an artifact by name
135    pub fn fetch(&self, name: impl AsRef<str>) -> Result<ArtifactResponse, ArtifactError> {
136        let name = name.as_ref();
137        self.client
138            .list_artifacts_by_name(
139                self.exp_path.owner_name(),
140                self.exp_path.project_name(),
141                self.exp_path.experiment_num(),
142                name,
143            )?
144            .items
145            .into_iter()
146            .next()
147            .ok_or_else(|| ArtifactError::NotFound(name.to_owned()))
148    }
149
150    fn upload_file_multipart(
151        &self,
152        file_data: &[u8],
153        multipart_info: &MultipartUploadResponse,
154    ) -> Result<(), ArtifactError> {
155        let mut part_indices: Vec<usize> = (0..multipart_info.parts.len()).collect();
156        part_indices.sort_by_key(|&i| multipart_info.parts[i].part);
157
158        for (i, &part_idx) in part_indices.iter().enumerate() {
159            let part = &multipart_info.parts[part_idx];
160            if part.part != (i as u32 + 1) {
161                return Err(ArtifactError::Internal(format!(
162                    "Invalid part numbering: expected part {}, got part {}",
163                    i + 1,
164                    part.part
165                )));
166            }
167        }
168
169        let mut current_offset = 0usize;
170        let total_parts = multipart_info.parts.len();
171
172        for (part_index, &part_idx) in part_indices.iter().enumerate() {
173            let part_info = &multipart_info.parts[part_idx];
174            let end_offset = std::cmp::min(
175                current_offset + part_info.size_bytes as usize,
176                file_data.len(),
177            );
178
179            if current_offset >= file_data.len() {
180                break;
181            }
182
183            let part_data = &file_data[current_offset..end_offset];
184
185            self.client
186                .upload_bytes_to_url(&part_info.url, part_data.to_vec())
187                .map_err(|e| {
188                    ArtifactError::Internal(format!(
189                        "Failed to upload part {} of {}: {}",
190                        part_index + 1,
191                        total_parts,
192                        e
193                    ))
194                })?;
195
196            current_offset = end_offset;
197        }
198
199        Ok(())
200    }
201}
202
203fn sha256_and_size_from_bytes(bytes: &[u8]) -> (String, u64) {
204    let mut hasher = sha2::Sha256::new();
205    hasher.update(bytes);
206    let digest = hasher.finalize();
207    (format!("{:x}", digest), bytes.len() as u64)
208}
209
210#[derive(Debug, thiserror::Error)]
211pub enum ArtifactError {
212    #[error("Artifact not found: {0}")]
213    NotFound(String),
214    #[error(transparent)]
215    Client(#[from] ClientError),
216    #[error("Error while encoding artifact: {0}")]
217    Encoding(String),
218    #[error("Error while decoding artifact: {0}")]
219    Decoding(String),
220    #[error("Internal error: {0}")]
221    Internal(String),
222}