intent_engine/
search.rs

1//! Search utilities for intent-engine
2//!
3//! This module provides:
4//! 1. CJK (Chinese, Japanese, Korean) search utilities for detecting when to use
5//!    LIKE fallback vs FTS5 trigram search
6//! 2. Unified search across tasks and events
7//!
8//! **Background**: SQLite FTS5 with trigram tokenizer requires at least 3 consecutive
9//! characters to match. This is problematic for CJK languages where single-character
10//! or two-character searches are common (e.g., "用户", "认证").
11//!
12//! **Solution**: For short CJK queries, we fallback to LIKE search which supports
13//! any length substring matching, albeit slower.
14
15/// Check if a character is a CJK character
16pub fn is_cjk_char(c: char) -> bool {
17    let code = c as u32;
18    matches!(code,
19        // CJK Unified Ideographs (most common Chinese characters)
20        0x4E00..=0x9FFF |
21        // CJK Extension A
22        0x3400..=0x4DBF |
23        // CJK Extension B-F (less common, but included for completeness)
24        0x20000..=0x2A6DF |
25        0x2A700..=0x2B73F |
26        0x2B740..=0x2B81F |
27        0x2B820..=0x2CEAF |
28        0x2CEB0..=0x2EBEF |
29        // Hiragana (Japanese)
30        0x3040..=0x309F |
31        // Katakana (Japanese)
32        0x30A0..=0x30FF |
33        // Hangul Syllables (Korean)
34        0xAC00..=0xD7AF
35    )
36}
37
38/// Determine if a query should use LIKE fallback instead of FTS5 trigram
39///
40/// Returns `true` if:
41/// - Query is a single CJK character, OR
42/// - Query is two CJK characters
43///
44/// Trigram tokenizer requires 3+ characters for matching, so we use LIKE
45/// for shorter CJK queries to ensure they work.
46pub fn needs_like_fallback(query: &str) -> bool {
47    let chars: Vec<char> = query.chars().collect();
48
49    // Single-character CJK
50    if chars.len() == 1 && is_cjk_char(chars[0]) {
51        return true;
52    }
53
54    // Two-character all-CJK
55    // This is optional - could also let trigram handle it, but trigram
56    // needs minimum 3 chars so two-char CJK won't work well
57    if chars.len() == 2 && chars.iter().all(|c| is_cjk_char(*c)) {
58        return true;
59    }
60
61    false
62}
63
64/// Escape FTS5 special characters in a query string
65///
66/// FTS5 queries support advanced syntax (AND, OR, NOT, *, "phrase search", etc.).
67/// This function only escapes double quotes, which is the most common case where
68/// user input needs escaping.
69///
70/// # Arguments
71/// * `query` - The query string to escape
72///
73/// # Returns
74/// The escaped query string with double quotes escaped as `""`
75///
76/// # Example
77/// ```ignore
78/// use crate::search::escape_fts5;
79///
80/// let escaped = escape_fts5("user \"admin\" role");
81/// assert_eq!(escaped, "user \"\"admin\"\" role");
82/// ```
83pub fn escape_fts5(query: &str) -> String {
84    query.replace('"', "\"\"")
85}
86
87// ============================================================================
88// Unified Search
89// ============================================================================
90
91use crate::db::models::{Event, PaginatedSearchResults, SearchResult, Task};
92use crate::error::Result;
93use crate::tasks::TaskManager;
94use sqlx::{Row, SqlitePool};
95
96pub struct SearchManager<'a> {
97    pool: &'a SqlitePool,
98}
99
100impl<'a> SearchManager<'a> {
101    pub fn new(pool: &'a SqlitePool) -> Self {
102        Self { pool }
103    }
104
105    /// Unified search across tasks and events with pagination support
106    ///
107    /// This is the new unified search method that replaces unified_search().
108    /// Key improvements:
109    /// - Pagination support (limit, offset)
110    /// - Flexible result inclusion (tasks, events)
111    /// - Optional priority-based secondary sorting
112    /// - Returns PaginatedSearchResults with metadata
113    ///
114    /// # Parameters
115    /// - `query`: FTS5 search query string
116    /// - `include_tasks`: Whether to search in tasks (default: true)
117    /// - `include_events`: Whether to search in events (default: true)
118    /// - `limit`: Maximum number of results per source (default: 20)
119    /// - `offset`: Number of results to skip (default: 0)
120    /// - `sort_by_priority`: Enable priority-based secondary sorting (default: false)
121    ///
122    /// # Returns
123    /// PaginatedSearchResults with mixed task and event results, ordered by relevance (FTS5 rank)
124    pub async fn search(
125        &self,
126        query: &str,
127        include_tasks: bool,
128        include_events: bool,
129        limit: Option<i64>,
130        offset: Option<i64>,
131        sort_by_priority: bool,
132    ) -> Result<PaginatedSearchResults> {
133        let limit = limit.unwrap_or(20);
134        let offset = offset.unwrap_or(0);
135
136        // Handle empty or whitespace-only queries
137        if query.trim().is_empty() {
138            return Ok(PaginatedSearchResults {
139                results: Vec::new(),
140                total_tasks: 0,
141                total_events: 0,
142                has_more: false,
143                limit,
144                offset,
145            });
146        }
147
148        // Handle queries with no searchable content (only special characters)
149        let has_searchable = query.chars().any(|c| c.is_alphanumeric() || is_cjk_char(c));
150        if !has_searchable {
151            return Ok(PaginatedSearchResults {
152                results: Vec::new(),
153                total_tasks: 0,
154                total_events: 0,
155                has_more: false,
156                limit,
157                offset,
158            });
159        }
160
161        // Escape FTS5 special characters
162        let escaped_query = escape_fts5(query);
163
164        let mut total_tasks: i64 = 0;
165        let mut total_events: i64 = 0;
166        let mut all_results: Vec<(SearchResult, f64)> = Vec::new();
167
168        // Check if we need LIKE fallback for short CJK queries
169        let use_like_fallback = needs_like_fallback(query);
170
171        if use_like_fallback {
172            // LIKE fallback path for short CJK queries (1-2 chars)
173            let like_pattern = format!("%{}%", query);
174
175            // Search tasks if enabled
176            if include_tasks {
177                // Get total count
178                let count_result = sqlx::query_scalar::<_, i64>(
179                    "SELECT COUNT(*) FROM tasks WHERE name LIKE ? OR spec LIKE ?",
180                )
181                .bind(&like_pattern)
182                .bind(&like_pattern)
183                .fetch_one(self.pool)
184                .await?;
185                total_tasks = count_result;
186
187                // Build ORDER BY clause
188                let order_by = if sort_by_priority {
189                    "ORDER BY COALESCE(priority, 0) ASC, id ASC"
190                } else {
191                    "ORDER BY id ASC"
192                };
193
194                // Query tasks with pagination
195                let task_query = format!(
196                    r#"
197                    SELECT
198                        id,
199                        parent_id,
200                        name,
201                        spec,
202                        status,
203                        complexity,
204                        priority,
205                        first_todo_at,
206                        first_doing_at,
207                        first_done_at,
208                        active_form,
209                        owner
210                    FROM tasks
211                    WHERE name LIKE ? OR spec LIKE ?
212                    {}
213                    LIMIT ? OFFSET ?
214                    "#,
215                    order_by
216                );
217
218                let rows = sqlx::query(&task_query)
219                    .bind(&like_pattern)
220                    .bind(&like_pattern)
221                    .bind(limit)
222                    .bind(offset)
223                    .fetch_all(self.pool)
224                    .await?;
225
226                for row in rows {
227                    let task = Task {
228                        id: row.get("id"),
229                        parent_id: row.get("parent_id"),
230                        name: row.get("name"),
231                        spec: row.get("spec"),
232                        status: row.get("status"),
233                        complexity: row.get("complexity"),
234                        priority: row.get("priority"),
235                        first_todo_at: row.get("first_todo_at"),
236                        first_doing_at: row.get("first_doing_at"),
237                        first_done_at: row.get("first_done_at"),
238                        active_form: row.get("active_form"),
239                        owner: row.get("owner"),
240                    };
241
242                    // Determine match field and create snippet
243                    let (match_field, match_snippet) = if task.name.contains(query) {
244                        ("name".to_string(), task.name.clone())
245                    } else if let Some(ref spec) = task.spec {
246                        if spec.contains(query) {
247                            ("spec".to_string(), spec.clone())
248                        } else {
249                            ("name".to_string(), task.name.clone())
250                        }
251                    } else {
252                        ("name".to_string(), task.name.clone())
253                    };
254
255                    all_results.push((
256                        SearchResult::Task {
257                            task,
258                            match_snippet,
259                            match_field,
260                        },
261                        1.0, // Constant rank for LIKE results
262                    ));
263                }
264            }
265
266            // Search events if enabled
267            if include_events {
268                // Get total count
269                let count_result = sqlx::query_scalar::<_, i64>(
270                    "SELECT COUNT(*) FROM events WHERE discussion_data LIKE ?",
271                )
272                .bind(&like_pattern)
273                .fetch_one(self.pool)
274                .await?;
275                total_events = count_result;
276
277                // Query events with pagination
278                let rows = sqlx::query(
279                    r#"
280                    SELECT
281                        id,
282                        task_id,
283                        timestamp,
284                        log_type,
285                        discussion_data
286                    FROM events
287                    WHERE discussion_data LIKE ?
288                    ORDER BY id ASC
289                    LIMIT ? OFFSET ?
290                    "#,
291                )
292                .bind(&like_pattern)
293                .bind(limit)
294                .bind(offset)
295                .fetch_all(self.pool)
296                .await?;
297
298                let task_mgr = TaskManager::new(self.pool);
299                for row in rows {
300                    let event = Event {
301                        id: row.get("id"),
302                        task_id: row.get("task_id"),
303                        timestamp: row.get("timestamp"),
304                        log_type: row.get("log_type"),
305                        discussion_data: row.get("discussion_data"),
306                    };
307
308                    // Create match snippet
309                    let match_snippet = event.discussion_data.clone();
310
311                    // Get task ancestry chain for this event
312                    let task_chain = task_mgr.get_task_ancestry(event.task_id).await?;
313
314                    all_results.push((
315                        SearchResult::Event {
316                            event,
317                            task_chain,
318                            match_snippet,
319                        },
320                        1.0, // Constant rank for LIKE results
321                    ));
322                }
323            }
324        } else {
325            // FTS5 path for longer queries (3+ chars)
326            // Search tasks if enabled
327            if include_tasks {
328                // Get total count
329                let count_result = sqlx::query_scalar::<_, i64>(
330                    "SELECT COUNT(*) FROM tasks_fts WHERE tasks_fts MATCH ?",
331                )
332                .bind(&escaped_query)
333                .fetch_one(self.pool)
334                .await?;
335                total_tasks = count_result;
336
337                // Build ORDER BY clause
338                let order_by = if sort_by_priority {
339                    "ORDER BY rank ASC, COALESCE(t.priority, 0) ASC, t.id ASC"
340                } else {
341                    "ORDER BY rank ASC, t.id ASC"
342                };
343
344                // Query tasks with pagination
345                let task_query = format!(
346                    r#"
347                SELECT
348                    t.id,
349                    t.parent_id,
350                    t.name,
351                    t.spec,
352                    t.status,
353                    t.complexity,
354                    t.priority,
355                    t.first_todo_at,
356                    t.first_doing_at,
357                    t.first_done_at,
358                    t.active_form,
359                    t.owner,
360                    COALESCE(
361                        snippet(tasks_fts, 1, '**', '**', '...', 15),
362                        snippet(tasks_fts, 0, '**', '**', '...', 15)
363                    ) as match_snippet,
364                    rank
365                FROM tasks_fts
366                INNER JOIN tasks t ON tasks_fts.rowid = t.id
367                WHERE tasks_fts MATCH ?
368                {}
369                LIMIT ? OFFSET ?
370                "#,
371                    order_by
372                );
373
374                let rows = sqlx::query(&task_query)
375                    .bind(&escaped_query)
376                    .bind(limit)
377                    .bind(offset)
378                    .fetch_all(self.pool)
379                    .await?;
380
381                for row in rows {
382                    let task = Task {
383                        id: row.get("id"),
384                        parent_id: row.get("parent_id"),
385                        name: row.get("name"),
386                        spec: row.get("spec"),
387                        status: row.get("status"),
388                        complexity: row.get("complexity"),
389                        priority: row.get("priority"),
390                        first_todo_at: row.get("first_todo_at"),
391                        first_doing_at: row.get("first_doing_at"),
392                        first_done_at: row.get("first_done_at"),
393                        active_form: row.get("active_form"),
394                        owner: row.get("owner"),
395                    };
396                    let match_snippet: String = row.get("match_snippet");
397                    let rank: f64 = row.get("rank");
398
399                    // Determine match field based on snippet content
400                    let match_field = if task
401                        .spec
402                        .as_ref()
403                        .map(|s| match_snippet.to_lowercase().contains(&s.to_lowercase()))
404                        .unwrap_or(false)
405                    {
406                        "spec".to_string()
407                    } else {
408                        "name".to_string()
409                    };
410
411                    all_results.push((
412                        SearchResult::Task {
413                            task,
414                            match_snippet,
415                            match_field,
416                        },
417                        rank,
418                    ));
419                }
420            }
421
422            // Search events if enabled
423            if include_events {
424                // Get total count
425                let count_result = sqlx::query_scalar::<_, i64>(
426                    "SELECT COUNT(*) FROM events_fts WHERE events_fts MATCH ?",
427                )
428                .bind(&escaped_query)
429                .fetch_one(self.pool)
430                .await?;
431                total_events = count_result;
432
433                // Query events with pagination
434                let rows = sqlx::query(
435                    r#"
436                SELECT
437                    e.id,
438                    e.task_id,
439                    e.timestamp,
440                    e.log_type,
441                    e.discussion_data,
442                    snippet(events_fts, 0, '**', '**', '...', 15) as match_snippet,
443                    rank
444                FROM events_fts
445                INNER JOIN events e ON events_fts.rowid = e.id
446                WHERE events_fts MATCH ?
447                ORDER BY rank ASC, e.id ASC
448                LIMIT ? OFFSET ?
449                "#,
450                )
451                .bind(&escaped_query)
452                .bind(limit)
453                .bind(offset)
454                .fetch_all(self.pool)
455                .await?;
456
457                let task_mgr = TaskManager::new(self.pool);
458                for row in rows {
459                    let event = Event {
460                        id: row.get("id"),
461                        task_id: row.get("task_id"),
462                        timestamp: row.get("timestamp"),
463                        log_type: row.get("log_type"),
464                        discussion_data: row.get("discussion_data"),
465                    };
466                    let match_snippet: String = row.get("match_snippet");
467                    let rank: f64 = row.get("rank");
468
469                    // Get task ancestry chain for this event
470                    let task_chain = task_mgr.get_task_ancestry(event.task_id).await?;
471
472                    all_results.push((
473                        SearchResult::Event {
474                            event,
475                            task_chain,
476                            match_snippet,
477                        },
478                        rank,
479                    ));
480                }
481            }
482        } // End of else block (FTS5 path)
483
484        // Sort all results by rank (relevance)
485        all_results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
486
487        // Extract results without rank
488        let results: Vec<SearchResult> =
489            all_results.into_iter().map(|(result, _)| result).collect();
490
491        // Calculate has_more
492        let total_count = total_tasks + total_events;
493        let has_more = offset + (results.len() as i64) < total_count;
494
495        Ok(PaginatedSearchResults {
496            results,
497            total_tasks,
498            total_events,
499            has_more,
500            limit,
501            offset,
502        })
503    }
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509
510    #[test]
511    fn test_is_cjk_char() {
512        // Chinese characters
513        assert!(is_cjk_char('中'));
514        assert!(is_cjk_char('文'));
515        assert!(is_cjk_char('认'));
516        assert!(is_cjk_char('证'));
517
518        // Japanese Hiragana
519        assert!(is_cjk_char('あ'));
520        assert!(is_cjk_char('い'));
521
522        // Japanese Katakana
523        assert!(is_cjk_char('ア'));
524        assert!(is_cjk_char('イ'));
525
526        // Korean Hangul
527        assert!(is_cjk_char('가'));
528        assert!(is_cjk_char('나'));
529
530        // Non-CJK
531        assert!(!is_cjk_char('a'));
532        assert!(!is_cjk_char('A'));
533        assert!(!is_cjk_char('1'));
534        assert!(!is_cjk_char(' '));
535        assert!(!is_cjk_char('.'));
536    }
537
538    #[test]
539    fn test_needs_like_fallback() {
540        // Single CJK character - needs fallback
541        assert!(needs_like_fallback("中"));
542        assert!(needs_like_fallback("认"));
543        assert!(needs_like_fallback("あ"));
544        assert!(needs_like_fallback("가"));
545
546        // Two CJK characters - needs fallback
547        assert!(needs_like_fallback("中文"));
548        assert!(needs_like_fallback("认证"));
549        assert!(needs_like_fallback("用户"));
550
551        // Three+ CJK characters - can use FTS5
552        assert!(!needs_like_fallback("用户认"));
553        assert!(!needs_like_fallback("用户认证"));
554
555        // English - can use FTS5
556        assert!(!needs_like_fallback("JWT"));
557        assert!(!needs_like_fallback("auth"));
558        assert!(!needs_like_fallback("a")); // Single ASCII char, not CJK
559
560        // Mixed - can use FTS5
561        assert!(!needs_like_fallback("JWT认证"));
562        assert!(!needs_like_fallback("API接口"));
563    }
564
565    #[test]
566    fn test_needs_like_fallback_mixed_cjk_ascii() {
567        // Two characters: one CJK + one ASCII - should NOT need fallback
568        // because not all chars are CJK
569        assert!(!needs_like_fallback("中a"));
570        assert!(!needs_like_fallback("a中"));
571        assert!(!needs_like_fallback("認1"));
572
573        // Three+ characters with mixed CJK/ASCII - can use FTS5
574        assert!(!needs_like_fallback("中文API"));
575        assert!(!needs_like_fallback("JWT认证系统"));
576        assert!(!needs_like_fallback("API中文文档"));
577    }
578
579    #[test]
580    fn test_needs_like_fallback_edge_cases() {
581        // Empty string - no fallback needed
582        assert!(!needs_like_fallback(""));
583
584        // Whitespace only - no fallback
585        assert!(!needs_like_fallback(" "));
586        assert!(!needs_like_fallback("  "));
587
588        // Single non-CJK - no fallback
589        assert!(!needs_like_fallback("1"));
590        assert!(!needs_like_fallback("@"));
591        assert!(!needs_like_fallback(" "));
592
593        // Two non-CJK - no fallback
594        assert!(!needs_like_fallback("ab"));
595        assert!(!needs_like_fallback("12"));
596    }
597
598    #[test]
599    fn test_is_cjk_char_extension_ranges() {
600        // CJK Extension A (U+3400..U+4DBF)
601        assert!(is_cjk_char('\u{3400}')); // First char of Extension A
602        assert!(is_cjk_char('\u{4DBF}')); // Last char of Extension A
603
604        // CJK Unified Ideographs (U+4E00..U+9FFF) - common range
605        assert!(is_cjk_char('\u{4E00}')); // First common CJK
606        assert!(is_cjk_char('\u{9FFF}')); // Last common CJK
607
608        // Characters just outside ranges - should NOT be CJK
609        assert!(!is_cjk_char('\u{33FF}')); // Just before Extension A
610        assert!(!is_cjk_char('\u{4DC0}')); // Just after Extension A
611        assert!(!is_cjk_char('\u{4DFF}')); // Just before Unified Ideographs
612        assert!(!is_cjk_char('\u{A000}')); // Just after Unified Ideographs
613    }
614
615    #[test]
616    fn test_is_cjk_char_japanese() {
617        // Hiragana range (U+3040..U+309F)
618        assert!(is_cjk_char('\u{3040}')); // First Hiragana
619        assert!(is_cjk_char('ひ')); // Middle Hiragana
620        assert!(is_cjk_char('\u{309F}')); // Last Hiragana
621
622        // Katakana range (U+30A0..U+30FF)
623        assert!(is_cjk_char('\u{30A0}')); // First Katakana
624        assert!(is_cjk_char('カ')); // Middle Katakana
625        assert!(is_cjk_char('\u{30FF}')); // Last Katakana
626
627        // Just outside Japanese ranges
628        assert!(!is_cjk_char('\u{303F}')); // Before Hiragana
629        assert!(!is_cjk_char('\u{3100}')); // After Katakana (Bopomofo, not CJK by our definition)
630    }
631
632    #[test]
633    fn test_is_cjk_char_korean() {
634        // Hangul Syllables (U+AC00..U+D7AF)
635        assert!(is_cjk_char('\u{AC00}')); // First Hangul syllable (가)
636        assert!(is_cjk_char('한')); // Middle Hangul
637        assert!(is_cjk_char('\u{D7AF}')); // Last Hangul syllable
638
639        // Just outside Korean range
640        assert!(!is_cjk_char('\u{ABFF}')); // Before Hangul
641        assert!(!is_cjk_char('\u{D7B0}')); // After Hangul
642    }
643
644    #[test]
645    fn test_escape_fts5_basic() {
646        // No quotes - no escaping needed
647        assert_eq!(escape_fts5("hello world"), "hello world");
648        assert_eq!(escape_fts5("JWT authentication"), "JWT authentication");
649
650        // Single quote (not escaped by this function, only double quotes)
651        assert_eq!(escape_fts5("user's task"), "user's task");
652    }
653
654    #[test]
655    fn test_escape_fts5_double_quotes() {
656        // Single double quote
657        assert_eq!(escape_fts5("\"admin\""), "\"\"admin\"\"");
658
659        // Multiple double quotes
660        assert_eq!(
661            escape_fts5("\"user\" and \"admin\""),
662            "\"\"user\"\" and \"\"admin\"\""
663        );
664
665        // Double quotes at different positions
666        assert_eq!(
667            escape_fts5("start \"middle\" end"),
668            "start \"\"middle\"\" end"
669        );
670        assert_eq!(escape_fts5("\"start"), "\"\"start");
671        assert_eq!(escape_fts5("end\""), "end\"\"");
672    }
673
674    #[test]
675    fn test_escape_fts5_complex_queries() {
676        // Mixed quotes and special characters
677        assert_eq!(
678            escape_fts5("search for \"exact phrase\" here"),
679            "search for \"\"exact phrase\"\" here"
680        );
681
682        // Empty string
683        assert_eq!(escape_fts5(""), "");
684
685        // Only quotes
686        assert_eq!(escape_fts5("\""), "\"\"");
687        assert_eq!(escape_fts5("\"\""), "\"\"\"\"");
688        assert_eq!(escape_fts5("\"\"\""), "\"\"\"\"\"\"");
689    }
690
691    #[test]
692    fn test_escape_fts5_cjk_with_quotes() {
693        // CJK text with quotes
694        assert_eq!(escape_fts5("用户\"管理员\"权限"), "用户\"\"管理员\"\"权限");
695        assert_eq!(escape_fts5("\"認証\"システム"), "\"\"認証\"\"システム");
696
697        // Mixed CJK and English with quotes
698        assert_eq!(
699            escape_fts5("API\"接口\"documentation"),
700            "API\"\"接口\"\"documentation"
701        );
702    }
703
704    #[test]
705    fn test_needs_like_fallback_unicode_normalization() {
706        // Test with different Unicode representations
707        // Most CJK characters don't have composition, but test general behavior
708
709        // Standard CJK characters
710        assert!(needs_like_fallback("中"));
711        assert!(needs_like_fallback("日"));
712
713        // Two CJK characters
714        assert!(needs_like_fallback("中日"));
715        assert!(needs_like_fallback("認證"));
716    }
717}