1use std::collections::HashMap;
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum DerivedChannelType {
41 Spsc,
46
47 Broadcast {
52 consumer_count: usize,
54 },
55}
56
57impl DerivedChannelType {
58 #[must_use]
60 pub fn is_broadcast(&self) -> bool {
61 matches!(self, DerivedChannelType::Broadcast { .. })
62 }
63
64 #[must_use]
66 pub fn consumer_count(&self) -> usize {
67 match self {
68 DerivedChannelType::Spsc => 1,
69 DerivedChannelType::Broadcast { consumer_count } => *consumer_count,
70 }
71 }
72}
73
74#[derive(Debug, Clone)]
78pub struct SourceDefinition {
79 pub name: String,
81 pub watermark_column: Option<String>,
83}
84
85impl SourceDefinition {
86 #[must_use]
88 pub fn new(name: impl Into<String>) -> Self {
89 Self {
90 name: name.into(),
91 watermark_column: None,
92 }
93 }
94
95 #[must_use]
97 pub fn with_watermark(name: impl Into<String>, watermark_column: impl Into<String>) -> Self {
98 Self {
99 name: name.into(),
100 watermark_column: Some(watermark_column.into()),
101 }
102 }
103}
104
105#[derive(Debug, Clone)]
109pub struct MvDefinition {
110 pub name: String,
112 pub source_refs: Vec<String>,
114}
115
116impl MvDefinition {
117 #[must_use]
119 pub fn new(name: impl Into<String>, source_refs: Vec<String>) -> Self {
120 Self {
121 name: name.into(),
122 source_refs,
123 }
124 }
125
126 #[must_use]
128 pub fn from_source(name: impl Into<String>, source: impl Into<String>) -> Self {
129 Self {
130 name: name.into(),
131 source_refs: vec![source.into()],
132 }
133 }
134}
135
136#[must_use]
181pub fn derive_channel_types(
182 sources: &[SourceDefinition],
183 mvs: &[MvDefinition],
184) -> HashMap<String, DerivedChannelType> {
185 let consumer_counts = count_consumers_per_source(mvs);
186
187 sources
188 .iter()
189 .map(|source| {
190 let count = consumer_counts.get(&source.name).copied().unwrap_or(0);
191 let channel_type = if count <= 1 {
192 DerivedChannelType::Spsc
193 } else {
194 DerivedChannelType::Broadcast {
195 consumer_count: count,
196 }
197 };
198 (source.name.clone(), channel_type)
199 })
200 .collect()
201}
202
203fn count_consumers_per_source(mvs: &[MvDefinition]) -> HashMap<String, usize> {
205 let mut counts: HashMap<String, usize> = HashMap::new();
206
207 for mv in mvs {
208 for source_ref in &mv.source_refs {
209 *counts.entry(source_ref.clone()).or_insert(0) += 1;
210 }
211 }
212
213 counts
214}
215
216#[must_use]
231pub fn analyze_mv_sources(mv_name: &str, source_tables: &[&str]) -> MvDefinition {
232 MvDefinition::new(
233 mv_name.to_string(),
234 source_tables.iter().map(|s| (*s).to_string()).collect(),
235 )
236}
237
238#[derive(Debug, Clone)]
240pub struct ChannelDerivationResult {
241 pub channel_types: HashMap<String, DerivedChannelType>,
243 pub orphaned_sources: Vec<String>,
245 pub broadcast_count: usize,
247 pub spsc_count: usize,
249}
250
251#[must_use]
256pub fn derive_channel_types_detailed(
257 sources: &[SourceDefinition],
258 mvs: &[MvDefinition],
259) -> ChannelDerivationResult {
260 let channel_types = derive_channel_types(sources, mvs);
261
262 let orphaned_sources: Vec<String> = channel_types
263 .iter()
264 .filter(|(_, ct)| matches!(ct, DerivedChannelType::Spsc))
265 .filter(|(name, _)| {
266 !mvs.iter().any(|mv| mv.source_refs.contains(*name))
268 })
269 .map(|(name, _)| name.clone())
270 .collect();
271
272 let broadcast_count = channel_types
273 .values()
274 .filter(|ct| ct.is_broadcast())
275 .count();
276
277 let spsc_count = channel_types.len() - broadcast_count;
278
279 ChannelDerivationResult {
280 channel_types,
281 orphaned_sources,
282 broadcast_count,
283 spsc_count,
284 }
285}
286
287#[cfg(test)]
292mod tests {
293 use super::*;
294
295 #[test]
296 fn test_derive_single_consumer_spsc() {
297 let sources = vec![SourceDefinition::new("trades")];
298 let mvs = vec![MvDefinition::from_source("vwap", "trades")];
299
300 let channel_types = derive_channel_types(&sources, &mvs);
301
302 assert_eq!(channel_types.get("trades"), Some(&DerivedChannelType::Spsc));
303 }
304
305 #[test]
306 fn test_derive_multiple_consumers_broadcast() {
307 let sources = vec![SourceDefinition::new("trades")];
308 let mvs = vec![
309 MvDefinition::from_source("vwap", "trades"),
310 MvDefinition::from_source("max_price", "trades"),
311 ];
312
313 let channel_types = derive_channel_types(&sources, &mvs);
314
315 assert_eq!(
316 channel_types.get("trades"),
317 Some(&DerivedChannelType::Broadcast { consumer_count: 2 })
318 );
319 }
320
321 #[test]
322 fn test_derive_mixed_sources() {
323 let sources = vec![
324 SourceDefinition::new("trades"),
325 SourceDefinition::new("orders"),
326 ];
327 let mvs = vec![
328 MvDefinition::from_source("vwap", "trades"),
329 MvDefinition::from_source("max_price", "trades"),
330 MvDefinition::from_source("order_count", "orders"),
331 ];
332
333 let channel_types = derive_channel_types(&sources, &mvs);
334
335 assert_eq!(
337 channel_types.get("trades"),
338 Some(&DerivedChannelType::Broadcast { consumer_count: 2 })
339 );
340
341 assert_eq!(channel_types.get("orders"), Some(&DerivedChannelType::Spsc));
343 }
344
345 #[test]
346 fn test_derive_no_consumers() {
347 let sources = vec![SourceDefinition::new("orphan")];
348 let mvs: Vec<MvDefinition> = vec![];
349
350 let channel_types = derive_channel_types(&sources, &mvs);
351
352 assert_eq!(channel_types.get("orphan"), Some(&DerivedChannelType::Spsc));
354 }
355
356 #[test]
357 fn test_derive_mv_with_multiple_sources() {
358 let sources = vec![
359 SourceDefinition::new("orders"),
360 SourceDefinition::new("payments"),
361 ];
362 let mvs = vec![MvDefinition::new(
363 "order_payments",
364 vec!["orders".to_string(), "payments".to_string()],
365 )];
366
367 let channel_types = derive_channel_types(&sources, &mvs);
368
369 assert_eq!(channel_types.get("orders"), Some(&DerivedChannelType::Spsc));
371 assert_eq!(
372 channel_types.get("payments"),
373 Some(&DerivedChannelType::Spsc)
374 );
375 }
376
377 #[test]
378 fn test_derived_channel_type_methods() {
379 let spsc = DerivedChannelType::Spsc;
380 assert!(!spsc.is_broadcast());
381 assert_eq!(spsc.consumer_count(), 1);
382
383 let broadcast = DerivedChannelType::Broadcast { consumer_count: 3 };
384 assert!(broadcast.is_broadcast());
385 assert_eq!(broadcast.consumer_count(), 3);
386 }
387
388 #[test]
389 fn test_source_definition() {
390 let source = SourceDefinition::new("trades");
391 assert_eq!(source.name, "trades");
392 assert!(source.watermark_column.is_none());
393
394 let source_wm = SourceDefinition::with_watermark("trades", "event_time");
395 assert_eq!(source_wm.name, "trades");
396 assert_eq!(source_wm.watermark_column, Some("event_time".to_string()));
397 }
398
399 #[test]
400 fn test_mv_definition() {
401 let mv = MvDefinition::from_source("vwap", "trades");
402 assert_eq!(mv.name, "vwap");
403 assert_eq!(mv.source_refs, vec!["trades"]);
404
405 let mv_multi = MvDefinition::new(
406 "join_result",
407 vec!["orders".to_string(), "payments".to_string()],
408 );
409 assert_eq!(mv_multi.name, "join_result");
410 assert_eq!(mv_multi.source_refs.len(), 2);
411 }
412
413 #[test]
414 fn test_analyze_mv_sources() {
415 let mv = analyze_mv_sources("my_mv", &["table1", "table2"]);
416 assert_eq!(mv.name, "my_mv");
417 assert_eq!(mv.source_refs, vec!["table1", "table2"]);
418 }
419
420 #[test]
421 fn test_detailed_derivation() {
422 let sources = vec![
423 SourceDefinition::new("trades"),
424 SourceDefinition::new("orders"),
425 SourceDefinition::new("orphan"),
426 ];
427 let mvs = vec![
428 MvDefinition::from_source("vwap", "trades"),
429 MvDefinition::from_source("max_price", "trades"),
430 MvDefinition::from_source("order_count", "orders"),
431 ];
432
433 let result = derive_channel_types_detailed(&sources, &mvs);
434
435 assert_eq!(result.broadcast_count, 1); assert_eq!(result.spsc_count, 2); assert!(result.orphaned_sources.contains(&"orphan".to_string()));
438 }
439
440 #[test]
441 fn test_three_consumers() {
442 let sources = vec![SourceDefinition::new("events")];
443 let mvs = vec![
444 MvDefinition::from_source("mv1", "events"),
445 MvDefinition::from_source("mv2", "events"),
446 MvDefinition::from_source("mv3", "events"),
447 ];
448
449 let channel_types = derive_channel_types(&sources, &mvs);
450
451 assert_eq!(
452 channel_types.get("events"),
453 Some(&DerivedChannelType::Broadcast { consumer_count: 3 })
454 );
455 }
456}