1use std::{fmt, sync::Arc, time::Duration};
2
3use bytes::Bytes;
4use rand::{Rng, RngCore};
5use tokio::sync::mpsc;
6
7#[cfg(test)]
8use crate::setup_tracing;
9
10use crate::{
11 TargetShard,
12 frame::{FrameOpcode, FrameParams, RequestFrame, RequestOpcode, ResponseFrame, ResponseOpcode},
13};
14use scylla_cql::frame::response::error::DbError;
15
16#[derive(Debug, Clone)]
20pub enum Condition {
21 True,
22
23 False,
24
25 Not(Box<Condition>),
26
27 And(Box<Condition>, Box<Condition>),
28
29 Or(Box<Condition>, Box<Condition>),
30
31 ConnectionSeqNo(usize),
33
34 RequestOpcode(RequestOpcode),
36
37 ResponseOpcode(ResponseOpcode),
39
40 BodyContainsCaseSensitive(Box<[u8]>),
42
43 BodyContainsCaseInsensitive(Box<[u8]>),
45
46 RandomWithProbability(f64),
48
49 TrueForLimitedTimes(usize),
51
52 ConnectionRegisteredAnyEvent,
54}
55
56pub(crate) struct EvaluationContext {
58 pub(crate) connection_seq_no: usize,
59 pub(crate) connection_has_events: bool,
60 pub(crate) opcode: FrameOpcode,
61 pub(crate) frame_body: Bytes,
62}
63
64impl Condition {
65 pub(crate) fn eval(&mut self, ctx: &EvaluationContext) -> bool {
66 match self {
67 Condition::True => true,
68
69 Condition::False => false,
70
71 Condition::Not(c) => !c.eval(ctx),
72
73 Condition::And(c1, c2) => c1.eval(ctx) && c2.eval(ctx),
74
75 Condition::Or(c1, c2) => c1.eval(ctx) || c2.eval(ctx),
76
77 Condition::ConnectionSeqNo(no) => *no == ctx.connection_seq_no,
78
79 Condition::RequestOpcode(op1) => match ctx.opcode {
80 FrameOpcode::Request(op2) => *op1 == op2,
81 FrameOpcode::Response(_) => panic!(
82 "Invalid type applied in rule condition: driver request opcode in cluster context"
83 ),
84 },
85
86 Condition::ResponseOpcode(op1) => match ctx.opcode {
87 FrameOpcode::Request(_) => panic!(
88 "Invalid type applied in rule condition: cluster response opcode in driver context"
89 ),
90 FrameOpcode::Response(op2) => *op1 == op2,
91 },
92
93 Condition::BodyContainsCaseSensitive(pattern) => ctx
94 .frame_body
95 .windows(pattern.len())
96 .any(|window| *window == **pattern),
97
98 Condition::BodyContainsCaseInsensitive(pattern) => std::str::from_utf8(pattern)
99 .map(|pattern_str| {
100 ctx.frame_body.windows(pattern.len()).any(|window| {
101 std::str::from_utf8(window)
102 .map(|window_str| str::eq_ignore_ascii_case(window_str, pattern_str))
103 .unwrap_or(false)
104 })
105 })
106 .unwrap_or(false),
107 Condition::RandomWithProbability(probability) => rand::rng().random_bool(*probability),
108
109 Condition::TrueForLimitedTimes(times) => {
110 let val = *times > 0;
111 if val {
112 *times -= 1;
113 }
114 val
115 }
116
117 Condition::ConnectionRegisteredAnyEvent => ctx.connection_has_events,
118 }
119 }
120
121 #[expect(clippy::should_implement_trait)]
123 pub fn not(c: Self) -> Self {
124 Condition::Not(Box::new(c))
125 }
126
127 pub fn and(self, c2: Self) -> Self {
129 Self::And(Box::new(self), Box::new(c2))
130 }
131
132 pub fn or(self, c2: Self) -> Self {
134 Self::Or(Box::new(self), Box::new(c2))
135 }
136
137 pub fn all(cs: impl IntoIterator<Item = Self>) -> Self {
139 let mut cs = cs.into_iter();
140 match cs.next() {
141 None => Self::True, Some(mut c) => {
143 for head in cs {
144 c = head.and(c);
145 }
146 c
147 }
148 }
149 }
150
151 pub fn any(cs: impl IntoIterator<Item = Self>) -> Self {
153 let mut cs = cs.into_iter();
154 match cs.next() {
155 None => Self::False, Some(mut c) => {
157 for head in cs {
158 c = head.or(c);
159 }
160 c
161 }
162 }
163 }
164}
165
166pub trait Reaction: Sized {
178 type Incoming;
179 type Returning;
180
181 fn noop() -> Self;
183
184 fn drop_frame() -> Self;
186
187 fn delay(time: Duration) -> Self;
189
190 fn forge_response(f: Arc<dyn Fn(Self::Incoming) -> Self::Returning + Send + Sync>) -> Self;
192
193 fn forge_response_with_delay(
195 time: Duration,
196 f: Arc<dyn Fn(Self::Incoming) -> Self::Returning + Send + Sync>,
197 ) -> Self;
198
199 fn transform_frame(f: Arc<dyn Fn(Self::Incoming) -> Self::Incoming + Send + Sync>) -> Self;
201
202 fn drop_connection() -> Self;
204
205 fn drop_connection_with_delay(time: Duration) -> Self;
207
208 fn with_feedback_when_performed(
211 self,
212 tx: mpsc::UnboundedSender<(Self::Incoming, Option<TargetShard>)>,
213 ) -> Self;
214}
215
216fn fmt_reaction(
217 f: &mut std::fmt::Formatter<'_>,
218 reaction_type: &str,
219 to_addressee: &dyn fmt::Debug,
220 to_sender: &dyn fmt::Debug,
221 drop_connection: &dyn fmt::Debug,
222 has_feedback_channel: bool,
223) -> std::fmt::Result {
224 f.debug_struct(reaction_type)
225 .field("to_addressee", to_addressee)
226 .field("to_sender", to_sender)
227 .field("drop_connection", drop_connection)
228 .field(
229 "feedback_channel",
230 if has_feedback_channel {
231 &"Some(<feedback_channel>)"
232 } else {
233 &"None"
234 },
235 )
236 .finish()
237}
238
239#[derive(Clone)]
240pub struct RequestReaction {
241 pub to_addressee: Option<Action<RequestFrame, RequestFrame>>,
242 pub to_sender: Option<Action<RequestFrame, ResponseFrame>>,
243 pub drop_connection: Option<Option<Duration>>,
244 pub feedback_channel: Option<mpsc::UnboundedSender<(RequestFrame, Option<TargetShard>)>>,
245}
246
247impl fmt::Debug for RequestReaction {
248 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249 fmt_reaction(
250 f,
251 "RequestReaction",
252 &self.to_addressee,
253 &self.to_sender,
254 &self.drop_connection,
255 self.feedback_channel.is_some(),
256 )
257 }
258}
259
260#[derive(Clone)]
261pub struct ResponseReaction {
262 pub to_addressee: Option<Action<ResponseFrame, ResponseFrame>>,
263 pub to_sender: Option<Action<ResponseFrame, RequestFrame>>,
264 pub drop_connection: Option<Option<Duration>>,
265 pub feedback_channel: Option<mpsc::UnboundedSender<(ResponseFrame, Option<TargetShard>)>>,
266}
267
268impl fmt::Debug for ResponseReaction {
269 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270 fmt_reaction(
271 f,
272 "ResponseReaction",
273 &self.to_addressee,
274 &self.to_sender,
275 &self.drop_connection,
276 self.feedback_channel.is_some(),
277 )
278 }
279}
280
281impl Reaction for RequestReaction {
282 type Incoming = RequestFrame;
283 type Returning = ResponseFrame;
284
285 fn noop() -> Self {
286 RequestReaction {
287 to_addressee: Some(Action {
288 delay: None,
289 msg_processor: None,
290 }),
291 to_sender: None,
292 drop_connection: None,
293 feedback_channel: None,
294 }
295 }
296
297 fn drop_frame() -> Self {
298 RequestReaction {
299 to_addressee: None,
300 to_sender: None,
301 drop_connection: None,
302 feedback_channel: None,
303 }
304 }
305
306 fn delay(time: Duration) -> Self {
307 RequestReaction {
308 to_addressee: Some(Action {
309 delay: Some(time),
310 msg_processor: None,
311 }),
312 to_sender: None,
313 drop_connection: None,
314 feedback_channel: None,
315 }
316 }
317
318 fn forge_response(f: Arc<dyn Fn(Self::Incoming) -> Self::Returning + Send + Sync>) -> Self {
319 RequestReaction {
320 to_addressee: None,
321 to_sender: Some(Action {
322 delay: None,
323 msg_processor: Some(f),
324 }),
325 drop_connection: None,
326 feedback_channel: None,
327 }
328 }
329
330 fn forge_response_with_delay(
331 time: Duration,
332 f: Arc<dyn Fn(Self::Incoming) -> Self::Returning + Send + Sync>,
333 ) -> Self {
334 RequestReaction {
335 to_addressee: None,
336 to_sender: Some(Action {
337 delay: Some(time),
338 msg_processor: Some(f),
339 }),
340 drop_connection: None,
341 feedback_channel: None,
342 }
343 }
344
345 fn transform_frame(f: Arc<dyn Fn(Self::Incoming) -> Self::Incoming + Send + Sync>) -> Self {
346 RequestReaction {
347 to_addressee: Some(Action {
348 delay: None,
349 msg_processor: Some(f),
350 }),
351 to_sender: None,
352 drop_connection: None,
353 feedback_channel: None,
354 }
355 }
356
357 fn drop_connection() -> Self {
358 RequestReaction {
359 to_addressee: None,
360 to_sender: None,
361 drop_connection: Some(None),
362 feedback_channel: None,
363 }
364 }
365
366 fn drop_connection_with_delay(time: Duration) -> Self {
367 RequestReaction {
368 to_addressee: None,
369 to_sender: None,
370 drop_connection: Some(Some(time)),
371 feedback_channel: None,
372 }
373 }
374
375 fn with_feedback_when_performed(
376 self,
377 tx: mpsc::UnboundedSender<(Self::Incoming, Option<TargetShard>)>,
378 ) -> Self {
379 Self {
380 feedback_channel: Some(tx),
381 ..self
382 }
383 }
384}
385
386impl RequestReaction {
387 pub fn forge_with_error_lazy(gen_error: Box<dyn Fn() -> DbError + Send + Sync>) -> Self {
388 Self::forge_with_error_lazy_delay(gen_error, None)
389 }
390 pub fn forge_with_error_lazy_delay(
393 gen_error: Box<dyn Fn() -> DbError + Send + Sync>,
394 delay: Option<Duration>,
395 ) -> Self {
396 RequestReaction {
397 to_addressee: None,
398 to_sender: Some(Action {
399 delay,
400 msg_processor: Some(Arc::new(move |request: RequestFrame| {
401 ResponseFrame::forged_error(request.params.for_response(), gen_error(), None)
402 .unwrap()
403 })),
404 }),
405 drop_connection: None,
406 feedback_channel: None,
407 }
408 }
409
410 pub fn forge_with_error(error: DbError) -> Self {
411 Self::forge_with_error_and_message(error, Some("Proxy-triggered error.".into()))
412 }
413
414 pub fn forge_with_error_and_message(error: DbError, msg: Option<String>) -> Self {
416 ResponseFrame::forged_error(
418 FrameParams {
419 version: 0,
420 flags: 0,
421 stream: 0,
422 },
423 error.clone(),
424 None,
425 )
426 .unwrap_or_else(|_| panic!("Invalid DbError provided: {error:#?}"));
427
428 RequestReaction {
429 to_addressee: None,
430 to_sender: Some(Action {
431 delay: None,
432 msg_processor: Some(Arc::new(move |request: RequestFrame| {
433 ResponseFrame::forged_error(
434 request.params.for_response(),
435 error.clone(),
436 msg.as_deref(),
437 )
438 .unwrap()
439 })),
440 }),
441 drop_connection: None,
442 feedback_channel: None,
443 }
444 }
445
446 pub fn forge() -> ResponseForger {
447 ResponseForger
448 }
449}
450
451pub mod example_db_errors {
452 use bytes::Bytes;
453 use scylla_cql::{
454 Consistency,
455 frame::response::error::{DbError, WriteType},
456 };
457
458 pub fn syntax_error() -> DbError {
459 DbError::SyntaxError
460 }
461 pub fn invalid() -> DbError {
462 DbError::Invalid
463 }
464 pub fn already_exists() -> DbError {
465 DbError::AlreadyExists {
466 keyspace: "proxy".into(),
467 table: "worker".into(),
468 }
469 }
470 pub fn function_failure() -> DbError {
471 DbError::FunctionFailure {
472 keyspace: "proxy".into(),
473 function: "fibonacci".into(),
474 arg_types: vec!["n".into()],
475 }
476 }
477 pub fn authentication_error() -> DbError {
478 DbError::AuthenticationError
479 }
480 pub fn unauthorized() -> DbError {
481 DbError::Unauthorized
482 }
483 pub fn config_error() -> DbError {
484 DbError::ConfigError
485 }
486 pub fn unavailable() -> DbError {
487 DbError::Unavailable {
488 consistency: Consistency::One,
489 required: 2,
490 alive: 1,
491 }
492 }
493 pub fn overloaded() -> DbError {
494 DbError::Overloaded
495 }
496 pub fn is_bootstrapping() -> DbError {
497 DbError::IsBootstrapping
498 }
499 pub fn truncate_error() -> DbError {
500 DbError::TruncateError
501 }
502 pub fn read_timeout() -> DbError {
503 DbError::ReadTimeout {
504 consistency: Consistency::One,
505 received: 2,
506 required: 3,
507 data_present: true,
508 }
509 }
510 pub fn write_timeout() -> DbError {
511 DbError::WriteTimeout {
512 consistency: Consistency::One,
513 received: 2,
514 required: 3,
515 write_type: WriteType::UnloggedBatch,
516 }
517 }
518 pub fn read_failure() -> DbError {
519 DbError::ReadFailure {
520 consistency: Consistency::One,
521 received: 2,
522 required: 3,
523 data_present: true,
524 numfailures: 1,
525 }
526 }
527 pub fn write_failure() -> DbError {
528 DbError::WriteFailure {
529 consistency: Consistency::One,
530 received: 2,
531 required: 3,
532 write_type: WriteType::UnloggedBatch,
533 numfailures: 1,
534 }
535 }
536 pub fn unprepared() -> DbError {
537 DbError::Unprepared {
538 statement_id: Bytes::from_static(b"21372137"),
539 }
540 }
541 pub fn server_error() -> DbError {
542 DbError::ServerError
543 }
544 pub fn protocol_error() -> DbError {
545 DbError::ProtocolError
546 }
547 pub fn other(num: i32) -> DbError {
548 DbError::Other(num)
549 }
550}
551
552pub struct ResponseForger;
553
554impl ResponseForger {
555 pub fn syntax_error(&self) -> RequestReaction {
556 RequestReaction::forge_with_error(example_db_errors::syntax_error())
557 }
558 pub fn invalid(&self) -> RequestReaction {
559 RequestReaction::forge_with_error(example_db_errors::invalid())
560 }
561 pub fn already_exists(&self) -> RequestReaction {
562 RequestReaction::forge_with_error(example_db_errors::already_exists())
563 }
564 pub fn function_failure(&self) -> RequestReaction {
565 RequestReaction::forge_with_error(example_db_errors::function_failure())
566 }
567 pub fn authentication_error(&self) -> RequestReaction {
568 RequestReaction::forge_with_error(example_db_errors::authentication_error())
569 }
570 pub fn unauthorized(&self) -> RequestReaction {
571 RequestReaction::forge_with_error(example_db_errors::unauthorized())
572 }
573 pub fn config_error(&self) -> RequestReaction {
574 RequestReaction::forge_with_error(example_db_errors::config_error())
575 }
576 pub fn unavailable(&self) -> RequestReaction {
577 RequestReaction::forge_with_error(example_db_errors::unavailable())
578 }
579 pub fn overloaded(&self) -> RequestReaction {
580 RequestReaction::forge_with_error(example_db_errors::overloaded())
581 }
582 pub fn is_bootstrapping(&self) -> RequestReaction {
583 RequestReaction::forge_with_error(example_db_errors::is_bootstrapping())
584 }
585 pub fn truncate_error(&self) -> RequestReaction {
586 RequestReaction::forge_with_error(example_db_errors::truncate_error())
587 }
588 pub fn read_timeout(&self) -> RequestReaction {
589 RequestReaction::forge_with_error(example_db_errors::read_timeout())
590 }
591 pub fn write_timeout(&self) -> RequestReaction {
592 RequestReaction::forge_with_error(example_db_errors::write_timeout())
593 }
594 pub fn read_failure(&self) -> RequestReaction {
595 RequestReaction::forge_with_error(example_db_errors::read_failure())
596 }
597 pub fn write_failure(&self) -> RequestReaction {
598 RequestReaction::forge_with_error(example_db_errors::write_failure())
599 }
600 pub fn unprepared(&self) -> RequestReaction {
601 RequestReaction::forge_with_error(example_db_errors::unprepared())
602 }
603 pub fn server_error(&self) -> RequestReaction {
604 RequestReaction::forge_with_error(example_db_errors::server_error())
605 }
606 pub fn protocol_error(&self) -> RequestReaction {
607 RequestReaction::forge_with_error(example_db_errors::protocol_error())
608 }
609 pub fn other(&self, num: i32) -> RequestReaction {
610 RequestReaction::forge_with_error(example_db_errors::other(num))
611 }
612 pub fn random_error(&self) -> RequestReaction {
613 self.random_error_with_delay(None)
614 }
615 pub fn random_error_with_delay(&self, delay: Option<Duration>) -> RequestReaction {
616 static ERRORS: &[fn() -> DbError] = &[
617 example_db_errors::invalid,
618 example_db_errors::already_exists,
619 example_db_errors::function_failure,
620 example_db_errors::authentication_error,
621 example_db_errors::unauthorized,
622 example_db_errors::config_error,
623 example_db_errors::unavailable,
624 example_db_errors::overloaded,
625 example_db_errors::is_bootstrapping,
626 example_db_errors::truncate_error,
627 example_db_errors::read_timeout,
628 example_db_errors::write_timeout,
629 example_db_errors::write_failure,
630 example_db_errors::unprepared,
631 example_db_errors::server_error,
632 example_db_errors::protocol_error,
633 || example_db_errors::other(2137),
634 ];
635 RequestReaction::forge_with_error_lazy_delay(
636 Box::new(|| ERRORS[rand::rng().next_u32() as usize % ERRORS.len()]()),
637 delay,
638 )
639 }
640}
641
642impl Reaction for ResponseReaction {
643 type Incoming = ResponseFrame;
644 type Returning = RequestFrame;
645
646 fn noop() -> Self {
647 ResponseReaction {
648 to_addressee: Some(Action {
649 delay: None,
650 msg_processor: None,
651 }),
652 to_sender: None,
653 drop_connection: None,
654 feedback_channel: None,
655 }
656 }
657
658 fn drop_frame() -> Self {
659 ResponseReaction {
660 to_addressee: None,
661 to_sender: None,
662 drop_connection: None,
663 feedback_channel: None,
664 }
665 }
666
667 fn delay(time: Duration) -> Self {
668 ResponseReaction {
669 to_addressee: Some(Action {
670 delay: Some(time),
671 msg_processor: None,
672 }),
673 to_sender: None,
674 drop_connection: None,
675 feedback_channel: None,
676 }
677 }
678
679 fn forge_response(f: Arc<dyn Fn(Self::Incoming) -> Self::Returning + Send + Sync>) -> Self {
680 ResponseReaction {
681 to_addressee: None,
682 to_sender: Some(Action {
683 delay: None,
684 msg_processor: Some(f),
685 }),
686 drop_connection: None,
687 feedback_channel: None,
688 }
689 }
690
691 fn forge_response_with_delay(
692 time: Duration,
693 f: Arc<dyn Fn(Self::Incoming) -> Self::Returning + Send + Sync>,
694 ) -> Self {
695 ResponseReaction {
696 to_addressee: None,
697 to_sender: Some(Action {
698 delay: Some(time),
699 msg_processor: Some(f),
700 }),
701 drop_connection: None,
702 feedback_channel: None,
703 }
704 }
705
706 fn transform_frame(f: Arc<dyn Fn(Self::Incoming) -> Self::Incoming + Send + Sync>) -> Self {
707 ResponseReaction {
708 to_addressee: Some(Action {
709 delay: None,
710 msg_processor: Some(f),
711 }),
712 to_sender: None,
713 drop_connection: None,
714 feedback_channel: None,
715 }
716 }
717
718 fn drop_connection() -> Self {
719 ResponseReaction {
720 to_addressee: None,
721 to_sender: None,
722 drop_connection: Some(None),
723 feedback_channel: None,
724 }
725 }
726
727 fn drop_connection_with_delay(time: Duration) -> Self {
728 ResponseReaction {
729 to_addressee: None,
730 to_sender: None,
731 drop_connection: Some(Some(time)),
732 feedback_channel: None,
733 }
734 }
735
736 fn with_feedback_when_performed(
737 self,
738 tx: mpsc::UnboundedSender<(Self::Incoming, Option<TargetShard>)>,
739 ) -> Self {
740 Self {
741 feedback_channel: Some(tx),
742 ..self
743 }
744 }
745}
746
747#[derive(Clone)]
750pub struct Action<TFrom, TTo> {
751 pub delay: Option<Duration>,
752 pub msg_processor: Option<Arc<dyn Fn(TFrom) -> TTo + Send + Sync>>,
753}
754
755impl<TFrom, TTo> std::fmt::Debug for Action<TFrom, TTo> {
758 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
759 f.debug_struct("Action")
760 .field("delay", &self.delay)
761 .field(
762 "msg_processor",
763 match self.msg_processor {
764 Some(_) => &"Some(<closure>)",
765 None => &"None",
766 },
767 )
768 .finish()
769 }
770}
771
772#[derive(Clone, Debug)]
775pub struct RequestRule(pub Condition, pub RequestReaction);
776
777#[derive(Clone, Debug)]
780pub struct ResponseRule(pub Condition, pub ResponseReaction);
781
782#[test]
783fn condition_case_insensitive_matching() {
784 setup_tracing();
785 let mut condition_matching =
786 Condition::BodyContainsCaseInsensitive(Box::new(*b"cassandra'sInefficiency"));
787 let mut condition_nonmatching =
788 Condition::BodyContainsCaseInsensitive(Box::new(*b"cassandrasInefficiency"));
789 let ctx = EvaluationContext {
790 connection_seq_no: 42,
791 opcode: FrameOpcode::Request(RequestOpcode::Options),
792 frame_body: Bytes::from_static(b"\0\0x{0x223}Cassandra'sINEFFICIENCY\x12\x31"),
793 connection_has_events: false,
794 };
795
796 assert!(condition_matching.eval(&ctx));
797 assert!(!condition_nonmatching.eval(&ctx));
798}