google_cloud_bigquery/http/
bigquery_model_client.rs

1use std::sync::Arc;
2
3use crate::http::bigquery_client::BigqueryClient;
4use crate::http::error::Error;
5use crate::http::model;
6use crate::http::model::list::{ListModelsRequest, ListModelsResponse, ModelOverview};
7use crate::http::model::Model;
8
9#[derive(Debug, Clone)]
10pub struct BigqueryModelClient {
11    inner: Arc<BigqueryClient>,
12}
13
14impl BigqueryModelClient {
15    pub fn new(inner: Arc<BigqueryClient>) -> Self {
16        Self { inner }
17    }
18
19    /// https://cloud.google.com/bigquery/docs/reference/rest/v2/models/delete
20    #[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
21    pub async fn delete(&self, project_id: &str, dataset_id: &str, table_id: &str) -> Result<(), Error> {
22        let builder = model::delete::build(self.inner.endpoint(), self.inner.http(), project_id, dataset_id, table_id);
23        self.inner.send_get_empty(builder).await
24    }
25
26    /// https://cloud.google.com/bigquery/docs/reference/rest/v2/models/patch
27    #[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
28    pub async fn patch(&self, metadata: &Model) -> Result<Model, Error> {
29        let builder = model::patch::build(self.inner.endpoint(), self.inner.http(), metadata);
30        self.inner.send(builder).await
31    }
32
33    /// https://cloud.google.com/bigquery/docs/reference/rest/v2/models/get
34    #[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
35    pub async fn get(&self, project_id: &str, dataset_id: &str, model_id: &str) -> Result<Model, Error> {
36        let builder = model::get::build(self.inner.endpoint(), self.inner.http(), project_id, dataset_id, model_id);
37        self.inner.send(builder).await
38    }
39
40    /// https://cloud.google.com/bigquery/docs/reference/rest/v2/models/list
41    #[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
42    pub async fn list(
43        &self,
44        project_id: &str,
45        dataset_id: &str,
46        req: &ListModelsRequest,
47    ) -> Result<Vec<ModelOverview>, Error> {
48        let mut page_token: Option<String> = None;
49        let mut models = vec![];
50        loop {
51            let builder = model::list::build(
52                self.inner.endpoint(),
53                self.inner.http(),
54                project_id,
55                dataset_id,
56                req,
57                page_token,
58            );
59            let response: ListModelsResponse = self.inner.send(builder).await?;
60            models.extend(response.models);
61            if response.next_page_token.is_none() {
62                break;
63            }
64            page_token = response.next_page_token;
65        }
66        Ok(models)
67    }
68}
69
70#[cfg(test)]
71mod test {
72    use std::sync::Arc;
73
74    use serial_test::serial;
75    use time::OffsetDateTime;
76
77    use crate::http::bigquery_client::test::{create_client, dataset_name};
78    use crate::http::bigquery_job_client::BigqueryJobClient;
79    use crate::http::bigquery_model_client::BigqueryModelClient;
80    use crate::http::job::get::GetJobRequest;
81    use crate::http::job::query::QueryRequest;
82    use crate::http::job::{Job, JobConfiguration, JobConfigurationQuery, JobState, JobType, TrainingType};
83    use crate::http::model::list::ListModelsRequest;
84    use crate::http::model::ModelType;
85
86    #[tokio::test]
87    #[serial]
88    pub async fn crud_model() {
89        let dataset = dataset_name("model");
90        let (client, project) = create_client().await;
91        let job_client = BigqueryJobClient::new(Arc::new(client.clone()));
92        let client = BigqueryModelClient::new(Arc::new(client));
93
94        // create model
95        let model_id = format!("penguins_model_{}", OffsetDateTime::now_utc().unix_timestamp());
96        let mut job1 = Job::default();
97        job1.job_reference.job_id = format!("test_model_job_{}", OffsetDateTime::now_utc().unix_timestamp());
98        job1.job_reference.project_id = project.to_string();
99        job1.job_reference.location = Some("US".to_string());
100        job1.configuration = JobConfiguration {
101            job: JobType::Query(JobConfigurationQuery {
102                use_legacy_sql: Some(false),
103                query: format!(
104                    "
105                    CREATE OR REPLACE MODEL `{}.{}`
106                    OPTIONS (model_type='linear_reg', input_label_cols=['body_mass_g']) AS
107                        SELECT
108                            *
109                        FROM
110                            `bigquery-public-data.ml_datasets.penguins`
111                        WHERE
112                            body_mass_g IS NOT NULL
113                        LIMIT 100
114                    ",
115                    dataset, model_id
116                ),
117                ..Default::default()
118            }),
119            ..Default::default()
120        };
121        let mut job = job_client.create(&job1).await.unwrap();
122
123        // wait for training complete
124        let elapsed = 0;
125        loop {
126            if job.status.state == JobState::Done {
127                break;
128            }
129            let jr = &job.job_reference;
130            job = job_client
131                .get(&jr.project_id, &jr.job_id, &GetJobRequest { location: None })
132                .await
133                .unwrap();
134            tokio::time::sleep(tokio::time::Duration::from_secs(3)).await;
135            tracing::info!("current job status.state = {:?}", job.status.state);
136            assert!(elapsed < 20, "model creation timedout");
137        }
138        let statistics = job.statistics.unwrap().query.unwrap().ml_statistics;
139        let ml = statistics.unwrap();
140        assert_eq!(ml.training_type, TrainingType::SingleTraining);
141        assert_eq!(ml.model_type, ModelType::LinearRegression);
142
143        // predict
144        let result = job_client
145            .query(
146                &project,
147                &QueryRequest {
148                    max_results: None,
149                    query: format!(
150                        "
151                    SELECT * FROM  ML.PREDICT(MODEL `{}.{}`, (
152                        SELECT
153                            *
154                        FROM
155                            `bigquery-public-data.ml_datasets.penguins`
156                        WHERE
157                            body_mass_g IS NOT NULL
158                        AND island = 'Biscoe' LIMIT 10))
159                    ",
160                        dataset, model_id
161                    ),
162                    ..Default::default()
163                },
164            )
165            .await
166            .unwrap();
167        assert_eq!(result.total_rows.unwrap(), 10);
168
169        // list / get / patch / delete
170        let models = client
171            .list(&project, &dataset, &ListModelsRequest::default())
172            .await
173            .unwrap();
174        assert!(!models.is_empty());
175
176        for model in models {
177            let model = model.model_reference;
178            let model = client
179                .get(model.project_id.as_str(), model.dataset_id.as_str(), model.model_id.as_str())
180                .await
181                .unwrap();
182            assert_eq!(model.model_type.clone().unwrap(), ModelType::LinearRegression);
183            let model = &client.patch(&model).await.unwrap().model_reference;
184            client
185                .delete(model.project_id.as_str(), model.dataset_id.as_str(), model.model_id.as_str())
186                .await
187                .unwrap();
188        }
189    }
190}