replicate_rust/model.rs
1//! # Used to interact with the [Model Endpoints](https://replicate.com/docs/reference/http#models.get).
2//!
3//! The model module contains all the functionality for interacting with the model endpoints of the Replicate API.
4//! Currently supports the following endpoint:
5//! * [Get Model](https://replicate.com/docs/reference/http#models.get)
6//! * [Get Model Versions](https://replicate.com/docs/reference/http#models.versions.get)
7//! * [List Model Versions](https://replicate.com/docs/reference/http#models.versions.list)
8//! * [Delete Model Version](https://replicate.com/docs/reference/http#models.versions.delete)
9//!
10//! # Example
11//! ```
12//! use replicate_rust::{Replicate, config::Config};
13//!
14//! let config = Config::default();
15//! let replicate = Replicate::new(config);
16//!
17//! let model = replicate.models.get("replicate", "hello-world")?;
18//! println!("Model : {:?}", model);
19//! # Ok::<(), replicate_rust::errors::ReplicateError>(())
20//! ```
21
22use crate::{api_definitions::GetModel, errors::ReplicateError, version::Version};
23
24// #[derive(Clone)]
25/// Used to interact with the [Model Endpoints](https://replicate.com/docs/reference/http#models.get).
26#[derive(Clone, Debug)]
27pub struct Model {
28 /// Holds a reference to a Configuration struct, which contains the base url, auth token among other settings.
29 pub parent: crate::config::Config,
30
31 /// Holds a reference to a Version struct, which contains the functionality for interacting with the version endpoints of the Replicate API.
32 pub versions: Version,
33}
34
35/// Model struct contains all the functionality for interacting with the model endpoints of the Replicate API.
36/// Currently supports the following endpoint:
37/// * [Get Model](https://replicate.com/docs/reference/http#models.get)
38/// * [Get Model Versions](https://replicate.com/docs/reference/http#models.versions.get)
39/// * [List Model Versions](https://replicate.com/docs/reference/http#models.versions.list)
40/// * [Delete Model Version](https://replicate.com/docs/reference/http#models.versions.delete)
41///
42impl Model {
43 /// Create a new Model struct.
44 /// # Arguments
45 /// * `rep` - The config (`crate::config::Config`) to use for authentication and communication.
46 ///
47 pub fn new(rep: crate::config::Config) -> Self {
48 let versions = Version::new(rep.clone());
49 Self {
50 parent: rep,
51 versions,
52 }
53 }
54
55 /// Get the details of a model.
56 /// # Arguments
57 /// * `model_owner` - The owner of the model.
58 /// * `model_name` - The name of the model.
59 ///
60 /// # Example
61 /// ```
62 /// use replicate_rust::{Replicate, config::Config};
63 ///
64 /// let config = Config::default();
65 /// let replicate = Replicate::new(config);
66 ///
67 /// let model = replicate.models.get("replicate", "hello-world")?;
68 /// println!("Model : {:?}", model);
69 ///
70 /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
71 /// ```
72 pub fn get(&self, model_owner: &str, model_name: &str) -> Result<GetModel, ReplicateError> {
73 let client = reqwest::blocking::Client::new();
74
75 let response = client
76 .get(format!(
77 "{}/models/{}/{}",
78 self.parent.base_url, model_owner, model_name
79 ))
80 .header("Authorization", format!("Token {}", self.parent.auth))
81 .header("User-Agent", &self.parent.user_agent)
82 .send()?;
83
84 if !response.status().is_success() {
85 return Err(ReplicateError::ResponseError(response.text()?));
86 }
87
88 let response_string = response.text()?;
89 let response_struct: GetModel = serde_json::from_str(&response_string)?;
90
91 Ok(response_struct)
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use crate::{config::Config, errors::ReplicateError, Replicate};
98
99 use httpmock::{Method::GET, MockServer};
100 use serde_json::json;
101
102 #[test]
103 fn test_get() -> Result<(), ReplicateError> {
104 let server = MockServer::start();
105
106 let get_mock = server.mock(|when, then| {
107 when.method(GET).path("/models/replicate/hello-world");
108 then.status(200).json_body_obj(&json!( {
109 "url": "https://replicate.com/replicate/hello-world",
110 "owner": "replicate",
111 "name": "hello-world",
112 "description": "A tiny model that says hello",
113 "visibility": "public",
114 "github_url": "https://github.com/replicate/cog-examples",
115 "paper_url": None::<String>,
116 "license_url": None::<String>,
117 "run_count": 12345,
118 "cover_image_url": "",
119 "default_example": {},
120 "latest_version": {}
121 }
122 ));
123 });
124
125 let config = Config {
126 auth: String::from("test"),
127 base_url: server.base_url(),
128 ..Config::default()
129 };
130 let replicate = Replicate::new(config);
131
132 let result = replicate.models.get("replicate", "hello-world")?;
133
134 println!("{:?}", result);
135
136 // Ensure the mocks were called as expected
137 get_mock.assert();
138
139 Ok(())
140 }
141}