1use async_trait::async_trait;
29use bb8::ManageConnection;
30use clickhouse::{Client, Compression};
31use std::ops::{Deref, DerefMut};
32use thiserror::Error;
33
34#[derive(Error, Debug)]
36pub enum ClickHouseError {
37 #[error("Failed to create connection")]
38 ConnectionFailed(#[from] clickhouse::error::Error),
39
40 #[error("Health check failed")]
41 HealthCheckFailed(#[from] Box<clickhouse::error::Error>),
42}
43
44#[derive(Clone)]
49pub struct ConnectionBuilder {
50 url: Option<String>,
51 database: Option<String>,
52 user: Option<String>,
53 password: Option<String>,
54 access_token: Option<String>,
55 compression: Option<Compression>,
56 headers: Vec<(String, String)>,
57 options: Vec<(String, String)>,
58 roles: Option<Vec<String>>,
59 default_roles: bool,
60 product_name: Option<String>,
61 product_version: Option<String>,
62 validation: bool,
63}
64
65impl Default for ConnectionBuilder {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71impl ConnectionBuilder {
72 pub fn new() -> Self {
74 Self {
75 url: None,
76 database: None,
77 user: None,
78 password: None,
79 access_token: None,
80 compression: None,
81 headers: Vec::new(),
82 options: Vec::new(),
83 roles: None,
84 default_roles: false,
85 product_name: None,
86 product_version: None,
87 validation: false,
88 }
89 }
90
91 pub fn with_url(mut self, url: impl Into<String>) -> Self {
93 self.url = Some(url.into());
94 self
95 }
96
97 pub fn with_database(mut self, database: impl Into<String>) -> Self {
99 self.database = Some(database.into());
100 self
101 }
102
103 pub fn with_user(mut self, user: impl Into<String>) -> Self {
105 self.user = Some(user.into());
106 self
107 }
108
109 pub fn with_password(mut self, password: impl Into<String>) -> Self {
111 self.password = Some(password.into());
112 self
113 }
114
115 pub fn with_access_token(mut self, access_token: impl Into<String>) -> Self {
117 self.access_token = Some(access_token.into());
118 self
119 }
120
121 pub fn with_compression(mut self, compression: Compression) -> Self {
123 self.compression = Some(compression);
124 self
125 }
126
127 pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
129 self.headers.push((name.into(), value.into()));
130 self
131 }
132
133 pub fn with_option(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
135 self.options.push((name.into(), value.into()));
136 self
137 }
138
139 pub fn with_roles<I>(mut self, roles: I) -> Self
141 where
142 I: IntoIterator,
143 I::Item: Into<String>,
144 {
145 self.roles = Some(roles.into_iter().map(|r| r.into()).collect());
146 self
147 }
148
149 pub fn with_default_roles(mut self) -> Self {
151 self.default_roles = true;
152 self
153 }
154
155 pub fn with_product_info(
157 mut self,
158 product_name: impl Into<String>,
159 product_version: impl Into<String>,
160 ) -> Self {
161 self.product_name = Some(product_name.into());
162 self.product_version = Some(product_version.into());
163 self
164 }
165
166 pub fn with_validation(mut self, enabled: bool) -> Self {
168 self.validation = enabled;
169 self
170 }
171
172 fn build_client(&self) -> Client {
174 let mut client = Client::default();
175
176 if let Some(url) = &self.url {
178 client = client.with_url(url.clone());
179 }
180
181 if let Some(database) = &self.database {
182 client = client.with_database(database.clone());
183 }
184
185 if let Some(user) = &self.user {
186 client = client.with_user(user.clone());
187 }
188
189 if let Some(password) = &self.password {
190 client = client.with_password(password.clone());
191 }
192
193 if let Some(compression) = self.compression {
194 client = client.with_compression(compression);
195 }
196
197 for (name, value) in &self.options {
198 client = client.with_option(name.clone(), value.clone());
199 }
200
201 client
202 }
203}
204
205pub struct Connection {
211 client: Client,
212 is_broken: bool,
213}
214
215impl Connection {
216 fn new(client: Client) -> Self {
218 Self {
219 client,
220 is_broken: false,
221 }
222 }
223
224 fn mark_broken(&mut self) {
226 self.is_broken = true;
227 }
228
229 pub fn is_broken(&self) -> bool {
231 self.is_broken
232 }
233}
234
235impl Deref for Connection {
236 type Target = Client;
237
238 fn deref(&self) -> &Self::Target {
239 &self.client
240 }
241}
242
243impl DerefMut for Connection {
244 fn deref_mut(&mut self) -> &mut Self::Target {
245 &mut self.client
246 }
247}
248
249pub struct ConnectionManager {
255 builder: ConnectionBuilder,
256}
257
258impl ConnectionManager {
259 pub fn new(builder: ConnectionBuilder) -> Self {
261 Self { builder }
262 }
263}
264
265#[async_trait]
266impl ManageConnection for ConnectionManager {
267 type Connection = Connection;
268 type Error = ClickHouseError;
269
270 async fn connect(&self) -> Result<Self::Connection, Self::Error> {
272 let client = self.builder.build_client();
273 Ok(Connection::new(client))
274 }
275
276 async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
278 match conn.query("select 1").fetch_optional::<u8>().await {
279 Ok(_) => Ok(()),
280 Err(e) => {
281 conn.mark_broken();
282 Err(ClickHouseError::HealthCheckFailed(Box::new(e)))
283 }
284 }
285 }
286
287 fn has_broken(&self, conn: &mut Self::Connection) -> bool {
289 conn.is_broken()
290 }
291}
292
293pub type Pool = bb8::Pool<ConnectionManager>;
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn test_connection_builder_with_all_options() {
305 let builder = ConnectionBuilder::new()
306 .with_url("http://localhost:8123")
307 .with_database("default")
308 .with_user("default")
309 .with_password("password")
310 .with_access_token("token123")
311 .with_compression(Compression::Lz4)
312 .with_header("X-Custom", "value")
313 .with_option("max_rows_to_read", "1000")
314 .with_roles(vec!["role1"])
315 .with_default_roles()
316 .with_product_info("myapp", "1.0.0")
317 .with_validation(true);
318
319 assert_eq!(builder.url, Some("http://localhost:8123".to_string()));
320 assert_eq!(builder.database, Some("default".to_string()));
321 assert_eq!(builder.user, Some("default".to_string()));
322 assert_eq!(builder.password, Some("password".to_string()));
323 assert_eq!(builder.access_token, Some("token123".to_string()));
324 assert!(builder.compression.is_some());
325 assert_eq!(builder.headers.len(), 1);
326 assert_eq!(builder.options.len(), 1);
327 assert_eq!(builder.roles.as_ref().map(|r| r.len()), Some(1));
328 assert!(builder.default_roles);
329 assert_eq!(builder.product_name, Some("myapp".to_string()));
330 assert_eq!(builder.product_version, Some("1.0.0".to_string()));
331 assert!(builder.validation);
332 }
333
334 #[test]
335 fn test_connection_creation() {
336 let client = Client::default();
337 let conn = Connection::new(client);
338 assert!(!conn.is_broken());
339 }
340
341 #[test]
342 fn test_connection_mark_broken() {
343 let client = Client::default();
344 let mut conn = Connection::new(client);
345 assert!(!conn.is_broken());
346
347 conn.mark_broken();
348 assert!(conn.is_broken());
349 }
350}