use std::fmt::Write;
use crate::{
BigInt, Nullable, Text,
expression::{Column, Expression, In, InList, LowerIn},
lower::{Data, Instructions, JsonContainsLower, LowerCtx, Params},
param::Param,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum JsonPathSegment<'a> {
Key(&'a str),
Index(&'a str),
}
fn parse_json_path(path: &str) -> Vec<JsonPathSegment<'_>> {
assert!(!path.is_empty(), "json path cannot be empty");
path.split("->")
.map(|segment| {
assert!(!segment.is_empty(), "json path contains an empty segment");
if segment.bytes().all(|byte| byte.is_ascii_digit()) {
JsonPathSegment::Index(segment)
} else {
assert!(
segment
.bytes()
.all(|byte| byte.is_ascii_alphanumeric() || byte == b'_'),
"json path segment `{segment}` contains unsupported characters",
);
JsonPathSegment::Key(segment)
}
})
.collect()
}
fn write_json_string(out: &mut String, value: &str) {
out.push('"');
for ch in value.chars() {
match ch {
'"' => out.push_str("\\\""),
'\\' => out.push_str("\\\\"),
'\u{08}' => out.push_str("\\b"),
'\u{0C}' => out.push_str("\\f"),
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
'\t' => out.push_str("\\t"),
ch if ch <= '\u{1F}' => {
let _ = write!(out, "\\u{:04x}", ch as u32);
}
ch => out.push(ch),
}
}
out.push('"');
}
#[doc(hidden)]
pub trait JsonLiteral {
fn write_json(&self, out: &mut String);
fn is_compound(&self) -> bool {
false
}
}
impl JsonLiteral for str {
fn write_json(&self, out: &mut String) {
write_json_string(out, self);
}
}
impl JsonLiteral for String {
fn write_json(&self, out: &mut String) {
self.as_str().write_json(out);
}
}
impl JsonLiteral for &str {
fn write_json(&self, out: &mut String) {
(*self).write_json(out);
}
}
impl JsonLiteral for bool {
fn write_json(&self, out: &mut String) {
out.push_str(if *self { "true" } else { "false" });
}
}
macro_rules! impl_json_number {
($($ty:ty),+ $(,)?) => {
$(
impl JsonLiteral for $ty {
fn write_json(&self, out: &mut String) {
let _ = write!(out, "{}", self);
}
}
)+
};
}
impl_json_number!(i8, i16, i32, i64, isize, u8, u16, u32, u64, usize, f32, f64);
impl<T> JsonLiteral for Vec<T>
where
T: JsonLiteral,
{
fn write_json(&self, out: &mut String) {
out.push('[');
for (index, item) in self.iter().enumerate() {
if index > 0 {
out.push(',');
}
item.write_json(out);
}
out.push(']');
}
fn is_compound(&self) -> bool {
true
}
}
impl<T, const N: usize> JsonLiteral for [T; N]
where
T: JsonLiteral,
{
fn write_json(&self, out: &mut String) {
out.push('[');
for (index, item) in self.iter().enumerate() {
if index > 0 {
out.push(',');
}
item.write_json(out);
}
out.push(']');
}
fn is_compound(&self) -> bool {
true
}
}
#[cfg(feature = "serde_json")]
impl JsonLiteral for serde_json::Value {
fn write_json(&self, out: &mut String) {
out.push_str(&serde_json::to_string(self).expect("json serialization should succeed"));
}
fn is_compound(&self) -> bool {
matches!(
self,
serde_json::Value::Array(_) | serde_json::Value::Object(_)
)
}
}
pub struct JsonPath<E> {
inner: E,
path: &'static str,
}
pub struct JsonContains<E, V> {
inner: E,
path: &'static str,
value: V,
negated: bool,
}
pub struct JsonContainsKey<E> {
inner: E,
path: &'static str,
negated: bool,
}
pub struct JsonLength<E> {
inner: E,
path: &'static str,
}
pub trait JsonPathExt: Sized {
fn json_path(self, path: &'static str) -> JsonPath<Self>;
}
impl<M> JsonPathExt for Column<M, Text> {
fn json_path(self, path: &'static str) -> JsonPath<Self> {
let _ = parse_json_path(path);
JsonPath { inner: self, path }
}
}
impl<M> JsonPathExt for Column<M, Nullable<Text>> {
fn json_path(self, path: &'static str) -> JsonPath<Self> {
let _ = parse_json_path(path);
JsonPath { inner: self, path }
}
}
impl<E> JsonPath<E> {
pub fn one_of<R>(self, list: R) -> InList<Self, R>
where
Self: Expression<Type = Nullable<Text>>,
R: LowerIn<Nullable<Text>>,
{
In::one_of(self, list)
}
pub fn json_contains<V>(self, value: V) -> JsonContains<E, V>
where
V: JsonLiteral,
{
JsonContains {
inner: self.inner,
path: self.path,
value,
negated: false,
}
}
pub fn json_doesnt_contain<V>(self, value: V) -> JsonContains<E, V>
where
V: JsonLiteral,
{
JsonContains {
inner: self.inner,
path: self.path,
value,
negated: true,
}
}
pub fn json_contains_key(self) -> JsonContainsKey<E> {
JsonContainsKey {
inner: self.inner,
path: self.path,
negated: false,
}
}
pub fn json_doesnt_contain_key(self) -> JsonContainsKey<E> {
JsonContainsKey {
inner: self.inner,
path: self.path,
negated: true,
}
}
pub fn json_length(self) -> JsonLength<E> {
JsonLength {
inner: self.inner,
path: self.path,
}
}
}
#[qraft_expression_macro::as_expression]
impl<E> Expression for JsonPath<E>
where
E: Expression,
{
type Type = Nullable<Text>;
fn lower(&self, ctx: &mut LowerCtx) -> usize {
let inner = self.inner.lower(ctx);
ctx.lower_json_extract_text(inner, self.path)
}
}
#[qraft_expression_macro::as_expression]
impl<E, V> Expression for JsonContains<E, V>
where
E: Expression,
V: JsonLiteral,
{
type Type = crate::Bool;
fn lower(&self, ctx: &mut LowerCtx) -> usize {
let inner = self.inner.lower(ctx);
let mut json = String::new();
self.value.write_json(&mut json);
let span = ctx.data.intern_text(&json);
let id = ctx.params.push_param(Param::Text(Some(span)));
ctx.instrs.push_param(id);
ctx.lower_json_contains(
inner,
JsonContainsLower {
rhs: 1,
path: self.path,
negated: self.negated,
compound: self.value.is_compound(),
},
)
}
}
#[qraft_expression_macro::as_expression]
impl<E> Expression for JsonContainsKey<E>
where
E: Expression,
{
type Type = crate::Bool;
fn lower(&self, ctx: &mut LowerCtx) -> usize {
let inner = self.inner.lower(ctx);
ctx.lower_json_contains_key(inner, self.path, self.negated)
}
}
#[qraft_expression_macro::as_expression]
impl<E> Expression for JsonLength<E>
where
E: Expression,
{
type Type = BigInt;
fn lower(&self, ctx: &mut LowerCtx) -> usize {
let inner = self.inner.lower(ctx);
ctx.lower_json_length(inner, self.path)
}
}
#[cfg(test)]
mod tests {
use super::JsonPathExt;
use crate::{
MySql, Postgres, Sqlite, Text,
expression::{EqExt, OrderExt, col},
query::select_all,
};
#[test]
fn json_path_eq_sqlite_sql() {
let preferences = col::<Text>("preferences");
let sql = select_all()
.from("users")
.filter(preferences.json_path("dining->meal").eq("salad"))
.to_debug_sql::<Sqlite>();
assert_eq!(
sql,
r#"select * from "users" where json_extract("preferences", '$.dining.meal') = ?; params=["salad"]"#
);
}
#[test]
fn json_path_eq_postgres_sql() {
let preferences = col::<Text>("preferences");
let sql = select_all()
.from("users")
.filter(preferences.json_path("dining->meal").eq("salad"))
.to_debug_sql::<Postgres>();
assert_eq!(
sql,
r#"select * from "users" where jsonb_extract_path_text(("preferences")::jsonb, 'dining', 'meal') = $1; params=["salad"]"#
);
}
#[test]
fn json_path_eq_mariadb_sql() {
let preferences = col::<Text>("preferences");
let sql = select_all()
.from("users")
.filter(preferences.json_path("dining->meal").eq("salad"))
.to_debug_sql::<MySql>();
assert_eq!(
sql,
"select * from `users` where json_unquote(json_extract(`preferences`, '$.dining.meal')) = ?; params=[\"salad\"]"
);
}
#[test]
fn json_path_one_of_sqlite_sql() {
let preferences = col::<Text>("preferences");
let sql = select_all()
.from("users")
.filter(
preferences
.json_path("dining->meal")
.one_of(["pasta", "salad"]),
)
.to_debug_sql::<Sqlite>();
assert_eq!(
sql,
r#"select * from "users" where json_extract("preferences", '$.dining.meal') in (?, ?); params=["pasta", "salad"]"#
);
}
#[test]
fn json_contains_key_and_length_sqlite_sql() {
let options = col::<Text>("options");
let contains = select_all()
.from("users")
.filter(options.json_path("languages").json_contains_key())
.to_debug_sql::<Sqlite>();
let length = select_all()
.from("users")
.filter(options.json_path("languages").json_length().gt(1_i64))
.to_debug_sql::<Sqlite>();
assert_eq!(
contains,
r#"select * from "users" where json_type("options", '$.languages') is not null; params=[]"#
);
assert_eq!(
length,
r#"select * from "users" where json_array_length(json_extract("options", '$.languages')) > ?; params=[1]"#
);
}
#[test]
fn json_contains_sqlite_sql() {
let options = col::<Text>("options");
let sql = select_all()
.from("users")
.filter(options.json_path("languages").json_contains("en"))
.to_debug_sql::<Sqlite>();
assert_eq!(
sql,
r#"select * from "users" where exists (select 1 from json_each(json_extract("options", '$.languages')) where json_each.value = json_extract(?, '$')); params=[""en""]"#
);
}
}