1use crate::actions::{EvaluationContext, RequestRule, ResponseRule};
2use crate::errors::{DoorkeeperError, ProxyError, WorkerError};
3use crate::frame::{
4 self, read_response_frame, write_frame, FrameOpcode, FrameParams, RequestFrame, ResponseFrame,
5};
6use crate::{RequestOpcode, TargetShard};
7use bytes::Bytes;
8use scylla_cql::frame::types::read_string_multimap;
9use std::collections::HashMap;
10use std::fmt::Display;
11use std::future::Future;
12use std::net::{IpAddr, Ipv4Addr, SocketAddr};
13use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
14use std::sync::{Arc, Mutex};
15use tokio::io::{AsyncRead, AsyncWrite};
16use tokio::net::{TcpListener, TcpSocket, TcpStream};
17use tokio::sync::mpsc::error::TryRecvError;
18use tokio::sync::{broadcast, mpsc};
19use tracing::{debug, error, info, trace, warn};
20
21type FinishWaiter = mpsc::Receiver<()>;
23type FinishGuard = mpsc::Sender<()>;
24
25type TerminateNotifier = tokio::sync::broadcast::Receiver<()>;
27type TerminateSignaler = tokio::sync::broadcast::Sender<()>;
28
29type ConnectionCloseNotifier = tokio::sync::broadcast::Receiver<()>;
32type ConnectionCloseSignaler = tokio::sync::broadcast::Sender<()>;
33
34type ErrorPropagator = mpsc::UnboundedSender<ProxyError>;
37type ErrorSink = mpsc::UnboundedReceiver<ProxyError>;
38
39static HARDCODED_OPTIONS_PARAMS: FrameParams = FrameParams {
40 flags: 0,
41 version: 0x04,
42 stream: 0,
43};
44
45#[derive(Clone, Copy, Debug)]
47pub enum ShardAwareness {
48 Unaware,
50 QueryNode,
56 FixedNum(u16),
58}
59
60impl ShardAwareness {
61 pub fn is_aware(&self) -> bool {
62 !matches!(self, Self::Unaware)
63 }
64}
65
66enum NodeType {
86 Real {
87 real_addr: SocketAddr,
88 shard_awareness: ShardAwareness,
89 response_rules: Option<Vec<ResponseRule>>,
90 },
91 Simulated,
92}
93
94pub struct Node {
95 proxy_addr: SocketAddr,
96 request_rules: Option<Vec<RequestRule>>,
97 node_type: NodeType,
98}
99
100impl Node {
101 pub fn new(
103 real_addr: SocketAddr,
104 proxy_addr: SocketAddr,
105 shard_awareness: ShardAwareness,
106 request_rules: Option<Vec<RequestRule>>,
107 response_rules: Option<Vec<ResponseRule>>,
108 ) -> Self {
109 Self {
110 proxy_addr,
111 request_rules,
112 node_type: NodeType::Real {
113 real_addr,
114 shard_awareness,
115 response_rules,
116 },
117 }
118 }
119
120 pub fn new_dry_mode(proxy_addr: SocketAddr, request_rules: Option<Vec<RequestRule>>) -> Self {
122 Self {
123 proxy_addr,
124 request_rules,
125 node_type: NodeType::Simulated,
126 }
127 }
128
129 pub fn builder() -> NodeBuilder {
130 NodeBuilder {
131 real_addr: None,
132 proxy_addr: None,
133 shard_awareness: None,
134 request_rules: None,
135 response_rules: None,
136 }
137 }
138}
139
140pub struct NodeBuilder {
141 real_addr: Option<SocketAddr>,
142 proxy_addr: Option<SocketAddr>,
143 shard_awareness: Option<ShardAwareness>,
144 request_rules: Option<Vec<RequestRule>>,
145 response_rules: Option<Vec<ResponseRule>>,
146}
147
148impl NodeBuilder {
149 pub fn real_address(mut self, real_addr: SocketAddr) -> Self {
150 self.real_addr = Some(real_addr);
151 self
152 }
153
154 pub fn proxy_address(mut self, proxy_addr: SocketAddr) -> Self {
155 self.proxy_addr = Some(proxy_addr);
156 self
157 }
158
159 pub fn shard_awareness(mut self, shard_awareness: ShardAwareness) -> Self {
160 self.shard_awareness = Some(shard_awareness);
161 self
162 }
163
164 pub fn request_rules(mut self, request_rules: Vec<RequestRule>) -> Self {
165 self.request_rules = Some(request_rules);
166 self
167 }
168
169 pub fn response_rules(mut self, response_rules: Vec<ResponseRule>) -> Self {
170 self.response_rules = Some(response_rules);
171 self
172 }
173
174 pub fn build(self) -> Node {
176 Node {
177 proxy_addr: self.proxy_addr.expect("Proxy addr is required!"),
178 request_rules: self.request_rules,
179 node_type: NodeType::Real {
180 real_addr: self.real_addr.expect("Real addr is required!"),
181 shard_awareness: self.shard_awareness.expect("Shard awareness is required!"),
182 response_rules: self.response_rules,
183 },
184 }
185 }
186
187 pub fn build_dry_mode(self) -> Node {
189 Node {
190 proxy_addr: self.proxy_addr.expect("Proxy addr is required!"),
191 request_rules: self.request_rules,
192 node_type: NodeType::Simulated,
193 }
194 }
195}
196
197#[derive(Clone, Copy)]
198struct DisplayableRealAddrOption(Option<SocketAddr>);
199impl Display for DisplayableRealAddrOption {
200 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201 if let Some(addr) = self.0 {
202 write!(f, "{}", addr)
203 } else {
204 write!(f, "<dry mode>")
205 }
206 }
207}
208
209#[derive(Clone, Copy)]
210struct DisplayableShard(Option<TargetShard>);
211impl Display for DisplayableShard {
212 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213 if let Some(shard) = self.0 {
214 write!(f, "shard {}", shard)
215 } else {
216 write!(f, "unknown shard")
217 }
218 }
219}
220
221enum InternalNode {
222 Real {
223 real_addr: SocketAddr,
224 proxy_addr: SocketAddr,
225 shard_awareness: ShardAwareness,
226 request_rules: Arc<Mutex<Vec<RequestRule>>>,
227 response_rules: Arc<Mutex<Vec<ResponseRule>>>,
228 },
229 Simulated {
230 proxy_addr: SocketAddr,
231 request_rules: Arc<Mutex<Vec<RequestRule>>>,
232 },
233}
234
235impl InternalNode {
236 fn proxy_addr(&self) -> SocketAddr {
237 match *self {
238 InternalNode::Real { proxy_addr, .. } => proxy_addr,
239 InternalNode::Simulated { proxy_addr, .. } => proxy_addr,
240 }
241 }
242 fn real_addr(&self) -> Option<SocketAddr> {
243 match *self {
244 InternalNode::Real { real_addr, .. } => Some(real_addr),
245 InternalNode::Simulated { .. } => None,
246 }
247 }
248 fn request_rules(&self) -> &Arc<Mutex<Vec<RequestRule>>> {
249 match self {
250 InternalNode::Real { request_rules, .. } => request_rules,
251 InternalNode::Simulated { request_rules, .. } => request_rules,
252 }
253 }
254}
255
256impl From<Node> for InternalNode {
257 fn from(node: Node) -> Self {
258 match node.node_type {
259 NodeType::Real {
260 real_addr,
261 shard_awareness,
262 response_rules,
263 } => InternalNode::Real {
264 real_addr,
265 proxy_addr: node.proxy_addr,
266 shard_awareness,
267 request_rules: node
268 .request_rules
269 .map(|rules| Arc::new(Mutex::new(rules)))
270 .unwrap_or_default(),
271 response_rules: response_rules
272 .map(|rules| Arc::new(Mutex::new(rules)))
273 .unwrap_or_default(),
274 },
275 NodeType::Simulated => InternalNode::Simulated {
276 proxy_addr: node.proxy_addr,
277 request_rules: node
278 .request_rules
279 .map(|rules| Arc::new(Mutex::new(rules)))
280 .unwrap_or_default(),
281 },
282 }
283 }
284}
285
286pub struct ProxyBuilder {
287 nodes: Vec<Node>,
288}
289
290impl ProxyBuilder {
291 pub fn with_node(mut self, node: Node) -> ProxyBuilder {
292 self.nodes.push(node);
293 self
294 }
295
296 pub fn build(self) -> Proxy {
297 Proxy::new(self.nodes)
298 }
299}
300
301pub struct Proxy {
302 nodes: Vec<InternalNode>,
303}
304
305impl Proxy {
306 pub fn new(nodes: impl IntoIterator<Item = Node>) -> Self {
307 Proxy {
308 nodes: nodes.into_iter().map(|node| node.into()).collect(),
309 }
310 }
311
312 pub fn builder() -> ProxyBuilder {
313 ProxyBuilder { nodes: vec![] }
314 }
315
316 pub fn translation_map(&self) -> HashMap<SocketAddr, SocketAddr> {
320 let mut translation_map = HashMap::new();
321 for node in self.nodes.iter() {
322 if let &InternalNode::Real {
323 real_addr,
324 proxy_addr,
325 ..
326 } = node
327 {
328 translation_map.insert(real_addr, proxy_addr);
329 let shard_aware_real_addr = SocketAddr::new(real_addr.ip(), 19042);
330 translation_map.insert(shard_aware_real_addr, proxy_addr);
331 }
332 }
333 translation_map
334 }
335
336 pub async fn run(self) -> Result<RunningProxy, DoorkeeperError> {
339 let (terminate_signaler, _t) = tokio::sync::broadcast::channel(1);
340 let (finish_guard, finish_waiter) = mpsc::channel(1);
341
342 let (error_propagator, error_sink) = mpsc::unbounded_channel();
343 let (doorkeepers, running_nodes): (Vec<_>, Vec<RunningNode>) = self
344 .nodes
345 .into_iter()
346 .map(|node| {
347 let running = {
348 let (request_rules, response_rules) = match node {
349 InternalNode::Real {
350 ref request_rules,
351 ref response_rules,
352 ..
353 } => (request_rules, Some(response_rules)),
354 InternalNode::Simulated {
355 ref request_rules, ..
356 } => (request_rules, None),
357 };
358 RunningNode {
359 request_rules: request_rules.clone(),
360 response_rules: response_rules.cloned(),
361 }
362 };
363 (
364 Doorkeeper::spawn(
365 node,
366 terminate_signaler.clone(),
367 finish_guard.clone(),
368 error_propagator.clone(),
369 ),
370 running,
371 )
372 })
373 .unzip();
374
375 for doorkeeper in doorkeepers {
376 doorkeeper.await?; }
378
379 Ok(RunningProxy {
380 terminate_signaler,
381 finish_waiter,
382 running_nodes,
383 error_sink,
384 })
385 }
386}
387
388pub struct RunningNode {
390 request_rules: Arc<Mutex<Vec<RequestRule>>>,
391 response_rules: Option<Arc<Mutex<Vec<ResponseRule>>>>,
392}
393
394impl RunningNode {
395 pub fn change_request_rules(&mut self, rules: Option<Vec<RequestRule>>) {
397 *self.request_rules.lock().unwrap() = rules.unwrap_or_default();
398 }
399
400 pub fn change_response_rules(&mut self, rules: Option<Vec<ResponseRule>>) {
402 *self
403 .response_rules
404 .as_ref()
405 .expect("No response rules on a simulated node!")
406 .lock()
407 .unwrap() = rules.unwrap_or_default();
408 }
409}
410
411pub struct RunningProxy {
413 terminate_signaler: TerminateSignaler,
414 finish_waiter: FinishWaiter,
415 pub running_nodes: Vec<RunningNode>,
416 error_sink: ErrorSink,
417}
418
419impl RunningProxy {
420 pub fn turn_off_rules(&mut self) {
422 for (request_rules, response_rules) in self
423 .running_nodes
424 .iter_mut()
425 .map(|node| (&node.request_rules, &node.response_rules))
426 {
427 request_rules.lock().unwrap().clear();
428 if let Some(response_rules) = response_rules {
429 response_rules.lock().unwrap().clear();
430 }
431 }
432 }
433
434 pub fn sanity_check(&mut self) -> Result<(), ProxyError> {
437 match self.error_sink.try_recv() {
438 Ok(err) => Err(err),
439 Err(TryRecvError::Empty) => Ok(()),
440 Err(TryRecvError::Disconnected) => {
441 Err(ProxyError::SanityCheckFailure)
443 }
444 }
445 }
446
447 pub async fn wait_for_error(&mut self) -> Option<ProxyError> {
449 self.error_sink.recv().await
450 }
451
452 pub async fn finish(mut self) -> Result<(), ProxyError> {
455 self.terminate_signaler.send(()).map_err(|err| {
456 ProxyError::AwaitFinishFailure(format!(
457 "Send error in terminate_signaler: {} (bug!)",
458 err
459 ))
460 })?;
461 info!("Sent finish signal to proxy workers.");
462
463 std::mem::drop(self.terminate_signaler);
465
466 if self.finish_waiter.recv().await.is_some() {
467 unreachable!();
468 };
469 info!("All workers have finished.");
470
471 match self.error_sink.try_recv() {
472 Ok(err) => Err(err),
473 Err(TryRecvError::Disconnected) => Ok(()),
474 Err(TryRecvError::Empty) => {
475 unreachable!("Worker await logic bug!");
477 }
478 }
479 }
480}
481
482struct Doorkeeper {
487 node: InternalNode,
488 listener: TcpListener,
489 terminate_signaler: TerminateSignaler,
490 finish_guard: FinishGuard,
491 shards_count: Option<u16>,
492 error_propagator: ErrorPropagator,
493}
494
495impl Doorkeeper {
496 async fn spawn(
497 node: InternalNode,
498 terminate_signaler: TerminateSignaler,
499 finish_guard: FinishGuard,
500 error_propagator: ErrorPropagator,
501 ) -> Result<(), DoorkeeperError> {
502 let listener = TcpListener::bind(node.proxy_addr())
503 .await
504 .map_err(|err| DoorkeeperError::DriverConnectionAttempt(node.proxy_addr(), err))?;
505
506 if let InternalNode::Real {
507 shard_awareness,
508 real_addr,
509 ..
510 } = node
511 {
512 info!(
513 "Spawned a {} doorkeeper for pair real:{} - proxy:{}.",
514 if shard_awareness.is_aware() {
515 "shard-aware"
516 } else {
517 "shard-unaware"
518 },
519 real_addr,
520 node.proxy_addr(),
521 );
522 } else {
523 info!(
524 "Spawned a dry-mode doorkeeper for proxy:{}.",
525 node.proxy_addr(),
526 )
527 };
528
529 let doorkeeper = Doorkeeper {
530 shards_count: None, node,
532 listener,
533 terminate_signaler,
534 finish_guard,
535 error_propagator,
536 };
537 tokio::task::spawn(doorkeeper.run());
538 Ok(())
539 }
540
541 async fn run(mut self) {
542 self.update_shards_count().await;
543 let mut own_terminate_notifier = self.terminate_signaler.subscribe();
544 let (connection_close_tx, _connection_close_rx) = broadcast::channel::<()>(2);
545 let mut connection_no: usize = 0;
546 loop {
547 tokio::select! {
548 res = self.accept_connection(&connection_close_tx, connection_no) => {
549 match res {
550 Ok(()) => connection_no += 1,
551 Err(err) => {
552 error!(
553 "Error in doorkeeper with addr {} for node {}: {}",
554 self.node.proxy_addr(),
555 DisplayableRealAddrOption(self.node.real_addr()),
556 err
557 );
558 let _ = self.error_propagator.send(err.into());
559 break;
560 },
561 }
562 },
563 _terminate = own_terminate_notifier.recv() => break
564 }
565 }
566 debug!(
567 "Doorkeeper exits: proxy {}, node {}.",
568 self.node.proxy_addr(),
569 DisplayableRealAddrOption(self.node.real_addr())
570 );
571 }
572
573 async fn update_shards_count(&mut self) {
574 if let InternalNode::Real {
575 real_addr,
576 shard_awareness,
577 ..
578 } = self.node
579 {
580 self.shards_count = match shard_awareness {
581 ShardAwareness::Unaware => None,
582 ShardAwareness::FixedNum(shards_num) => Some(shards_num),
583 ShardAwareness::QueryNode => match self.obtain_shards_count(real_addr).await {
584 Ok(shards) => Some(shards),
585 Err(DoorkeeperError::ObtainingShardNumberNoShardInfo) => {
587 info!(
588 "Doorkeeper with addr {} found no shard info in node {}; falling back to ShardAwareness::Unaware",
589 self.node.proxy_addr(),
590 DisplayableRealAddrOption(self.node.real_addr()),
591 );
592 None
593 }
594 Err(e) => {
595 error!(
596 "Error in doorkeeper with addr {} while querying shard info from node {}: {}",
597 self.node.proxy_addr(),
598 DisplayableRealAddrOption(self.node.real_addr()),
599 e
600 );
601 None
602 }
603 },
604 }
605 }
606 }
607
608 async fn spawn_workers(
609 &mut self,
610 driver_addr: SocketAddr,
611 connection_close_tx: &ConnectionCloseSignaler,
612 connection_no: usize,
613 driver_stream: TcpStream,
614 cluster_stream: Option<TcpStream>,
615 shard: Option<TargetShard>,
616 ) {
617 let (driver_read, driver_write) = driver_stream.into_split();
618
619 let new_worker = || ProxyWorker {
620 terminate_notifier: self.terminate_signaler.subscribe(),
621 finish_guard: self.finish_guard.clone(),
622 connection_close_notifier: connection_close_tx.subscribe(),
623 error_propagator: self.error_propagator.clone(),
624 driver_addr,
625 real_addr: self.node.real_addr(),
626 proxy_addr: self.node.proxy_addr(),
627 shard,
628 };
629
630 let (tx_request, rx_request) = mpsc::unbounded_channel::<RequestFrame>();
631 let (tx_response, rx_response) = mpsc::unbounded_channel::<ResponseFrame>();
632 let (tx_cluster, rx_cluster) = mpsc::unbounded_channel::<RequestFrame>();
633 let (tx_driver, rx_driver) = mpsc::unbounded_channel::<ResponseFrame>();
634 let event_register_flag = Arc::new(AtomicBool::new(false));
635
636 tokio::task::spawn(new_worker().receiver_from_driver(driver_read, tx_request));
637 tokio::task::spawn(new_worker().sender_to_driver(
638 driver_write,
639 rx_driver,
640 connection_close_tx.subscribe(),
641 self.terminate_signaler.subscribe(),
642 ));
643 tokio::task::spawn(new_worker().request_processor(
644 rx_request,
645 tx_driver.clone(),
646 tx_cluster.clone(),
647 connection_no,
648 self.node.request_rules().clone(),
649 connection_close_tx.clone(),
650 event_register_flag.clone(),
651 ));
652 if let InternalNode::Real {
653 ref response_rules, ..
654 } = self.node
655 {
656 let (cluster_read, cluster_write) = cluster_stream.unwrap().into_split();
657 tokio::task::spawn(new_worker().sender_to_cluster(
658 cluster_write,
659 rx_cluster,
660 connection_close_tx.subscribe(),
661 self.terminate_signaler.subscribe(),
662 ));
663 tokio::task::spawn(new_worker().receiver_from_cluster(cluster_read, tx_response));
664 tokio::task::spawn(new_worker().response_processor(
665 rx_response,
666 tx_driver,
667 tx_cluster,
668 connection_no,
669 response_rules.clone(),
670 connection_close_tx.clone(),
671 event_register_flag.clone(),
672 ));
673 }
674 debug!(
675 "Doorkeeper with addr {} of node {} spawned workers.",
676 self.node.proxy_addr(),
677 DisplayableRealAddrOption(self.node.real_addr())
678 );
679 }
680
681 async fn accept_connection(
682 &mut self,
683 connection_close_tx: &ConnectionCloseSignaler,
684 connection_no: usize,
685 ) -> Result<(), DoorkeeperError> {
686 let (driver_stream, driver_addr) = self.make_driver_stream(connection_no).await?;
687 let (cluster_stream, shard) = match self.node {
688 InternalNode::Real { real_addr, .. } => {
689 let (cluster_stream, shard) =
690 self.make_cluster_stream(driver_addr, real_addr).await?;
691 (Some(cluster_stream), shard)
692 }
693 InternalNode::Simulated { .. } => (None, None),
694 };
695
696 self.spawn_workers(
697 driver_addr,
698 connection_close_tx,
699 connection_no,
700 driver_stream,
701 cluster_stream,
702 shard,
703 )
704 .await;
705
706 Ok(())
707 }
708
709 async fn make_driver_stream(
710 &mut self,
711 connection_no: usize,
712 ) -> Result<(TcpStream, SocketAddr), DoorkeeperError> {
713 let (driver_stream, driver_addr) =
714 self.listener.accept().await.map_err(|err| {
715 DoorkeeperError::DriverConnectionAttempt(self.node.proxy_addr(), err)
716 })?;
717 info!(
718 "Connected driver from {} to {}, connection no={}.",
719 driver_addr,
720 self.node.proxy_addr(),
721 connection_no
722 );
723 Ok((driver_stream, driver_addr))
724 }
725
726 async fn make_cluster_stream(
727 &mut self,
728 driver_addr: SocketAddr,
729 real_addr: SocketAddr,
730 ) -> Result<(TcpStream, Option<TargetShard>), DoorkeeperError> {
731 let mut cluster_stream = if let Some(shards) = self.shards_count {
732 let socket = match self.node.proxy_addr().ip() {
733 std::net::IpAddr::V4(_) => TcpSocket::new_v4(),
734 std::net::IpAddr::V6(_) => TcpSocket::new_v6(),
735 }
736 .map_err(DoorkeeperError::SocketCreate)?;
737
738 let shard_preserving_addr = {
739 let mut desired_addr =
740 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), driver_addr.port());
741 while socket.bind(desired_addr).is_err() {
742 let next_port = self.next_port_to_same_shard(desired_addr.port());
744 if next_port == driver_addr.port() {
745 return Err(DoorkeeperError::NoMorePorts);
746 }
747 desired_addr.set_port(next_port);
748 }
749 desired_addr
750 };
751
752 socket.connect(real_addr).await.map(|ok| {
753 info!(
754 "Connected to the cluster from {} at {}, intended shard {}.",
755 ok.local_addr().unwrap(),
756 real_addr,
757 shard_preserving_addr.port() % shards
758 );
759 ok
760 })
761 } else {
762 TcpStream::connect(real_addr).await.map(|ok| {
763 info!("Connected to the cluster at {}.", real_addr);
764 ok
765 })
766 }
767 .map_err(|err| DoorkeeperError::NodeConnectionAttempt(real_addr, err))?;
768
769 let shard = if self.shards_count.is_some() {
776 self.obtain_shard_number(real_addr, &mut cluster_stream)
777 .await?
778 } else {
779 None
780 };
781
782 Ok((cluster_stream, shard))
783 }
784
785 fn next_port_to_same_shard(&self, port: u16) -> u16 {
786 port.wrapping_add(self.shards_count.unwrap())
787 }
788
789 async fn get_supported_options(
790 connection: &mut TcpStream,
791 ) -> Result<HashMap<String, Vec<String>>, DoorkeeperError> {
792 write_frame(
793 HARDCODED_OPTIONS_PARAMS,
794 FrameOpcode::Request(RequestOpcode::Options),
795 &Bytes::new(),
796 connection,
797 )
798 .await
799 .map_err(DoorkeeperError::ObtainingShardNumber)?;
800
801 let supported_frame = read_response_frame(connection)
802 .await
803 .map_err(DoorkeeperError::ObtainingShardNumberFrame)?;
804
805 let options = read_string_multimap(&mut supported_frame.body.as_ref())
806 .map_err(DoorkeeperError::ObtainingShardNumberParseOptions)?;
807
808 Ok(options)
809 }
810
811 async fn obtain_shards_count(&self, real_addr: SocketAddr) -> Result<u16, DoorkeeperError> {
812 let mut connection = TcpStream::connect(real_addr)
813 .await
814 .map_err(|err| DoorkeeperError::NodeConnectionAttempt(real_addr, err))?;
815 let options = Self::get_supported_options(&mut connection).await?;
816 let nr_shards_entry = options.get("SCYLLA_NR_SHARDS");
817 let shards = match nr_shards_entry
818 .and_then(|vec| vec.first())
819 .ok_or(DoorkeeperError::ObtainingShardNumberNoShardInfo)?
820 .parse::<u16>()
821 .map_err(DoorkeeperError::ObtainingShardNumberParseShardNumber)?
822 {
823 0u16 => Err(DoorkeeperError::ObtainingShardNumberGotZero),
824 num => Ok(num),
825 }?;
826 info!("Obtained shards number on node {}: {}", real_addr, shards);
827 Ok(shards)
828 }
829
830 async fn obtain_shard_number(
831 &self,
832 real_addr: SocketAddr,
833 connection: &mut TcpStream,
834 ) -> Result<Option<TargetShard>, DoorkeeperError> {
835 let options = Self::get_supported_options(connection).await?;
836 let shard_entry = options.get("SCYLLA_SHARD");
837 let shard = shard_entry
838 .and_then(|vec| vec.first())
839 .map(|s| {
840 s.parse::<u16>()
841 .map_err(DoorkeeperError::ObtainingShardNumberParseShardNumber)
842 })
843 .transpose()?;
844 info!("Connected to node {}, shard {:?}", real_addr, shard);
845 Ok(shard)
846 }
847}
848
849struct ProxyWorker {
850 terminate_notifier: TerminateNotifier,
851 finish_guard: FinishGuard,
852 connection_close_notifier: ConnectionCloseNotifier,
853 error_propagator: ErrorPropagator,
854 driver_addr: SocketAddr,
855 real_addr: Option<SocketAddr>,
856 proxy_addr: SocketAddr,
857 shard: Option<TargetShard>,
858}
859
860impl ProxyWorker {
861 fn exit(self, duty: &'static str) {
862 debug!(
863 "Worker exits: [driver: {}, proxy: {}, node: {}, {}]::{}.",
864 self.driver_addr,
865 self.proxy_addr,
866 DisplayableRealAddrOption(self.real_addr),
867 DisplayableShard(self.shard),
868 duty
869 );
870 std::mem::drop(self.finish_guard);
871 }
872
873 async fn run_until_interrupted<F, Fut>(mut self, worker_name: &'static str, f: F)
874 where
875 F: FnOnce(SocketAddr, SocketAddr, Option<SocketAddr>) -> Fut,
876 Fut: Future<Output = Result<(), ProxyError>>,
877 {
878 let fut = f(self.driver_addr, self.proxy_addr, self.real_addr);
879
880 tokio::select! {
881 result = fut => {
882 if let Err(err) = result {
883 let _ = self.error_propagator.send(err);
885 }
886 }
887 _ = self.terminate_notifier.recv() => (),
888 _ = self.connection_close_notifier.recv() => (),
889 }
890 self.exit(worker_name);
891 }
892
893 async fn receiver_from_driver(
894 self,
895 mut read_half: (impl AsyncRead + Unpin),
896 request_processor_tx: mpsc::UnboundedSender<RequestFrame>,
897 ) {
898 let shard = self.shard;
899 self.run_until_interrupted(
900 "receiver_from_driver",
901 |driver_addr, proxy_addr, _real_addr| async move {
902 loop {
903 let frame = frame::read_request_frame(&mut read_half)
904 .await
905 .map_err(|err| {
906 warn!("Request reception from {} error: {}", driver_addr, err);
907 WorkerError::DriverDisconnected(driver_addr)
908 })?;
909
910 debug!(
911 "Intercepted Driver ({}) -> Cluster ({}) ({}) frame. opcode: {:?}.",
912 driver_addr,
913 proxy_addr,
914 DisplayableShard(shard),
915 &frame.opcode
916 );
917 if request_processor_tx.send(frame).is_err() {
918 warn!("request_processor had exited.");
919 return Result::<(), ProxyError>::Ok(());
920 }
921 }
922 },
923 )
924 .await
925 }
926
927 async fn receiver_from_cluster(
928 self,
929 mut read_half: (impl AsyncRead + Unpin),
930 response_processor_tx: mpsc::UnboundedSender<ResponseFrame>,
931 ) {
932 let shard = self.shard;
933 self.run_until_interrupted(
934 "receiver_from_cluster",
935 |driver_addr, _proxy_addr, real_addr| async move {
936 let real_addr = real_addr.expect("BUG: no real_addr in cluster worker");
937 loop {
938 let frame =
939 frame::read_response_frame(&mut read_half)
940 .await
941 .map_err(|err| {
942 warn!("Response reception from {} error: {}", real_addr, err);
943 WorkerError::NodeDisconnected(real_addr)
944 })?;
945
946 debug!(
947 "Intercepted Cluster ({}) -> Driver ({}) ({}) frame. opcode: {:?}.",
948 real_addr,
949 driver_addr,
950 DisplayableShard(shard),
951 &frame.opcode
952 );
953
954 if response_processor_tx.send(frame).is_err() {
955 warn!("response_processor had exited.");
956 return Ok::<(), ProxyError>(());
957 }
958 }
959 },
960 )
961 .await;
962 }
963
964 async fn sender_to_driver(
965 self,
966 mut write_half: (impl AsyncWrite + Unpin),
967 mut responses_rx: mpsc::UnboundedReceiver<ResponseFrame>,
968 mut connection_close_notifier: ConnectionCloseNotifier,
969 mut terminate_notifier: TerminateNotifier,
970 ) {
971 let shard = self.shard;
972 self.run_until_interrupted(
973 "sender_to_driver",
974 |driver_addr, proxy_addr, _real_addr| async move {
975 loop {
976 let response = match responses_rx.recv().await {
977 Some(response) => response,
978 None => {
979 if terminate_notifier.try_recv().is_err()
980 && connection_close_notifier.try_recv().is_err()
981 {
982 warn!("Response processor had exited");
983 }
984 return Ok(());
985 }
986 };
987
988 debug!(
989 "Sending Proxy ({}) -> Driver ({}) ({}) frame. opcode: {:?}.",
990 proxy_addr,
991 driver_addr,
992 DisplayableShard(shard),
993 &response.opcode
994 );
995 if response.write(&mut write_half).await.is_err() {
996 if terminate_notifier.try_recv().is_err()
997 && connection_close_notifier.try_recv().is_err()
998 {
999 warn!("Driver dropped connection");
1000 return Err(WorkerError::DriverDisconnected(driver_addr).into());
1001 }
1002 return Ok(());
1003 }
1004 }
1005 },
1006 )
1007 .await;
1008 }
1009
1010 async fn sender_to_cluster(
1011 self,
1012 mut write_half: (impl AsyncWrite + Unpin),
1013 mut requests_rx: mpsc::UnboundedReceiver<RequestFrame>,
1014 mut connection_close_notifier: ConnectionCloseNotifier,
1015 mut terminate_notifier: TerminateNotifier,
1016 ) {
1017 let shard = self.shard;
1018 self.run_until_interrupted(
1019 "sender_to_driver",
1020 |_driver_addr, proxy_addr, real_addr| async move {
1021 let real_addr = real_addr.expect("BUG: no real_addr in cluster worker");
1022 loop {
1023 let request = match requests_rx.recv().await {
1024 Some(request) => request,
1025 None => {
1026 if terminate_notifier.try_recv().is_err()
1027 && connection_close_notifier.try_recv().is_err()
1028 {
1029 warn!("Request processor had exited");
1030 }
1031 return Ok(());
1032 }
1033 };
1034
1035 debug!(
1036 "Sending Proxy ({}) -> Cluster ({}) ({}) frame. opcode: {:?}.",
1037 proxy_addr,
1038 real_addr,
1039 DisplayableShard(shard),
1040 &request.opcode
1041 );
1042
1043 if request.write(&mut write_half).await.is_err() {
1044 if terminate_notifier.try_recv().is_err()
1045 && connection_close_notifier.try_recv().is_err()
1046 {
1047 warn!("Node {} dropped connection", real_addr);
1048 return Err(WorkerError::NodeDisconnected(real_addr).into());
1049 }
1050 return Ok(());
1051 }
1052 }
1053 },
1054 )
1055 .await;
1056 }
1057
1058 #[allow(clippy::too_many_arguments)]
1059 async fn request_processor(
1060 self,
1061 mut requests_rx: mpsc::UnboundedReceiver<RequestFrame>,
1062 driver_tx: mpsc::UnboundedSender<ResponseFrame>,
1063 cluster_tx: mpsc::UnboundedSender<RequestFrame>,
1064 connection_no: usize,
1065 request_rules: Arc<Mutex<Vec<RequestRule>>>,
1066 connection_close_signaler: ConnectionCloseSignaler,
1067 event_registered_flag: Arc<AtomicBool>,
1068 ) {
1069 let shard = self.shard;
1070 self.run_until_interrupted("request_processor", |driver_addr, _, real_addr| async move {
1071 'mainloop: loop {
1072 match requests_rx.recv().await {
1073 Some(request) => {
1074 if request.opcode == RequestOpcode::Register {
1075 event_registered_flag.store(true, Ordering::Relaxed);
1076 }
1077 let ctx = EvaluationContext {
1078 connection_seq_no: connection_no,
1079 opcode: FrameOpcode::Request(request.opcode),
1080 frame_body: request.body.clone(),
1081 connection_has_events: event_registered_flag.load(Ordering::Relaxed),
1082 };
1083 let mut guard = request_rules.lock().unwrap();
1084 '_ruleloop: for (i, request_rule) in guard.iter_mut().enumerate() {
1085 if request_rule.0.eval(&ctx) {
1086 info!("Applying rule no={} to request ({} -> {} ({})).", i, driver_addr, DisplayableRealAddrOption(real_addr), DisplayableShard(shard));
1087 debug!("-> Applied rule: {:?}", request_rule);
1088 debug!("-> To request: {:?}", ctx.opcode);
1089 trace!("{:?}", request);
1090
1091 if let Some(ref tx) = request_rule.1.feedback_channel {
1092 tx.send((request.clone(), shard)).unwrap_or_else(|err|
1093 warn!("Could not send received request as feedback: {}", err)
1094 );
1095 }
1096
1097 let request_rule = request_rule.clone();
1098 let to_addressee_action = request_rule.1.to_addressee;
1099 let to_sender_action = request_rule.1.to_sender;
1100 let drop_connection_action = request_rule.1.drop_connection;
1101
1102 let cluster_tx_clone = cluster_tx.clone();
1103 let request_clone = request.clone();
1104 let pass_action = async move {
1105 if let Some(ref pass_action) = to_addressee_action {
1106 if let Some(time) = pass_action.delay {
1107 tokio::time::sleep(time).await;
1108 }
1109 let passed_frame = match pass_action.msg_processor {
1110 Some(ref processor) => processor(request_clone),
1111 None => request_clone,
1112 };
1113 let _ = cluster_tx_clone.send(passed_frame);
1114 };
1115 };
1116
1117 let driver_tx_clone = driver_tx.clone();
1118 let request_clone = request.clone();
1119 let forge_action = async move {
1120 if let Some(ref forge_action) = to_sender_action {
1121 if let Some(time) = forge_action.delay {
1122 tokio::time::sleep(time).await;
1123 }
1124 let forged_frame = {
1125 let processor = forge_action.msg_processor.as_ref()
1126 .expect("Frame processor is required to forge a frame.");
1127 processor(request_clone)
1128 };
1129 let _ = driver_tx_clone.send(forged_frame);
1130 };
1131 };
1132
1133 let connection_close_signaler_clone =
1134 connection_close_signaler.clone();
1135 let drop_action = async move {
1136 if let Some(ref delay) = drop_connection_action {
1137 if let Some(ref time) = delay {
1138 tokio::time::sleep(*time).await;
1139 }
1140 info!(
1142 "Dropping connection between {} and {} ({}) (as requested by a proxy rule)!",
1143 driver_addr,
1144 DisplayableRealAddrOption(real_addr),
1145 DisplayableShard(shard),
1146 );
1147 let _ = connection_close_signaler_clone.send(());
1148 }
1149 };
1150
1151 tokio::task::spawn(async {
1152 futures::join!(pass_action, forge_action, drop_action);
1153 });
1154
1155 continue 'mainloop; }
1157 }
1158 let _ = cluster_tx.send(request); }
1160 None => return Ok(()),
1161 }
1162 }
1163 })
1164 .await;
1165 }
1166
1167 #[allow(clippy::too_many_arguments)]
1168 async fn response_processor(
1169 self,
1170 mut responses_rx: mpsc::UnboundedReceiver<ResponseFrame>,
1171 driver_tx: mpsc::UnboundedSender<ResponseFrame>,
1172 cluster_tx: mpsc::UnboundedSender<RequestFrame>,
1173 connection_no: usize,
1174 response_rules: Arc<Mutex<Vec<ResponseRule>>>,
1175 connection_close_signaler: ConnectionCloseSignaler,
1176 event_registered_flag: Arc<AtomicBool>,
1177 ) {
1178 let shard = self.shard;
1179 self.run_until_interrupted("request_processor", |driver_addr, _, real_addr| async move {
1180 'mainloop: loop {
1181 match responses_rx.recv().await {
1182 Some(response) => {
1183 let ctx = EvaluationContext {
1184 connection_seq_no: connection_no,
1185 opcode: FrameOpcode::Response(response.opcode),
1186 frame_body: response.body.clone(),
1187 connection_has_events: event_registered_flag.load(Ordering::Relaxed),
1188 };
1189 let mut guard = response_rules.lock().unwrap();
1190 '_ruleloop: for (i, response_rule) in guard.iter_mut().enumerate() {
1191 if response_rule.0.eval(&ctx) {
1192 info!("Applying rule no={} to request ({} -> {} ({})).", i, DisplayableRealAddrOption(real_addr), driver_addr, DisplayableShard(shard));
1193 debug!("-> Applied rule: {:?}", response_rule);
1194 debug!("-> To response: {:?}", ctx.opcode);
1195 trace!("{:?}", response);
1196
1197 if let Some(ref tx) = response_rule.1.feedback_channel {
1198 tx.send((response.clone(), shard)).unwrap_or_else(|err| warn!(
1199 "Could not send received response as feedback: {}", err
1200 ));
1201 }
1202
1203 let response_rule = response_rule.clone();
1204 let to_addressee_action = response_rule.1.to_addressee;
1205 let to_sender_action = response_rule.1.to_sender;
1206 let drop_connection_action = response_rule.1.drop_connection;
1207
1208 let response_clone = response.clone();
1209 let driver_tx_clone = driver_tx.clone();
1210 let pass_action = async move {
1211 if let Some(ref pass_action) = to_addressee_action {
1212 if let Some(time) = pass_action.delay {
1213 tokio::time::sleep(time).await;
1214 }
1215 let passed_frame = match pass_action.msg_processor {
1216 Some(ref processor) => processor(response_clone),
1217 None => response_clone,
1218 };
1219 let _ = driver_tx_clone.send(passed_frame);
1220 };
1221 };
1222
1223 let response_clone = response.clone();
1224 let cluster_tx_clone = cluster_tx.clone();
1225 let forge_action = async move {
1226 if let Some(ref forge_action) = to_sender_action {
1227 if let Some(time) = forge_action.delay {
1228 tokio::time::sleep(time).await;
1229 }
1230 let forged_frame = {
1231 let processor = forge_action.msg_processor.as_ref()
1232 .expect("Frame processor is required to forge a frame.");
1233 processor(response_clone)
1234 };
1235 let _ = cluster_tx_clone.send(forged_frame);
1236 };
1237 };
1238
1239 let connection_close_signaler_clone =
1240 connection_close_signaler.clone();
1241 let drop_action = async move {
1242 if let Some(ref delay) = drop_connection_action {
1243 if let Some(ref time) = delay {
1244 tokio::time::sleep(*time).await;
1245 }
1246 info!(
1248 "Dropping connection between {} and {} ({}) (as requested by a proxy rule)!",
1249 driver_addr,
1250 real_addr.expect("BUG: response rules are unavailable for dry-mode proxy!"),
1251 DisplayableShard(shard)
1252 );
1253 let _ = connection_close_signaler_clone.send(());
1254 }
1255 };
1256
1257 tokio::task::spawn(async {
1258 futures::join!(pass_action, forge_action, drop_action);
1259 });
1260
1261 continue 'mainloop;
1262 }
1263 }
1264 let _ = driver_tx.send(response); }
1266 None => return Ok(()),
1267 }
1268 }
1269 })
1270 .await
1271 }
1272}
1273
1274#[doc(hidden)]
1277pub fn get_exclusive_local_address() -> IpAddr {
1278 static ADDRESS_LOWER_THREE_OCTETS: AtomicU32 = AtomicU32::new(4242);
1280 let next_addr = ADDRESS_LOWER_THREE_OCTETS.fetch_add(1, Ordering::Relaxed);
1281 if next_addr > (u32::MAX >> 8) {
1282 panic!("Loopback address pool for tests depleted");
1283 }
1284 let next_addr_bytes = next_addr.to_le_bytes();
1285 IpAddr::V4(Ipv4Addr::new(
1286 127,
1287 next_addr_bytes[2],
1288 next_addr_bytes[1],
1289 next_addr_bytes[0],
1290 ))
1291}
1292
1293#[cfg(test)]
1294mod tests {
1295 use super::*;
1296 use crate::frame::{read_frame, read_request_frame, FrameType};
1297 use crate::{
1298 setup_tracing, Condition, Reaction as _, RequestReaction, ResponseOpcode, ResponseReaction,
1299 };
1300 use assert_matches::assert_matches;
1301 use bytes::{BufMut, BytesMut};
1302 use futures::future::{join, join3};
1303 use rand::RngCore;
1304 use scylla_cql::frame::frame_errors::FrameError;
1305 use scylla_cql::frame::types::write_string_multimap;
1306 use std::collections::HashMap;
1307 use std::mem;
1308 use std::str::FromStr;
1309 use std::time::Duration;
1310 use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
1311 use tokio::sync::oneshot;
1312
1313 fn random_body() -> Bytes {
1314 let body_len = (rand::random::<u32>() % 1000) as usize;
1315 let mut body = BytesMut::zeroed(body_len);
1316 rand::thread_rng().fill_bytes(body.as_mut());
1317 body.freeze()
1318 }
1319
1320 async fn respond_with_supported(
1321 conn: &mut TcpStream,
1322 supported_options: &HashMap<String, Vec<String>>,
1323 ) {
1324 let RequestFrame {
1325 params: recvd_params,
1326 opcode: recvd_opcode,
1327 body: recvd_body,
1328 } = read_request_frame(conn).await.unwrap();
1329 assert_eq!(recvd_params, HARDCODED_OPTIONS_PARAMS);
1330 assert_eq!(recvd_opcode, RequestOpcode::Options);
1331 assert_eq!(recvd_body, Bytes::new()); let mut body = BytesMut::new();
1334 write_string_multimap(supported_options, &mut body).unwrap();
1335
1336 let body = body.freeze();
1337
1338 write_frame(
1339 HARDCODED_OPTIONS_PARAMS.for_response(),
1340 FrameOpcode::Response(ResponseOpcode::Supported),
1341 &body,
1342 conn,
1343 )
1344 .await
1345 .unwrap();
1346 }
1347
1348 fn supported_shards_count(shards_count: u16) -> HashMap<String, Vec<String>> {
1349 let mut sharded_info = HashMap::new();
1350 sharded_info.insert(
1351 String::from("SCYLLA_NR_SHARDS"),
1352 vec![shards_count.to_string()],
1353 );
1354 sharded_info
1355 }
1356
1357 fn supported_shard_number(shard_num: TargetShard) -> HashMap<String, Vec<String>> {
1358 let mut sharded_info = HashMap::new();
1359 sharded_info.insert(String::from("SCYLLA_SHARD"), vec![shard_num.to_string()]);
1360 sharded_info
1361 }
1362
1363 async fn respond_with_shards_count(conn: &mut TcpStream, shards_count: u16) {
1364 respond_with_supported(conn, &supported_shards_count(shards_count)).await;
1365 }
1366
1367 async fn respond_with_shard_num(conn: &mut TcpStream, shard_num: TargetShard) {
1368 respond_with_supported(conn, &supported_shard_number(shard_num)).await;
1369 }
1370
1371 fn next_local_address_with_port(port: u16) -> SocketAddr {
1372 SocketAddr::new(get_exclusive_local_address(), port)
1373 }
1374
1375 async fn identity_proxy_does_not_mutate_frames(shard_awareness: ShardAwareness) {
1376 let node1_real_addr = next_local_address_with_port(9876);
1377 let node1_proxy_addr = next_local_address_with_port(9876);
1378 let proxy = Proxy::new([Node::new(
1379 node1_real_addr,
1380 node1_proxy_addr,
1381 shard_awareness,
1382 None,
1383 None,
1384 )]);
1385 let running_proxy = proxy.run().await.unwrap();
1386
1387 let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1388
1389 let params = FrameParams {
1390 flags: 0,
1391 version: 0x04,
1392 stream: 0,
1393 };
1394 let opcode = FrameOpcode::Request(RequestOpcode::Options);
1395
1396 let body = random_body();
1397
1398 let send_frame_to_shard = async {
1399 let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
1400
1401 write_frame(params, opcode, &body, &mut conn).await.unwrap();
1402 conn
1403 };
1404
1405 let mock_node_action = async {
1406 if let ShardAwareness::QueryNode = shard_awareness {
1407 respond_with_shards_count(&mut mock_node_listener.accept().await.unwrap().0, 1)
1408 .await;
1409 }
1410 let (mut conn, _) = mock_node_listener.accept().await.unwrap();
1411 if shard_awareness.is_aware() {
1412 respond_with_shard_num(&mut conn, 1).await;
1413 }
1414 let RequestFrame {
1415 params: recvd_params,
1416 opcode: recvd_opcode,
1417 body: recvd_body,
1418 } = read_request_frame(&mut conn).await.unwrap();
1419 assert_eq!(recvd_params, params);
1420 assert_eq!(FrameOpcode::Request(recvd_opcode), opcode);
1421 assert_eq!(recvd_body, body);
1422 conn
1423 };
1424
1425 let (_node_conn, _driver_conn) = join(mock_node_action, send_frame_to_shard).await;
1427 running_proxy.finish().await.unwrap();
1428 }
1429
1430 #[tokio::test]
1431 #[ntest::timeout(1000)]
1432 async fn identity_shard_unaware_proxy_does_not_mutate_frames() {
1433 setup_tracing();
1434 identity_proxy_does_not_mutate_frames(ShardAwareness::Unaware).await
1435 }
1436
1437 #[tokio::test]
1438 #[ntest::timeout(1000)]
1439 async fn identity_shard_aware_proxy_does_not_mutate_frames() {
1440 setup_tracing();
1441 identity_proxy_does_not_mutate_frames(ShardAwareness::QueryNode).await
1442 }
1443
1444 #[tokio::test]
1445 #[ntest::timeout(1000)]
1446 async fn shard_aware_proxy_is_transparent_for_connection_to_shards() {
1447 setup_tracing();
1448 async fn test_for_shards_num(shards_num: u16) {
1449 let node1_real_addr = next_local_address_with_port(9876);
1450 let node1_proxy_addr = next_local_address_with_port(9876);
1451 let proxy = Proxy::new([Node::new(
1452 node1_real_addr,
1453 node1_proxy_addr,
1454 ShardAwareness::FixedNum(shards_num),
1455 None,
1456 None,
1457 )]);
1458 let running_proxy = proxy.run().await.unwrap();
1459
1460 let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1461
1462 let (driver_addr_tx, driver_addr_rx) = oneshot::channel::<SocketAddr>();
1463
1464 let send_frame_to_shard = async {
1465 let socket = TcpSocket::new_v4().unwrap();
1466 socket
1467 .bind(SocketAddr::from_str("0.0.0.0:0").unwrap())
1468 .unwrap();
1469 let conn = socket.connect(node1_proxy_addr).await.unwrap();
1470 driver_addr_tx.send(conn.local_addr().unwrap()).unwrap();
1471 conn
1472 };
1473
1474 let mock_node_action = async {
1475 let (conn, remote_addr) = mock_node_listener.accept().await.unwrap();
1476 let driver_addr = driver_addr_rx.await.unwrap();
1477 assert_eq!(
1478 driver_addr.port() % shards_num,
1479 remote_addr.port() % shards_num
1480 );
1481 conn
1482 };
1483
1484 let (_node_conn, _driver_conn) = join(mock_node_action, send_frame_to_shard).await;
1486 running_proxy.finish().await.unwrap();
1487 }
1488
1489 for shard_num in 1..6 {
1490 test_for_shards_num(shard_num).await;
1491 }
1492 }
1493
1494 #[tokio::test]
1495 #[ntest::timeout(1000)]
1496 async fn shard_aware_proxy_queries_shards_number() {
1497 setup_tracing();
1498 async fn test_for_shards_num(shards_num: u16) {
1499 for shard_num in 0..shards_num {
1500 let node1_real_addr = next_local_address_with_port(9876);
1501 let node1_proxy_addr = next_local_address_with_port(9876);
1502 let proxy = Proxy::new([Node::new(
1503 node1_real_addr,
1504 node1_proxy_addr,
1505 ShardAwareness::QueryNode,
1506 None,
1507 None,
1508 )]);
1509 let running_proxy = proxy.run().await.unwrap();
1510
1511 let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1512
1513 let (driver_addr_tx, driver_addr_rx) = oneshot::channel::<SocketAddr>();
1514
1515 let mock_driver_addr = next_local_address_with_port(shards_num * 1234 + shard_num);
1516 let send_frame_to_shard = async {
1517 let socket = TcpSocket::new_v4().unwrap();
1518 socket
1519 .bind(mock_driver_addr)
1520 .unwrap_or_else(|_| panic!("driver_addr failed: {}", mock_driver_addr));
1521 driver_addr_tx.send(socket.local_addr().unwrap()).unwrap();
1522 socket.connect(node1_proxy_addr).await.unwrap()
1523 };
1524
1525 let mock_node_action = async {
1526 respond_with_shards_count(
1527 &mut mock_node_listener.accept().await.unwrap().0,
1528 shards_num,
1529 )
1530 .await;
1531 let (conn, remote_addr) = mock_node_listener.accept().await.unwrap();
1532 let driver_addr = driver_addr_rx.await.unwrap();
1533 assert_eq!(
1534 driver_addr.port() % shards_num,
1535 remote_addr.port() % shards_num
1536 );
1537 conn
1538 };
1539
1540 let (_node_conn, _driver_conn) = join(mock_node_action, send_frame_to_shard).await;
1541 running_proxy.finish().await.unwrap();
1542 }
1543 }
1544
1545 for shard_num in 1..6 {
1546 test_for_shards_num(shard_num).await;
1547 }
1548 }
1549
1550 #[tokio::test]
1551 #[ntest::timeout(1000)]
1552 async fn forger_proxy_forges_response() {
1553 setup_tracing();
1554 let node1_real_addr = next_local_address_with_port(9876);
1555 let node1_proxy_addr = next_local_address_with_port(9876);
1556
1557 let this_shall_pass = b"This.Shall.Pass.";
1558 let test_msg = b"Test";
1559
1560 let proxy = Proxy::new([Node::new(
1561 node1_real_addr,
1562 node1_proxy_addr,
1563 ShardAwareness::Unaware,
1564 Some(vec![
1565 RequestRule(
1566 Condition::RequestOpcode(RequestOpcode::Register),
1567 RequestReaction::forge_response(Arc::new(|RequestFrame { params, .. }| {
1568 ResponseFrame {
1569 params: params.for_response(),
1570 opcode: ResponseOpcode::Event,
1571 body: Bytes::from_static(test_msg),
1572 }
1573 })),
1574 ),
1575 RequestRule(
1576 Condition::BodyContainsCaseSensitive(Box::new(*this_shall_pass)),
1577 RequestReaction::noop(),
1578 ),
1579 RequestRule(
1580 Condition::True, RequestReaction::forge_response(Arc::new(|RequestFrame { params, .. }| {
1582 ResponseFrame {
1583 params: params.for_response(),
1584 opcode: ResponseOpcode::Ready,
1585 body: Bytes::new(),
1586 }
1587 })),
1588 ),
1589 ]),
1590 None,
1591 )]);
1592 let running_proxy = proxy.run().await.unwrap();
1593
1594 let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1595
1596 let params1 = FrameParams {
1597 flags: 3,
1598 version: 0x42,
1599 stream: 42,
1600 };
1601 let opcode1 = FrameOpcode::Request(RequestOpcode::Startup);
1602
1603 let params2 = FrameParams {
1604 flags: 4,
1605 version: 0x04,
1606 stream: 17,
1607 };
1608 let opcode2 = FrameOpcode::Request(RequestOpcode::Register);
1609
1610 let params3 = FrameParams {
1611 flags: 8,
1612 version: 0x04,
1613 stream: 11,
1614 };
1615 let opcode3 = FrameOpcode::Request(RequestOpcode::Execute);
1616
1617 let body1 = random_body();
1618 let body2 = random_body();
1619 let body3 = {
1620 let mut body = BytesMut::new();
1621 body.put(&b"uSeLeSs JuNk"[..]);
1622 body.put(&this_shall_pass[..]);
1623 body.freeze()
1624 };
1625
1626 let send_frame_to_shard = async {
1627 let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
1628
1629 write_frame(params1, opcode1, &body1, &mut conn)
1630 .await
1631 .unwrap();
1632 write_frame(params2, opcode2, &body2, &mut conn)
1633 .await
1634 .unwrap();
1635 write_frame(params3, opcode3, &body3, &mut conn)
1636 .await
1637 .unwrap();
1638
1639 let ResponseFrame {
1640 params: recvd_params,
1641 opcode: recvd_opcode,
1642 body: recvd_body,
1643 } = read_response_frame(&mut conn).await.unwrap();
1644 assert_eq!(recvd_params, params1.for_response());
1645 assert_eq!(recvd_opcode, ResponseOpcode::Ready);
1646 assert_eq!(recvd_body, Bytes::new());
1647
1648 let ResponseFrame {
1649 params: recvd_params,
1650 opcode: recvd_opcode,
1651 body: recvd_body,
1652 } = read_response_frame(&mut conn).await.unwrap();
1653 assert_eq!(recvd_params, params2.for_response());
1654 assert_eq!(recvd_opcode, ResponseOpcode::Event);
1655 assert_eq!(recvd_body, Bytes::from_static(test_msg));
1656
1657 conn
1658 };
1659
1660 let mock_node_action = async {
1661 let (mut conn, _) = mock_node_listener.accept().await.unwrap();
1662 let RequestFrame {
1663 params: recvd_params,
1664 opcode: recvd_opcode,
1665 body: recvd_body,
1666 } = read_request_frame(&mut conn).await.unwrap();
1667 assert_eq!(recvd_params, params3);
1668 assert_eq!(FrameOpcode::Request(recvd_opcode), opcode3);
1669 assert_eq!(recvd_body, body3);
1670
1671 conn
1672 };
1673
1674 let (mut node_conn, mut driver_conn) = join(mock_node_action, send_frame_to_shard).await;
1675
1676 running_proxy.finish().await.unwrap();
1677
1678 assert_matches!(driver_conn.read(&mut [0u8; 1]).await, Ok(0));
1679 assert_matches!(node_conn.read(&mut [0u8; 1]).await, Ok(0));
1680 }
1681
1682 #[tokio::test]
1683 #[ntest::timeout(1000)]
1684 async fn ad_hoc_rules_changing() {
1685 setup_tracing();
1686 let node1_real_addr = next_local_address_with_port(9876);
1687 let node1_proxy_addr = next_local_address_with_port(9876);
1688 let proxy = Proxy::new([Node::new(
1689 node1_real_addr,
1690 node1_proxy_addr,
1691 ShardAwareness::Unaware,
1692 None,
1693 None,
1694 )]);
1695 let mut running_proxy = proxy.run().await.unwrap();
1696
1697 let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1698
1699 let params = FrameParams {
1700 flags: 0,
1701 version: 0x04,
1702 stream: 0,
1703 };
1704 let opcode = FrameOpcode::Request(RequestOpcode::Options);
1705
1706 let body = random_body();
1707
1708 let (mut driver, mut node) = {
1709 let results = join(
1710 TcpStream::connect(node1_proxy_addr),
1711 mock_node_listener.accept(),
1712 )
1713 .await;
1714 (results.0.unwrap(), results.1.unwrap().0)
1715 };
1716
1717 async fn request(
1718 driver: &mut TcpStream,
1719 node: &mut TcpStream,
1720 params: FrameParams,
1721 opcode: FrameOpcode,
1722 body: &Bytes,
1723 ) -> Result<RequestFrame, FrameError> {
1724 let (send_res, recv_res) = join(
1725 write_frame(params, opcode, &body.clone(), driver),
1726 read_request_frame(node),
1727 )
1728 .await;
1729 send_res.unwrap();
1730 recv_res
1731 }
1732 {
1733 let RequestFrame {
1735 params: recvd_params,
1736 opcode: recvd_opcode,
1737 body: recvd_body,
1738 } = request(&mut driver, &mut node, params, opcode, &body)
1739 .await
1740 .unwrap();
1741 assert_eq!(recvd_params, params);
1742 assert_eq!(FrameOpcode::Request(recvd_opcode), opcode);
1743 assert_eq!(recvd_body, body);
1744 }
1745 running_proxy.running_nodes[0].change_request_rules(Some(vec![RequestRule(
1746 Condition::True,
1747 RequestReaction::drop_frame(),
1748 )]));
1749
1750 {
1751 tokio::select! {
1753 res = request(&mut driver, &mut node, params, opcode, &body) => panic!("Rules did not work: received response {:?}", res),
1754 _ = tokio::time::sleep(std::time::Duration::from_millis(20)) => (),
1755 };
1756 }
1757
1758 running_proxy.turn_off_rules();
1759
1760 {
1761 let RequestFrame {
1763 params: recvd_params,
1764 opcode: recvd_opcode,
1765 body: recvd_body,
1766 } = request(&mut driver, &mut node, params, opcode, &body)
1767 .await
1768 .unwrap();
1769 assert_eq!(recvd_params, params);
1770 assert_eq!(FrameOpcode::Request(recvd_opcode), opcode);
1771 assert_eq!(recvd_body, body);
1772 }
1773
1774 running_proxy.finish().await.unwrap();
1775 }
1776
1777 #[tokio::test]
1778 #[ntest::timeout(2000)]
1779 async fn limited_times_condition_expires() {
1780 setup_tracing();
1781 const FAILING_TRIES: usize = 4;
1782 const PASSING_TRIES: usize = 5;
1783
1784 let node1_real_addr = next_local_address_with_port(9876);
1785 let node1_proxy_addr = next_local_address_with_port(9876);
1786 let proxy = Proxy::new([Node::new(
1787 node1_real_addr,
1788 node1_proxy_addr,
1789 ShardAwareness::Unaware,
1790 Some(vec![
1791 RequestRule(
1792 Condition::not(Condition::TrueForLimitedTimes(
1794 FAILING_TRIES + PASSING_TRIES,
1795 )),
1796 RequestReaction::drop_frame(),
1797 ),
1798 RequestRule(
1799 Condition::not(Condition::TrueForLimitedTimes(FAILING_TRIES)),
1801 RequestReaction::noop(),
1802 ),
1803 RequestRule(
1804 Condition::True,
1806 RequestReaction::drop_frame(),
1807 ),
1808 ]),
1809 None,
1810 )]);
1811 let running_proxy = proxy.run().await.unwrap();
1812
1813 let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1814
1815 let params = FrameParams {
1816 flags: 0,
1817 version: 0x04,
1818 stream: 0,
1819 };
1820 let opcode = FrameOpcode::Request(RequestOpcode::Options);
1821 let body = random_body();
1822
1823 let (mut driver, mut node) = {
1824 let results = join(
1825 TcpStream::connect(node1_proxy_addr),
1826 mock_node_listener.accept(),
1827 )
1828 .await;
1829 (results.0.unwrap(), results.1.unwrap().0)
1830 };
1831
1832 async fn request(
1833 driver: &mut TcpStream,
1834 node: &mut TcpStream,
1835 params: FrameParams,
1836 opcode: FrameOpcode,
1837 body: &Bytes,
1838 ) -> Result<RequestFrame, FrameError> {
1839 let (send_res, recv_res) = join(
1840 write_frame(params, opcode, &body.clone(), driver),
1841 read_request_frame(node),
1842 )
1843 .await;
1844 send_res.unwrap();
1845 recv_res
1846 }
1847
1848 for _ in 0..FAILING_TRIES {
1849 tokio::select! {
1850 res = request(&mut driver, &mut node, params, opcode, &body) => panic!("Rules did not work: received response {:?}", res),
1851 _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => (),
1852 };
1853 }
1854
1855 for _ in 0..PASSING_TRIES {
1856 let RequestFrame {
1857 params: recvd_params,
1858 opcode: recvd_opcode,
1859 body: recvd_body,
1860 } = request(&mut driver, &mut node, params, opcode, &body)
1861 .await
1862 .unwrap();
1863 assert_eq!(recvd_params, params);
1864 assert_eq!(FrameOpcode::Request(recvd_opcode), opcode);
1865 assert_eq!(recvd_body, body);
1866 }
1867
1868 for _ in 0..3 {
1869 tokio::select! {
1871 res = request(&mut driver, &mut node, params, opcode, &body) => panic!("Rules did not work: received response {:?}", res),
1872 _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => (),
1873 };
1874 }
1875
1876 running_proxy.finish().await.unwrap();
1877 }
1878
1879 #[tokio::test]
1880 #[ntest::timeout(1000)]
1881 async fn proxy_reports_requests_and_responses_as_feedback() {
1882 setup_tracing();
1883 let node1_real_addr = next_local_address_with_port(9876);
1884 let node1_proxy_addr = next_local_address_with_port(9876);
1885
1886 let (request_feedback_tx, mut request_feedback_rx) = mpsc::unbounded_channel();
1887 let (response_feedback_tx, mut response_feedback_rx) = mpsc::unbounded_channel();
1888 let proxy = Proxy::new([Node::new(
1889 node1_real_addr,
1890 node1_proxy_addr,
1891 ShardAwareness::Unaware,
1892 Some(vec![RequestRule(
1893 Condition::True,
1894 RequestReaction::drop_frame().with_feedback_when_performed(request_feedback_tx),
1895 )]),
1896 Some(vec![ResponseRule(
1897 Condition::True,
1898 ResponseReaction::drop_frame().with_feedback_when_performed(response_feedback_tx),
1899 )]),
1900 )]);
1901 let running_proxy = proxy.run().await.unwrap();
1902
1903 let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1904
1905 let params = FrameParams {
1906 flags: 0,
1907 version: 0x04,
1908 stream: 0,
1909 };
1910 let request_opcode = FrameOpcode::Request(RequestOpcode::Options);
1911 let response_opcode = FrameOpcode::Response(ResponseOpcode::Ready);
1912
1913 let body = random_body();
1914
1915 let send_frame_to_shard = async {
1916 let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
1917 write_frame(params, request_opcode, &body, &mut conn)
1918 .await
1919 .unwrap();
1920 conn
1921 };
1922
1923 let mock_node_action = async {
1924 let (mut conn, _) = mock_node_listener.accept().await.unwrap();
1925 write_frame(params.for_response(), response_opcode, &body, &mut conn)
1926 .await
1927 .unwrap();
1928 conn
1929 };
1930
1931 let (_node_conn, _driver_conn) = join(mock_node_action, send_frame_to_shard).await;
1933
1934 let (feedback_request, _shard) = request_feedback_rx.recv().await.unwrap();
1935 assert_eq!(feedback_request.params, params);
1936 assert_eq!(
1937 FrameOpcode::Request(feedback_request.opcode),
1938 request_opcode
1939 );
1940 assert_eq!(feedback_request.body, body);
1941 let (feedback_response, _shard) = response_feedback_rx.recv().await.unwrap();
1942 assert_eq!(feedback_response.params, params.for_response());
1943 assert_eq!(
1944 FrameOpcode::Response(feedback_response.opcode),
1945 response_opcode
1946 );
1947 assert_eq!(feedback_response.body, body);
1948
1949 running_proxy.finish().await.unwrap();
1950 }
1951
1952 #[tokio::test]
1953 #[ntest::timeout(1000)]
1954 async fn sanity_check_reports_errors() {
1955 setup_tracing();
1956 let node1_real_addr = next_local_address_with_port(9876);
1957 let node1_proxy_addr = next_local_address_with_port(9876);
1958 let proxy = Proxy::new([Node::new(
1959 node1_real_addr,
1960 node1_proxy_addr,
1961 ShardAwareness::Unaware,
1962 None,
1963 None,
1964 )]);
1965 let mut running_proxy = proxy.run().await.unwrap();
1966
1967 let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1968
1969 let send_frame_to_shard = async {
1970 let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
1971
1972 conn.write_all(b"uselessJunk").await.unwrap();
1973 conn
1974 };
1975
1976 let mock_node_action = async {
1977 let (conn, _) = mock_node_listener.accept().await.unwrap();
1978 conn
1979 };
1980
1981 let (node_conn, driver_conn) = join(mock_node_action, send_frame_to_shard).await;
1982
1983 running_proxy.sanity_check().unwrap();
1984
1985 mem::drop(driver_conn);
1986 assert_matches!(
1987 running_proxy.wait_for_error().await,
1988 Some(ProxyError::Worker(WorkerError::DriverDisconnected(_)))
1989 );
1990 running_proxy.sanity_check().unwrap();
1991
1992 mem::drop(node_conn);
1993 assert_matches!(
1994 running_proxy.wait_for_error().await,
1995 Some(ProxyError::Worker(WorkerError::NodeDisconnected(_)))
1996 );
1997 running_proxy.sanity_check().unwrap();
1998
1999 let _ = running_proxy.finish().await;
2001 }
2002
2003 #[tokio::test]
2004 #[ntest::timeout(1000)]
2005 async fn proxy_processes_requests_concurrently() {
2006 setup_tracing();
2007 let node1_real_addr = next_local_address_with_port(9876);
2008 let node1_proxy_addr = next_local_address_with_port(9876);
2009
2010 let delay = Duration::from_millis(30);
2011
2012 let proxy = Proxy::new([Node::new(
2013 node1_real_addr,
2014 node1_proxy_addr,
2015 ShardAwareness::Unaware,
2016 Some(vec![RequestRule(
2017 Condition::TrueForLimitedTimes(1),
2018 RequestReaction::delay(delay),
2019 )]),
2020 None,
2021 )]);
2022 let running_proxy = proxy.run().await.unwrap();
2023
2024 let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2025
2026 let params1 = FrameParams {
2027 flags: 0,
2028 version: 0x04,
2029 stream: 0,
2030 };
2031 let opcode1 = FrameOpcode::Request(RequestOpcode::Options);
2032
2033 let body1 = random_body();
2034
2035 let params2 = FrameParams {
2036 flags: 0,
2037 version: 0x04,
2038 stream: 0,
2039 };
2040 let opcode2 = FrameOpcode::Request(RequestOpcode::Register);
2041
2042 let body2 = random_body();
2043
2044 let send_frame_to_shard = async {
2045 let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
2046
2047 write_frame(params1, opcode1, &body1, &mut conn)
2048 .await
2049 .unwrap();
2050 write_frame(params2, opcode2, &body2, &mut conn)
2051 .await
2052 .unwrap();
2053 conn
2054 };
2055
2056 let mock_node_action = async {
2057 let (mut conn, _) = mock_node_listener.accept().await.unwrap();
2058 let RequestFrame {
2059 params: recvd_params,
2060 opcode: recvd_opcode,
2061 body: recvd_body,
2062 } = read_request_frame(&mut conn).await.unwrap();
2063 assert_eq!(recvd_params, params2);
2064 assert_eq!(FrameOpcode::Request(recvd_opcode), opcode2);
2065 assert_eq!(recvd_body, body2);
2066 conn
2067 };
2068
2069 let (_node_conn, _driver_conn) =
2071 tokio::time::timeout(delay, join(mock_node_action, send_frame_to_shard))
2072 .await
2073 .expect("Request processing was not concurrent");
2074 running_proxy.finish().await.unwrap();
2075 }
2076
2077 #[tokio::test]
2078 #[ntest::timeout(1000)]
2079 async fn dry_mode_proxy_drops_incoming_frames() {
2080 setup_tracing();
2081 let node1_proxy_addr = next_local_address_with_port(9876);
2082 let proxy = Proxy::new([Node::new_dry_mode(node1_proxy_addr, None)]);
2083 let running_proxy = proxy.run().await.unwrap();
2084
2085 let params = FrameParams {
2086 flags: 0,
2087 version: 0x04,
2088 stream: 0,
2089 };
2090 let opcode = FrameOpcode::Request(RequestOpcode::Options);
2091
2092 let body = random_body();
2093
2094 let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
2095
2096 write_frame(params, opcode, &body, &mut conn).await.unwrap();
2097 tokio::time::sleep(Duration::from_millis(3)).await;
2099 running_proxy.finish().await.unwrap();
2100 }
2101
2102 #[tokio::test]
2103 #[ntest::timeout(1000)]
2104 async fn dry_mode_forger_proxy_forges_response() {
2105 setup_tracing();
2106 let node1_proxy_addr = next_local_address_with_port(9876);
2107
2108 let this_shall_pass = b"This.Shall.Pass.";
2109 let test_msg = b"Test";
2110
2111 let proxy = Proxy::new([Node::new_dry_mode(
2112 node1_proxy_addr,
2113 Some(vec![
2114 RequestRule(
2115 Condition::RequestOpcode(RequestOpcode::Register),
2116 RequestReaction::forge_response(Arc::new(|RequestFrame { params, .. }| {
2117 ResponseFrame {
2118 params: params.for_response(),
2119 opcode: ResponseOpcode::Event,
2120 body: Bytes::from_static(test_msg),
2121 }
2122 })),
2123 ),
2124 RequestRule(
2125 Condition::BodyContainsCaseSensitive(Box::new(*this_shall_pass)),
2126 RequestReaction::noop(),
2127 ),
2128 RequestRule(
2129 Condition::True, RequestReaction::forge_response(Arc::new(|RequestFrame { params, .. }| {
2131 ResponseFrame {
2132 params: params.for_response(),
2133 opcode: ResponseOpcode::Ready,
2134 body: Bytes::new(),
2135 }
2136 })),
2137 ),
2138 ]),
2139 )]);
2140 let running_proxy = proxy.run().await.unwrap();
2141
2142 let params1 = FrameParams {
2143 flags: 3,
2144 version: 0x42,
2145 stream: 42,
2146 };
2147 let opcode1 = FrameOpcode::Request(RequestOpcode::Startup);
2148
2149 let params2 = FrameParams {
2150 flags: 4,
2151 version: 0x04,
2152 stream: 17,
2153 };
2154 let opcode2 = FrameOpcode::Request(RequestOpcode::Register);
2155
2156 let params3 = FrameParams {
2157 flags: 8,
2158 version: 0x04,
2159 stream: 11,
2160 };
2161 let opcode3 = FrameOpcode::Request(RequestOpcode::Execute);
2162
2163 let body1 = random_body();
2164 let body2 = random_body();
2165 let body3 = {
2166 let mut body = BytesMut::new();
2167 body.put(&b"uSeLeSs JuNk"[..]);
2168 body.put(&this_shall_pass[..]);
2169 body.freeze()
2170 };
2171
2172 let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
2173
2174 write_frame(params1, opcode1, &body1, &mut conn)
2175 .await
2176 .unwrap();
2177 write_frame(params2, opcode2, &body2, &mut conn)
2178 .await
2179 .unwrap();
2180 write_frame(params3, opcode3, &body3, &mut conn)
2181 .await
2182 .unwrap();
2183
2184 let ResponseFrame {
2185 params: recvd_params,
2186 opcode: recvd_opcode,
2187 body: recvd_body,
2188 } = read_response_frame(&mut conn).await.unwrap();
2189 assert_eq!(recvd_params, params1.for_response());
2190 assert_eq!(recvd_opcode, ResponseOpcode::Ready);
2191 assert_eq!(recvd_body, Bytes::new());
2192
2193 let ResponseFrame {
2194 params: recvd_params,
2195 opcode: recvd_opcode,
2196 body: recvd_body,
2197 } = read_response_frame(&mut conn).await.unwrap();
2198 assert_eq!(recvd_params, params2.for_response());
2199 assert_eq!(recvd_opcode, ResponseOpcode::Event);
2200 assert_eq!(recvd_body, Bytes::from_static(test_msg));
2201
2202 running_proxy.finish().await.unwrap();
2203
2204 assert_matches!(conn.read(&mut [0u8; 1]).await, Ok(0));
2205 }
2206
2207 #[tokio::test]
2211 #[ntest::timeout(1000)]
2212 async fn proxy_reports_target_shard_as_feedback() {
2213 setup_tracing();
2214
2215 let node_port = 10101;
2216 let node_real_addr = next_local_address_with_port(node_port);
2217 let mock_node_listener = TcpListener::bind(node_real_addr).await.unwrap();
2218
2219 let params = FrameParams {
2220 flags: 0,
2221 version: 0x04,
2222 stream: 0,
2223 };
2224 let request_opcode = FrameOpcode::Request(RequestOpcode::Options);
2225 let response_opcode = FrameOpcode::Response(ResponseOpcode::Ready);
2226
2227 let body = random_body();
2228
2229 for shards_count in 2..9 {
2230 let driver1_shard = shards_count - 1;
2232 let driver2_shard = shards_count - 2;
2233 let node_proxy_addr = next_local_address_with_port(node_port);
2234
2235 let (request_feedback_tx, mut request_feedback_rx) = mpsc::unbounded_channel();
2236 let (response_feedback_tx, mut response_feedback_rx) = mpsc::unbounded_channel();
2237
2238 let proxy = Proxy::new([Node::new(
2239 node_real_addr,
2240 node_proxy_addr,
2241 ShardAwareness::FixedNum(shards_count),
2242 Some(vec![RequestRule(
2243 Condition::True,
2244 RequestReaction::drop_frame().with_feedback_when_performed(request_feedback_tx),
2245 )]),
2246 Some(vec![ResponseRule(
2247 Condition::True,
2248 ResponseReaction::drop_frame()
2249 .with_feedback_when_performed(response_feedback_tx),
2250 )]),
2251 )]);
2252 let running_proxy = proxy.run().await.unwrap();
2253
2254 fn draw_source_port_for_shard(shards_count: u16, shard: u16) -> u16 {
2256 assert!(shard < shards_count);
2257 (49152 + shards_count - 1) / shards_count * shards_count + shard
2258 }
2259
2260 async fn bind_socket_for_shard(shards_count: u16, shard: u16) -> TcpSocket {
2261 let socket = TcpSocket::new_v4().unwrap();
2262 let initial_port = draw_source_port_for_shard(shards_count, shard);
2263
2264 let mut desired_addr =
2265 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), initial_port);
2266 while socket.bind(desired_addr).is_err() {
2267 let next_port = desired_addr.port().wrapping_add(shards_count);
2269 if next_port == initial_port {
2270 panic!("No more ports left");
2271 }
2272 desired_addr.set_port(next_port);
2273 }
2274
2275 socket
2276 }
2277
2278 let body_ref = &body;
2279 let send_frame_to_shard = |driver_shard: u16| async move {
2280 let socket = bind_socket_for_shard(shards_count, driver_shard).await;
2281 let mut conn = socket.connect(node_proxy_addr).await.unwrap();
2282
2283 write_frame(params, request_opcode, body_ref, &mut conn)
2284 .await
2285 .unwrap();
2286 conn
2287 };
2288
2289 let mock_driver1_action = send_frame_to_shard(driver1_shard);
2290 let mock_driver2_action = send_frame_to_shard(driver2_shard);
2291
2292 let mock_node_action = async {
2294 let mut conns_futs = (0..2)
2295 .map(|_| async {
2296 let (mut conn, driver_addr) = mock_node_listener.accept().await.unwrap();
2297 respond_with_shard_num(&mut conn, driver_addr.port() % shards_count).await;
2298 write_frame(params.for_response(), response_opcode, body_ref, &mut conn)
2299 .await
2300 .unwrap();
2301 conn
2302 })
2303 .collect::<Vec<_>>();
2304 let conn2 = conns_futs.pop().unwrap().await;
2305 let conn1 = conns_futs.pop().unwrap().await;
2306 (conn1, conn2)
2307 };
2308
2309 let (_node_conns, _driver1_conn, _driver2_conn) =
2311 join3(mock_node_action, mock_driver1_action, mock_driver2_action).await;
2312
2313 let assert_feedback_request = |feedback_request: RequestFrame| {
2314 assert_eq!(feedback_request.params, params);
2315 assert_eq!(
2316 FrameOpcode::Request(feedback_request.opcode),
2317 request_opcode
2318 );
2319 assert_eq!(feedback_request.body, body);
2320 };
2321
2322 let assert_feedback_response = |feedback_response: ResponseFrame| {
2323 assert_eq!(feedback_response.params, params.for_response());
2324 assert_eq!(
2325 FrameOpcode::Response(feedback_response.opcode),
2326 response_opcode
2327 );
2328 assert_eq!(feedback_response.body, body);
2329 };
2330
2331 let (feedback_request, shard1) = request_feedback_rx.recv().await.unwrap();
2332 assert_feedback_request(feedback_request);
2333 let (feedback_request, shard2) = request_feedback_rx.recv().await.unwrap();
2334 assert_feedback_request(feedback_request);
2335 let (feedback_response, shard3) = response_feedback_rx.recv().await.unwrap();
2336 assert_feedback_response(feedback_response);
2337 let (feedback_response, shard4) = response_feedback_rx.recv().await.unwrap();
2338 assert_feedback_response(feedback_response);
2339
2340 let mut expected_shards = [driver1_shard, driver1_shard, driver2_shard, driver2_shard];
2342 expected_shards.sort_unstable();
2343
2344 let mut got_shards = [
2345 shard1.unwrap(),
2346 shard2.unwrap(),
2347 shard3.unwrap(),
2348 shard4.unwrap(),
2349 ];
2350 got_shards.sort_unstable();
2351
2352 assert_eq!(expected_shards, got_shards);
2353
2354 running_proxy.finish().await.unwrap();
2355 }
2356 }
2357
2358 #[tokio::test]
2359 #[ntest::timeout(1000)]
2360 async fn proxy_ignores_control_connection_messages() {
2361 setup_tracing();
2362 let node1_real_addr = next_local_address_with_port(9876);
2363 let node1_proxy_addr = next_local_address_with_port(9876);
2364
2365 let (request_feedback_tx, mut request_feedback_rx) = mpsc::unbounded_channel();
2366 let (response_feedback_tx, mut response_feedback_rx) = mpsc::unbounded_channel();
2367 let proxy = Proxy::new([Node::new(
2368 node1_real_addr,
2369 node1_proxy_addr,
2370 ShardAwareness::Unaware,
2371 Some(vec![RequestRule(
2372 Condition::not(Condition::ConnectionRegisteredAnyEvent),
2373 RequestReaction::noop().with_feedback_when_performed(request_feedback_tx),
2374 )]),
2375 Some(vec![ResponseRule(
2376 Condition::not(Condition::ConnectionRegisteredAnyEvent),
2377 ResponseReaction::noop().with_feedback_when_performed(response_feedback_tx),
2378 )]),
2379 )]);
2380 let running_proxy = proxy.run().await.unwrap();
2381
2382 let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2383
2384 let (mut client_socket, mut server_socket) = join(
2385 async { TcpStream::connect(node1_proxy_addr).await.unwrap() },
2386 async { mock_node_listener.accept().await.unwrap().0 },
2387 )
2388 .await;
2389
2390 async fn perform_reqest_response<'a>(
2391 req_opcode: RequestOpcode,
2392 resp_opcode: ResponseOpcode,
2393 client_socket_ref: &'a mut TcpStream,
2394 server_socket_ref: &'a mut TcpStream,
2395 body_base: &'a str,
2396 ) {
2397 let params = FrameParams {
2398 flags: 0,
2399 version: 0x04,
2400 stream: 0,
2401 };
2402
2403 write_frame(
2404 params,
2405 FrameOpcode::Request(req_opcode),
2406 &(body_base.to_string() + "|request|").into(),
2407 client_socket_ref,
2408 )
2409 .await
2410 .unwrap();
2411
2412 let received_request = read_frame(server_socket_ref, FrameType::Request)
2413 .await
2414 .unwrap();
2415 assert_eq!(received_request.1, FrameOpcode::Request(req_opcode));
2416
2417 write_frame(
2418 params.for_response(),
2419 FrameOpcode::Response(resp_opcode),
2420 &(body_base.to_string() + "|response|").into(),
2421 server_socket_ref,
2422 )
2423 .await
2424 .unwrap();
2425
2426 let received_response = read_frame(client_socket_ref, FrameType::Response)
2427 .await
2428 .unwrap();
2429 assert_eq!(received_response.1, FrameOpcode::Response(resp_opcode));
2430 }
2431
2432 for i in 0..5 {
2434 perform_reqest_response(
2435 RequestOpcode::Query,
2436 ResponseOpcode::Result,
2437 &mut client_socket,
2438 &mut server_socket,
2439 &format!("message_before_{i}"),
2440 )
2441 .await
2442 }
2443
2444 perform_reqest_response(
2445 RequestOpcode::Register,
2446 ResponseOpcode::Result,
2447 &mut client_socket,
2448 &mut server_socket,
2449 "message_register",
2450 )
2451 .await;
2452
2453 for i in 0..5 {
2455 perform_reqest_response(
2456 RequestOpcode::Query,
2457 ResponseOpcode::Result,
2458 &mut client_socket,
2459 &mut server_socket,
2460 &format!("message_after_{i}"),
2461 )
2462 .await
2463 }
2464
2465 running_proxy.finish().await.unwrap();
2466
2467 for _ in 0..5 {
2468 let (feedback_request, _shard) = request_feedback_rx.recv().await.unwrap();
2469 assert_eq!(feedback_request.opcode, RequestOpcode::Query);
2470 let (feedback_response, _shard) = response_feedback_rx.recv().await.unwrap();
2471 assert_eq!(feedback_response.opcode, ResponseOpcode::Result);
2472 }
2473
2474 let _ = request_feedback_rx.try_recv().unwrap_err();
2476 let _ = response_feedback_rx.try_recv().unwrap_err();
2477 }
2478}