replicate_rust/training.rs
1//! Used to interact with the [Training Endpoints](https://replicate.com/docs/reference/http#trainings.create).
2//!
3//!
4//! # Example
5//!
6//! ```
7//! use replicate_rust::{Replicate, config::Config, training::TrainingOptions};
8//! use std::collections::HashMap;
9//!
10//! let config = Config::default();
11//! let replicate = Replicate::new(config);
12//!
13//! let mut input = HashMap::new();
14//! input.insert(String::from("train_data"), String::from("https://example.com/70k_samples.jsonl"));
15//!
16//! let result = replicate.trainings.create(
17//! "owner",
18//! "model",
19//! "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532",
20//! TrainingOptions {
21//! destination: String::from("new_owner/new_name"),
22//! input,
23//! webhook: String::from("https://example.com/my-webhook"),
24//! _webhook_events_filter: None,
25//! },
26//! )?;
27//! # Ok::<(), replicate_rust::errors::ReplicateError>(())
28//! ```
29//!
30//!
31
32use std::collections::HashMap;
33
34use crate::{api_definitions::{CreateTraining, GetTraining, ListTraining, WebhookEvents}, errors::ReplicateError};
35
36/// Contains all the options for creating a training.
37pub struct TrainingOptions {
38
39 /// A string representing the desired model to push to in the format {destination_model_owner}/{destination_model_name}. This should be an existing model owned by the user or organization making the API request. If the destination is invalid, the server returns an appropriate 4XX response.
40 pub destination: String,
41
42 /// An object containing inputs to the Cog model's train() function.
43 pub input: HashMap<String, String>,
44
45 /// An HTTPS URL for receiving a webhook when the training completes. The webhook will be a POST request where the request body is the same as the response body of the get training operation. If there are network problems, we will retry the webhook a few times, so make sure it can be safely called more than once.
46 pub webhook: String,
47
48 /// TO only send specifc events to the webhook, use this field. If not specified, all events will be sent. TODO : Add this to the API
49 pub _webhook_events_filter: Option<WebhookEvents>,
50}
51
52
53/// Data to be sent to the API when creating a training.
54#[derive(Debug, serde::Serialize, serde::Deserialize)]
55pub struct CreateTrainingPayload {
56
57 /// A string representing the desired model to push to in the format {destination_model_owner}/{destination_model_name}. This should be an existing model owned by the user or organization making the API request. If the destination is invalid, the server returns an appropriate 4XX response.
58 pub destination: String,
59
60 /// An object containing inputs to the Cog model's train() function.
61 pub input: HashMap<String, String>,
62
63 /// An HTTPS URL for receiving a webhook when the training completes. The webhook will be a POST request where the request body is the same as the response body of the get training operation. If there are network problems, we will retry the webhook a few times, so make sure it can be safely called more than once.
64 pub webhook: String,
65}
66
67/// Used to interact with the [Training Endpoints](https://replicate.com/docs/reference/http#trainings.create).
68#[derive(Clone, Debug)]
69pub struct Training {
70 /// Holds a reference to a Configuration struct, which contains the base url, auth token among other settings.
71 pub parent: crate::config::Config,
72}
73
74/// Training struct contains all the functionality for interacting with the training endpoints of the Replicate API.
75impl Training {
76
77 /// Create a new Training struct.
78 pub fn new(rep: crate::config::Config) -> Self {
79 Self { parent: rep }
80 }
81
82 /// Create a new training.
83 ///
84 /// # Arguments
85 /// * `model_owner` - The name of the user or organization that owns the model.
86 /// * `model_name` - The name of the model.
87 /// * `version_id` - The ID of the version.
88 /// * `options` - The options for creating a training.
89 /// * `destination` - A string representing the desired model to push to in the format {destination_model_owner}/{destination_model_name}. This should be an existing model owned by the user or organization making the API request. If the destination is invalid, the server returns an appropriate 4XX response.
90 /// * `input` - An object containing inputs to the Cog model's train() function.
91 /// * `webhook` - An HTTPS URL for receiving a webhook when the training completes. The webhook will be a POST request where the request body is the same as the response body of the get training operation. If there are network problems, we will retry the webhook a few times, so make sure it can be safely called more than once.
92 /// * `_webhook_events_filter` - TO only send specifc events to the webhook, use this field. If not specified, all events will be sent. The following events are supported:
93 ///
94 /// # Example
95 /// ```
96 /// use replicate_rust::{Replicate, config::Config, training::TrainingOptions};
97 /// use std::collections::HashMap;
98 ///
99 /// let config = Config::default();
100 /// let replicate = Replicate::new(config);
101 ///
102 /// let mut input = HashMap::new();
103 /// input.insert(String::from("training_data"), String::from("https://example.com/70k_samples.jsonl"));
104 ///
105 /// let result = replicate.trainings.create(
106 /// "owner",
107 /// "model",
108 /// "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532",
109 /// TrainingOptions {
110 /// destination: String::from("new_owner/new_name"),
111 /// input,
112 /// webhook: String::from("https://example.com/my-webhook"),
113 /// _webhook_events_filter: None,
114 /// },
115 /// )?;
116 ///
117 /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
118 /// ```
119 ///
120 pub fn create(
121 &self,
122 model_owner: &str,
123 model_name: &str,
124 version_id: &str,
125 options: TrainingOptions,
126 ) -> Result<CreateTraining, ReplicateError> {
127 let client = reqwest::blocking::Client::new();
128
129 let payload = CreateTrainingPayload {
130 destination: options.destination,
131 input: options.input,
132 webhook: options.webhook,
133 };
134
135 let response = client
136 .post(format!(
137 "{}/models/{}/{}/versions/{}/trainings",
138 self.parent.base_url, model_owner, model_name, version_id,
139 ))
140 .header("Authorization", format!("Token {}", self.parent.auth))
141 .header("User-Agent", &self.parent.user_agent)
142 .json(&payload)
143 .send()?;
144
145 if !response.status().is_success() {
146 return Err(ReplicateError::ResponseError(response.text()?));
147 }
148
149 let response_string = response.text()?;
150 let response_struct: CreateTraining = serde_json::from_str(&response_string)?;
151
152 Ok(response_struct)
153 }
154
155
156 /// Get the details of a training.
157 ///
158 /// # Arguments
159 /// * `training_id` - The ID of the training you want to get.
160 ///
161 /// # Example
162 /// ```
163 /// use replicate_rust::{Replicate, config::Config};
164 ///
165 /// let config = Config::default();
166 /// let replicate = Replicate::new(config);
167 ///
168 /// let training = replicate.trainings.get("zz4ibbonubfz7carwiefibzgga")?;
169 /// println!("Training : {:?}", training);
170 ///
171 /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
172 /// ```
173 pub fn get(&self, training_id: &str) -> Result<GetTraining, ReplicateError> {
174 let client = reqwest::blocking::Client::new();
175
176 let response = client
177 .get(format!(
178 "{}/trainings/{}",
179 self.parent.base_url, training_id,
180 ))
181 .header("Authorization", format!("Token {}", self.parent.auth))
182 .header("User-Agent", &self.parent.user_agent)
183 .send()?;
184
185 if !response.status().is_success() {
186 return Err(ReplicateError::ResponseError(response.text()?));
187 }
188
189 let response_string = response.text()?;
190 let response_struct: GetTraining = serde_json::from_str(&response_string)?;
191
192 Ok(response_struct)
193 }
194
195 /// Get a paginated list of trainings that you've created with your account. Returns 100 records per page.
196 ///
197 /// # Example
198 /// ```
199 /// use replicate_rust::{Replicate, config::Config};
200 ///
201 /// let config = Config::default();
202 /// let replicate = Replicate::new(config);
203 ///
204 /// let trainings = replicate.trainings.list()?;
205 /// println!("Trainings : {:?}", trainings);
206 ///
207 /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
208 /// ```
209 pub fn list(&self) -> Result<ListTraining, ReplicateError> {
210 let client = reqwest::blocking::Client::new();
211
212 let response = client
213 .get(format!("{}/trainings", self.parent.base_url,))
214 .header("Authorization", format!("Token {}", self.parent.auth))
215 .header("User-Agent", &self.parent.user_agent)
216 .send()?;
217
218 if !response.status().is_success() {
219 return Err(ReplicateError::ResponseError(response.text()?));
220 }
221
222 let response_string = response.text()?;
223 let response_struct: ListTraining = serde_json::from_str(&response_string)?;
224
225 Ok(response_struct)
226 }
227
228 /// Cancel a training.
229 ///
230 /// # Arguments
231 /// * `training_id` - The ID of the training you want to cancel.
232 ///
233 /// # Example
234 /// ```
235 /// use replicate_rust::{Replicate, config::Config};
236 ///
237 /// let config = Config::default();
238 /// let replicate = Replicate::new(config);
239 ///
240 /// let result = replicate.trainings.cancel("zz4ibbonubfz7carwiefibzgga")?;
241 /// println!("Result : {:?}", result);
242 ///
243 /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
244 /// ```
245 pub fn cancel(&self, training_id: &str) -> Result<GetTraining, ReplicateError> {
246 let client = reqwest::blocking::Client::new();
247
248 let response = client
249 .post(format!(
250 "{}/trainings/{}/cancel",
251 self.parent.base_url, training_id
252 ))
253 .header("Authorization", format!("Token {}", self.parent.auth))
254 .header("User-Agent", &self.parent.user_agent)
255 .send()?;
256
257 if !response.status().is_success() {
258 return Err(ReplicateError::ResponseError(response.text()?));
259 }
260 let response_string = response.text()?;
261 let response_struct: GetTraining = serde_json::from_str(&response_string)?;
262
263 Ok(response_struct)
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use crate::{api_definitions::PredictionStatus, config::Config, Replicate};
270
271 use super::*;
272 use httpmock::{
273 Method::{GET, POST},
274 MockServer,
275 };
276 use serde_json::json;
277
278 #[test]
279 fn test_create() -> Result<(), ReplicateError> {
280 let server = MockServer::start();
281
282 let post_mock = server.mock(|when, then| {
283 when.method(POST).path("/models/owner/model/versions/632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532/trainings");
284 then.status(200).json_body_obj(&json!( {
285 "id": "zz4ibbonubfz7carwiefibzgga",
286 "version": "{version}",
287 "status": "starting",
288 "input": {
289 "text": "...",
290 },
291 "output": None::<String>,
292 "error": None::<String>,
293 "logs": None::<String>,
294 "started_at": None::<String>,
295 "created_at": "2023-03-28T21:47:58.566434Z",
296 "completed_at": None::<String>,
297 }
298 ));
299 });
300
301 let config = Config {
302 auth: String::from("test"),
303 base_url: server.base_url(),
304 ..Config::default()
305 };
306 let replicate = Replicate::new(config);
307
308 let mut input = HashMap::new();
309 input.insert(String::from("text"),String::from("..."));
310
311 let result = replicate.trainings.create(
312 "owner",
313 "model",
314 "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532",
315 TrainingOptions {
316 destination: String::from("new_owner/new_model"),
317 input,
318 webhook: String::from("webhook"),
319 _webhook_events_filter: None,
320 },
321 );
322
323 assert_eq!(result?.id, "zz4ibbonubfz7carwiefibzgga");
324 // Ensure the mocks were called as expected
325 post_mock.assert();
326
327 Ok(())
328 }
329
330 #[test]
331 fn test_get() -> Result<(), ReplicateError> {
332 let server = MockServer::start();
333
334 let get_mock = server.mock(|when, then| {
335 when.method(GET)
336 .path("/trainings/zz4ibbonubfz7carwiefibzgga");
337 then.status(200).json_body_obj(&json!( {
338 "id": "zz4ibbonubfz7carwiefibzgga",
339 "version": "{version}",
340 "status": "succeeded",
341 "input": {
342 "text": "...",
343 "param" : "..."
344 },
345 "output": {
346 "version": "...",
347 },
348 "error": None::<String>,
349 "logs": None::<String>,
350 "webhook_completed": None::<String>,
351 "started_at": None::<String>,
352 "created_at": "2023-03-28T21:47:58.566434Z",
353 "completed_at": None::<String>,
354 }
355 ));
356 });
357
358 let config = Config {
359 auth: String::from("test"),
360 base_url: server.base_url(),
361 ..Config::default()
362 };
363 let replicate = Replicate::new(config);
364
365 let result = replicate
366 .trainings
367 .get("zz4ibbonubfz7carwiefibzgga");
368
369 assert_eq!(result?.status, PredictionStatus::succeeded);
370 // Ensure the mocks were called as expected
371 get_mock.assert();
372
373 Ok(())
374 }
375
376 #[test]
377 fn test_cancel() -> Result<(), ReplicateError> {
378 let server = MockServer::start();
379
380 let get_mock = server.mock(|when, then| {
381 when.method(POST)
382 .path("/trainings/zz4ibbonubfz7carwiefibzgga/cancel");
383 then.status(200).json_body_obj(&json!( {
384 "id": "zz4ibbonubfz7carwiefibzgga",
385 "version": "{version}",
386 "status": "canceled",
387 "input": {
388 "text": "...",
389 "param1" : "..."
390 },
391 "output": {
392 "version": "...",
393 },
394 "error": None::<String>,
395 "logs": None::<String>,
396 "webhook_completed": None::<String>,
397 "started_at": None::<String>,
398 "created_at": "2023-03-28T21:47:58.566434Z",
399 "completed_at": None::<String>,
400 }
401 ));
402 });
403
404 let config = Config {
405 auth: String::from("test"),
406 base_url: server.base_url(),
407 ..Config::default()
408 };
409 let replicate = Replicate::new(config);
410
411 let result = replicate
412 .trainings
413 .cancel("zz4ibbonubfz7carwiefibzgga")?;
414
415 assert_eq!(result.status, PredictionStatus::canceled);
416 // Ensure the mocks were called as expected
417 get_mock.assert();
418
419 Ok(())
420 }
421
422 #[test]
423 fn test_list() -> Result<(), ReplicateError> {
424 let server = MockServer::start();
425
426 let get_mock = server.mock(|when, then| {
427 when.method(GET).path("/trainings");
428 then.status(200).json_body_obj(&json!( {
429 "next": "https://api.replicate.com/v1/trainings?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw",
430 "previous": None::<String>,
431 "results": [
432 {
433 "id": "jpzd7hm5gfcapbfyt4mqytarku",
434 "version": "b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05",
435 "urls": {
436 "get": "https://api.replicate.com/v1/trainings/jpzd7hm5gfcapbfyt4mqytarku",
437 "cancel": "https://api.replicate.com/v1/trainings/jpzd7hm5gfcapbfyt4mqytarku/cancel"
438 },
439 "created_at": "2022-04-26T20:00:40.658234Z",
440 "started_at": "2022-04-26T20:00:84.583803Z",
441 "completed_at": "2022-04-26T20:02:27.648305Z",
442 "source": "web",
443 "status": "succeeded"
444 }
445 ]
446 }
447
448 ));
449 });
450
451 let config = Config {
452 auth: String::from("test"),
453 base_url: server.base_url(),
454 ..Config::default()
455 };
456 let replicate = Replicate::new(config);
457
458 let result = replicate.trainings.list()?;
459
460 assert_eq!(result.results.len(), 1);
461 assert_eq!(result.results[0].id, "jpzd7hm5gfcapbfyt4mqytarku");
462
463 // Ensure the mocks were called as expected
464 get_mock.assert();
465
466 Ok(())
467 }
468}