1use async_trait::async_trait;
2use parking_lot::RwLock;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use crate::channel::message::MessageChannel;
8use crate::client::RpcClient;
9use crate::codec::{BincodeCodec, Codec};
10use crate::discovery::Endpoint;
11use crate::error::{Result, RpcError};
12use crate::loadbalancer::{LoadBalanceStrategy, LoadBalancer};
13use crate::streaming::StreamReceiver;
14
15#[async_trait]
16pub trait ClientFactory<T, C>: Send + Sync
17where
18 T: MessageChannel<C>,
19 C: Codec,
20{
21 async fn create(&self, endpoint: &Endpoint) -> Result<RpcClient<T, C>>;
22}
23
24pub struct LoadBalancedClient<T, C, S>
25where
26 T: MessageChannel<C>,
27 C: Codec,
28 S: LoadBalanceStrategy,
29{
30 load_balancer: Arc<LoadBalancer<S>>,
31 clients: RwLock<HashMap<usize, Arc<RpcClient<T, C>>>>,
32 factory: Arc<dyn ClientFactory<T, C>>,
33}
34
35impl<T, C, S> LoadBalancedClient<T, C, S>
36where
37 T: MessageChannel<C> + 'static,
38 C: Codec + Clone + Default + 'static,
39 S: LoadBalanceStrategy + 'static,
40{
41 pub fn new(load_balancer: Arc<LoadBalancer<S>>, factory: Arc<dyn ClientFactory<T, C>>) -> Self {
42 Self {
43 load_balancer,
44 clients: RwLock::new(HashMap::new()),
45 factory,
46 }
47 }
48
49 pub async fn init(&self) -> Result<()> {
50 self.load_balancer.init().await?;
51
52 for i in 0..self.load_balancer.server_count() {
53 if let Some(endpoint) = self.load_balancer.get_endpoint(i) {
54 let _ = self.get_or_create_client(i, &endpoint).await;
55 }
56 }
57
58 Ok(())
59 }
60
61 async fn get_or_create_client(
62 &self,
63 server_idx: usize,
64 endpoint: &Endpoint,
65 ) -> Result<Arc<RpcClient<T, C>>> {
66 {
67 let clients = self.clients.read();
68 if let Some(client) = clients.get(&server_idx) {
69 if client.is_connected() {
70 return Ok(client.clone());
71 }
72 }
73 }
74
75 let client = self.factory.create(endpoint).await?;
76 let client = Arc::new(client);
77
78 let _handle = client.start();
79
80 self.clients.write().insert(server_idx, client.clone());
81
82 Ok(client)
83 }
84
85 pub async fn call<Req, Resp>(&self, method: &str, request: &Req) -> Result<Resp>
86 where
87 Req: Serialize,
88 Resp: for<'de> Deserialize<'de>,
89 {
90 let max_attempts = self.load_balancer.config().max_failover_attempts as usize + 1;
91 let mut last_error = None;
92 let mut tried_servers = Vec::new();
93
94 for _ in 0..max_attempts {
95 let server_idx = match self.select_excluding(&tried_servers) {
96 Some(idx) => idx,
97 None => break,
98 };
99
100 tried_servers.push(server_idx);
101
102 let endpoint = match self.load_balancer.get_endpoint(server_idx) {
103 Some(ep) => ep,
104 None => continue,
105 };
106
107 let client = match self.get_or_create_client(server_idx, &endpoint).await {
108 Ok(c) => c,
109 Err(e) => {
110 self.load_balancer.record_failure(server_idx);
111 last_error = Some(e);
112 continue;
113 }
114 };
115
116 self.load_balancer.acquire(server_idx);
117 let result = client.call(method, request).await;
118 self.load_balancer.release(server_idx);
119
120 match result {
121 Ok(resp) => {
122 self.load_balancer.record_success(server_idx);
123 return Ok(resp);
124 }
125 Err(e) => {
126 self.load_balancer.record_failure(server_idx);
127 self.clients.write().remove(&server_idx);
128 last_error = Some(e);
129 }
130 }
131 }
132
133 Err(last_error.unwrap_or_else(|| RpcError::ClientError("No servers available".to_string())))
134 }
135
136 fn select_excluding(&self, excluded: &[usize]) -> Option<usize> {
137 for _ in 0..self.load_balancer.server_count().max(10) {
138 if let Some(idx) = self.load_balancer.select() {
139 if !excluded.contains(&idx) {
140 return Some(idx);
141 }
142 } else {
143 break;
144 }
145 }
146 None
147 }
148
149 pub async fn call_server_stream<Req, Resp>(
150 &self,
151 method: &str,
152 request: &Req,
153 ) -> Result<StreamReceiver<Resp, C>>
154 where
155 Req: Serialize,
156 Resp: for<'de> Deserialize<'de>,
157 {
158 let stream_id = crate::streaming::next_stream_id();
159
160 let server_idx = self
161 .load_balancer
162 .select_for_stream(stream_id)
163 .ok_or_else(|| RpcError::ClientError("No servers available".to_string()))?;
164
165 let endpoint = self
166 .load_balancer
167 .get_endpoint(server_idx)
168 .ok_or_else(|| RpcError::ClientError("Server not found".to_string()))?;
169
170 let client = self.get_or_create_client(server_idx, &endpoint).await?;
171
172 client.call_server_stream(method, request).await
173 }
174
175 pub async fn notify<Req: Serialize>(&self, method: &str, request: &Req) -> Result<()> {
176 let server_idx = self
177 .load_balancer
178 .select()
179 .ok_or_else(|| RpcError::ClientError("No servers available".to_string()))?;
180
181 let endpoint = self
182 .load_balancer
183 .get_endpoint(server_idx)
184 .ok_or_else(|| RpcError::ClientError("Server not found".to_string()))?;
185
186 let client = self.get_or_create_client(server_idx, &endpoint).await?;
187 client.notify(method, request).await
188 }
189
190 pub fn load_balancer(&self) -> &LoadBalancer<S> {
191 &self.load_balancer
192 }
193
194 pub fn connection_count(&self) -> usize {
195 self.clients.read().len()
196 }
197
198 pub fn release_stream(&self, stream_id: crate::streaming::StreamId) {
199 self.load_balancer.release_stream(stream_id);
200 }
201}
202
203impl<T, C, S> std::fmt::Debug for LoadBalancedClient<T, C, S>
204where
205 T: MessageChannel<C>,
206 C: Codec,
207 S: LoadBalanceStrategy + 'static,
208{
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 f.debug_struct("LoadBalancedClient")
211 .field("server_count", &self.load_balancer.server_count())
212 .field("available_count", &self.load_balancer.available_count())
213 .field("connection_count", &self.clients.read().len())
214 .field("strategy", &self.load_balancer.strategy_name())
215 .finish()
216 }
217}
218
219pub type DefaultLoadBalancedClient<T, S> = LoadBalancedClient<T, BincodeCodec, S>;
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use crate::channel::message::MessageChannelAdapter;
225 use crate::discovery::StaticDiscovery;
226 use crate::loadbalancer::RoundRobin;
227 use crate::transport::channel::{ChannelConfig, ChannelFrameTransport};
228 use std::sync::atomic::{AtomicUsize, Ordering};
229
230 struct MockClientFactory {
231 create_count: AtomicUsize,
232 }
233
234 impl MockClientFactory {
235 fn new() -> Self {
236 Self {
237 create_count: AtomicUsize::new(0),
238 }
239 }
240 }
241
242 #[async_trait]
243 impl ClientFactory<MessageChannelAdapter<ChannelFrameTransport>, BincodeCodec>
244 for MockClientFactory
245 {
246 async fn create(
247 &self,
248 _endpoint: &Endpoint,
249 ) -> Result<RpcClient<MessageChannelAdapter<ChannelFrameTransport>, BincodeCodec>> {
250 self.create_count.fetch_add(1, Ordering::Relaxed);
251
252 let (t1, _t2) =
253 ChannelFrameTransport::create_pair("mock", ChannelConfig::default()).unwrap();
254 let channel = MessageChannelAdapter::new(t1);
255 Ok(RpcClient::new(channel))
256 }
257 }
258
259 #[tokio::test]
260 async fn test_lb_client_creation() {
261 let endpoints = vec![
262 Endpoint::tcp_from_str("127.0.0.1:8001").unwrap(),
263 Endpoint::tcp_from_str("127.0.0.1:8002").unwrap(),
264 ];
265
266 let discovery = Arc::new(StaticDiscovery::new(endpoints));
267 let lb = Arc::new(LoadBalancer::new(discovery, RoundRobin::new()));
268
269 let factory = Arc::new(MockClientFactory::new());
270 let client = LoadBalancedClient::new(lb.clone(), factory.clone());
271
272 client.init().await.unwrap();
273
274 assert_eq!(client.load_balancer().server_count(), 2);
275 assert_eq!(factory.create_count.load(Ordering::Relaxed), 2);
276 }
277}