burn_central_core/artifacts/
client.rs

1use burn_central_client::response::MultipartUploadResponse;
2use burn_central_client::{Client, ClientError};
3use sha2::Digest;
4use std::collections::BTreeMap;
5
6use crate::artifacts::ArtifactInfo;
7use crate::bundle::{BundleDecode, BundleEncode, InMemoryBundleReader, InMemoryBundleSources};
8use crate::schemas::ExperimentPath;
9use burn_central_client::request::{ArtifactFileSpecRequest, CreateArtifactRequest};
10
11#[derive(Debug, Clone, strum::Display, strum::EnumString)]
12#[strum(serialize_all = "snake_case")]
13pub enum ArtifactKind {
14    Model,
15    Log,
16    Other,
17}
18
19/// A scope for artifact operations within a specific experiment
20#[derive(Clone)]
21pub struct ExperimentArtifactClient {
22    client: Client,
23    exp_path: ExperimentPath,
24}
25
26impl ExperimentArtifactClient {
27    pub(crate) fn new(client: Client, exp_path: ExperimentPath) -> Self {
28        Self { client, exp_path }
29    }
30
31    /// Upload an artifact using the BundleEncode trait
32    pub fn upload<E: BundleEncode>(
33        &self,
34        name: impl Into<String>,
35        kind: ArtifactKind,
36        artifact: E,
37        settings: &E::Settings,
38    ) -> Result<String, ArtifactError> {
39        let name = name.into();
40        let mut sources = InMemoryBundleSources::new();
41        artifact.encode(&mut sources, settings).map_err(|e| {
42            ArtifactError::Encoding(format!("Failed to encode artifact: {}", e.into()))
43        })?;
44
45        let mut specs = Vec::with_capacity(sources.files().len());
46        for f in sources.files() {
47            let (checksum, size) = sha256_and_size_from_bytes(f.source());
48            specs.push(ArtifactFileSpecRequest {
49                rel_path: f.dest_path().to_string(),
50                size_bytes: size,
51                checksum,
52            });
53        }
54
55        let res = self.client.create_artifact(
56            self.exp_path.owner_name(),
57            self.exp_path.project_name(),
58            self.exp_path.experiment_num(),
59            CreateArtifactRequest {
60                name: name.clone(),
61                kind: kind.to_string(),
62                files: specs,
63            },
64        )?;
65
66        let mut multipart_map: BTreeMap<String, &MultipartUploadResponse> = BTreeMap::new();
67        for f in &res.files {
68            multipart_map.insert(f.rel_path.clone(), &f.urls);
69        }
70
71        for f in sources.into_files() {
72            let multipart_info = multipart_map.get(f.dest_path()).ok_or_else(|| {
73                ArtifactError::Internal(format!(
74                    "Missing multipart upload info for file {}",
75                    f.dest_path()
76                ))
77            })?;
78
79            self.upload_file_multipart(f.source(), multipart_info)?;
80        }
81
82        self.client.complete_artifact_upload(
83            self.exp_path.owner_name(),
84            self.exp_path.project_name(),
85            self.exp_path.experiment_num(),
86            &res.id,
87            None,
88        )?;
89
90        Ok(res.id)
91    }
92
93    /// Download an artifact and decode it using the BundleDecode trait
94    pub fn download<D: BundleDecode>(
95        &self,
96        name: impl AsRef<str>,
97        settings: &D::Settings,
98    ) -> Result<D, ArtifactError> {
99        let reader = self.download_raw(name.as_ref())?;
100        D::decode(&reader, settings).map_err(|e| {
101            ArtifactError::Decoding(format!(
102                "Failed to decode artifact {}: {}",
103                name.as_ref(),
104                e.into()
105            ))
106        })
107    }
108
109    /// Download an artifact as a raw memory bundle reader
110    pub fn download_raw(
111        &self,
112        name: impl AsRef<str>,
113    ) -> Result<InMemoryBundleReader, ArtifactError> {
114        let name = name.as_ref();
115        let artifact = self.fetch(name)?;
116        let resp = self.client.presign_artifact_download(
117            self.exp_path.owner_name(),
118            self.exp_path.project_name(),
119            self.exp_path.experiment_num(),
120            &artifact.id.to_string(),
121        )?;
122
123        let mut data = BTreeMap::new();
124
125        for file in resp.files {
126            data.insert(
127                file.rel_path.clone(),
128                self.client.download_bytes_from_url(&file.url)?,
129            );
130        }
131
132        Ok(InMemoryBundleReader::new(data))
133    }
134
135    /// Fetch information about an artifact by name
136    pub fn fetch(&self, name: impl AsRef<str>) -> Result<ArtifactInfo, ArtifactError> {
137        let name = name.as_ref();
138        let artifact_resp = self
139            .client
140            .list_artifacts_by_name(
141                self.exp_path.owner_name(),
142                self.exp_path.project_name(),
143                self.exp_path.experiment_num(),
144                name,
145            )?
146            .items
147            .into_iter()
148            .next()
149            .ok_or_else(|| ArtifactError::NotFound(name.to_owned()))?;
150
151        Ok(artifact_resp.into())
152    }
153
154    fn upload_file_multipart(
155        &self,
156        file_data: &[u8],
157        multipart_info: &MultipartUploadResponse,
158    ) -> Result<(), ArtifactError> {
159        let mut part_indices: Vec<usize> = (0..multipart_info.parts.len()).collect();
160        part_indices.sort_by_key(|&i| multipart_info.parts[i].part);
161
162        for (i, &part_idx) in part_indices.iter().enumerate() {
163            let part = &multipart_info.parts[part_idx];
164            if part.part != (i as u32 + 1) {
165                return Err(ArtifactError::Internal(format!(
166                    "Invalid part numbering: expected part {}, got part {}",
167                    i + 1,
168                    part.part
169                )));
170            }
171        }
172
173        let mut current_offset = 0usize;
174        let total_parts = multipart_info.parts.len();
175
176        for (part_index, &part_idx) in part_indices.iter().enumerate() {
177            let part_info = &multipart_info.parts[part_idx];
178            let end_offset = std::cmp::min(
179                current_offset + part_info.size_bytes as usize,
180                file_data.len(),
181            );
182
183            if current_offset >= file_data.len() {
184                break;
185            }
186
187            let part_data = &file_data[current_offset..end_offset];
188
189            self.client
190                .upload_bytes_to_url(&part_info.url, part_data.to_vec())
191                .map_err(|e| {
192                    ArtifactError::Internal(format!(
193                        "Failed to upload part {} of {}: {}",
194                        part_index + 1,
195                        total_parts,
196                        e
197                    ))
198                })?;
199
200            current_offset = end_offset;
201        }
202
203        Ok(())
204    }
205}
206
207fn sha256_and_size_from_bytes(bytes: &[u8]) -> (String, u64) {
208    let mut hasher = sha2::Sha256::new();
209    hasher.update(bytes);
210    let digest = hasher.finalize();
211    (format!("{:x}", digest), bytes.len() as u64)
212}
213
214#[derive(Debug, thiserror::Error)]
215pub enum ArtifactError {
216    #[error("Artifact not found: {0}")]
217    NotFound(String),
218    #[error(transparent)]
219    Client(#[from] ClientError),
220    #[error("Error while encoding artifact: {0}")]
221    Encoding(String),
222    #[error("Error while decoding artifact: {0}")]
223    Decoding(String),
224    #[error("Internal error: {0}")]
225    Internal(String),
226}