Skip to main content

couchbase_core/memdx/
client.rs

1/*
2 *
3 *  * Copyright (c) 2025 Couchbase, Inc.
4 *  *
5 *  * Licensed under the Apache License, Version 2.0 (the "License");
6 *  * you may not use this file except in compliance with the License.
7 *  * You may obtain a copy of the License at
8 *  *
9 *  *    http://www.apache.org/licenses/LICENSE-2.0
10 *  *
11 *  * Unless required by applicable law or agreed to in writing, software
12 *  * distributed under the License is distributed on an "AS IS" BASIS,
13 *  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 *  * See the License for the specific language governing permissions and
15 *  * limitations under the License.
16 *
17 */
18
19use async_trait::async_trait;
20use bytes::Bytes;
21use futures::{SinkExt, TryFutureExt};
22use snap::raw::Decoder;
23use std::backtrace::Backtrace;
24use std::cell::RefCell;
25use std::collections::HashMap;
26use std::io::empty;
27use std::net::SocketAddr;
28use std::pin::pin;
29use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
30use std::sync::Arc;
31use std::thread::spawn;
32use std::{env, mem};
33use tokio::io::{AsyncRead, AsyncWrite, Join, ReadHalf, WriteHalf};
34use tokio::select;
35use tokio::sync::mpsc::unbounded_channel;
36use tokio::sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender};
37use tokio::sync::{mpsc, oneshot, Mutex, MutexGuard, RwLock};
38use tokio::task::JoinHandle;
39use tokio_stream::StreamExt;
40use tokio_util::codec::{FramedRead, FramedWrite};
41use tokio_util::sync::{CancellationToken, DropGuard};
42use tracing::{debug, error, info, trace, warn};
43use uuid::Uuid;
44
45use crate::memdx::client_response::ClientResponse;
46use crate::memdx::codec::KeyValueCodec;
47use crate::memdx::connection::{ConnectionType, Stream};
48use crate::memdx::datatype::DataTypeFlag;
49use crate::memdx::dispatcher::{
50    Dispatcher, DispatcherOptions, OnReadLoopCloseHandler, OrphanResponseHandler,
51    UnsolicitedPacketHandler,
52};
53use crate::memdx::error;
54use crate::memdx::error::{CancellationErrorKind, Error};
55use crate::memdx::hello_feature::HelloFeature::DataType;
56use crate::memdx::magic::Magic;
57use crate::memdx::opcode::OpCode;
58use crate::memdx::packet::{RequestPacket, ResponsePacket};
59use crate::memdx::pendingop::ClientPendingOp;
60use crate::memdx::subdoc::SubdocRequestInfo;
61use crate::orphan_reporter::OrphanContext;
62
63pub(crate) type ResponseSender = Sender<error::Result<ClientResponse>>;
64pub(crate) type OpaqueMap = HashMap<u32, SenderContext>;
65
66#[derive(Debug, Clone)]
67pub struct ResponseContext {
68    pub cas: Option<u64>,
69    pub subdoc_info: Option<SubdocRequestInfo>,
70    pub scope_name: Option<String>,
71    pub collection_name: Option<String>,
72}
73
74#[derive(Debug, Clone)]
75pub(crate) struct SenderContext {
76    pub sender: ResponseSender,
77    pub is_persistent: bool,
78    pub context: Option<ResponseContext>,
79}
80
81struct ReadLoopOptions {
82    pub client_id: String,
83    pub unsolicited_packet_handler: UnsolicitedPacketHandler,
84    pub orphan_handler: Option<OrphanResponseHandler>,
85    pub on_read_close_handler: OnReadLoopCloseHandler,
86    pub on_close_cancel: CancellationToken,
87    pub disable_decompression: bool,
88    pub local_addr: SocketAddr,
89    pub peer_addr: SocketAddr,
90}
91
92#[derive(Debug)]
93struct ClientReadHandle {
94    read_handle: JoinHandle<()>,
95}
96
97impl ClientReadHandle {
98    pub async fn await_completion(&mut self) {
99        (&mut self.read_handle).await.unwrap_or_default()
100    }
101}
102
103#[derive(Debug)]
104pub struct Client {
105    current_opaque: AtomicU32,
106    opaque_map: Arc<std::sync::Mutex<OpaqueMap>>,
107
108    client_id: String,
109
110    writer: Mutex<FramedWrite<WriteHalf<Box<dyn Stream>>, KeyValueCodec>>,
111    on_close_cancel: DropGuard,
112
113    local_addr: SocketAddr,
114    peer_addr: SocketAddr,
115
116    closed: AtomicBool,
117}
118
119impl Client {
120    fn register_handler(&self, response_context: SenderContext) -> u32 {
121        let mut map = self.opaque_map.lock().unwrap();
122
123        let opaque = self.current_opaque.fetch_add(1, Ordering::SeqCst);
124
125        map.insert(opaque, response_context);
126
127        opaque
128    }
129
130    async fn drain_opaque_map(opaque_map: Arc<std::sync::Mutex<OpaqueMap>>) {
131        let mut senders = vec![];
132        {
133            let mut guard = opaque_map.lock().unwrap();
134            guard.drain().for_each(|(_, v)| {
135                senders.push(v);
136            });
137        }
138
139        for sender in senders {
140            sender
141                .sender
142                .send(Err(Error::new_cancelled_error(
143                    CancellationErrorKind::ClosedInFlight,
144                )))
145                .await
146                .unwrap_or_default();
147        }
148    }
149
150    async fn on_read_loop_close(
151        client_id: &str,
152        stream: FramedRead<ReadHalf<Box<dyn Stream>>, KeyValueCodec>,
153        opaque_map: Arc<std::sync::Mutex<OpaqueMap>>,
154        on_read_loop_close: OnReadLoopCloseHandler,
155    ) {
156        drop(stream);
157
158        Self::drain_opaque_map(opaque_map).await;
159
160        if on_read_loop_close.send(()).is_err() {
161            error!("{} failed to notify read loop closure", &client_id);
162        }
163
164        debug!("{client_id} read loop shut down");
165    }
166
167    async fn read_loop(
168        mut stream: FramedRead<ReadHalf<Box<dyn Stream>>, KeyValueCodec>,
169        opaque_map: Arc<std::sync::Mutex<OpaqueMap>>,
170        mut opts: ReadLoopOptions,
171    ) {
172        loop {
173            select! {
174                (_) = opts.on_close_cancel.cancelled() => {
175                    Self::on_read_loop_close(&opts.client_id, stream, opaque_map, opts.on_read_close_handler).await;
176                    return;
177                },
178                (next) = stream.next() => {
179                    match next {
180                        Some(input) => {
181                            match input {
182                                Ok(mut packet) => {
183                                    if packet.magic == Magic::ServerReq {
184
185                                        trace!(
186                                            "Handling server request on {}. Opcode={}",
187                                            opts.client_id,
188                                            packet.op_code,
189                                        );
190
191                                        (opts.unsolicited_packet_handler)(packet).await;
192                                        continue;
193                                    }
194
195                                    trace!(
196                                        "Resolving response on {}. Opcode={}. Opaque={}. Status={}",
197                                        opts.client_id,
198                                        packet.op_code,
199                                        packet.opaque,
200                                        packet.status,
201                                    );
202
203                                    let opaque = packet.opaque;
204
205                                    let requests: Arc<std::sync::Mutex<OpaqueMap>> = Arc::clone(&opaque_map);
206                                    let context = {
207                                        let mut map = requests.lock().unwrap();
208                                        map.remove(&opaque)
209                                    };
210
211                                    if let Some(mut context) = context {
212                                        let sender = &context.sender;
213
214                                        if let Some(value) = &packet.value {
215                                            if !opts.disable_decompression && (packet.datatype & u8::from(DataTypeFlag::Compressed) != 0) {
216                                                let mut decoder = Decoder::new();
217                                                let new_value = match decoder
218                                                    .decompress_vec(value)
219                                                     {
220                                                        Ok(v) => v,
221                                                        Err(e) => {
222                                                            match sender.send(Err(Error::new_decompression_error().with(e))).await{
223                                                                Ok(_) => {}
224                                                                Err(e) => {
225                                                                     debug!("Sending response to caller failed: {e}");
226                                                                }
227                                                            };
228                                                         continue;
229                                                        }
230                                                    };
231
232                                                packet.datatype &= !u8::from(DataTypeFlag::Compressed);
233                                                packet.value = Some(Bytes::from(new_value));
234                                            }
235                                        }
236
237                                        if context.is_persistent {
238                                            {
239                                                let mut map = requests.lock().unwrap();
240                                                map.insert(opaque, context.clone());
241                                            }
242                                        }
243
244                                        let resp = ClientResponse::new(packet, context.context);
245                                        match sender.send(Ok(resp)).await {
246                                            Ok(_) => {}
247                                            Err(e) => {
248                                                debug!("Sending response to caller failed: {e}");
249                                                Self::on_read_loop_close(&opts.client_id, stream, opaque_map, opts.on_read_close_handler).await;
250                                                return;
251                                            }
252                                        };
253                                    } else if let Some(ref orphan_handler) = opts.orphan_handler {
254                                        orphan_handler(
255                                            packet,
256                                            OrphanContext {
257                                                client_id: opts.client_id.clone(),
258                                                local_addr: opts.local_addr,
259                                                peer_addr: opts.peer_addr,
260                                            },
261                                        );
262                                    }
263                                    drop(requests);
264                                }
265                                Err(e) => {
266                                    warn!("{} failed to read frame {}", opts.client_id, e);
267                                    Self::on_read_loop_close(&opts.client_id, stream, opaque_map, opts.on_read_close_handler).await;
268                                    return;
269                                }
270                            }
271                        }
272                        None => {
273                            Self::on_read_loop_close(&opts.client_id, stream, opaque_map, opts.on_read_close_handler).await;
274                            return;
275                        }
276                    }
277                }
278            }
279        }
280    }
281
282    fn split_stream<StreamType: AsyncRead + AsyncWrite + Send + Unpin>(
283        stream: StreamType,
284    ) -> (ReadHalf<StreamType>, WriteHalf<StreamType>) {
285        tokio::io::split(stream)
286    }
287}
288
289/// A guard that removes an opaque entry from the map when dropped, unless disarmed.
290/// This prevents opaque map leaks if the dispatch future is cancelled/dropped at an
291/// `.await` point before a `ClientPendingOp` is created to take over cleanup.
292struct DispatchOpaqueGuard {
293    opaque: u32,
294    opaque_map: Option<Arc<std::sync::Mutex<OpaqueMap>>>,
295}
296
297impl DispatchOpaqueGuard {
298    fn new(opaque: u32, opaque_map: Arc<std::sync::Mutex<OpaqueMap>>) -> Self {
299        Self {
300            opaque,
301            opaque_map: Some(opaque_map),
302        }
303    }
304
305    /// Disarm the guard so that dropping it will not remove the opaque entry.
306    fn disarm(&mut self) {
307        self.opaque_map = None;
308    }
309}
310
311impl Drop for DispatchOpaqueGuard {
312    fn drop(&mut self) {
313        if let Some(opaque_map) = self.opaque_map.take() {
314            let mut map = opaque_map.lock().unwrap();
315            map.remove(&self.opaque);
316        }
317    }
318}
319
320#[async_trait]
321impl Dispatcher for Client {
322    fn new(conn: ConnectionType, opts: DispatcherOptions) -> Self {
323        let local_addr = *conn.local_addr();
324        let peer_addr = *conn.peer_addr();
325
326        let (r, w) = tokio::io::split(conn.into_inner());
327
328        let codec = KeyValueCodec::default();
329        let reader = FramedRead::new(r, codec);
330        let writer = FramedWrite::new(w, codec);
331
332        let cancel_token = CancellationToken::new();
333        let cancel_child = cancel_token.child_token();
334        let cancel_guard = cancel_token.drop_guard();
335
336        let opaque_map = Arc::new(std::sync::Mutex::new(OpaqueMap::default()));
337
338        let read_opaque_map = Arc::clone(&opaque_map);
339        let read_uuid = opts.id.clone();
340
341        tokio::spawn(async move {
342            Client::read_loop(
343                reader,
344                read_opaque_map,
345                ReadLoopOptions {
346                    client_id: read_uuid,
347                    unsolicited_packet_handler: opts.unsolicited_packet_handler,
348                    orphan_handler: opts.orphan_handler,
349                    on_read_close_handler: opts.on_read_close_tx,
350                    on_close_cancel: cancel_child,
351                    disable_decompression: opts.disable_decompression,
352                    local_addr,
353                    peer_addr,
354                },
355            )
356            .await;
357        });
358
359        Self {
360            current_opaque: AtomicU32::new(1),
361            opaque_map,
362            client_id: opts.id,
363
364            on_close_cancel: cancel_guard,
365
366            writer: Mutex::new(writer),
367
368            local_addr,
369            peer_addr,
370
371            closed: AtomicBool::new(false),
372        }
373    }
374
375    async fn dispatch<'a>(
376        &self,
377        mut packet: RequestPacket<'a>,
378        is_persistent: bool,
379        response_context: Option<ResponseContext>,
380    ) -> error::Result<ClientPendingOp> {
381        let (response_tx, response_rx) = mpsc::channel(1);
382
383        let opaque = self.register_handler(SenderContext {
384            sender: response_tx,
385            is_persistent,
386            context: response_context,
387        });
388        packet.opaque = Some(opaque);
389        let op_code = packet.op_code;
390
391        // Create a guard that will remove the opaque entry from the map if the future is
392        // dropped before we successfully construct a ClientPendingOp (which takes over
393        // cleanup responsibility via its own Drop impl).
394        let mut opaque_guard = DispatchOpaqueGuard::new(opaque, self.opaque_map.clone());
395
396        trace!(
397            "Writing request on {}. Opcode={}. Opaque={}",
398            &self.client_id,
399            packet.op_code,
400            opaque,
401        );
402
403        let mut writer = self.writer.lock().await;
404        match writer.send(packet).await {
405            Ok(_) => {
406                // Disarm the guard — the ClientPendingOp now owns cleanup responsibility.
407                opaque_guard.disarm();
408                Ok(ClientPendingOp::new(
409                    opaque,
410                    self.opaque_map.clone(),
411                    response_rx,
412                    is_persistent,
413                ))
414            }
415            Err(e) => {
416                debug!(
417                    "{} failed to write packet {} {} {}",
418                    self.client_id, opaque, op_code, e
419                );
420
421                // opaque_guard will remove the entry from the opaque map when dropped.
422
423                Err(Error::new_dispatch_error(opaque, op_code, Box::new(e)))
424            }
425        }
426    }
427
428    async fn close(&self) -> error::Result<()> {
429        if self.closed.swap(true, Ordering::SeqCst) {
430            return Ok(());
431        }
432
433        info!("Closing client {}", self.client_id);
434
435        let mut close_err = None;
436        let mut writer = self.writer.lock().await;
437        match writer.close().await {
438            Ok(_) => {}
439            Err(e) => {
440                close_err = Some(e);
441            }
442        };
443
444        Self::drain_opaque_map(self.opaque_map.clone()).await;
445
446        if let Some(e) = close_err {
447            return Err(Error::new_close_error(e.to_string(), Box::new(e)));
448        }
449
450        Ok(())
451    }
452}
453
454impl Drop for Client {
455    fn drop(&mut self) {
456        info!("Dropping client {}", self.client_id);
457    }
458}