1use crate::{
7 api::{ConfigEvaluationResult, GateEvaluationResult},
8 config::StatsigClientConfig,
9 error::Result,
10 transport::StatsigTransport,
11 user::User,
12};
13use std::collections::HashMap;
14use tokio::sync::{mpsc, oneshot};
15use tracing::{error, info};
16
17#[derive(Debug)]
19pub enum BatchRequest {
20 CheckGates {
21 gate_names: Vec<String>,
22 user: User,
23 response_tx: oneshot::Sender<Result<Vec<GateEvaluationResult>>>,
24 },
25 GetConfigs {
26 config_names: Vec<String>,
27 user: User,
28 response_tx: oneshot::Sender<Result<Vec<ConfigEvaluationResult>>>,
29 },
30}
31
32pub struct BatchProcessor {
34 receiver: mpsc::Receiver<BatchRequest>,
35 shutdown_rx: tokio::sync::broadcast::Receiver<()>,
36}
37
38impl BatchProcessor {
39 pub fn new(
41 receiver: mpsc::Receiver<BatchRequest>,
42 shutdown_rx: tokio::sync::broadcast::Receiver<()>,
43 ) -> Self {
44 Self {
45 receiver,
46 shutdown_rx,
47 }
48 }
49
50 pub async fn run(mut self, transport: StatsigTransport, config: StatsigClientConfig) {
52 let mut interval = tokio::time::interval(config.batch_flush_interval);
53 let mut gate_requests = Vec::new();
54 let mut config_requests = Vec::new();
55
56 loop {
57 tokio::select! {
58 Some(request) = self.receiver.recv() => {
59 match request {
60 BatchRequest::CheckGates { .. } => gate_requests.push(request),
61 BatchRequest::GetConfigs { .. } => config_requests.push(request),
62 }
63
64 if gate_requests.len() >= config.batch_size || config_requests.len() >= config.batch_size {
66 Self::process_gate_batch(&transport, &mut gate_requests).await;
67 Self::process_config_batch(&transport, &mut config_requests).await;
68 }
69 }
70 _ = interval.tick() => {
71 if !gate_requests.is_empty() {
72 Self::process_gate_batch(&transport, &mut gate_requests).await;
73 }
74 if !config_requests.is_empty() {
75 Self::process_config_batch(&transport, &mut config_requests).await;
76 }
77 }
78 _ = self.shutdown_rx.recv() => {
79 info!("Batch processor shutting down");
80 break;
81 }
82 }
83 }
84 }
85
86 async fn process_gate_batch(transport: &StatsigTransport, requests: &mut Vec<BatchRequest>) {
88 if requests.is_empty() {
89 return;
90 }
91
92 let batch = std::mem::take(requests);
93
94 let mut user_groups: HashMap<String, Vec<_>> = HashMap::new();
96 for request in batch {
97 if let BatchRequest::CheckGates { user, .. } = &request {
98 let user_hash = Self::hash_user_for_batch(user);
99 user_groups.entry(user_hash).or_default().push(request);
100 }
101 }
102
103 for (_user_hash, group_requests) in user_groups {
104 if let Some(first_request) = group_requests.first() {
105 if let BatchRequest::CheckGates { user, .. } = first_request {
106 let all_gate_names: Vec<String> = group_requests
107 .iter()
108 .filter_map(|req| {
109 if let BatchRequest::CheckGates { gate_names, .. } = req {
110 Some(gate_names.clone())
111 } else {
112 None
113 }
114 })
115 .flatten()
116 .collect();
117
118 match transport.check_gates(all_gate_names, user).await {
119 Ok(results) => {
120 for request in group_requests {
122 if let BatchRequest::CheckGates {
123 gate_names,
124 response_tx,
125 ..
126 } = request
127 {
128 let filtered_results: Vec<GateEvaluationResult> = results
129 .iter()
130 .filter(|result| gate_names.contains(&result.name))
131 .cloned()
132 .collect();
133 let _ = response_tx.send(Ok(filtered_results));
134 }
135 }
136 }
137 Err(e) => {
138 error!("Failed to fetch gates from API: {:?}", e);
139 for request in group_requests {
141 if let BatchRequest::CheckGates { response_tx, .. } = request {
142 let _ = response_tx.send(Err(e.clone()));
143 }
144 }
145 }
146 }
147 }
148 }
149 }
150 }
151
152 async fn process_config_batch(transport: &StatsigTransport, requests: &mut Vec<BatchRequest>) {
154 if requests.is_empty() {
155 return;
156 }
157
158 let batch = std::mem::take(requests);
159
160 for request in batch {
162 if let BatchRequest::GetConfigs {
163 config_names,
164 user,
165 response_tx,
166 } = request
167 {
168 let results = Self::fetch_configs_from_api(transport, &config_names, &user).await;
169 let _ = response_tx.send(results);
170 }
171 }
172 }
173
174 async fn fetch_configs_from_api(
176 transport: &StatsigTransport,
177 config_names: &[String],
178 user: &User,
179 ) -> Result<Vec<ConfigEvaluationResult>> {
180 let mut results = Vec::new();
181
182 for config_name in config_names {
183 let config_result = transport.get_config(config_name, user).await?;
184 results.push(config_result);
185 }
186
187 Ok(results)
188 }
189
190 fn hash_user_for_batch(user: &User) -> String {
192 user.hash_for_cache()
193 }
194}