google_cloud_bigquery/http/
bigquery_model_client.rs1use std::sync::Arc;
2
3use crate::http::bigquery_client::BigqueryClient;
4use crate::http::error::Error;
5use crate::http::model;
6use crate::http::model::list::{ListModelsRequest, ListModelsResponse, ModelOverview};
7use crate::http::model::Model;
8
9#[derive(Debug, Clone)]
10pub struct BigqueryModelClient {
11 inner: Arc<BigqueryClient>,
12}
13
14impl BigqueryModelClient {
15 pub fn new(inner: Arc<BigqueryClient>) -> Self {
16 Self { inner }
17 }
18
19 #[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
21 pub async fn delete(&self, project_id: &str, dataset_id: &str, table_id: &str) -> Result<(), Error> {
22 let builder = model::delete::build(self.inner.endpoint(), self.inner.http(), project_id, dataset_id, table_id);
23 self.inner.send_get_empty(builder).await
24 }
25
26 #[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
28 pub async fn patch(&self, metadata: &Model) -> Result<Model, Error> {
29 let builder = model::patch::build(self.inner.endpoint(), self.inner.http(), metadata);
30 self.inner.send(builder).await
31 }
32
33 #[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
35 pub async fn get(&self, project_id: &str, dataset_id: &str, model_id: &str) -> Result<Model, Error> {
36 let builder = model::get::build(self.inner.endpoint(), self.inner.http(), project_id, dataset_id, model_id);
37 self.inner.send(builder).await
38 }
39
40 #[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
42 pub async fn list(
43 &self,
44 project_id: &str,
45 dataset_id: &str,
46 req: &ListModelsRequest,
47 ) -> Result<Vec<ModelOverview>, Error> {
48 let mut page_token: Option<String> = None;
49 let mut models = vec![];
50 loop {
51 let builder = model::list::build(
52 self.inner.endpoint(),
53 self.inner.http(),
54 project_id,
55 dataset_id,
56 req,
57 page_token,
58 );
59 let response: ListModelsResponse = self.inner.send(builder).await?;
60 models.extend(response.models);
61 if response.next_page_token.is_none() {
62 break;
63 }
64 page_token = response.next_page_token;
65 }
66 Ok(models)
67 }
68}
69
70#[cfg(test)]
71mod test {
72 use std::sync::Arc;
73
74 use serial_test::serial;
75 use time::OffsetDateTime;
76
77 use crate::http::bigquery_client::test::{create_client, dataset_name};
78 use crate::http::bigquery_job_client::BigqueryJobClient;
79 use crate::http::bigquery_model_client::BigqueryModelClient;
80 use crate::http::job::get::GetJobRequest;
81 use crate::http::job::query::QueryRequest;
82 use crate::http::job::{Job, JobConfiguration, JobConfigurationQuery, JobState, JobType, TrainingType};
83 use crate::http::model::list::ListModelsRequest;
84 use crate::http::model::ModelType;
85
86 #[tokio::test]
87 #[serial]
88 pub async fn crud_model() {
89 let dataset = dataset_name("model");
90 let (client, project) = create_client().await;
91 let job_client = BigqueryJobClient::new(Arc::new(client.clone()));
92 let client = BigqueryModelClient::new(Arc::new(client));
93
94 let model_id = format!("penguins_model_{}", OffsetDateTime::now_utc().unix_timestamp());
96 let mut job1 = Job::default();
97 job1.job_reference.job_id = format!("test_model_job_{}", OffsetDateTime::now_utc().unix_timestamp());
98 job1.job_reference.project_id = project.to_string();
99 job1.job_reference.location = Some("US".to_string());
100 job1.configuration = JobConfiguration {
101 job: JobType::Query(JobConfigurationQuery {
102 use_legacy_sql: Some(false),
103 query: format!(
104 "
105 CREATE OR REPLACE MODEL `{}.{}`
106 OPTIONS (model_type='linear_reg', input_label_cols=['body_mass_g']) AS
107 SELECT
108 *
109 FROM
110 `bigquery-public-data.ml_datasets.penguins`
111 WHERE
112 body_mass_g IS NOT NULL
113 LIMIT 100
114 ",
115 dataset, model_id
116 ),
117 ..Default::default()
118 }),
119 ..Default::default()
120 };
121 let mut job = job_client.create(&job1).await.unwrap();
122
123 let elapsed = 0;
125 loop {
126 if job.status.state == JobState::Done {
127 break;
128 }
129 let jr = &job.job_reference;
130 job = job_client
131 .get(&jr.project_id, &jr.job_id, &GetJobRequest { location: None })
132 .await
133 .unwrap();
134 tokio::time::sleep(tokio::time::Duration::from_secs(3)).await;
135 tracing::info!("current job status.state = {:?}", job.status.state);
136 assert!(elapsed < 20, "model creation timedout");
137 }
138 let statistics = job.statistics.unwrap().query.unwrap().ml_statistics;
139 let ml = statistics.unwrap();
140 assert_eq!(ml.training_type, TrainingType::SingleTraining);
141 assert_eq!(ml.model_type, ModelType::LinearRegression);
142
143 let result = job_client
145 .query(
146 &project,
147 &QueryRequest {
148 max_results: None,
149 query: format!(
150 "
151 SELECT * FROM ML.PREDICT(MODEL `{}.{}`, (
152 SELECT
153 *
154 FROM
155 `bigquery-public-data.ml_datasets.penguins`
156 WHERE
157 body_mass_g IS NOT NULL
158 AND island = 'Biscoe' LIMIT 10))
159 ",
160 dataset, model_id
161 ),
162 ..Default::default()
163 },
164 )
165 .await
166 .unwrap();
167 assert_eq!(result.total_rows.unwrap(), 10);
168
169 let models = client
171 .list(&project, &dataset, &ListModelsRequest::default())
172 .await
173 .unwrap();
174 assert!(!models.is_empty());
175
176 for model in models {
177 let model = model.model_reference;
178 let model = client
179 .get(model.project_id.as_str(), model.dataset_id.as_str(), model.model_id.as_str())
180 .await
181 .unwrap();
182 assert_eq!(model.model_type.clone().unwrap(), ModelType::LinearRegression);
183 let model = &client.patch(&model).await.unwrap().model_reference;
184 client
185 .delete(model.project_id.as_str(), model.dataset_id.as_str(), model.model_id.as_str())
186 .await
187 .unwrap();
188 }
189 }
190}