1use crate::balancer::BalancingStrategy;
7use crate::error::{NetError, NetResult};
8use crate::pool::{ConnectionPool, ConnectionPoolBuilder, PoolConfig, PoolStats};
9use std::sync::Arc;
12use std::time::Duration;
13use tonic::transport::Channel;
14
15#[derive(Debug, Clone)]
17pub struct ClientConfig {
18 pub connect_timeout: Duration,
20 pub request_timeout: Duration,
22 pub keep_alive: bool,
24 pub keep_alive_interval: Duration,
26 pub pool: PoolConfig,
28}
29
30impl Default for ClientConfig {
31 fn default() -> Self {
32 Self {
33 connect_timeout: Duration::from_secs(10),
34 request_timeout: Duration::from_secs(30),
35 keep_alive: true,
36 keep_alive_interval: Duration::from_secs(60),
37 pool: PoolConfig::default(),
38 }
39 }
40}
41
42pub struct AqlClient {
44 pool: Arc<ConnectionPool>,
45 config: ClientConfig,
46}
47
48impl AqlClient {
49 pub fn new() -> Self {
51 Self::with_config(ClientConfig::default())
52 }
53
54 pub fn with_config(config: ClientConfig) -> Self {
56 let pool = ConnectionPool::new(config.pool.clone());
57
58 Self {
59 pool: Arc::new(pool),
60 config,
61 }
62 }
63
64 pub fn builder() -> AqlClientBuilder {
66 AqlClientBuilder::new()
67 }
68
69 pub fn add_endpoint(&self, id: String, address: String) {
71 self.pool.add_endpoint(id, address);
72 }
73
74 pub fn add_endpoint_with_weight(&self, id: String, address: String, weight: u32) {
76 self.pool.add_endpoint_with_weight(id, address, weight);
77 }
78
79 pub fn remove_endpoint(&self, endpoint_id: &str) -> bool {
81 self.pool.remove_endpoint(endpoint_id)
82 }
83
84 pub fn pool_stats(&self) -> PoolStats {
95 self.pool.stats()
96 }
97
98 pub fn circuit_breaker_stats(&self) -> Option<crate::circuit_breaker::CircuitBreakerStats> {
100 self.pool.circuit_breaker_stats()
101 }
102
103 pub async fn drain(&self) -> NetResult<()> {
134 self.pool.drain().await
135 }
136
137 pub async fn shutdown(self) -> NetResult<()> {
139 Arc::try_unwrap(self.pool)
140 .map_err(|_| {
141 NetError::ServerInternal("Cannot shutdown: pool still has references".to_string())
142 })?
143 .shutdown()
144 .await
145 }
146}
147
148impl Default for AqlClient {
149 fn default() -> Self {
150 Self::new()
151 }
152}
153
154pub struct AqlClientBuilder {
156 config: ClientConfig,
157 pool_builder: ConnectionPoolBuilder,
158}
159
160impl AqlClientBuilder {
161 pub fn new() -> Self {
163 Self {
164 config: ClientConfig::default(),
165 pool_builder: ConnectionPoolBuilder::new(),
166 }
167 }
168
169 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
171 self.config.connect_timeout = timeout;
172 self.pool_builder = self.pool_builder.connect_timeout(timeout);
173 self
174 }
175
176 pub fn request_timeout(mut self, timeout: Duration) -> Self {
178 self.config.request_timeout = timeout;
179 self
180 }
181
182 pub fn keep_alive(mut self, enabled: bool) -> Self {
184 self.config.keep_alive = enabled;
185 self
186 }
187
188 pub fn keep_alive_interval(mut self, interval: Duration) -> Self {
190 self.config.keep_alive_interval = interval;
191 self
192 }
193
194 pub fn min_pool_size(mut self, size: usize) -> Self {
196 self.config.pool.min_size = size;
197 self.pool_builder = self.pool_builder.min_size(size);
198 self
199 }
200
201 pub fn max_pool_size(mut self, size: usize) -> Self {
203 self.config.pool.max_size = size;
204 self.pool_builder = self.pool_builder.max_size(size);
205 self
206 }
207
208 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
210 self.config.pool.idle_timeout = timeout;
211 self.pool_builder = self.pool_builder.idle_timeout(timeout);
212 self
213 }
214
215 pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
217 self.config.pool.max_lifetime = lifetime;
218 self.pool_builder = self.pool_builder.max_lifetime(lifetime);
219 self
220 }
221
222 pub fn health_check_interval(mut self, interval: Duration) -> Self {
224 self.config.pool.health_check_interval = interval;
225 self.pool_builder = self.pool_builder.health_check_interval(interval);
226 self
227 }
228
229 pub fn balancing_strategy(mut self, strategy: BalancingStrategy) -> Self {
231 self.config.pool.balancing_strategy = strategy;
232 self.pool_builder = self.pool_builder.balancing_strategy(strategy);
233 self
234 }
235
236 pub fn circuit_breaker(mut self, enabled: bool) -> Self {
238 self.config.pool.enable_circuit_breaker = enabled;
239 self.pool_builder = self.pool_builder.circuit_breaker(enabled);
240 self
241 }
242
243 pub fn add_endpoint(mut self, id: String, address: String) -> Self {
245 self.pool_builder = self.pool_builder.add_endpoint(id, address);
246 self
247 }
248
249 pub fn add_endpoint_with_weight(mut self, id: String, address: String, weight: u32) -> Self {
251 self.pool_builder = self
252 .pool_builder
253 .add_endpoint_with_weight(id, address, weight);
254 self
255 }
256
257 pub fn build(self) -> AqlClient {
259 let pool = self.pool_builder.build();
260
261 AqlClient {
262 pool: Arc::new(pool),
263 config: self.config,
264 }
265 }
266}
267
268impl Default for AqlClientBuilder {
269 fn default() -> Self {
270 Self::new()
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[test]
279 fn test_client_config_default() {
280 let config = ClientConfig::default();
281 assert_eq!(config.connect_timeout, Duration::from_secs(10));
282 assert_eq!(config.request_timeout, Duration::from_secs(30));
283 assert!(config.keep_alive);
284 }
285
286 #[tokio::test]
287 async fn test_client_creation() {
288 let config = ClientConfig::default();
289 let _client = AqlClient::with_config(config);
290 }
291
292 #[tokio::test]
293 async fn test_client_builder() {
294 let client = AqlClient::builder()
295 .connect_timeout(Duration::from_secs(5))
296 .request_timeout(Duration::from_secs(15))
297 .min_pool_size(3)
298 .max_pool_size(15)
299 .balancing_strategy(BalancingStrategy::RoundRobin)
300 .add_endpoint("ep1".to_string(), "localhost:50051".to_string())
301 .add_endpoint("ep2".to_string(), "localhost:50052".to_string())
302 .build();
303
304 let stats = client.pool_stats();
305 assert_eq!(stats.active_connections, 0);
306 }
307
308 #[tokio::test]
309 async fn test_client_add_remove_endpoint() {
310 let client = AqlClient::new();
311
312 client.add_endpoint("ep1".to_string(), "localhost:50051".to_string());
313 client.add_endpoint("ep2".to_string(), "localhost:50052".to_string());
314
315 assert!(client.remove_endpoint("ep1"));
316 assert!(!client.remove_endpoint("ep3"));
317 }
318
319 #[tokio::test]
320 async fn test_client_pool_stats() {
321 let client = AqlClient::builder()
322 .add_endpoint("ep1".to_string(), "localhost:50051".to_string())
323 .build();
324
325 let stats = client.pool_stats();
326 assert_eq!(stats.total_connections, 0);
327 }
328
329 #[tokio::test]
330 async fn test_client_drain() {
331 let client = AqlClient::builder()
332 .add_endpoint("ep1".to_string(), "localhost:50051".to_string())
333 .build();
334
335 let result = client.drain().await;
336 assert!(result.is_ok());
337 }
338}