gcloud_bigquery/http/
bigquery_model_client.rs

1use std::sync::Arc;
2
3use crate::http::bigquery_client::BigqueryClient;
4use crate::http::error::Error;
5use crate::http::model;
6use crate::http::model::list::{ListModelsRequest, ListModelsResponse, ModelOverview};
7use crate::http::model::Model;
8
9#[derive(Debug, Clone)]
10pub struct BigqueryModelClient {
11    inner: Arc<BigqueryClient>,
12}
13
14impl BigqueryModelClient {
15    pub fn new(inner: Arc<BigqueryClient>) -> Self {
16        Self { inner }
17    }
18
19    /// https://cloud.google.com/bigquery/docs/reference/rest/v2/models/delete
20    #[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
21    pub async fn delete(&self, project_id: &str, dataset_id: &str, table_id: &str) -> Result<(), Error> {
22        let builder = model::delete::build(self.inner.endpoint(), self.inner.http(), project_id, dataset_id, table_id);
23        self.inner.send_get_empty(builder).await
24    }
25
26    /// https://cloud.google.com/bigquery/docs/reference/rest/v2/models/patch
27    #[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
28    pub async fn patch(&self, metadata: &Model) -> Result<Model, Error> {
29        let builder = model::patch::build(self.inner.endpoint(), self.inner.http(), metadata);
30        self.inner.send(builder).await
31    }
32
33    /// https://cloud.google.com/bigquery/docs/reference/rest/v2/models/get
34    #[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
35    pub async fn get(&self, project_id: &str, dataset_id: &str, model_id: &str) -> Result<Model, Error> {
36        let builder = model::get::build(self.inner.endpoint(), self.inner.http(), project_id, dataset_id, model_id);
37        self.inner.send(builder).await
38    }
39
40    /// https://cloud.google.com/bigquery/docs/reference/rest/v2/models/list
41    #[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
42    pub async fn list(
43        &self,
44        project_id: &str,
45        dataset_id: &str,
46        req: &ListModelsRequest,
47    ) -> Result<Vec<ModelOverview>, Error> {
48        let mut page_token: Option<String> = None;
49        let mut models = vec![];
50        loop {
51            let builder = model::list::build(
52                self.inner.endpoint(),
53                self.inner.http(),
54                project_id,
55                dataset_id,
56                req,
57                page_token,
58            );
59            let response: ListModelsResponse = self.inner.send(builder).await?;
60            models.extend(response.models);
61            if response.next_page_token.is_none() {
62                break;
63            }
64            page_token = response.next_page_token;
65        }
66        Ok(models)
67    }
68}
69
70#[cfg(test)]
71mod test {
72    use std::sync::Arc;
73
74    use serial_test::serial;
75    use time::OffsetDateTime;
76
77    use crate::http::bigquery_client::test::{create_client, dataset_name};
78    use crate::http::bigquery_job_client::BigqueryJobClient;
79    use crate::http::bigquery_model_client::BigqueryModelClient;
80    use crate::http::job::get::GetJobRequest;
81    use crate::http::job::query::QueryRequest;
82    use crate::http::job::{Job, JobConfiguration, JobConfigurationQuery, JobState, JobType, TrainingType};
83    use crate::http::model::list::ListModelsRequest;
84    use crate::http::model::ModelType;
85
86    #[tokio::test]
87    #[serial]
88    pub async fn crud_model() {
89        let dataset = dataset_name("model");
90        let (client, project) = create_client().await;
91        let job_client = BigqueryJobClient::new(Arc::new(client.clone()));
92        let client = BigqueryModelClient::new(Arc::new(client));
93
94        // create model
95        let model_id = format!("penguins_model_{}", OffsetDateTime::now_utc().unix_timestamp());
96        let mut job1 = Job::default();
97        job1.job_reference.job_id = format!("test_model_job_{}", OffsetDateTime::now_utc().unix_timestamp());
98        job1.job_reference.project_id = project.to_string();
99        job1.job_reference.location = Some("US".to_string());
100        job1.configuration = JobConfiguration {
101            job: JobType::Query(JobConfigurationQuery {
102                use_legacy_sql: Some(false),
103                query: format!(
104                    "
105                    CREATE OR REPLACE MODEL `{dataset}.{model_id}`
106                    OPTIONS (model_type='linear_reg', input_label_cols=['body_mass_g']) AS
107                        SELECT
108                            *
109                        FROM
110                            `bigquery-public-data.ml_datasets.penguins`
111                        WHERE
112                            body_mass_g IS NOT NULL
113                        LIMIT 100
114                    "
115                ),
116                ..Default::default()
117            }),
118            ..Default::default()
119        };
120        let mut job = job_client.create(&job1).await.unwrap();
121
122        // wait for training complete
123        let elapsed = 0;
124        loop {
125            if job.status.state == JobState::Done {
126                break;
127            }
128            let jr = &job.job_reference;
129            job = job_client
130                .get(&jr.project_id, &jr.job_id, &GetJobRequest { location: None })
131                .await
132                .unwrap();
133            tokio::time::sleep(tokio::time::Duration::from_secs(3)).await;
134            tracing::info!("current job status.state = {:?}", job.status.state);
135            assert!(elapsed < 20, "model creation timedout");
136        }
137        let statistics = job.statistics.unwrap().query.unwrap().ml_statistics;
138        let ml = statistics.unwrap();
139        assert_eq!(ml.training_type, TrainingType::SingleTraining);
140        assert_eq!(ml.model_type, ModelType::LinearRegression);
141
142        // predict
143        let result = job_client
144            .query(
145                &project,
146                &QueryRequest {
147                    max_results: None,
148                    query: format!(
149                        "
150                    SELECT * FROM  ML.PREDICT(MODEL `{dataset}.{model_id}`, (
151                        SELECT
152                            *
153                        FROM
154                            `bigquery-public-data.ml_datasets.penguins`
155                        WHERE
156                            body_mass_g IS NOT NULL
157                        AND island = 'Biscoe' LIMIT 10))
158                    "
159                    ),
160                    ..Default::default()
161                },
162            )
163            .await
164            .unwrap();
165        assert_eq!(result.total_rows.unwrap(), 10);
166
167        // list / get / patch / delete
168        let models = client
169            .list(&project, &dataset, &ListModelsRequest::default())
170            .await
171            .unwrap();
172        assert!(!models.is_empty());
173
174        for model in models {
175            let model = model.model_reference;
176            let model = client
177                .get(model.project_id.as_str(), model.dataset_id.as_str(), model.model_id.as_str())
178                .await
179                .unwrap();
180            assert_eq!(model.model_type.clone().unwrap(), ModelType::LinearRegression);
181            let model = &client.patch(&model).await.unwrap().model_reference;
182            client
183                .delete(model.project_id.as_str(), model.dataset_id.as_str(), model.model_id.as_str())
184                .await
185                .unwrap();
186        }
187    }
188}