gcloud_bigquery/http/
bigquery_model_client.rs1use std::sync::Arc;
2
3use crate::http::bigquery_client::BigqueryClient;
4use crate::http::error::Error;
5use crate::http::model;
6use crate::http::model::list::{ListModelsRequest, ListModelsResponse, ModelOverview};
7use crate::http::model::Model;
8
9#[derive(Debug, Clone)]
10pub struct BigqueryModelClient {
11 inner: Arc<BigqueryClient>,
12}
13
14impl BigqueryModelClient {
15 pub fn new(inner: Arc<BigqueryClient>) -> Self {
16 Self { inner }
17 }
18
19 #[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
21 pub async fn delete(&self, project_id: &str, dataset_id: &str, table_id: &str) -> Result<(), Error> {
22 let builder = model::delete::build(self.inner.endpoint(), self.inner.http(), project_id, dataset_id, table_id);
23 self.inner.send_get_empty(builder).await
24 }
25
26 #[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
28 pub async fn patch(&self, metadata: &Model) -> Result<Model, Error> {
29 let builder = model::patch::build(self.inner.endpoint(), self.inner.http(), metadata);
30 self.inner.send(builder).await
31 }
32
33 #[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
35 pub async fn get(&self, project_id: &str, dataset_id: &str, model_id: &str) -> Result<Model, Error> {
36 let builder = model::get::build(self.inner.endpoint(), self.inner.http(), project_id, dataset_id, model_id);
37 self.inner.send(builder).await
38 }
39
40 #[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
42 pub async fn list(
43 &self,
44 project_id: &str,
45 dataset_id: &str,
46 req: &ListModelsRequest,
47 ) -> Result<Vec<ModelOverview>, Error> {
48 let mut page_token: Option<String> = None;
49 let mut models = vec![];
50 loop {
51 let builder = model::list::build(
52 self.inner.endpoint(),
53 self.inner.http(),
54 project_id,
55 dataset_id,
56 req,
57 page_token,
58 );
59 let response: ListModelsResponse = self.inner.send(builder).await?;
60 models.extend(response.models);
61 if response.next_page_token.is_none() {
62 break;
63 }
64 page_token = response.next_page_token;
65 }
66 Ok(models)
67 }
68}
69
70#[cfg(test)]
71mod test {
72 use std::sync::Arc;
73
74 use serial_test::serial;
75 use time::OffsetDateTime;
76
77 use crate::http::bigquery_client::test::{create_client, dataset_name};
78 use crate::http::bigquery_job_client::BigqueryJobClient;
79 use crate::http::bigquery_model_client::BigqueryModelClient;
80 use crate::http::job::get::GetJobRequest;
81 use crate::http::job::query::QueryRequest;
82 use crate::http::job::{Job, JobConfiguration, JobConfigurationQuery, JobState, JobType, TrainingType};
83 use crate::http::model::list::ListModelsRequest;
84 use crate::http::model::ModelType;
85
86 #[tokio::test]
87 #[serial]
88 pub async fn crud_model() {
89 let dataset = dataset_name("model");
90 let (client, project) = create_client().await;
91 let job_client = BigqueryJobClient::new(Arc::new(client.clone()));
92 let client = BigqueryModelClient::new(Arc::new(client));
93
94 let model_id = format!("penguins_model_{}", OffsetDateTime::now_utc().unix_timestamp());
96 let mut job1 = Job::default();
97 job1.job_reference.job_id = format!("test_model_job_{}", OffsetDateTime::now_utc().unix_timestamp());
98 job1.job_reference.project_id = project.to_string();
99 job1.job_reference.location = Some("US".to_string());
100 job1.configuration = JobConfiguration {
101 job: JobType::Query(JobConfigurationQuery {
102 use_legacy_sql: Some(false),
103 query: format!(
104 "
105 CREATE OR REPLACE MODEL `{dataset}.{model_id}`
106 OPTIONS (model_type='linear_reg', input_label_cols=['body_mass_g']) AS
107 SELECT
108 *
109 FROM
110 `bigquery-public-data.ml_datasets.penguins`
111 WHERE
112 body_mass_g IS NOT NULL
113 LIMIT 100
114 "
115 ),
116 ..Default::default()
117 }),
118 ..Default::default()
119 };
120 let mut job = job_client.create(&job1).await.unwrap();
121
122 let elapsed = 0;
124 loop {
125 if job.status.state == JobState::Done {
126 break;
127 }
128 let jr = &job.job_reference;
129 job = job_client
130 .get(&jr.project_id, &jr.job_id, &GetJobRequest { location: None })
131 .await
132 .unwrap();
133 tokio::time::sleep(tokio::time::Duration::from_secs(3)).await;
134 tracing::info!("current job status.state = {:?}", job.status.state);
135 assert!(elapsed < 20, "model creation timedout");
136 }
137 let statistics = job.statistics.unwrap().query.unwrap().ml_statistics;
138 let ml = statistics.unwrap();
139 assert_eq!(ml.training_type, TrainingType::SingleTraining);
140 assert_eq!(ml.model_type, ModelType::LinearRegression);
141
142 let result = job_client
144 .query(
145 &project,
146 &QueryRequest {
147 max_results: None,
148 query: format!(
149 "
150 SELECT * FROM ML.PREDICT(MODEL `{dataset}.{model_id}`, (
151 SELECT
152 *
153 FROM
154 `bigquery-public-data.ml_datasets.penguins`
155 WHERE
156 body_mass_g IS NOT NULL
157 AND island = 'Biscoe' LIMIT 10))
158 "
159 ),
160 ..Default::default()
161 },
162 )
163 .await
164 .unwrap();
165 assert_eq!(result.total_rows.unwrap(), 10);
166
167 let models = client
169 .list(&project, &dataset, &ListModelsRequest::default())
170 .await
171 .unwrap();
172 assert!(!models.is_empty());
173
174 for model in models {
175 let model = model.model_reference;
176 let model = client
177 .get(model.project_id.as_str(), model.dataset_id.as_str(), model.model_id.as_str())
178 .await
179 .unwrap();
180 assert_eq!(model.model_type.clone().unwrap(), ModelType::LinearRegression);
181 let model = &client.patch(&model).await.unwrap().model_reference;
182 client
183 .delete(model.project_id.as_str(), model.dataset_id.as_str(), model.model_id.as_str())
184 .await
185 .unwrap();
186 }
187 }
188}