#![doc = include_str!("../ARCHITECTURE.md")]
#![allow(unused_assignments)]
use std::sync::OnceLock;
use std::{collections::HashMap, path::PathBuf, str::FromStr};
use anstream::adapter::strip_str;
use semver::Version;
use serde::{Deserialize, Serialize};
use strum::VariantNames;
pub use error_message::{ErrorMessage, ErrorMessages, SourceLocation};
pub use prqlc_parser::error::{Error, ErrorSource, Errors, MessageKind, Reason, WithErrorInfo};
pub use prqlc_parser::lexer::lr;
pub use prqlc_parser::parser::pr;
pub use prqlc_parser::span::Span;
mod codegen;
pub mod debug;
mod error_message;
pub mod ir;
pub mod parser;
pub mod semantic;
pub mod sql;
#[cfg(feature = "cli")]
pub mod utils;
#[cfg(not(feature = "cli"))]
pub(crate) mod utils;
pub type Result<T, E = Error> = core::result::Result<T, E>;
pub fn compiler_version() -> Version {
if let Ok(prql_version_override) = std::env::var("PRQL_VERSION_OVERRIDE") {
return Version::parse(&prql_version_override).unwrap_or_else(|e| {
panic!("Could not parse PRQL version {prql_version_override}\n{e}")
});
};
static COMPILER_VERSION: OnceLock<Version> = OnceLock::new();
COMPILER_VERSION
.get_or_init(|| {
if let Ok(prql_version_override) = std::env::var("PRQL_VERSION_OVERRIDE") {
return Version::parse(&prql_version_override).unwrap_or_else(|e| {
panic!("Could not parse PRQL version {prql_version_override}\n{e}")
});
}
let git_version = env!("VERGEN_GIT_DESCRIBE");
let cargo_version = env!("CARGO_PKG_VERSION");
Version::parse(git_version)
.or_else(|e| {
log::info!("Could not parse git version number {git_version}\n{e}");
Version::parse(cargo_version)
})
.unwrap_or_else(|e| {
panic!("Could not parse prqlc version number {cargo_version}\n{e}")
})
})
.clone()
}
pub fn compile(prql: &str, options: &Options) -> Result<String, ErrorMessages> {
let sources = SourceTree::from(prql);
Ok(&sources)
.and_then(parser::parse)
.and_then(|ast| {
semantic::resolve_and_lower(ast, &[], None)
.map_err(|e| e.with_source(ErrorSource::NameResolver).into())
})
.and_then(|rq| {
sql::compile(rq, options).map_err(|e| e.with_source(ErrorSource::SQL).into())
})
.map_err(|e| {
let error_messages = ErrorMessages::from(e).composed(&sources);
match options.display {
DisplayOptions::AnsiColor => error_messages,
DisplayOptions::Plain => ErrorMessages {
inner: error_messages
.inner
.into_iter()
.map(|e| ErrorMessage {
display: e.display.map(|s| strip_str(&s).to_string()),
..e
})
.collect(),
},
}
})
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Target {
Sql(Option<sql::Dialect>),
}
impl Default for Target {
fn default() -> Self {
Self::Sql(None)
}
}
impl Target {
pub fn names() -> Vec<String> {
let mut names = vec!["sql.any".to_string()];
let dialects = sql::Dialect::VARIANTS;
names.extend(dialects.iter().map(|d| format!("sql.{d}")));
names
}
}
impl FromStr for Target {
type Err = Error;
fn from_str(s: &str) -> Result<Target, Self::Err> {
if let Some(dialect) = s.strip_prefix("sql.") {
if dialect == "any" {
return Ok(Target::Sql(None));
}
if let Ok(dialect) = sql::Dialect::from_str(dialect) {
return Ok(Target::Sql(Some(dialect)));
}
}
Err(Error::new(Reason::NotFound {
name: format!("{s:?}"),
namespace: "target".to_string(),
}))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Options {
pub format: bool,
pub target: Target,
pub signature_comment: bool,
pub color: bool,
pub display: DisplayOptions,
}
impl Default for Options {
fn default() -> Self {
Self {
format: true,
target: Target::Sql(None),
signature_comment: true,
color: true,
display: DisplayOptions::AnsiColor,
}
}
}
impl Options {
pub fn with_format(mut self, format: bool) -> Self {
self.format = format;
self
}
pub fn no_format(self) -> Self {
self.with_format(false)
}
pub fn with_signature_comment(mut self, signature_comment: bool) -> Self {
self.signature_comment = signature_comment;
self
}
pub fn no_signature(self) -> Self {
self.with_signature_comment(false)
}
pub fn with_target(mut self, target: Target) -> Self {
self.target = target;
self
}
#[deprecated(note = "`color` is replaced by `display`; see `Options` docs for more details")]
pub fn with_color(mut self, color: bool) -> Self {
self.color = color;
self
}
pub fn with_display(mut self, display: DisplayOptions) -> Self {
self.display = display;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, strum::EnumString)]
#[strum(serialize_all = "snake_case")]
#[non_exhaustive]
pub enum DisplayOptions {
Plain,
AnsiColor,
}
#[doc = include_str!("../README.md")]
#[cfg(doctest)]
pub struct ReadmeDoctests;
pub fn prql_to_tokens(prql: &str) -> Result<lr::Tokens, ErrorMessages> {
prqlc_parser::lexer::lex_source(prql).map_err(|e| {
e.into_iter()
.map(|e| e.into())
.collect::<Vec<ErrorMessage>>()
.into()
})
}
pub fn prql_to_pl(prql: &str) -> Result<pr::ModuleDef, ErrorMessages> {
let source_tree = SourceTree::from(prql);
prql_to_pl_tree(&source_tree)
}
pub fn prql_to_pl_tree(prql: &SourceTree) -> Result<pr::ModuleDef, ErrorMessages> {
parser::parse(prql).map_err(|e| ErrorMessages::from(e).composed(prql))
}
pub fn pl_to_rq(pl: pr::ModuleDef) -> Result<ir::rq::RelationalQuery, ErrorMessages> {
semantic::resolve_and_lower(pl, &[], None)
.map_err(|e| e.with_source(ErrorSource::NameResolver).into())
}
pub fn pl_to_rq_tree(
pl: pr::ModuleDef,
main_path: &[String],
database_module_path: &[String],
) -> Result<ir::rq::RelationalQuery, ErrorMessages> {
semantic::resolve_and_lower(pl, main_path, Some(database_module_path))
.map_err(|e| e.with_source(ErrorSource::NameResolver).into())
}
pub fn rq_to_sql(rq: ir::rq::RelationalQuery, options: &Options) -> Result<String, ErrorMessages> {
sql::compile(rq, options).map_err(|e| e.with_source(ErrorSource::SQL).into())
}
pub fn pl_to_prql(pl: &pr::ModuleDef) -> Result<String, ErrorMessages> {
Ok(codegen::WriteSource::write(&pl.stmts, codegen::WriteOpt::default()).unwrap())
}
pub mod json {
use super::*;
pub fn from_pl(pl: &pr::ModuleDef) -> Result<String, ErrorMessages> {
serde_json::to_string(pl).map_err(convert_json_err)
}
pub fn to_pl(json: &str) -> Result<pr::ModuleDef, ErrorMessages> {
serde_json::from_str(json).map_err(convert_json_err)
}
pub fn from_rq(rq: &ir::rq::RelationalQuery) -> Result<String, ErrorMessages> {
serde_json::to_string(rq).map_err(convert_json_err)
}
pub fn to_rq(json: &str) -> Result<ir::rq::RelationalQuery, ErrorMessages> {
serde_json::from_str(json).map_err(convert_json_err)
}
fn convert_json_err(err: serde_json::Error) -> ErrorMessages {
ErrorMessages::from(Error::new_simple(err.to_string()))
}
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct SourceTree {
pub root: Option<PathBuf>,
pub sources: HashMap<PathBuf, String>,
source_ids: HashMap<u16, PathBuf>,
}
impl SourceTree {
pub fn single(path: PathBuf, content: String) -> Self {
SourceTree {
sources: [(path.clone(), content)].into(),
source_ids: [(1, path)].into(),
root: None,
}
}
pub fn new<I>(iter: I, root: Option<PathBuf>) -> Self
where
I: IntoIterator<Item = (PathBuf, String)>,
{
let mut res = SourceTree {
sources: HashMap::new(),
source_ids: HashMap::new(),
root,
};
for (index, (path, content)) in iter.into_iter().enumerate() {
res.sources.insert(path.clone(), content);
res.source_ids.insert((index + 1) as u16, path);
}
res
}
pub fn insert(&mut self, path: PathBuf, content: String) {
let last_id = self.source_ids.keys().max().cloned().unwrap_or(0);
self.sources.insert(path.clone(), content);
self.source_ids.insert(last_id + 1, path);
}
pub fn get_path(&self, source_id: u16) -> Option<&PathBuf> {
self.source_ids.get(&source_id)
}
}
impl<S: ToString> From<S> for SourceTree {
fn from(source: S) -> Self {
SourceTree::single(PathBuf::from(""), source.to_string())
}
}
pub mod internal {
use super::*;
pub fn pl_to_lineage(
pl: pr::ModuleDef,
) -> Result<semantic::reporting::FrameCollector, ErrorMessages> {
let ast = Some(pl.clone());
let root_module = semantic::resolve(pl).map_err(ErrorMessages::from)?;
let (main, _) = root_module.find_main_rel(&[]).unwrap();
let mut fc =
semantic::reporting::collect_frames(*main.clone().into_relation_var().unwrap());
fc.ast = ast;
Ok(fc)
}
pub mod json {
use super::*;
pub fn from_lineage(
fc: &semantic::reporting::FrameCollector,
) -> Result<String, ErrorMessages> {
serde_json::to_string(fc).map_err(convert_json_err)
}
fn convert_json_err(err: serde_json::Error) -> ErrorMessages {
ErrorMessages::from(Error::new_simple(err.to_string()))
}
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use insta::assert_debug_snapshot;
use crate::pr::Ident;
use crate::Target;
pub fn compile(prql: &str) -> Result<String, super::ErrorMessages> {
anstream::ColorChoice::Never.write_global();
super::compile(prql, &super::Options::default().no_signature())
}
#[test]
fn test_starts_with() {
let a = Ident::from_path(vec!["a", "b", "c"]);
let b = Ident::from_path(vec!["a", "b"]);
let c = Ident::from_path(vec!["a", "b", "c", "d"]);
let d = Ident::from_path(vec!["a", "b", "d"]);
let e = Ident::from_path(vec!["a", "c"]);
let f = Ident::from_path(vec!["b", "c"]);
assert!(a.starts_with(&b));
assert!(a.starts_with(&a));
assert!(!a.starts_with(&c));
assert!(!a.starts_with(&d));
assert!(!a.starts_with(&e));
assert!(!a.starts_with(&f));
}
#[test]
fn test_target_from_str() {
assert_debug_snapshot!(Target::from_str("sql.postgres"), @r"
Ok(
Sql(
Some(
Postgres,
),
),
)
");
assert_debug_snapshot!(Target::from_str("sql.poostgres"), @r#"
Err(
Error {
kind: Error,
span: None,
reason: NotFound {
name: "\"sql.poostgres\"",
namespace: "target",
},
hints: [],
code: None,
},
)
"#);
assert_debug_snapshot!(Target::from_str("postgres"), @r#"
Err(
Error {
kind: Error,
span: None,
reason: NotFound {
name: "\"postgres\"",
namespace: "target",
},
hints: [],
code: None,
},
)
"#);
}
#[test]
fn test_target_names() {
let _: Vec<_> = Target::names()
.into_iter()
.map(|name| Target::from_str(&name))
.collect();
}
#[test]
fn test_sort_not_propagated_after_join() {
use insta::assert_snapshot;
assert_snapshot!(
super::compile(
r#"
prql target:sql.postgres
from tracks
group media_type_id (
sort name
take 1
)
join media_types (== media_type_id)
select {
tracks.track_id,
media_types.name
}
"#,
&super::Options::default().no_signature()
).unwrap(),
@"
WITH table_0 AS (
SELECT
DISTINCT ON (media_type_id) track_id,
media_type_id,
name
FROM
tracks
ORDER BY
media_type_id,
name
)
SELECT
table_0.track_id,
media_types.name
FROM
table_0
INNER JOIN media_types ON table_0.media_type_id = media_types.media_type_id
"
);
}
#[test]
fn test_explicit_sort_after_distinct_on_preserved() {
use insta::assert_snapshot;
assert_snapshot!(
super::compile(
r#"
prql target:sql.postgres
from tracks
group media_type_id (
sort name
take 1
)
sort media_type_id
join media_types (== media_type_id)
select {
tracks.track_id,
media_types.name
}
"#,
&super::Options::default().no_signature()
).unwrap(),
@"
WITH table_0 AS (
SELECT
DISTINCT ON (media_type_id) track_id,
media_type_id,
name
FROM
tracks
ORDER BY
media_type_id,
name
),
table_1 AS (
SELECT
table_0.track_id,
media_types.name,
table_0.media_type_id
FROM
table_0
INNER JOIN media_types ON table_0.media_type_id = media_types.media_type_id
)
SELECT
track_id,
name
FROM
table_1
ORDER BY
media_type_id
"
);
}
}