Skip to main content

xrpc/
lb_client.rs

1use async_trait::async_trait;
2use parking_lot::RwLock;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use crate::channel::message::MessageChannel;
8use crate::client::RpcClient;
9use crate::codec::{BincodeCodec, Codec};
10use crate::discovery::Endpoint;
11use crate::error::{Result, RpcError};
12use crate::loadbalancer::{LoadBalanceStrategy, LoadBalancer};
13use crate::streaming::StreamReceiver;
14
15#[async_trait]
16pub trait ClientFactory<T, C>: Send + Sync
17where
18    T: MessageChannel<C>,
19    C: Codec,
20{
21    async fn create(&self, endpoint: &Endpoint) -> Result<RpcClient<T, C>>;
22}
23
24pub struct LoadBalancedClient<T, C, S>
25where
26    T: MessageChannel<C>,
27    C: Codec,
28    S: LoadBalanceStrategy,
29{
30    load_balancer: Arc<LoadBalancer<S>>,
31    clients: RwLock<HashMap<usize, Arc<RpcClient<T, C>>>>,
32    factory: Arc<dyn ClientFactory<T, C>>,
33}
34
35impl<T, C, S> LoadBalancedClient<T, C, S>
36where
37    T: MessageChannel<C> + 'static,
38    C: Codec + Clone + Default + 'static,
39    S: LoadBalanceStrategy + 'static,
40{
41    pub fn new(load_balancer: Arc<LoadBalancer<S>>, factory: Arc<dyn ClientFactory<T, C>>) -> Self {
42        Self {
43            load_balancer,
44            clients: RwLock::new(HashMap::new()),
45            factory,
46        }
47    }
48
49    pub async fn init(&self) -> Result<()> {
50        self.load_balancer.init().await?;
51
52        for i in 0..self.load_balancer.server_count() {
53            if let Some(endpoint) = self.load_balancer.get_endpoint(i) {
54                let _ = self.get_or_create_client(i, &endpoint).await;
55            }
56        }
57
58        Ok(())
59    }
60
61    async fn get_or_create_client(
62        &self,
63        server_idx: usize,
64        endpoint: &Endpoint,
65    ) -> Result<Arc<RpcClient<T, C>>> {
66        {
67            let clients = self.clients.read();
68            if let Some(client) = clients.get(&server_idx) {
69                if client.is_connected() {
70                    return Ok(client.clone());
71                }
72            }
73        }
74
75        let client = self.factory.create(endpoint).await?;
76        let client = Arc::new(client);
77
78        let _handle = client.start();
79
80        self.clients.write().insert(server_idx, client.clone());
81
82        Ok(client)
83    }
84
85    pub async fn call<Req, Resp>(&self, method: &str, request: &Req) -> Result<Resp>
86    where
87        Req: Serialize,
88        Resp: for<'de> Deserialize<'de>,
89    {
90        let max_attempts = self.load_balancer.config().max_failover_attempts as usize + 1;
91        let mut last_error = None;
92        let mut tried_servers = Vec::new();
93
94        for _ in 0..max_attempts {
95            let server_idx = match self.select_excluding(&tried_servers) {
96                Some(idx) => idx,
97                None => break,
98            };
99
100            tried_servers.push(server_idx);
101
102            let endpoint = match self.load_balancer.get_endpoint(server_idx) {
103                Some(ep) => ep,
104                None => continue,
105            };
106
107            let client = match self.get_or_create_client(server_idx, &endpoint).await {
108                Ok(c) => c,
109                Err(e) => {
110                    self.load_balancer.record_failure(server_idx);
111                    last_error = Some(e);
112                    continue;
113                }
114            };
115
116            self.load_balancer.acquire(server_idx);
117            let result = client.call(method, request).await;
118            self.load_balancer.release(server_idx);
119
120            match result {
121                Ok(resp) => {
122                    self.load_balancer.record_success(server_idx);
123                    return Ok(resp);
124                }
125                Err(e) => {
126                    self.load_balancer.record_failure(server_idx);
127                    self.clients.write().remove(&server_idx);
128                    last_error = Some(e);
129                }
130            }
131        }
132
133        Err(last_error.unwrap_or_else(|| RpcError::ClientError("No servers available".to_string())))
134    }
135
136    fn select_excluding(&self, excluded: &[usize]) -> Option<usize> {
137        for _ in 0..self.load_balancer.server_count().max(10) {
138            if let Some(idx) = self.load_balancer.select() {
139                if !excluded.contains(&idx) {
140                    return Some(idx);
141                }
142            } else {
143                break;
144            }
145        }
146        None
147    }
148
149    pub async fn call_server_stream<Req, Resp>(
150        &self,
151        method: &str,
152        request: &Req,
153    ) -> Result<StreamReceiver<Resp, C>>
154    where
155        Req: Serialize,
156        Resp: for<'de> Deserialize<'de>,
157    {
158        let stream_id = crate::streaming::next_stream_id();
159
160        let server_idx = self
161            .load_balancer
162            .select_for_stream(stream_id)
163            .ok_or_else(|| RpcError::ClientError("No servers available".to_string()))?;
164
165        let endpoint = self
166            .load_balancer
167            .get_endpoint(server_idx)
168            .ok_or_else(|| RpcError::ClientError("Server not found".to_string()))?;
169
170        let client = self.get_or_create_client(server_idx, &endpoint).await?;
171
172        client.call_server_stream(method, request).await
173    }
174
175    pub async fn notify<Req: Serialize>(&self, method: &str, request: &Req) -> Result<()> {
176        let server_idx = self
177            .load_balancer
178            .select()
179            .ok_or_else(|| RpcError::ClientError("No servers available".to_string()))?;
180
181        let endpoint = self
182            .load_balancer
183            .get_endpoint(server_idx)
184            .ok_or_else(|| RpcError::ClientError("Server not found".to_string()))?;
185
186        let client = self.get_or_create_client(server_idx, &endpoint).await?;
187        client.notify(method, request).await
188    }
189
190    pub fn load_balancer(&self) -> &LoadBalancer<S> {
191        &self.load_balancer
192    }
193
194    pub fn connection_count(&self) -> usize {
195        self.clients.read().len()
196    }
197
198    pub fn release_stream(&self, stream_id: crate::streaming::StreamId) {
199        self.load_balancer.release_stream(stream_id);
200    }
201}
202
203impl<T, C, S> std::fmt::Debug for LoadBalancedClient<T, C, S>
204where
205    T: MessageChannel<C>,
206    C: Codec,
207    S: LoadBalanceStrategy + 'static,
208{
209    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210        f.debug_struct("LoadBalancedClient")
211            .field("server_count", &self.load_balancer.server_count())
212            .field("available_count", &self.load_balancer.available_count())
213            .field("connection_count", &self.clients.read().len())
214            .field("strategy", &self.load_balancer.strategy_name())
215            .finish()
216    }
217}
218
219pub type DefaultLoadBalancedClient<T, S> = LoadBalancedClient<T, BincodeCodec, S>;
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use crate::channel::message::MessageChannelAdapter;
225    use crate::discovery::StaticDiscovery;
226    use crate::loadbalancer::RoundRobin;
227    use crate::transport::channel::{ChannelConfig, ChannelFrameTransport};
228    use std::sync::atomic::{AtomicUsize, Ordering};
229
230    struct MockClientFactory {
231        create_count: AtomicUsize,
232    }
233
234    impl MockClientFactory {
235        fn new() -> Self {
236            Self {
237                create_count: AtomicUsize::new(0),
238            }
239        }
240    }
241
242    #[async_trait]
243    impl ClientFactory<MessageChannelAdapter<ChannelFrameTransport>, BincodeCodec>
244        for MockClientFactory
245    {
246        async fn create(
247            &self,
248            _endpoint: &Endpoint,
249        ) -> Result<RpcClient<MessageChannelAdapter<ChannelFrameTransport>, BincodeCodec>> {
250            self.create_count.fetch_add(1, Ordering::Relaxed);
251
252            let (t1, _t2) =
253                ChannelFrameTransport::create_pair("mock", ChannelConfig::default()).unwrap();
254            let channel = MessageChannelAdapter::new(t1);
255            Ok(RpcClient::new(channel))
256        }
257    }
258
259    #[tokio::test]
260    async fn test_lb_client_creation() {
261        let endpoints = vec![
262            Endpoint::tcp_from_str("127.0.0.1:8001").unwrap(),
263            Endpoint::tcp_from_str("127.0.0.1:8002").unwrap(),
264        ];
265
266        let discovery = Arc::new(StaticDiscovery::new(endpoints));
267        let lb = Arc::new(LoadBalancer::new(discovery, RoundRobin::new()));
268
269        let factory = Arc::new(MockClientFactory::new());
270        let client = LoadBalancedClient::new(lb.clone(), factory.clone());
271
272        client.init().await.unwrap();
273
274        assert_eq!(client.load_balancer().server_count(), 2);
275        assert_eq!(factory.create_count.load(Ordering::Relaxed), 2);
276    }
277}