1#[allow(clippy::disallowed_types)] use std::collections::HashMap;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum DerivedChannelType {
42 Spsc,
47
48 Broadcast {
53 consumer_count: usize,
55 },
56}
57
58impl DerivedChannelType {
59 #[must_use]
61 pub fn is_broadcast(&self) -> bool {
62 matches!(self, DerivedChannelType::Broadcast { .. })
63 }
64
65 #[must_use]
67 pub fn consumer_count(&self) -> usize {
68 match self {
69 DerivedChannelType::Spsc => 1,
70 DerivedChannelType::Broadcast { consumer_count } => *consumer_count,
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
79pub struct SourceDefinition {
80 pub name: String,
82 pub watermark_column: Option<String>,
84}
85
86impl SourceDefinition {
87 #[must_use]
89 pub fn new(name: impl Into<String>) -> Self {
90 Self {
91 name: name.into(),
92 watermark_column: None,
93 }
94 }
95
96 #[must_use]
98 pub fn with_watermark(name: impl Into<String>, watermark_column: impl Into<String>) -> Self {
99 Self {
100 name: name.into(),
101 watermark_column: Some(watermark_column.into()),
102 }
103 }
104}
105
106#[derive(Debug, Clone)]
110pub struct MvDefinition {
111 pub name: String,
113 pub source_refs: Vec<String>,
115}
116
117impl MvDefinition {
118 #[must_use]
120 pub fn new(name: impl Into<String>, source_refs: Vec<String>) -> Self {
121 Self {
122 name: name.into(),
123 source_refs,
124 }
125 }
126
127 #[must_use]
129 pub fn from_source(name: impl Into<String>, source: impl Into<String>) -> Self {
130 Self {
131 name: name.into(),
132 source_refs: vec![source.into()],
133 }
134 }
135}
136
137#[must_use]
182pub fn derive_channel_types(
183 sources: &[SourceDefinition],
184 mvs: &[MvDefinition],
185) -> HashMap<String, DerivedChannelType> {
186 let consumer_counts = count_consumers_per_source(mvs);
187
188 sources
189 .iter()
190 .map(|source| {
191 let count = consumer_counts.get(&source.name).copied().unwrap_or(0);
192 let channel_type = if count <= 1 {
193 DerivedChannelType::Spsc
194 } else {
195 DerivedChannelType::Broadcast {
196 consumer_count: count,
197 }
198 };
199 (source.name.clone(), channel_type)
200 })
201 .collect()
202}
203
204fn count_consumers_per_source(mvs: &[MvDefinition]) -> HashMap<String, usize> {
206 let mut counts: HashMap<String, usize> = HashMap::with_capacity(mvs.len());
207
208 for mv in mvs {
209 for source_ref in &mv.source_refs {
210 *counts.entry(source_ref.clone()).or_insert(0) += 1;
211 }
212 }
213
214 counts
215}
216
217#[must_use]
232pub fn analyze_mv_sources(mv_name: &str, source_tables: &[&str]) -> MvDefinition {
233 MvDefinition::new(
234 mv_name.to_string(),
235 source_tables.iter().map(|s| (*s).to_string()).collect(),
236 )
237}
238
239#[derive(Debug, Clone)]
241pub struct ChannelDerivationResult {
242 pub channel_types: HashMap<String, DerivedChannelType>,
244 pub orphaned_sources: Vec<String>,
246 pub broadcast_count: usize,
248 pub spsc_count: usize,
250}
251
252#[must_use]
257pub fn derive_channel_types_detailed(
258 sources: &[SourceDefinition],
259 mvs: &[MvDefinition],
260) -> ChannelDerivationResult {
261 let channel_types = derive_channel_types(sources, mvs);
262
263 let orphaned_sources: Vec<String> = channel_types
264 .iter()
265 .filter(|(_, ct)| matches!(ct, DerivedChannelType::Spsc))
266 .filter(|(name, _)| {
267 !mvs.iter().any(|mv| mv.source_refs.contains(*name))
269 })
270 .map(|(name, _)| name.clone())
271 .collect();
272
273 let broadcast_count = channel_types
274 .values()
275 .filter(|ct| ct.is_broadcast())
276 .count();
277
278 let spsc_count = channel_types.len() - broadcast_count;
279
280 ChannelDerivationResult {
281 channel_types,
282 orphaned_sources,
283 broadcast_count,
284 spsc_count,
285 }
286}
287
288#[cfg(test)]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn test_derive_single_consumer_spsc() {
298 let sources = vec![SourceDefinition::new("trades")];
299 let mvs = vec![MvDefinition::from_source("vwap", "trades")];
300
301 let channel_types = derive_channel_types(&sources, &mvs);
302
303 assert_eq!(channel_types.get("trades"), Some(&DerivedChannelType::Spsc));
304 }
305
306 #[test]
307 fn test_derive_multiple_consumers_broadcast() {
308 let sources = vec![SourceDefinition::new("trades")];
309 let mvs = vec![
310 MvDefinition::from_source("vwap", "trades"),
311 MvDefinition::from_source("max_price", "trades"),
312 ];
313
314 let channel_types = derive_channel_types(&sources, &mvs);
315
316 assert_eq!(
317 channel_types.get("trades"),
318 Some(&DerivedChannelType::Broadcast { consumer_count: 2 })
319 );
320 }
321
322 #[test]
323 fn test_derive_mixed_sources() {
324 let sources = vec![
325 SourceDefinition::new("trades"),
326 SourceDefinition::new("orders"),
327 ];
328 let mvs = vec![
329 MvDefinition::from_source("vwap", "trades"),
330 MvDefinition::from_source("max_price", "trades"),
331 MvDefinition::from_source("order_count", "orders"),
332 ];
333
334 let channel_types = derive_channel_types(&sources, &mvs);
335
336 assert_eq!(
338 channel_types.get("trades"),
339 Some(&DerivedChannelType::Broadcast { consumer_count: 2 })
340 );
341
342 assert_eq!(channel_types.get("orders"), Some(&DerivedChannelType::Spsc));
344 }
345
346 #[test]
347 fn test_derive_no_consumers() {
348 let sources = vec![SourceDefinition::new("orphan")];
349 let mvs: Vec<MvDefinition> = vec![];
350
351 let channel_types = derive_channel_types(&sources, &mvs);
352
353 assert_eq!(channel_types.get("orphan"), Some(&DerivedChannelType::Spsc));
355 }
356
357 #[test]
358 fn test_derive_mv_with_multiple_sources() {
359 let sources = vec![
360 SourceDefinition::new("orders"),
361 SourceDefinition::new("payments"),
362 ];
363 let mvs = vec![MvDefinition::new(
364 "order_payments",
365 vec!["orders".to_string(), "payments".to_string()],
366 )];
367
368 let channel_types = derive_channel_types(&sources, &mvs);
369
370 assert_eq!(channel_types.get("orders"), Some(&DerivedChannelType::Spsc));
372 assert_eq!(
373 channel_types.get("payments"),
374 Some(&DerivedChannelType::Spsc)
375 );
376 }
377
378 #[test]
379 fn test_derived_channel_type_methods() {
380 let spsc = DerivedChannelType::Spsc;
381 assert!(!spsc.is_broadcast());
382 assert_eq!(spsc.consumer_count(), 1);
383
384 let broadcast = DerivedChannelType::Broadcast { consumer_count: 3 };
385 assert!(broadcast.is_broadcast());
386 assert_eq!(broadcast.consumer_count(), 3);
387 }
388
389 #[test]
390 fn test_source_definition() {
391 let source = SourceDefinition::new("trades");
392 assert_eq!(source.name, "trades");
393 assert!(source.watermark_column.is_none());
394
395 let source_wm = SourceDefinition::with_watermark("trades", "event_time");
396 assert_eq!(source_wm.name, "trades");
397 assert_eq!(source_wm.watermark_column, Some("event_time".to_string()));
398 }
399
400 #[test]
401 fn test_mv_definition() {
402 let mv = MvDefinition::from_source("vwap", "trades");
403 assert_eq!(mv.name, "vwap");
404 assert_eq!(mv.source_refs, vec!["trades"]);
405
406 let mv_multi = MvDefinition::new(
407 "join_result",
408 vec!["orders".to_string(), "payments".to_string()],
409 );
410 assert_eq!(mv_multi.name, "join_result");
411 assert_eq!(mv_multi.source_refs.len(), 2);
412 }
413
414 #[test]
415 fn test_analyze_mv_sources() {
416 let mv = analyze_mv_sources("my_mv", &["table1", "table2"]);
417 assert_eq!(mv.name, "my_mv");
418 assert_eq!(mv.source_refs, vec!["table1", "table2"]);
419 }
420
421 #[test]
422 fn test_detailed_derivation() {
423 let sources = vec![
424 SourceDefinition::new("trades"),
425 SourceDefinition::new("orders"),
426 SourceDefinition::new("orphan"),
427 ];
428 let mvs = vec![
429 MvDefinition::from_source("vwap", "trades"),
430 MvDefinition::from_source("max_price", "trades"),
431 MvDefinition::from_source("order_count", "orders"),
432 ];
433
434 let result = derive_channel_types_detailed(&sources, &mvs);
435
436 assert_eq!(result.broadcast_count, 1); assert_eq!(result.spsc_count, 2); assert!(result.orphaned_sources.contains(&"orphan".to_string()));
439 }
440
441 #[test]
442 fn test_three_consumers() {
443 let sources = vec![SourceDefinition::new("events")];
444 let mvs = vec![
445 MvDefinition::from_source("mv1", "events"),
446 MvDefinition::from_source("mv2", "events"),
447 MvDefinition::from_source("mv3", "events"),
448 ];
449
450 let channel_types = derive_channel_types(&sources, &mvs);
451
452 assert_eq!(
453 channel_types.get("events"),
454 Some(&DerivedChannelType::Broadcast { consumer_count: 3 })
455 );
456 }
457}