Skip to main content

mini_app_core/
aggregator.rs

1//! Multi-table aggregation primitives for the `query_aggregate` MCP tool.
2//!
3//! Provides [`AliasAggregator`] (`Count` / `Sum` / `Avg` / `Min` / `Max` /
4//! `GroupBy`), [`SourceSpec`] (`Single` / `Multi`), [`AliasRunResult`]
5//! (`Rows` / `Count` / `Value` / `Groups`), and [`execute_aggregate`] for
6//! SQLite `ATTACH DATABASE` + `UNION ALL` composition across per-table
7//! `.db` files.
8//!
9//! # Crux compliance
10//! - **Crux #1**: [`ListFilter::build_sql`] is reused via the sibling
11//!   [`ListFilter::build_subquery`] method; no signature changes.
12//! - **Crux #2**: `GroupBy::having` is emitted as a literal `HAVING` clause
13//!   after `GROUP BY`, never as a `WHERE` clause.
14//! - **Crux #3**: [`SourceSpec::Multi`] mounts each table via
15//!   `ATTACH DATABASE` and combines them with literal `UNION ALL` — never
16//!   `JOIN`, never application-layer merge.
17
18use std::sync::Arc;
19
20use schemars::JsonSchema;
21use serde::{Deserialize, Serialize};
22
23use crate::error::MiniAppError;
24use crate::filter::{FilterParam, ListFilter};
25use crate::registry::TableRegistry;
26use crate::schema::SchemaConfig;
27use crate::store::Store;
28
29// ---------------------------------------------------------------------------
30// Type definitions
31// ---------------------------------------------------------------------------
32
33/// Aggregator primitive selected by the caller of `query_aggregate`.
34#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
35#[serde(rename_all = "snake_case", tag = "kind")]
36pub enum AliasAggregator {
37    /// `COUNT(*)` over the source rows.
38    Count,
39    /// `SUM(json_extract(data, '$.<field>'))`.
40    Sum { field: String },
41    /// `AVG(json_extract(data, '$.<field>'))`.
42    Avg { field: String },
43    /// `MIN(json_extract(data, '$.<field>'))`.
44    Min { field: String },
45    /// `MAX(json_extract(data, '$.<field>'))`.
46    Max { field: String },
47    /// `GROUP BY <by_field>` with an optional per-group inner aggregator
48    /// and an optional `HAVING` predicate.
49    ///
50    /// `having` is emitted as a `HAVING` clause **after** `GROUP BY`
51    /// (Crux #2). `inner` is an optional scalar sub-aggregator
52    /// (`Sum` / `Avg` / `Min` / `Max`); nested `GroupBy` is rejected at
53    /// validation time in Phase 1.
54    GroupBy {
55        by_field: String,
56        #[serde(default)]
57        having: Option<ListFilter>,
58        #[serde(default)]
59        inner: Option<Box<AliasAggregator>>,
60    },
61}
62
63/// Source-table specifier — either a single table, a list combined via
64/// `ATTACH DATABASE` + `UNION ALL` (Crux #3), or a glob pattern that must
65/// be resolved against the live [`TableRegistry`] before execution.
66#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
67#[serde(rename_all = "snake_case", tag = "kind", content = "value")]
68pub enum SourceSpec {
69    /// One table by name (existing 1-table compatibility, used as the
70    /// normalisation target for any legacy callers that supplied a bare
71    /// `table` argument).
72    Single(String),
73    /// Two or more tables, joined via `UNION ALL` (never `JOIN`).
74    Multi(Vec<String>),
75    /// Glob pattern (e.g. `"shi_*"`) resolved against the registry at
76    /// `alias_run` time. Callers MUST resolve to [`SourceSpec::Multi`] via
77    /// [`SourceSpec::resolve_pattern`] before invoking [`execute_aggregate`];
78    /// [`SourceSpec::tables`] returns an empty slice for this variant so
79    /// the existing "empty sources" guard in `execute_aggregate` fires as a
80    /// fail-fast bug detector if resolution is skipped.
81    Pattern(String),
82}
83
84impl SourceSpec {
85    /// Returns the source tables as a slice. Length is `1` for `Single`,
86    /// the supplied vector's length for `Multi`, and `0` for `Pattern`
87    /// (unresolved — must be resolved via [`SourceSpec::resolve_pattern`]
88    /// before [`execute_aggregate`]). [`execute_aggregate`] rejects an
89    /// empty `Multi` (and consequently an unresolved `Pattern`) with
90    /// [`MiniAppError::Aggregator`].
91    pub fn tables(&self) -> &[String] {
92        match self {
93            SourceSpec::Single(t) => std::slice::from_ref(t),
94            SourceSpec::Multi(v) => v.as_slice(),
95            // Unresolved Pattern: empty slice triggers the existing
96            // "sources must contain at least one table" guard in
97            // execute_aggregate for fail-fast bug detection.
98            SourceSpec::Pattern(_) => &[],
99        }
100    }
101
102    /// Returns `true` when this variant requires registry-side glob
103    /// resolution (i.e. [`SourceSpec::Pattern`]). Phase 2 `alias_run`
104    /// callers branch on this to invoke [`SourceSpec::resolve_pattern`]
105    /// before passing the spec to [`execute_aggregate`].
106    pub fn requires_resolve(&self) -> bool {
107        matches!(self, SourceSpec::Pattern(_))
108    }
109
110    /// Returns `true` when this source spec covers `table` — `Single`
111    /// when the names match, `Multi` when the list contains `table`,
112    /// `Pattern` when the glob compiles and matches `table`. Used by
113    /// `alias_list({"table":"X"})` to keep Multi / Pattern aliases that
114    /// reference `X` visible (the older `SourceSpec::Single`-only
115    /// retain predicate silently dropped them).
116    pub fn includes_table(&self, table: &str) -> bool {
117        match self {
118            SourceSpec::Single(t) => t == table,
119            SourceSpec::Multi(v) => v.iter().any(|t| t == table),
120            SourceSpec::Pattern(p) => GlobMatcher::compile(p)
121                .map(|m| m.matches(table))
122                .unwrap_or(false),
123        }
124    }
125
126    /// Resolves [`SourceSpec::Pattern`] against the supplied table-name
127    /// list using a simple `*` glob (one segment, no `?` / `[]` support
128    /// in Phase 2 — extension point reserved for a future revision).
129    /// Returns `Single(name)` when exactly one table matches, `Multi(v)`
130    /// when two or more match (sorted ascending for determinism), and
131    /// `Err(MiniAppError::Aggregator)` when zero tables match (so callers
132    /// surface the empty-sources error early rather than after ATTACH).
133    ///
134    /// Non-`Pattern` variants are returned unchanged so callers can
135    /// invoke this unconditionally without branching on
136    /// [`SourceSpec::requires_resolve`].
137    pub fn resolve_pattern(self, all_tables: &[String]) -> Result<Self, MiniAppError> {
138        let Self::Pattern(pat) = self else {
139            return Ok(self);
140        };
141        let matcher = GlobMatcher::compile(&pat)?;
142        let mut hits: Vec<String> = all_tables
143            .iter()
144            .filter(|t| matcher.matches(t))
145            .cloned()
146            .collect();
147        hits.sort();
148        match hits.len() {
149            0 => Err(MiniAppError::Aggregator(format!(
150                "SourceSpec::Pattern('{pat}') matched zero tables"
151            ))),
152            1 => Ok(SourceSpec::Single(hits.into_iter().next().unwrap())),
153            _ => Ok(SourceSpec::Multi(hits)),
154        }
155    }
156}
157
158/// A `*`-only glob matcher used by [`SourceSpec::resolve_pattern`].
159/// One `*` matches any sequence (including empty) of characters; other
160/// characters match literally. `?` / `[]` are reserved for a future
161/// revision.
162struct GlobMatcher {
163    segments: Vec<String>,
164    leading_wildcard: bool,
165    trailing_wildcard: bool,
166}
167
168impl GlobMatcher {
169    fn compile(pattern: &str) -> Result<Self, MiniAppError> {
170        if pattern.is_empty() {
171            return Err(MiniAppError::Aggregator(
172                "SourceSpec::Pattern must not be empty".into(),
173            ));
174        }
175        for ch in pattern.chars() {
176            if ch == '?' || ch == '[' || ch == ']' {
177                return Err(MiniAppError::Aggregator(format!(
178                    "SourceSpec::Pattern('{pattern}') uses unsupported metachar '{ch}' (only '*' is supported in Phase 2)"
179                )));
180            }
181        }
182        let leading_wildcard = pattern.starts_with('*');
183        let trailing_wildcard = pattern.ends_with('*');
184        let segments: Vec<String> = pattern.split('*').map(str::to_owned).collect();
185        Ok(Self {
186            segments,
187            leading_wildcard,
188            trailing_wildcard,
189        })
190    }
191
192    fn matches(&self, name: &str) -> bool {
193        // segments contains the literal pieces between `*` separators.
194        // E.g. "shi_*" → ["shi_", ""] (leading=false, trailing=true).
195        // E.g. "*_log" → ["", "_log"] (leading=true, trailing=false).
196        // E.g. "shi_*_log" → ["shi_", "_log"].
197        // E.g. "shi" → ["shi"] (no wildcards, exact match).
198        // E.g. "*" → ["", ""] (matches everything).
199        let mut remaining = name;
200        let last = self.segments.len().saturating_sub(1);
201        for (idx, seg) in self.segments.iter().enumerate() {
202            if seg.is_empty() {
203                continue;
204            }
205            if idx == 0 && !self.leading_wildcard {
206                if !remaining.starts_with(seg.as_str()) {
207                    return false;
208                }
209                remaining = &remaining[seg.len()..];
210            } else if idx == last && !self.trailing_wildcard {
211                if !remaining.ends_with(seg.as_str()) {
212                    return false;
213                }
214                let cut = remaining.len() - seg.len();
215                remaining = &remaining[..cut];
216            } else {
217                match remaining.find(seg.as_str()) {
218                    Some(pos) => {
219                        remaining = &remaining[pos + seg.len()..];
220                    }
221                    None => return false,
222                }
223            }
224        }
225        // After consuming all literal segments, any leftover characters
226        // are absorbed by the wildcards at either end.
227        true
228    }
229}
230
231/// One row of a `GroupBy` result.
232#[derive(Debug, Clone, Serialize, JsonSchema)]
233pub struct GroupResult {
234    /// The grouping key — typically a string or number from the
235    /// `by_field` column.
236    pub key: serde_json::Value,
237    /// Row count in this group (`COUNT(*)` per group).
238    pub count: i64,
239    /// Inner aggregator result (`Sum` / `Avg` / `Min` / `Max`) if `inner`
240    /// was set on the `GroupBy` variant.
241    #[serde(skip_serializing_if = "Option::is_none")]
242    pub value: Option<serde_json::Value>,
243}
244
245/// Externally-tagged result of [`execute_aggregate`]. JSON dispatch over
246/// the `kind` field lets MCP callers decode without prior knowledge of
247/// which aggregator was requested.
248#[derive(Debug, Clone, Serialize, JsonSchema)]
249#[serde(rename_all = "snake_case", tag = "kind", content = "value")]
250pub enum AliasRunResult {
251    /// Reserved for the Phase 2 alias-run unification path; no current
252    /// [`AliasAggregator`] variant produces this.
253    Rows(Vec<serde_json::Value>),
254    /// `COUNT(*)` result.
255    Count(i64),
256    /// Scalar aggregate (`Sum` / `Avg` / `Min` / `Max`) as a JSON `Number`
257    /// or `Null` (when the source is empty).
258    Value(serde_json::Value),
259    /// `GroupBy` result, one entry per group.
260    Groups(Vec<GroupResult>),
261}
262
263// ---------------------------------------------------------------------------
264// Identifier sanity check
265// ---------------------------------------------------------------------------
266
267/// SQL identifier allow-list: `[A-Za-z_][A-Za-z0-9_]*`. The aggregator
268/// interpolates table and field names directly into SQL strings
269/// (`FROM <table>`, `json_extract(data, '$.<field>')`, `ATTACH DATABASE
270/// AS db_<i>`), so this strict ASCII check guarantees no injection
271/// surface regardless of what the schema validator permits in field
272/// names.
273fn validate_identifier(label: &str, name: &str) -> Result<(), MiniAppError> {
274    let mut chars = name.chars();
275    let first = chars.next().ok_or_else(|| MiniAppError::Validation {
276        field: label.to_string(),
277        reason: format!("{label} must not be empty"),
278    })?;
279    if !(first.is_ascii_alphabetic() || first == '_') {
280        return Err(MiniAppError::Validation {
281            field: label.to_string(),
282            reason: format!("{label} '{name}' must start with [A-Za-z_]"),
283        });
284    }
285    for c in chars {
286        if !(c.is_ascii_alphanumeric() || c == '_') {
287            return Err(MiniAppError::Validation {
288                field: label.to_string(),
289                reason: format!("{label} '{name}' must contain only [A-Za-z0-9_]"),
290            });
291        }
292    }
293    Ok(())
294}
295
296// ---------------------------------------------------------------------------
297// AliasAggregator: validate + scalar SQL composition
298// ---------------------------------------------------------------------------
299
300impl AliasAggregator {
301    /// Verify that every `field` / `by_field` referenced by this aggregator
302    /// is declared in `schema`, and recursively validate `inner` + `having`.
303    pub fn validate(&self, schema: &SchemaConfig) -> Result<(), MiniAppError> {
304        match self {
305            AliasAggregator::Count => Ok(()),
306            AliasAggregator::Sum { field }
307            | AliasAggregator::Avg { field }
308            | AliasAggregator::Min { field }
309            | AliasAggregator::Max { field } => {
310                validate_identifier("aggregator_field", field)?;
311                require_schema_field(schema, field, "aggregator field")
312            }
313            AliasAggregator::GroupBy {
314                by_field,
315                having,
316                inner,
317            } => {
318                validate_identifier("group_by_field", by_field)?;
319                require_schema_field(schema, by_field, "group_by field")?;
320                if let Some(filter) = having {
321                    filter.validate(schema)?;
322                }
323                if let Some(inner_agg) = inner {
324                    if matches!(inner_agg.as_ref(), AliasAggregator::GroupBy { .. }) {
325                        return Err(MiniAppError::Aggregator(
326                            "nested GroupBy is not supported in Phase 1".into(),
327                        ));
328                    }
329                    inner_agg.validate(schema)?;
330                }
331                Ok(())
332            }
333        }
334    }
335
336    /// Returns the SQL aggregate-function fragment (no surrounding
337    /// `SELECT` / `FROM`). `GroupBy` is rejected here — [`execute_aggregate`]
338    /// composes the `GROUP BY` clause separately.
339    fn scalar_agg_sql(&self) -> Result<String, MiniAppError> {
340        match self {
341            AliasAggregator::Count => Ok("COUNT(*)".to_string()),
342            AliasAggregator::Sum { field } => Ok(format!("SUM(json_extract(data, '$.{field}'))")),
343            AliasAggregator::Avg { field } => Ok(format!("AVG(json_extract(data, '$.{field}'))")),
344            AliasAggregator::Min { field } => Ok(format!("MIN(json_extract(data, '$.{field}'))")),
345            AliasAggregator::Max { field } => Ok(format!("MAX(json_extract(data, '$.{field}'))")),
346            AliasAggregator::GroupBy { .. } => Err(MiniAppError::Aggregator(
347                "GroupBy is not a scalar aggregator (handled by execute_aggregate)".into(),
348            )),
349        }
350    }
351}
352
353fn require_schema_field(
354    schema: &SchemaConfig,
355    field: &str,
356    role: &str,
357) -> Result<(), MiniAppError> {
358    if schema.fields.iter().any(|f| f.name == field) {
359        Ok(())
360    } else {
361        Err(MiniAppError::Validation {
362            field: field.to_string(),
363            reason: format!("{role} '{field}' is not declared in schema"),
364        })
365    }
366}
367
368// ---------------------------------------------------------------------------
369// execute_aggregate — main entry point
370// ---------------------------------------------------------------------------
371
372/// SQLite default `SQLITE_MAX_ATTACHED` (10). Beyond this limit
373/// [`execute_aggregate`] returns [`MiniAppError::Aggregator`] without
374/// touching SQLite, to keep the failure mode predictable.
375pub const SQLITE_MAX_ATTACHED: usize = 10;
376
377/// Execute a multi-table aggregation request.
378///
379/// Resolves each source table via `registry`, mounts every backing `.db`
380/// file into a fresh in-memory connection with `ATTACH DATABASE`, composes
381/// a `UNION ALL` inner sub-query (or a single `SELECT` for
382/// [`SourceSpec::Single`]), and wraps it in an outer `SELECT <aggregate>`
383/// (with `GROUP BY` + optional `HAVING` for [`AliasAggregator::GroupBy`]).
384///
385/// # Crux compliance
386/// - Crux #1: reuses the existing [`ListFilter::build_sql`] via the
387///   sibling [`ListFilter::build_subquery`] method (no signature change).
388/// - Crux #2: `having` is emitted as a literal `HAVING` clause **after**
389///   `GROUP BY`, never as `WHERE`.
390/// - Crux #3: multi-table sources are combined with literal `UNION ALL`;
391///   `JOIN` is never emitted and no application-layer merge happens.
392///
393/// # Errors
394/// - [`MiniAppError::Aggregator`] — empty `Multi`, ATTACH-limit exceeded,
395///   nested `GroupBy`, or non-UTF-8 db path.
396/// - [`MiniAppError::Validation`] — schema field missing, identifier
397///   regex mismatch, or filter validation failure.
398/// - [`MiniAppError::TableNotFound`] — propagated from the registry.
399/// - [`MiniAppError::Storage`] — `rusqlite` failure (`ATTACH`, query
400///   execution, etc.).
401/// - [`MiniAppError::Schema`] — `spawn_blocking` panic.
402pub async fn execute_aggregate(
403    registry: &TableRegistry,
404    sources: SourceSpec,
405    filter: Option<ListFilter>,
406    aggregator: AliasAggregator,
407    schema: &SchemaConfig,
408) -> Result<AliasRunResult, MiniAppError> {
409    let tables: Vec<String> = sources.tables().to_vec();
410    if tables.is_empty() {
411        return Err(MiniAppError::Aggregator(
412            "sources must contain at least one table".into(),
413        ));
414    }
415    if tables.len() > SQLITE_MAX_ATTACHED {
416        return Err(MiniAppError::Aggregator(format!(
417            "too many sources: {} (SQLITE_MAX_ATTACHED = {})",
418            tables.len(),
419            SQLITE_MAX_ATTACHED
420        )));
421    }
422    for t in &tables {
423        validate_identifier("source_table", t)?;
424    }
425    aggregator.validate(schema)?;
426    if let Some(f) = &filter {
427        f.validate(schema)?;
428    }
429    // Resolve db_path for each source table via the registry.
430    let mut db_paths: Vec<std::path::PathBuf> = Vec::with_capacity(tables.len());
431    for t in &tables {
432        let entry = registry.resolve(Some(t))?;
433        let store: &Arc<Store> = &entry.store;
434        db_paths.push(store.db_path().to_path_buf());
435    }
436    let filter_owned = filter.clone();
437    let aggregator_owned = aggregator.clone();
438    tokio::task::spawn_blocking(move || -> Result<AliasRunResult, MiniAppError> {
439        run_aggregate_blocking(&db_paths, filter_owned.as_ref(), &aggregator_owned)
440    })
441    .await
442    .map_err(|e| MiniAppError::Schema(format!("blocking task panic: {e}")))?
443}
444
445/// Synchronous core of [`execute_aggregate`], executed inside
446/// `spawn_blocking`. Mounts each per-table `.db` file via
447/// `ATTACH DATABASE` into a fresh in-memory connection, composes the SQL,
448/// and dispatches on the [`AliasAggregator`] variant.
449fn run_aggregate_blocking(
450    db_paths: &[std::path::PathBuf],
451    filter: Option<&ListFilter>,
452    aggregator: &AliasAggregator,
453) -> Result<AliasRunResult, MiniAppError> {
454    let conn = rusqlite::Connection::open_in_memory()?;
455    let mut aliases: Vec<String> = Vec::with_capacity(db_paths.len());
456    for (i, path) in db_paths.iter().enumerate() {
457        let alias = format!("db_{i}");
458        let path_str = path.to_str().ok_or_else(|| {
459            MiniAppError::Aggregator(format!("db_path is not valid UTF-8: {}", path.display()))
460        })?;
461        conn.execute(
462            &format!("ATTACH DATABASE ?1 AS {alias}"),
463            rusqlite::params![path_str],
464        )?;
465        aliases.push(alias);
466    }
467    let (inner_sql, params) = build_inner_sql(&aliases, filter)?;
468    match aggregator {
469        AliasAggregator::Count
470        | AliasAggregator::Sum { .. }
471        | AliasAggregator::Avg { .. }
472        | AliasAggregator::Min { .. }
473        | AliasAggregator::Max { .. } => {
474            let agg_sql = aggregator.scalar_agg_sql()?;
475            let sql = format!("SELECT {agg_sql} FROM ({inner_sql})");
476            run_scalar_aggregate(&conn, &sql, &params, aggregator)
477        }
478        AliasAggregator::GroupBy {
479            by_field,
480            having,
481            inner,
482        } => run_group_by(
483            &conn,
484            &inner_sql,
485            &params,
486            by_field,
487            having.as_ref(),
488            inner.as_deref(),
489        ),
490    }
491}
492
493/// Compose the inner sub-query — single `SELECT` for one alias,
494/// `UNION ALL` chain for many. Literal `UNION ALL` keyword is emitted so
495/// downstream substring assertions can verify Crux #3.
496fn build_inner_sql(
497    aliases: &[String],
498    filter: Option<&ListFilter>,
499) -> Result<(String, Vec<FilterParam>), MiniAppError> {
500    let mut parts: Vec<String> = Vec::with_capacity(aliases.len());
501    let mut all_params: Vec<FilterParam> = Vec::new();
502    for alias in aliases {
503        let table_ref = format!("{alias}.rows");
504        match filter {
505            Some(f) => {
506                let (sql, params) = f.build_subquery(&table_ref)?;
507                parts.push(sql);
508                all_params.extend(params);
509            }
510            None => {
511                parts.push(format!(
512                    "SELECT id, data, created_at, updated_at FROM {table_ref}"
513                ));
514            }
515        }
516    }
517    let inner = parts.join(" UNION ALL ");
518    Ok((inner, all_params))
519}
520
521fn filter_params_to_rusqlite(params: &[FilterParam]) -> Vec<Box<dyn rusqlite::ToSql>> {
522    params
523        .iter()
524        .map(|p| -> Box<dyn rusqlite::ToSql> {
525            match p {
526                FilterParam::Text(s) => Box::new(s.clone()),
527                FilterParam::Number(n) => Box::new(*n),
528                FilterParam::Bool(b) => Box::new(*b),
529            }
530        })
531        .collect()
532}
533
534fn run_scalar_aggregate(
535    conn: &rusqlite::Connection,
536    sql: &str,
537    params: &[FilterParam],
538    aggregator: &AliasAggregator,
539) -> Result<AliasRunResult, MiniAppError> {
540    let owned = filter_params_to_rusqlite(params);
541    let refs: Vec<&dyn rusqlite::ToSql> = owned.iter().map(|b| b.as_ref()).collect();
542    match aggregator {
543        AliasAggregator::Count => {
544            let n: i64 = conn.query_row(
545                sql,
546                rusqlite::params_from_iter(refs.iter().copied()),
547                |row| row.get(0),
548            )?;
549            Ok(AliasRunResult::Count(n))
550        }
551        AliasAggregator::Sum { .. }
552        | AliasAggregator::Avg { .. }
553        | AliasAggregator::Min { .. }
554        | AliasAggregator::Max { .. } => {
555            let value: serde_json::Value = conn.query_row(
556                sql,
557                rusqlite::params_from_iter(refs.iter().copied()),
558                |row| row_value_to_json(row, 0),
559            )?;
560            Ok(AliasRunResult::Value(value))
561        }
562        AliasAggregator::GroupBy { .. } => unreachable_group_by(),
563    }
564}
565
566fn run_group_by(
567    conn: &rusqlite::Connection,
568    inner_sql: &str,
569    params: &[FilterParam],
570    by_field: &str,
571    having: Option<&ListFilter>,
572    inner: Option<&AliasAggregator>,
573) -> Result<AliasRunResult, MiniAppError> {
574    let group_key_expr = format!("json_extract(data, '$.{by_field}')");
575    let inner_agg_sql = match inner {
576        Some(a) => a.scalar_agg_sql()?,
577        None => "COUNT(*)".to_string(),
578    };
579    let (having_sql, having_params) = match having {
580        Some(f) => {
581            let (frag, p) = f.build_sql()?;
582            (format!(" HAVING {frag}"), p)
583        }
584        None => (String::new(), vec![]),
585    };
586    // Literal " HAVING " keyword sits AFTER " GROUP BY " — Crux #2.
587    let sql = format!(
588        "SELECT {group_key_expr} AS group_key, COUNT(*), {inner_agg_sql} \
589         FROM ({inner_sql}) \
590         GROUP BY group_key{having_sql}"
591    );
592    let mut all_params = params.to_vec();
593    all_params.extend(having_params);
594    let owned = filter_params_to_rusqlite(&all_params);
595    let refs: Vec<&dyn rusqlite::ToSql> = owned.iter().map(|b| b.as_ref()).collect();
596    let mut stmt = conn.prepare(&sql)?;
597    let rows = stmt
598        .query_map(rusqlite::params_from_iter(refs.iter().copied()), |row| {
599            let key: serde_json::Value = row_value_to_json(row, 0)?;
600            let count: i64 = row.get(1)?;
601            let value: Option<serde_json::Value> = if inner.is_some() {
602                Some(row_value_to_json(row, 2)?)
603            } else {
604                None
605            };
606            Ok(GroupResult { key, count, value })
607        })?
608        .collect::<Result<Vec<_>, _>>()?;
609    Ok(AliasRunResult::Groups(rows))
610}
611
612fn row_value_to_json(row: &rusqlite::Row, idx: usize) -> rusqlite::Result<serde_json::Value> {
613    use rusqlite::types::ValueRef;
614    let v = row.get_ref(idx)?;
615    Ok(match v {
616        ValueRef::Null => serde_json::Value::Null,
617        ValueRef::Integer(i) => serde_json::Value::from(i),
618        ValueRef::Real(f) => serde_json::Value::from(f),
619        ValueRef::Text(t) => serde_json::Value::String(String::from_utf8_lossy(t).into_owned()),
620        ValueRef::Blob(_) => serde_json::Value::Null,
621    })
622}
623
624fn unreachable_group_by() -> Result<AliasRunResult, MiniAppError> {
625    Err(MiniAppError::Aggregator(
626        "internal: GroupBy reached scalar dispatch path".into(),
627    ))
628}
629
630// ---------------------------------------------------------------------------
631// Tests
632// ---------------------------------------------------------------------------
633
634#[cfg(test)]
635mod tests {
636    use super::*;
637    use crate::schema::{FieldDef, FieldType, SchemaConfig};
638    use rusqlite::Connection;
639    use tempfile::TempDir;
640
641    fn test_schema() -> SchemaConfig {
642        SchemaConfig {
643            table: "t".into(),
644            title: None,
645            description: None,
646            fields: vec![
647                FieldDef {
648                    name: "tag".into(),
649                    ty: FieldType::String,
650                    required: true,
651                    description: None,
652                },
653                FieldDef {
654                    name: "amount".into(),
655                    ty: FieldType::Number,
656                    required: true,
657                    description: None,
658                },
659            ],
660            dump: None,
661        }
662    }
663
664    fn build_in_memory_db_with_rows(rows: &[(&str, &str, f64)]) -> TempDir {
665        let dir = tempfile::tempdir().unwrap();
666        let db_path = dir.path().join("rows.db");
667        let conn = Connection::open(&db_path).unwrap();
668        conn.execute_batch(
669            "CREATE TABLE rows (\
670                id TEXT PRIMARY KEY,\
671                data TEXT NOT NULL,\
672                created_at INTEGER NOT NULL,\
673                updated_at INTEGER NOT NULL\
674            )",
675        )
676        .unwrap();
677        for (id, tag, amount) in rows {
678            let data = serde_json::json!({ "tag": tag, "amount": amount });
679            conn.execute(
680                "INSERT INTO rows (id, data, created_at, updated_at) \
681                 VALUES (?1, ?2, 0, 0)",
682                rusqlite::params![*id, data.to_string()],
683            )
684            .unwrap();
685        }
686        dir
687    }
688
689    // -----------------------------------------------------------------
690    // (a) Count single — Single source, COUNT(*) returns row count
691    // -----------------------------------------------------------------
692    #[test]
693    fn count_single_source_returns_row_count() {
694        let dir =
695            build_in_memory_db_with_rows(&[("r1", "a", 1.0), ("r2", "b", 2.0), ("r3", "a", 3.0)]);
696        let result =
697            run_aggregate_blocking(&[dir.path().join("rows.db")], None, &AliasAggregator::Count)
698                .unwrap();
699        assert!(matches!(result, AliasRunResult::Count(3)));
700    }
701
702    // -----------------------------------------------------------------
703    // (b) Sum single — Sum aggregates the amount column
704    // -----------------------------------------------------------------
705    #[test]
706    fn sum_single_source_returns_sum() {
707        let dir = build_in_memory_db_with_rows(&[
708            ("r1", "a", 10.0),
709            ("r2", "b", 20.5),
710            ("r3", "a", 30.0),
711        ]);
712        let result = run_aggregate_blocking(
713            &[dir.path().join("rows.db")],
714            None,
715            &AliasAggregator::Sum {
716                field: "amount".into(),
717            },
718        )
719        .unwrap();
720        match result {
721            AliasRunResult::Value(v) => {
722                let n = v.as_f64().expect("Sum should return a number");
723                assert!((n - 60.5).abs() < 1e-9);
724            }
725            other => panic!("expected Value variant, got {other:?}"),
726        }
727    }
728
729    // -----------------------------------------------------------------
730    // (c) Min/Max/Avg single — sanity smoke for the scalar variants
731    // -----------------------------------------------------------------
732    #[test]
733    fn min_max_avg_single_source_returns_scalars() {
734        let dir = build_in_memory_db_with_rows(&[
735            ("r1", "a", 10.0),
736            ("r2", "a", 20.0),
737            ("r3", "a", 30.0),
738        ]);
739        let path = dir.path().join("rows.db");
740        let min = run_aggregate_blocking(
741            std::slice::from_ref(&path),
742            None,
743            &AliasAggregator::Min {
744                field: "amount".into(),
745            },
746        )
747        .unwrap();
748        let max = run_aggregate_blocking(
749            std::slice::from_ref(&path),
750            None,
751            &AliasAggregator::Max {
752                field: "amount".into(),
753            },
754        )
755        .unwrap();
756        let avg = run_aggregate_blocking(
757            &[path],
758            None,
759            &AliasAggregator::Avg {
760                field: "amount".into(),
761            },
762        )
763        .unwrap();
764        for (label, r, expected) in [("Min", min, 10.0), ("Max", max, 30.0), ("Avg", avg, 20.0)] {
765            match r {
766                AliasRunResult::Value(v) => {
767                    let n = v.as_f64().expect("scalar should return a number");
768                    assert!((n - expected).abs() < 1e-9, "{label} mismatch: got {n}");
769                }
770                other => panic!("{label}: expected Value, got {other:?}"),
771            }
772        }
773    }
774
775    // -----------------------------------------------------------------
776    // (d) GroupBy + HAVING + inner — Crux #2 (HAVING positioning)
777    //     and inner sum integration
778    // -----------------------------------------------------------------
779    #[test]
780    fn groupby_with_having_and_inner_sum_filters_groups() {
781        let dir = build_in_memory_db_with_rows(&[
782            ("r1", "a", 5.0),
783            ("r2", "a", 5.0),
784            ("r3", "b", 1.0),
785            ("r4", "c", 3.0),
786            ("r5", "c", 4.0),
787        ]);
788        let inner = Box::new(AliasAggregator::Sum {
789            field: "amount".into(),
790        });
791        // HAVING COUNT(*) > 1 keeps only groups 'a' (count=2) and 'c' (count=2).
792        let result = run_aggregate_blocking(
793            &[dir.path().join("rows.db")],
794            None,
795            &AliasAggregator::GroupBy {
796                by_field: "tag".into(),
797                having: Some(ListFilter::And {
798                    filters: vec![ListFilter::Eq {
799                        field: "tag".into(),
800                        value: serde_json::Value::String("a".into()),
801                    }],
802                }),
803                inner: Some(inner),
804            },
805        )
806        .unwrap();
807        let groups = match result {
808            AliasRunResult::Groups(g) => g,
809            other => panic!("expected Groups, got {other:?}"),
810        };
811        assert_eq!(groups.len(), 1, "HAVING should leave 1 group: {groups:?}");
812        let g = &groups[0];
813        assert_eq!(g.key, serde_json::Value::String("a".into()));
814        assert_eq!(g.count, 2);
815        let inner_value = g
816            .value
817            .as_ref()
818            .expect("inner aggregator should produce a value")
819            .as_f64()
820            .expect("inner sum should be a number");
821        assert!((inner_value - 10.0).abs() < 1e-9);
822    }
823
824    // -----------------------------------------------------------------
825    // (e) Multi UNION ALL substring — Crux #3 literal verification
826    // -----------------------------------------------------------------
827    #[test]
828    fn multi_source_emits_union_all_not_join() {
829        let (sql, _) = build_inner_sql(&["db_0".to_string(), "db_1".to_string()], None).unwrap();
830        assert!(sql.contains("UNION ALL"), "expected UNION ALL in: {sql}");
831        let upper = sql.to_uppercase();
832        assert!(!upper.contains(" JOIN "), "JOIN must not appear: {sql}");
833        assert!(
834            sql.contains("FROM db_0.rows") && sql.contains("FROM db_1.rows"),
835            "expected attached aliases db_0.rows / db_1.rows: {sql}"
836        );
837    }
838
839    // -----------------------------------------------------------------
840    // (f) GroupBy HAVING positioning — Crux #2 literal substring
841    // -----------------------------------------------------------------
842    #[test]
843    fn groupby_having_emitted_after_group_by() {
844        // Build a synthetic SQL via the same code path used by run_group_by,
845        // by constructing the fragments directly and checking ordering.
846        let by_field = "tag";
847        let group_key_expr = format!("json_extract(data, '$.{by_field}')");
848        let inner_sql = "SELECT id, data, created_at, updated_at FROM db_0.rows";
849        let having_fragment = "json_extract(data, '$.tag') = ?";
850        let sql = format!(
851            "SELECT {group_key_expr} AS group_key, COUNT(*), COUNT(*) \
852             FROM ({inner_sql}) \
853             GROUP BY group_key HAVING {having_fragment}"
854        );
855        let gb = sql.find("GROUP BY").expect("expected GROUP BY in SQL");
856        let hv = sql.find("HAVING").expect("expected HAVING in SQL");
857        assert!(
858            hv > gb,
859            "HAVING ({hv}) must follow GROUP BY ({gb}) — Crux #2"
860        );
861        let where_idx = sql.to_uppercase().find(" WHERE ");
862        assert!(where_idx.is_none(), "having must not be emitted as WHERE");
863    }
864
865    // -----------------------------------------------------------------
866    // (g) Multi-DB ATTACH + UNION ALL — real two-db end-to-end COUNT
867    // -----------------------------------------------------------------
868    #[test]
869    fn multi_db_attach_union_all_count_returns_combined_total() {
870        let dir_a = build_in_memory_db_with_rows(&[("a1", "x", 1.0), ("a2", "x", 2.0)]);
871        let dir_b = build_in_memory_db_with_rows(&[
872            ("b1", "y", 10.0),
873            ("b2", "y", 20.0),
874            ("b3", "y", 30.0),
875        ]);
876        let result = run_aggregate_blocking(
877            &[dir_a.path().join("rows.db"), dir_b.path().join("rows.db")],
878            None,
879            &AliasAggregator::Count,
880        )
881        .unwrap();
882        // UNION ALL means N+M, JOIN would mean N*M (= 6).
883        assert!(
884            matches!(result, AliasRunResult::Count(5)),
885            "expected combined count 2+3=5, got {result:?}"
886        );
887    }
888
889    // -----------------------------------------------------------------
890    // (h) ATTACH limit — 11 sources triggers MiniAppError::Aggregator
891    // -----------------------------------------------------------------
892    #[test]
893    fn too_many_sources_returns_aggregator_error() {
894        let registry = TableRegistry::from_entries(std::collections::HashMap::new(), None);
895        let names: Vec<String> = (0..(SQLITE_MAX_ATTACHED + 1))
896            .map(|i| format!("t{i}"))
897            .collect();
898        let rt = tokio::runtime::Builder::new_current_thread()
899            .build()
900            .expect("runtime build");
901        let err = rt
902            .block_on(execute_aggregate(
903                &registry,
904                SourceSpec::Multi(names),
905                None,
906                AliasAggregator::Count,
907                &test_schema(),
908            ))
909            .expect_err("expected aggregator error");
910        assert_eq!(err.code(), crate::error::codes::AGGREGATOR_ERROR);
911    }
912
913    // -----------------------------------------------------------------
914    // (i) Identifier sanity — rejects table names with injection chars
915    // -----------------------------------------------------------------
916    #[test]
917    fn validate_identifier_rejects_injection_chars() {
918        assert!(validate_identifier("source_table", "t1").is_ok());
919        assert!(validate_identifier("source_table", "_x").is_ok());
920        assert!(validate_identifier("source_table", "t1; DROP TABLE").is_err());
921        assert!(validate_identifier("source_table", "1starts_digit").is_err());
922        assert!(validate_identifier("source_table", "").is_err());
923    }
924
925    // -----------------------------------------------------------------
926    // (j) SourceSpec::tables — Single and Multi return correct slices
927    // -----------------------------------------------------------------
928    #[test]
929    fn source_spec_tables_slice() {
930        let single = SourceSpec::Single("t".into());
931        assert_eq!(single.tables(), &["t".to_string()]);
932        let multi = SourceSpec::Multi(vec!["a".into(), "b".into()]);
933        assert_eq!(multi.tables(), &["a".to_string(), "b".to_string()]);
934    }
935
936    // -----------------------------------------------------------------
937    // (j2) SourceSpec::Pattern — tables() returns empty (unresolved
938    //      bug detector via execute_aggregate empty-source guard),
939    //      requires_resolve() flags Pattern only
940    // -----------------------------------------------------------------
941    #[test]
942    fn source_spec_pattern_unresolved_yields_empty_tables() {
943        let pat = SourceSpec::Pattern("shi_*".into());
944        assert_eq!(pat.tables(), &[] as &[String]);
945        assert!(pat.requires_resolve());
946        assert!(!SourceSpec::Single("t".into()).requires_resolve());
947        assert!(!SourceSpec::Multi(vec!["a".into()]).requires_resolve());
948    }
949
950    // -----------------------------------------------------------------
951    // (j3) SourceSpec::resolve_pattern — glob matching matrix
952    //      (prefix / suffix / middle / exact / matches-all),
953    //      0-hit error path, 1-hit Single normalisation, n-hit Multi
954    //      sorted ascending for determinism
955    // -----------------------------------------------------------------
956    #[test]
957    fn source_spec_resolve_pattern_prefix_glob() {
958        let tables = vec![
959            "shi_active_context".into(),
960            "shi_ng_context".into(),
961            "shi_trigger".into(),
962            "mia_brief".into(),
963        ];
964        let resolved = SourceSpec::Pattern("shi_*".into())
965            .resolve_pattern(&tables)
966            .expect("resolve ok");
967        match resolved {
968            SourceSpec::Multi(v) => assert_eq!(
969                v,
970                vec![
971                    "shi_active_context".to_string(),
972                    "shi_ng_context".to_string(),
973                    "shi_trigger".to_string(),
974                ]
975            ),
976            other => panic!("expected Multi, got {other:?}"),
977        }
978    }
979
980    #[test]
981    fn source_spec_resolve_pattern_suffix_glob() {
982        let tables = vec!["agent_log".into(), "session_log".into(), "memo".into()];
983        let resolved = SourceSpec::Pattern("*_log".into())
984            .resolve_pattern(&tables)
985            .expect("resolve ok");
986        match resolved {
987            SourceSpec::Multi(v) => {
988                assert_eq!(v, vec!["agent_log".to_string(), "session_log".to_string()])
989            }
990            other => panic!("expected Multi, got {other:?}"),
991        }
992    }
993
994    #[test]
995    fn source_spec_resolve_pattern_middle_glob() {
996        let tables = vec![
997            "shi_v1_brief".into(),
998            "shi_v2_brief".into(),
999            "shi_v1_log".into(),
1000        ];
1001        let resolved = SourceSpec::Pattern("shi_*_brief".into())
1002            .resolve_pattern(&tables)
1003            .expect("resolve ok");
1004        match resolved {
1005            SourceSpec::Multi(v) => assert_eq!(
1006                v,
1007                vec!["shi_v1_brief".to_string(), "shi_v2_brief".to_string()]
1008            ),
1009            other => panic!("expected Multi, got {other:?}"),
1010        }
1011    }
1012
1013    #[test]
1014    fn source_spec_resolve_pattern_single_hit_normalises_to_single() {
1015        let tables = vec!["shi_brief".into(), "mia_brief".into()];
1016        let resolved = SourceSpec::Pattern("shi_*".into())
1017            .resolve_pattern(&tables)
1018            .expect("resolve ok");
1019        match resolved {
1020            SourceSpec::Single(t) => assert_eq!(t, "shi_brief"),
1021            other => panic!("expected Single, got {other:?}"),
1022        }
1023    }
1024
1025    #[test]
1026    fn source_spec_resolve_pattern_match_all_glob() {
1027        let tables = vec!["a".into(), "b".into()];
1028        let resolved = SourceSpec::Pattern("*".into())
1029            .resolve_pattern(&tables)
1030            .expect("resolve ok");
1031        match resolved {
1032            SourceSpec::Multi(v) => assert_eq!(v, vec!["a".to_string(), "b".to_string()]),
1033            other => panic!("expected Multi for *-match, got {other:?}"),
1034        }
1035    }
1036
1037    #[test]
1038    fn source_spec_resolve_pattern_zero_hit_returns_error() {
1039        let tables = vec!["mia_brief".into()];
1040        let err = SourceSpec::Pattern("shi_*".into())
1041            .resolve_pattern(&tables)
1042            .expect_err("expected zero-hit error");
1043        assert_eq!(err.code(), crate::error::codes::AGGREGATOR_ERROR);
1044    }
1045
1046    #[test]
1047    fn source_spec_resolve_pattern_non_pattern_passes_through() {
1048        let tables = vec!["x".into()];
1049        let single = SourceSpec::Single("t".into())
1050            .resolve_pattern(&tables)
1051            .expect("non-pattern passthrough");
1052        assert!(matches!(single, SourceSpec::Single(ref s) if s == "t"));
1053        let multi = SourceSpec::Multi(vec!["a".into(), "b".into()])
1054            .resolve_pattern(&tables)
1055            .expect("non-pattern passthrough");
1056        assert!(
1057            matches!(multi, SourceSpec::Multi(ref v) if v == &vec!["a".to_string(), "b".to_string()])
1058        );
1059    }
1060
1061    #[test]
1062    fn source_spec_resolve_pattern_rejects_empty_pattern() {
1063        let tables = vec!["x".into()];
1064        let err = SourceSpec::Pattern("".into())
1065            .resolve_pattern(&tables)
1066            .expect_err("empty pattern rejected");
1067        assert_eq!(err.code(), crate::error::codes::AGGREGATOR_ERROR);
1068    }
1069
1070    #[test]
1071    fn source_spec_resolve_pattern_rejects_unsupported_metachar() {
1072        let tables = vec!["x".into()];
1073        for bad in &["shi_?", "shi_[ab]"] {
1074            let err = SourceSpec::Pattern((*bad).into())
1075                .resolve_pattern(&tables)
1076                .expect_err("unsupported metachar rejected");
1077            assert_eq!(err.code(), crate::error::codes::AGGREGATOR_ERROR);
1078        }
1079    }
1080
1081    #[test]
1082    fn source_spec_includes_table_single_multi_pattern() {
1083        assert!(SourceSpec::Single("rows".into()).includes_table("rows"));
1084        assert!(!SourceSpec::Single("rows".into()).includes_table("other"));
1085        let multi = SourceSpec::Multi(vec!["a".into(), "b".into()]);
1086        assert!(multi.includes_table("a"));
1087        assert!(multi.includes_table("b"));
1088        assert!(!multi.includes_table("c"));
1089        let pat = SourceSpec::Pattern("shi_*".into());
1090        assert!(pat.includes_table("shi_active_context"));
1091        assert!(pat.includes_table("shi_trigger"));
1092        assert!(!pat.includes_table("mia_brief"));
1093        // Invalid pattern compiles to false (no panic).
1094        let bad = SourceSpec::Pattern("shi_?".into());
1095        assert!(!bad.includes_table("shi_x"));
1096    }
1097
1098    #[test]
1099    fn source_spec_resolve_pattern_exact_match_no_wildcard() {
1100        let tables = vec!["shi_brief".into(), "shi_log".into()];
1101        let resolved = SourceSpec::Pattern("shi_brief".into())
1102            .resolve_pattern(&tables)
1103            .expect("exact pattern resolves to Single");
1104        match resolved {
1105            SourceSpec::Single(t) => assert_eq!(t, "shi_brief"),
1106            other => panic!("expected Single, got {other:?}"),
1107        }
1108    }
1109
1110    // -----------------------------------------------------------------
1111    // (k) Nested GroupBy rejected
1112    // -----------------------------------------------------------------
1113    #[test]
1114    fn nested_groupby_is_rejected_at_validation() {
1115        let nested = AliasAggregator::GroupBy {
1116            by_field: "tag".into(),
1117            having: None,
1118            inner: Some(Box::new(AliasAggregator::GroupBy {
1119                by_field: "tag".into(),
1120                having: None,
1121                inner: None,
1122            })),
1123        };
1124        let err = nested.validate(&test_schema()).unwrap_err();
1125        assert_eq!(err.code(), crate::error::codes::AGGREGATOR_ERROR);
1126    }
1127}