1use crate::cluster::ServerNode;
19use crate::error::Error;
20use crate::rpc::api_version::ApiVersion;
21use crate::rpc::error::RpcError;
22use crate::rpc::error::RpcError::ConnectionError;
23use crate::rpc::frame::{AsyncMessageRead, AsyncMessageWrite};
24use crate::rpc::message::{
25 ReadVersionedType, RequestBody, RequestHeader, ResponseHeader, WriteVersionedType,
26};
27use crate::rpc::transport::Transport;
28use futures::future::BoxFuture;
29use log::warn;
30use parking_lot::{Mutex, RwLock};
31use std::collections::HashMap;
32use std::fmt;
33use std::io::Cursor;
34use std::ops::DerefMut;
35use std::sync::Arc;
36use std::sync::atomic::{AtomicI32, Ordering};
37use std::task::Poll;
38use std::time::Duration;
39use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, WriteHalf};
40use tokio::sync::Mutex as AsyncMutex;
41use tokio::sync::oneshot::{Sender, channel};
42use tokio::task::JoinHandle;
43
44pub type MessengerTransport = ServerConnectionInner<BufStream<Transport>>;
45
46pub type ServerConnection = Arc<MessengerTransport>;
47
48const AUTH_INITIAL_BACKOFF_MS: f64 = 100.0;
50const AUTH_MAX_BACKOFF_MS: f64 = 5000.0;
51const AUTH_BACKOFF_MULTIPLIER: f64 = 2.0;
52const AUTH_JITTER: f64 = 0.2;
53
54#[derive(Clone)]
55pub struct SaslConfig {
56 pub username: String,
57 pub password: String,
58}
59
60impl fmt::Debug for SaslConfig {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 f.debug_struct("SaslConfig")
63 .field("username", &self.username)
64 .field("password", &"[REDACTED]")
65 .finish()
66 }
67}
68
69#[derive(Debug, Default)]
70pub struct RpcClient {
71 connections: RwLock<HashMap<String, ServerConnection>>,
72 client_id: Arc<str>,
73 timeout: Option<Duration>,
74 max_message_size: usize,
75 sasl_config: Option<SaslConfig>,
76}
77
78impl RpcClient {
79 pub fn new() -> Self {
80 RpcClient {
81 connections: Default::default(),
82 client_id: Arc::from(""),
83 timeout: None,
84 max_message_size: usize::MAX,
85 sasl_config: None,
86 }
87 }
88
89 pub fn with_timeout(mut self, timeout: Duration) -> Self {
90 self.timeout = Some(timeout);
91 self
92 }
93
94 pub fn with_sasl(mut self, username: String, password: String) -> Self {
95 self.sasl_config = Some(SaslConfig { username, password });
96 self
97 }
98
99 pub async fn get_connection(
100 &self,
101 server_node: &ServerNode,
102 ) -> Result<ServerConnection, Error> {
103 let server_id = server_node.uid();
104 {
105 let connections = self.connections.read();
106 if let Some(conn) = connections.get(server_id).cloned() {
107 if !conn.is_poisoned() {
108 return Ok(conn);
109 }
110 }
111 }
112 let new_server = self.connect(server_node).await?;
113 {
114 let mut connections = self.connections.write();
115 if let Some(race_conn) = connections.get(server_id) {
116 if !race_conn.is_poisoned() {
117 return Ok(race_conn.clone());
118 }
119 }
120
121 connections.insert(server_id.to_owned(), new_server.clone());
122 }
123 Ok(new_server)
124 }
125
126 async fn connect(&self, server_node: &ServerNode) -> Result<ServerConnection, Error> {
127 let url = server_node.url();
128 let transport = Transport::connect(&url, self.timeout)
129 .await
130 .map_err(|error| ConnectionError(error.to_string()))?;
131
132 let messenger = ServerConnectionInner::new(
133 BufStream::new(transport),
134 self.max_message_size,
135 self.client_id.clone(),
136 );
137 let connection = ServerConnection::new(messenger);
138
139 if let Some(ref sasl) = self.sasl_config {
140 Self::authenticate(&connection, &sasl.username, &sasl.password).await?;
141 }
142
143 Ok(connection)
144 }
145
146 async fn authenticate(
153 connection: &ServerConnection,
154 username: &str,
155 password: &str,
156 ) -> Result<(), Error> {
157 use crate::rpc::fluss_api_error::FlussError;
158 use crate::rpc::message::AuthenticateRequest;
159 use rand::Rng;
160
161 let initial_request = AuthenticateRequest::new_plain(username, password);
162 let mut retry_count: u32 = 0;
163
164 loop {
165 let request = initial_request.clone();
166 let result = connection.request(request).await;
167
168 match result {
169 Ok(response) => {
170 if let Some(challenge) = response.challenge {
174 let challenge_req = AuthenticateRequest::from_challenge("PLAIN", challenge);
175 connection.request(challenge_req).await?;
176 }
177 return Ok(());
178 }
179 Err(Error::FlussAPIError { ref api_error })
180 if FlussError::for_code(api_error.code)
181 == FlussError::RetriableAuthenticateException =>
182 {
183 retry_count += 1;
184 let exp_max = (AUTH_MAX_BACKOFF_MS / AUTH_INITIAL_BACKOFF_MS).log2();
188 let exp = ((retry_count as f64) - 1.0).min(exp_max);
189 let term = AUTH_INITIAL_BACKOFF_MS * AUTH_BACKOFF_MULTIPLIER.powf(exp);
190 let jitter_factor =
191 1.0 - AUTH_JITTER + rand::rng().random::<f64>() * (2.0 * AUTH_JITTER);
192 let backoff_ms = (term * jitter_factor) as u64;
193 log::warn!(
194 "SASL authentication retriable failure (attempt {retry_count}), \
195 retrying in {backoff_ms}ms: {}",
196 api_error.message
197 );
198 tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
199 }
200 Err(e) => return Err(e),
203 }
204 }
205 }
206}
207
208#[derive(Debug)]
209struct Response {
210 #[allow(dead_code)]
211 header: ResponseHeader,
212 data: Cursor<Vec<u8>>,
213}
214
215#[derive(Debug)]
216struct ActiveRequest {
217 channel: Sender<Result<Response, RpcError>>,
218}
219
220#[derive(Debug)]
221enum ConnectionState {
222 RequestMap(HashMap<i32, ActiveRequest>),
226
227 Poison(Arc<RpcError>),
229}
230
231impl ConnectionState {
232 fn poison(&mut self, err: RpcError) -> Arc<RpcError> {
233 match self {
234 Self::RequestMap(map) => {
235 let err = Arc::new(err);
236
237 for (_request_id, active_request) in map.drain() {
239 active_request
241 .channel
242 .send(Err(RpcError::Poisoned(Arc::clone(&err))))
243 .ok();
244 }
245 *self = Self::Poison(Arc::clone(&err));
246 err
247 }
248 Self::Poison(e) => {
249 Arc::clone(e)
251 }
252 }
253 }
254}
255
256#[derive(Debug)]
257pub struct ServerConnectionInner<RW> {
258 stream_write: Arc<AsyncMutex<WriteHalf<RW>>>,
262
263 client_id: Arc<str>,
264
265 request_id: AtomicI32,
266
267 state: Arc<Mutex<ConnectionState>>,
268
269 join_handle: JoinHandle<()>,
270}
271
272impl<RW> ServerConnectionInner<RW>
273where
274 RW: AsyncRead + AsyncWrite + Send + 'static,
275{
276 pub fn new(stream: RW, max_message_size: usize, client_id: Arc<str>) -> Self {
277 let (stream_read, stream_write) = tokio::io::split(stream);
278 let state = Arc::new(Mutex::new(ConnectionState::RequestMap(HashMap::default())));
279 let state_captured = Arc::clone(&state);
280
281 let join_handle = tokio::spawn(async move {
282 let mut stream_read = stream_read;
283 loop {
284 match stream_read.read_message(max_message_size).await {
285 Ok(msg) => {
286 let mut cursor = Cursor::new(msg);
288 let header =
289 match ResponseHeader::read_versioned(&mut cursor, ApiVersion(0)) {
290 Ok(header) => header,
291 Err(err) => {
292 log::warn!(
293 "Cannot read message header, ignoring message: {err:?}"
294 );
295 continue;
296 }
297 };
298
299 let active_request = match state_captured.lock().deref_mut() {
300 ConnectionState::RequestMap(map) => {
301 match map.remove(&header.request_id) {
302 Some(active_request) => active_request,
303 _ => {
304 log::warn!(
305 request_id:% = header.request_id;
306 "Got response for unknown request",
307 );
308 continue;
309 }
310 }
311 }
312 ConnectionState::Poison(_) => {
313 return;
315 }
316 };
317
318 active_request
320 .channel
321 .send(Ok(Response {
322 header,
323 data: cursor,
324 }))
325 .ok();
326 }
327 Err(e) => {
328 state_captured.lock().poison(RpcError::ReadMessageError(e));
329 return;
330 }
331 }
332 }
333 });
334
335 Self {
336 stream_write: Arc::new(AsyncMutex::new(stream_write)),
337 client_id,
338 request_id: AtomicI32::new(0),
339 state,
340 join_handle,
341 }
342 }
343
344 fn is_poisoned(&self) -> bool {
345 let guard = self.state.lock();
346 matches!(*guard, ConnectionState::Poison(_))
347 }
348
349 pub async fn request<R>(&self, msg: R) -> Result<R::ResponseBody, Error>
350 where
351 R: RequestBody + Send + WriteVersionedType<Vec<u8>>,
352 R::ResponseBody: ReadVersionedType<Cursor<Vec<u8>>>,
353 {
354 let request_id = self.request_id.fetch_add(1, Ordering::SeqCst) & 0x7FFFFFFF;
355 let header = RequestHeader {
356 request_api_key: R::API_KEY,
357 request_api_version: ApiVersion(0),
358 request_id,
359 client_id: Some(String::from(self.client_id.as_ref())),
360 };
361
362 let header_version = ApiVersion(0);
363
364 let body_api_version = ApiVersion(0);
365
366 let mut buf = Vec::new();
367 header
369 .write_versioned(&mut buf, header_version)
370 .map_err(RpcError::WriteMessageError)?;
371 msg.write_versioned(&mut buf, body_api_version)
373 .map_err(RpcError::WriteMessageError)?;
374
375 let (tx, rx) = channel();
376
377 let _cleanup_on_cancel =
380 CleanupRequestStateOnCancel::new(Arc::clone(&self.state), request_id);
381
382 match self.state.lock().deref_mut() {
383 ConnectionState::RequestMap(map) => {
384 map.insert(request_id, ActiveRequest { channel: tx });
385 }
386 ConnectionState::Poison(e) => return Err(RpcError::Poisoned(Arc::clone(e)).into()),
387 }
388
389 self.send_message(buf).await?;
390 _cleanup_on_cancel.message_sent();
391 let mut response = rx.await.map_err(|e| Error::UnexpectedError {
392 message: "Got recvError, some one close the channel".to_string(),
393 source: Some(Box::new(e)),
394 })??;
395
396 if let Some(error_response) = response.header.error_response {
397 return Err(Error::FlussAPIError {
398 api_error: crate::rpc::ApiError::from(error_response),
399 });
400 }
401
402 let body = R::ResponseBody::read_versioned(&mut response.data, body_api_version)
403 .map_err(RpcError::ReadMessageError)?;
404
405 let read_bytes = response.data.position();
406 let message_bytes = response.data.into_inner().len() as u64;
407 if read_bytes != message_bytes {
408 return Err(RpcError::TooMuchData {
409 message_size: message_bytes,
410 read: read_bytes,
411 api_key: R::API_KEY,
412 api_version: body_api_version,
413 }
414 .into());
415 }
416 Ok(body)
417 }
418
419 async fn send_message(&self, msg: Vec<u8>) -> Result<(), RpcError> {
420 match self.send_message_inner(msg).await {
421 Ok(()) => Ok(()),
422 Err(e) => {
423 let mut state = self.state.lock();
425 Err(RpcError::Poisoned(state.poison(e)))
426 }
427 }
428 }
429
430 async fn send_message_inner(&self, msg: Vec<u8>) -> Result<(), RpcError> {
431 let mut stream_write = Arc::clone(&self.stream_write).lock_owned().await;
432
433 let fut = CancellationSafeFuture::new(async move {
435 stream_write.write_message(&msg).await?;
436 stream_write.flush().await?;
437 Ok(())
438 });
439
440 fut.await
441 }
442}
443
444impl<RW> Drop for ServerConnectionInner<RW> {
445 fn drop(&mut self) {
446 self.join_handle.abort();
448 }
449}
450
451struct CancellationSafeFuture<F>
452where
453 F: Future + Send + 'static,
454{
455 done: bool,
457
458 inner: Option<BoxFuture<'static, F::Output>>,
465}
466
467impl<F> CancellationSafeFuture<F>
468where
469 F: Future + Send,
470{
471 fn new(fut: F) -> Self {
472 Self {
473 done: false,
474 inner: Some(Box::pin(fut)),
475 }
476 }
477}
478
479impl<F> Future for CancellationSafeFuture<F>
480where
481 F: Future + Send,
482{
483 type Output = F::Output;
484
485 fn poll(
486 mut self: std::pin::Pin<&mut Self>,
487 cx: &mut std::task::Context<'_>,
488 ) -> Poll<Self::Output> {
489 let inner = self
490 .inner
491 .as_mut()
492 .expect("CancellationSafeFuture polled after completion");
493
494 match inner.as_mut().poll(cx) {
495 Poll::Ready(res) => {
496 self.done = true;
497 self.inner = None; Poll::Ready(res)
499 }
500 Poll::Pending => Poll::Pending,
501 }
502 }
503}
504
505impl<F> Drop for CancellationSafeFuture<F>
506where
507 F: Future + Send + 'static,
508{
509 fn drop(&mut self) {
510 if let Some(fut) = self.inner.take() {
513 if let Ok(handle) = tokio::runtime::Handle::try_current() {
516 handle.spawn(async move {
517 let _ = fut.await;
518 });
519 } else {
520 warn!("Tokio runtime not found during drop; background task cancelled.");
525 }
526 }
527 }
528}
529
530struct CleanupRequestStateOnCancel {
532 state: Arc<Mutex<ConnectionState>>,
533 request_id: i32,
534 message_sent: bool,
535}
536
537impl CleanupRequestStateOnCancel {
538 fn new(state: Arc<Mutex<ConnectionState>>, request_id: i32) -> Self {
542 Self {
543 state,
544 request_id,
545 message_sent: false,
546 }
547 }
548
549 fn message_sent(mut self) {
551 self.message_sent = true;
552 }
553}
554
555impl Drop for CleanupRequestStateOnCancel {
556 fn drop(&mut self) {
557 if !self.message_sent {
558 if let ConnectionState::RequestMap(map) = self.state.lock().deref_mut() {
559 map.remove(&self.request_id);
560 }
561 }
562 }
563}