replicate_rust/
collection.rs

1//! Used to interact with the [Collection Endpoints](https://replicate.com/docs/reference/http#collections.get).
2//!
3//! The Collection struct is initialized with a Config struct.
4//!
5//! # Example
6//!
7//! ```
8//! use replicate_rust::{Replicate, config::Config};
9//!
10//! let config = Config::default();
11//! let replicate = Replicate::new(config);
12//!
13//! let collections = replicate.collections.get("audio-generation")?;
14//! println!("Collection : {:?}", collections);
15//!
16//! # Ok::<(), replicate_rust::errors::ReplicateError>(())
17//! ```
18//!
19
20use crate::{
21    api_definitions::{GetCollectionModels, ListCollectionModels},
22    errors::ReplicateError,
23};
24
25/// Used to interact with the [Collection Endpoints](https://replicate.com/docs/reference/http#collections.get).
26#[derive(Clone, Debug)]
27pub struct Collection {
28    /// Holds a reference to a Config struct, which contains the base url,  auth token among other settings.
29    pub parent: crate::config::Config,
30}
31
32impl Collection {
33    /// Create a new Collection struct.
34    pub fn new(rep: crate::config::Config) -> Self {
35        Self { parent: rep }
36    }
37
38    /// Get a collection by slug.
39    ///
40    /// # Example
41    ///
42    /// ```
43    /// use replicate_rust::{Replicate, config::Config};
44    ///
45    /// let config = Config::default();
46    /// let replicate = Replicate::new(config);
47    ///
48    /// let collections = replicate.collections.get("audio-generation")?;
49    /// println!("Collections : {:?}", collections);
50    ///
51    /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
52    /// ```
53    pub fn get(&self, collection_slug: &str) -> Result<GetCollectionModels, ReplicateError> {
54        let client = reqwest::blocking::Client::new();
55
56        let response = client
57            .get(format!(
58                "{}/collections/{}",
59                self.parent.base_url, collection_slug
60            ))
61            .header("Authorization", format!("Token {}", self.parent.auth))
62            .header("User-Agent", &self.parent.user_agent)
63            .send()?;
64
65        if !response.status().is_success() {
66            return Err(ReplicateError::ResponseError(response.text()?));
67        }
68
69        let response_string = response.text()?;
70        let response_struct: GetCollectionModels = serde_json::from_str(&response_string)?;
71
72        Ok(response_struct)
73    }
74
75    /// List all collections present in Replicate.
76    ///
77    /// # Example
78    ///
79    /// ```
80    /// use replicate_rust::{Replicate, config::Config};
81    ///
82    /// let config = Config::default();
83    /// let replicate = Replicate::new(config);
84    ///
85    /// let collections = replicate.collections.list()?;
86    /// println!("Collections : {:?}", collections);
87    ///
88    /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
89    /// ```
90    pub fn list(&self) -> Result<ListCollectionModels, ReplicateError> {
91        let client = reqwest::blocking::Client::new();
92
93        let response = client
94            .get(format!("{}/collections", self.parent.base_url))
95            .header("Authorization", format!("Token {}", self.parent.auth))
96            .header("User-Agent", &self.parent.user_agent)
97            .send()?;
98
99        if !response.status().is_success() {
100            return Err(ReplicateError::ResponseError(response.text()?));
101        }
102
103        let response_string = response.text()?;
104        let response_struct: ListCollectionModels = serde_json::from_str(&response_string)?;
105
106        Ok(response_struct)
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use crate::{config::Config, errors::ReplicateError, Replicate};
113
114    use httpmock::{Method::GET, MockServer};
115    use serde_json::json;
116
117    #[test]
118    fn test_get() -> Result<(), ReplicateError> {
119        let server = MockServer::start();
120
121        let get_mock = server.mock(|when, then| {
122            when.method(GET)
123                .path("/collections/super-resolution");
124            then.status(200).json_body_obj(&json!( {
125                "name": "Super resolution",
126                "slug": "super-resolution",
127                "description": "Upscaling models that create high-quality images from low-quality images.",
128                "models": [],
129              }));
130        });
131
132        let config = Config {
133            auth: String::from("test"),
134            base_url: server.base_url(),
135            ..Config::default()
136        };
137        let replicate = Replicate::new(config);
138
139        let result = replicate.collections.get("super-resolution");
140
141        // Assert that the returned value is correct
142        assert_eq!(result?.name, "Super resolution");
143
144        // Ensure the mocks were called as expected
145        get_mock.assert();
146
147        Ok(())
148    }
149
150    #[test]
151    fn test_list() -> Result<(), ReplicateError> {
152        let server = MockServer::start();
153
154        let get_mock = server.mock(|when, then| {
155            when.method(GET)
156                .path("/collections");
157            then.status(200).json_body_obj(&json!( {
158                "results": [
159                  {
160                    "name": "Super resolution",
161                    "slug": "super-resolution",
162                    "description": "Upscaling models that create high-quality images from low-quality images.",
163                  },
164                  {
165                    "name": "Image classification",
166                    "slug": "image-classification",
167                    "description": "Models that classify images.",
168                  },
169                ],
170                "next": None::<String>,
171                "previous": None::<String>,
172              }));
173        });
174
175        let config: Config = Config {
176            auth: String::from("test"),
177            base_url: server.base_url(),
178            ..Config::default()
179        };
180        let replicate = Replicate::new(config);
181
182        let result = replicate.collections.list()?;
183
184        // Assert that the returned value is correct
185        assert_eq!(result.results.len(), 2);
186
187        // Ensure the mocks were called as expected
188        get_mock.assert();
189
190        Ok(())
191    }
192}