Skip to main content

burn_central_core/models/
mod.rs

1use std::collections::BTreeMap;
2
3use crate::bundle::{BundleDecode, InMemoryBundleReader};
4use burn_central_client::response::{ModelResponse, ModelVersionResponse};
5use burn_central_client::{Client, ClientError};
6
7/// An interface for downloading models from Burn Central.
8#[derive(Clone)]
9pub struct ModelRegistry {
10    client: Client,
11}
12
13#[derive(Debug, Clone, PartialEq)]
14pub struct ModelPath {
15    owner_name: String,
16    project_name: String,
17    model_name: String,
18}
19
20impl ModelRegistry {
21    pub fn new(client: Client) -> Self {
22        Self { client }
23    }
24
25    /// Create a scope for a specific model within a project.
26    pub fn model(&self, model_path: ModelPath) -> Result<ModelClient, ModelError> {
27        let response = self
28            .client
29            .get_model(
30                &model_path.owner_name,
31                &model_path.project_name,
32                &model_path.model_name,
33            )
34            .map_err(|e| {
35                if matches!(e, ClientError::NotFound) {
36                    ModelError::NotFound(format!("Model not found: {:?}", model_path))
37                } else {
38                    ModelError::Client(e)
39                }
40            })?;
41
42        Ok(ModelClient::new(self.client.clone(), model_path, response))
43    }
44
45    /// Download a specific model version and decode it using the BundleDecode trait.
46    pub fn download<T: BundleDecode>(
47        &self,
48        model_path: ModelPath,
49        version: u32,
50        settings: &T::Settings,
51    ) -> Result<T, ModelError> {
52        let scope = self.model(model_path)?;
53        scope.download(version, settings)
54    }
55
56    /// Download a specific model version as a memory reader for dynamic access.
57    pub fn download_raw(
58        &self,
59        model_path: ModelPath,
60        version: u32,
61    ) -> Result<InMemoryBundleReader, ModelError> {
62        let scope = self.model(model_path)?;
63        scope.download_raw(version)
64    }
65}
66
67/// A scope for operations on a specific model within a project.
68#[derive(Clone)]
69pub struct ModelClient {
70    client: Client,
71    model_path: ModelPath,
72    model: ModelResponse,
73}
74
75impl ModelClient {
76    pub(crate) fn new(client: Client, model_path: ModelPath, model: ModelResponse) -> Self {
77        Self {
78            client,
79            model_path,
80            model,
81        }
82    }
83
84    /// Download a specific version of this model and decode it using the BundleDecode trait.
85    /// This allows reusing existing bundle decoders for models.
86    pub fn download<T: BundleDecode>(
87        &self,
88        version: u32,
89        settings: &T::Settings,
90    ) -> Result<T, ModelError> {
91        let reader = self.download_raw(version)?;
92        T::decode(&reader, settings).map_err(|e| {
93            ModelError::Decode(format!(
94                "Failed to decode model {:?}: {}",
95                self.model_path,
96                e.into()
97            ))
98        })
99    }
100
101    /// Download a specific version of this model as a memory reader for dynamic access.
102    pub fn download_raw(&self, version: u32) -> Result<InMemoryBundleReader, ModelError> {
103        let resp = self
104            .client
105            .presign_model_download(
106                &self.model_path.owner_name,
107                &self.model_path.project_name,
108                &self.model_path.model_name,
109                version,
110            )
111            .map_err(|e| {
112                if matches!(e, ClientError::NotFound) {
113                    ModelError::VersionNotFound(format!("{:?} v{}", self.model_path, version))
114                } else {
115                    ModelError::Client(e)
116                }
117            })?;
118
119        let mut data = BTreeMap::new();
120
121        for file in resp.files {
122            let bytes = self.client.download_bytes_from_url(&file.url)?;
123            data.insert(file.rel_path, bytes);
124        }
125
126        Ok(InMemoryBundleReader::new(data))
127    }
128
129    /// Get information about a specific model version.
130    pub fn fetch(&self, version: u32) -> Result<ModelVersionResponse, ModelError> {
131        self.client
132            .get_model_version(
133                &self.model_path.owner_name,
134                &self.model_path.project_name,
135                &self.model_path.model_name,
136                version,
137            )
138            .map_err(|e| {
139                if matches!(e, ClientError::NotFound) {
140                    ModelError::VersionNotFound(format!("{:?} v{}", self.model_path, version))
141                } else {
142                    ModelError::Client(e)
143                }
144            })
145    }
146
147    /// Get the total number of versions available for this model.
148    pub fn total_versions(&self) -> u64 {
149        self.model.version_count
150    }
151}
152
153#[derive(Debug, thiserror::Error)]
154pub enum ModelError {
155    #[error("Client error: {0}")]
156    Client(#[from] ClientError),
157    #[error("Decode error: {0}")]
158    Decode(String),
159    #[error("Model not found: {0}")]
160    NotFound(String),
161    #[error("Model version not found: {0}")]
162    VersionNotFound(String),
163}