replicate_rust/
model.rs

1//! # Used to interact with the [Model Endpoints](https://replicate.com/docs/reference/http#models.get).
2//!
3//! The model module contains all the functionality for interacting with the model endpoints of the Replicate API.
4//! Currently supports the following endpoint:
5//! * [Get Model](https://replicate.com/docs/reference/http#models.get)
6//! * [Get Model Versions](https://replicate.com/docs/reference/http#models.versions.get)
7//! * [List Model Versions](https://replicate.com/docs/reference/http#models.versions.list)
8//! * [Delete Model Version](https://replicate.com/docs/reference/http#models.versions.delete)
9//!
10//! # Example
11//! ```
12//! use replicate_rust::{Replicate, config::Config};
13//!
14//! let config = Config::default();
15//! let replicate = Replicate::new(config);
16//!
17//! let model = replicate.models.get("replicate", "hello-world")?;
18//! println!("Model : {:?}", model);
19//! # Ok::<(), replicate_rust::errors::ReplicateError>(())
20//! ```
21
22use crate::{api_definitions::GetModel, errors::ReplicateError, version::Version};
23
24// #[derive(Clone)]
25/// Used to interact with the [Model Endpoints](https://replicate.com/docs/reference/http#models.get).
26#[derive(Clone, Debug)]
27pub struct Model {
28    /// Holds a reference to a Configuration struct, which contains the base url,  auth token among other settings.
29    pub parent: crate::config::Config,
30
31    /// Holds a reference to a Version struct, which contains the functionality for interacting with the version endpoints of the Replicate API.
32    pub versions: Version,
33}
34
35/// Model struct contains all the functionality for interacting with the model endpoints of the Replicate API.
36/// Currently supports the following endpoint:
37/// * [Get Model](https://replicate.com/docs/reference/http#models.get)
38/// * [Get Model Versions](https://replicate.com/docs/reference/http#models.versions.get)
39/// * [List Model Versions](https://replicate.com/docs/reference/http#models.versions.list)
40/// * [Delete Model Version](https://replicate.com/docs/reference/http#models.versions.delete)
41///
42impl Model {
43    /// Create a new Model struct.
44    /// # Arguments
45    /// * `rep` - The config (`crate::config::Config`) to use for authentication and communication.
46    ///
47    pub fn new(rep: crate::config::Config) -> Self {
48        let versions = Version::new(rep.clone());
49        Self {
50            parent: rep,
51            versions,
52        }
53    }
54
55    /// Get the details of a model.
56    /// # Arguments
57    /// * `model_owner` - The owner of the model.
58    /// * `model_name` - The name of the model.
59    ///
60    /// # Example
61    /// ```
62    /// use replicate_rust::{Replicate, config::Config};
63    ///
64    /// let config = Config::default();
65    /// let replicate = Replicate::new(config);
66    ///
67    /// let model = replicate.models.get("replicate", "hello-world")?;
68    /// println!("Model : {:?}", model);
69    ///
70    /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
71    /// ```
72    pub fn get(&self, model_owner: &str, model_name: &str) -> Result<GetModel, ReplicateError> {
73        let client = reqwest::blocking::Client::new();
74
75        let response = client
76            .get(format!(
77                "{}/models/{}/{}",
78                self.parent.base_url, model_owner, model_name
79            ))
80            .header("Authorization", format!("Token {}", self.parent.auth))
81            .header("User-Agent", &self.parent.user_agent)
82            .send()?;
83
84        if !response.status().is_success() {
85            return Err(ReplicateError::ResponseError(response.text()?));
86        }
87
88        let response_string = response.text()?;
89        let response_struct: GetModel = serde_json::from_str(&response_string)?;
90
91        Ok(response_struct)
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use crate::{config::Config, errors::ReplicateError, Replicate};
98
99    use httpmock::{Method::GET, MockServer};
100    use serde_json::json;
101
102    #[test]
103    fn test_get() -> Result<(), ReplicateError> {
104        let server = MockServer::start();
105
106        let get_mock = server.mock(|when, then| {
107            when.method(GET).path("/models/replicate/hello-world");
108            then.status(200).json_body_obj(&json!( {
109                "url": "https://replicate.com/replicate/hello-world",
110                "owner": "replicate",
111                "name": "hello-world",
112                "description": "A tiny model that says hello",
113                "visibility": "public",
114                "github_url": "https://github.com/replicate/cog-examples",
115                "paper_url": None::<String>,
116                "license_url": None::<String>,
117                "run_count": 12345,
118                "cover_image_url": "",
119                "default_example": {},
120                "latest_version": {}
121            }
122            ));
123        });
124
125        let config = Config {
126            auth: String::from("test"),
127            base_url: server.base_url(),
128            ..Config::default()
129        };
130        let replicate = Replicate::new(config);
131
132        let result = replicate.models.get("replicate", "hello-world")?;
133
134        println!("{:?}", result);
135
136        // Ensure the mocks were called as expected
137        get_mock.assert();
138
139        Ok(())
140    }
141}