oss_vizier/
lib.rs

1// Copyright 2022 Sebastien Soudan.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Unofficial OSS Vizier Client API.
16//!
17//! See <https://github.com/google/vizier> for OSS Vizier backend.
18//!
19//! ```no_run
20//! let endpoint = std::env::var("ENDPOINT").unwrap_or_else(|_| "http://localhost:28080".to_string());
21//!
22//! let service = VizierServiceClient::connect(endpoint).await.unwrap();
23//!
24//! let owner = "owner".to_string();
25//!
26//! let mut client = VizierClient::new_with_service(owner, service)
27//!
28//! let request = client
29//!     .mk_list_studies_request_builder()
30//!     .with_page_size(2)
31//!     .build();
32//!
33//! let studies = client.service.list_studies(request).await.unwrap();
34//! let study_list = &studies.get_ref().studies;
35//! for t in study_list {
36//!     println!("- {}", &t.display_name);
37//! }
38//! ```
39
40use std::time::Duration;
41
42use prost::bytes::Bytes;
43pub use prost_types;
44use tokio::time::sleep;
45use tonic::codegen::http::uri::InvalidUri;
46use tonic::codegen::{Body, StdError};
47
48use crate::google::longrunning::{operation, GetOperationRequest, Operation};
49use crate::model::{study, trial};
50use crate::study::StudyName;
51use crate::trial::complete::FinalMeasurementOrReason;
52use crate::trial::{early_stopping, optimal, stop, TrialName};
53use crate::vizier::vizier_service_client::VizierServiceClient;
54use crate::vizier::{
55    AddTrialMeasurementRequest, CheckTrialEarlyStoppingStateRequest, CompleteTrialRequest,
56    CreateTrialRequest, DeleteStudyRequest, DeleteTrialRequest, GetStudyRequest, GetTrialRequest,
57    ListOptimalTrialsRequest, Measurement, StopTrialRequest, SuggestTrialsRequest,
58    SuggestTrialsResponse, Trial,
59};
60
61pub mod model;
62pub mod util;
63
64/// google protos.
65#[allow(missing_docs)]
66pub mod google {
67    /// google.apis protos.
68    pub mod api {
69        #![allow(clippy::derive_partial_eq_without_eq)]
70        tonic::include_proto!("google.api");
71    }
72
73    /// google.rpc protos.
74    pub mod rpc {
75        #![allow(clippy::derive_partial_eq_without_eq)]
76        tonic::include_proto!("google.rpc");
77    }
78
79    /// google.longrunning protos.
80    pub mod longrunning {
81        #![allow(clippy::derive_partial_eq_without_eq)]
82        tonic::include_proto!("google.longrunning");
83    }
84}
85
86/// vizier oss proto
87#[allow(missing_docs)]
88pub mod vizier {
89    #![allow(clippy::derive_partial_eq_without_eq)]
90    tonic::include_proto!("vizier");
91}
92
93/// Vizier client.
94#[derive(Clone)]
95pub struct VizierClient<T> {
96    owner: String,
97    /// The Vizier service client.
98    pub service: VizierServiceClient<T>,
99}
100
101/// Errors that can occur when using [VizierClient].
102#[derive(thiserror::Error, Debug)]
103pub enum Error {
104    /// Transport error
105    #[error("tonic transport error - {0}")]
106    Tonic(#[from] tonic::transport::Error),
107    /// Invalid URI.
108    #[error("{0}")]
109    InvalidUri(#[from] InvalidUri),
110    /// Decoding error.
111    #[error("{0}")]
112    DecodingError(#[from] util::Error),
113    /// Vizier service error.
114    #[error("Status: {}", .0.message())]
115    Status(#[from] tonic::Status),
116}
117
118impl<T> VizierClient<T>
119where
120    T: tonic::client::GrpcService<tonic::body::BoxBody>,
121    T::Error: Into<StdError>,
122    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
123    <T::ResponseBody as Body>::Error: Into<StdError> + Send,
124{
125    /// Creates a new Vizier client.
126    pub fn new(owner: String, service: VizierServiceClient<T>) -> Self {
127        Self { owner, service }
128    }
129
130    /// Creates a new [crate::vizier::CreateStudyRequest] builder.
131    pub fn mk_study_request_builder(&self) -> study::create::RequestBuilder {
132        study::create::RequestBuilder::new(self.owner.clone())
133    }
134
135    /// Creates a new [GetStudyRequest].
136    pub fn mk_get_study_request(&self, study_name: StudyName) -> GetStudyRequest {
137        study::get::RequestBuilder::new(study_name).build()
138    }
139
140    /// Creates a new [DeleteStudyRequest].
141    pub fn mk_delete_study_request(&self, study_name: StudyName) -> DeleteStudyRequest {
142        study::delete::RequestBuilder::new(study_name).build()
143    }
144
145    /// Creates a new [crate::vizier::ListStudiesRequest] builder.
146    pub fn mk_list_studies_request_builder(&self) -> study::list::RequestBuilder {
147        study::list::RequestBuilder::new(self.owner.clone())
148    }
149
150    /// Creates a new [GetTrialRequest].
151    pub fn mk_get_trial_request(&self, trial_name: TrialName) -> GetTrialRequest {
152        trial::get::RequestBuilder::new(trial_name).build()
153    }
154
155    /// Creates a new [SuggestTrialsRequest].
156    pub fn mk_suggest_trials_request(
157        &self,
158        study_name: StudyName,
159        suggestion_count: i32,
160        client_id: String,
161    ) -> SuggestTrialsRequest {
162        trial::suggest::RequestBuilder::new(study_name, suggestion_count, client_id).build()
163    }
164
165    /// Creates a new [CreateTrialRequest].
166    pub fn mk_create_trial_request(
167        &self,
168        study_name: StudyName,
169        trial: Trial,
170    ) -> CreateTrialRequest {
171        trial::create::RequestBuilder::new(study_name, trial).build()
172    }
173
174    /// Creates a new [DeleteTrialRequest].
175    pub fn mk_delete_trial_request(&self, trial_name: TrialName) -> DeleteTrialRequest {
176        trial::delete::RequestBuilder::new(trial_name).build()
177    }
178
179    /// Creates a new [crate::vizier::ListTrialsRequest] builder.
180    pub fn mk_list_trials_request_builder(
181        &self,
182        study_name: StudyName,
183    ) -> trial::list::RequestBuilder {
184        trial::list::RequestBuilder::new(study_name)
185    }
186
187    /// Creates a new [AddTrialMeasurementRequest].
188    pub fn mk_add_trial_measurement_request(
189        &self,
190        trial_name: TrialName,
191        measurement: Measurement,
192    ) -> AddTrialMeasurementRequest {
193        trial::add_measurement::RequestBuilder::new(trial_name, measurement).build()
194    }
195
196    /// Creates a new [CompleteTrialRequest].
197    pub fn mk_complete_trial_request(
198        &self,
199        trial_name: TrialName,
200        final_measurement: FinalMeasurementOrReason,
201    ) -> CompleteTrialRequest {
202        trial::complete::RequestBuilder::new(trial_name, final_measurement).build()
203    }
204
205    /// Creates a new [CheckTrialEarlyStoppingStateRequest].
206    pub fn mk_check_trial_early_stopping_state_request(
207        &self,
208        trial_name: TrialName,
209    ) -> CheckTrialEarlyStoppingStateRequest {
210        early_stopping::RequestBuilder::new(trial_name).build()
211    }
212
213    /// Creates a new [StopTrialRequest].
214    pub fn mk_stop_trial_request(&self, trial_name: TrialName) -> StopTrialRequest {
215        stop::RequestBuilder::new(trial_name).build()
216    }
217
218    /// Creates a new [ListOptimalTrialsRequest].
219    pub fn mk_list_optimal_trials_request(
220        &self,
221        study_name: StudyName,
222    ) -> ListOptimalTrialsRequest {
223        optimal::RequestBuilder::new(study_name).build()
224    }
225
226    /// Creates a [TrialName] (of the form
227    /// "owners/{owner}/studies/{study}/trials/{trial}"). #
228    /// Arguments
229    /// * `study` - The study number - {study} in the pattern.
230    /// * `trial` - The trial number - {trial} in the pattern.
231    pub fn trial_name(&self, study: String, trial: String) -> TrialName {
232        TrialName::new(self.owner.clone(), study, trial)
233    }
234
235    /// Creates a [TrialName] from a [StudyName] and trial number.
236    /// # Arguments
237    /// * `study_name` - The [StudyName].
238    /// * `trial` - The trial number.
239    pub fn trial_name_from_study(
240        &self,
241        study_name: &StudyName,
242        trial: impl Into<String>,
243    ) -> TrialName {
244        TrialName::from_study(study_name, trial.into())
245    }
246
247    /// Creates a [StudyName] (of the form
248    /// "owners/{owner}/studies/{study}").  
249    /// # Arguments
250    ///  * `study` - The study number - {study} in the pattern.
251    pub fn study_name(&self, study: impl Into<String>) -> StudyName {
252        StudyName::new(self.owner.clone(), study.into())
253    }
254
255    /// Waits for an operation to be completed.
256    /// Makes `retries` attempts and return the error if it still fails.
257    /// # Arguments
258    /// * `retries` - The number of retries.
259    /// * `operation` - The operation to wait for.
260    pub async fn wait_for_operation(
261        &mut self,
262        mut retries: usize,
263        mut operation: Operation,
264    ) -> Result<Option<operation::Result>, Error> {
265        while !operation.done {
266            let mut wait_ms = 500;
267            let resp = loop {
268                match self
269                    .service
270                    .get_operation(GetOperationRequest {
271                        name: operation.name.clone(),
272                    })
273                    .await
274                {
275                    Err(_) if retries > 0 => {
276                        retries -= 1;
277                        sleep(Duration::from_millis(wait_ms)).await;
278                        wait_ms *= 2;
279                    }
280                    res => break res,
281                }
282            }?;
283
284            operation = resp.into_inner();
285        }
286
287        Ok(operation.result)
288    }
289
290    /// Gets the [operation::Result] of an [Operation] specified by its name.
291    pub async fn get_operation(
292        &mut self,
293        operation_name: String,
294    ) -> Result<Option<operation::Result>, Error> {
295        let resp = self
296            .service
297            .get_operation(GetOperationRequest {
298                name: operation_name,
299            })
300            .await?;
301
302        let operation = resp.into_inner();
303
304        if operation.done {
305            Ok(operation.result)
306        } else {
307            Ok(None)
308        }
309    }
310
311    /// Suggests trials to a study.
312    pub async fn suggest_trials(
313        &mut self,
314        request: SuggestTrialsRequest,
315    ) -> Result<SuggestTrialsResponse, Error> {
316        let trials = self.service.suggest_trials(request).await?;
317        let operation = trials.into_inner();
318
319        let result = loop {
320            if let Some(result) = self.get_operation(operation.name.clone()).await? {
321                break result;
322            }
323            sleep(Duration::from_millis(100)).await;
324        };
325
326        // parse the result into trials
327        let resp: SuggestTrialsResponse =
328            util::decode_operation_result_as(result, "SuggestTrialsResponse")?;
329
330        Ok(resp)
331    }
332}
333
334#[cfg(test)]
335mod trials {
336    use std::time::Duration;
337
338    use tonic::Code;
339
340    use super::common::{create_dummy_study, test_client};
341    use crate::trial::complete::FinalMeasurementOrReason;
342    use crate::util::decode_operation_result_as;
343    use crate::vizier::{measurement, Measurement};
344    use crate::SuggestTrialsResponse;
345
346    #[tokio::test]
347    async fn it_can_get_a_trial() {
348        let mut client = test_client().await;
349
350        let study_name = "it_can_get_a_trial".to_string();
351
352        // create a study
353        create_dummy_study(
354            &mut client,
355            "ALGORITHM_UNSPECIFIED".to_string(),
356            study_name.clone(),
357        )
358        .await;
359
360        let study_name = client.study_name(study_name);
361
362        // suggest a trial
363        let _resp = client
364            .suggest_trials(client.mk_suggest_trials_request(
365                study_name.clone(),
366                1,
367                "it_can_get_a_trial".to_string(),
368            ))
369            .await
370            .unwrap();
371
372        // get the trial
373        let trial = "1".to_string();
374
375        dbg!(&study_name);
376        let trial_name = client.trial_name_from_study(&study_name, trial);
377        dbg!(&trial_name);
378        let request = client.mk_get_trial_request(trial_name);
379
380        let trial = client.service.get_trial(request).await.unwrap();
381        let trial = trial.get_ref();
382        dbg!(trial);
383    }
384
385    #[tokio::test]
386    async fn it_deletes_a_trial() {
387        let mut client = test_client().await;
388
389        let study = "53316451264".to_string();
390        let trial = "2".to_string();
391
392        let study_name = client.study_name(study);
393        let trial_name = client.trial_name_from_study(&study_name, trial);
394
395        let request = client.mk_delete_trial_request(trial_name);
396
397        match client.service.delete_trial(request).await {
398            Ok(study) => {
399                let study = study.get_ref();
400                dbg!(study);
401            }
402            Err(err) => {
403                // dbg!(&err);
404                assert_eq!(err.code(), Code::Unknown);
405            }
406        }
407    }
408
409    #[tokio::test]
410    async fn it_suggests_trials_raw() {
411        let mut client = test_client().await;
412
413        let study_name = "it_suggests_trials_raw".to_string();
414
415        // create a study
416        create_dummy_study(
417            &mut client,
418            "ALGORITHM_UNSPECIFIED".to_string(),
419            study_name.clone(),
420        )
421        .await;
422
423        let study_name = client.study_name(study_name);
424
425        let client_id = "it_can_suggest_trials".to_string();
426
427        let request = client.mk_suggest_trials_request(study_name, 1, client_id);
428
429        let resp = client.service.suggest_trials(request).await.unwrap();
430        let operation = resp.into_inner();
431
432        if let Some(result) = client.wait_for_operation(3, operation).await.unwrap() {
433            // parse the result into trials
434            let resp: SuggestTrialsResponse =
435                decode_operation_result_as(result, "SuggestTrialsResponse").unwrap();
436
437            dbg!(&resp);
438
439            assert_eq!(resp.trials.len(), 1);
440        } else {
441            panic!("no result");
442        }
443    }
444
445    #[tokio::test]
446    async fn it_suggests_trials() {
447        let mut client = test_client().await;
448
449        let study_name = "it_suggests_trials".to_string();
450
451        // create a study
452        create_dummy_study(
453            &mut client,
454            "ALGORITHM_UNSPECIFIED".to_string(),
455            study_name.clone(),
456        )
457        .await;
458
459        let study_name = client.study_name(study_name);
460
461        let client_id = "it_can_suggest_trials".to_string();
462
463        let request = client.mk_suggest_trials_request(study_name, 1, client_id);
464
465        let resp = client.suggest_trials(request).await.unwrap();
466
467        dbg!(resp);
468    }
469
470    #[tokio::test]
471    async fn it_lists_trials() {
472        let mut client = test_client().await;
473
474        let study_name = "it_lists_trials".to_string();
475
476        // create a study
477        create_dummy_study(
478            &mut client,
479            "ALGORITHM_UNSPECIFIED".to_string(),
480            study_name.clone(),
481        )
482        .await;
483
484        let study_name = client.study_name(study_name);
485
486        // suggest 3 trials
487        let client_id = "it_can_list_trials".to_string();
488        let request = client.mk_suggest_trials_request(study_name.clone(), 3, client_id);
489
490        let _resp = client.suggest_trials(request).await.unwrap();
491
492        // list the trials
493        let request = client
494            .mk_list_trials_request_builder(study_name.clone())
495            .with_page_size(2)
496            .build();
497
498        let trials = client.service.list_trials(request).await.unwrap();
499        let trial_list = &trials.get_ref().trials;
500        for t in trial_list {
501            dbg!(&t);
502        }
503
504        if !trials.get_ref().next_page_token.is_empty() {
505            let mut page_token = trials.get_ref().next_page_token.clone();
506
507            while !page_token.is_empty() {
508                println!("There is more! - {:?}", &page_token);
509
510                let request = client
511                    .mk_list_trials_request_builder(study_name.clone())
512                    .with_page_token(page_token)
513                    .with_page_size(2)
514                    .build();
515
516                let trials = client.service.list_trials(request).await.unwrap();
517                let trial_list = &trials.get_ref().trials;
518                for t in trial_list {
519                    dbg!(&t);
520                }
521
522                page_token = trials.get_ref().next_page_token.clone();
523            }
524        }
525    }
526
527    #[tokio::test]
528    async fn it_can_add_trial_measurement() {
529        let mut client = test_client().await;
530
531        let study_name = "it_can_add_trial_measurement".to_string();
532
533        // create a study
534        create_dummy_study(
535            &mut client,
536            "ALGORITHM_UNSPECIFIED".to_string(),
537            study_name.clone(),
538        )
539        .await;
540
541        let study_name = client.study_name(study_name);
542
543        // create trials
544        let client_id = "it_can_add_trial_measurement".to_string();
545
546        let request = client.mk_suggest_trials_request(study_name.clone(), 1, client_id);
547
548        let resp = client.service.suggest_trials(request).await.unwrap();
549        let operation = resp.into_inner();
550
551        if let Some(result) = client.wait_for_operation(3, operation).await.unwrap() {
552            // parse the result into trials
553            let resp: SuggestTrialsResponse =
554                decode_operation_result_as(result, "SuggestTrialsResponse").unwrap();
555
556            dbg!(&resp);
557
558            assert_eq!(resp.trials.len(), 1);
559        } else {
560            panic!("no result");
561        }
562
563        // do something with the trials
564
565        let trial = "1".to_string();
566
567        let trial_name = client.trial_name_from_study(&study_name, trial);
568
569        let measurement = Measurement {
570            elapsed_duration: Some(Duration::from_secs(10).try_into().unwrap()),
571            step_count: 13,
572            metrics: vec![measurement::Metric {
573                metric_id: "m1".to_string(),
574                value: 2.1,
575            }],
576        };
577
578        let request = client.mk_add_trial_measurement_request(trial_name, measurement);
579
580        let trial = client.service.add_trial_measurement(request).await.unwrap();
581        let trial = trial.get_ref();
582        dbg!(trial);
583    }
584
585    #[tokio::test]
586    async fn it_can_complete_a_trial() {
587        let mut client = test_client().await;
588
589        let study = "blah2".to_string();
590        let trial = "3".to_string();
591
592        let study_name = client.study_name(study);
593        let trial_name = client.trial_name_from_study(&study_name, trial);
594
595        let final_measurement_or_reason = FinalMeasurementOrReason::FinalMeasurement(Measurement {
596            elapsed_duration: Some(Duration::from_secs(100).try_into().unwrap()),
597            step_count: 14,
598            metrics: vec![measurement::Metric {
599                metric_id: "m1".to_string(),
600                value: 3.1,
601            }],
602        });
603
604        let request = client.mk_complete_trial_request(trial_name, final_measurement_or_reason);
605
606        match client.service.complete_trial(request).await {
607            Ok(trial) => {
608                let trial = trial.get_ref();
609                dbg!(trial);
610            }
611            Err(e) => {
612                dbg!(e);
613            }
614        };
615    }
616
617    #[tokio::test]
618    async fn it_can_check_trial_early_stopping_state() {
619        let mut client = test_client().await;
620
621        let study_name = "it_can_check_trial_early_stopping_state".to_string();
622
623        // create a study
624        create_dummy_study(&mut client, "RANDOM_SEARCH".to_string(), study_name.clone()).await;
625
626        let study_name = client.study_name(study_name);
627
628        // create trials
629        let client_id = "it_can_check_trial_early_stopping_state".to_string();
630
631        let request = client.mk_suggest_trials_request(study_name.clone(), 1, client_id);
632
633        let resp = client.service.suggest_trials(request).await.unwrap();
634        let operation = resp.into_inner();
635
636        if let Some(result) = client.wait_for_operation(3, operation).await.unwrap() {
637            // parse the result into trials
638            let resp: SuggestTrialsResponse =
639                decode_operation_result_as(result, "SuggestTrialsResponse").unwrap();
640
641            dbg!(&resp);
642
643            assert_eq!(resp.trials.len(), 1);
644        } else {
645            panic!("no result");
646        }
647
648        // do something with the trials
649
650        let trial = "1".to_string();
651
652        let trial_name = client.trial_name_from_study(&study_name, trial);
653
654        let request = client.mk_check_trial_early_stopping_state_request(trial_name);
655
656        let resp = client
657            .service
658            .check_trial_early_stopping_state(request)
659            .await
660            .unwrap();
661
662        dbg!(&resp);
663
664        let result = resp.into_inner();
665
666        dbg!(result);
667    }
668
669    #[tokio::test]
670    async fn it_can_stop_a_trial() {
671        let mut client = test_client().await;
672
673        let study = "blah2".to_string();
674        let trial = "3".to_string();
675
676        let study_name = client.study_name(study);
677        let trial_name = client.trial_name_from_study(&study_name, trial);
678
679        let request = client.mk_stop_trial_request(trial_name);
680
681        match client.service.stop_trial(request).await {
682            Ok(trial) => {
683                let trial = trial.get_ref();
684                dbg!(trial);
685            }
686            Err(err) => {
687                dbg!(err);
688            }
689        };
690    }
691
692    #[tokio::test]
693    async fn it_lists_optimal_trials() {
694        let mut client = test_client().await;
695
696        let study_name = "it_lists_optimal_trials".to_string();
697
698        // create a study
699        create_dummy_study(
700            &mut client,
701            "ALGORITHM_UNSPECIFIED".to_string(),
702            study_name.clone(),
703        )
704        .await;
705
706        let study_name = client.study_name(study_name);
707
708        let request = client.mk_list_optimal_trials_request(study_name);
709
710        let trials = client.service.list_optimal_trials(request).await.unwrap();
711        let trial_list = &trials.get_ref().optimal_trials;
712        for t in trial_list {
713            dbg!(&t.name);
714        }
715    }
716}
717
718#[cfg(test)]
719mod studies {
720    use tonic::Code;
721
722    use super::common::{create_dummy_study, test_client};
723    use crate::study::spec::StudySpecBuilder;
724    use crate::vizier::study_spec::metric_spec::GoalType;
725    use crate::vizier::study_spec::parameter_spec::{
726        DoubleValueSpec, IntegerValueSpec, ParameterValueSpec, ScaleType,
727    };
728    use crate::vizier::study_spec::{MetricSpec, ObservationNoise, ParameterSpec};
729
730    #[tokio::test]
731    async fn it_lists_studies() {
732        let mut client = test_client().await;
733
734        // create a study
735        create_dummy_study(
736            &mut client,
737            "ALGORITHM_UNSPECIFIED".to_string(),
738            "it_lists_studies_1".to_string(),
739        )
740        .await;
741        create_dummy_study(
742            &mut client,
743            "ALGORITHM_UNSPECIFIED".to_string(),
744            "it_lists_studies_2".to_string(),
745        )
746        .await;
747
748        // list studies
749        let request = client
750            .mk_list_studies_request_builder()
751            .with_page_size(2)
752            .build();
753
754        let studies = client.service.list_studies(request).await.unwrap();
755        let study_list_resp = studies.get_ref();
756        let study_list = &study_list_resp.studies;
757        for t in study_list {
758            dbg!(&t.name);
759            dbg!(&t.display_name);
760        }
761
762        if !studies.get_ref().next_page_token.is_empty() {
763            let mut page_token = studies.get_ref().next_page_token.clone();
764
765            while !page_token.is_empty() {
766                println!("There is more! - {:?}", &page_token);
767
768                let request = client
769                    .mk_list_studies_request_builder()
770                    .with_page_token(page_token)
771                    .with_page_size(2)
772                    .build();
773
774                let studies = client.service.list_studies(request).await.unwrap();
775                let study_list = &studies.get_ref().studies;
776                for t in study_list {
777                    dbg!(&t.display_name);
778                }
779
780                page_token = studies.get_ref().next_page_token.clone();
781            }
782        }
783    }
784
785    #[tokio::test]
786    async fn it_creates_studies() {
787        let mut client = test_client().await;
788
789        let study_spec =
790            StudySpecBuilder::new("ALGORITHM_UNSPECIFIED".to_string(), ObservationNoise::Low)
791                .with_metric_specs(vec![MetricSpec {
792                    metric_id: "m1".to_string(),
793                    goal: GoalType::Maximize as i32,
794                    safety_config: None,
795                }])
796                .with_parameters(vec![
797                    ParameterSpec {
798                        parameter_id: "a".to_string(),
799                        scale_type: ScaleType::Unspecified as i32,
800                        conditional_parameter_specs: vec![],
801                        parameter_value_spec: Some(ParameterValueSpec::DoubleValueSpec(
802                            DoubleValueSpec {
803                                min_value: 0.0,
804                                max_value: 12.0,
805                                default_value: Some(4.0),
806                            },
807                        )),
808                    },
809                    ParameterSpec {
810                        parameter_id: "b".to_string(),
811                        scale_type: ScaleType::Unspecified as i32,
812                        conditional_parameter_specs: vec![],
813                        parameter_value_spec: Some(ParameterValueSpec::IntegerValueSpec(
814                            IntegerValueSpec {
815                                min_value: 4,
816                                max_value: 10,
817                                default_value: Some(7),
818                            },
819                        )),
820                    },
821                ])
822                .build();
823
824        let request = client
825            .mk_study_request_builder()
826            .with_display_name("blah2".to_string())
827            .with_study_spec(study_spec)
828            .build()
829            .unwrap();
830
831        match client.service.create_study(request).await {
832            Ok(study_response) => {
833                let study = study_response.get_ref();
834                dbg!(&study);
835            }
836            Err(e) => {
837                dbg!(e);
838            }
839        }
840    }
841
842    #[tokio::test]
843    async fn it_can_get_a_study() {
844        let mut client = test_client().await;
845
846        let study_name = "it_can_get_a_study".to_string();
847
848        // create a study
849        create_dummy_study(
850            &mut client,
851            "ALGORITHM_UNSPECIFIED".to_string(),
852            study_name.clone(),
853        )
854        .await;
855
856        let study_name = client.study_name(study_name);
857
858        let request = client.mk_get_study_request(study_name);
859
860        let study = client.service.get_study(request).await.unwrap();
861        let study = study.get_ref();
862        dbg!(study);
863    }
864
865    #[tokio::test]
866    async fn it_deletes_a_study() {
867        let mut client = test_client().await;
868
869        let study = "blah_to_delete".to_string();
870        let study_name = client.study_name(study);
871
872        let request = client.mk_delete_study_request(study_name);
873
874        match client.service.delete_study(request).await {
875            Ok(study) => {
876                let study = study.get_ref();
877                dbg!(study);
878            }
879            Err(err) => {
880                assert_eq!(err.code(), Code::Unknown);
881            }
882        }
883    }
884}
885
886#[cfg(test)]
887mod common {
888
889    use tonic::transport::Channel;
890
891    use crate::study::spec::StudySpecBuilder;
892    use crate::vizier::study_spec::metric_spec::GoalType;
893    use crate::vizier::study_spec::parameter_spec::{
894        DoubleValueSpec, IntegerValueSpec, ParameterValueSpec, ScaleType,
895    };
896    use crate::vizier::study_spec::{MetricSpec, ObservationNoise, ParameterSpec};
897    use crate::vizier::vizier_service_client::VizierServiceClient;
898    use crate::VizierClient;
899
900    pub(crate) async fn test_client() -> VizierClient<Channel> {
901        let endpoint =
902            std::env::var("ENDPOINT").unwrap_or_else(|_| "http://localhost:28080".to_string());
903
904        let service = VizierServiceClient::connect(endpoint).await.unwrap();
905
906        let owner = "owner".to_string();
907
908        VizierClient::new(owner, service)
909    }
910
911    pub(crate) async fn create_dummy_study(
912        client: &mut VizierClient<Channel>,
913        algorithm: String,
914        study_name: String,
915    ) {
916        let study_spec = StudySpecBuilder::new(algorithm, ObservationNoise::Low)
917            .with_metric_specs(vec![MetricSpec {
918                metric_id: "m1".to_string(),
919                goal: GoalType::Maximize as i32,
920                safety_config: None,
921            }])
922            .with_parameters(vec![
923                ParameterSpec {
924                    parameter_id: "a".to_string(),
925                    scale_type: ScaleType::Unspecified as i32,
926                    conditional_parameter_specs: vec![],
927                    parameter_value_spec: Some(ParameterValueSpec::DoubleValueSpec(
928                        DoubleValueSpec {
929                            min_value: 0.0,
930                            max_value: 12.0,
931                            default_value: Some(4.0),
932                        },
933                    )),
934                },
935                ParameterSpec {
936                    parameter_id: "b".to_string(),
937                    scale_type: ScaleType::Unspecified as i32,
938                    conditional_parameter_specs: vec![],
939                    parameter_value_spec: Some(ParameterValueSpec::IntegerValueSpec(
940                        IntegerValueSpec {
941                            min_value: 4,
942                            max_value: 10,
943                            default_value: Some(7),
944                        },
945                    )),
946                },
947            ])
948            .build();
949
950        let request = client
951            .mk_study_request_builder()
952            .with_display_name(study_name)
953            .with_study_spec(study_spec)
954            .build()
955            .unwrap();
956
957        match client.service.create_study(request).await {
958            Ok(study_response) => {
959                let study = study_response.get_ref();
960                dbg!(&study);
961            }
962            Err(e) => {
963                dbg!(e);
964            }
965        }
966    }
967}