1#![doc = include_str!("../ARCHITECTURE.md")]
38#![allow(unused_assignments)]
43use std::sync::OnceLock;
104use std::{collections::HashMap, path::PathBuf, str::FromStr};
105
106use anstream::adapter::strip_str;
107use semver::Version;
108use serde::{Deserialize, Serialize};
109use strum::VariantNames;
110
111pub use error_message::{ErrorMessage, ErrorMessages, SourceLocation};
112pub use prqlc_parser::error::{Error, ErrorSource, Errors, MessageKind, Reason, WithErrorInfo};
113pub use prqlc_parser::lexer::lr;
114pub use prqlc_parser::parser::pr;
115pub use prqlc_parser::span::Span;
116
117mod codegen;
118pub mod debug;
119mod error_message;
120pub mod ir;
121pub mod parser;
122pub mod semantic;
123pub mod sql;
124#[cfg(feature = "cli")]
125pub mod utils;
126#[cfg(not(feature = "cli"))]
127pub(crate) mod utils;
128
129pub type Result<T, E = Error> = core::result::Result<T, E>;
130
131pub fn compiler_version() -> Version {
143 if let Ok(prql_version_override) = std::env::var("PRQL_VERSION_OVERRIDE") {
144 return Version::parse(&prql_version_override).unwrap_or_else(|e| {
145 panic!("Could not parse PRQL version {prql_version_override}\n{e}")
146 });
147 };
148
149 static COMPILER_VERSION: OnceLock<Version> = OnceLock::new();
150 COMPILER_VERSION
151 .get_or_init(|| {
152 if let Ok(prql_version_override) = std::env::var("PRQL_VERSION_OVERRIDE") {
153 return Version::parse(&prql_version_override).unwrap_or_else(|e| {
154 panic!("Could not parse PRQL version {prql_version_override}\n{e}")
155 });
156 }
157 let git_version = env!("VERGEN_GIT_DESCRIBE");
158 let cargo_version = env!("CARGO_PKG_VERSION");
159 Version::parse(git_version)
160 .or_else(|e| {
161 log::info!("Could not parse git version number {git_version}\n{e}");
162 Version::parse(cargo_version)
163 })
164 .unwrap_or_else(|e| {
165 panic!("Could not parse prqlc version number {cargo_version}\n{e}")
166 })
167 })
168 .clone()
169}
170
171pub fn compile(prql: &str, options: &Options) -> Result<String, ErrorMessages> {
195 let sources = SourceTree::from(prql);
196
197 Ok(&sources)
198 .and_then(parser::parse)
199 .and_then(|ast| {
200 semantic::resolve_and_lower(ast, &[], None)
201 .map_err(|e| e.with_source(ErrorSource::NameResolver).into())
202 })
203 .and_then(|rq| {
204 sql::compile(rq, options).map_err(|e| e.with_source(ErrorSource::SQL).into())
205 })
206 .map_err(|e| {
207 let error_messages = ErrorMessages::from(e).composed(&sources);
208 match options.display {
209 DisplayOptions::AnsiColor => error_messages,
210 DisplayOptions::Plain => ErrorMessages {
211 inner: error_messages
212 .inner
213 .into_iter()
214 .map(|e| ErrorMessage {
215 display: e.display.map(|s| strip_str(&s).to_string()),
216 ..e
217 })
218 .collect(),
219 },
220 }
221 })
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize)]
225pub enum Target {
226 Sql(Option<sql::Dialect>),
228}
229
230impl Default for Target {
231 fn default() -> Self {
232 Self::Sql(None)
233 }
234}
235
236impl Target {
237 pub fn names() -> Vec<String> {
238 let mut names = vec!["sql.any".to_string()];
239
240 let dialects = sql::Dialect::VARIANTS;
241 names.extend(dialects.iter().map(|d| format!("sql.{d}")));
242
243 names
244 }
245}
246
247impl FromStr for Target {
248 type Err = Error;
249
250 fn from_str(s: &str) -> Result<Target, Self::Err> {
251 if let Some(dialect) = s.strip_prefix("sql.") {
252 if dialect == "any" {
253 return Ok(Target::Sql(None));
254 }
255
256 if let Ok(dialect) = sql::Dialect::from_str(dialect) {
257 return Ok(Target::Sql(Some(dialect)));
258 }
259 }
260
261 Err(Error::new(Reason::NotFound {
262 name: format!("{s:?}"),
263 namespace: "target".to_string(),
264 }))
265 }
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct Options {
271 pub format: bool,
276
277 pub target: Target,
279
280 pub signature_comment: bool,
284
285 pub color: bool,
287
288 pub display: DisplayOptions,
300}
301
302impl Default for Options {
303 fn default() -> Self {
304 Self {
305 format: true,
306 target: Target::Sql(None),
307 signature_comment: true,
308 color: true,
309 display: DisplayOptions::AnsiColor,
310 }
311 }
312}
313
314impl Options {
315 pub fn with_format(mut self, format: bool) -> Self {
316 self.format = format;
317 self
318 }
319
320 pub fn no_format(self) -> Self {
321 self.with_format(false)
322 }
323
324 pub fn with_signature_comment(mut self, signature_comment: bool) -> Self {
325 self.signature_comment = signature_comment;
326 self
327 }
328
329 pub fn no_signature(self) -> Self {
330 self.with_signature_comment(false)
331 }
332
333 pub fn with_target(mut self, target: Target) -> Self {
334 self.target = target;
335 self
336 }
337
338 #[deprecated(note = "`color` is replaced by `display`; see `Options` docs for more details")]
339 pub fn with_color(mut self, color: bool) -> Self {
340 self.color = color;
341 self
342 }
343
344 pub fn with_display(mut self, display: DisplayOptions) -> Self {
345 self.display = display;
346 self
347 }
348}
349
350#[derive(Debug, Clone, Serialize, Deserialize, strum::EnumString)]
351#[strum(serialize_all = "snake_case")]
352#[non_exhaustive]
353pub enum DisplayOptions {
354 Plain,
356 AnsiColor,
358}
359
360#[doc = include_str!("../README.md")]
361#[cfg(doctest)]
362pub struct ReadmeDoctests;
363
364pub fn prql_to_tokens(prql: &str) -> Result<lr::Tokens, ErrorMessages> {
366 prqlc_parser::lexer::lex_source(prql).map_err(|e| {
367 e.into_iter()
368 .map(|e| e.into())
369 .collect::<Vec<ErrorMessage>>()
370 .into()
371 })
372}
373
374pub fn prql_to_pl(prql: &str) -> Result<pr::ModuleDef, ErrorMessages> {
377 let source_tree = SourceTree::from(prql);
378 prql_to_pl_tree(&source_tree)
379}
380
381pub fn prql_to_pl_tree(prql: &SourceTree) -> Result<pr::ModuleDef, ErrorMessages> {
383 parser::parse(prql).map_err(|e| ErrorMessages::from(e).composed(prql))
384}
385
386pub fn pl_to_rq(pl: pr::ModuleDef) -> Result<ir::rq::RelationalQuery, ErrorMessages> {
389 semantic::resolve_and_lower(pl, &[], None)
390 .map_err(|e| e.with_source(ErrorSource::NameResolver).into())
391}
392
393pub fn pl_to_rq_tree(
395 pl: pr::ModuleDef,
396 main_path: &[String],
397 database_module_path: &[String],
398) -> Result<ir::rq::RelationalQuery, ErrorMessages> {
399 semantic::resolve_and_lower(pl, main_path, Some(database_module_path))
400 .map_err(|e| e.with_source(ErrorSource::NameResolver).into())
401}
402
403pub fn rq_to_sql(rq: ir::rq::RelationalQuery, options: &Options) -> Result<String, ErrorMessages> {
405 sql::compile(rq, options).map_err(|e| e.with_source(ErrorSource::SQL).into())
406}
407
408pub fn pl_to_prql(pl: &pr::ModuleDef) -> Result<String, ErrorMessages> {
410 Ok(codegen::WriteSource::write(&pl.stmts, codegen::WriteOpt::default()).unwrap())
411}
412
413pub mod json {
415 use super::*;
416
417 pub fn from_pl(pl: &pr::ModuleDef) -> Result<String, ErrorMessages> {
419 serde_json::to_string(pl).map_err(convert_json_err)
420 }
421
422 pub fn to_pl(json: &str) -> Result<pr::ModuleDef, ErrorMessages> {
424 serde_json::from_str(json).map_err(convert_json_err)
425 }
426
427 pub fn from_rq(rq: &ir::rq::RelationalQuery) -> Result<String, ErrorMessages> {
429 serde_json::to_string(rq).map_err(convert_json_err)
430 }
431
432 pub fn to_rq(json: &str) -> Result<ir::rq::RelationalQuery, ErrorMessages> {
434 serde_json::from_str(json).map_err(convert_json_err)
435 }
436
437 fn convert_json_err(err: serde_json::Error) -> ErrorMessages {
438 ErrorMessages::from(Error::new_simple(err.to_string()))
439 }
440}
441
442#[derive(Debug, Clone, Default, Serialize)]
449pub struct SourceTree {
450 pub root: Option<PathBuf>,
452
453 pub sources: HashMap<PathBuf, String>,
456
457 source_ids: HashMap<u16, PathBuf>,
459}
460
461impl SourceTree {
462 pub fn single(path: PathBuf, content: String) -> Self {
463 SourceTree {
464 sources: [(path.clone(), content)].into(),
465 source_ids: [(1, path)].into(),
466 root: None,
467 }
468 }
469
470 pub fn new<I>(iter: I, root: Option<PathBuf>) -> Self
471 where
472 I: IntoIterator<Item = (PathBuf, String)>,
473 {
474 let mut res = SourceTree {
475 sources: HashMap::new(),
476 source_ids: HashMap::new(),
477 root,
478 };
479
480 for (index, (path, content)) in iter.into_iter().enumerate() {
481 res.sources.insert(path.clone(), content);
482 res.source_ids.insert((index + 1) as u16, path);
483 }
484 res
485 }
486
487 pub fn insert(&mut self, path: PathBuf, content: String) {
488 let last_id = self.source_ids.keys().max().cloned().unwrap_or(0);
489 self.sources.insert(path.clone(), content);
490 self.source_ids.insert(last_id + 1, path);
491 }
492
493 pub fn get_path(&self, source_id: u16) -> Option<&PathBuf> {
494 self.source_ids.get(&source_id)
495 }
496}
497
498impl<S: ToString> From<S> for SourceTree {
499 fn from(source: S) -> Self {
500 SourceTree::single(PathBuf::from(""), source.to_string())
501 }
502}
503
504pub mod internal {
506 use super::*;
507
508 pub fn pl_to_lineage(
510 pl: pr::ModuleDef,
511 ) -> Result<semantic::reporting::FrameCollector, ErrorMessages> {
512 let ast = Some(pl.clone());
513
514 let root_module = semantic::resolve(pl).map_err(ErrorMessages::from)?;
515
516 let (main, _) = root_module.find_main_rel(&[]).unwrap();
517 let mut fc =
518 semantic::reporting::collect_frames(*main.clone().into_relation_var().unwrap());
519 fc.ast = ast;
520
521 Ok(fc)
522 }
523
524 pub mod json {
525 use super::*;
526
527 pub fn from_lineage(
529 fc: &semantic::reporting::FrameCollector,
530 ) -> Result<String, ErrorMessages> {
531 serde_json::to_string(fc).map_err(convert_json_err)
532 }
533
534 fn convert_json_err(err: serde_json::Error) -> ErrorMessages {
535 ErrorMessages::from(Error::new_simple(err.to_string()))
536 }
537 }
538}
539
540#[cfg(test)]
541mod tests {
542 use std::str::FromStr;
543
544 use insta::assert_debug_snapshot;
545
546 use crate::pr::Ident;
547 use crate::Target;
548
549 pub fn compile(prql: &str) -> Result<String, super::ErrorMessages> {
550 anstream::ColorChoice::Never.write_global();
551 super::compile(prql, &super::Options::default().no_signature())
552 }
553
554 #[test]
555 fn test_starts_with() {
556 let a = Ident::from_path(vec!["a", "b", "c"]);
558 let b = Ident::from_path(vec!["a", "b"]);
559 let c = Ident::from_path(vec!["a", "b", "c", "d"]);
560 let d = Ident::from_path(vec!["a", "b", "d"]);
561 let e = Ident::from_path(vec!["a", "c"]);
562 let f = Ident::from_path(vec!["b", "c"]);
563 assert!(a.starts_with(&b));
564 assert!(a.starts_with(&a));
565 assert!(!a.starts_with(&c));
566 assert!(!a.starts_with(&d));
567 assert!(!a.starts_with(&e));
568 assert!(!a.starts_with(&f));
569 }
570
571 #[test]
572 fn test_target_from_str() {
573 assert_debug_snapshot!(Target::from_str("sql.postgres"), @r"
574 Ok(
575 Sql(
576 Some(
577 Postgres,
578 ),
579 ),
580 )
581 ");
582
583 assert_debug_snapshot!(Target::from_str("sql.poostgres"), @r#"
584 Err(
585 Error {
586 kind: Error,
587 span: None,
588 reason: NotFound {
589 name: "\"sql.poostgres\"",
590 namespace: "target",
591 },
592 hints: [],
593 code: None,
594 },
595 )
596 "#);
597
598 assert_debug_snapshot!(Target::from_str("postgres"), @r#"
599 Err(
600 Error {
601 kind: Error,
602 span: None,
603 reason: NotFound {
604 name: "\"postgres\"",
605 namespace: "target",
606 },
607 hints: [],
608 code: None,
609 },
610 )
611 "#);
612 }
613
614 #[test]
616 fn test_target_names() {
617 let _: Vec<_> = Target::names()
618 .into_iter()
619 .map(|name| Target::from_str(&name))
620 .collect();
621 }
622
623 #[test]
629 fn test_sort_not_propagated_after_join() {
630 use insta::assert_snapshot;
631
632 assert_snapshot!(
635 super::compile(
636 r#"
637 prql target:sql.postgres
638
639 from tracks
640 group media_type_id (
641 sort name
642 take 1
643 )
644 join media_types (== media_type_id)
645 select {
646 tracks.track_id,
647 media_types.name
648 }
649 "#,
650 &super::Options::default().no_signature()
651 ).unwrap(),
652 @"
653 WITH table_0 AS (
654 SELECT
655 DISTINCT ON (media_type_id) track_id,
656 media_type_id,
657 name
658 FROM
659 tracks
660 ORDER BY
661 media_type_id,
662 name
663 )
664 SELECT
665 table_0.track_id,
666 media_types.name
667 FROM
668 table_0
669 INNER JOIN media_types ON table_0.media_type_id = media_types.media_type_id
670 "
671 );
672 }
673
674 #[test]
680 fn test_explicit_sort_after_distinct_on_preserved() {
681 use insta::assert_snapshot;
682
683 assert_snapshot!(
686 super::compile(
687 r#"
688 prql target:sql.postgres
689
690 from tracks
691 group media_type_id (
692 sort name
693 take 1
694 )
695 sort media_type_id
696 join media_types (== media_type_id)
697 select {
698 tracks.track_id,
699 media_types.name
700 }
701 "#,
702 &super::Options::default().no_signature()
703 ).unwrap(),
704 @"
705 WITH table_0 AS (
706 SELECT
707 DISTINCT ON (media_type_id) track_id,
708 media_type_id,
709 name
710 FROM
711 tracks
712 ORDER BY
713 media_type_id,
714 name
715 ),
716 table_1 AS (
717 SELECT
718 table_0.track_id,
719 media_types.name,
720 table_0.media_type_id
721 FROM
722 table_0
723 INNER JOIN media_types ON table_0.media_type_id = media_types.media_type_id
724 )
725 SELECT
726 track_id,
727 name
728 FROM
729 table_1
730 ORDER BY
731 media_type_id
732 "
733 );
734 }
735}