1use std::collections::HashMap;
2
3use time::OffsetDateTime;
4
5use crate::http::table::TableReference;
6use crate::http::types::{EncryptionConfiguration, StandardSqlField};
7
8pub mod delete;
9pub mod get;
10pub mod list;
11pub mod patch;
12
13#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug)]
14#[serde(rename_all = "camelCase")]
15pub struct Model {
16 pub etag: String,
18 pub model_reference: ModelReference,
20 #[serde(deserialize_with = "crate::http::from_str")]
22 pub creation_time: i64,
23 #[serde(deserialize_with = "crate::http::from_str")]
25 pub last_modified_time: u64,
26 pub description: Option<String>,
28 pub friendly_name: Option<String>,
30 pub labels: Option<HashMap<String, String>>,
37 #[serde(default, deserialize_with = "crate::http::from_str_option")]
40 pub expiration_time: Option<i64>,
41 pub location: Option<String>,
43 pub encryption_configuration: Option<EncryptionConfiguration>,
47 pub model_type: Option<ModelType>,
49 pub training_runs: Option<Vec<TrainingRun>>,
51 pub feature_columns: Option<Vec<StandardSqlField>>,
53 pub label_columns: Option<Vec<StandardSqlField>>,
56 pub hparam_search_spaces: Option<HparamSearchSpaces>,
58 #[serde(default, deserialize_with = "crate::http::from_str_option")]
60 pub default_trial_id: Option<i64>,
61 pub hparam_trials: Option<Vec<HparamTuningTrial>>,
63 #[serde(default, deserialize_with = "crate::http::from_str_vec_option")]
66 pub optimal_trial_ids: Option<Vec<i64>>,
67 pub remote_model_info: Option<RemoteModelInfo>,
69}
70
71#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug)]
72#[serde(rename_all = "camelCase")]
73pub struct TrainingRun {
74 pub training_options: Option<TrainingOptions>,
76 #[serde(default, with = "time::serde::rfc3339::option")]
78 pub start_time: Option<OffsetDateTime>,
79 pub results: Option<Vec<IterationResult>>,
81 pub evaluation_metrics: Option<EvaluationMetrics>,
83 pub data_split_result: Option<DataSplitResult>,
85 pub model_level_global_explanation: Option<GlobalExplanation>,
87 pub class_level_global_explanations: Option<Vec<GlobalExplanation>>,
89 pub vertex_ai_model_id: Option<String>,
91 pub vertex_ai_model_version: Option<String>,
93}
94
95#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
96#[serde(rename_all = "camelCase")]
97pub struct DataSplitResult {
98 pub training_table: Option<TableReference>,
100 pub evaluation_table: Option<TableReference>,
102 pub test_table: Option<TableReference>,
104}
105
106#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug)]
107#[serde(rename_all = "camelCase")]
108pub struct GlobalExplanation {
109 pub explanations: Option<Vec<Explanation>>,
111 pub class_label: Option<String>,
114}
115
116#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug)]
117#[serde(rename_all = "camelCase")]
118pub struct Explanation {
119 pub feature_name: String,
123 pub attribution: f64,
125}
126
127#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug)]
128#[serde(rename_all = "camelCase")]
129pub struct RemoteModelInfo {
130 pub connection: String,
133 #[serde(default, deserialize_with = "crate::http::from_str_option")]
135 pub max_batching_rows: Option<i64>,
136 pub endpoint: String,
138 pub remote_service_type: RemoteServiceType,
140}
141
142#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug)]
143#[serde(rename_all = "camelCase")]
144pub struct HparamSearchSpaces {
145 pub learn_rate: Option<DoubleHparamSearchSpace>,
147 pub l1_reg: Option<DoubleHparamSearchSpace>,
149 pub l2_reg: Option<DoubleHparamSearchSpace>,
151 pub num_clusters: Option<IntHparamSearchSpace>,
153 pub num_factors: Option<IntHparamSearchSpace>,
155 pub hidden_units: Option<IntArrayHparamSearchSpace>,
157 pub batch_size: Option<IntHparamSearchSpace>,
159 pub dropout: Option<DoubleHparamSearchSpace>,
161 pub max_tree_depth: Option<IntHparamSearchSpace>,
163 pub subsample: Option<DoubleHparamSearchSpace>,
165 pub min_split_loss: Option<DoubleHparamSearchSpace>,
167 pub wals_alpha: Option<DoubleHparamSearchSpace>,
169 pub booster_type: Option<StringHparamSearchSpace>,
171 pub num_parallel_tree: Option<IntHparamSearchSpace>,
173 pub dart_normalize_type: Option<StringHparamSearchSpace>,
175 pub tree_method: Option<StringHparamSearchSpace>,
177 pub min_tree_child_weight: Option<IntHparamSearchSpace>,
179 pub colsample_bytree: Option<DoubleHparamSearchSpace>,
181 pub colsample_bylevel: Option<DoubleHparamSearchSpace>,
183 pub colsample_bynode: Option<DoubleHparamSearchSpace>,
185 pub activation_fn: Option<StringHparamSearchSpace>,
187 pub optimizer: Option<StringHparamSearchSpace>,
189}
190
191#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug)]
192#[serde(rename_all = "camelCase")]
193pub enum DoubleHparamSearchSpace {
194 Range(DoubleRange),
195 Candidates(DoubleCandidates),
196}
197
198#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug)]
199#[serde(rename_all = "camelCase")]
200pub struct DoubleRange {
201 pub min: f64,
202 pub max: f64,
203}
204
205#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug)]
206#[serde(rename_all = "camelCase")]
207pub struct DoubleCandidates {
208 pub candidates: Vec<f64>,
209}
210
211#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug)]
212#[serde(rename_all = "camelCase")]
213pub enum IntHparamSearchSpace {
214 Range(IntRange),
215 Candidates(IntCandidates),
216}
217
218#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug)]
219#[serde(rename_all = "camelCase")]
220pub struct IntRange {
221 #[serde(deserialize_with = "crate::http::from_str")]
222 pub min: i64,
223 #[serde(deserialize_with = "crate::http::from_str")]
224 pub max: i64,
225}
226
227#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug)]
228#[serde(rename_all = "camelCase")]
229pub struct IntCandidates {
230 #[serde(deserialize_with = "crate::http::from_str_vec")]
231 pub candidates: Vec<i64>,
232}
233
234#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug)]
235#[serde(rename_all = "camelCase")]
236pub struct IntArrayHparamSearchSpace {
237 pub candidates: Vec<IntArray>,
238}
239
240#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug)]
241#[serde(rename_all = "camelCase")]
242pub struct IntArray {
243 #[serde(deserialize_with = "crate::http::from_str")]
244 pub elements: i64,
245}
246
247#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug)]
248#[serde(rename_all = "camelCase")]
249pub struct StringHparamSearchSpace {
250 pub candidates: Vec<String>,
251}
252
253#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
254#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
255pub enum RemoteServiceType {
256 #[default]
257 RemoteServiceTypeUnspecified,
258 CloudAiTranslateV3,
259 CloudAiVisionV1,
260 CloudAiNaturalLanguageV1,
261}
262
263#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
264#[serde(rename_all = "camelCase")]
265pub struct ModelReference {
266 pub project_id: String,
268 pub dataset_id: String,
270 pub model_id: String,
273}
274
275#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
276#[serde(rename_all = "camelCase")]
277pub struct IterationResult {
278 pub index: i32,
280 #[serde(default, deserialize_with = "crate::http::from_str_option")]
282 pub duration_ms: Option<i64>,
283 pub training_loss: Option<f64>,
285 pub eval_loss: Option<f64>,
287 pub learn_rate: Option<f64>,
289 pub cluster_infos: Option<Vec<ClusterInfo>>,
291 pub arima_result: Option<ArimaResult>,
292 pub principal_component_infos: Option<Vec<PrincipalComponentInfo>>,
294}
295
296#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
297#[serde(rename_all = "camelCase")]
298pub struct ClusterInfo {
299 #[serde(default, deserialize_with = "crate::http::from_str_option")]
301 pub centroid_id: Option<i64>,
302 pub cluster_radius: Option<f64>,
304 #[serde(default, deserialize_with = "crate::http::from_str_option")]
306 pub cluster_size: Option<i64>,
307}
308
309#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
310#[serde(rename_all = "camelCase")]
311pub struct ArimaResult {
312 pub arima_model_info: Option<Vec<ArimaModelInfo>>,
314 pub seasonal_periods: Option<Vec<SeasonalPeriodType>>,
316}
317
318#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
319#[serde(rename_all = "camelCase")]
320pub struct ArimaModelInfo {
321 pub non_seasonal_order: Option<ArimaOrder>,
323 pub arima_coefficients: Option<ArimaCoefficients>,
325 pub arima_fitting_metrics: Option<ArimaFittingMetrics>,
327 pub has_drift: Option<bool>,
329 pub time_series_id: Option<String>,
333 pub time_series_ids: Option<Vec<String>>,
339 pub seasonal_periods: Option<Vec<SeasonalPeriodType>>,
341 pub has_holiday_effect: Option<bool>,
343 pub has_spikes_and_dips: Option<bool>,
345 pub has_step_changes: Option<bool>,
347}
348
349#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
350#[serde(rename_all = "camelCase")]
351pub struct ArimaCoefficients {
352 pub auto_regressive_coefficients: Option<Vec<f64>>,
354 pub moving_average_coefficients: Option<Vec<f64>>,
356 pub intercept_coefficient: Option<f64>,
358}
359
360#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
361#[serde(rename_all = "camelCase")]
362pub struct ArimaFittingMetrics {
363 pub log_likelihood: Option<f64>,
365 pub aic: Option<f64>,
367 pub variance: Option<f64>,
369}
370
371#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
372#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
373pub enum SeasonalPeriodType {
374 #[default]
375 SeasonalPeriodTypeUnspecified,
376 NoSeasonality,
377 Daily,
378 Weekly,
379 Monthly,
380 Quarterly,
381 Yearly,
382}
383
384#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
385#[serde(rename_all = "camelCase")]
386pub struct PrincipalComponentInfo {
387 #[serde(deserialize_with = "crate::http::from_str")]
389 pub principal_component_id: i64,
390 pub explained_variance: Option<f64>,
392 pub explained_variance_ratio: Option<f64>,
394 pub cumulative_explained_variance_ratio: Option<f64>,
397}
398
399#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
400#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
401pub enum ModelType {
402 #[default]
403 ModelTypeUnspecified,
404 LinearRegression,
406 LogisticRegression,
408 Kmeans,
410 MatrixFactorization,
412 DnnClassifier,
414 Tensorflow,
416 DnnRegression,
418 BoostedTreeRegressor,
420 BoostedTreeClassifier,
422 Arima,
424 AutomlRegressor,
426 AutomlClassifier,
428 Pca,
430 Autoencoder,
432 ArimaPlus,
434}
435
436#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
437#[serde(rename_all = "camelCase")]
438pub struct HparamTuningTrial {
439 #[serde(default, deserialize_with = "crate::http::from_str_option")]
441 pub trial_id: Option<i64>,
442 #[serde(default, deserialize_with = "crate::http::from_str_option")]
444 pub start_time_ms: Option<i64>,
445 #[serde(default, deserialize_with = "crate::http::from_str_option")]
447 pub end_time_ms: Option<i64>,
448 pub hparams: Option<TrainingOptions>,
450 pub evaluation_metrics: Option<EvaluationMetrics>,
452 pub status: Option<TrialStatus>,
454 pub error_message: Option<String>,
456 pub training_loss: Option<f64>,
458 pub eval_loss: Option<f64>,
460 pub hparam_tuning_evaluation_metrics: Option<EvaluationMetrics>,
463}
464
465#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
466#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
467pub enum TrialStatus {
468 #[default]
469 TrialStatusUnspecified,
470 NotStarted,
471 Running,
472 Succeeded,
473 Failed,
474 Infeasible,
475 StoppedEarly,
476}
477
478#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
479#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
480pub enum LossType {
481 #[default]
482 LossTypeUnspecified,
483 MeanSquaredLoss,
484 MeanLogLoss,
485}
486
487#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
488#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
489pub enum DataSplitMethod {
490 #[default]
491 DataSplitMethodUnspecified,
492 Random,
493 Custom,
494 Sequential,
495 NoSplit,
496 AutoSplit,
497}
498
499#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
500#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
501pub enum LearnRateStrategy {
502 #[default]
503 LearnRateStrategyUnspecified,
504 LineSearch,
505 Constant,
506}
507
508#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
509#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
510pub enum DistanceType {
511 #[default]
512 DistanceTypeUnspecified,
513 Euclidean,
514 Cosine,
515}
516
517#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
518#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
519pub enum OptimizationStrategy {
520 #[default]
521 OptimizationStrategyUnspecified,
522 BatchGradientDescent,
523 NormalEquation,
524}
525
526#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
527#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
528pub enum BoosterType {
529 #[default]
530 BoosterTypeUnspecified,
531 Gbtree,
532 Dart,
533}
534
535#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
536#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
537pub enum DartNormalizeType {
538 #[default]
539 DataNormalizeTypeUnspecified,
540 Tree,
541 Forest,
542}
543
544#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
545#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
546pub enum TestMethod {
547 #[default]
548 TreeMethodUnspecified,
549 Auto,
550 Exact,
551 Approx,
552 Hist,
553}
554
555#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
556#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
557pub enum FeedbackType {
558 #[default]
559 FeedbackTypeUnspecified,
560 Implicit,
561 Explicit,
562}
563
564#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
565#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
566pub enum KmeansInitializationMethod {
567 #[default]
568 KmeansInitializationMethodUnspecified,
569 Random,
570 Custom,
571 KmeansPlusPlus,
572}
573
574#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
575#[serde(rename_all = "camelCase")]
576pub struct ArimaOrder {
577 #[serde(default, deserialize_with = "crate::http::from_str_option")]
579 pub p: Option<i64>,
580 #[serde(default, deserialize_with = "crate::http::from_str_option")]
582 pub d: Option<i64>,
583 #[serde(default, deserialize_with = "crate::http::from_str_option")]
585 pub q: Option<i64>,
586}
587
588#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
589#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
590pub enum DataFrequency {
591 #[default]
592 DataFrequencyUnspecified,
593 AutoFrequency,
594 Yearly,
595 Quarterly,
596 Monthly,
597 Weekly,
598 Daily,
599 Hourly,
600 PerMinute,
601}
602
603#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
604#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
605pub enum HolidayRegion {
606 #[default]
607 HolidayRegionUnspecified,
608 Global,
609 Na,
610 Japac,
611 Emea,
612 Ae,
613 Ar,
614 At,
615 Au,
616 Be,
617 Br,
618 Ca,
619 Ch,
620 Cl,
621 Cn,
622 Co,
623 Cs,
624 Cz,
625 De,
626 Dk,
627 Dz,
628 Ec,
629 Ee,
630 Eg,
631 Es,
632 Fi,
633 Fr,
634 Gb,
635 Gr,
636 Hk,
637 Hu,
638 Id,
639 Ie,
640 Il,
641 In,
642 Ir,
643 It,
644 Jp,
645 Kr,
646 Lv,
647 Ma,
648 Mx,
649 My,
650 Mg,
651 Nl,
652 No,
653 Nz,
654 Pe,
655 Ph,
656 Pk,
657 Pl,
658 Pt,
659 Ro,
660 Rs,
661 Ru,
662 Sa,
663 Se,
664 Sg,
665 Si,
666 Sk,
667 Th,
668 Tr,
669 Tw,
670 Ua,
671 Us,
672 Ve,
673 Vn,
674 Za,
675}
676
677#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
678#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
679pub enum HparamTuningObjective {
680 #[default]
681 HparamTuningObjectiveUnspecified,
682 MeanAbsoluteError,
683 MeanSquaredError,
684 MeanSquaredLogError,
685 MedianAbsoluteError,
686 RSquared,
687 ExplainedVariance,
688 Precision,
689 Recall,
690 Accuracy,
691 F1Score,
692 LogLoss,
693 RocAuc,
694 DaviesBouldinIndex,
695 MeanAveragePrecision,
696 NormalizedDiscountedCumulativeGain,
697 AverageRank,
698}
699
700#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
701#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
702pub enum TreeMethod {
703 #[default]
704 TreeMethodUnspecified,
705 Auto,
706 Exact,
707 Approx,
708 Hist,
709}
710
711#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
712#[serde(rename_all = "camelCase")]
713pub struct TrainingOptions {
714 #[serde(default, deserialize_with = "crate::http::from_str_option")]
716 pub max_iterations: Option<i64>,
717 pub loss_type: Option<LossType>,
719 pub learn_rate: Option<f64>,
721 pub l1_regularization: Option<f64>,
723 pub l2_regularization: Option<f64>,
725 pub min_relative_progress: Option<f64>,
727 pub warm_start: Option<bool>,
729 pub early_stop: Option<bool>,
731 pub input_label_columns: Option<Vec<String>>,
733 pub data_split_method: Option<DataSplitMethod>,
735 pub data_split_eval_fraction: Option<f64>,
738 pub data_split_column: Option<String>,
745 pub learn_rate_strategy: Option<LearnRateStrategy>,
747 pub initial_learn_rate: Option<f64>,
749 pub label_class_weights: Option<HashMap<String, f64>>,
754 pub user_column: Option<String>,
756 pub item_column: Option<String>,
758 pub distance_type: Option<DistanceType>,
760 #[serde(default, deserialize_with = "crate::http::from_str_option")]
762 pub num_clusters: Option<i64>,
763 pub model_uri: Option<String>,
766 pub optimization_strategy: Option<OptimizationStrategy>,
768 #[serde(default, deserialize_with = "crate::http::from_str_vec_option")]
770 pub hidden_units: Option<Vec<i64>>,
771 #[serde(default, deserialize_with = "crate::http::from_str_option")]
773 pub batch_size: Option<i64>,
774 pub dropout: Option<f64>,
776 pub max_tree_depth: Option<i64>,
778 pub subsample: Option<f64>,
780 pub min_split_loss: Option<f64>,
782 pub booster_type: Option<BoosterType>,
784 #[serde(default, deserialize_with = "crate::http::from_str_option")]
786 pub num_parallel_tree: Option<i64>,
787 pub dart_normalize_type: Option<DartNormalizeType>,
789 pub tree_method: Option<TreeMethod>,
791 #[serde(default, deserialize_with = "crate::http::from_str_option")]
793 pub min_tree_child_weight: Option<i64>,
794 pub colsample_bytree: Option<f64>,
796 pub colsample_bylevel: Option<f64>,
798 pub colsample_bynode: Option<f64>,
800 #[serde(default, deserialize_with = "crate::http::from_str_option")]
802 pub num_factors: Option<i64>,
803 pub feedback_type: Option<FeedbackType>,
805 pub wals_alpha: Option<f64>,
807 pub kmeans_initialization_method: Option<KmeansInitializationMethod>,
809 pub kmeans_initialization_column: Option<String>,
811 pub time_series_timestamp_column: Option<String>,
813 pub time_series_data_column: Option<String>,
815 pub auto_arima: Option<bool>,
817 pub non_seasonal_order: Option<ArimaOrder>,
819 pub data_frequency: Option<DataFrequency>,
821 pub calculate_p_values: Option<bool>,
823 pub include_drift: Option<bool>,
825 pub holiday_region: Option<HolidayRegion>,
827 pub time_series_id_column: Option<String>,
829 pub time_series_id_columns: Option<Vec<String>>,
831 #[serde(default, deserialize_with = "crate::http::from_str_option")]
833 pub horizon: Option<i64>,
834 #[serde(skip_serializing_if = "Option::is_none")]
837 pub preserve_input_structs: Option<bool>,
838 #[serde(default, deserialize_with = "crate::http::from_str_option")]
840 pub auto_arima_max_order: Option<i64>,
841 #[serde(default, deserialize_with = "crate::http::from_str_option")]
843 pub auto_arima_min_order: Option<i64>,
844 #[serde(default, deserialize_with = "crate::http::from_str_option")]
846 pub num_trials: Option<i64>,
847 #[serde(default, deserialize_with = "crate::http::from_str_option")]
849 pub max_parallel_trials: Option<i64>,
850 pub hparam_tuning_objectives: Option<Vec<HparamTuningObjective>>,
852 pub decompose_time_series: Option<bool>,
854 pub clean_spikes_and_dips: Option<bool>,
856 pub adjust_step_changes: Option<bool>,
858 pub enable_global_explain: Option<bool>,
860 #[serde(default, deserialize_with = "crate::http::from_str_option")]
862 pub sampled_shapley_num_paths: Option<i64>,
863 #[serde(default, deserialize_with = "crate::http::from_str_option")]
865 pub integrated_gradients_num_steps: Option<i64>,
866}
867
868#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug)]
869#[serde(rename_all = "camelCase")]
870pub enum EvaluationMetrics {
871 RegressionMetrics(RegressionMetrics),
873 BinaryClassificationMetrics(BinaryClassificationMetrics),
875 MultiClassClassificationMetrics(MultiClassClassificationMetrics),
877 ClusteringMetrics(ClusteringMetrics),
879 RankingMetrics(RankingMetrics),
881 ArimaForecastingMetrics(ArimaForecastingMetrics),
883 DimensionalityReductionMetrics(DimensionalityReductionMetrics),
885}
886
887impl Default for EvaluationMetrics {
888 fn default() -> Self {
889 Self::RegressionMetrics(RegressionMetrics::default())
890 }
891}
892
893#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
894#[serde(rename_all = "camelCase")]
895pub struct RegressionMetrics {
896 pub mean_absolute_error: Option<f64>,
898 pub mean_squared_error: Option<f64>,
900 pub mean_squared_log_error: Option<f64>,
902 pub median_absolute_error: Option<f64>,
904 pub r_squared: Option<f64>,
906}
907
908#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
909#[serde(rename_all = "camelCase")]
910pub struct BinaryClassificationMetrics {
911 pub aggregate_classification_metrics: Option<AggregateClassificationMetrics>,
913 pub binary_confusion_matrix_list: Option<Vec<BinaryConfusionMatrix>>,
915 pub positive_label: Option<String>,
917 pub negative_label: Option<String>,
919}
920
921#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
922#[serde(rename_all = "camelCase")]
923pub struct AggregateClassificationMetrics {
924 pub precision: Option<f64>,
927 pub recall: Option<f64>,
930 pub accuracy: Option<f64>,
932 pub threshold: Option<f64>,
936 pub f1_score: Option<f64>,
938 pub log_loss: Option<f64>,
940 pub roc_auc: Option<f64>,
942}
943
944#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
945#[serde(rename_all = "camelCase")]
946pub struct BinaryConfusionMatrix {
947 pub positive_class_threshold: Option<f64>,
949 #[serde(default, deserialize_with = "crate::http::from_str_option")]
951 pub true_positives: Option<i64>,
952 #[serde(default, deserialize_with = "crate::http::from_str_option")]
954 pub false_positives: Option<i64>,
955 #[serde(default, deserialize_with = "crate::http::from_str_option")]
957 pub true_negatives: Option<i64>,
958 #[serde(default, deserialize_with = "crate::http::from_str_option")]
960 pub false_negatives: Option<i64>,
961 pub precision: Option<f64>,
963 pub recall: Option<f64>,
965 pub f1_score: Option<f64>,
967 pub accuracy: Option<f64>,
969}
970
971#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
972#[serde(rename_all = "camelCase")]
973pub struct MultiClassClassificationMetrics {
974 pub aggregate_classification_metrics: Option<AggregateClassificationMetrics>,
976 pub confusion_matrix_list: Option<Vec<ConfusionMatrix>>,
978}
979
980#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
981#[serde(rename_all = "camelCase")]
982pub struct ConfusionMatrix {
983 pub confidence_threshold: Option<f64>,
985 pub rows: Option<Vec<Row>>,
987}
988
989#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
990#[serde(rename_all = "camelCase")]
991pub struct Row {
992 pub actual_label: Option<String>,
994 pub entries: Option<Vec<Entry>>,
996}
997
998#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
999#[serde(rename_all = "camelCase")]
1000pub struct Entry {
1001 pub predicted_label: Option<String>,
1004 #[serde(default, deserialize_with = "crate::http::from_str_option")]
1006 pub item_count: Option<i64>,
1007}
1008
1009#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
1010#[serde(rename_all = "camelCase")]
1011pub struct ClusteringMetrics {
1012 pub davies_bouldin_index: Option<f64>,
1014 pub mean_squared_distance: Option<f64>,
1016 pub clusters: Option<Vec<Cluster>>,
1018}
1019
1020#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
1021#[serde(rename_all = "camelCase")]
1022pub struct Cluster {
1023 #[serde(default, deserialize_with = "crate::http::from_str_option")]
1025 pub centroid_id: Option<i64>,
1026 pub feature_values: Option<Vec<FeatureValue>>,
1028 #[serde(default, deserialize_with = "crate::http::from_str_option")]
1030 pub count: Option<i64>,
1031}
1032
1033#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
1034#[serde(rename_all = "camelCase")]
1035pub struct FeatureValue {
1036 pub feature_column: Option<String>,
1038 #[serde(flatten)]
1039 pub value: FeatureValueType,
1040}
1041
1042#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug)]
1043#[serde(rename_all = "camelCase")]
1044pub enum FeatureValueType {
1045 NumericalValue(f64),
1047 CategoricalValue(CategoricalValue),
1049}
1050
1051impl Default for FeatureValueType {
1052 fn default() -> Self {
1053 Self::NumericalValue(0.0)
1054 }
1055}
1056
1057#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
1058#[serde(rename_all = "camelCase")]
1059pub struct CategoricalValue {
1060 pub category_counts: Option<Vec<CategoryCount>>,
1064}
1065
1066#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
1067#[serde(rename_all = "camelCase")]
1068pub struct CategoryCount {
1069 pub category: Option<String>,
1071 #[serde(default, deserialize_with = "crate::http::from_str_option")]
1073 pub count: Option<i64>,
1074}
1075
1076#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
1077#[serde(rename_all = "camelCase")]
1078pub struct RankingMetrics {
1079 pub mean_average_precision: Option<f64>,
1081 pub mean_squared_error: Option<f64>,
1083 pub normalized_discounted_cumulative_gain: Option<f64>,
1085 pub average_rank: Option<f64>,
1087}
1088
1089#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
1090#[serde(rename_all = "camelCase")]
1091pub struct ArimaForecastingMetrics {
1092 pub arima_single_model_forecasting_metrics: Option<Vec<ArimaSingleModelForecastingMetrics>>,
1094}
1095
1096#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
1097#[serde(rename_all = "camelCase")]
1098pub struct ArimaSingleModelForecastingMetrics {
1099 pub non_seasonal_order: Option<ArimaOrder>,
1101 pub arima_fitting_metrics: Option<ArimaFittingMetrics>,
1103 pub has_drift: Option<bool>,
1105 pub time_series_id: Option<String>,
1107 pub time_series_ids: Option<Vec<String>>,
1110 pub seasonal_periods: Option<Vec<SeasonalPeriodType>>,
1112 pub has_holiday_effect: Option<bool>,
1114 pub has_spikes_and_dips: Option<bool>,
1116 pub has_step_changes: Option<bool>,
1118}
1119
1120#[derive(Clone, PartialEq, serde::Deserialize, serde::Serialize, Debug, Default)]
1121#[serde(rename_all = "camelCase")]
1122pub struct DimensionalityReductionMetrics {
1123 pub total_explained_variance_ratio: Option<f64>,
1125}