Skip to main content

laminar_sql/planner/
channel_derivation.rs

1//! Channel type derivation from query plan analysis.
2//!
3//! This module analyzes query plans to automatically determine whether sources
4//! need SPSC (single consumer) or broadcast (multiple consumer) channels.
5//!
6//! # Key Principle
7//!
8//! **Users never specify broadcast mode.** The planner analyzes MVs and sources
9//! to derive the correct channel type automatically:
10//!
11//! - If a source is consumed by exactly 1 MV → SPSC (optimal)
12//! - If a source is consumed by 2+ MVs → Broadcast
13//!
14//! # Example
15//!
16//! ```sql
17//! CREATE SOURCE trades (...);
18//!
19//! CREATE MATERIALIZED VIEW vwap AS
20//!   SELECT symbol, SUM(price * volume) / SUM(volume) AS vwap
21//!   FROM trades
22//!   GROUP BY symbol, TUMBLE(ts, INTERVAL '1' MINUTE);
23//!
24//! CREATE MATERIALIZED VIEW max_price AS
25//!   SELECT symbol, MAX(price)
26//!   FROM trades
27//!   GROUP BY symbol, TUMBLE(ts, INTERVAL '1' MINUTE);
28//! ```
29//!
30//! In this example, `trades` source has 2 consumers (`vwap` and `max_price`),
31//! so the planner derives `Broadcast { consumer_count: 2 }` for `trades`.
32
33#[allow(clippy::disallowed_types)] // cold path: query planning
34use std::collections::HashMap;
35
36/// Channel type derived from query analysis.
37///
38/// This enum represents the automatically-derived channel configuration
39/// for a source based on how many downstream consumers it has.
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum DerivedChannelType {
42    /// Single consumer - use SPSC channel.
43    ///
44    /// Optimal for sources with exactly one downstream MV.
45    /// No cloning overhead, lock-free single producer/consumer.
46    Spsc,
47
48    /// Multiple consumers - use Broadcast channel.
49    ///
50    /// Used when a source feeds multiple downstream MVs.
51    /// Values are cloned to each consumer.
52    Broadcast {
53        /// Number of downstream consumers.
54        consumer_count: usize,
55    },
56}
57
58impl DerivedChannelType {
59    /// Returns true if this is a broadcast channel.
60    #[must_use]
61    pub fn is_broadcast(&self) -> bool {
62        matches!(self, DerivedChannelType::Broadcast { .. })
63    }
64
65    /// Returns the consumer count.
66    #[must_use]
67    pub fn consumer_count(&self) -> usize {
68        match self {
69            DerivedChannelType::Spsc => 1,
70            DerivedChannelType::Broadcast { consumer_count } => *consumer_count,
71        }
72    }
73}
74
75/// Source definition for channel derivation.
76///
77/// Represents a registered streaming source that can be consumed by MVs.
78#[derive(Debug, Clone)]
79pub struct SourceDefinition {
80    /// Source name (e.g., "trades", "orders").
81    pub name: String,
82    /// Optional watermark column for event time processing.
83    pub watermark_column: Option<String>,
84}
85
86impl SourceDefinition {
87    /// Creates a new source definition.
88    #[must_use]
89    pub fn new(name: impl Into<String>) -> Self {
90        Self {
91            name: name.into(),
92            watermark_column: None,
93        }
94    }
95
96    /// Creates a source definition with a watermark column.
97    #[must_use]
98    pub fn with_watermark(name: impl Into<String>, watermark_column: impl Into<String>) -> Self {
99        Self {
100            name: name.into(),
101            watermark_column: Some(watermark_column.into()),
102        }
103    }
104}
105
106/// Materialized view definition for channel derivation.
107///
108/// Represents a continuous query that consumes from one or more sources.
109#[derive(Debug, Clone)]
110pub struct MvDefinition {
111    /// MV name (e.g., "vwap", "max_price").
112    pub name: String,
113    /// Sources this MV reads from.
114    pub source_refs: Vec<String>,
115}
116
117impl MvDefinition {
118    /// Creates a new MV definition.
119    #[must_use]
120    pub fn new(name: impl Into<String>, source_refs: Vec<String>) -> Self {
121        Self {
122            name: name.into(),
123            source_refs,
124        }
125    }
126
127    /// Creates an MV definition that reads from a single source.
128    #[must_use]
129    pub fn from_source(name: impl Into<String>, source: impl Into<String>) -> Self {
130        Self {
131            name: name.into(),
132            source_refs: vec![source.into()],
133        }
134    }
135}
136
137/// Analyzes query plan to determine channel types per source.
138///
139/// Examines how many MVs consume each source and determines whether
140/// SPSC or Broadcast channels are needed.
141///
142/// # Arguments
143///
144/// * `sources` - Registered source definitions
145/// * `mvs` - Materialized view definitions
146///
147/// # Returns
148///
149/// A map from source name to derived channel type.
150///
151/// # Example
152///
153/// ```rust,ignore
154/// use laminar_sql::planner::channel_derivation::*;
155///
156/// let sources = vec![
157///     SourceDefinition::new("trades"),
158///     SourceDefinition::new("orders"),
159/// ];
160///
161/// let mvs = vec![
162///     MvDefinition::from_source("vwap", "trades"),
163///     MvDefinition::from_source("max_price", "trades"),
164///     MvDefinition::from_source("order_count", "orders"),
165/// ];
166///
167/// let channel_types = derive_channel_types(&sources, &mvs);
168///
169/// // trades has 2 consumers → Broadcast
170/// assert_eq!(
171///     channel_types.get("trades"),
172///     Some(&DerivedChannelType::Broadcast { consumer_count: 2 })
173/// );
174///
175/// // orders has 1 consumer → SPSC
176/// assert_eq!(
177///     channel_types.get("orders"),
178///     Some(&DerivedChannelType::Spsc)
179/// );
180/// ```
181#[must_use]
182pub fn derive_channel_types(
183    sources: &[SourceDefinition],
184    mvs: &[MvDefinition],
185) -> HashMap<String, DerivedChannelType> {
186    let consumer_counts = count_consumers_per_source(mvs);
187
188    sources
189        .iter()
190        .map(|source| {
191            let count = consumer_counts.get(&source.name).copied().unwrap_or(0);
192            let channel_type = if count <= 1 {
193                DerivedChannelType::Spsc
194            } else {
195                DerivedChannelType::Broadcast {
196                    consumer_count: count,
197                }
198            };
199            (source.name.clone(), channel_type)
200        })
201        .collect()
202}
203
204/// Counts how many MVs read from each source.
205fn count_consumers_per_source(mvs: &[MvDefinition]) -> HashMap<String, usize> {
206    let mut counts: HashMap<String, usize> = HashMap::with_capacity(mvs.len());
207
208    for mv in mvs {
209        for source_ref in &mv.source_refs {
210            *counts.entry(source_ref.clone()).or_insert(0) += 1;
211        }
212    }
213
214    counts
215}
216
217/// Analyzes a single MV to extract its source references.
218///
219/// This is a helper for parsing SQL queries to find referenced sources.
220/// In practice, this would integrate with the SQL parser to extract
221/// table references from FROM clauses.
222///
223/// # Arguments
224///
225/// * `mv_name` - The MV name
226/// * `source_tables` - Tables referenced in the query
227///
228/// # Returns
229///
230/// An `MvDefinition` with the extracted source references.
231#[must_use]
232pub fn analyze_mv_sources(mv_name: &str, source_tables: &[&str]) -> MvDefinition {
233    MvDefinition::new(
234        mv_name.to_string(),
235        source_tables.iter().map(|s| (*s).to_string()).collect(),
236    )
237}
238
239/// Channel derivation result with additional metadata.
240#[derive(Debug, Clone)]
241pub struct ChannelDerivationResult {
242    /// Derived channel types per source.
243    pub channel_types: HashMap<String, DerivedChannelType>,
244    /// Sources with no consumers (orphaned).
245    pub orphaned_sources: Vec<String>,
246    /// Total broadcast channels needed.
247    pub broadcast_count: usize,
248    /// Total SPSC channels needed.
249    pub spsc_count: usize,
250}
251
252/// Derives channel types with additional analysis metadata.
253///
254/// Returns a result that includes orphaned sources (sources with no consumers)
255/// and counts of each channel type.
256#[must_use]
257pub fn derive_channel_types_detailed(
258    sources: &[SourceDefinition],
259    mvs: &[MvDefinition],
260) -> ChannelDerivationResult {
261    let channel_types = derive_channel_types(sources, mvs);
262
263    let orphaned_sources: Vec<String> = channel_types
264        .iter()
265        .filter(|(_, ct)| matches!(ct, DerivedChannelType::Spsc))
266        .filter(|(name, _)| {
267            // Check if this source actually has any consumers
268            !mvs.iter().any(|mv| mv.source_refs.contains(*name))
269        })
270        .map(|(name, _)| name.clone())
271        .collect();
272
273    let broadcast_count = channel_types
274        .values()
275        .filter(|ct| ct.is_broadcast())
276        .count();
277
278    let spsc_count = channel_types.len() - broadcast_count;
279
280    ChannelDerivationResult {
281        channel_types,
282        orphaned_sources,
283        broadcast_count,
284        spsc_count,
285    }
286}
287
288// ===========================================================================
289// Tests
290// ===========================================================================
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    #[test]
297    fn test_derive_single_consumer_spsc() {
298        let sources = vec![SourceDefinition::new("trades")];
299        let mvs = vec![MvDefinition::from_source("vwap", "trades")];
300
301        let channel_types = derive_channel_types(&sources, &mvs);
302
303        assert_eq!(channel_types.get("trades"), Some(&DerivedChannelType::Spsc));
304    }
305
306    #[test]
307    fn test_derive_multiple_consumers_broadcast() {
308        let sources = vec![SourceDefinition::new("trades")];
309        let mvs = vec![
310            MvDefinition::from_source("vwap", "trades"),
311            MvDefinition::from_source("max_price", "trades"),
312        ];
313
314        let channel_types = derive_channel_types(&sources, &mvs);
315
316        assert_eq!(
317            channel_types.get("trades"),
318            Some(&DerivedChannelType::Broadcast { consumer_count: 2 })
319        );
320    }
321
322    #[test]
323    fn test_derive_mixed_sources() {
324        let sources = vec![
325            SourceDefinition::new("trades"),
326            SourceDefinition::new("orders"),
327        ];
328        let mvs = vec![
329            MvDefinition::from_source("vwap", "trades"),
330            MvDefinition::from_source("max_price", "trades"),
331            MvDefinition::from_source("order_count", "orders"),
332        ];
333
334        let channel_types = derive_channel_types(&sources, &mvs);
335
336        // trades: 2 consumers → Broadcast
337        assert_eq!(
338            channel_types.get("trades"),
339            Some(&DerivedChannelType::Broadcast { consumer_count: 2 })
340        );
341
342        // orders: 1 consumer → SPSC
343        assert_eq!(channel_types.get("orders"), Some(&DerivedChannelType::Spsc));
344    }
345
346    #[test]
347    fn test_derive_no_consumers() {
348        let sources = vec![SourceDefinition::new("orphan")];
349        let mvs: Vec<MvDefinition> = vec![];
350
351        let channel_types = derive_channel_types(&sources, &mvs);
352
353        // No consumers → SPSC (default)
354        assert_eq!(channel_types.get("orphan"), Some(&DerivedChannelType::Spsc));
355    }
356
357    #[test]
358    fn test_derive_mv_with_multiple_sources() {
359        let sources = vec![
360            SourceDefinition::new("orders"),
361            SourceDefinition::new("payments"),
362        ];
363        let mvs = vec![MvDefinition::new(
364            "order_payments",
365            vec!["orders".to_string(), "payments".to_string()],
366        )];
367
368        let channel_types = derive_channel_types(&sources, &mvs);
369
370        // Both sources have 1 consumer → SPSC
371        assert_eq!(channel_types.get("orders"), Some(&DerivedChannelType::Spsc));
372        assert_eq!(
373            channel_types.get("payments"),
374            Some(&DerivedChannelType::Spsc)
375        );
376    }
377
378    #[test]
379    fn test_derived_channel_type_methods() {
380        let spsc = DerivedChannelType::Spsc;
381        assert!(!spsc.is_broadcast());
382        assert_eq!(spsc.consumer_count(), 1);
383
384        let broadcast = DerivedChannelType::Broadcast { consumer_count: 3 };
385        assert!(broadcast.is_broadcast());
386        assert_eq!(broadcast.consumer_count(), 3);
387    }
388
389    #[test]
390    fn test_source_definition() {
391        let source = SourceDefinition::new("trades");
392        assert_eq!(source.name, "trades");
393        assert!(source.watermark_column.is_none());
394
395        let source_wm = SourceDefinition::with_watermark("trades", "event_time");
396        assert_eq!(source_wm.name, "trades");
397        assert_eq!(source_wm.watermark_column, Some("event_time".to_string()));
398    }
399
400    #[test]
401    fn test_mv_definition() {
402        let mv = MvDefinition::from_source("vwap", "trades");
403        assert_eq!(mv.name, "vwap");
404        assert_eq!(mv.source_refs, vec!["trades"]);
405
406        let mv_multi = MvDefinition::new(
407            "join_result",
408            vec!["orders".to_string(), "payments".to_string()],
409        );
410        assert_eq!(mv_multi.name, "join_result");
411        assert_eq!(mv_multi.source_refs.len(), 2);
412    }
413
414    #[test]
415    fn test_analyze_mv_sources() {
416        let mv = analyze_mv_sources("my_mv", &["table1", "table2"]);
417        assert_eq!(mv.name, "my_mv");
418        assert_eq!(mv.source_refs, vec!["table1", "table2"]);
419    }
420
421    #[test]
422    fn test_detailed_derivation() {
423        let sources = vec![
424            SourceDefinition::new("trades"),
425            SourceDefinition::new("orders"),
426            SourceDefinition::new("orphan"),
427        ];
428        let mvs = vec![
429            MvDefinition::from_source("vwap", "trades"),
430            MvDefinition::from_source("max_price", "trades"),
431            MvDefinition::from_source("order_count", "orders"),
432        ];
433
434        let result = derive_channel_types_detailed(&sources, &mvs);
435
436        assert_eq!(result.broadcast_count, 1); // trades
437        assert_eq!(result.spsc_count, 2); // orders, orphan
438        assert!(result.orphaned_sources.contains(&"orphan".to_string()));
439    }
440
441    #[test]
442    fn test_three_consumers() {
443        let sources = vec![SourceDefinition::new("events")];
444        let mvs = vec![
445            MvDefinition::from_source("mv1", "events"),
446            MvDefinition::from_source("mv2", "events"),
447            MvDefinition::from_source("mv3", "events"),
448        ];
449
450        let channel_types = derive_channel_types(&sources, &mvs);
451
452        assert_eq!(
453            channel_types.get("events"),
454            Some(&DerivedChannelType::Broadcast { consumer_count: 3 })
455        );
456    }
457}