1use anyhow::anyhow;
10use reqwest::StatusCode;
11use serde::Deserialize;
12use serde_json::Value;
13
14use crate::config::ReplicateConfig;
15use crate::errors::{get_error, ReplicateError, ReplicateResult};
16
17#[derive(Debug, Deserialize)]
18struct ModelVersionError {
19 detail: String,
20}
21
22#[derive(Debug, Deserialize, Clone)]
24pub struct ModelVersion {
25 pub id: String,
27 pub created_at: String,
29 pub cog_version: String,
31 pub openapi_schema: serde_json::Value,
33}
34
35#[derive(Debug, Deserialize)]
37pub struct ModelVersions {
38 pub next: Option<String>,
40 pub previous: Option<String>,
42 pub results: Vec<ModelVersion>,
44}
45
46#[derive(Debug, Deserialize)]
48pub struct Models {
49 pub next: Option<String>,
51 pub previous: Option<String>,
53 pub results: Vec<Model>,
55}
56
57#[derive(Deserialize, Debug)]
59pub struct Model {
60 pub url: String,
62 pub owner: String,
64 pub name: String,
66 pub description: String,
68 pub visibility: String,
70 pub github_url: String,
72 pub paper_url: Option<String>,
74 pub license_url: Option<String>,
76 pub run_count: usize,
78 pub cover_image_url: String,
80 pub default_example: Value,
82 pub latest_version: ModelVersion,
84}
85
86pub struct ModelClient {
88 client: ReplicateConfig,
89}
90
91impl ModelClient {
92 pub fn from(client: ReplicateConfig) -> Self {
94 ModelClient { client }
95 }
96
97 pub async fn get(&self, owner: &str, name: &str) -> anyhow::Result<Model> {
99 let api_key = self.client.get_api_key()?;
100 let base_url = self.client.get_base_url();
101 let endpoint = format!("{base_url}/models/{owner}/{name}");
102 let client = reqwest::Client::new();
103 let response = client
104 .get(endpoint)
105 .header("Authorization", format!("Token {api_key}"))
106 .send()
107 .await?;
108
109 let data = response.text().await?;
110 let model: Model = serde_json::from_str(&data)?;
111 anyhow::Ok(model)
112 }
113
114 pub async fn get_specific_version(
116 &self,
117 owner: &str,
118 name: &str,
119 version_id: &str,
120 ) -> ReplicateResult<Model> {
121 let api_key = self.client.get_api_key()?;
122 let base_url = self.client.get_base_url();
123 let endpoint = format!("{base_url}/models/{owner}/{name}/versions/{version_id}");
124 let client = reqwest::Client::new();
125 let response = client
126 .get(endpoint)
127 .header("Authorization", format!("Token {api_key}"))
128 .send()
129 .await
130 .map_err(|err| ReplicateError::ClientError(err.to_string()))?;
131
132 let data = response
133 .text()
134 .await
135 .map_err(|err| ReplicateError::ClientError(err.to_string()))?;
136 let model: Model = serde_json::from_str(&data)
137 .map_err(|err| ReplicateError::SerializationError(err.to_string()))?;
138 Ok(model)
139 }
140
141 pub async fn delete_version(
143 &self,
144 owner: &str,
145 name: &str,
146 version_id: &str,
147 ) -> ReplicateResult<()> {
148 let api_key = self.client.get_api_key()?;
149 let base_url = self.client.get_base_url();
150 let endpoint = format!("{base_url}/models/{owner}/{name}/versions/{version_id}");
151 let client = reqwest::Client::new();
152 let response = client
153 .delete(endpoint)
154 .header("Authorization", format!("Token {api_key}"))
155 .send()
156 .await
157 .map_err(|err| ReplicateError::ClientError(err.to_string()))?;
158
159 if response.status().is_success() {
160 Ok(())
161 } else {
162 Err(ReplicateError::Misc("delete request failed".to_string()))
163 }
164 }
165
166 pub async fn get_latest_version(
168 &self,
169 owner: &str,
170 name: &str,
171 ) -> ReplicateResult<ModelVersion> {
172 let all_versions = self.list_versions(owner, name).await?;
173 let latest_version = all_versions.results.get(0).ok_or(ReplicateError::Misc(
174 "no versions found for {owner}/{name}".to_string(),
175 ))?;
176 Ok(latest_version.clone())
177 }
178
179 pub async fn list_versions(&self, owner: &str, name: &str) -> ReplicateResult<ModelVersions> {
181 let base_url = self.client.get_base_url();
182 let api_key = self.client.get_api_key()?;
183 let endpoint = format!("{base_url}/models/{owner}/{name}/versions");
184 let client = reqwest::Client::new();
185 let response = client
186 .get(endpoint)
187 .header("Authorization", format!("Token {api_key}"))
188 .send()
189 .await
190 .map_err(|err| ReplicateError::ClientError(err.to_string()))?;
191
192 let status = response.status();
193 let data = response
194 .text()
195 .await
196 .map_err(|err| ReplicateError::ClientError(err.to_string()))?;
197
198 return match status.clone() {
199 reqwest::StatusCode::OK => {
200 let data: ModelVersions = serde_json::from_str(&data)
201 .map_err(|err| ReplicateError::SerializationError(err.to_string()))?;
202 Ok(data)
203 }
204 _ => Err(get_error(status, data.as_str())),
205 };
206 }
207
208 pub async fn get_models(&self) -> ReplicateResult<Models> {
210 let base_url = self.client.get_base_url();
211 let api_key = self.client.get_api_key()?;
212 let endpoint = format!("{base_url}/models");
213 let client = reqwest::Client::new();
214 let response = client
215 .get(endpoint)
216 .header("Authorization", format!("Token {api_key}"))
217 .send()
218 .await
219 .map_err(|err| ReplicateError::ClientError(err.to_string()))?;
220
221 let data = response
222 .text()
223 .await
224 .map_err(|err| ReplicateError::ClientError(err.to_string()))?;
225 let models: Models = serde_json::from_str(&data)
226 .map_err(|err| ReplicateError::SerializationError(err.to_string()))?;
227 Ok(models)
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use httpmock::prelude::*;
235 use serde_json::json;
236
237 #[tokio::test]
238 async fn test_get_model() {
239 let mock_server = MockServer::start();
240
241 let model_mock = mock_server.mock(|when, then| {
242 when.method(GET).path("/models/replicate/hello-world");
243 then.status(200).json_body_obj(&json!({
244 "url": "https://replicate.com/replicate/hello-world",
245 "owner": "replicate",
246 "name": "hello-world",
247 "description": "A tiny model that says hello",
248 "visibility": "public",
249 "github_url": "https://github.com/replicate/cog-examples",
250 "paper_url": null,
251 "license_url": null,
252 "run_count": 5681081,
253 "cover_image_url": "...",
254 "default_example": null,
255 "latest_version": {
256 "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
257 "created_at": "2022-04-26T19:29:04.418669Z",
258 "cog_version": "0.3.0",
259 "openapi_schema": {}
260 }
261 }));
262 });
263
264 let client = ReplicateConfig::test(mock_server.base_url()).unwrap();
265 let model_client = ModelClient::from(client);
266 model_client.get("replicate", "hello-world").await.unwrap();
267
268 model_mock.assert();
269 }
270
271 #[tokio::test]
272 async fn test_get_specific_version() {
273 let mock_server = MockServer::start();
274
275 let model_mock = mock_server.mock(|when, then| {
276 when.method(GET)
277 .path("/models/replicate/hello-world/versions/1234");
278 then.status(200).json_body_obj(&json!({
279 "url": "https://replicate.com/replicate/hello-world",
280 "owner": "replicate",
281 "name": "hello-world",
282 "description": "A tiny model that says hello",
283 "visibility": "public",
284 "github_url": "https://github.com/replicate/cog-examples",
285 "paper_url": null,
286 "license_url": null,
287 "run_count": 5681081,
288 "cover_image_url": "...",
289 "default_example": null,
290 "latest_version": {
291 "id": "1234",
292 "created_at": "2022-04-26T19:29:04.418669Z",
293 "cog_version": "0.3.0",
294 "openapi_schema": {}
295 }
296 }));
297 });
298
299 let client = ReplicateConfig::test(mock_server.base_url()).unwrap();
300 let model_client = ModelClient::from(client);
301 model_client
302 .get_specific_version("replicate", "hello-world", "1234")
303 .await
304 .unwrap();
305
306 model_mock.assert();
307 }
308 #[tokio::test]
309 async fn test_list_model_versions() {
310 let mock_server = MockServer::start();
311
312 let model_mock = mock_server.mock(|when, then| {
314 when.method(GET)
315 .path("/models/replicate/hello-world/versions");
316
317 then.status(200).json_body_obj(&json!({
318 "next": null,
319 "previous": null,
320 "results": [{
321 "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
322 "created_at": "2022-04-26T19:29:04.418669Z",
323 "cog_version": "0.3.0",
324 "openapi_schema": null
325 }]
326 }));
327 });
328
329 let client = ReplicateConfig::test(mock_server.base_url()).unwrap();
330 let model_client = ModelClient::from(client);
331 model_client
332 .list_versions("replicate", "hello-world")
333 .await
334 .unwrap();
335
336 model_mock.assert();
337 }
338
339 #[tokio::test]
340 async fn test_get_latest_version() {
341 let mock_server = MockServer::start();
342
343 let model_mock = mock_server.mock(|when, then| {
345 when.method(GET)
346 .path("/models/replicate/hello-world/versions");
347
348 then.status(200).json_body_obj(&json!({
349 "next": null,
350 "previous": null,
351 "results": [{
352 "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
353 "created_at": "2022-04-26T19:29:04.418669Z",
354 "cog_version": "0.3.0",
355 "openapi_schema": null
356 }]
357 }));
358 });
359
360 let client = ReplicateConfig::test(mock_server.base_url()).unwrap();
361 let model_client = ModelClient::from(client);
362 model_client
363 .get_latest_version("replicate", "hello-world")
364 .await
365 .unwrap();
366
367 model_mock.assert();
368 }
369
370 #[tokio::test]
371 async fn test_get_models() {
372 let mock_server = MockServer::start();
373
374 let model_mock = mock_server.mock(|when, then| {
376 when.method(GET).path("/models");
377 then.status(200).json_body_obj(&json!({
378 "next": "some pagination string or null",
379 "previous": "some pagination string or null",
380 "results": [
381 {
382 "url": "https://modelhomepage.example.com",
383 "owner": "jdoe",
384 "name": "super-cool-model",
385 "description": "A model that predicts something very cool.",
386 "visibility": "public",
387 "github_url": "https://github.com/jdoe/super-cool-model",
388 "paper_url": "https://research.example.com/super-cool-model-paper.pdf",
389 "license_url": null,
390 "run_count": 420,
391 "cover_image_url": "https://cdn.example.com/images/super-cool-model-cover.jpg",
392 "default_example": {
393 "input": "Example input data for the model."
394 },
395 "latest_version": {
396 "id": "v1.0.0",
397 "created_at": "2022-01-01T12:00:00Z",
398 "cog_version": "0.2",
399 "openapi_schema": null
400 }
401 },
402 {
403 "url": "https://anothermodelhomepage.example.com",
404 "owner": "asmith",
405 "name": "another-awesome-model",
406 "description": "This model does awesome things with data.",
407 "visibility": "private",
408 "github_url": "https://github.com/asmith/another-awesome-model",
409 "paper_url": null,
410 "license_url": "https://licenses.example.com/another-awesome-model-license.txt",
411 "run_count": 150,
412 "cover_image_url": "https://cdn.example.com/images/another-awesome-model-cover.jpg",
413 "default_example": {
414 "input": "Some example input for this awesome model."
415 },
416 "latest_version": {
417 "id": "v1.2.3",
418 "created_at": "2023-02-15T08:30:00Z",
419 "cog_version": "0.2",
420 "openapi_schema": null
421 }
422 }
423 ]}));
424 });
425
426 let client = ReplicateConfig::test(mock_server.base_url()).unwrap();
427 let model_client = ModelClient::from(client);
428 model_client.get_models().await.unwrap();
429
430 model_mock.assert();
431 }
432}