Skip to main content

fluss/rpc/
server_connection.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::cluster::ServerNode;
19use crate::error::Error;
20use crate::rpc::api_version::ApiVersion;
21use crate::rpc::error::RpcError;
22use crate::rpc::error::RpcError::ConnectionError;
23use crate::rpc::frame::{AsyncMessageRead, AsyncMessageWrite};
24use crate::rpc::message::{
25    ReadVersionedType, RequestBody, RequestHeader, ResponseHeader, WriteVersionedType,
26};
27use crate::rpc::transport::Transport;
28use futures::future::BoxFuture;
29use log::warn;
30use parking_lot::{Mutex, RwLock};
31use std::collections::HashMap;
32use std::fmt;
33use std::io::Cursor;
34use std::ops::DerefMut;
35use std::sync::Arc;
36use std::sync::atomic::{AtomicI32, Ordering};
37use std::task::Poll;
38use std::time::Duration;
39use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, WriteHalf};
40use tokio::sync::Mutex as AsyncMutex;
41use tokio::sync::oneshot::{Sender, channel};
42use tokio::task::JoinHandle;
43
44pub type MessengerTransport = ServerConnectionInner<BufStream<Transport>>;
45
46pub type ServerConnection = Arc<MessengerTransport>;
47
48// Matches Java's ExponentialBackoff(100ms initial, 2x multiplier, 5000ms max, 0.2 jitter).
49const AUTH_INITIAL_BACKOFF_MS: f64 = 100.0;
50const AUTH_MAX_BACKOFF_MS: f64 = 5000.0;
51const AUTH_BACKOFF_MULTIPLIER: f64 = 2.0;
52const AUTH_JITTER: f64 = 0.2;
53
54#[derive(Clone)]
55pub struct SaslConfig {
56    pub username: String,
57    pub password: String,
58}
59
60impl fmt::Debug for SaslConfig {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62        f.debug_struct("SaslConfig")
63            .field("username", &self.username)
64            .field("password", &"[REDACTED]")
65            .finish()
66    }
67}
68
69#[derive(Debug, Default)]
70pub struct RpcClient {
71    connections: RwLock<HashMap<String, ServerConnection>>,
72    client_id: Arc<str>,
73    timeout: Option<Duration>,
74    max_message_size: usize,
75    sasl_config: Option<SaslConfig>,
76}
77
78impl RpcClient {
79    pub fn new() -> Self {
80        RpcClient {
81            connections: Default::default(),
82            client_id: Arc::from(""),
83            timeout: None,
84            max_message_size: usize::MAX,
85            sasl_config: None,
86        }
87    }
88
89    pub fn with_timeout(mut self, timeout: Duration) -> Self {
90        self.timeout = Some(timeout);
91        self
92    }
93
94    pub fn with_sasl(mut self, username: String, password: String) -> Self {
95        self.sasl_config = Some(SaslConfig { username, password });
96        self
97    }
98
99    pub async fn get_connection(
100        &self,
101        server_node: &ServerNode,
102    ) -> Result<ServerConnection, Error> {
103        let server_id = server_node.uid();
104        {
105            let connections = self.connections.read();
106            if let Some(conn) = connections.get(server_id).cloned() {
107                if !conn.is_poisoned() {
108                    return Ok(conn);
109                }
110            }
111        }
112        let new_server = self.connect(server_node).await?;
113        {
114            let mut connections = self.connections.write();
115            if let Some(race_conn) = connections.get(server_id) {
116                if !race_conn.is_poisoned() {
117                    return Ok(race_conn.clone());
118                }
119            }
120
121            connections.insert(server_id.to_owned(), new_server.clone());
122        }
123        Ok(new_server)
124    }
125
126    async fn connect(&self, server_node: &ServerNode) -> Result<ServerConnection, Error> {
127        let url = server_node.url();
128        let transport = Transport::connect(&url, self.timeout)
129            .await
130            .map_err(|error| ConnectionError(error.to_string()))?;
131
132        let messenger = ServerConnectionInner::new(
133            BufStream::new(transport),
134            self.max_message_size,
135            self.client_id.clone(),
136        );
137        let connection = ServerConnection::new(messenger);
138
139        if let Some(ref sasl) = self.sasl_config {
140            Self::authenticate(&connection, &sasl.username, &sasl.password).await?;
141        }
142
143        Ok(connection)
144    }
145
146    /// Perform SASL/PLAIN authentication handshake.
147    ///
148    /// Retries on `RetriableAuthenticateException` with exponential backoff
149    /// (matching Java's unbounded retry behaviour). Non-retriable errors
150    /// (wrong password, unknown user) propagate immediately as
151    /// `Error::FlussAPIError` with the original error code.
152    async fn authenticate(
153        connection: &ServerConnection,
154        username: &str,
155        password: &str,
156    ) -> Result<(), Error> {
157        use crate::rpc::fluss_api_error::FlussError;
158        use crate::rpc::message::AuthenticateRequest;
159        use rand::Rng;
160
161        let initial_request = AuthenticateRequest::new_plain(username, password);
162        let mut retry_count: u32 = 0;
163
164        loop {
165            let request = initial_request.clone();
166            let result = connection.request(request).await;
167
168            match result {
169                Ok(response) => {
170                    // Check for server challenge (multi-round auth).
171                    // PLAIN mechanism never sends a challenge, but we handle it
172                    // for protocol correctness matching Java's handleAuthenticateResponse.
173                    if let Some(challenge) = response.challenge {
174                        let challenge_req = AuthenticateRequest::from_challenge("PLAIN", challenge);
175                        connection.request(challenge_req).await?;
176                    }
177                    return Ok(());
178                }
179                Err(Error::FlussAPIError { ref api_error })
180                    if FlussError::for_code(api_error.code)
181                        == FlussError::RetriableAuthenticateException =>
182                {
183                    retry_count += 1;
184                    // Cap the exponent like Java's ExponentialBackoff.expMax so that
185                    // jitter still produces a range at steady state instead of being
186                    // clamped to AUTH_MAX_BACKOFF_MS.
187                    let exp_max = (AUTH_MAX_BACKOFF_MS / AUTH_INITIAL_BACKOFF_MS).log2();
188                    let exp = ((retry_count as f64) - 1.0).min(exp_max);
189                    let term = AUTH_INITIAL_BACKOFF_MS * AUTH_BACKOFF_MULTIPLIER.powf(exp);
190                    let jitter_factor =
191                        1.0 - AUTH_JITTER + rand::rng().random::<f64>() * (2.0 * AUTH_JITTER);
192                    let backoff_ms = (term * jitter_factor) as u64;
193                    log::warn!(
194                        "SASL authentication retriable failure (attempt {retry_count}), \
195                         retrying in {backoff_ms}ms: {}",
196                        api_error.message
197                    );
198                    tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
199                }
200                // Server-side auth errors (wrong password, unknown user, etc.)
201                // propagate with their original error code preserved.
202                Err(e) => return Err(e),
203            }
204        }
205    }
206}
207
208#[derive(Debug)]
209struct Response {
210    #[allow(dead_code)]
211    header: ResponseHeader,
212    data: Cursor<Vec<u8>>,
213}
214
215#[derive(Debug)]
216struct ActiveRequest {
217    channel: Sender<Result<Response, RpcError>>,
218}
219
220#[derive(Debug)]
221enum ConnectionState {
222    /// Currently active requests by request ID.
223    ///
224    /// An active request is one that got prepared or send but the response wasn't received yet.
225    RequestMap(HashMap<i32, ActiveRequest>),
226
227    /// One or our streams died and we are unable to process any more requests.
228    Poison(Arc<RpcError>),
229}
230
231impl ConnectionState {
232    fn poison(&mut self, err: RpcError) -> Arc<RpcError> {
233        match self {
234            Self::RequestMap(map) => {
235                let err = Arc::new(err);
236
237                // inform all active requests
238                for (_request_id, active_request) in map.drain() {
239                    // it's OK if the other side is gone
240                    active_request
241                        .channel
242                        .send(Err(RpcError::Poisoned(Arc::clone(&err))))
243                        .ok();
244                }
245                *self = Self::Poison(Arc::clone(&err));
246                err
247            }
248            Self::Poison(e) => {
249                // already poisoned, used existing error
250                Arc::clone(e)
251            }
252        }
253    }
254}
255
256#[derive(Debug)]
257pub struct ServerConnectionInner<RW> {
258    /// The half of the stream that we use to send data TO the broker.
259    ///
260    /// This will be used by [`request`](Self::request) to queue up messages.
261    stream_write: Arc<AsyncMutex<WriteHalf<RW>>>,
262
263    client_id: Arc<str>,
264
265    request_id: AtomicI32,
266
267    state: Arc<Mutex<ConnectionState>>,
268
269    join_handle: JoinHandle<()>,
270}
271
272impl<RW> ServerConnectionInner<RW>
273where
274    RW: AsyncRead + AsyncWrite + Send + 'static,
275{
276    pub fn new(stream: RW, max_message_size: usize, client_id: Arc<str>) -> Self {
277        let (stream_read, stream_write) = tokio::io::split(stream);
278        let state = Arc::new(Mutex::new(ConnectionState::RequestMap(HashMap::default())));
279        let state_captured = Arc::clone(&state);
280
281        let join_handle = tokio::spawn(async move {
282            let mut stream_read = stream_read;
283            loop {
284                match stream_read.read_message(max_message_size).await {
285                    Ok(msg) => {
286                        // message was read, so all subsequent errors should not poison the whole stream
287                        let mut cursor = Cursor::new(msg);
288                        let header =
289                            match ResponseHeader::read_versioned(&mut cursor, ApiVersion(0)) {
290                                Ok(header) => header,
291                                Err(err) => {
292                                    log::warn!(
293                                        "Cannot read message header, ignoring message: {err:?}"
294                                    );
295                                    continue;
296                                }
297                            };
298
299                        let active_request = match state_captured.lock().deref_mut() {
300                            ConnectionState::RequestMap(map) => {
301                                match map.remove(&header.request_id) {
302                                    Some(active_request) => active_request,
303                                    _ => {
304                                        log::warn!(
305                                            request_id:% = header.request_id;
306                                            "Got response for unknown request",
307                                        );
308                                        continue;
309                                    }
310                                }
311                            }
312                            ConnectionState::Poison(_) => {
313                                // stream is poisoned, no need to anything
314                                return;
315                            }
316                        };
317
318                        // we don't care if the other side is gone
319                        active_request
320                            .channel
321                            .send(Ok(Response {
322                                header,
323                                data: cursor,
324                            }))
325                            .ok();
326                    }
327                    Err(e) => {
328                        state_captured.lock().poison(RpcError::ReadMessageError(e));
329                        return;
330                    }
331                }
332            }
333        });
334
335        Self {
336            stream_write: Arc::new(AsyncMutex::new(stream_write)),
337            client_id,
338            request_id: AtomicI32::new(0),
339            state,
340            join_handle,
341        }
342    }
343
344    fn is_poisoned(&self) -> bool {
345        let guard = self.state.lock();
346        matches!(*guard, ConnectionState::Poison(_))
347    }
348
349    pub async fn request<R>(&self, msg: R) -> Result<R::ResponseBody, Error>
350    where
351        R: RequestBody + Send + WriteVersionedType<Vec<u8>>,
352        R::ResponseBody: ReadVersionedType<Cursor<Vec<u8>>>,
353    {
354        let request_id = self.request_id.fetch_add(1, Ordering::SeqCst) & 0x7FFFFFFF;
355        let header = RequestHeader {
356            request_api_key: R::API_KEY,
357            request_api_version: ApiVersion(0),
358            request_id,
359            client_id: Some(String::from(self.client_id.as_ref())),
360        };
361
362        let header_version = ApiVersion(0);
363
364        let body_api_version = ApiVersion(0);
365
366        let mut buf = Vec::new();
367        // write header
368        header
369            .write_versioned(&mut buf, header_version)
370            .map_err(RpcError::WriteMessageError)?;
371        // write message body
372        msg.write_versioned(&mut buf, body_api_version)
373            .map_err(RpcError::WriteMessageError)?;
374
375        let (tx, rx) = channel();
376
377        // to prevent stale data in inner state, ensure that we would remove the request again if we are cancelled while
378        // sending the request
379        let _cleanup_on_cancel =
380            CleanupRequestStateOnCancel::new(Arc::clone(&self.state), request_id);
381
382        match self.state.lock().deref_mut() {
383            ConnectionState::RequestMap(map) => {
384                map.insert(request_id, ActiveRequest { channel: tx });
385            }
386            ConnectionState::Poison(e) => return Err(RpcError::Poisoned(Arc::clone(e)).into()),
387        }
388
389        self.send_message(buf).await?;
390        _cleanup_on_cancel.message_sent();
391        let mut response = rx.await.map_err(|e| Error::UnexpectedError {
392            message: "Got recvError, some one close the channel".to_string(),
393            source: Some(Box::new(e)),
394        })??;
395
396        if let Some(error_response) = response.header.error_response {
397            return Err(Error::FlussAPIError {
398                api_error: crate::rpc::ApiError::from(error_response),
399            });
400        }
401
402        let body = R::ResponseBody::read_versioned(&mut response.data, body_api_version)
403            .map_err(RpcError::ReadMessageError)?;
404
405        let read_bytes = response.data.position();
406        let message_bytes = response.data.into_inner().len() as u64;
407        if read_bytes != message_bytes {
408            return Err(RpcError::TooMuchData {
409                message_size: message_bytes,
410                read: read_bytes,
411                api_key: R::API_KEY,
412                api_version: body_api_version,
413            }
414            .into());
415        }
416        Ok(body)
417    }
418
419    async fn send_message(&self, msg: Vec<u8>) -> Result<(), RpcError> {
420        match self.send_message_inner(msg).await {
421            Ok(()) => Ok(()),
422            Err(e) => {
423                // need to poison the stream because message framing might be out-of-sync
424                let mut state = self.state.lock();
425                Err(RpcError::Poisoned(state.poison(e)))
426            }
427        }
428    }
429
430    async fn send_message_inner(&self, msg: Vec<u8>) -> Result<(), RpcError> {
431        let mut stream_write = Arc::clone(&self.stream_write).lock_owned().await;
432
433        // use a wrapper so that cancellation doesn't cancel the send operation and leaves half-send messages on the wire
434        let fut = CancellationSafeFuture::new(async move {
435            stream_write.write_message(&msg).await?;
436            stream_write.flush().await?;
437            Ok(())
438        });
439
440        fut.await
441    }
442}
443
444impl<RW> Drop for ServerConnectionInner<RW> {
445    fn drop(&mut self) {
446        // todo: should remove from server_connections map?
447        self.join_handle.abort();
448    }
449}
450
451struct CancellationSafeFuture<F>
452where
453    F: Future + Send + 'static,
454{
455    /// Mark if the inner future finished. If not, we must spawn a helper task on drop.
456    done: bool,
457
458    /// Inner future.
459    ///
460    /// Wrapped in an `Option` so we can extract it during drop. Inside that option however we also need a pinned
461    /// box because once this wrapper is polled, it will be pinned in memory -- even during drop. Now the inner
462    /// future does not necessarily implement `Unpin`, so we need a heap allocation to pin it in memory even when we
463    /// move it out of this option.
464    inner: Option<BoxFuture<'static, F::Output>>,
465}
466
467impl<F> CancellationSafeFuture<F>
468where
469    F: Future + Send,
470{
471    fn new(fut: F) -> Self {
472        Self {
473            done: false,
474            inner: Some(Box::pin(fut)),
475        }
476    }
477}
478
479impl<F> Future for CancellationSafeFuture<F>
480where
481    F: Future + Send,
482{
483    type Output = F::Output;
484
485    fn poll(
486        mut self: std::pin::Pin<&mut Self>,
487        cx: &mut std::task::Context<'_>,
488    ) -> Poll<Self::Output> {
489        let inner = self
490            .inner
491            .as_mut()
492            .expect("CancellationSafeFuture polled after completion");
493
494        match inner.as_mut().poll(cx) {
495            Poll::Ready(res) => {
496                self.done = true;
497                self.inner = None; // Prevent re-polling
498                Poll::Ready(res)
499            }
500            Poll::Pending => Poll::Pending,
501        }
502    }
503}
504
505impl<F> Drop for CancellationSafeFuture<F>
506where
507    F: Future + Send + 'static,
508{
509    fn drop(&mut self) {
510        // If the future hasn't finished yet, we must ensure it completes in the background.
511        // This prevents leaving half-sent messages on the wire if the caller cancels the request.
512        if let Some(fut) = self.inner.take() {
513            // Attempt to get a handle to the current Tokio runtime.
514            // This avoids a panic if the runtime has already shut down.
515            if let Ok(handle) = tokio::runtime::Handle::try_current() {
516                handle.spawn(async move {
517                    let _ = fut.await;
518                });
519            } else {
520                // Fallback: If no runtime is active, we cannot spawn.
521                // At this point, the future 'fut' will be dropped.
522                // Since the runtime is likely shutting down anyway,
523                // the underlying connection is probably being closed.
524                warn!("Tokio runtime not found during drop; background task cancelled.");
525            }
526        }
527    }
528}
529
530/// Helper that ensures that a request is removed when a request is cancelled before it was actually sent out.
531struct CleanupRequestStateOnCancel {
532    state: Arc<Mutex<ConnectionState>>,
533    request_id: i32,
534    message_sent: bool,
535}
536
537impl CleanupRequestStateOnCancel {
538    /// Create new helper.
539    ///
540    /// You must call [`message_sent`](Self::message_sent) when the request was sent.
541    fn new(state: Arc<Mutex<ConnectionState>>, request_id: i32) -> Self {
542        Self {
543            state,
544            request_id,
545            message_sent: false,
546        }
547    }
548
549    /// Request was sent. Do NOT clean the state any longer.
550    fn message_sent(mut self) {
551        self.message_sent = true;
552    }
553}
554
555impl Drop for CleanupRequestStateOnCancel {
556    fn drop(&mut self) {
557        if !self.message_sent {
558            if let ConnectionState::RequestMap(map) = self.state.lock().deref_mut() {
559                map.remove(&self.request_id);
560            }
561        }
562    }
563}