1use async_trait::async_trait;
20use bytes::Bytes;
21use futures::{SinkExt, TryFutureExt};
22use snap::raw::Decoder;
23use std::backtrace::Backtrace;
24use std::cell::RefCell;
25use std::collections::HashMap;
26use std::io::empty;
27use std::net::SocketAddr;
28use std::pin::pin;
29use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
30use std::sync::Arc;
31use std::thread::spawn;
32use std::{env, mem};
33use tokio::io::{AsyncRead, AsyncWrite, Join, ReadHalf, WriteHalf};
34use tokio::select;
35use tokio::sync::mpsc::unbounded_channel;
36use tokio::sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender};
37use tokio::sync::{mpsc, oneshot, Mutex, MutexGuard, RwLock};
38use tokio::task::JoinHandle;
39use tokio_stream::StreamExt;
40use tokio_util::codec::{FramedRead, FramedWrite};
41use tokio_util::sync::{CancellationToken, DropGuard};
42use tracing::{debug, error, info, trace, warn};
43use uuid::Uuid;
44
45use crate::memdx::client_response::ClientResponse;
46use crate::memdx::codec::KeyValueCodec;
47use crate::memdx::connection::{ConnectionType, Stream};
48use crate::memdx::datatype::DataTypeFlag;
49use crate::memdx::dispatcher::{
50 Dispatcher, DispatcherOptions, OnReadLoopCloseHandler, OrphanResponseHandler,
51 UnsolicitedPacketHandler,
52};
53use crate::memdx::error;
54use crate::memdx::error::{CancellationErrorKind, Error};
55use crate::memdx::hello_feature::HelloFeature::DataType;
56use crate::memdx::magic::Magic;
57use crate::memdx::opcode::OpCode;
58use crate::memdx::packet::{RequestPacket, ResponsePacket};
59use crate::memdx::pendingop::ClientPendingOp;
60use crate::memdx::subdoc::SubdocRequestInfo;
61use crate::orphan_reporter::OrphanContext;
62
63pub(crate) type ResponseSender = Sender<error::Result<ClientResponse>>;
64pub(crate) type OpaqueMap = HashMap<u32, SenderContext>;
65
66#[derive(Debug, Clone)]
67pub struct ResponseContext {
68 pub cas: Option<u64>,
69 pub subdoc_info: Option<SubdocRequestInfo>,
70 pub scope_name: Option<String>,
71 pub collection_name: Option<String>,
72}
73
74#[derive(Debug, Clone)]
75pub(crate) struct SenderContext {
76 pub sender: ResponseSender,
77 pub is_persistent: bool,
78 pub context: Option<ResponseContext>,
79}
80
81struct ReadLoopOptions {
82 pub client_id: String,
83 pub unsolicited_packet_handler: UnsolicitedPacketHandler,
84 pub orphan_handler: Option<OrphanResponseHandler>,
85 pub on_read_close_handler: OnReadLoopCloseHandler,
86 pub on_close_cancel: CancellationToken,
87 pub disable_decompression: bool,
88 pub local_addr: SocketAddr,
89 pub peer_addr: SocketAddr,
90}
91
92#[derive(Debug)]
93struct ClientReadHandle {
94 read_handle: JoinHandle<()>,
95}
96
97impl ClientReadHandle {
98 pub async fn await_completion(&mut self) {
99 (&mut self.read_handle).await.unwrap_or_default()
100 }
101}
102
103#[derive(Debug)]
104pub struct Client {
105 current_opaque: AtomicU32,
106 opaque_map: Arc<std::sync::Mutex<OpaqueMap>>,
107
108 client_id: String,
109
110 writer: Mutex<FramedWrite<WriteHalf<Box<dyn Stream>>, KeyValueCodec>>,
111 on_close_cancel: DropGuard,
112
113 local_addr: SocketAddr,
114 peer_addr: SocketAddr,
115
116 closed: AtomicBool,
117}
118
119impl Client {
120 fn register_handler(&self, response_context: SenderContext) -> u32 {
121 let mut map = self.opaque_map.lock().unwrap();
122
123 let opaque = self.current_opaque.fetch_add(1, Ordering::SeqCst);
124
125 map.insert(opaque, response_context);
126
127 opaque
128 }
129
130 async fn drain_opaque_map(opaque_map: Arc<std::sync::Mutex<OpaqueMap>>) {
131 let mut senders = vec![];
132 {
133 let mut guard = opaque_map.lock().unwrap();
134 guard.drain().for_each(|(_, v)| {
135 senders.push(v);
136 });
137 }
138
139 for sender in senders {
140 sender
141 .sender
142 .send(Err(Error::new_cancelled_error(
143 CancellationErrorKind::ClosedInFlight,
144 )))
145 .await
146 .unwrap_or_default();
147 }
148 }
149
150 async fn on_read_loop_close(
151 client_id: &str,
152 stream: FramedRead<ReadHalf<Box<dyn Stream>>, KeyValueCodec>,
153 opaque_map: Arc<std::sync::Mutex<OpaqueMap>>,
154 on_read_loop_close: OnReadLoopCloseHandler,
155 ) {
156 drop(stream);
157
158 Self::drain_opaque_map(opaque_map).await;
159
160 if on_read_loop_close.send(()).is_err() {
161 error!("{} failed to notify read loop closure", &client_id);
162 }
163
164 debug!("{client_id} read loop shut down");
165 }
166
167 async fn read_loop(
168 mut stream: FramedRead<ReadHalf<Box<dyn Stream>>, KeyValueCodec>,
169 opaque_map: Arc<std::sync::Mutex<OpaqueMap>>,
170 mut opts: ReadLoopOptions,
171 ) {
172 loop {
173 select! {
174 (_) = opts.on_close_cancel.cancelled() => {
175 Self::on_read_loop_close(&opts.client_id, stream, opaque_map, opts.on_read_close_handler).await;
176 return;
177 },
178 (next) = stream.next() => {
179 match next {
180 Some(input) => {
181 match input {
182 Ok(mut packet) => {
183 if packet.magic == Magic::ServerReq {
184
185 trace!(
186 "Handling server request on {}. Opcode={}",
187 opts.client_id,
188 packet.op_code,
189 );
190
191 (opts.unsolicited_packet_handler)(packet).await;
192 continue;
193 }
194
195 trace!(
196 "Resolving response on {}. Opcode={}. Opaque={}. Status={}",
197 opts.client_id,
198 packet.op_code,
199 packet.opaque,
200 packet.status,
201 );
202
203 let opaque = packet.opaque;
204
205 let requests: Arc<std::sync::Mutex<OpaqueMap>> = Arc::clone(&opaque_map);
206 let context = {
207 let mut map = requests.lock().unwrap();
208 map.remove(&opaque)
209 };
210
211 if let Some(mut context) = context {
212 let sender = &context.sender;
213
214 if let Some(value) = &packet.value {
215 if !opts.disable_decompression && (packet.datatype & u8::from(DataTypeFlag::Compressed) != 0) {
216 let mut decoder = Decoder::new();
217 let new_value = match decoder
218 .decompress_vec(value)
219 {
220 Ok(v) => v,
221 Err(e) => {
222 match sender.send(Err(Error::new_decompression_error().with(e))).await{
223 Ok(_) => {}
224 Err(e) => {
225 debug!("Sending response to caller failed: {e}");
226 }
227 };
228 continue;
229 }
230 };
231
232 packet.datatype &= !u8::from(DataTypeFlag::Compressed);
233 packet.value = Some(Bytes::from(new_value));
234 }
235 }
236
237 if context.is_persistent {
238 {
239 let mut map = requests.lock().unwrap();
240 map.insert(opaque, context.clone());
241 }
242 }
243
244 let resp = ClientResponse::new(packet, context.context);
245 match sender.send(Ok(resp)).await {
246 Ok(_) => {}
247 Err(e) => {
248 debug!("Sending response to caller failed: {e}");
249 Self::on_read_loop_close(&opts.client_id, stream, opaque_map, opts.on_read_close_handler).await;
250 return;
251 }
252 };
253 } else if let Some(ref orphan_handler) = opts.orphan_handler {
254 orphan_handler(
255 packet,
256 OrphanContext {
257 client_id: opts.client_id.clone(),
258 local_addr: opts.local_addr,
259 peer_addr: opts.peer_addr,
260 },
261 );
262 }
263 drop(requests);
264 }
265 Err(e) => {
266 warn!("{} failed to read frame {}", opts.client_id, e);
267 Self::on_read_loop_close(&opts.client_id, stream, opaque_map, opts.on_read_close_handler).await;
268 return;
269 }
270 }
271 }
272 None => {
273 Self::on_read_loop_close(&opts.client_id, stream, opaque_map, opts.on_read_close_handler).await;
274 return;
275 }
276 }
277 }
278 }
279 }
280 }
281
282 fn split_stream<StreamType: AsyncRead + AsyncWrite + Send + Unpin>(
283 stream: StreamType,
284 ) -> (ReadHalf<StreamType>, WriteHalf<StreamType>) {
285 tokio::io::split(stream)
286 }
287}
288
289struct DispatchOpaqueGuard {
293 opaque: u32,
294 opaque_map: Option<Arc<std::sync::Mutex<OpaqueMap>>>,
295}
296
297impl DispatchOpaqueGuard {
298 fn new(opaque: u32, opaque_map: Arc<std::sync::Mutex<OpaqueMap>>) -> Self {
299 Self {
300 opaque,
301 opaque_map: Some(opaque_map),
302 }
303 }
304
305 fn disarm(&mut self) {
307 self.opaque_map = None;
308 }
309}
310
311impl Drop for DispatchOpaqueGuard {
312 fn drop(&mut self) {
313 if let Some(opaque_map) = self.opaque_map.take() {
314 let mut map = opaque_map.lock().unwrap();
315 map.remove(&self.opaque);
316 }
317 }
318}
319
320#[async_trait]
321impl Dispatcher for Client {
322 fn new(conn: ConnectionType, opts: DispatcherOptions) -> Self {
323 let local_addr = *conn.local_addr();
324 let peer_addr = *conn.peer_addr();
325
326 let (r, w) = tokio::io::split(conn.into_inner());
327
328 let codec = KeyValueCodec::default();
329 let reader = FramedRead::new(r, codec);
330 let writer = FramedWrite::new(w, codec);
331
332 let cancel_token = CancellationToken::new();
333 let cancel_child = cancel_token.child_token();
334 let cancel_guard = cancel_token.drop_guard();
335
336 let opaque_map = Arc::new(std::sync::Mutex::new(OpaqueMap::default()));
337
338 let read_opaque_map = Arc::clone(&opaque_map);
339 let read_uuid = opts.id.clone();
340
341 tokio::spawn(async move {
342 Client::read_loop(
343 reader,
344 read_opaque_map,
345 ReadLoopOptions {
346 client_id: read_uuid,
347 unsolicited_packet_handler: opts.unsolicited_packet_handler,
348 orphan_handler: opts.orphan_handler,
349 on_read_close_handler: opts.on_read_close_tx,
350 on_close_cancel: cancel_child,
351 disable_decompression: opts.disable_decompression,
352 local_addr,
353 peer_addr,
354 },
355 )
356 .await;
357 });
358
359 Self {
360 current_opaque: AtomicU32::new(1),
361 opaque_map,
362 client_id: opts.id,
363
364 on_close_cancel: cancel_guard,
365
366 writer: Mutex::new(writer),
367
368 local_addr,
369 peer_addr,
370
371 closed: AtomicBool::new(false),
372 }
373 }
374
375 async fn dispatch<'a>(
376 &self,
377 mut packet: RequestPacket<'a>,
378 is_persistent: bool,
379 response_context: Option<ResponseContext>,
380 ) -> error::Result<ClientPendingOp> {
381 let (response_tx, response_rx) = mpsc::channel(1);
382
383 let opaque = self.register_handler(SenderContext {
384 sender: response_tx,
385 is_persistent,
386 context: response_context,
387 });
388 packet.opaque = Some(opaque);
389 let op_code = packet.op_code;
390
391 let mut opaque_guard = DispatchOpaqueGuard::new(opaque, self.opaque_map.clone());
395
396 trace!(
397 "Writing request on {}. Opcode={}. Opaque={}",
398 &self.client_id,
399 packet.op_code,
400 opaque,
401 );
402
403 let mut writer = self.writer.lock().await;
404 match writer.send(packet).await {
405 Ok(_) => {
406 opaque_guard.disarm();
408 Ok(ClientPendingOp::new(
409 opaque,
410 self.opaque_map.clone(),
411 response_rx,
412 is_persistent,
413 ))
414 }
415 Err(e) => {
416 debug!(
417 "{} failed to write packet {} {} {}",
418 self.client_id, opaque, op_code, e
419 );
420
421 Err(Error::new_dispatch_error(opaque, op_code, Box::new(e)))
424 }
425 }
426 }
427
428 async fn close(&self) -> error::Result<()> {
429 if self.closed.swap(true, Ordering::SeqCst) {
430 return Ok(());
431 }
432
433 info!("Closing client {}", self.client_id);
434
435 let mut close_err = None;
436 let mut writer = self.writer.lock().await;
437 match writer.close().await {
438 Ok(_) => {}
439 Err(e) => {
440 close_err = Some(e);
441 }
442 };
443
444 Self::drain_opaque_map(self.opaque_map.clone()).await;
445
446 if let Some(e) = close_err {
447 return Err(Error::new_close_error(e.to_string(), Box::new(e)));
448 }
449
450 Ok(())
451 }
452}
453
454impl Drop for Client {
455 fn drop(&mut self) {
456 info!("Dropping client {}", self.client_id);
457 }
458}