Skip to main content

laminar_sql/planner/
channel_derivation.rs

1//! Channel type derivation from query plan analysis.
2//!
3//! This module analyzes query plans to automatically determine whether sources
4//! need SPSC (single consumer) or broadcast (multiple consumer) channels.
5//!
6//! # Key Principle
7//!
8//! **Users never specify broadcast mode.** The planner analyzes MVs and sources
9//! to derive the correct channel type automatically:
10//!
11//! - If a source is consumed by exactly 1 MV → SPSC (optimal)
12//! - If a source is consumed by 2+ MVs → Broadcast
13//!
14//! # Example
15//!
16//! ```sql
17//! CREATE SOURCE trades (...);
18//!
19//! CREATE MATERIALIZED VIEW vwap AS
20//!   SELECT symbol, SUM(price * volume) / SUM(volume) AS vwap
21//!   FROM trades
22//!   GROUP BY symbol, TUMBLE(ts, INTERVAL '1' MINUTE);
23//!
24//! CREATE MATERIALIZED VIEW max_price AS
25//!   SELECT symbol, MAX(price)
26//!   FROM trades
27//!   GROUP BY symbol, TUMBLE(ts, INTERVAL '1' MINUTE);
28//! ```
29//!
30//! In this example, `trades` source has 2 consumers (`vwap` and `max_price`),
31//! so the planner derives `Broadcast { consumer_count: 2 }` for `trades`.
32
33use std::collections::HashMap;
34
35/// Channel type derived from query analysis.
36///
37/// This enum represents the automatically-derived channel configuration
38/// for a source based on how many downstream consumers it has.
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum DerivedChannelType {
41    /// Single consumer - use SPSC channel.
42    ///
43    /// Optimal for sources with exactly one downstream MV.
44    /// No cloning overhead, lock-free single producer/consumer.
45    Spsc,
46
47    /// Multiple consumers - use Broadcast channel.
48    ///
49    /// Used when a source feeds multiple downstream MVs.
50    /// Values are cloned to each consumer.
51    Broadcast {
52        /// Number of downstream consumers.
53        consumer_count: usize,
54    },
55}
56
57impl DerivedChannelType {
58    /// Returns true if this is a broadcast channel.
59    #[must_use]
60    pub fn is_broadcast(&self) -> bool {
61        matches!(self, DerivedChannelType::Broadcast { .. })
62    }
63
64    /// Returns the consumer count.
65    #[must_use]
66    pub fn consumer_count(&self) -> usize {
67        match self {
68            DerivedChannelType::Spsc => 1,
69            DerivedChannelType::Broadcast { consumer_count } => *consumer_count,
70        }
71    }
72}
73
74/// Source definition for channel derivation.
75///
76/// Represents a registered streaming source that can be consumed by MVs.
77#[derive(Debug, Clone)]
78pub struct SourceDefinition {
79    /// Source name (e.g., "trades", "orders").
80    pub name: String,
81    /// Optional watermark column for event time processing.
82    pub watermark_column: Option<String>,
83}
84
85impl SourceDefinition {
86    /// Creates a new source definition.
87    #[must_use]
88    pub fn new(name: impl Into<String>) -> Self {
89        Self {
90            name: name.into(),
91            watermark_column: None,
92        }
93    }
94
95    /// Creates a source definition with a watermark column.
96    #[must_use]
97    pub fn with_watermark(name: impl Into<String>, watermark_column: impl Into<String>) -> Self {
98        Self {
99            name: name.into(),
100            watermark_column: Some(watermark_column.into()),
101        }
102    }
103}
104
105/// Materialized view definition for channel derivation.
106///
107/// Represents a continuous query that consumes from one or more sources.
108#[derive(Debug, Clone)]
109pub struct MvDefinition {
110    /// MV name (e.g., "vwap", "max_price").
111    pub name: String,
112    /// Sources this MV reads from.
113    pub source_refs: Vec<String>,
114}
115
116impl MvDefinition {
117    /// Creates a new MV definition.
118    #[must_use]
119    pub fn new(name: impl Into<String>, source_refs: Vec<String>) -> Self {
120        Self {
121            name: name.into(),
122            source_refs,
123        }
124    }
125
126    /// Creates an MV definition that reads from a single source.
127    #[must_use]
128    pub fn from_source(name: impl Into<String>, source: impl Into<String>) -> Self {
129        Self {
130            name: name.into(),
131            source_refs: vec![source.into()],
132        }
133    }
134}
135
136/// Analyzes query plan to determine channel types per source.
137///
138/// Examines how many MVs consume each source and determines whether
139/// SPSC or Broadcast channels are needed.
140///
141/// # Arguments
142///
143/// * `sources` - Registered source definitions
144/// * `mvs` - Materialized view definitions
145///
146/// # Returns
147///
148/// A map from source name to derived channel type.
149///
150/// # Example
151///
152/// ```rust,ignore
153/// use laminar_sql::planner::channel_derivation::*;
154///
155/// let sources = vec![
156///     SourceDefinition::new("trades"),
157///     SourceDefinition::new("orders"),
158/// ];
159///
160/// let mvs = vec![
161///     MvDefinition::from_source("vwap", "trades"),
162///     MvDefinition::from_source("max_price", "trades"),
163///     MvDefinition::from_source("order_count", "orders"),
164/// ];
165///
166/// let channel_types = derive_channel_types(&sources, &mvs);
167///
168/// // trades has 2 consumers → Broadcast
169/// assert_eq!(
170///     channel_types.get("trades"),
171///     Some(&DerivedChannelType::Broadcast { consumer_count: 2 })
172/// );
173///
174/// // orders has 1 consumer → SPSC
175/// assert_eq!(
176///     channel_types.get("orders"),
177///     Some(&DerivedChannelType::Spsc)
178/// );
179/// ```
180#[must_use]
181pub fn derive_channel_types(
182    sources: &[SourceDefinition],
183    mvs: &[MvDefinition],
184) -> HashMap<String, DerivedChannelType> {
185    let consumer_counts = count_consumers_per_source(mvs);
186
187    sources
188        .iter()
189        .map(|source| {
190            let count = consumer_counts.get(&source.name).copied().unwrap_or(0);
191            let channel_type = if count <= 1 {
192                DerivedChannelType::Spsc
193            } else {
194                DerivedChannelType::Broadcast {
195                    consumer_count: count,
196                }
197            };
198            (source.name.clone(), channel_type)
199        })
200        .collect()
201}
202
203/// Counts how many MVs read from each source.
204fn count_consumers_per_source(mvs: &[MvDefinition]) -> HashMap<String, usize> {
205    let mut counts: HashMap<String, usize> = HashMap::new();
206
207    for mv in mvs {
208        for source_ref in &mv.source_refs {
209            *counts.entry(source_ref.clone()).or_insert(0) += 1;
210        }
211    }
212
213    counts
214}
215
216/// Analyzes a single MV to extract its source references.
217///
218/// This is a helper for parsing SQL queries to find referenced sources.
219/// In practice, this would integrate with the SQL parser to extract
220/// table references from FROM clauses.
221///
222/// # Arguments
223///
224/// * `mv_name` - The MV name
225/// * `source_tables` - Tables referenced in the query
226///
227/// # Returns
228///
229/// An `MvDefinition` with the extracted source references.
230#[must_use]
231pub fn analyze_mv_sources(mv_name: &str, source_tables: &[&str]) -> MvDefinition {
232    MvDefinition::new(
233        mv_name.to_string(),
234        source_tables.iter().map(|s| (*s).to_string()).collect(),
235    )
236}
237
238/// Channel derivation result with additional metadata.
239#[derive(Debug, Clone)]
240pub struct ChannelDerivationResult {
241    /// Derived channel types per source.
242    pub channel_types: HashMap<String, DerivedChannelType>,
243    /// Sources with no consumers (orphaned).
244    pub orphaned_sources: Vec<String>,
245    /// Total broadcast channels needed.
246    pub broadcast_count: usize,
247    /// Total SPSC channels needed.
248    pub spsc_count: usize,
249}
250
251/// Derives channel types with additional analysis metadata.
252///
253/// Returns a result that includes orphaned sources (sources with no consumers)
254/// and counts of each channel type.
255#[must_use]
256pub fn derive_channel_types_detailed(
257    sources: &[SourceDefinition],
258    mvs: &[MvDefinition],
259) -> ChannelDerivationResult {
260    let channel_types = derive_channel_types(sources, mvs);
261
262    let orphaned_sources: Vec<String> = channel_types
263        .iter()
264        .filter(|(_, ct)| matches!(ct, DerivedChannelType::Spsc))
265        .filter(|(name, _)| {
266            // Check if this source actually has any consumers
267            !mvs.iter().any(|mv| mv.source_refs.contains(*name))
268        })
269        .map(|(name, _)| name.clone())
270        .collect();
271
272    let broadcast_count = channel_types
273        .values()
274        .filter(|ct| ct.is_broadcast())
275        .count();
276
277    let spsc_count = channel_types.len() - broadcast_count;
278
279    ChannelDerivationResult {
280        channel_types,
281        orphaned_sources,
282        broadcast_count,
283        spsc_count,
284    }
285}
286
287// ===========================================================================
288// Tests
289// ===========================================================================
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    #[test]
296    fn test_derive_single_consumer_spsc() {
297        let sources = vec![SourceDefinition::new("trades")];
298        let mvs = vec![MvDefinition::from_source("vwap", "trades")];
299
300        let channel_types = derive_channel_types(&sources, &mvs);
301
302        assert_eq!(channel_types.get("trades"), Some(&DerivedChannelType::Spsc));
303    }
304
305    #[test]
306    fn test_derive_multiple_consumers_broadcast() {
307        let sources = vec![SourceDefinition::new("trades")];
308        let mvs = vec![
309            MvDefinition::from_source("vwap", "trades"),
310            MvDefinition::from_source("max_price", "trades"),
311        ];
312
313        let channel_types = derive_channel_types(&sources, &mvs);
314
315        assert_eq!(
316            channel_types.get("trades"),
317            Some(&DerivedChannelType::Broadcast { consumer_count: 2 })
318        );
319    }
320
321    #[test]
322    fn test_derive_mixed_sources() {
323        let sources = vec![
324            SourceDefinition::new("trades"),
325            SourceDefinition::new("orders"),
326        ];
327        let mvs = vec![
328            MvDefinition::from_source("vwap", "trades"),
329            MvDefinition::from_source("max_price", "trades"),
330            MvDefinition::from_source("order_count", "orders"),
331        ];
332
333        let channel_types = derive_channel_types(&sources, &mvs);
334
335        // trades: 2 consumers → Broadcast
336        assert_eq!(
337            channel_types.get("trades"),
338            Some(&DerivedChannelType::Broadcast { consumer_count: 2 })
339        );
340
341        // orders: 1 consumer → SPSC
342        assert_eq!(channel_types.get("orders"), Some(&DerivedChannelType::Spsc));
343    }
344
345    #[test]
346    fn test_derive_no_consumers() {
347        let sources = vec![SourceDefinition::new("orphan")];
348        let mvs: Vec<MvDefinition> = vec![];
349
350        let channel_types = derive_channel_types(&sources, &mvs);
351
352        // No consumers → SPSC (default)
353        assert_eq!(channel_types.get("orphan"), Some(&DerivedChannelType::Spsc));
354    }
355
356    #[test]
357    fn test_derive_mv_with_multiple_sources() {
358        let sources = vec![
359            SourceDefinition::new("orders"),
360            SourceDefinition::new("payments"),
361        ];
362        let mvs = vec![MvDefinition::new(
363            "order_payments",
364            vec!["orders".to_string(), "payments".to_string()],
365        )];
366
367        let channel_types = derive_channel_types(&sources, &mvs);
368
369        // Both sources have 1 consumer → SPSC
370        assert_eq!(channel_types.get("orders"), Some(&DerivedChannelType::Spsc));
371        assert_eq!(
372            channel_types.get("payments"),
373            Some(&DerivedChannelType::Spsc)
374        );
375    }
376
377    #[test]
378    fn test_derived_channel_type_methods() {
379        let spsc = DerivedChannelType::Spsc;
380        assert!(!spsc.is_broadcast());
381        assert_eq!(spsc.consumer_count(), 1);
382
383        let broadcast = DerivedChannelType::Broadcast { consumer_count: 3 };
384        assert!(broadcast.is_broadcast());
385        assert_eq!(broadcast.consumer_count(), 3);
386    }
387
388    #[test]
389    fn test_source_definition() {
390        let source = SourceDefinition::new("trades");
391        assert_eq!(source.name, "trades");
392        assert!(source.watermark_column.is_none());
393
394        let source_wm = SourceDefinition::with_watermark("trades", "event_time");
395        assert_eq!(source_wm.name, "trades");
396        assert_eq!(source_wm.watermark_column, Some("event_time".to_string()));
397    }
398
399    #[test]
400    fn test_mv_definition() {
401        let mv = MvDefinition::from_source("vwap", "trades");
402        assert_eq!(mv.name, "vwap");
403        assert_eq!(mv.source_refs, vec!["trades"]);
404
405        let mv_multi = MvDefinition::new(
406            "join_result",
407            vec!["orders".to_string(), "payments".to_string()],
408        );
409        assert_eq!(mv_multi.name, "join_result");
410        assert_eq!(mv_multi.source_refs.len(), 2);
411    }
412
413    #[test]
414    fn test_analyze_mv_sources() {
415        let mv = analyze_mv_sources("my_mv", &["table1", "table2"]);
416        assert_eq!(mv.name, "my_mv");
417        assert_eq!(mv.source_refs, vec!["table1", "table2"]);
418    }
419
420    #[test]
421    fn test_detailed_derivation() {
422        let sources = vec![
423            SourceDefinition::new("trades"),
424            SourceDefinition::new("orders"),
425            SourceDefinition::new("orphan"),
426        ];
427        let mvs = vec![
428            MvDefinition::from_source("vwap", "trades"),
429            MvDefinition::from_source("max_price", "trades"),
430            MvDefinition::from_source("order_count", "orders"),
431        ];
432
433        let result = derive_channel_types_detailed(&sources, &mvs);
434
435        assert_eq!(result.broadcast_count, 1); // trades
436        assert_eq!(result.spsc_count, 2); // orders, orphan
437        assert!(result.orphaned_sources.contains(&"orphan".to_string()));
438    }
439
440    #[test]
441    fn test_three_consumers() {
442        let sources = vec![SourceDefinition::new("events")];
443        let mvs = vec![
444            MvDefinition::from_source("mv1", "events"),
445            MvDefinition::from_source("mv2", "events"),
446            MvDefinition::from_source("mv3", "events"),
447        ];
448
449        let channel_types = derive_channel_types(&sources, &mvs);
450
451        assert_eq!(
452            channel_types.get("events"),
453            Some(&DerivedChannelType::Broadcast { consumer_count: 3 })
454        );
455    }
456}