oss_vizier/
lib.rs

1// Copyright 2022 Sebastien Soudan.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Unofficial OSS Vizier Client API.
16//!
17//! See <https://github.com/google/vizier> for OSS Vizier backend.
18//!
19//! ```no_run
20//! let endpoint = std::env::var("ENDPOINT").unwrap_or_else(|_| "http://localhost:28080".to_string());
21//!
22//! let service = VizierServiceClient::connect(endpoint).await.unwrap();
23//!
24//! let owner = "owner".to_string();
25//!
26//! let mut client = VizierClient::new_with_service(owner, service)
27//!
28//! let request = client
29//!     .mk_list_studies_request_builder()
30//!     .with_page_size(2)
31//!     .build();
32//!
33//! let studies = client.service.list_studies(request).await.unwrap();
34//! let study_list = &studies.get_ref().studies;
35//! for t in study_list {
36//!     println!("- {}", &t.display_name);
37//! }
38//! ```
39
40use std::time::Duration;
41
42use prost::bytes::Bytes;
43pub use prost_types;
44use tokio::time::sleep;
45use tonic::codegen::http::uri::InvalidUri;
46use tonic::codegen::{Body, StdError};
47
48use crate::google::longrunning::{GetOperationRequest, Operation, operation};
49use crate::model::{study, trial};
50use crate::study::StudyName;
51use crate::trial::complete::FinalMeasurementOrReason;
52use crate::trial::{TrialName, early_stopping, optimal, stop};
53use crate::vizier::vizier_service_client::VizierServiceClient;
54use crate::vizier::{
55    AddTrialMeasurementRequest, CheckTrialEarlyStoppingStateRequest, CompleteTrialRequest,
56    CreateTrialRequest, DeleteStudyRequest, DeleteTrialRequest, GetStudyRequest, GetTrialRequest,
57    ListOptimalTrialsRequest, Measurement, StopTrialRequest, SuggestTrialsRequest,
58    SuggestTrialsResponse, Trial,
59};
60
61pub mod model;
62pub mod util;
63
64/// google protos.
65#[allow(missing_docs)]
66pub mod google {
67    /// google.apis protos.
68    pub mod api {
69        #![allow(clippy::derive_partial_eq_without_eq)]
70        #![allow(clippy::doc_overindented_list_items)]
71        tonic::include_proto!("google.api");
72    }
73
74    /// google.rpc protos.
75    pub mod rpc {
76        #![allow(clippy::derive_partial_eq_without_eq)]
77        tonic::include_proto!("google.rpc");
78    }
79
80    /// google.longrunning protos.
81    pub mod longrunning {
82        #![allow(clippy::derive_partial_eq_without_eq)]
83        tonic::include_proto!("google.longrunning");
84    }
85}
86
87/// vizier oss proto
88pub mod vizier {
89    #![allow(missing_docs)]
90    #![allow(clippy::derive_partial_eq_without_eq)]
91    tonic::include_proto!("vizier");
92}
93
94/// Vizier client.
95#[derive(Clone)]
96pub struct VizierClient<T> {
97    owner: String,
98    /// The Vizier service client.
99    pub service: VizierServiceClient<T>,
100}
101
102/// Errors that can occur when using [VizierClient].
103#[derive(thiserror::Error, Debug)]
104pub enum Error {
105    /// Transport error
106    #[error("tonic transport error - {0}")]
107    Tonic(#[from] tonic::transport::Error),
108    /// Invalid URI.
109    #[error("{0}")]
110    InvalidUri(#[from] InvalidUri),
111    /// Decoding error.
112    #[error("{0}")]
113    DecodingError(#[from] util::Error),
114    /// Vizier service error.
115    #[error("Status: {}", .0.message())]
116    Status(#[from] tonic::Status),
117}
118
119impl<T> VizierClient<T>
120where
121    T: tonic::client::GrpcService<tonic::body::Body>,
122    T::Error: Into<StdError>,
123    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
124    <T::ResponseBody as Body>::Error: Into<StdError> + Send,
125{
126    /// Creates a new Vizier client.
127    pub fn new(owner: String, service: VizierServiceClient<T>) -> Self {
128        Self { owner, service }
129    }
130
131    /// Creates a new [crate::vizier::CreateStudyRequest] builder.
132    pub fn mk_study_request_builder(&self) -> study::create::RequestBuilder {
133        study::create::RequestBuilder::new(self.owner.clone())
134    }
135
136    /// Creates a new [GetStudyRequest].
137    pub fn mk_get_study_request(&self, study_name: StudyName) -> GetStudyRequest {
138        study::get::RequestBuilder::new(study_name).build()
139    }
140
141    /// Creates a new [DeleteStudyRequest].
142    pub fn mk_delete_study_request(&self, study_name: StudyName) -> DeleteStudyRequest {
143        study::delete::RequestBuilder::new(study_name).build()
144    }
145
146    /// Creates a new [crate::vizier::ListStudiesRequest] builder.
147    pub fn mk_list_studies_request_builder(&self) -> study::list::RequestBuilder {
148        study::list::RequestBuilder::new(self.owner.clone())
149    }
150
151    /// Creates a new [GetTrialRequest].
152    pub fn mk_get_trial_request(&self, trial_name: TrialName) -> GetTrialRequest {
153        trial::get::RequestBuilder::new(trial_name).build()
154    }
155
156    /// Creates a new [SuggestTrialsRequest].
157    pub fn mk_suggest_trials_request(
158        &self,
159        study_name: StudyName,
160        suggestion_count: i32,
161        client_id: String,
162    ) -> SuggestTrialsRequest {
163        trial::suggest::RequestBuilder::new(study_name, suggestion_count, client_id).build()
164    }
165
166    /// Creates a new [CreateTrialRequest].
167    pub fn mk_create_trial_request(
168        &self,
169        study_name: StudyName,
170        trial: Trial,
171    ) -> CreateTrialRequest {
172        trial::create::RequestBuilder::new(study_name, trial).build()
173    }
174
175    /// Creates a new [DeleteTrialRequest].
176    pub fn mk_delete_trial_request(&self, trial_name: TrialName) -> DeleteTrialRequest {
177        trial::delete::RequestBuilder::new(trial_name).build()
178    }
179
180    /// Creates a new [crate::vizier::ListTrialsRequest] builder.
181    pub fn mk_list_trials_request_builder(
182        &self,
183        study_name: StudyName,
184    ) -> trial::list::RequestBuilder {
185        trial::list::RequestBuilder::new(study_name)
186    }
187
188    /// Creates a new [AddTrialMeasurementRequest].
189    pub fn mk_add_trial_measurement_request(
190        &self,
191        trial_name: TrialName,
192        measurement: Measurement,
193    ) -> AddTrialMeasurementRequest {
194        trial::add_measurement::RequestBuilder::new(trial_name, measurement).build()
195    }
196
197    /// Creates a new [CompleteTrialRequest].
198    pub fn mk_complete_trial_request(
199        &self,
200        trial_name: TrialName,
201        final_measurement: FinalMeasurementOrReason,
202    ) -> CompleteTrialRequest {
203        trial::complete::RequestBuilder::new(trial_name, final_measurement).build()
204    }
205
206    /// Creates a new [CheckTrialEarlyStoppingStateRequest].
207    pub fn mk_check_trial_early_stopping_state_request(
208        &self,
209        trial_name: TrialName,
210    ) -> CheckTrialEarlyStoppingStateRequest {
211        early_stopping::RequestBuilder::new(trial_name).build()
212    }
213
214    /// Creates a new [StopTrialRequest].
215    pub fn mk_stop_trial_request(&self, trial_name: TrialName) -> StopTrialRequest {
216        stop::RequestBuilder::new(trial_name).build()
217    }
218
219    /// Creates a new [ListOptimalTrialsRequest].
220    pub fn mk_list_optimal_trials_request(
221        &self,
222        study_name: StudyName,
223    ) -> ListOptimalTrialsRequest {
224        optimal::RequestBuilder::new(study_name).build()
225    }
226
227    /// Creates a [TrialName] (of the form
228    /// "owners/{owner}/studies/{study}/trials/{trial}"). #
229    /// Arguments
230    /// * `study` - The study number - {study} in the pattern.
231    /// * `trial` - The trial number - {trial} in the pattern.
232    pub fn trial_name(&self, study: String, trial: String) -> TrialName {
233        TrialName::new(self.owner.clone(), study, trial)
234    }
235
236    /// Creates a [TrialName] from a [StudyName] and trial number.
237    /// # Arguments
238    /// * `study_name` - The [StudyName].
239    /// * `trial` - The trial number.
240    pub fn trial_name_from_study(
241        &self,
242        study_name: &StudyName,
243        trial: impl Into<String>,
244    ) -> TrialName {
245        TrialName::from_study(study_name, trial.into())
246    }
247
248    /// Creates a [StudyName] (of the form
249    /// "owners/{owner}/studies/{study}").  
250    /// # Arguments
251    ///  * `study` - The study number - {study} in the pattern.
252    pub fn study_name(&self, study: impl Into<String>) -> StudyName {
253        StudyName::new(self.owner.clone(), study.into())
254    }
255
256    /// Waits for an operation to be completed.
257    /// Makes `retries` attempts and return the error if it still fails.
258    /// # Arguments
259    /// * `retries` - The number of retries.
260    /// * `operation` - The operation to wait for.
261    pub async fn wait_for_operation(
262        &mut self,
263        mut retries: usize,
264        mut operation: Operation,
265    ) -> Result<Option<operation::Result>, Error> {
266        while !operation.done {
267            let mut wait_ms = 500;
268            let resp = loop {
269                match self
270                    .service
271                    .get_operation(GetOperationRequest {
272                        name: operation.name.clone(),
273                    })
274                    .await
275                {
276                    Err(_) if retries > 0 => {
277                        retries -= 1;
278                        sleep(Duration::from_millis(wait_ms)).await;
279                        wait_ms *= 2;
280                    }
281                    res => break res,
282                }
283            }?;
284
285            operation = resp.into_inner();
286        }
287
288        Ok(operation.result)
289    }
290
291    /// Gets the [operation::Result] of an [Operation] specified by its name.
292    pub async fn get_operation(
293        &mut self,
294        operation_name: String,
295    ) -> Result<Option<operation::Result>, Error> {
296        let resp = self
297            .service
298            .get_operation(GetOperationRequest {
299                name: operation_name,
300            })
301            .await?;
302
303        let operation = resp.into_inner();
304
305        if operation.done {
306            Ok(operation.result)
307        } else {
308            Ok(None)
309        }
310    }
311
312    /// Suggests trials to a study.
313    pub async fn suggest_trials(
314        &mut self,
315        request: SuggestTrialsRequest,
316    ) -> Result<SuggestTrialsResponse, Error> {
317        let trials = self.service.suggest_trials(request).await?;
318        let operation = trials.into_inner();
319
320        let result = loop {
321            if let Some(result) = self.get_operation(operation.name.clone()).await? {
322                break result;
323            }
324            sleep(Duration::from_millis(100)).await;
325        };
326
327        // parse the result into trials
328        let resp: SuggestTrialsResponse =
329            util::decode_operation_result_as(result, "SuggestTrialsResponse")?;
330
331        Ok(resp)
332    }
333}
334
335#[cfg(test)]
336mod trials {
337    use std::time::Duration;
338
339    use tonic::Code;
340
341    use super::common::{create_dummy_study, test_client};
342    use crate::SuggestTrialsResponse;
343    use crate::trial::complete::FinalMeasurementOrReason;
344    use crate::util::decode_operation_result_as;
345    use crate::vizier::{Measurement, measurement};
346
347    #[tokio::test]
348    async fn it_can_get_a_trial() {
349        let mut client = test_client().await;
350
351        let study_name = "it_can_get_a_trial".to_string();
352
353        // create a study
354        create_dummy_study(
355            &mut client,
356            "ALGORITHM_UNSPECIFIED".to_string(),
357            study_name.clone(),
358        )
359        .await;
360
361        let study_name = client.study_name(study_name);
362
363        // suggest a trial
364        let _resp = client
365            .suggest_trials(client.mk_suggest_trials_request(
366                study_name.clone(),
367                1,
368                "it_can_get_a_trial".to_string(),
369            ))
370            .await
371            .unwrap();
372
373        // get the trial
374        let trial = "1".to_string();
375
376        dbg!(&study_name);
377        let trial_name = client.trial_name_from_study(&study_name, trial);
378        dbg!(&trial_name);
379        let request = client.mk_get_trial_request(trial_name);
380
381        let trial = client.service.get_trial(request).await.unwrap();
382        let trial = trial.get_ref();
383        dbg!(trial);
384    }
385
386    #[tokio::test]
387    async fn it_deletes_a_trial() {
388        let mut client = test_client().await;
389
390        let study = "53316451264".to_string();
391        let trial = "2".to_string();
392
393        let study_name = client.study_name(study);
394        let trial_name = client.trial_name_from_study(&study_name, trial);
395
396        let request = client.mk_delete_trial_request(trial_name);
397
398        match client.service.delete_trial(request).await {
399            Ok(study) => {
400                let study = study.get_ref();
401                dbg!(study);
402            }
403            Err(err) => {
404                // dbg!(&err);
405                assert_eq!(err.code(), Code::Unknown);
406            }
407        }
408    }
409
410    #[tokio::test]
411    async fn it_suggests_trials_raw() {
412        let mut client = test_client().await;
413
414        let study_name = "it_suggests_trials_raw".to_string();
415
416        // create a study
417        create_dummy_study(
418            &mut client,
419            "ALGORITHM_UNSPECIFIED".to_string(),
420            study_name.clone(),
421        )
422        .await;
423
424        let study_name = client.study_name(study_name);
425
426        let client_id = "it_can_suggest_trials".to_string();
427
428        let request = client.mk_suggest_trials_request(study_name, 1, client_id);
429
430        let resp = client.service.suggest_trials(request).await.unwrap();
431        let operation = resp.into_inner();
432
433        if let Some(result) = client.wait_for_operation(3, operation).await.unwrap() {
434            // parse the result into trials
435            let resp: SuggestTrialsResponse =
436                decode_operation_result_as(result, "SuggestTrialsResponse").unwrap();
437
438            dbg!(&resp);
439
440            assert_eq!(resp.trials.len(), 1);
441        } else {
442            panic!("no result");
443        }
444    }
445
446    #[tokio::test]
447    async fn it_suggests_trials() {
448        let mut client = test_client().await;
449
450        let study_name = "it_suggests_trials".to_string();
451
452        // create a study
453        create_dummy_study(
454            &mut client,
455            "ALGORITHM_UNSPECIFIED".to_string(),
456            study_name.clone(),
457        )
458        .await;
459
460        let study_name = client.study_name(study_name);
461
462        let client_id = "it_can_suggest_trials".to_string();
463
464        let request = client.mk_suggest_trials_request(study_name, 1, client_id);
465
466        let resp = client.suggest_trials(request).await.unwrap();
467
468        dbg!(resp);
469    }
470
471    #[tokio::test]
472    async fn it_lists_trials() {
473        let mut client = test_client().await;
474
475        let study_name = "it_lists_trials".to_string();
476
477        // create a study
478        create_dummy_study(
479            &mut client,
480            "ALGORITHM_UNSPECIFIED".to_string(),
481            study_name.clone(),
482        )
483        .await;
484
485        let study_name = client.study_name(study_name);
486
487        // suggest 3 trials
488        let client_id = "it_can_list_trials".to_string();
489        let request = client.mk_suggest_trials_request(study_name.clone(), 3, client_id);
490
491        let _resp = client.suggest_trials(request).await.unwrap();
492
493        // list the trials
494        let request = client
495            .mk_list_trials_request_builder(study_name.clone())
496            .with_page_size(2)
497            .build();
498
499        let trials = client.service.list_trials(request).await.unwrap();
500        let trial_list = &trials.get_ref().trials;
501        for t in trial_list {
502            dbg!(&t);
503        }
504
505        if !trials.get_ref().next_page_token.is_empty() {
506            let mut page_token = trials.get_ref().next_page_token.clone();
507
508            while !page_token.is_empty() {
509                println!("There is more! - {:?}", &page_token);
510
511                let request = client
512                    .mk_list_trials_request_builder(study_name.clone())
513                    .with_page_token(page_token)
514                    .with_page_size(2)
515                    .build();
516
517                let trials = client.service.list_trials(request).await.unwrap();
518                let trial_list = &trials.get_ref().trials;
519                for t in trial_list {
520                    dbg!(&t);
521                }
522
523                page_token = trials.get_ref().next_page_token.clone();
524            }
525        }
526    }
527
528    #[tokio::test]
529    async fn it_can_add_trial_measurement() {
530        let mut client = test_client().await;
531
532        let study_name = "it_can_add_trial_measurement".to_string();
533
534        // create a study
535        create_dummy_study(
536            &mut client,
537            "ALGORITHM_UNSPECIFIED".to_string(),
538            study_name.clone(),
539        )
540        .await;
541
542        let study_name = client.study_name(study_name);
543
544        // create trials
545        let client_id = "it_can_add_trial_measurement".to_string();
546
547        let request = client.mk_suggest_trials_request(study_name.clone(), 1, client_id);
548
549        let resp = client.service.suggest_trials(request).await.unwrap();
550        let operation = resp.into_inner();
551
552        if let Some(result) = client.wait_for_operation(3, operation).await.unwrap() {
553            // parse the result into trials
554            let resp: SuggestTrialsResponse =
555                decode_operation_result_as(result, "SuggestTrialsResponse").unwrap();
556
557            dbg!(&resp);
558
559            assert_eq!(resp.trials.len(), 1);
560        } else {
561            panic!("no result");
562        }
563
564        // do something with the trials
565
566        let trial = "1".to_string();
567
568        let trial_name = client.trial_name_from_study(&study_name, trial);
569
570        let measurement = Measurement {
571            elapsed_duration: Some(Duration::from_secs(10).try_into().unwrap()),
572            step_count: 13,
573            metrics: vec![measurement::Metric {
574                metric_id: "m1".to_string(),
575                value: 2.1,
576            }],
577        };
578
579        let request = client.mk_add_trial_measurement_request(trial_name, measurement);
580
581        let trial = client.service.add_trial_measurement(request).await.unwrap();
582        let trial = trial.get_ref();
583        dbg!(trial);
584    }
585
586    #[tokio::test]
587    async fn it_can_complete_a_trial() {
588        let mut client = test_client().await;
589
590        let study = "blah2".to_string();
591        let trial = "3".to_string();
592
593        let study_name = client.study_name(study);
594        let trial_name = client.trial_name_from_study(&study_name, trial);
595
596        let final_measurement_or_reason = FinalMeasurementOrReason::FinalMeasurement(Measurement {
597            elapsed_duration: Some(Duration::from_secs(100).try_into().unwrap()),
598            step_count: 14,
599            metrics: vec![measurement::Metric {
600                metric_id: "m1".to_string(),
601                value: 3.1,
602            }],
603        });
604
605        let request = client.mk_complete_trial_request(trial_name, final_measurement_or_reason);
606
607        match client.service.complete_trial(request).await {
608            Ok(trial) => {
609                let trial = trial.get_ref();
610                dbg!(trial);
611            }
612            Err(e) => {
613                dbg!(e);
614            }
615        };
616    }
617
618    #[tokio::test]
619    async fn it_can_check_trial_early_stopping_state() {
620        let mut client = test_client().await;
621
622        let study_name = "it_can_check_trial_early_stopping_state".to_string();
623
624        // create a study
625        create_dummy_study(&mut client, "RANDOM_SEARCH".to_string(), study_name.clone()).await;
626
627        let study_name = client.study_name(study_name);
628
629        // create trials
630        let client_id = "it_can_check_trial_early_stopping_state".to_string();
631
632        let request = client.mk_suggest_trials_request(study_name.clone(), 1, client_id);
633
634        let resp = client.service.suggest_trials(request).await.unwrap();
635        let operation = resp.into_inner();
636
637        if let Some(result) = client.wait_for_operation(3, operation).await.unwrap() {
638            // parse the result into trials
639            let resp: SuggestTrialsResponse =
640                decode_operation_result_as(result, "SuggestTrialsResponse").unwrap();
641
642            dbg!(&resp);
643
644            assert_eq!(resp.trials.len(), 1);
645        } else {
646            panic!("no result");
647        }
648
649        // do something with the trials
650
651        let trial = "1".to_string();
652
653        let trial_name = client.trial_name_from_study(&study_name, trial);
654
655        let request = client.mk_check_trial_early_stopping_state_request(trial_name);
656
657        let resp = client
658            .service
659            .check_trial_early_stopping_state(request)
660            .await
661            .unwrap();
662
663        dbg!(&resp);
664
665        let result = resp.into_inner();
666
667        dbg!(result);
668    }
669
670    #[tokio::test]
671    async fn it_can_stop_a_trial() {
672        let mut client = test_client().await;
673
674        let study = "blah2".to_string();
675        let trial = "3".to_string();
676
677        let study_name = client.study_name(study);
678        let trial_name = client.trial_name_from_study(&study_name, trial);
679
680        let request = client.mk_stop_trial_request(trial_name);
681
682        match client.service.stop_trial(request).await {
683            Ok(trial) => {
684                let trial = trial.get_ref();
685                dbg!(trial);
686            }
687            Err(err) => {
688                dbg!(err);
689            }
690        };
691    }
692
693    #[tokio::test]
694    async fn it_lists_optimal_trials() {
695        let mut client = test_client().await;
696
697        let study_name = "it_lists_optimal_trials".to_string();
698
699        // create a study
700        create_dummy_study(
701            &mut client,
702            "ALGORITHM_UNSPECIFIED".to_string(),
703            study_name.clone(),
704        )
705        .await;
706
707        let study_name = client.study_name(study_name);
708
709        let request = client.mk_list_optimal_trials_request(study_name);
710
711        let trials = client.service.list_optimal_trials(request).await.unwrap();
712        let trial_list = &trials.get_ref().optimal_trials;
713        for t in trial_list {
714            dbg!(&t.name);
715        }
716    }
717}
718
719#[cfg(test)]
720mod studies {
721    use tonic::Code;
722
723    use super::common::{create_dummy_study, test_client};
724    use crate::study::spec::StudySpecBuilder;
725    use crate::vizier::study_spec::metric_spec::GoalType;
726    use crate::vizier::study_spec::parameter_spec::{
727        DoubleValueSpec, IntegerValueSpec, ParameterValueSpec, ScaleType,
728    };
729    use crate::vizier::study_spec::{MetricSpec, ObservationNoise, ParameterSpec};
730
731    #[tokio::test]
732    async fn it_lists_studies() {
733        let mut client = test_client().await;
734
735        // create a study
736        create_dummy_study(
737            &mut client,
738            "ALGORITHM_UNSPECIFIED".to_string(),
739            "it_lists_studies_1".to_string(),
740        )
741        .await;
742        create_dummy_study(
743            &mut client,
744            "ALGORITHM_UNSPECIFIED".to_string(),
745            "it_lists_studies_2".to_string(),
746        )
747        .await;
748
749        // list studies
750        let request = client
751            .mk_list_studies_request_builder()
752            .with_page_size(2)
753            .build();
754
755        let studies = client.service.list_studies(request).await.unwrap();
756        let study_list_resp = studies.get_ref();
757        let study_list = &study_list_resp.studies;
758        for t in study_list {
759            dbg!(&t.name);
760            dbg!(&t.display_name);
761        }
762
763        if !studies.get_ref().next_page_token.is_empty() {
764            let mut page_token = studies.get_ref().next_page_token.clone();
765
766            while !page_token.is_empty() {
767                println!("There is more! - {:?}", &page_token);
768
769                let request = client
770                    .mk_list_studies_request_builder()
771                    .with_page_token(page_token)
772                    .with_page_size(2)
773                    .build();
774
775                let studies = client.service.list_studies(request).await.unwrap();
776                let study_list = &studies.get_ref().studies;
777                for t in study_list {
778                    dbg!(&t.display_name);
779                }
780
781                page_token = studies.get_ref().next_page_token.clone();
782            }
783        }
784    }
785
786    #[tokio::test]
787    async fn it_creates_studies() {
788        let mut client = test_client().await;
789
790        let study_spec =
791            StudySpecBuilder::new("ALGORITHM_UNSPECIFIED".to_string(), ObservationNoise::Low)
792                .with_metric_specs(vec![MetricSpec {
793                    metric_id: "m1".to_string(),
794                    goal: GoalType::Maximize as i32,
795                    safety_config: None,
796                }])
797                .with_parameters(vec![
798                    ParameterSpec {
799                        parameter_id: "a".to_string(),
800                        scale_type: ScaleType::Unspecified as i32,
801                        conditional_parameter_specs: vec![],
802                        parameter_value_spec: Some(ParameterValueSpec::DoubleValueSpec(
803                            DoubleValueSpec {
804                                min_value: 0.0,
805                                max_value: 12.0,
806                                default_value: Some(4.0),
807                            },
808                        )),
809                    },
810                    ParameterSpec {
811                        parameter_id: "b".to_string(),
812                        scale_type: ScaleType::Unspecified as i32,
813                        conditional_parameter_specs: vec![],
814                        parameter_value_spec: Some(ParameterValueSpec::IntegerValueSpec(
815                            IntegerValueSpec {
816                                min_value: 4,
817                                max_value: 10,
818                                default_value: Some(7),
819                            },
820                        )),
821                    },
822                ])
823                .build();
824
825        let request = client
826            .mk_study_request_builder()
827            .with_display_name("blah2".to_string())
828            .with_study_spec(study_spec)
829            .build()
830            .unwrap();
831
832        match client.service.create_study(request).await {
833            Ok(study_response) => {
834                let study = study_response.get_ref();
835                dbg!(&study);
836            }
837            Err(e) => {
838                dbg!(e);
839            }
840        }
841    }
842
843    #[tokio::test]
844    async fn it_can_get_a_study() {
845        let mut client = test_client().await;
846
847        let study_name = "it_can_get_a_study".to_string();
848
849        // create a study
850        create_dummy_study(
851            &mut client,
852            "ALGORITHM_UNSPECIFIED".to_string(),
853            study_name.clone(),
854        )
855        .await;
856
857        let study_name = client.study_name(study_name);
858
859        let request = client.mk_get_study_request(study_name);
860
861        let study = client.service.get_study(request).await.unwrap();
862        let study = study.get_ref();
863        dbg!(study);
864    }
865
866    #[tokio::test]
867    async fn it_deletes_a_study() {
868        let mut client = test_client().await;
869
870        let study = "blah_to_delete".to_string();
871        let study_name = client.study_name(study);
872
873        let request = client.mk_delete_study_request(study_name);
874
875        match client.service.delete_study(request).await {
876            Ok(study) => {
877                let study = study.get_ref();
878                dbg!(study);
879            }
880            Err(err) => {
881                assert_eq!(err.code(), Code::Unknown);
882            }
883        }
884    }
885}
886
887#[cfg(test)]
888mod common {
889
890    use tonic::transport::Channel;
891
892    use crate::VizierClient;
893    use crate::study::spec::StudySpecBuilder;
894    use crate::vizier::study_spec::metric_spec::GoalType;
895    use crate::vizier::study_spec::parameter_spec::{
896        DoubleValueSpec, IntegerValueSpec, ParameterValueSpec, ScaleType,
897    };
898    use crate::vizier::study_spec::{MetricSpec, ObservationNoise, ParameterSpec};
899    use crate::vizier::vizier_service_client::VizierServiceClient;
900
901    pub(crate) async fn test_client() -> VizierClient<Channel> {
902        let endpoint =
903            std::env::var("ENDPOINT").unwrap_or_else(|_| "http://localhost:28080".to_string());
904
905        let service = VizierServiceClient::connect(endpoint).await.unwrap();
906
907        let owner = "owner".to_string();
908
909        VizierClient::new(owner, service)
910    }
911
912    pub(crate) async fn create_dummy_study(
913        client: &mut VizierClient<Channel>,
914        algorithm: String,
915        study_name: String,
916    ) {
917        let study_spec = StudySpecBuilder::new(algorithm, ObservationNoise::Low)
918            .with_metric_specs(vec![MetricSpec {
919                metric_id: "m1".to_string(),
920                goal: GoalType::Maximize as i32,
921                safety_config: None,
922            }])
923            .with_parameters(vec![
924                ParameterSpec {
925                    parameter_id: "a".to_string(),
926                    scale_type: ScaleType::Unspecified as i32,
927                    conditional_parameter_specs: vec![],
928                    parameter_value_spec: Some(ParameterValueSpec::DoubleValueSpec(
929                        DoubleValueSpec {
930                            min_value: 0.0,
931                            max_value: 12.0,
932                            default_value: Some(4.0),
933                        },
934                    )),
935                },
936                ParameterSpec {
937                    parameter_id: "b".to_string(),
938                    scale_type: ScaleType::Unspecified as i32,
939                    conditional_parameter_specs: vec![],
940                    parameter_value_spec: Some(ParameterValueSpec::IntegerValueSpec(
941                        IntegerValueSpec {
942                            min_value: 4,
943                            max_value: 10,
944                            default_value: Some(7),
945                        },
946                    )),
947                },
948            ])
949            .build();
950
951        let request = client
952            .mk_study_request_builder()
953            .with_display_name(study_name)
954            .with_study_spec(study_spec)
955            .build()
956            .unwrap();
957
958        match client.service.create_study(request).await {
959            Ok(study_response) => {
960                let study = study_response.get_ref();
961                dbg!(&study);
962            }
963            Err(e) => {
964                dbg!(e);
965            }
966        }
967    }
968}