Skip to main content

laminar_sql/planner/
channel_derivation.rs

1//! Channel type derivation from query plan analysis.
2
3#[allow(clippy::disallowed_types)] // cold path: query planning
4use std::collections::HashMap;
5
6/// Channel type derived from query analysis.
7///
8/// This enum represents the automatically-derived channel configuration
9/// for a source based on how many downstream consumers it has.
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum DerivedChannelType {
12    /// Single consumer - use SPSC channel.
13    ///
14    /// Optimal for sources with exactly one downstream MV.
15    /// No cloning overhead, lock-free single producer/consumer.
16    Spsc,
17
18    /// Multiple consumers - use Broadcast channel.
19    ///
20    /// Used when a source feeds multiple downstream MVs.
21    /// Values are cloned to each consumer.
22    Broadcast {
23        /// Number of downstream consumers.
24        consumer_count: usize,
25    },
26}
27
28impl DerivedChannelType {
29    /// Returns true if this is a broadcast channel.
30    #[must_use]
31    pub fn is_broadcast(&self) -> bool {
32        matches!(self, DerivedChannelType::Broadcast { .. })
33    }
34
35    /// Returns the consumer count.
36    #[must_use]
37    pub fn consumer_count(&self) -> usize {
38        match self {
39            DerivedChannelType::Spsc => 1,
40            DerivedChannelType::Broadcast { consumer_count } => *consumer_count,
41        }
42    }
43}
44
45/// Source definition for channel derivation.
46///
47/// Represents a registered streaming source that can be consumed by MVs.
48#[derive(Debug, Clone)]
49pub struct SourceDefinition {
50    /// Source name (e.g., "trades", "orders").
51    pub name: String,
52    /// Optional watermark column for event time processing.
53    pub watermark_column: Option<String>,
54}
55
56impl SourceDefinition {
57    /// Creates a new source definition.
58    #[must_use]
59    pub fn new(name: impl Into<String>) -> Self {
60        Self {
61            name: name.into(),
62            watermark_column: None,
63        }
64    }
65
66    /// Creates a source definition with a watermark column.
67    #[must_use]
68    pub fn with_watermark(name: impl Into<String>, watermark_column: impl Into<String>) -> Self {
69        Self {
70            name: name.into(),
71            watermark_column: Some(watermark_column.into()),
72        }
73    }
74}
75
76/// Materialized view definition for channel derivation.
77///
78/// Represents a continuous query that consumes from one or more sources.
79#[derive(Debug, Clone)]
80pub struct MvDefinition {
81    /// MV name (e.g., "vwap", "max_price").
82    pub name: String,
83    /// Sources this MV reads from.
84    pub source_refs: Vec<String>,
85}
86
87impl MvDefinition {
88    /// Creates a new MV definition.
89    #[must_use]
90    pub fn new(name: impl Into<String>, source_refs: Vec<String>) -> Self {
91        Self {
92            name: name.into(),
93            source_refs,
94        }
95    }
96
97    /// Creates an MV definition that reads from a single source.
98    #[must_use]
99    pub fn from_source(name: impl Into<String>, source: impl Into<String>) -> Self {
100        Self {
101            name: name.into(),
102            source_refs: vec![source.into()],
103        }
104    }
105}
106
107/// Analyzes query plan to determine channel types per source.
108///
109/// Examines how many MVs consume each source and determines whether
110/// SPSC or Broadcast channels are needed.
111///
112/// # Arguments
113///
114/// * `sources` - Registered source definitions
115/// * `mvs` - Materialized view definitions
116///
117/// # Returns
118///
119/// A map from source name to derived channel type.
120///
121/// # Example
122///
123/// ```rust,ignore
124/// use laminar_sql::planner::channel_derivation::*;
125///
126/// let sources = vec![
127///     SourceDefinition::new("trades"),
128///     SourceDefinition::new("orders"),
129/// ];
130///
131/// let mvs = vec![
132///     MvDefinition::from_source("vwap", "trades"),
133///     MvDefinition::from_source("max_price", "trades"),
134///     MvDefinition::from_source("order_count", "orders"),
135/// ];
136///
137/// let channel_types = derive_channel_types(&sources, &mvs);
138///
139/// // trades has 2 consumers → Broadcast
140/// assert_eq!(
141///     channel_types.get("trades"),
142///     Some(&DerivedChannelType::Broadcast { consumer_count: 2 })
143/// );
144///
145/// // orders has 1 consumer → SPSC
146/// assert_eq!(
147///     channel_types.get("orders"),
148///     Some(&DerivedChannelType::Spsc)
149/// );
150/// ```
151#[must_use]
152pub fn derive_channel_types(
153    sources: &[SourceDefinition],
154    mvs: &[MvDefinition],
155) -> HashMap<String, DerivedChannelType> {
156    let consumer_counts = count_consumers_per_source(mvs);
157
158    sources
159        .iter()
160        .map(|source| {
161            let count = consumer_counts.get(&source.name).copied().unwrap_or(0);
162            let channel_type = if count <= 1 {
163                DerivedChannelType::Spsc
164            } else {
165                DerivedChannelType::Broadcast {
166                    consumer_count: count,
167                }
168            };
169            (source.name.clone(), channel_type)
170        })
171        .collect()
172}
173
174/// Counts how many MVs read from each source.
175fn count_consumers_per_source(mvs: &[MvDefinition]) -> HashMap<String, usize> {
176    let mut counts: HashMap<String, usize> = HashMap::with_capacity(mvs.len());
177
178    for mv in mvs {
179        for source_ref in &mv.source_refs {
180            *counts.entry(source_ref.clone()).or_insert(0) += 1;
181        }
182    }
183
184    counts
185}
186
187/// Analyzes a single MV to extract its source references.
188///
189/// This is a helper for parsing SQL queries to find referenced sources.
190/// In practice, this would integrate with the SQL parser to extract
191/// table references from FROM clauses.
192///
193/// # Arguments
194///
195/// * `mv_name` - The MV name
196/// * `source_tables` - Tables referenced in the query
197///
198/// # Returns
199///
200/// An `MvDefinition` with the extracted source references.
201#[must_use]
202pub fn analyze_mv_sources(mv_name: &str, source_tables: &[&str]) -> MvDefinition {
203    MvDefinition::new(
204        mv_name.to_string(),
205        source_tables.iter().map(|s| (*s).to_string()).collect(),
206    )
207}
208
209/// Channel derivation result with additional metadata.
210#[derive(Debug, Clone)]
211pub struct ChannelDerivationResult {
212    /// Derived channel types per source.
213    pub channel_types: HashMap<String, DerivedChannelType>,
214    /// Sources with no consumers (orphaned).
215    pub orphaned_sources: Vec<String>,
216    /// Total broadcast channels needed.
217    pub broadcast_count: usize,
218    /// Total SPSC channels needed.
219    pub spsc_count: usize,
220}
221
222/// Derives channel types with additional analysis metadata.
223///
224/// Returns a result that includes orphaned sources (sources with no consumers)
225/// and counts of each channel type.
226#[must_use]
227pub fn derive_channel_types_detailed(
228    sources: &[SourceDefinition],
229    mvs: &[MvDefinition],
230) -> ChannelDerivationResult {
231    let channel_types = derive_channel_types(sources, mvs);
232
233    let orphaned_sources: Vec<String> = channel_types
234        .iter()
235        .filter(|(_, ct)| matches!(ct, DerivedChannelType::Spsc))
236        .filter(|(name, _)| {
237            // Check if this source actually has any consumers
238            !mvs.iter().any(|mv| mv.source_refs.contains(*name))
239        })
240        .map(|(name, _)| name.clone())
241        .collect();
242
243    let broadcast_count = channel_types
244        .values()
245        .filter(|ct| ct.is_broadcast())
246        .count();
247
248    let spsc_count = channel_types.len() - broadcast_count;
249
250    ChannelDerivationResult {
251        channel_types,
252        orphaned_sources,
253        broadcast_count,
254        spsc_count,
255    }
256}
257
258// ===========================================================================
259// Tests
260// ===========================================================================
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    #[test]
267    fn test_derive_single_consumer_spsc() {
268        let sources = vec![SourceDefinition::new("trades")];
269        let mvs = vec![MvDefinition::from_source("vwap", "trades")];
270
271        let channel_types = derive_channel_types(&sources, &mvs);
272
273        assert_eq!(channel_types.get("trades"), Some(&DerivedChannelType::Spsc));
274    }
275
276    #[test]
277    fn test_derive_multiple_consumers_broadcast() {
278        let sources = vec![SourceDefinition::new("trades")];
279        let mvs = vec![
280            MvDefinition::from_source("vwap", "trades"),
281            MvDefinition::from_source("max_price", "trades"),
282        ];
283
284        let channel_types = derive_channel_types(&sources, &mvs);
285
286        assert_eq!(
287            channel_types.get("trades"),
288            Some(&DerivedChannelType::Broadcast { consumer_count: 2 })
289        );
290    }
291
292    #[test]
293    fn test_derive_mixed_sources() {
294        let sources = vec![
295            SourceDefinition::new("trades"),
296            SourceDefinition::new("orders"),
297        ];
298        let mvs = vec![
299            MvDefinition::from_source("vwap", "trades"),
300            MvDefinition::from_source("max_price", "trades"),
301            MvDefinition::from_source("order_count", "orders"),
302        ];
303
304        let channel_types = derive_channel_types(&sources, &mvs);
305
306        // trades: 2 consumers → Broadcast
307        assert_eq!(
308            channel_types.get("trades"),
309            Some(&DerivedChannelType::Broadcast { consumer_count: 2 })
310        );
311
312        // orders: 1 consumer → SPSC
313        assert_eq!(channel_types.get("orders"), Some(&DerivedChannelType::Spsc));
314    }
315
316    #[test]
317    fn test_derive_no_consumers() {
318        let sources = vec![SourceDefinition::new("orphan")];
319        let mvs: Vec<MvDefinition> = vec![];
320
321        let channel_types = derive_channel_types(&sources, &mvs);
322
323        // No consumers → SPSC (default)
324        assert_eq!(channel_types.get("orphan"), Some(&DerivedChannelType::Spsc));
325    }
326
327    #[test]
328    fn test_derive_mv_with_multiple_sources() {
329        let sources = vec![
330            SourceDefinition::new("orders"),
331            SourceDefinition::new("payments"),
332        ];
333        let mvs = vec![MvDefinition::new(
334            "order_payments",
335            vec!["orders".to_string(), "payments".to_string()],
336        )];
337
338        let channel_types = derive_channel_types(&sources, &mvs);
339
340        // Both sources have 1 consumer → SPSC
341        assert_eq!(channel_types.get("orders"), Some(&DerivedChannelType::Spsc));
342        assert_eq!(
343            channel_types.get("payments"),
344            Some(&DerivedChannelType::Spsc)
345        );
346    }
347
348    #[test]
349    fn test_derived_channel_type_methods() {
350        let spsc = DerivedChannelType::Spsc;
351        assert!(!spsc.is_broadcast());
352        assert_eq!(spsc.consumer_count(), 1);
353
354        let broadcast = DerivedChannelType::Broadcast { consumer_count: 3 };
355        assert!(broadcast.is_broadcast());
356        assert_eq!(broadcast.consumer_count(), 3);
357    }
358
359    #[test]
360    fn test_source_definition() {
361        let source = SourceDefinition::new("trades");
362        assert_eq!(source.name, "trades");
363        assert!(source.watermark_column.is_none());
364
365        let source_wm = SourceDefinition::with_watermark("trades", "event_time");
366        assert_eq!(source_wm.name, "trades");
367        assert_eq!(source_wm.watermark_column, Some("event_time".to_string()));
368    }
369
370    #[test]
371    fn test_mv_definition() {
372        let mv = MvDefinition::from_source("vwap", "trades");
373        assert_eq!(mv.name, "vwap");
374        assert_eq!(mv.source_refs, vec!["trades"]);
375
376        let mv_multi = MvDefinition::new(
377            "join_result",
378            vec!["orders".to_string(), "payments".to_string()],
379        );
380        assert_eq!(mv_multi.name, "join_result");
381        assert_eq!(mv_multi.source_refs.len(), 2);
382    }
383
384    #[test]
385    fn test_analyze_mv_sources() {
386        let mv = analyze_mv_sources("my_mv", &["table1", "table2"]);
387        assert_eq!(mv.name, "my_mv");
388        assert_eq!(mv.source_refs, vec!["table1", "table2"]);
389    }
390
391    #[test]
392    fn test_detailed_derivation() {
393        let sources = vec![
394            SourceDefinition::new("trades"),
395            SourceDefinition::new("orders"),
396            SourceDefinition::new("orphan"),
397        ];
398        let mvs = vec![
399            MvDefinition::from_source("vwap", "trades"),
400            MvDefinition::from_source("max_price", "trades"),
401            MvDefinition::from_source("order_count", "orders"),
402        ];
403
404        let result = derive_channel_types_detailed(&sources, &mvs);
405
406        assert_eq!(result.broadcast_count, 1); // trades
407        assert_eq!(result.spsc_count, 2); // orders, orphan
408        assert!(result.orphaned_sources.contains(&"orphan".to_string()));
409    }
410
411    #[test]
412    fn test_three_consumers() {
413        let sources = vec![SourceDefinition::new("events")];
414        let mvs = vec![
415            MvDefinition::from_source("mv1", "events"),
416            MvDefinition::from_source("mv2", "events"),
417            MvDefinition::from_source("mv3", "events"),
418        ];
419
420        let channel_types = derive_channel_types(&sources, &mvs);
421
422        assert_eq!(
423            channel_types.get("events"),
424            Some(&DerivedChannelType::Broadcast { consumer_count: 3 })
425        );
426    }
427}