1#[allow(clippy::disallowed_types)] use std::collections::HashMap;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum DerivedChannelType {
12 Spsc,
17
18 Broadcast {
23 consumer_count: usize,
25 },
26}
27
28impl DerivedChannelType {
29 #[must_use]
31 pub fn is_broadcast(&self) -> bool {
32 matches!(self, DerivedChannelType::Broadcast { .. })
33 }
34
35 #[must_use]
37 pub fn consumer_count(&self) -> usize {
38 match self {
39 DerivedChannelType::Spsc => 1,
40 DerivedChannelType::Broadcast { consumer_count } => *consumer_count,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
49pub struct SourceDefinition {
50 pub name: String,
52 pub watermark_column: Option<String>,
54}
55
56impl SourceDefinition {
57 #[must_use]
59 pub fn new(name: impl Into<String>) -> Self {
60 Self {
61 name: name.into(),
62 watermark_column: None,
63 }
64 }
65
66 #[must_use]
68 pub fn with_watermark(name: impl Into<String>, watermark_column: impl Into<String>) -> Self {
69 Self {
70 name: name.into(),
71 watermark_column: Some(watermark_column.into()),
72 }
73 }
74}
75
76#[derive(Debug, Clone)]
80pub struct MvDefinition {
81 pub name: String,
83 pub source_refs: Vec<String>,
85}
86
87impl MvDefinition {
88 #[must_use]
90 pub fn new(name: impl Into<String>, source_refs: Vec<String>) -> Self {
91 Self {
92 name: name.into(),
93 source_refs,
94 }
95 }
96
97 #[must_use]
99 pub fn from_source(name: impl Into<String>, source: impl Into<String>) -> Self {
100 Self {
101 name: name.into(),
102 source_refs: vec![source.into()],
103 }
104 }
105}
106
107#[must_use]
152pub fn derive_channel_types(
153 sources: &[SourceDefinition],
154 mvs: &[MvDefinition],
155) -> HashMap<String, DerivedChannelType> {
156 let consumer_counts = count_consumers_per_source(mvs);
157
158 sources
159 .iter()
160 .map(|source| {
161 let count = consumer_counts.get(&source.name).copied().unwrap_or(0);
162 let channel_type = if count <= 1 {
163 DerivedChannelType::Spsc
164 } else {
165 DerivedChannelType::Broadcast {
166 consumer_count: count,
167 }
168 };
169 (source.name.clone(), channel_type)
170 })
171 .collect()
172}
173
174fn count_consumers_per_source(mvs: &[MvDefinition]) -> HashMap<String, usize> {
176 let mut counts: HashMap<String, usize> = HashMap::with_capacity(mvs.len());
177
178 for mv in mvs {
179 for source_ref in &mv.source_refs {
180 *counts.entry(source_ref.clone()).or_insert(0) += 1;
181 }
182 }
183
184 counts
185}
186
187#[must_use]
202pub fn analyze_mv_sources(mv_name: &str, source_tables: &[&str]) -> MvDefinition {
203 MvDefinition::new(
204 mv_name.to_string(),
205 source_tables.iter().map(|s| (*s).to_string()).collect(),
206 )
207}
208
209#[derive(Debug, Clone)]
211pub struct ChannelDerivationResult {
212 pub channel_types: HashMap<String, DerivedChannelType>,
214 pub orphaned_sources: Vec<String>,
216 pub broadcast_count: usize,
218 pub spsc_count: usize,
220}
221
222#[must_use]
227pub fn derive_channel_types_detailed(
228 sources: &[SourceDefinition],
229 mvs: &[MvDefinition],
230) -> ChannelDerivationResult {
231 let channel_types = derive_channel_types(sources, mvs);
232
233 let orphaned_sources: Vec<String> = channel_types
234 .iter()
235 .filter(|(_, ct)| matches!(ct, DerivedChannelType::Spsc))
236 .filter(|(name, _)| {
237 !mvs.iter().any(|mv| mv.source_refs.contains(*name))
239 })
240 .map(|(name, _)| name.clone())
241 .collect();
242
243 let broadcast_count = channel_types
244 .values()
245 .filter(|ct| ct.is_broadcast())
246 .count();
247
248 let spsc_count = channel_types.len() - broadcast_count;
249
250 ChannelDerivationResult {
251 channel_types,
252 orphaned_sources,
253 broadcast_count,
254 spsc_count,
255 }
256}
257
258#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn test_derive_single_consumer_spsc() {
268 let sources = vec![SourceDefinition::new("trades")];
269 let mvs = vec![MvDefinition::from_source("vwap", "trades")];
270
271 let channel_types = derive_channel_types(&sources, &mvs);
272
273 assert_eq!(channel_types.get("trades"), Some(&DerivedChannelType::Spsc));
274 }
275
276 #[test]
277 fn test_derive_multiple_consumers_broadcast() {
278 let sources = vec![SourceDefinition::new("trades")];
279 let mvs = vec![
280 MvDefinition::from_source("vwap", "trades"),
281 MvDefinition::from_source("max_price", "trades"),
282 ];
283
284 let channel_types = derive_channel_types(&sources, &mvs);
285
286 assert_eq!(
287 channel_types.get("trades"),
288 Some(&DerivedChannelType::Broadcast { consumer_count: 2 })
289 );
290 }
291
292 #[test]
293 fn test_derive_mixed_sources() {
294 let sources = vec![
295 SourceDefinition::new("trades"),
296 SourceDefinition::new("orders"),
297 ];
298 let mvs = vec![
299 MvDefinition::from_source("vwap", "trades"),
300 MvDefinition::from_source("max_price", "trades"),
301 MvDefinition::from_source("order_count", "orders"),
302 ];
303
304 let channel_types = derive_channel_types(&sources, &mvs);
305
306 assert_eq!(
308 channel_types.get("trades"),
309 Some(&DerivedChannelType::Broadcast { consumer_count: 2 })
310 );
311
312 assert_eq!(channel_types.get("orders"), Some(&DerivedChannelType::Spsc));
314 }
315
316 #[test]
317 fn test_derive_no_consumers() {
318 let sources = vec![SourceDefinition::new("orphan")];
319 let mvs: Vec<MvDefinition> = vec![];
320
321 let channel_types = derive_channel_types(&sources, &mvs);
322
323 assert_eq!(channel_types.get("orphan"), Some(&DerivedChannelType::Spsc));
325 }
326
327 #[test]
328 fn test_derive_mv_with_multiple_sources() {
329 let sources = vec![
330 SourceDefinition::new("orders"),
331 SourceDefinition::new("payments"),
332 ];
333 let mvs = vec![MvDefinition::new(
334 "order_payments",
335 vec!["orders".to_string(), "payments".to_string()],
336 )];
337
338 let channel_types = derive_channel_types(&sources, &mvs);
339
340 assert_eq!(channel_types.get("orders"), Some(&DerivedChannelType::Spsc));
342 assert_eq!(
343 channel_types.get("payments"),
344 Some(&DerivedChannelType::Spsc)
345 );
346 }
347
348 #[test]
349 fn test_derived_channel_type_methods() {
350 let spsc = DerivedChannelType::Spsc;
351 assert!(!spsc.is_broadcast());
352 assert_eq!(spsc.consumer_count(), 1);
353
354 let broadcast = DerivedChannelType::Broadcast { consumer_count: 3 };
355 assert!(broadcast.is_broadcast());
356 assert_eq!(broadcast.consumer_count(), 3);
357 }
358
359 #[test]
360 fn test_source_definition() {
361 let source = SourceDefinition::new("trades");
362 assert_eq!(source.name, "trades");
363 assert!(source.watermark_column.is_none());
364
365 let source_wm = SourceDefinition::with_watermark("trades", "event_time");
366 assert_eq!(source_wm.name, "trades");
367 assert_eq!(source_wm.watermark_column, Some("event_time".to_string()));
368 }
369
370 #[test]
371 fn test_mv_definition() {
372 let mv = MvDefinition::from_source("vwap", "trades");
373 assert_eq!(mv.name, "vwap");
374 assert_eq!(mv.source_refs, vec!["trades"]);
375
376 let mv_multi = MvDefinition::new(
377 "join_result",
378 vec!["orders".to_string(), "payments".to_string()],
379 );
380 assert_eq!(mv_multi.name, "join_result");
381 assert_eq!(mv_multi.source_refs.len(), 2);
382 }
383
384 #[test]
385 fn test_analyze_mv_sources() {
386 let mv = analyze_mv_sources("my_mv", &["table1", "table2"]);
387 assert_eq!(mv.name, "my_mv");
388 assert_eq!(mv.source_refs, vec!["table1", "table2"]);
389 }
390
391 #[test]
392 fn test_detailed_derivation() {
393 let sources = vec![
394 SourceDefinition::new("trades"),
395 SourceDefinition::new("orders"),
396 SourceDefinition::new("orphan"),
397 ];
398 let mvs = vec![
399 MvDefinition::from_source("vwap", "trades"),
400 MvDefinition::from_source("max_price", "trades"),
401 MvDefinition::from_source("order_count", "orders"),
402 ];
403
404 let result = derive_channel_types_detailed(&sources, &mvs);
405
406 assert_eq!(result.broadcast_count, 1); assert_eq!(result.spsc_count, 2); assert!(result.orphaned_sources.contains(&"orphan".to_string()));
409 }
410
411 #[test]
412 fn test_three_consumers() {
413 let sources = vec![SourceDefinition::new("events")];
414 let mvs = vec![
415 MvDefinition::from_source("mv1", "events"),
416 MvDefinition::from_source("mv2", "events"),
417 MvDefinition::from_source("mv3", "events"),
418 ];
419
420 let channel_types = derive_channel_types(&sources, &mvs);
421
422 assert_eq!(
423 channel_types.get("events"),
424 Some(&DerivedChannelType::Broadcast { consumer_count: 3 })
425 );
426 }
427}