replicate_rust/
collection.rs1use crate::{
21 api_definitions::{GetCollectionModels, ListCollectionModels},
22 errors::ReplicateError,
23};
24
25#[derive(Clone, Debug)]
27pub struct Collection {
28 pub parent: crate::config::Config,
30}
31
32impl Collection {
33 pub fn new(rep: crate::config::Config) -> Self {
35 Self { parent: rep }
36 }
37
38 pub fn get(&self, collection_slug: &str) -> Result<GetCollectionModels, ReplicateError> {
54 let client = reqwest::blocking::Client::new();
55
56 let response = client
57 .get(format!(
58 "{}/collections/{}",
59 self.parent.base_url, collection_slug
60 ))
61 .header("Authorization", format!("Token {}", self.parent.auth))
62 .header("User-Agent", &self.parent.user_agent)
63 .send()?;
64
65 if !response.status().is_success() {
66 return Err(ReplicateError::ResponseError(response.text()?));
67 }
68
69 let response_string = response.text()?;
70 let response_struct: GetCollectionModels = serde_json::from_str(&response_string)?;
71
72 Ok(response_struct)
73 }
74
75 pub fn list(&self) -> Result<ListCollectionModels, ReplicateError> {
91 let client = reqwest::blocking::Client::new();
92
93 let response = client
94 .get(format!("{}/collections", self.parent.base_url))
95 .header("Authorization", format!("Token {}", self.parent.auth))
96 .header("User-Agent", &self.parent.user_agent)
97 .send()?;
98
99 if !response.status().is_success() {
100 return Err(ReplicateError::ResponseError(response.text()?));
101 }
102
103 let response_string = response.text()?;
104 let response_struct: ListCollectionModels = serde_json::from_str(&response_string)?;
105
106 Ok(response_struct)
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use crate::{config::Config, errors::ReplicateError, Replicate};
113
114 use httpmock::{Method::GET, MockServer};
115 use serde_json::json;
116
117 #[test]
118 fn test_get() -> Result<(), ReplicateError> {
119 let server = MockServer::start();
120
121 let get_mock = server.mock(|when, then| {
122 when.method(GET)
123 .path("/collections/super-resolution");
124 then.status(200).json_body_obj(&json!( {
125 "name": "Super resolution",
126 "slug": "super-resolution",
127 "description": "Upscaling models that create high-quality images from low-quality images.",
128 "models": [],
129 }));
130 });
131
132 let config = Config {
133 auth: String::from("test"),
134 base_url: server.base_url(),
135 ..Config::default()
136 };
137 let replicate = Replicate::new(config);
138
139 let result = replicate.collections.get("super-resolution");
140
141 assert_eq!(result?.name, "Super resolution");
143
144 get_mock.assert();
146
147 Ok(())
148 }
149
150 #[test]
151 fn test_list() -> Result<(), ReplicateError> {
152 let server = MockServer::start();
153
154 let get_mock = server.mock(|when, then| {
155 when.method(GET)
156 .path("/collections");
157 then.status(200).json_body_obj(&json!( {
158 "results": [
159 {
160 "name": "Super resolution",
161 "slug": "super-resolution",
162 "description": "Upscaling models that create high-quality images from low-quality images.",
163 },
164 {
165 "name": "Image classification",
166 "slug": "image-classification",
167 "description": "Models that classify images.",
168 },
169 ],
170 "next": None::<String>,
171 "previous": None::<String>,
172 }));
173 });
174
175 let config: Config = Config {
176 auth: String::from("test"),
177 base_url: server.base_url(),
178 ..Config::default()
179 };
180 let replicate = Replicate::new(config);
181
182 let result = replicate.collections.list()?;
183
184 assert_eq!(result.results.len(), 2);
186
187 get_mock.assert();
189
190 Ok(())
191 }
192}