1use async_trait::async_trait;
20use bytes::Bytes;
21use futures::{SinkExt, TryFutureExt};
22use log::{debug, error, trace, warn};
23use snap::raw::Decoder;
24use std::backtrace::Backtrace;
25use std::cell::RefCell;
26use std::collections::HashMap;
27use std::io::empty;
28use std::net::SocketAddr;
29use std::pin::pin;
30use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
31use std::sync::Arc;
32use std::thread::spawn;
33use std::{env, mem};
34use tokio::io::{AsyncRead, AsyncWrite, Join, ReadHalf, WriteHalf};
35use tokio::select;
36use tokio::sync::mpsc::unbounded_channel;
37use tokio::sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender};
38use tokio::sync::{mpsc, oneshot, Mutex, MutexGuard, RwLock};
39use tokio::task::JoinHandle;
40use tokio_stream::StreamExt;
41use tokio_util::codec::{FramedRead, FramedWrite};
42use tokio_util::sync::{CancellationToken, DropGuard};
43use uuid::Uuid;
44
45use crate::memdx::client_response::ClientResponse;
46use crate::memdx::codec::KeyValueCodec;
47use crate::memdx::connection::{ConnectionType, Stream};
48use crate::memdx::datatype::DataTypeFlag;
49use crate::memdx::dispatcher::{
50 Dispatcher, DispatcherOptions, OnConnectionCloseHandler, OrphanResponseHandler,
51 UnsolicitedPacketHandler,
52};
53use crate::memdx::error;
54use crate::memdx::error::{CancellationErrorKind, Error};
55use crate::memdx::hello_feature::HelloFeature::DataType;
56use crate::memdx::magic::Magic;
57use crate::memdx::opcode::OpCode;
58use crate::memdx::packet::{RequestPacket, ResponsePacket};
59use crate::memdx::pendingop::ClientPendingOp;
60use crate::memdx::subdoc::SubdocRequestInfo;
61use crate::orphan_reporter::OrphanContext;
62
63pub(crate) type ResponseSender = Sender<error::Result<ClientResponse>>;
64pub(crate) type OpaqueMap = HashMap<u32, Arc<SenderContext>>;
65
66#[derive(Debug, Clone)]
67pub struct ResponseContext {
68 pub cas: Option<u64>,
69 pub subdoc_info: Option<SubdocRequestInfo>,
70 pub is_persistent: bool,
71 pub scope_name: Option<String>,
72 pub collection_name: Option<String>,
73}
74
75#[derive(Debug, Clone)]
76pub(crate) struct SenderContext {
77 pub sender: ResponseSender,
78 pub context: Arc<ResponseContext>,
79}
80
81struct ReadLoopOptions {
82 pub client_id: String,
83 pub unsolicited_packet_handler: UnsolicitedPacketHandler,
84 pub orphan_handler: Option<OrphanResponseHandler>,
85 pub on_connection_close_tx: OnConnectionCloseHandler,
86 pub on_close_cancel: CancellationToken,
87 pub disable_decompression: bool,
88 pub local_addr: SocketAddr,
89 pub peer_addr: SocketAddr,
90}
91
92#[derive(Debug)]
93struct ClientReadHandle {
94 read_handle: JoinHandle<()>,
95}
96
97impl ClientReadHandle {
98 pub async fn await_completion(&mut self) {
99 (&mut self.read_handle).await.unwrap_or_default()
100 }
101}
102
103#[derive(Debug)]
104pub struct Client {
105 current_opaque: AtomicU32,
106 opaque_map: Arc<std::sync::Mutex<OpaqueMap>>,
107
108 client_id: String,
109
110 writer: Mutex<FramedWrite<WriteHalf<Box<dyn Stream>>, KeyValueCodec>>,
111 on_close_cancel: DropGuard,
112
113 local_addr: SocketAddr,
114 peer_addr: SocketAddr,
115
116 closed: AtomicBool,
117}
118
119impl Client {
120 fn register_handler(&self, response_context: SenderContext) -> u32 {
121 let mut map = self.opaque_map.lock().unwrap();
122
123 let opaque = self.current_opaque.fetch_add(1, Ordering::SeqCst);
124
125 map.insert(opaque, Arc::new(response_context));
126
127 opaque
128 }
129
130 async fn drain_opaque_map(opaque_map: Arc<std::sync::Mutex<OpaqueMap>>) {
131 let mut senders = vec![];
132 {
133 let mut guard = opaque_map.lock().unwrap();
134 guard.drain().for_each(|(_, v)| {
135 senders.push(v);
136 });
137 }
138
139 for sender in senders {
140 sender
141 .sender
142 .send(Err(Error::new_cancelled_error(
143 CancellationErrorKind::ClosedInFlight,
144 )))
145 .await
146 .unwrap_or_default();
147 }
148 }
149
150 async fn on_read_loop_close(
151 client_id: &str,
152 stream: FramedRead<ReadHalf<Box<dyn Stream>>, KeyValueCodec>,
153 opaque_map: Arc<std::sync::Mutex<OpaqueMap>>,
154 on_connection_close: OnConnectionCloseHandler,
155 ) {
156 drop(stream);
157
158 Self::drain_opaque_map(opaque_map).await;
159
160 on_connection_close().await;
161
162 debug!("{client_id} read loop shut down");
163 }
164
165 async fn read_loop(
166 mut stream: FramedRead<ReadHalf<Box<dyn Stream>>, KeyValueCodec>,
167 opaque_map: Arc<std::sync::Mutex<OpaqueMap>>,
168 mut opts: ReadLoopOptions,
169 ) {
170 loop {
171 select! {
172 (_) = opts.on_close_cancel.cancelled() => {
173 Self::on_read_loop_close(&opts.client_id, stream, opaque_map, opts.on_connection_close_tx).await;
174 return;
175 },
176 (next) = stream.next() => {
177 match next {
178 Some(input) => {
179 match input {
180 Ok(mut packet) => {
181 if packet.magic == Magic::ServerReq {
182
183 trace!(
184 "Handling server request on {}. Opcode={}",
185 opts.client_id,
186 packet.op_code,
187 );
188
189 (opts.unsolicited_packet_handler)(packet).await;
190 continue;
191 }
192
193 trace!(
194 "Resolving response on {}. Opcode={}. Opaque={}. Status={}",
195 opts.client_id,
196 packet.op_code,
197 packet.opaque,
198 packet.status,
199 );
200
201 let opaque = packet.opaque;
202
203 let requests: Arc<std::sync::Mutex<OpaqueMap>> = Arc::clone(&opaque_map);
204 let context = {
205 let map = requests.lock().unwrap();
206
207 let t = map.get(&opaque);
208
209 t.map(Arc::clone)
210 };
211
212 if let Some(context) = context {
213 let sender = &context.sender;
214
215 if let Some(value) = &packet.value {
216 if !opts.disable_decompression && (packet.datatype & u8::from(DataTypeFlag::Compressed) != 0) {
217 let mut decoder = Decoder::new();
218 let new_value = match decoder
219 .decompress_vec(value)
220 {
221 Ok(v) => v,
222 Err(e) => {
223 match sender.send(Err(Error::new_decompression_error().with(e))).await{
224 Ok(_) => {}
225 Err(e) => {
226 debug!("Sending response to caller failed: {e}");
227 }
228 };
229 continue;
230 }
231 };
232
233 packet.datatype &= !u8::from(DataTypeFlag::Compressed);
234 packet.value = Some(new_value);
235 }
236 }
237
238 if !context.context.is_persistent {
239 {
240 let mut map = requests.lock().unwrap();
241 map.remove(&opaque);
242 }
243 }
244
245 let resp = ClientResponse::new(packet, context.context.clone());
246 match sender.send(Ok(resp)).await {
247 Ok(_) => {}
248 Err(e) => {
249 debug!("Sending response to caller failed: {e}");
250 }
251 };
252 drop(context);
253 } else if let Some(ref orphan_handler) = opts.orphan_handler {
254 orphan_handler(
255 packet,
256 OrphanContext {
257 client_id: opts.client_id.clone(),
258 local_addr: opts.local_addr,
259 peer_addr: opts.peer_addr,
260 },
261 );
262 }
263 drop(requests);
264 }
265 Err(e) => {
266 warn!("{} failed to read frame {}", opts.client_id, e);
267 }
268 }
269 }
270 None => {
271 Self::on_read_loop_close(&opts.client_id, stream, opaque_map, opts.on_connection_close_tx).await;
272 return;
273 }
274 }
275 }
276 }
277 }
278 }
279
280 fn split_stream<StreamType: AsyncRead + AsyncWrite + Send + Unpin>(
281 stream: StreamType,
282 ) -> (ReadHalf<StreamType>, WriteHalf<StreamType>) {
283 tokio::io::split(stream)
284 }
285}
286
287#[async_trait]
288impl Dispatcher for Client {
289 fn new(conn: ConnectionType, opts: DispatcherOptions) -> Self {
290 let local_addr = *conn.local_addr();
291 let peer_addr = *conn.peer_addr();
292
293 let (r, w) = tokio::io::split(conn.into_inner());
294
295 let codec = KeyValueCodec::default();
296 let reader = FramedRead::new(r, codec);
297 let writer = FramedWrite::new(w, codec);
298
299 let cancel_token = CancellationToken::new();
300 let cancel_child = cancel_token.child_token();
301 let cancel_guard = cancel_token.drop_guard();
302
303 let opaque_map = Arc::new(std::sync::Mutex::new(OpaqueMap::default()));
304
305 let read_opaque_map = Arc::clone(&opaque_map);
306 let read_uuid = opts.id.clone();
307
308 tokio::spawn(async move {
309 Client::read_loop(
310 reader,
311 read_opaque_map,
312 ReadLoopOptions {
313 client_id: read_uuid,
314 unsolicited_packet_handler: opts.unsolicited_packet_handler,
315 orphan_handler: opts.orphan_handler,
316 on_connection_close_tx: opts.on_connection_close_handler,
317 on_close_cancel: cancel_child,
318 disable_decompression: opts.disable_decompression,
319 local_addr,
320 peer_addr,
321 },
322 )
323 .await;
324 });
325
326 Self {
327 current_opaque: AtomicU32::new(1),
328 opaque_map,
329 client_id: opts.id,
330
331 on_close_cancel: cancel_guard,
332
333 writer: Mutex::new(writer),
334
335 local_addr,
336 peer_addr,
337
338 closed: AtomicBool::new(false),
339 }
340 }
341
342 async fn dispatch<'a>(
343 &self,
344 mut packet: RequestPacket<'a>,
345 response_context: Option<ResponseContext>,
346 ) -> error::Result<ClientPendingOp> {
347 let (response_tx, response_rx) = mpsc::channel(1);
348 let context = response_context.unwrap_or(ResponseContext {
349 cas: packet.cas,
350 subdoc_info: None,
351 is_persistent: false,
352 scope_name: None,
353 collection_name: None,
354 });
355 let is_persistent = context.is_persistent;
356 let opaque = self.register_handler(SenderContext {
357 sender: response_tx,
358 context: Arc::new(context),
359 });
360 packet.opaque = Some(opaque);
361 let op_code = packet.op_code;
362
363 trace!(
364 "Writing request on {}. Opcode={}. Opaque={}",
365 &self.client_id,
366 packet.op_code,
367 opaque,
368 );
369
370 let mut writer = self.writer.lock().await;
371 match writer.send(packet).await {
372 Ok(_) => Ok(ClientPendingOp::new(
373 opaque,
374 self.opaque_map.clone(),
375 response_rx,
376 is_persistent,
377 )),
378 Err(e) => {
379 debug!(
380 "{} failed to write packet {} {} {}",
381 self.client_id, opaque, op_code, e
382 );
383
384 let requests: Arc<std::sync::Mutex<OpaqueMap>> = Arc::clone(&self.opaque_map);
385 {
386 let mut map = requests.lock().unwrap();
387 map.remove(&opaque);
388 }
389
390 Err(Error::new_dispatch_error(opaque, op_code, Box::new(e)))
391 }
392 }
393 }
394
395 async fn close(&self) -> error::Result<()> {
396 if self.closed.swap(true, Ordering::SeqCst) {
397 return Ok(());
398 }
399
400 debug!("Client {} closing", self.client_id);
401
402 let mut close_err = None;
403 let mut writer = self.writer.lock().await;
404 match writer.close().await {
405 Ok(_) => {}
406 Err(e) => {
407 close_err = Some(e);
408 }
409 };
410
411 Self::drain_opaque_map(self.opaque_map.clone()).await;
412
413 if let Some(e) = close_err {
414 return Err(Error::new_close_error(e.to_string(), Box::new(e)));
415 }
416
417 Ok(())
418 }
419}
420
421impl Drop for Client {
422 fn drop(&mut self) {
423 debug!("Client {} exiting", self.client_id);
424 }
425}