palpo_data/full_text_search/
mod.rs

1// forked from https://github.com/diesel-rs/diesel_full_text_search/
2mod types {
3    use std::io::{BufRead, Cursor};
4
5    use byteorder::{NetworkEndian, ReadBytesExt};
6    use diesel::{Queryable, deserialize::FromSql, pg::Pg, sql_types::*};
7
8    #[derive(Clone, Copy, SqlType)]
9    #[diesel(postgres_type(oid = 3615, array_oid = 3645))]
10    pub struct TsQuery;
11
12    #[derive(Clone, Copy, SqlType)]
13    #[diesel(postgres_type(oid = 3614, array_oid = 3643))]
14    pub struct TsVector;
15    pub type Tsvector = TsVector;
16
17    pub trait TextOrNullableText {}
18
19    impl TextOrNullableText for Text {}
20    impl TextOrNullableText for Nullable<Text> {}
21    impl TextOrNullableText for TsVector {}
22    impl TextOrNullableText for Nullable<TsVector> {}
23
24    #[derive(SqlType)]
25    #[diesel(postgres_type(name = "regconfig"))]
26    pub struct RegConfig;
27
28    impl FromSql<TsVector, Pg> for PgTsVector {
29        fn from_sql(bytes: <Pg as diesel::backend::Backend>::RawValue<'_>) -> diesel::deserialize::Result<Self> {
30            let mut cursor = Cursor::new(bytes.as_bytes());
31
32            // From Postgres `tsvector.c`:
33            //
34            //     The binary format is as follows:
35            //
36            //     uint32   number of lexemes
37            //
38            //     for each lexeme:
39            //          lexeme text in client encoding, null-terminated
40            //          uint16  number of positions
41            //          for each position:
42            //              uint16 WordEntryPos
43
44            // Number of lexemes (uint32)
45            let num_lexemes = cursor.read_u32::<NetworkEndian>()?;
46
47            let mut entries = Vec::with_capacity(num_lexemes as usize);
48
49            for _ in 0..num_lexemes {
50                let mut lexeme = Vec::new();
51                cursor.read_until(0, &mut lexeme)?;
52                // Remove null terminator
53                lexeme.pop();
54                let lexeme = String::from_utf8(lexeme)?;
55
56                // Number of positions (uint16)
57                let num_positions = cursor.read_u16::<NetworkEndian>()?;
58
59                let mut positions = Vec::with_capacity(num_positions as usize);
60                for _ in 0..num_positions {
61                    positions.push(cursor.read_u16::<NetworkEndian>()?);
62                }
63
64                entries.push(PgTsVectorEntry { lexeme, positions });
65            }
66
67            Ok(PgTsVector { entries })
68        }
69    }
70
71    impl Queryable<TsVector, Pg> for PgTsVector {
72        type Row = Self;
73
74        fn build(row: Self::Row) -> diesel::deserialize::Result<Self> {
75            Ok(row)
76        }
77    }
78
79    #[derive(Debug, Clone, PartialEq)]
80    pub struct PgTsVector {
81        pub entries: Vec<PgTsVectorEntry>,
82    }
83
84    #[derive(Debug, Clone, PartialEq)]
85    pub struct PgTsVectorEntry {
86        pub lexeme: String,
87        pub positions: Vec<u16>,
88    }
89}
90
91pub mod configuration {
92    use crate::full_text_search::RegConfig;
93
94    use diesel::backend::Backend;
95    use diesel::deserialize::{self, FromSql, FromSqlRow};
96    use diesel::expression::{ValidGrouping, is_aggregate};
97    use diesel::pg::{Pg, PgValue};
98    use diesel::query_builder::{AstPass, QueryFragment, QueryId};
99    use diesel::serialize::{self, Output, ToSql};
100    use diesel::sql_types::Integer;
101    use diesel::{AppearsOnTable, Expression, QueryResult, SelectableExpression};
102
103    #[derive(Debug, PartialEq, Eq, diesel::expression::AsExpression, FromSqlRow)]
104    #[diesel(sql_type = RegConfig)]
105    pub struct TsConfiguration(pub u32);
106
107    impl TsConfiguration {
108        pub const SIMPLE: Self = Self(3748);
109        pub const DANISH: Self = Self(12824);
110        pub const DUTCH: Self = Self(12826);
111        pub const ENGLISH: Self = Self(12828);
112        pub const FINNISH: Self = Self(12830);
113        pub const FRENCH: Self = Self(12832);
114        pub const GERMAN: Self = Self(12834);
115        pub const HUNGARIAN: Self = Self(12836);
116        pub const ITALIAN: Self = Self(12838);
117        pub const NORWEGIAN: Self = Self(12840);
118        pub const PORTUGUESE: Self = Self(12842);
119        pub const ROMANIAN: Self = Self(12844);
120        pub const RUSSIAN: Self = Self(12846);
121        pub const SPANISH: Self = Self(12848);
122        pub const SWEDISH: Self = Self(12850);
123        pub const TURKISH: Self = Self(12852);
124    }
125
126    impl FromSql<RegConfig, Pg> for TsConfiguration
127    where
128        i32: FromSql<Integer, Pg>,
129    {
130        fn from_sql(bytes: PgValue) -> deserialize::Result<Self> {
131            <i32 as FromSql<Integer, Pg>>::from_sql(bytes).map(|oid| TsConfiguration(oid as u32))
132        }
133    }
134
135    impl ToSql<RegConfig, Pg> for TsConfiguration
136    where
137        i32: ToSql<Integer, Pg>,
138    {
139        fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
140            <i32 as ToSql<Integer, Pg>>::to_sql(&(self.0 as i32), &mut out.reborrow())
141        }
142    }
143
144    #[derive(Debug, Copy, Clone, PartialEq, Eq)]
145    pub struct TsConfigurationByName(pub &'static str);
146
147    impl<DB> QueryFragment<DB> for TsConfigurationByName
148    where
149        DB: Backend,
150    {
151        fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
152            out.push_sql(&format!("'{}'", &self.0));
153            Ok(())
154        }
155    }
156
157    impl<GB> ValidGrouping<GB> for TsConfigurationByName {
158        type IsAggregate = is_aggregate::Never;
159    }
160
161    impl QueryId for TsConfigurationByName {
162        const HAS_STATIC_QUERY_ID: bool = false;
163
164        type QueryId = ();
165    }
166
167    impl<QS> SelectableExpression<QS> for TsConfigurationByName where Self: Expression {}
168
169    impl<QS> AppearsOnTable<QS> for TsConfigurationByName where Self: Expression {}
170
171    impl Expression for TsConfigurationByName {
172        type SqlType = RegConfig;
173    }
174}
175
176#[allow(deprecated)]
177mod functions {
178    use crate::full_text_search::types::*;
179    use diesel::define_sql_function;
180    use diesel::sql_types::*;
181
182    define_sql_function!(fn length(x: TsVector) -> Integer);
183    define_sql_function!(fn numnode(x: TsQuery) -> Integer);
184    define_sql_function!(fn plainto_tsquery(x: Text) -> TsQuery);
185    define_sql_function! {
186        #[sql_name = "plainto_tsquery"]
187        fn plainto_tsquery_with_search_config(config: RegConfig, querytext: Text) -> TsQuery;
188    }
189    define_sql_function!(fn querytree(x: TsQuery) -> Text);
190    define_sql_function!(fn strip(x: TsVector) -> TsVector);
191    define_sql_function!(fn to_tsquery(x: Text) -> TsQuery);
192    define_sql_function! {
193        #[sql_name = "to_tsquery"]
194        fn to_tsquery_with_search_config(config: RegConfig, querytext: Text) -> TsQuery;
195    }
196    define_sql_function!(fn to_tsvector<T: TextOrNullableText + SingleValue>(x: T) -> TsVector);
197    define_sql_function! {
198        #[sql_name = "to_tsvector"]
199        fn to_tsvector_with_search_config<T: TextOrNullableText + SingleValue>(config: RegConfig, document_content: T) -> TsVector;
200    }
201    define_sql_function!(fn ts_headline(x: Text, y: TsQuery) -> Text);
202    define_sql_function! {
203        #[sql_name = "ts_headline"]
204        fn ts_headline_with_search_config(config: RegConfig, x: Text, y: TsQuery) -> Text;
205    }
206    define_sql_function!(fn ts_rank(x: TsVector, y: TsQuery) -> Float);
207    define_sql_function!(fn ts_rank_cd(x: TsVector, y: TsQuery) -> Float);
208    define_sql_function! {
209        #[sql_name = "ts_rank_cd"]
210        fn ts_rank_cd_weighted(w: Array<Float>, x: TsVector, y: TsQuery) -> Float;
211    }
212    define_sql_function! {
213        #[sql_name = "ts_rank_cd"]
214        fn ts_rank_cd_normalized(x: TsVector, y: TsQuery, n: Integer) -> Float;
215    }
216    define_sql_function! {
217        #[sql_name = "ts_rank_cd"]
218        fn ts_rank_cd_weighted_normalized(w: Array<Float>, x: TsVector, y: TsQuery, n: Integer) -> Float;
219    }
220    define_sql_function!(fn phraseto_tsquery(x: Text) -> TsQuery);
221    define_sql_function!(fn websearch_to_tsquery(x: Text) -> TsQuery);
222    define_sql_function! {
223        #[sql_name = "websearch_to_tsquery"]
224        fn websearch_to_tsquery_with_search_config(config: RegConfig, x: Text) -> TsQuery;
225    }
226    define_sql_function!(fn setweight(x: TsVector, w: CChar) -> TsVector);
227}
228
229mod dsl {
230    use crate::full_text_search::types::*;
231    use diesel::expression::{AsExpression, Expression};
232
233    mod predicates {
234        use crate::full_text_search::types::*;
235        use diesel::pg::Pg;
236
237        diesel::infix_operator!(Matches, " @@ ", backend: Pg);
238        diesel::infix_operator!(Concat, " || ", TsVector, backend: Pg);
239        diesel::infix_operator!(And, " && ", TsQuery, backend: Pg);
240        diesel::infix_operator!(Or, " || ", TsQuery, backend: Pg);
241        diesel::infix_operator!(Contains, " @> ", backend: Pg);
242        diesel::infix_operator!(ContainedBy, " <@ ", backend: Pg);
243    }
244
245    use self::predicates::*;
246
247    pub trait TsVectorExtensions: Expression<SqlType = TsVector> + Sized {
248        fn matches<T: AsExpression<TsQuery>>(self, other: T) -> Matches<Self, T::Expression> {
249            Matches::new(self, other.as_expression())
250        }
251
252        fn concat<T: AsExpression<TsVector>>(self, other: T) -> Concat<Self, T::Expression> {
253            Concat::new(self, other.as_expression())
254        }
255    }
256
257    pub trait TsQueryExtensions: Expression<SqlType = TsQuery> + Sized {
258        fn matches<T: AsExpression<TsVector>>(self, other: T) -> Matches<Self, T::Expression> {
259            Matches::new(self, other.as_expression())
260        }
261
262        fn and<T: AsExpression<TsQuery>>(self, other: T) -> And<Self, T::Expression> {
263            And::new(self, other.as_expression())
264        }
265
266        fn or<T: AsExpression<TsQuery>>(self, other: T) -> Or<Self, T::Expression> {
267            Or::new(self, other.as_expression())
268        }
269
270        fn contains<T: AsExpression<TsQuery>>(self, other: T) -> Contains<Self, T::Expression> {
271            Contains::new(self, other.as_expression())
272        }
273
274        fn contained_by<T: AsExpression<TsQuery>>(self, other: T) -> ContainedBy<Self, T::Expression> {
275            ContainedBy::new(self, other.as_expression())
276        }
277    }
278
279    impl<T: Expression<SqlType = TsVector>> TsVectorExtensions for T {}
280
281    impl<T: Expression<SqlType = TsQuery>> TsQueryExtensions for T {}
282}
283
284pub use self::dsl::*;
285pub use self::functions::*;
286pub use self::types::*;
287
288mod tests {
289    #[test]
290    fn test_tsvector_from_sql_with_positions() {
291        let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
292        let mut conn = PgConnection::establish(&database_url).expect("Error connecting to database");
293
294        let query = diesel::select(sql::<TsVector>(
295            "to_tsvector('a fat cat sat on a mat and ate a fat rat')",
296        ));
297        let result: PgTsVector = query.get_result(&mut conn).expect("Error executing query");
298
299        let expected = PgTsVector {
300            entries: vec![
301                PgTsVectorEntry {
302                    lexeme: "ate".to_owned(),
303                    positions: vec![9],
304                },
305                PgTsVectorEntry {
306                    lexeme: "cat".to_owned(),
307                    positions: vec![3],
308                },
309                PgTsVectorEntry {
310                    lexeme: "fat".to_owned(),
311                    positions: vec![2, 11],
312                },
313                PgTsVectorEntry {
314                    lexeme: "mat".to_owned(),
315                    positions: vec![7],
316                },
317                PgTsVectorEntry {
318                    lexeme: "rat".to_owned(),
319                    positions: vec![12],
320                },
321                PgTsVectorEntry {
322                    lexeme: "sat".to_owned(),
323                    positions: vec![4],
324                },
325            ],
326        };
327
328        assert_eq!(expected, result);
329    }
330
331    #[test]
332    fn test_tsvector_from_sql_without_positions() {
333        let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
334        let mut conn = PgConnection::establish(&database_url).expect("Error connecting to database");
335
336        let query = diesel::select(sql::<TsVector>("'a fat cat sat on a mat and ate a fat rat'::tsvector"));
337        let result: PgTsVector = query.get_result(&mut conn).expect("Error executing query");
338
339        let expected = PgTsVector {
340            entries: vec![
341                PgTsVectorEntry {
342                    lexeme: "a".to_owned(),
343                    positions: vec![],
344                },
345                PgTsVectorEntry {
346                    lexeme: "and".to_owned(),
347                    positions: vec![],
348                },
349                PgTsVectorEntry {
350                    lexeme: "ate".to_owned(),
351                    positions: vec![],
352                },
353                PgTsVectorEntry {
354                    lexeme: "cat".to_owned(),
355                    positions: vec![],
356                },
357                PgTsVectorEntry {
358                    lexeme: "fat".to_owned(),
359                    positions: vec![],
360                },
361                PgTsVectorEntry {
362                    lexeme: "mat".to_owned(),
363                    positions: vec![],
364                },
365                PgTsVectorEntry {
366                    lexeme: "on".to_owned(),
367                    positions: vec![],
368                },
369                PgTsVectorEntry {
370                    lexeme: "rat".to_owned(),
371                    positions: vec![],
372                },
373                PgTsVectorEntry {
374                    lexeme: "sat".to_owned(),
375                    positions: vec![],
376                },
377            ],
378        };
379
380        assert_eq!(expected, result);
381    }
382}