replicate_rs/
models.rs

1//! Utilities for interacting with models endpoints.
2//!
3//! This includes the following:
4//! - [Get a Model](https://replicate.com/docs/reference/http#models.get)
5//! - [Get a Model Version](https://replicate.com/docs/reference/http#models.versions.get)
6//! - [List a Model's Versions](https://replicate.com/docs/reference/http#models.versions.list)
7//! - [List all Public Models](https://replicate.com/docs/reference/http#models.list)
8//!
9use anyhow::anyhow;
10use reqwest::StatusCode;
11use serde::Deserialize;
12use serde_json::Value;
13
14use crate::config::ReplicateConfig;
15use crate::errors::{get_error, ReplicateError, ReplicateResult};
16
17#[derive(Debug, Deserialize)]
18struct ModelVersionError {
19    detail: String,
20}
21
22/// Version details for a particular model
23#[derive(Debug, Deserialize, Clone)]
24pub struct ModelVersion {
25    /// Id of the model
26    pub id: String,
27    /// Time in which the model was created
28    pub created_at: String,
29    /// Version of cog used to create the model
30    pub cog_version: String,
31    /// OpenAPI Schema of model input and outputs
32    pub openapi_schema: serde_json::Value,
33}
34
35/// Paginated view of all versions for a particular model
36#[derive(Debug, Deserialize)]
37pub struct ModelVersions {
38    /// Place in pagination
39    pub next: Option<String>,
40    /// Place in pagination
41    pub previous: Option<String>,
42    /// List of all versions available
43    pub results: Vec<ModelVersion>,
44}
45
46/// Paginated view of all available models
47#[derive(Debug, Deserialize)]
48pub struct Models {
49    /// Place in pagination
50    pub next: Option<String>,
51    /// Place in pagination
52    pub previous: Option<String>,
53    /// List of all versions available
54    pub results: Vec<Model>,
55}
56
57/// All details available for a particular Model
58#[derive(Deserialize, Debug)]
59pub struct Model {
60    /// URL for model homepage
61    pub url: String,
62    /// The owner of the model
63    pub owner: String,
64    /// The name of the model
65    pub name: String,
66    /// A brief description of the model
67    pub description: String,
68    /// Whether the model is public or private
69    pub visibility: String,
70    /// Github URL for the associated repo
71    pub github_url: String,
72    /// Url for an associated paper
73    pub paper_url: Option<String>,
74    /// Url for the model's license
75    pub license_url: Option<String>,
76    /// How many times the model has been run
77    pub run_count: usize,
78    /// Image URL to show on Replicate's Model page
79    pub cover_image_url: String,
80    /// A simple example to show model's use
81    pub default_example: Value,
82    /// The latest version's details
83    pub latest_version: ModelVersion,
84}
85
86/// A client for interacting with `models` endpoints
87pub struct ModelClient {
88    client: ReplicateConfig,
89}
90
91impl ModelClient {
92    /// Create a new `ModelClient` based upon a `ReplicateConfig` object
93    pub fn from(client: ReplicateConfig) -> Self {
94        ModelClient { client }
95    }
96
97    /// Retrieve details for a specific model
98    pub async fn get(&self, owner: &str, name: &str) -> anyhow::Result<Model> {
99        let api_key = self.client.get_api_key()?;
100        let base_url = self.client.get_base_url();
101        let endpoint = format!("{base_url}/models/{owner}/{name}");
102        let client = reqwest::Client::new();
103        let response = client
104            .get(endpoint)
105            .header("Authorization", format!("Token {api_key}"))
106            .send()
107            .await?;
108
109        let data = response.text().await?;
110        let model: Model = serde_json::from_str(&data)?;
111        anyhow::Ok(model)
112    }
113
114    /// Retrieve details for a specific model's version
115    pub async fn get_specific_version(
116        &self,
117        owner: &str,
118        name: &str,
119        version_id: &str,
120    ) -> ReplicateResult<Model> {
121        let api_key = self.client.get_api_key()?;
122        let base_url = self.client.get_base_url();
123        let endpoint = format!("{base_url}/models/{owner}/{name}/versions/{version_id}");
124        let client = reqwest::Client::new();
125        let response = client
126            .get(endpoint)
127            .header("Authorization", format!("Token {api_key}"))
128            .send()
129            .await
130            .map_err(|err| ReplicateError::ClientError(err.to_string()))?;
131
132        let data = response
133            .text()
134            .await
135            .map_err(|err| ReplicateError::ClientError(err.to_string()))?;
136        let model: Model = serde_json::from_str(&data)
137            .map_err(|err| ReplicateError::SerializationError(err.to_string()))?;
138        Ok(model)
139    }
140
141    /// Delete specific model version
142    pub async fn delete_version(
143        &self,
144        owner: &str,
145        name: &str,
146        version_id: &str,
147    ) -> ReplicateResult<()> {
148        let api_key = self.client.get_api_key()?;
149        let base_url = self.client.get_base_url();
150        let endpoint = format!("{base_url}/models/{owner}/{name}/versions/{version_id}");
151        let client = reqwest::Client::new();
152        let response = client
153            .delete(endpoint)
154            .header("Authorization", format!("Token {api_key}"))
155            .send()
156            .await
157            .map_err(|err| ReplicateError::ClientError(err.to_string()))?;
158
159        if response.status().is_success() {
160            Ok(())
161        } else {
162            Err(ReplicateError::Misc("delete request failed".to_string()))
163        }
164    }
165
166    /// Retrieve details for latest version of a specific model
167    pub async fn get_latest_version(
168        &self,
169        owner: &str,
170        name: &str,
171    ) -> ReplicateResult<ModelVersion> {
172        let all_versions = self.list_versions(owner, name).await?;
173        let latest_version = all_versions.results.get(0).ok_or(ReplicateError::Misc(
174            "no versions found for {owner}/{name}".to_string(),
175        ))?;
176        Ok(latest_version.clone())
177    }
178
179    /// Retrieve list of all available versions of a specific model
180    pub async fn list_versions(&self, owner: &str, name: &str) -> ReplicateResult<ModelVersions> {
181        let base_url = self.client.get_base_url();
182        let api_key = self.client.get_api_key()?;
183        let endpoint = format!("{base_url}/models/{owner}/{name}/versions");
184        let client = reqwest::Client::new();
185        let response = client
186            .get(endpoint)
187            .header("Authorization", format!("Token {api_key}"))
188            .send()
189            .await
190            .map_err(|err| ReplicateError::ClientError(err.to_string()))?;
191
192        let status = response.status();
193        let data = response
194            .text()
195            .await
196            .map_err(|err| ReplicateError::ClientError(err.to_string()))?;
197
198        return match status.clone() {
199            reqwest::StatusCode::OK => {
200                let data: ModelVersions = serde_json::from_str(&data)
201                    .map_err(|err| ReplicateError::SerializationError(err.to_string()))?;
202                Ok(data)
203            }
204            _ => Err(get_error(status, data.as_str())),
205        };
206    }
207
208    /// Retrieve all publically and private available models
209    pub async fn get_models(&self) -> ReplicateResult<Models> {
210        let base_url = self.client.get_base_url();
211        let api_key = self.client.get_api_key()?;
212        let endpoint = format!("{base_url}/models");
213        let client = reqwest::Client::new();
214        let response = client
215            .get(endpoint)
216            .header("Authorization", format!("Token {api_key}"))
217            .send()
218            .await
219            .map_err(|err| ReplicateError::ClientError(err.to_string()))?;
220
221        let data = response
222            .text()
223            .await
224            .map_err(|err| ReplicateError::ClientError(err.to_string()))?;
225        let models: Models = serde_json::from_str(&data)
226            .map_err(|err| ReplicateError::SerializationError(err.to_string()))?;
227        Ok(models)
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use httpmock::prelude::*;
235    use serde_json::json;
236
237    #[tokio::test]
238    async fn test_get_model() {
239        let mock_server = MockServer::start();
240
241        let model_mock = mock_server.mock(|when, then| {
242            when.method(GET).path("/models/replicate/hello-world");
243            then.status(200).json_body_obj(&json!({
244                "url": "https://replicate.com/replicate/hello-world",
245                "owner": "replicate",
246                "name": "hello-world",
247                "description": "A tiny model that says hello",
248                "visibility": "public",
249                "github_url": "https://github.com/replicate/cog-examples",
250                "paper_url": null,
251                "license_url": null,
252                "run_count": 5681081,
253                "cover_image_url": "...",
254                "default_example": null,
255                "latest_version": {
256                    "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
257                    "created_at": "2022-04-26T19:29:04.418669Z",
258                    "cog_version": "0.3.0",
259                    "openapi_schema": {}
260                }
261            }));
262        });
263
264        let client = ReplicateConfig::test(mock_server.base_url()).unwrap();
265        let model_client = ModelClient::from(client);
266        model_client.get("replicate", "hello-world").await.unwrap();
267
268        model_mock.assert();
269    }
270
271    #[tokio::test]
272    async fn test_get_specific_version() {
273        let mock_server = MockServer::start();
274
275        let model_mock = mock_server.mock(|when, then| {
276            when.method(GET)
277                .path("/models/replicate/hello-world/versions/1234");
278            then.status(200).json_body_obj(&json!({
279                "url": "https://replicate.com/replicate/hello-world",
280                "owner": "replicate",
281                "name": "hello-world",
282                "description": "A tiny model that says hello",
283                "visibility": "public",
284                "github_url": "https://github.com/replicate/cog-examples",
285                "paper_url": null,
286                "license_url": null,
287                "run_count": 5681081,
288                "cover_image_url": "...",
289                "default_example": null,
290                "latest_version": {
291                    "id": "1234",
292                    "created_at": "2022-04-26T19:29:04.418669Z",
293                    "cog_version": "0.3.0",
294                    "openapi_schema": {}
295                }
296            }));
297        });
298
299        let client = ReplicateConfig::test(mock_server.base_url()).unwrap();
300        let model_client = ModelClient::from(client);
301        model_client
302            .get_specific_version("replicate", "hello-world", "1234")
303            .await
304            .unwrap();
305
306        model_mock.assert();
307    }
308    #[tokio::test]
309    async fn test_list_model_versions() {
310        let mock_server = MockServer::start();
311
312        // Model endpoints
313        let model_mock = mock_server.mock(|when, then| {
314            when.method(GET)
315                .path("/models/replicate/hello-world/versions");
316
317            then.status(200).json_body_obj(&json!({
318                "next": null,
319                "previous": null,
320                "results": [{
321                    "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
322                    "created_at": "2022-04-26T19:29:04.418669Z",
323                    "cog_version": "0.3.0",
324                    "openapi_schema": null
325                }]
326            }));
327        });
328
329        let client = ReplicateConfig::test(mock_server.base_url()).unwrap();
330        let model_client = ModelClient::from(client);
331        model_client
332            .list_versions("replicate", "hello-world")
333            .await
334            .unwrap();
335
336        model_mock.assert();
337    }
338
339    #[tokio::test]
340    async fn test_get_latest_version() {
341        let mock_server = MockServer::start();
342
343        // Model endpoints
344        let model_mock = mock_server.mock(|when, then| {
345            when.method(GET)
346                .path("/models/replicate/hello-world/versions");
347
348            then.status(200).json_body_obj(&json!({
349                "next": null,
350                "previous": null,
351                "results": [{
352                    "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
353                    "created_at": "2022-04-26T19:29:04.418669Z",
354                    "cog_version": "0.3.0",
355                    "openapi_schema": null
356                }]
357            }));
358        });
359
360        let client = ReplicateConfig::test(mock_server.base_url()).unwrap();
361        let model_client = ModelClient::from(client);
362        model_client
363            .get_latest_version("replicate", "hello-world")
364            .await
365            .unwrap();
366
367        model_mock.assert();
368    }
369
370    #[tokio::test]
371    async fn test_get_models() {
372        let mock_server = MockServer::start();
373
374        // Model endpoints
375        let model_mock = mock_server.mock(|when, then| {
376            when.method(GET).path("/models");
377            then.status(200).json_body_obj(&json!({
378                    "next": "some pagination string or null",
379                    "previous": "some pagination string or null",
380            "results": [
381                {
382                "url": "https://modelhomepage.example.com",
383                "owner": "jdoe",
384                "name": "super-cool-model",
385                "description": "A model that predicts something very cool.",
386                "visibility": "public",
387                "github_url": "https://github.com/jdoe/super-cool-model",
388                "paper_url": "https://research.example.com/super-cool-model-paper.pdf",
389                "license_url": null,
390                "run_count": 420,
391                "cover_image_url": "https://cdn.example.com/images/super-cool-model-cover.jpg",
392                "default_example": {
393                    "input": "Example input data for the model."
394                },
395                "latest_version": {
396                    "id": "v1.0.0",
397                    "created_at": "2022-01-01T12:00:00Z",
398                    "cog_version": "0.2",
399                    "openapi_schema": null
400                }
401                },
402                {
403                "url": "https://anothermodelhomepage.example.com",
404                "owner": "asmith",
405                "name": "another-awesome-model",
406                "description": "This model does awesome things with data.",
407                "visibility": "private",
408                "github_url": "https://github.com/asmith/another-awesome-model",
409                "paper_url": null,
410                "license_url": "https://licenses.example.com/another-awesome-model-license.txt",
411                "run_count": 150,
412                "cover_image_url": "https://cdn.example.com/images/another-awesome-model-cover.jpg",
413                "default_example": {
414                    "input": "Some example input for this awesome model."
415                },
416                "latest_version": {
417                    "id": "v1.2.3",
418                    "created_at": "2023-02-15T08:30:00Z",
419                    "cog_version": "0.2",
420                    "openapi_schema": null
421                }
422            }
423        ]}));
424        });
425
426        let client = ReplicateConfig::test(mock_server.base_url()).unwrap();
427        let model_client = ModelClient::from(client);
428        model_client.get_models().await.unwrap();
429
430        model_mock.assert();
431    }
432}