burn_central_core/models/
mod.rs1use std::collections::BTreeMap;
2
3use crate::bundle::{BundleDecode, InMemoryBundleReader};
4use burn_central_client::response::{ModelResponse, ModelVersionResponse};
5use burn_central_client::{Client, ClientError};
6
7#[derive(Clone)]
9pub struct ModelRegistry {
10 client: Client,
11}
12
13#[derive(Debug, Clone, PartialEq)]
14pub struct ModelPath {
15 owner_name: String,
16 project_name: String,
17 model_name: String,
18}
19
20impl ModelRegistry {
21 pub fn new(client: Client) -> Self {
22 Self { client }
23 }
24
25 pub fn model(&self, model_path: ModelPath) -> Result<ModelClient, ModelError> {
27 let response = self
28 .client
29 .get_model(
30 &model_path.owner_name,
31 &model_path.project_name,
32 &model_path.model_name,
33 )
34 .map_err(|e| {
35 if matches!(e, ClientError::NotFound) {
36 ModelError::NotFound(format!("Model not found: {:?}", model_path))
37 } else {
38 ModelError::Client(e)
39 }
40 })?;
41
42 Ok(ModelClient::new(self.client.clone(), model_path, response))
43 }
44
45 pub fn download<T: BundleDecode>(
47 &self,
48 model_path: ModelPath,
49 version: u32,
50 settings: &T::Settings,
51 ) -> Result<T, ModelError> {
52 let scope = self.model(model_path)?;
53 scope.download(version, settings)
54 }
55
56 pub fn download_raw(
58 &self,
59 model_path: ModelPath,
60 version: u32,
61 ) -> Result<InMemoryBundleReader, ModelError> {
62 let scope = self.model(model_path)?;
63 scope.download_raw(version)
64 }
65}
66
67#[derive(Clone)]
69pub struct ModelClient {
70 client: Client,
71 model_path: ModelPath,
72 model: ModelResponse,
73}
74
75impl ModelClient {
76 pub(crate) fn new(client: Client, model_path: ModelPath, model: ModelResponse) -> Self {
77 Self {
78 client,
79 model_path,
80 model,
81 }
82 }
83
84 pub fn download<T: BundleDecode>(
87 &self,
88 version: u32,
89 settings: &T::Settings,
90 ) -> Result<T, ModelError> {
91 let reader = self.download_raw(version)?;
92 T::decode(&reader, settings).map_err(|e| {
93 ModelError::Decode(format!(
94 "Failed to decode model {:?}: {}",
95 self.model_path,
96 e.into()
97 ))
98 })
99 }
100
101 pub fn download_raw(&self, version: u32) -> Result<InMemoryBundleReader, ModelError> {
103 let resp = self
104 .client
105 .presign_model_download(
106 &self.model_path.owner_name,
107 &self.model_path.project_name,
108 &self.model_path.model_name,
109 version,
110 )
111 .map_err(|e| {
112 if matches!(e, ClientError::NotFound) {
113 ModelError::VersionNotFound(format!("{:?} v{}", self.model_path, version))
114 } else {
115 ModelError::Client(e)
116 }
117 })?;
118
119 let mut data = BTreeMap::new();
120
121 for file in resp.files {
122 let bytes = self.client.download_bytes_from_url(&file.url)?;
123 data.insert(file.rel_path, bytes);
124 }
125
126 Ok(InMemoryBundleReader::new(data))
127 }
128
129 pub fn fetch(&self, version: u32) -> Result<ModelVersionResponse, ModelError> {
131 self.client
132 .get_model_version(
133 &self.model_path.owner_name,
134 &self.model_path.project_name,
135 &self.model_path.model_name,
136 version,
137 )
138 .map_err(|e| {
139 if matches!(e, ClientError::NotFound) {
140 ModelError::VersionNotFound(format!("{:?} v{}", self.model_path, version))
141 } else {
142 ModelError::Client(e)
143 }
144 })
145 }
146
147 pub fn total_versions(&self) -> u64 {
149 self.model.version_count
150 }
151}
152
153#[derive(Debug, thiserror::Error)]
154pub enum ModelError {
155 #[error("Client error: {0}")]
156 Client(#[from] ClientError),
157 #[error("Decode error: {0}")]
158 Decode(String),
159 #[error("Model not found: {0}")]
160 NotFound(String),
161 #[error("Model version not found: {0}")]
162 VersionNotFound(String),
163}