use std::collections::BTreeMap;
use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum StoredQueryError {
#[error("invalid query file: missing front matter delimiters (---)")]
MissingFrontMatter,
#[error("failed to parse query metadata: {0}")]
InvalidMetadata(#[from] serde_yaml::Error),
#[error("query file has no SQL body")]
EmptyQuery,
#[error("failed to read query file: {0}")]
Read(#[from] std::io::Error),
#[error("parameter '{name}' is required")]
MissingParam { name: String },
#[error("parameter '{name}': expected {expected}, got '{value}'")]
InvalidParamType {
name: String,
expected: String,
value: String,
},
#[error("parameter '{name}': value {value} is below minimum {min}")]
BelowMin { name: String, value: f64, min: f64 },
#[error("parameter '{name}': value {value} exceeds maximum {max}")]
AboveMax { name: String, value: f64, max: f64 },
#[error("parameter '{name}': '{value}' is not one of the allowed values: {choices}")]
InvalidChoice {
name: String,
value: String,
choices: String,
},
#[error("parameter '{name}': value '{value}' does not match pattern '{pattern}'")]
PatternMismatch {
name: String,
value: String,
pattern: String,
},
#[error("no queries directory found")]
NoQueriesDir,
#[error("step '{name}' referenced in SQL but not defined in steps")]
UndefinedStep { name: String },
#[error("step '{name}' defined in metadata but has no SQL (missing `-- step: {name}`)")]
MissingStepSql { name: String },
#[error("SQL has `-- step: {name}` marker but '{name}' is not defined in steps")]
UnknownStepMarker { name: String },
#[error("step '{name}' returned no results, cannot resolve @{name}.{field}")]
EmptyStepResult { name: String, field: String },
#[error("field '{field}' not found in step '{name}' result")]
StepFieldNotFound { name: String, field: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StepDef {
pub name: String,
pub container: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum ParamType {
String,
Number,
Bool,
}
impl std::fmt::Display for ParamType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ParamType::String => write!(f, "string"),
ParamType::Number => write!(f, "number"),
ParamType::Bool => write!(f, "bool"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParamDef {
pub name: String,
#[serde(rename = "type")]
pub param_type: ParamType,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub default: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub choices: Option<Vec<serde_json::Value>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub required: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub min: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub pattern: Option<String>,
}
impl ParamDef {
pub fn is_required(&self) -> bool {
self.required
.unwrap_or_else(|| self.default.is_none() && self.choices.is_none())
}
pub fn validate(&self, value: &serde_json::Value) -> Result<(), StoredQueryError> {
match self.param_type {
ParamType::String => {
if !value.is_string() {
return Err(StoredQueryError::InvalidParamType {
name: self.name.clone(),
expected: "string".into(),
value: value.to_string(),
});
}
}
ParamType::Number => {
if !value.is_number() {
return Err(StoredQueryError::InvalidParamType {
name: self.name.clone(),
expected: "number".into(),
value: value.to_string(),
});
}
}
ParamType::Bool => {
if !value.is_boolean() {
return Err(StoredQueryError::InvalidParamType {
name: self.name.clone(),
expected: "bool".into(),
value: value.to_string(),
});
}
}
}
if let Some(num) = value.as_f64() {
if let Some(min) = self.min {
if num < min {
return Err(StoredQueryError::BelowMin {
name: self.name.clone(),
value: num,
min,
});
}
}
if let Some(max) = self.max {
if num > max {
return Err(StoredQueryError::AboveMax {
name: self.name.clone(),
value: num,
max,
});
}
}
}
if let Some(ref choices) = self.choices {
if !choices.contains(value) {
let choices_str = choices
.iter()
.map(|c| match c {
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
})
.collect::<Vec<_>>()
.join(", ");
return Err(StoredQueryError::InvalidChoice {
name: self.name.clone(),
value: match value {
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
},
choices: choices_str,
});
}
}
if let (Some(pattern), Some(s)) = (&self.pattern, value.as_str()) {
let re = regex::Regex::new(pattern).map_err(|_| StoredQueryError::PatternMismatch {
name: self.name.clone(),
value: s.to_string(),
pattern: pattern.clone(),
})?;
if !re.is_match(s) {
return Err(StoredQueryError::PatternMismatch {
name: self.name.clone(),
value: s.to_string(),
pattern: pattern.clone(),
});
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredQueryMetadata {
pub description: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub database: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub container: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub steps: Option<Vec<StepDef>>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub params: Vec<ParamDef>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub template: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub template_file: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub generated_by: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub generated_from: Option<String>,
}
#[derive(Debug, Clone)]
pub struct StoredQuery {
pub name: String,
pub metadata: StoredQueryMetadata,
pub sql: String,
pub step_queries: BTreeMap<String, String>,
}
impl StoredQuery {
pub fn parse(name: &str, contents: &str) -> Result<Self, StoredQueryError> {
let (metadata, raw_sql) = parse_front_matter(contents)?;
let raw_sql = raw_sql.trim().to_string();
if raw_sql.is_empty() {
return Err(StoredQueryError::EmptyQuery);
}
if let Some(ref steps) = metadata.steps {
let step_queries = parse_step_sql(&raw_sql)?;
let step_names: std::collections::HashSet<&str> =
steps.iter().map(|s| s.name.as_str()).collect();
for step in steps {
if !step_queries.contains_key(&step.name) {
return Err(StoredQueryError::MissingStepSql {
name: step.name.clone(),
});
}
}
for sql_name in step_queries.keys() {
if !step_names.contains(sql_name.as_str()) {
return Err(StoredQueryError::UnknownStepMarker {
name: sql_name.clone(),
});
}
}
Ok(Self {
name: name.to_string(),
metadata,
sql: String::new(),
step_queries,
})
} else {
Ok(Self {
name: name.to_string(),
metadata,
sql: raw_sql,
step_queries: BTreeMap::new(),
})
}
}
pub fn load(path: &Path) -> Result<Self, StoredQueryError> {
let name = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown")
.to_string();
let contents = std::fs::read_to_string(path)?;
Self::parse(&name, &contents)
}
pub fn is_multi_step(&self) -> bool {
self.metadata.steps.is_some()
}
pub fn to_file_contents(&self) -> Result<String, serde_yaml::Error> {
let yaml = serde_yaml::to_string(&self.metadata)?;
if self.is_multi_step() {
let mut sql_body = String::new();
for (i, step) in self.metadata.steps.as_ref().unwrap().iter().enumerate() {
if i > 0 {
sql_body.push('\n');
}
sql_body.push_str(&format!("-- step: {}\n", step.name));
if let Some(sql) = self.step_queries.get(&step.name) {
sql_body.push_str(sql);
if !sql.ends_with('\n') {
sql_body.push('\n');
}
}
}
Ok(format!("---\n{}---\n{}", yaml, sql_body))
} else {
Ok(format!("---\n{}---\n{}\n", yaml, self.sql))
}
}
pub fn find_step_references(sql: &str, step_names: &[String]) -> Vec<(String, String)> {
let re = regex::Regex::new(r"@(\w+)\.(\w+)").unwrap();
let mut refs = Vec::new();
for cap in re.captures_iter(sql) {
let step_name = cap[1].to_string();
let field_name = cap[2].to_string();
if step_names.contains(&step_name) {
refs.push((step_name, field_name));
}
}
refs
}
pub fn execution_order(&self) -> Result<Vec<Vec<String>>, StoredQueryError> {
let steps = match &self.metadata.steps {
Some(s) => s,
None => return Ok(vec![vec![]]),
};
let step_names: Vec<String> = steps.iter().map(|s| s.name.clone()).collect();
let mut deps: BTreeMap<String, std::collections::HashSet<String>> = BTreeMap::new();
for step in steps {
let sql = self
.step_queries
.get(&step.name)
.cloned()
.unwrap_or_default();
let step_refs = Self::find_step_references(&sql, &step_names);
let dep_names: std::collections::HashSet<String> =
step_refs.into_iter().map(|(name, _)| name).collect();
deps.insert(step.name.clone(), dep_names);
}
let mut layers = Vec::new();
let mut resolved: std::collections::HashSet<String> = std::collections::HashSet::new();
let mut remaining: Vec<String> = step_names;
while !remaining.is_empty() {
let layer: Vec<String> = remaining
.iter()
.filter(|name| {
deps.get(*name)
.map(|d| d.iter().all(|dep| resolved.contains(dep)))
.unwrap_or(true)
})
.cloned()
.collect();
if layer.is_empty() {
return Err(StoredQueryError::UndefinedStep {
name: remaining.join(", "),
});
}
for name in &layer {
resolved.insert(name.clone());
}
remaining.retain(|name| !resolved.contains(name));
layers.push(layer);
}
Ok(layers)
}
pub fn resolve_params(
&self,
provided: &BTreeMap<String, String>,
) -> Result<BTreeMap<String, serde_json::Value>, StoredQueryError> {
let mut resolved = BTreeMap::new();
for param in &self.metadata.params {
let value = if let Some(raw) = provided.get(¶m.name) {
parse_param_value(¶m.name, ¶m.param_type, raw)?
} else if let Some(ref default) = param.default {
default.clone()
} else if let Some(ref choices) = param.choices {
if choices.len() == 1 {
choices[0].clone()
} else if param.is_required() {
return Err(StoredQueryError::MissingParam {
name: param.name.clone(),
});
} else {
continue;
}
} else if param.is_required() {
return Err(StoredQueryError::MissingParam {
name: param.name.clone(),
});
} else {
continue;
};
param.validate(&value)?;
resolved.insert(param.name.clone(), value);
}
Ok(resolved)
}
pub fn build_cosmos_params(
resolved: &BTreeMap<String, serde_json::Value>,
) -> Vec<serde_json::Value> {
resolved
.iter()
.map(|(name, value)| {
serde_json::json!({
"name": format!("@{name}"),
"value": value
})
})
.collect()
}
}
pub fn parse_param_value_public(
name: &str,
param_type: &ParamType,
raw: &str,
) -> Result<serde_json::Value, StoredQueryError> {
parse_param_value(name, param_type, raw)
}
fn parse_param_value(
name: &str,
param_type: &ParamType,
raw: &str,
) -> Result<serde_json::Value, StoredQueryError> {
match param_type {
ParamType::String => Ok(serde_json::Value::String(raw.to_string())),
ParamType::Number => {
if let Ok(i) = raw.parse::<i64>() {
Ok(serde_json::json!(i))
} else if let Ok(f) = raw.parse::<f64>() {
Ok(serde_json::json!(f))
} else {
Err(StoredQueryError::InvalidParamType {
name: name.to_string(),
expected: "number".into(),
value: raw.to_string(),
})
}
}
ParamType::Bool => match raw.to_lowercase().as_str() {
"true" | "1" | "yes" => Ok(serde_json::Value::Bool(true)),
"false" | "0" | "no" => Ok(serde_json::Value::Bool(false)),
_ => Err(StoredQueryError::InvalidParamType {
name: name.to_string(),
expected: "bool (true/false)".into(),
value: raw.to_string(),
}),
},
}
}
fn parse_step_sql(raw_sql: &str) -> Result<BTreeMap<String, String>, StoredQueryError> {
let mut steps = BTreeMap::new();
let mut current_name: Option<String> = None;
let mut current_sql = String::new();
for line in raw_sql.lines() {
let trimmed = line.trim();
if let Some(name) = trimmed
.strip_prefix("-- step:")
.map(|s| s.trim().to_string())
{
if let Some(prev_name) = current_name.take() {
let sql = current_sql.trim().to_string();
if !sql.is_empty() {
steps.insert(prev_name, sql);
}
}
current_name = Some(name);
current_sql = String::new();
} else if current_name.is_some() {
current_sql.push_str(line);
current_sql.push('\n');
}
}
if let Some(name) = current_name {
let sql = current_sql.trim().to_string();
if !sql.is_empty() {
steps.insert(name, sql);
}
}
Ok(steps)
}
fn parse_front_matter(contents: &str) -> Result<(StoredQueryMetadata, String), StoredQueryError> {
let trimmed = contents.trim_start();
if !trimmed.starts_with("---") {
return Err(StoredQueryError::MissingFrontMatter);
}
let after_first = &trimmed[3..];
let closing = after_first
.find("\n---")
.ok_or(StoredQueryError::MissingFrontMatter)?;
let yaml_str = &after_first[..closing];
let rest = &after_first[closing + 4..];
let metadata: StoredQueryMetadata = serde_yaml::from_str(yaml_str)?;
Ok((metadata, rest.to_string()))
}
pub fn user_queries_dir() -> Result<PathBuf, StoredQueryError> {
dirs::home_dir()
.map(|d| d.join(".cosq").join("queries"))
.ok_or(StoredQueryError::NoQueriesDir)
}
pub fn project_queries_dir() -> Option<PathBuf> {
std::env::current_dir()
.ok()
.map(|d| d.join(".cosq").join("queries"))
}
pub fn list_stored_queries() -> Result<Vec<StoredQuery>, StoredQueryError> {
let mut queries = BTreeMap::new();
if let Ok(user_dir) = user_queries_dir() {
if user_dir.is_dir() {
load_queries_from_dir(&user_dir, &mut queries)?;
}
}
if let Some(project_dir) = project_queries_dir() {
if project_dir.is_dir() {
load_queries_from_dir(&project_dir, &mut queries)?;
}
}
Ok(queries.into_values().collect())
}
pub fn list_query_names() -> Vec<(String, Option<String>)> {
if let Ok(queries) = list_stored_queries() {
return queries
.into_iter()
.map(|q| (q.name, Some(q.metadata.description)))
.collect();
}
let mut names = BTreeMap::new();
if let Ok(user_dir) = user_queries_dir() {
if user_dir.is_dir() {
collect_names_from_dir(&user_dir, &mut names);
}
}
if let Some(project_dir) = project_queries_dir() {
if project_dir.is_dir() {
collect_names_from_dir(&project_dir, &mut names);
}
}
names.into_keys().map(|name| (name, None)).collect()
}
fn collect_names_from_dir(dir: &Path, names: &mut BTreeMap<String, ()>) {
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().is_some_and(|ext| ext == "cosq") {
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
names.insert(stem.to_string(), ());
}
}
}
}
}
pub fn find_stored_query(name: &str) -> Result<StoredQuery, StoredQueryError> {
let filename = if name.ends_with(".cosq") {
name.to_string()
} else {
format!("{name}.cosq")
};
if let Some(project_dir) = project_queries_dir() {
let path = project_dir.join(&filename);
if path.exists() {
return StoredQuery::load(&path);
}
}
let user_dir = user_queries_dir()?;
let path = user_dir.join(&filename);
if path.exists() {
return StoredQuery::load(&path);
}
Err(StoredQueryError::Read(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("stored query '{name}' not found"),
)))
}
pub fn query_file_path(name: &str, project_level: bool) -> Result<PathBuf, StoredQueryError> {
let filename = if name.ends_with(".cosq") {
name.to_string()
} else {
format!("{name}.cosq")
};
if project_level {
project_queries_dir()
.map(|d| d.join(filename))
.ok_or(StoredQueryError::NoQueriesDir)
} else {
Ok(user_queries_dir()?.join(filename))
}
}
fn load_queries_from_dir(
dir: &Path,
queries: &mut BTreeMap<String, StoredQuery>,
) -> Result<(), StoredQueryError> {
let entries = std::fs::read_dir(dir)?;
for entry in entries {
let entry = entry?;
let path = entry.path();
if path.extension().is_some_and(|ext| ext == "cosq") {
match StoredQuery::load(&path) {
Ok(query) => {
queries.insert(query.name.clone(), query);
}
Err(e) => {
eprintln!("Warning: skipping {}: {}", path.display(), e);
}
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
const EXAMPLE_QUERY: &str = r#"---
description: Find users who signed up recently
database: mydb
container: users
params:
- name: days
type: number
description: Number of days to look back
default: 30
---
SELECT c.id, c.email, c.displayName, c.createdAt
FROM c
WHERE c.createdAt >= DateTimeAdd("dd", -@days, GetCurrentDateTime())
ORDER BY c.createdAt DESC
"#;
const QUERY_WITH_CHOICES: &str = r#"---
description: List orders by status
database: shop-db
container: orders
params:
- name: status
type: string
description: Order status
choices: ["pending", "shipped", "delivered"]
default: "pending"
- name: limit
type: number
default: 50
min: 1
max: 1000
---
SELECT TOP @limit * FROM c WHERE c.status = @status
"#;
const QUERY_WITH_TEMPLATE: &str = r#"---
description: Orders summary
params:
- name: status
type: string
default: "pending"
template: |
Orders ({{ status }}):
{% for doc in documents %}
{{ loop.index }}. #{{ doc.id }} — ${{ doc.total }}
{% endfor %}
---
SELECT c.id, c.total FROM c WHERE c.status = @status
"#;
#[test]
fn test_parse_basic_query() {
let query = StoredQuery::parse("recent-users", EXAMPLE_QUERY).unwrap();
assert_eq!(query.name, "recent-users");
assert_eq!(
query.metadata.description,
"Find users who signed up recently"
);
assert_eq!(query.metadata.database.as_deref(), Some("mydb"));
assert_eq!(query.metadata.container.as_deref(), Some("users"));
assert_eq!(query.metadata.params.len(), 1);
assert_eq!(query.metadata.params[0].name, "days");
assert_eq!(query.metadata.params[0].param_type, ParamType::Number);
assert_eq!(
query.metadata.params[0].default,
Some(serde_json::json!(30))
);
assert!(query.sql.contains("SELECT"));
assert!(query.sql.contains("@days"));
}
#[test]
fn test_parse_query_with_choices() {
let query = StoredQuery::parse("orders", QUERY_WITH_CHOICES).unwrap();
assert_eq!(query.metadata.params.len(), 2);
let status_param = &query.metadata.params[0];
assert_eq!(status_param.name, "status");
assert_eq!(
status_param.choices.as_ref().unwrap(),
&vec![
serde_json::json!("pending"),
serde_json::json!("shipped"),
serde_json::json!("delivered"),
]
);
let limit_param = &query.metadata.params[1];
assert_eq!(limit_param.min, Some(1.0));
assert_eq!(limit_param.max, Some(1000.0));
}
#[test]
fn test_parse_query_with_template() {
let query = StoredQuery::parse("orders-summary", QUERY_WITH_TEMPLATE).unwrap();
assert!(query.metadata.template.is_some());
assert!(
query
.metadata
.template
.as_ref()
.unwrap()
.contains("{% for doc in documents %}")
);
}
#[test]
fn test_resolve_params_with_defaults() {
let query = StoredQuery::parse("recent-users", EXAMPLE_QUERY).unwrap();
let provided = BTreeMap::new();
let resolved = query.resolve_params(&provided).unwrap();
assert_eq!(resolved.get("days"), Some(&serde_json::json!(30)));
}
#[test]
fn test_resolve_params_with_cli_values() {
let query = StoredQuery::parse("recent-users", EXAMPLE_QUERY).unwrap();
let mut provided = BTreeMap::new();
provided.insert("days".to_string(), "7".to_string());
let resolved = query.resolve_params(&provided).unwrap();
assert_eq!(resolved.get("days"), Some(&serde_json::json!(7)));
}
#[test]
fn test_resolve_params_validation_range() {
let query = StoredQuery::parse("orders", QUERY_WITH_CHOICES).unwrap();
let mut provided = BTreeMap::new();
provided.insert("status".to_string(), "pending".to_string());
provided.insert("limit".to_string(), "5000".to_string());
let result = query.resolve_params(&provided);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
}
#[test]
fn test_resolve_params_validation_choices() {
let query = StoredQuery::parse("orders", QUERY_WITH_CHOICES).unwrap();
let mut provided = BTreeMap::new();
provided.insert("status".to_string(), "invalid".to_string());
provided.insert("limit".to_string(), "10".to_string());
let result = query.resolve_params(&provided);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("not one of the allowed")
);
}
#[test]
fn test_build_cosmos_params() {
let mut resolved = BTreeMap::new();
resolved.insert("days".to_string(), serde_json::json!(7));
resolved.insert("status".to_string(), serde_json::json!("active"));
let params = StoredQuery::build_cosmos_params(&resolved);
assert_eq!(params.len(), 2);
let days_param = params.iter().find(|p| p["name"] == "@days").unwrap();
assert_eq!(days_param["value"], 7);
let status_param = params.iter().find(|p| p["name"] == "@status").unwrap();
assert_eq!(status_param["value"], "active");
}
#[test]
fn test_roundtrip_serialization() {
let query = StoredQuery::parse("test", EXAMPLE_QUERY).unwrap();
let contents = query.to_file_contents().unwrap();
let reparsed = StoredQuery::parse("test", &contents).unwrap();
assert_eq!(reparsed.metadata.description, query.metadata.description);
assert_eq!(reparsed.metadata.database, query.metadata.database);
assert_eq!(reparsed.metadata.params.len(), query.metadata.params.len());
assert_eq!(reparsed.sql, query.sql);
}
#[test]
fn test_missing_front_matter() {
let result = StoredQuery::parse("bad", "SELECT * FROM c");
assert!(matches!(result, Err(StoredQueryError::MissingFrontMatter)));
}
#[test]
fn test_empty_query() {
let contents = "---\ndescription: empty\n---\n";
let result = StoredQuery::parse("empty", contents);
assert!(matches!(result, Err(StoredQueryError::EmptyQuery)));
}
#[test]
fn test_param_required_without_default() {
let contents = r#"---
description: test
params:
- name: id
type: string
---
SELECT * FROM c WHERE c.id = @id
"#;
let query = StoredQuery::parse("test", contents).unwrap();
assert!(query.metadata.params[0].is_required());
let result = query.resolve_params(&BTreeMap::new());
assert!(matches!(result, Err(StoredQueryError::MissingParam { .. })));
}
#[test]
fn test_parse_bool_param() {
let value = parse_param_value("active", &ParamType::Bool, "true").unwrap();
assert_eq!(value, serde_json::Value::Bool(true));
let value = parse_param_value("active", &ParamType::Bool, "false").unwrap();
assert_eq!(value, serde_json::Value::Bool(false));
let value = parse_param_value("active", &ParamType::Bool, "yes").unwrap();
assert_eq!(value, serde_json::Value::Bool(true));
}
#[test]
fn test_param_with_pattern() {
let contents = r#"---
description: test
params:
- name: email
type: string
pattern: "^[^@]+@[^@]+$"
---
SELECT * FROM c WHERE c.email = @email
"#;
let query = StoredQuery::parse("test", contents).unwrap();
let mut provided = BTreeMap::new();
provided.insert("email".to_string(), "user@example.com".to_string());
assert!(query.resolve_params(&provided).is_ok());
let mut bad = BTreeMap::new();
bad.insert("email".to_string(), "not-an-email".to_string());
assert!(query.resolve_params(&bad).is_err());
}
#[test]
fn test_query_no_params() {
let contents = r#"---
description: Simple count
---
SELECT VALUE COUNT(1) FROM c
"#;
let query = StoredQuery::parse("count", contents).unwrap();
assert!(query.metadata.params.is_empty());
let resolved = query.resolve_params(&BTreeMap::new()).unwrap();
assert!(resolved.is_empty());
}
const MULTI_STEP_PARALLEL: &str = r#"---
description: Order with line items
database: mydb
params:
- name: orderId
type: string
steps:
- name: header
container: order-headers
- name: lines
container: order-lines
template: |
Order: {{ header[0].orderId }}
{% for line in lines %}
{{ line.productName }}
{% endfor %}
---
-- step: header
SELECT * FROM c WHERE c.orderId = @orderId
-- step: lines
SELECT * FROM c WHERE c.orderId = @orderId ORDER BY c.lineNumber
"#;
const MULTI_STEP_CHAIN: &str = r#"---
description: Customer orders by name
database: mydb
params:
- name: customerName
type: string
steps:
- name: customer
container: customers
- name: orders
container: orders
template: |
Customer: {{ customer[0].name }}
{% for order in orders %}
{{ order.orderId }}
{% endfor %}
---
-- step: customer
SELECT TOP 1 * FROM c WHERE c.name = @customerName
-- step: orders
SELECT * FROM c WHERE c.customerId = @customer.id ORDER BY c.date DESC
"#;
#[test]
fn test_parse_multi_step_parallel() {
let query = StoredQuery::parse("order-detail", MULTI_STEP_PARALLEL).unwrap();
assert!(query.is_multi_step());
assert!(query.sql.is_empty());
let steps = query.metadata.steps.as_ref().unwrap();
assert_eq!(steps.len(), 2);
assert_eq!(steps[0].name, "header");
assert_eq!(steps[0].container, "order-headers");
assert_eq!(steps[1].name, "lines");
assert_eq!(steps[1].container, "order-lines");
assert_eq!(query.step_queries.len(), 2);
assert!(query.step_queries["header"].contains("@orderId"));
assert!(query.step_queries["lines"].contains("ORDER BY"));
}
#[test]
fn test_parse_multi_step_chain() {
let query = StoredQuery::parse("customer-orders", MULTI_STEP_CHAIN).unwrap();
assert!(query.is_multi_step());
assert!(query.step_queries["orders"].contains("@customer.id"));
}
#[test]
fn test_multi_step_execution_order_parallel() {
let query = StoredQuery::parse("order-detail", MULTI_STEP_PARALLEL).unwrap();
let layers = query.execution_order().unwrap();
assert_eq!(layers.len(), 1);
assert_eq!(layers[0].len(), 2);
}
#[test]
fn test_multi_step_execution_order_chain() {
let query = StoredQuery::parse("customer-orders", MULTI_STEP_CHAIN).unwrap();
let layers = query.execution_order().unwrap();
assert_eq!(layers.len(), 2);
assert_eq!(layers[0], vec!["customer"]);
assert_eq!(layers[1], vec!["orders"]);
}
#[test]
fn test_find_step_references() {
let step_names = vec!["customer".to_string(), "orders".to_string()];
let sql = "SELECT * FROM c WHERE c.customerId = @customer.id AND c.status = @status";
let refs = StoredQuery::find_step_references(sql, &step_names);
assert_eq!(refs, vec![("customer".to_string(), "id".to_string())]);
}
#[test]
fn test_find_step_references_no_matches() {
let step_names = vec!["customer".to_string()];
let sql = "SELECT * FROM c WHERE c.status = @status";
let refs = StoredQuery::find_step_references(sql, &step_names);
assert!(refs.is_empty());
}
#[test]
fn test_multi_step_roundtrip() {
let query = StoredQuery::parse("order-detail", MULTI_STEP_PARALLEL).unwrap();
let contents = query.to_file_contents().unwrap();
let reparsed = StoredQuery::parse("order-detail", &contents).unwrap();
assert!(reparsed.is_multi_step());
assert_eq!(
reparsed.metadata.steps.as_ref().unwrap().len(),
query.metadata.steps.as_ref().unwrap().len()
);
assert_eq!(reparsed.step_queries.len(), query.step_queries.len());
for (name, sql) in &query.step_queries {
assert_eq!(reparsed.step_queries[name], *sql);
}
}
#[test]
fn test_multi_step_missing_sql() {
let contents = r#"---
description: test
steps:
- name: step1
container: c1
- name: step2
container: c2
---
-- step: step1
SELECT * FROM c
"#;
let result = StoredQuery::parse("test", contents);
assert!(matches!(
result,
Err(StoredQueryError::MissingStepSql { .. })
));
}
#[test]
fn test_multi_step_unknown_marker() {
let contents = r#"---
description: test
steps:
- name: step1
container: c1
---
-- step: step1
SELECT * FROM c
-- step: unknown
SELECT * FROM c
"#;
let result = StoredQuery::parse("test", contents);
assert!(matches!(
result,
Err(StoredQueryError::UnknownStepMarker { .. })
));
}
#[test]
fn test_single_step_backward_compat() {
let query = StoredQuery::parse("recent-users", EXAMPLE_QUERY).unwrap();
assert!(!query.is_multi_step());
assert!(!query.sql.is_empty());
assert!(query.step_queries.is_empty());
}
}