Skip to main content

burn_central_client/experiment/
mod.rs

1pub mod request;
2pub mod response;
3pub mod websocket;
4
5use crate::{
6    Client, ClientError, WebSocketClient,
7    experiment::{request::CreateExperimentSchema, response::ExperimentResponse},
8    websocket::WebSocketError,
9};
10
11impl Client {
12    /// Formats a WebSocket URL for the given experiment.
13    fn format_websocket_url(&self, owner_name: &str, project_name: &str, exp_num: i32) -> String {
14        let mut url = self.join(&format!(
15            "projects/{owner_name}/{project_name}/experiments/{exp_num}/ws"
16        ));
17        url.set_scheme(if self.base_url.scheme() == "https" {
18            "wss"
19        } else {
20            "ws"
21        })
22        .expect("Should be able to set ws scheme");
23
24        url.to_string()
25    }
26
27    /// Create a new experiment for the given project.
28    ///
29    /// The client must be logged in before calling this method.
30    pub fn create_experiment(
31        &self,
32        owner_name: &str,
33        project_name: &str,
34        description: Option<String>,
35        code_version_digest: String,
36        routine: String,
37    ) -> Result<ExperimentResponse, ClientError> {
38        let url = self.join(&format!("projects/{owner_name}/{project_name}/experiments"));
39
40        // Create a new experiment
41        let experiment_response = self.post_json::<CreateExperimentSchema, ExperimentResponse>(
42            url,
43            Some(CreateExperimentSchema {
44                description,
45                code_version_digest,
46                routine_run: routine,
47            }),
48        )?;
49
50        Ok(experiment_response)
51    }
52
53    pub fn create_experiment_run_websocket(
54        &self,
55        owner_name: &str,
56        project_name: &str,
57        exp_num: i32,
58    ) -> Result<WebSocketClient, WebSocketError> {
59        let mut ws_client = WebSocketClient::new();
60
61        let ws_endpoint = self.format_websocket_url(owner_name, project_name, exp_num);
62
63        ws_client
64            .connect(
65                ws_endpoint,
66                &self.session_cookie.clone().unwrap_or("".to_string()),
67            )
68            .map_err(|e| WebSocketError::ConnectionError(e.to_string()))?;
69
70        Ok(ws_client)
71    }
72
73    /// Cancel an experiment.
74    ///
75    /// The client must be logged in before calling this method.
76    pub fn cancel_experiment(
77        &self,
78        owner_name: &str,
79        project_name: &str,
80        exp_num: i32,
81    ) -> Result<(), ClientError> {
82        let url = self.join(&format!(
83            "projects/{owner_name}/{project_name}/experiments/{exp_num}/cancel"
84        ));
85
86        self.post(url, None::<()>)
87    }
88}