intent_engine/
search.rs

1//! Search utilities for intent-engine
2//!
3//! This module provides:
4//! 1. CJK (Chinese, Japanese, Korean) search utilities for detecting when to use
5//!    LIKE fallback vs FTS5 trigram search
6//! 2. Unified search across tasks and events
7//!
8//! **Background**: SQLite FTS5 with trigram tokenizer requires at least 3 consecutive
9//! characters to match. This is problematic for CJK languages where single-character
10//! or two-character searches are common (e.g., "用户", "认证").
11//!
12//! **Solution**: For short CJK queries, we fallback to LIKE search which supports
13//! any length substring matching, albeit slower.
14
15/// Check if a character is a CJK character
16pub fn is_cjk_char(c: char) -> bool {
17    let code = c as u32;
18    matches!(code,
19        // CJK Unified Ideographs (most common Chinese characters)
20        0x4E00..=0x9FFF |
21        // CJK Extension A
22        0x3400..=0x4DBF |
23        // CJK Extension B-F (less common, but included for completeness)
24        0x20000..=0x2A6DF |
25        0x2A700..=0x2B73F |
26        0x2B740..=0x2B81F |
27        0x2B820..=0x2CEAF |
28        0x2CEB0..=0x2EBEF |
29        // Hiragana (Japanese)
30        0x3040..=0x309F |
31        // Katakana (Japanese)
32        0x30A0..=0x30FF |
33        // Hangul Syllables (Korean)
34        0xAC00..=0xD7AF
35    )
36}
37
38/// Determine if a query should use LIKE fallback instead of FTS5 trigram
39///
40/// Returns `true` if:
41/// - Query is a single CJK character, OR
42/// - Query is two CJK characters
43///
44/// Trigram tokenizer requires 3+ characters for matching, so we use LIKE
45/// for shorter CJK queries to ensure they work.
46pub fn needs_like_fallback(query: &str) -> bool {
47    let chars: Vec<char> = query.chars().collect();
48
49    // Single-character CJK
50    if chars.len() == 1 && is_cjk_char(chars[0]) {
51        return true;
52    }
53
54    // Two-character all-CJK
55    // This is optional - could also let trigram handle it, but trigram
56    // needs minimum 3 chars so two-char CJK won't work well
57    if chars.len() == 2 && chars.iter().all(|c| is_cjk_char(*c)) {
58        return true;
59    }
60
61    false
62}
63
64/// Escape FTS5 special characters in a query string
65///
66/// FTS5 queries support advanced syntax (AND, OR, NOT, *, "phrase search", etc.).
67/// This function only escapes double quotes, which is the most common case where
68/// user input needs escaping.
69///
70/// # Arguments
71/// * `query` - The query string to escape
72///
73/// # Returns
74/// The escaped query string with double quotes escaped as `""`
75///
76/// # Example
77/// ```ignore
78/// use crate::search::escape_fts5;
79///
80/// let escaped = escape_fts5("user \"admin\" role");
81/// assert_eq!(escaped, "user \"\"admin\"\" role");
82/// ```
83pub fn escape_fts5(query: &str) -> String {
84    query.replace('"', "\"\"")
85}
86
87// ============================================================================
88// Unified Search
89// ============================================================================
90
91use crate::db::models::{Event, PaginatedSearchResults, SearchResult, Task};
92use crate::error::Result;
93use crate::tasks::TaskManager;
94use sqlx::{Row, SqlitePool};
95
96pub struct SearchManager<'a> {
97    pool: &'a SqlitePool,
98}
99
100impl<'a> SearchManager<'a> {
101    pub fn new(pool: &'a SqlitePool) -> Self {
102        Self { pool }
103    }
104
105    /// Unified search across tasks and events with pagination support
106    ///
107    /// This is the new unified search method that replaces unified_search().
108    /// Key improvements:
109    /// - Pagination support (limit, offset)
110    /// - Flexible result inclusion (tasks, events)
111    /// - Optional priority-based secondary sorting
112    /// - Returns PaginatedSearchResults with metadata
113    ///
114    /// # Parameters
115    /// - `query`: FTS5 search query string
116    /// - `include_tasks`: Whether to search in tasks (default: true)
117    /// - `include_events`: Whether to search in events (default: true)
118    /// - `limit`: Maximum number of results per source (default: 20)
119    /// - `offset`: Number of results to skip (default: 0)
120    /// - `sort_by_priority`: Enable priority-based secondary sorting (default: false)
121    ///
122    /// # Returns
123    /// PaginatedSearchResults with mixed task and event results, ordered by relevance (FTS5 rank)
124    pub async fn search(
125        &self,
126        query: &str,
127        include_tasks: bool,
128        include_events: bool,
129        limit: Option<i64>,
130        offset: Option<i64>,
131        sort_by_priority: bool,
132    ) -> Result<PaginatedSearchResults> {
133        let limit = limit.unwrap_or(20);
134        let offset = offset.unwrap_or(0);
135
136        // Handle empty or whitespace-only queries
137        if query.trim().is_empty() {
138            return Ok(PaginatedSearchResults {
139                results: Vec::new(),
140                total_tasks: 0,
141                total_events: 0,
142                has_more: false,
143                limit,
144                offset,
145            });
146        }
147
148        // Handle queries with no searchable content (only special characters)
149        let has_searchable = query.chars().any(|c| c.is_alphanumeric() || is_cjk_char(c));
150        if !has_searchable {
151            return Ok(PaginatedSearchResults {
152                results: Vec::new(),
153                total_tasks: 0,
154                total_events: 0,
155                has_more: false,
156                limit,
157                offset,
158            });
159        }
160
161        // Escape FTS5 special characters
162        let escaped_query = escape_fts5(query);
163
164        let mut total_tasks: i64 = 0;
165        let mut total_events: i64 = 0;
166        let mut all_results: Vec<(SearchResult, f64)> = Vec::new();
167
168        // Check if we need LIKE fallback for short CJK queries
169        let use_like_fallback = needs_like_fallback(query);
170
171        if use_like_fallback {
172            // LIKE fallback path for short CJK queries (1-2 chars)
173            let like_pattern = format!("%{}%", query);
174
175            // Search tasks if enabled
176            if include_tasks {
177                // Get total count
178                let count_result = sqlx::query_scalar::<_, i64>(
179                    "SELECT COUNT(*) FROM tasks WHERE name LIKE ? OR spec LIKE ?",
180                )
181                .bind(&like_pattern)
182                .bind(&like_pattern)
183                .fetch_one(self.pool)
184                .await?;
185                total_tasks = count_result;
186
187                // Build ORDER BY clause
188                let order_by = if sort_by_priority {
189                    "ORDER BY COALESCE(priority, 0) ASC, id ASC"
190                } else {
191                    "ORDER BY id ASC"
192                };
193
194                // Query tasks with pagination
195                let task_query = format!(
196                    r#"
197                    SELECT
198                        id,
199                        parent_id,
200                        name,
201                        spec,
202                        status,
203                        complexity,
204                        priority,
205                        first_todo_at,
206                        first_doing_at,
207                        first_done_at,
208                        active_form
209                    FROM tasks
210                    WHERE name LIKE ? OR spec LIKE ?
211                    {}
212                    LIMIT ? OFFSET ?
213                    "#,
214                    order_by
215                );
216
217                let rows = sqlx::query(&task_query)
218                    .bind(&like_pattern)
219                    .bind(&like_pattern)
220                    .bind(limit)
221                    .bind(offset)
222                    .fetch_all(self.pool)
223                    .await?;
224
225                for row in rows {
226                    let task = Task {
227                        id: row.get("id"),
228                        parent_id: row.get("parent_id"),
229                        name: row.get("name"),
230                        spec: row.get("spec"),
231                        status: row.get("status"),
232                        complexity: row.get("complexity"),
233                        priority: row.get("priority"),
234                        first_todo_at: row.get("first_todo_at"),
235                        first_doing_at: row.get("first_doing_at"),
236                        first_done_at: row.get("first_done_at"),
237                        active_form: row.get("active_form"),
238                    };
239
240                    // Determine match field and create snippet
241                    let (match_field, match_snippet) = if task.name.contains(query) {
242                        ("name".to_string(), task.name.clone())
243                    } else if let Some(ref spec) = task.spec {
244                        if spec.contains(query) {
245                            ("spec".to_string(), spec.clone())
246                        } else {
247                            ("name".to_string(), task.name.clone())
248                        }
249                    } else {
250                        ("name".to_string(), task.name.clone())
251                    };
252
253                    all_results.push((
254                        SearchResult::Task {
255                            task,
256                            match_snippet,
257                            match_field,
258                        },
259                        1.0, // Constant rank for LIKE results
260                    ));
261                }
262            }
263
264            // Search events if enabled
265            if include_events {
266                // Get total count
267                let count_result = sqlx::query_scalar::<_, i64>(
268                    "SELECT COUNT(*) FROM events WHERE discussion_data LIKE ?",
269                )
270                .bind(&like_pattern)
271                .fetch_one(self.pool)
272                .await?;
273                total_events = count_result;
274
275                // Query events with pagination
276                let rows = sqlx::query(
277                    r#"
278                    SELECT
279                        id,
280                        task_id,
281                        timestamp,
282                        log_type,
283                        discussion_data
284                    FROM events
285                    WHERE discussion_data LIKE ?
286                    ORDER BY id ASC
287                    LIMIT ? OFFSET ?
288                    "#,
289                )
290                .bind(&like_pattern)
291                .bind(limit)
292                .bind(offset)
293                .fetch_all(self.pool)
294                .await?;
295
296                let task_mgr = TaskManager::new(self.pool);
297                for row in rows {
298                    let event = Event {
299                        id: row.get("id"),
300                        task_id: row.get("task_id"),
301                        timestamp: row.get("timestamp"),
302                        log_type: row.get("log_type"),
303                        discussion_data: row.get("discussion_data"),
304                    };
305
306                    // Create match snippet
307                    let match_snippet = event.discussion_data.clone();
308
309                    // Get task ancestry chain for this event
310                    let task_chain = task_mgr.get_task_ancestry(event.task_id).await?;
311
312                    all_results.push((
313                        SearchResult::Event {
314                            event,
315                            task_chain,
316                            match_snippet,
317                        },
318                        1.0, // Constant rank for LIKE results
319                    ));
320                }
321            }
322        } else {
323            // FTS5 path for longer queries (3+ chars)
324            // Search tasks if enabled
325            if include_tasks {
326                // Get total count
327                let count_result = sqlx::query_scalar::<_, i64>(
328                    "SELECT COUNT(*) FROM tasks_fts WHERE tasks_fts MATCH ?",
329                )
330                .bind(&escaped_query)
331                .fetch_one(self.pool)
332                .await?;
333                total_tasks = count_result;
334
335                // Build ORDER BY clause
336                let order_by = if sort_by_priority {
337                    "ORDER BY rank ASC, COALESCE(t.priority, 0) ASC, t.id ASC"
338                } else {
339                    "ORDER BY rank ASC, t.id ASC"
340                };
341
342                // Query tasks with pagination
343                let task_query = format!(
344                    r#"
345                SELECT
346                    t.id,
347                    t.parent_id,
348                    t.name,
349                    t.spec,
350                    t.status,
351                    t.complexity,
352                    t.priority,
353                    t.first_todo_at,
354                    t.first_doing_at,
355                    t.first_done_at,
356                    t.active_form,
357                    COALESCE(
358                        snippet(tasks_fts, 1, '**', '**', '...', 15),
359                        snippet(tasks_fts, 0, '**', '**', '...', 15)
360                    ) as match_snippet,
361                    rank
362                FROM tasks_fts
363                INNER JOIN tasks t ON tasks_fts.rowid = t.id
364                WHERE tasks_fts MATCH ?
365                {}
366                LIMIT ? OFFSET ?
367                "#,
368                    order_by
369                );
370
371                let rows = sqlx::query(&task_query)
372                    .bind(&escaped_query)
373                    .bind(limit)
374                    .bind(offset)
375                    .fetch_all(self.pool)
376                    .await?;
377
378                for row in rows {
379                    let task = Task {
380                        id: row.get("id"),
381                        parent_id: row.get("parent_id"),
382                        name: row.get("name"),
383                        spec: row.get("spec"),
384                        status: row.get("status"),
385                        complexity: row.get("complexity"),
386                        priority: row.get("priority"),
387                        first_todo_at: row.get("first_todo_at"),
388                        first_doing_at: row.get("first_doing_at"),
389                        first_done_at: row.get("first_done_at"),
390                        active_form: row.get("active_form"),
391                    };
392                    let match_snippet: String = row.get("match_snippet");
393                    let rank: f64 = row.get("rank");
394
395                    // Determine match field based on snippet content
396                    let match_field = if task
397                        .spec
398                        .as_ref()
399                        .map(|s| match_snippet.to_lowercase().contains(&s.to_lowercase()))
400                        .unwrap_or(false)
401                    {
402                        "spec".to_string()
403                    } else {
404                        "name".to_string()
405                    };
406
407                    all_results.push((
408                        SearchResult::Task {
409                            task,
410                            match_snippet,
411                            match_field,
412                        },
413                        rank,
414                    ));
415                }
416            }
417
418            // Search events if enabled
419            if include_events {
420                // Get total count
421                let count_result = sqlx::query_scalar::<_, i64>(
422                    "SELECT COUNT(*) FROM events_fts WHERE events_fts MATCH ?",
423                )
424                .bind(&escaped_query)
425                .fetch_one(self.pool)
426                .await?;
427                total_events = count_result;
428
429                // Query events with pagination
430                let rows = sqlx::query(
431                    r#"
432                SELECT
433                    e.id,
434                    e.task_id,
435                    e.timestamp,
436                    e.log_type,
437                    e.discussion_data,
438                    snippet(events_fts, 0, '**', '**', '...', 15) as match_snippet,
439                    rank
440                FROM events_fts
441                INNER JOIN events e ON events_fts.rowid = e.id
442                WHERE events_fts MATCH ?
443                ORDER BY rank ASC, e.id ASC
444                LIMIT ? OFFSET ?
445                "#,
446                )
447                .bind(&escaped_query)
448                .bind(limit)
449                .bind(offset)
450                .fetch_all(self.pool)
451                .await?;
452
453                let task_mgr = TaskManager::new(self.pool);
454                for row in rows {
455                    let event = Event {
456                        id: row.get("id"),
457                        task_id: row.get("task_id"),
458                        timestamp: row.get("timestamp"),
459                        log_type: row.get("log_type"),
460                        discussion_data: row.get("discussion_data"),
461                    };
462                    let match_snippet: String = row.get("match_snippet");
463                    let rank: f64 = row.get("rank");
464
465                    // Get task ancestry chain for this event
466                    let task_chain = task_mgr.get_task_ancestry(event.task_id).await?;
467
468                    all_results.push((
469                        SearchResult::Event {
470                            event,
471                            task_chain,
472                            match_snippet,
473                        },
474                        rank,
475                    ));
476                }
477            }
478        } // End of else block (FTS5 path)
479
480        // Sort all results by rank (relevance)
481        all_results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
482
483        // Extract results without rank
484        let results: Vec<SearchResult> =
485            all_results.into_iter().map(|(result, _)| result).collect();
486
487        // Calculate has_more
488        let total_count = total_tasks + total_events;
489        let has_more = offset + (results.len() as i64) < total_count;
490
491        Ok(PaginatedSearchResults {
492            results,
493            total_tasks,
494            total_events,
495            has_more,
496            limit,
497            offset,
498        })
499    }
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505
506    #[test]
507    fn test_is_cjk_char() {
508        // Chinese characters
509        assert!(is_cjk_char('中'));
510        assert!(is_cjk_char('文'));
511        assert!(is_cjk_char('认'));
512        assert!(is_cjk_char('证'));
513
514        // Japanese Hiragana
515        assert!(is_cjk_char('あ'));
516        assert!(is_cjk_char('い'));
517
518        // Japanese Katakana
519        assert!(is_cjk_char('ア'));
520        assert!(is_cjk_char('イ'));
521
522        // Korean Hangul
523        assert!(is_cjk_char('가'));
524        assert!(is_cjk_char('나'));
525
526        // Non-CJK
527        assert!(!is_cjk_char('a'));
528        assert!(!is_cjk_char('A'));
529        assert!(!is_cjk_char('1'));
530        assert!(!is_cjk_char(' '));
531        assert!(!is_cjk_char('.'));
532    }
533
534    #[test]
535    fn test_needs_like_fallback() {
536        // Single CJK character - needs fallback
537        assert!(needs_like_fallback("中"));
538        assert!(needs_like_fallback("认"));
539        assert!(needs_like_fallback("あ"));
540        assert!(needs_like_fallback("가"));
541
542        // Two CJK characters - needs fallback
543        assert!(needs_like_fallback("中文"));
544        assert!(needs_like_fallback("认证"));
545        assert!(needs_like_fallback("用户"));
546
547        // Three+ CJK characters - can use FTS5
548        assert!(!needs_like_fallback("用户认"));
549        assert!(!needs_like_fallback("用户认证"));
550
551        // English - can use FTS5
552        assert!(!needs_like_fallback("JWT"));
553        assert!(!needs_like_fallback("auth"));
554        assert!(!needs_like_fallback("a")); // Single ASCII char, not CJK
555
556        // Mixed - can use FTS5
557        assert!(!needs_like_fallback("JWT认证"));
558        assert!(!needs_like_fallback("API接口"));
559    }
560
561    #[test]
562    fn test_needs_like_fallback_mixed_cjk_ascii() {
563        // Two characters: one CJK + one ASCII - should NOT need fallback
564        // because not all chars are CJK
565        assert!(!needs_like_fallback("中a"));
566        assert!(!needs_like_fallback("a中"));
567        assert!(!needs_like_fallback("認1"));
568
569        // Three+ characters with mixed CJK/ASCII - can use FTS5
570        assert!(!needs_like_fallback("中文API"));
571        assert!(!needs_like_fallback("JWT认证系统"));
572        assert!(!needs_like_fallback("API中文文档"));
573    }
574
575    #[test]
576    fn test_needs_like_fallback_edge_cases() {
577        // Empty string - no fallback needed
578        assert!(!needs_like_fallback(""));
579
580        // Whitespace only - no fallback
581        assert!(!needs_like_fallback(" "));
582        assert!(!needs_like_fallback("  "));
583
584        // Single non-CJK - no fallback
585        assert!(!needs_like_fallback("1"));
586        assert!(!needs_like_fallback("@"));
587        assert!(!needs_like_fallback(" "));
588
589        // Two non-CJK - no fallback
590        assert!(!needs_like_fallback("ab"));
591        assert!(!needs_like_fallback("12"));
592    }
593
594    #[test]
595    fn test_is_cjk_char_extension_ranges() {
596        // CJK Extension A (U+3400..U+4DBF)
597        assert!(is_cjk_char('\u{3400}')); // First char of Extension A
598        assert!(is_cjk_char('\u{4DBF}')); // Last char of Extension A
599
600        // CJK Unified Ideographs (U+4E00..U+9FFF) - common range
601        assert!(is_cjk_char('\u{4E00}')); // First common CJK
602        assert!(is_cjk_char('\u{9FFF}')); // Last common CJK
603
604        // Characters just outside ranges - should NOT be CJK
605        assert!(!is_cjk_char('\u{33FF}')); // Just before Extension A
606        assert!(!is_cjk_char('\u{4DC0}')); // Just after Extension A
607        assert!(!is_cjk_char('\u{4DFF}')); // Just before Unified Ideographs
608        assert!(!is_cjk_char('\u{A000}')); // Just after Unified Ideographs
609    }
610
611    #[test]
612    fn test_is_cjk_char_japanese() {
613        // Hiragana range (U+3040..U+309F)
614        assert!(is_cjk_char('\u{3040}')); // First Hiragana
615        assert!(is_cjk_char('ひ')); // Middle Hiragana
616        assert!(is_cjk_char('\u{309F}')); // Last Hiragana
617
618        // Katakana range (U+30A0..U+30FF)
619        assert!(is_cjk_char('\u{30A0}')); // First Katakana
620        assert!(is_cjk_char('カ')); // Middle Katakana
621        assert!(is_cjk_char('\u{30FF}')); // Last Katakana
622
623        // Just outside Japanese ranges
624        assert!(!is_cjk_char('\u{303F}')); // Before Hiragana
625        assert!(!is_cjk_char('\u{3100}')); // After Katakana (Bopomofo, not CJK by our definition)
626    }
627
628    #[test]
629    fn test_is_cjk_char_korean() {
630        // Hangul Syllables (U+AC00..U+D7AF)
631        assert!(is_cjk_char('\u{AC00}')); // First Hangul syllable (가)
632        assert!(is_cjk_char('한')); // Middle Hangul
633        assert!(is_cjk_char('\u{D7AF}')); // Last Hangul syllable
634
635        // Just outside Korean range
636        assert!(!is_cjk_char('\u{ABFF}')); // Before Hangul
637        assert!(!is_cjk_char('\u{D7B0}')); // After Hangul
638    }
639
640    #[test]
641    fn test_escape_fts5_basic() {
642        // No quotes - no escaping needed
643        assert_eq!(escape_fts5("hello world"), "hello world");
644        assert_eq!(escape_fts5("JWT authentication"), "JWT authentication");
645
646        // Single quote (not escaped by this function, only double quotes)
647        assert_eq!(escape_fts5("user's task"), "user's task");
648    }
649
650    #[test]
651    fn test_escape_fts5_double_quotes() {
652        // Single double quote
653        assert_eq!(escape_fts5("\"admin\""), "\"\"admin\"\"");
654
655        // Multiple double quotes
656        assert_eq!(
657            escape_fts5("\"user\" and \"admin\""),
658            "\"\"user\"\" and \"\"admin\"\""
659        );
660
661        // Double quotes at different positions
662        assert_eq!(
663            escape_fts5("start \"middle\" end"),
664            "start \"\"middle\"\" end"
665        );
666        assert_eq!(escape_fts5("\"start"), "\"\"start");
667        assert_eq!(escape_fts5("end\""), "end\"\"");
668    }
669
670    #[test]
671    fn test_escape_fts5_complex_queries() {
672        // Mixed quotes and special characters
673        assert_eq!(
674            escape_fts5("search for \"exact phrase\" here"),
675            "search for \"\"exact phrase\"\" here"
676        );
677
678        // Empty string
679        assert_eq!(escape_fts5(""), "");
680
681        // Only quotes
682        assert_eq!(escape_fts5("\""), "\"\"");
683        assert_eq!(escape_fts5("\"\""), "\"\"\"\"");
684        assert_eq!(escape_fts5("\"\"\""), "\"\"\"\"\"\"");
685    }
686
687    #[test]
688    fn test_escape_fts5_cjk_with_quotes() {
689        // CJK text with quotes
690        assert_eq!(escape_fts5("用户\"管理员\"权限"), "用户\"\"管理员\"\"权限");
691        assert_eq!(escape_fts5("\"認証\"システム"), "\"\"認証\"\"システム");
692
693        // Mixed CJK and English with quotes
694        assert_eq!(
695            escape_fts5("API\"接口\"documentation"),
696            "API\"\"接口\"\"documentation"
697        );
698    }
699
700    #[test]
701    fn test_needs_like_fallback_unicode_normalization() {
702        // Test with different Unicode representations
703        // Most CJK characters don't have composition, but test general behavior
704
705        // Standard CJK characters
706        assert!(needs_like_fallback("中"));
707        assert!(needs_like_fallback("日"));
708
709        // Two CJK characters
710        assert!(needs_like_fallback("中日"));
711        assert!(needs_like_fallback("認證"));
712    }
713}