replicate_rust/
training.rs

1//! Used to interact with the [Training Endpoints](https://replicate.com/docs/reference/http#trainings.create).
2//!
3//!
4//! # Example
5//!
6//! ```
7//! use replicate_rust::{Replicate, config::Config, training::TrainingOptions};
8//! use std::collections::HashMap;
9//! 
10//! let config = Config::default();
11//! let replicate = Replicate::new(config);
12//! 
13//! let mut input = HashMap::new();
14//! input.insert(String::from("train_data"), String::from("https://example.com/70k_samples.jsonl"));
15//!
16//! let result = replicate.trainings.create(
17//!     "owner",
18//!     "model",
19//!     "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532",
20//!     TrainingOptions {
21//!         destination: String::from("new_owner/new_name"),
22//!         input,
23//!         webhook: String::from("https://example.com/my-webhook"),
24//!         _webhook_events_filter: None,
25//!     },
26//! )?;
27//! # Ok::<(), replicate_rust::errors::ReplicateError>(())
28//! ```
29//!
30//!
31
32use std::collections::HashMap;
33
34use crate::{api_definitions::{CreateTraining, GetTraining, ListTraining, WebhookEvents}, errors::ReplicateError};
35
36/// Contains all the options for creating a training.
37pub struct TrainingOptions {
38
39    /// A string representing the desired model to push to in the format {destination_model_owner}/{destination_model_name}. This should be an existing model owned by the user or organization making the API request. If the destination is invalid, the server returns an appropriate 4XX response.
40    pub destination: String,
41
42    /// An object containing inputs to the Cog model's train() function.
43    pub input: HashMap<String, String>,
44
45    /// An HTTPS URL for receiving a webhook when the training completes. The webhook will be a POST request where the request body is the same as the response body of the get training operation. If there are network problems, we will retry the webhook a few times, so make sure it can be safely called more than once.
46    pub webhook: String,
47
48    /// TO only send specifc events to the webhook, use this field. If not specified, all events will be sent. TODO : Add this to the API 
49    pub _webhook_events_filter: Option<WebhookEvents>,
50}
51
52
53/// Data to be sent to the API when creating a training.
54#[derive(Debug, serde::Serialize, serde::Deserialize)]
55pub struct CreateTrainingPayload {
56
57    /// A string representing the desired model to push to in the format {destination_model_owner}/{destination_model_name}. This should be an existing model owned by the user or organization making the API request. If the destination is invalid, the server returns an appropriate 4XX response.
58    pub destination: String,
59
60    /// An object containing inputs to the Cog model's train() function.
61    pub input: HashMap<String, String>,
62
63    /// An HTTPS URL for receiving a webhook when the training completes. The webhook will be a POST request where the request body is the same as the response body of the get training operation. If there are network problems, we will retry the webhook a few times, so make sure it can be safely called more than once.
64    pub webhook: String,
65}
66
67/// Used to interact with the [Training Endpoints](https://replicate.com/docs/reference/http#trainings.create).
68#[derive(Clone, Debug)]
69pub struct Training {
70    /// Holds a reference to a Configuration struct, which contains the base url, auth token among other settings.
71    pub parent: crate::config::Config,
72}
73
74/// Training struct contains all the functionality for interacting with the training endpoints of the Replicate API.
75impl Training {
76
77    /// Create a new Training struct.
78    pub fn new(rep: crate::config::Config) -> Self {
79        Self { parent: rep }
80    }
81
82    /// Create a new training.
83    /// 
84    /// # Arguments
85    /// * `model_owner` - The name of the user or organization that owns the model.
86    /// * `model_name` - The name of the model.
87    /// * `version_id` - The ID of the version.
88    /// * `options` - The options for creating a training.
89    ///     * `destination` - A string representing the desired model to push to in the format {destination_model_owner}/{destination_model_name}. This should be an existing model owned by the user or organization making the API request. If the destination is invalid, the server returns an appropriate 4XX response.
90    ///    * `input` - An object containing inputs to the Cog model's train() function.
91    ///   * `webhook` - An HTTPS URL for receiving a webhook when the training completes. The webhook will be a POST request where the request body is the same as the response body of the get training operation. If there are network problems, we will retry the webhook a few times, so make sure it can be safely called more than once.
92    ///  * `_webhook_events_filter` - TO only send specifc events to the webhook, use this field. If not specified, all events will be sent. The following events are supported:
93    /// 
94    /// # Example
95    /// ```
96    /// use replicate_rust::{Replicate, config::Config, training::TrainingOptions};
97    /// use std::collections::HashMap;
98    /// 
99    /// let config = Config::default();
100    /// let replicate = Replicate::new(config);
101    /// 
102    /// let mut input = HashMap::new();
103    /// input.insert(String::from("training_data"), String::from("https://example.com/70k_samples.jsonl"));
104    /// 
105    /// let result = replicate.trainings.create(
106    ///    "owner",
107    ///    "model",
108    ///   "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532",
109    ///  TrainingOptions {
110    ///     destination: String::from("new_owner/new_name"),
111    ///     input,
112    ///     webhook: String::from("https://example.com/my-webhook"),
113    ///     _webhook_events_filter: None,
114    /// },
115    /// )?;
116    /// 
117    /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
118    /// ```
119    /// 
120    pub fn create(
121        &self,
122        model_owner: &str,
123        model_name: &str,
124        version_id: &str,
125        options: TrainingOptions,
126    ) -> Result<CreateTraining, ReplicateError> {
127        let client = reqwest::blocking::Client::new();
128
129        let payload = CreateTrainingPayload {
130            destination: options.destination,
131            input: options.input,
132            webhook: options.webhook,
133        };
134
135        let response = client
136            .post(format!(
137                "{}/models/{}/{}/versions/{}/trainings",
138                self.parent.base_url, model_owner, model_name, version_id,
139            ))
140            .header("Authorization", format!("Token {}", self.parent.auth))
141            .header("User-Agent", &self.parent.user_agent)
142            .json(&payload)
143                .send()?;
144
145        if !response.status().is_success() {
146            return Err(ReplicateError::ResponseError(response.text()?));
147        }
148
149        let response_string = response.text()?;
150        let response_struct: CreateTraining = serde_json::from_str(&response_string)?;
151
152        Ok(response_struct)
153    }
154
155
156    /// Get the details of a training.
157    /// 
158    /// # Arguments
159    /// * `training_id` - The ID of the training you want to get.
160    /// 
161    /// # Example
162    /// ```
163    /// use replicate_rust::{Replicate, config::Config};
164    /// 
165    /// let config = Config::default();
166    /// let replicate = Replicate::new(config);
167    /// 
168    /// let training = replicate.trainings.get("zz4ibbonubfz7carwiefibzgga")?;
169    /// println!("Training : {:?}", training);
170    /// 
171    /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
172    /// ``` 
173    pub fn get(&self, training_id: &str) -> Result<GetTraining, ReplicateError> {
174        let client = reqwest::blocking::Client::new();
175
176        let response = client
177            .get(format!(
178                "{}/trainings/{}",
179                self.parent.base_url, training_id,
180            ))
181            .header("Authorization", format!("Token {}", self.parent.auth))
182            .header("User-Agent", &self.parent.user_agent)
183                .send()?;
184
185        if !response.status().is_success() {
186            return Err(ReplicateError::ResponseError(response.text()?));
187        }
188
189        let response_string = response.text()?;
190        let response_struct: GetTraining = serde_json::from_str(&response_string)?;
191
192        Ok(response_struct)
193    }
194
195    /// Get a paginated list of trainings that you've created with your account. Returns 100 records per page.
196    /// 
197    /// # Example
198    /// ```
199    /// use replicate_rust::{Replicate, config::Config};
200    /// 
201    /// let config = Config::default();
202    /// let replicate = Replicate::new(config);
203    /// 
204    /// let trainings = replicate.trainings.list()?;
205    /// println!("Trainings : {:?}", trainings);
206    /// 
207    /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
208    /// ```
209    pub fn list(&self) -> Result<ListTraining, ReplicateError> {
210        let client = reqwest::blocking::Client::new();
211
212        let response = client
213            .get(format!("{}/trainings", self.parent.base_url,))
214            .header("Authorization", format!("Token {}", self.parent.auth))
215            .header("User-Agent", &self.parent.user_agent)
216                .send()?;
217
218        if !response.status().is_success() {
219            return Err(ReplicateError::ResponseError(response.text()?));
220        }
221
222        let response_string = response.text()?;
223        let response_struct: ListTraining = serde_json::from_str(&response_string)?;
224
225        Ok(response_struct)
226    }
227
228    /// Cancel a training.
229    /// 
230    /// # Arguments
231    /// * `training_id` - The ID of the training you want to cancel.
232    /// 
233    /// # Example
234    /// ```
235    /// use replicate_rust::{Replicate, config::Config};
236    /// 
237    /// let config = Config::default();
238    /// let replicate = Replicate::new(config);
239    /// 
240    /// let result =  replicate.trainings.cancel("zz4ibbonubfz7carwiefibzgga")?;
241    /// println!("Result : {:?}", result);
242    /// 
243    /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
244    /// ```
245    pub fn cancel(&self, training_id: &str) -> Result<GetTraining, ReplicateError> {
246        let client = reqwest::blocking::Client::new();
247
248        let response = client
249            .post(format!(
250                "{}/trainings/{}/cancel",
251                self.parent.base_url, training_id
252            ))
253            .header("Authorization", format!("Token {}", self.parent.auth))
254            .header("User-Agent", &self.parent.user_agent)
255                .send()?;
256
257        if !response.status().is_success() {
258            return Err(ReplicateError::ResponseError(response.text()?));
259        }
260        let response_string = response.text()?;
261        let response_struct: GetTraining = serde_json::from_str(&response_string)?;
262
263        Ok(response_struct)
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use crate::{api_definitions::PredictionStatus, config::Config, Replicate};
270
271    use super::*;
272    use httpmock::{
273        Method::{GET, POST},
274        MockServer,
275    };
276    use serde_json::json;
277
278    #[test]
279    fn test_create() -> Result<(), ReplicateError> {
280        let server = MockServer::start();
281
282        let post_mock = server.mock(|when, then| {
283            when.method(POST).path("/models/owner/model/versions/632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532/trainings");
284            then.status(200).json_body_obj(&json!( {
285                "id": "zz4ibbonubfz7carwiefibzgga",
286                "version": "{version}",
287                "status": "starting",
288                "input": {
289                  "text": "...",
290                },
291                "output": None::<String>,
292                "error": None::<String>,
293                "logs": None::<String>,
294                "started_at": None::<String>,
295                "created_at": "2023-03-28T21:47:58.566434Z",
296                "completed_at": None::<String>,
297            }
298            ));
299        });
300
301        let config = Config {
302            auth: String::from("test"),
303            base_url: server.base_url(),
304            ..Config::default()
305        };
306        let replicate = Replicate::new(config);
307
308        let mut input = HashMap::new();
309        input.insert(String::from("text"),String::from("..."));
310
311        let result = replicate.trainings.create(
312            "owner",
313            "model",
314            "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532",
315            TrainingOptions {
316                destination: String::from("new_owner/new_model"),
317                input,
318                webhook: String::from("webhook"),
319                _webhook_events_filter: None,
320            },
321        );
322
323        assert_eq!(result?.id, "zz4ibbonubfz7carwiefibzgga");
324        // Ensure the mocks were called as expected
325        post_mock.assert();
326
327        Ok(())
328    }
329
330    #[test]
331    fn test_get() -> Result<(), ReplicateError> {
332        let server = MockServer::start();
333
334        let get_mock = server.mock(|when, then| {
335            when.method(GET)
336                .path("/trainings/zz4ibbonubfz7carwiefibzgga");
337            then.status(200).json_body_obj(&json!( {
338                "id": "zz4ibbonubfz7carwiefibzgga",
339                "version": "{version}",
340                "status": "succeeded",
341                "input": {
342                  "text": "...",
343                  "param" : "..."
344                },
345                "output": {
346                    "version": "...",
347                  },
348                "error": None::<String>,
349                "logs": None::<String>,
350                "webhook_completed": None::<String>,
351                "started_at": None::<String>,
352                "created_at": "2023-03-28T21:47:58.566434Z",
353                "completed_at": None::<String>,
354            }
355            ));
356        });
357
358        let config = Config {
359            auth: String::from("test"),
360            base_url: server.base_url(),
361            ..Config::default()
362        };
363        let replicate = Replicate::new(config);
364
365        let result = replicate
366            .trainings
367            .get("zz4ibbonubfz7carwiefibzgga");
368
369        assert_eq!(result?.status, PredictionStatus::succeeded);
370        // Ensure the mocks were called as expected
371        get_mock.assert();
372
373        Ok(())
374    }
375
376    #[test]
377    fn test_cancel() -> Result<(), ReplicateError> {
378        let server = MockServer::start();
379
380        let get_mock = server.mock(|when, then| {
381            when.method(POST)
382                .path("/trainings/zz4ibbonubfz7carwiefibzgga/cancel");
383            then.status(200).json_body_obj(&json!( {
384                "id": "zz4ibbonubfz7carwiefibzgga",
385                "version": "{version}",
386                "status": "canceled",
387                "input": {
388                  "text": "...",
389                  "param1" : "..."
390                },
391                "output": {
392                    "version": "...",
393                  },
394                "error": None::<String>,
395                "logs": None::<String>,
396                "webhook_completed": None::<String>,
397                "started_at": None::<String>,
398                "created_at": "2023-03-28T21:47:58.566434Z",
399                "completed_at": None::<String>,
400            }
401            ));
402        });
403
404        let config = Config {
405            auth: String::from("test"),
406            base_url: server.base_url(),
407            ..Config::default()
408        };
409        let replicate = Replicate::new(config);
410
411        let result = replicate
412            .trainings
413            .cancel("zz4ibbonubfz7carwiefibzgga")?;
414
415        assert_eq!(result.status, PredictionStatus::canceled);
416        // Ensure the mocks were called as expected
417        get_mock.assert();
418
419        Ok(())
420    }
421
422    #[test]
423    fn test_list() -> Result<(), ReplicateError> {
424        let server = MockServer::start();
425
426        let get_mock = server.mock(|when, then| {
427            when.method(GET).path("/trainings");
428            then.status(200).json_body_obj(&json!( {
429                "next": "https://api.replicate.com/v1/trainings?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw",
430                "previous": None::<String>,
431                "results": [
432                  {
433                    "id": "jpzd7hm5gfcapbfyt4mqytarku",
434                    "version": "b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05",
435                    "urls": {
436                      "get": "https://api.replicate.com/v1/trainings/jpzd7hm5gfcapbfyt4mqytarku",
437                      "cancel": "https://api.replicate.com/v1/trainings/jpzd7hm5gfcapbfyt4mqytarku/cancel"
438                    },
439                    "created_at": "2022-04-26T20:00:40.658234Z",
440                    "started_at": "2022-04-26T20:00:84.583803Z",
441                    "completed_at": "2022-04-26T20:02:27.648305Z",
442                    "source": "web",
443                    "status": "succeeded"
444                  }
445                ]
446              }
447              
448            ));
449        });
450
451        let config = Config {
452            auth: String::from("test"),
453            base_url: server.base_url(),
454            ..Config::default()
455        };
456        let replicate = Replicate::new(config);
457
458        let result = replicate.trainings.list()?;
459
460        assert_eq!(result.results.len(), 1);
461        assert_eq!(result.results[0].id, "jpzd7hm5gfcapbfyt4mqytarku");
462
463        // Ensure the mocks were called as expected
464        get_mock.assert();
465
466        Ok(())
467    }
468}