intent_engine/search.rs
1//! Search utilities for intent-engine
2//!
3//! This module provides:
4//! 1. CJK (Chinese, Japanese, Korean) search utilities for detecting when to use
5//! LIKE fallback vs FTS5 trigram search
6//! 2. Unified search across tasks and events
7//!
8//! **Background**: SQLite FTS5 with trigram tokenizer requires at least 3 consecutive
9//! characters to match. This is problematic for CJK languages where single-character
10//! or two-character searches are common (e.g., "用户", "认证").
11//!
12//! **Solution**: For short CJK queries, we fallback to LIKE search which supports
13//! any length substring matching, albeit slower.
14
15/// Check if a character is a CJK character
16pub fn is_cjk_char(c: char) -> bool {
17 let code = c as u32;
18 matches!(code,
19 // CJK Unified Ideographs (most common Chinese characters)
20 0x4E00..=0x9FFF |
21 // CJK Extension A
22 0x3400..=0x4DBF |
23 // CJK Extension B-F (less common, but included for completeness)
24 0x20000..=0x2A6DF |
25 0x2A700..=0x2B73F |
26 0x2B740..=0x2B81F |
27 0x2B820..=0x2CEAF |
28 0x2CEB0..=0x2EBEF |
29 // Hiragana (Japanese)
30 0x3040..=0x309F |
31 // Katakana (Japanese)
32 0x30A0..=0x30FF |
33 // Hangul Syllables (Korean)
34 0xAC00..=0xD7AF
35 )
36}
37
38/// Determine if a query should use LIKE fallback instead of FTS5 trigram
39///
40/// Returns `true` if:
41/// - Query is a single CJK character, OR
42/// - Query is two CJK characters
43///
44/// Trigram tokenizer requires 3+ characters for matching, so we use LIKE
45/// for shorter CJK queries to ensure they work.
46pub fn needs_like_fallback(query: &str) -> bool {
47 let chars: Vec<char> = query.chars().collect();
48
49 // Single-character CJK
50 if chars.len() == 1 && is_cjk_char(chars[0]) {
51 return true;
52 }
53
54 // Two-character all-CJK
55 // This is optional - could also let trigram handle it, but trigram
56 // needs minimum 3 chars so two-char CJK won't work well
57 if chars.len() == 2 && chars.iter().all(|c| is_cjk_char(*c)) {
58 return true;
59 }
60
61 false
62}
63
64/// Escape FTS5 special characters in a query string
65///
66/// FTS5 queries support advanced syntax (AND, OR, NOT, *, #, "phrase search", etc.).
67/// To safely search user input as literal text, we:
68/// 1. Escape any double quotes by doubling them (`"` -> `""`)
69/// 2. Wrap the entire query in double quotes for literal phrase search
70///
71/// This prevents special characters like `#`, `*`, `+`, `-`, etc. from being
72/// interpreted as FTS5 syntax operators.
73///
74/// # Arguments
75/// * `query` - The query string to escape
76///
77/// # Returns
78/// The query string wrapped in double quotes with internal quotes escaped
79///
80/// # Example
81/// ```ignore
82/// use crate::search::escape_fts5;
83///
84/// let escaped = escape_fts5("user \"admin\" role");
85/// assert_eq!(escaped, "\"user \"\"admin\"\" role\"");
86///
87/// let escaped = escape_fts5("#123");
88/// assert_eq!(escaped, "\"#123\"");
89/// ```
90pub fn escape_fts5(query: &str) -> String {
91 // Escape internal double quotes and wrap in quotes for literal search
92 format!("\"{}\"", query.replace('"', "\"\""))
93}
94
95// ============================================================================
96// Unified Search
97// ============================================================================
98
99use crate::db::models::{Event, PaginatedSearchResults, SearchResult, Task};
100use crate::error::Result;
101use crate::tasks::TaskManager;
102use sqlx::{Row, SqlitePool};
103
104pub struct SearchManager<'a> {
105 pool: &'a SqlitePool,
106}
107
108impl<'a> SearchManager<'a> {
109 pub fn new(pool: &'a SqlitePool) -> Self {
110 Self { pool }
111 }
112
113 /// Unified search across tasks and events with pagination support
114 ///
115 /// This is the new unified search method that replaces unified_search().
116 /// Key improvements:
117 /// - Pagination support (limit, offset)
118 /// - Flexible result inclusion (tasks, events)
119 /// - Optional priority-based secondary sorting
120 /// - Returns PaginatedSearchResults with metadata
121 ///
122 /// # Parameters
123 /// - `query`: FTS5 search query string
124 /// - `include_tasks`: Whether to search in tasks (default: true)
125 /// - `include_events`: Whether to search in events (default: true)
126 /// - `limit`: Maximum number of results per source (default: 20)
127 /// - `offset`: Number of results to skip (default: 0)
128 /// - `sort_by_priority`: Enable priority-based secondary sorting (default: false)
129 ///
130 /// # Returns
131 /// PaginatedSearchResults with mixed task and event results, ordered by relevance (FTS5 rank)
132 pub async fn search(
133 &self,
134 query: &str,
135 include_tasks: bool,
136 include_events: bool,
137 limit: Option<i64>,
138 offset: Option<i64>,
139 sort_by_priority: bool,
140 ) -> Result<PaginatedSearchResults> {
141 let limit = limit.unwrap_or(20);
142 let offset = offset.unwrap_or(0);
143
144 // Handle empty or whitespace-only queries
145 if query.trim().is_empty() {
146 return Ok(PaginatedSearchResults {
147 results: Vec::new(),
148 total_tasks: 0,
149 total_events: 0,
150 has_more: false,
151 limit,
152 offset,
153 });
154 }
155
156 // Handle queries with no searchable content (only special characters)
157 let has_searchable = query.chars().any(|c| c.is_alphanumeric() || is_cjk_char(c));
158 if !has_searchable {
159 return Ok(PaginatedSearchResults {
160 results: Vec::new(),
161 total_tasks: 0,
162 total_events: 0,
163 has_more: false,
164 limit,
165 offset,
166 });
167 }
168
169 // Escape FTS5 special characters
170 let escaped_query = escape_fts5(query);
171
172 let mut total_tasks: i64 = 0;
173 let mut total_events: i64 = 0;
174 let mut all_results: Vec<(SearchResult, f64)> = Vec::new();
175
176 // Check if we need LIKE fallback for short CJK queries
177 let use_like_fallback = needs_like_fallback(query);
178
179 if use_like_fallback {
180 // LIKE fallback path for short CJK queries (1-2 chars)
181 let like_pattern = format!("%{}%", query);
182
183 // Search tasks if enabled
184 if include_tasks {
185 // Get total count
186 let count_result = sqlx::query_scalar::<_, i64>(
187 "SELECT COUNT(*) FROM tasks WHERE name LIKE ? OR spec LIKE ?",
188 )
189 .bind(&like_pattern)
190 .bind(&like_pattern)
191 .fetch_one(self.pool)
192 .await?;
193 total_tasks = count_result;
194
195 // Build ORDER BY clause
196 let order_by = if sort_by_priority {
197 "ORDER BY COALESCE(priority, 0) ASC, id ASC"
198 } else {
199 "ORDER BY id ASC"
200 };
201
202 // Query tasks with pagination
203 let task_query = format!(
204 r#"
205 SELECT
206 id,
207 parent_id,
208 name,
209 spec,
210 status,
211 complexity,
212 priority,
213 first_todo_at,
214 first_doing_at,
215 first_done_at,
216 active_form,
217 owner
218 FROM tasks
219 WHERE name LIKE ? OR spec LIKE ?
220 {}
221 LIMIT ? OFFSET ?
222 "#,
223 order_by
224 );
225
226 let rows = sqlx::query(&task_query)
227 .bind(&like_pattern)
228 .bind(&like_pattern)
229 .bind(limit)
230 .bind(offset)
231 .fetch_all(self.pool)
232 .await?;
233
234 for row in rows {
235 let task = Task {
236 id: row.get("id"),
237 parent_id: row.get("parent_id"),
238 name: row.get("name"),
239 spec: row.get("spec"),
240 status: row.get("status"),
241 complexity: row.get("complexity"),
242 priority: row.get("priority"),
243 first_todo_at: row.get("first_todo_at"),
244 first_doing_at: row.get("first_doing_at"),
245 first_done_at: row.get("first_done_at"),
246 active_form: row.get("active_form"),
247 owner: row.get("owner"),
248 };
249
250 // Determine match field and create snippet
251 let (match_field, match_snippet) = if task.name.contains(query) {
252 ("name".to_string(), task.name.clone())
253 } else if let Some(ref spec) = task.spec {
254 if spec.contains(query) {
255 ("spec".to_string(), spec.clone())
256 } else {
257 ("name".to_string(), task.name.clone())
258 }
259 } else {
260 ("name".to_string(), task.name.clone())
261 };
262
263 all_results.push((
264 SearchResult::Task {
265 task,
266 match_snippet,
267 match_field,
268 },
269 1.0, // Constant rank for LIKE results
270 ));
271 }
272 }
273
274 // Search events if enabled
275 if include_events {
276 // Get total count
277 let count_result = sqlx::query_scalar::<_, i64>(
278 "SELECT COUNT(*) FROM events WHERE discussion_data LIKE ?",
279 )
280 .bind(&like_pattern)
281 .fetch_one(self.pool)
282 .await?;
283 total_events = count_result;
284
285 // Query events with pagination
286 let rows = sqlx::query(
287 r#"
288 SELECT
289 id,
290 task_id,
291 timestamp,
292 log_type,
293 discussion_data
294 FROM events
295 WHERE discussion_data LIKE ?
296 ORDER BY id ASC
297 LIMIT ? OFFSET ?
298 "#,
299 )
300 .bind(&like_pattern)
301 .bind(limit)
302 .bind(offset)
303 .fetch_all(self.pool)
304 .await?;
305
306 let task_mgr = TaskManager::new(self.pool);
307 for row in rows {
308 let event = Event {
309 id: row.get("id"),
310 task_id: row.get("task_id"),
311 timestamp: row.get("timestamp"),
312 log_type: row.get("log_type"),
313 discussion_data: row.get("discussion_data"),
314 };
315
316 // Create match snippet
317 let match_snippet = event.discussion_data.clone();
318
319 // Get task ancestry chain for this event
320 let task_chain = task_mgr.get_task_ancestry(event.task_id).await?;
321
322 all_results.push((
323 SearchResult::Event {
324 event,
325 task_chain,
326 match_snippet,
327 },
328 1.0, // Constant rank for LIKE results
329 ));
330 }
331 }
332 } else {
333 // FTS5 path for longer queries (3+ chars)
334 // Search tasks if enabled
335 if include_tasks {
336 // Get total count
337 let count_result = sqlx::query_scalar::<_, i64>(
338 "SELECT COUNT(*) FROM tasks_fts WHERE tasks_fts MATCH ?",
339 )
340 .bind(&escaped_query)
341 .fetch_one(self.pool)
342 .await?;
343 total_tasks = count_result;
344
345 // Build ORDER BY clause
346 let order_by = if sort_by_priority {
347 "ORDER BY rank ASC, COALESCE(t.priority, 0) ASC, t.id ASC"
348 } else {
349 "ORDER BY rank ASC, t.id ASC"
350 };
351
352 // Query tasks with pagination
353 let task_query = format!(
354 r#"
355 SELECT
356 t.id,
357 t.parent_id,
358 t.name,
359 t.spec,
360 t.status,
361 t.complexity,
362 t.priority,
363 t.first_todo_at,
364 t.first_doing_at,
365 t.first_done_at,
366 t.active_form,
367 t.owner,
368 COALESCE(
369 snippet(tasks_fts, 1, '**', '**', '...', 15),
370 snippet(tasks_fts, 0, '**', '**', '...', 15)
371 ) as match_snippet,
372 rank
373 FROM tasks_fts
374 INNER JOIN tasks t ON tasks_fts.rowid = t.id
375 WHERE tasks_fts MATCH ?
376 {}
377 LIMIT ? OFFSET ?
378 "#,
379 order_by
380 );
381
382 let rows = sqlx::query(&task_query)
383 .bind(&escaped_query)
384 .bind(limit)
385 .bind(offset)
386 .fetch_all(self.pool)
387 .await?;
388
389 for row in rows {
390 let task = Task {
391 id: row.get("id"),
392 parent_id: row.get("parent_id"),
393 name: row.get("name"),
394 spec: row.get("spec"),
395 status: row.get("status"),
396 complexity: row.get("complexity"),
397 priority: row.get("priority"),
398 first_todo_at: row.get("first_todo_at"),
399 first_doing_at: row.get("first_doing_at"),
400 first_done_at: row.get("first_done_at"),
401 active_form: row.get("active_form"),
402 owner: row.get("owner"),
403 };
404 let match_snippet: String = row.get("match_snippet");
405 let rank: f64 = row.get("rank");
406
407 // Determine match field based on snippet content
408 let match_field = if task
409 .spec
410 .as_ref()
411 .map(|s| match_snippet.to_lowercase().contains(&s.to_lowercase()))
412 .unwrap_or(false)
413 {
414 "spec".to_string()
415 } else {
416 "name".to_string()
417 };
418
419 all_results.push((
420 SearchResult::Task {
421 task,
422 match_snippet,
423 match_field,
424 },
425 rank,
426 ));
427 }
428 }
429
430 // Search events if enabled
431 if include_events {
432 // Get total count
433 let count_result = sqlx::query_scalar::<_, i64>(
434 "SELECT COUNT(*) FROM events_fts WHERE events_fts MATCH ?",
435 )
436 .bind(&escaped_query)
437 .fetch_one(self.pool)
438 .await?;
439 total_events = count_result;
440
441 // Query events with pagination
442 let rows = sqlx::query(
443 r#"
444 SELECT
445 e.id,
446 e.task_id,
447 e.timestamp,
448 e.log_type,
449 e.discussion_data,
450 snippet(events_fts, 0, '**', '**', '...', 15) as match_snippet,
451 rank
452 FROM events_fts
453 INNER JOIN events e ON events_fts.rowid = e.id
454 WHERE events_fts MATCH ?
455 ORDER BY rank ASC, e.id ASC
456 LIMIT ? OFFSET ?
457 "#,
458 )
459 .bind(&escaped_query)
460 .bind(limit)
461 .bind(offset)
462 .fetch_all(self.pool)
463 .await?;
464
465 let task_mgr = TaskManager::new(self.pool);
466 for row in rows {
467 let event = Event {
468 id: row.get("id"),
469 task_id: row.get("task_id"),
470 timestamp: row.get("timestamp"),
471 log_type: row.get("log_type"),
472 discussion_data: row.get("discussion_data"),
473 };
474 let match_snippet: String = row.get("match_snippet");
475 let rank: f64 = row.get("rank");
476
477 // Get task ancestry chain for this event
478 let task_chain = task_mgr.get_task_ancestry(event.task_id).await?;
479
480 all_results.push((
481 SearchResult::Event {
482 event,
483 task_chain,
484 match_snippet,
485 },
486 rank,
487 ));
488 }
489 }
490 } // End of else block (FTS5 path)
491
492 // Sort all results by rank (relevance)
493 all_results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
494
495 // Extract results without rank
496 let results: Vec<SearchResult> =
497 all_results.into_iter().map(|(result, _)| result).collect();
498
499 // Calculate has_more
500 let total_count = total_tasks + total_events;
501 let has_more = offset + (results.len() as i64) < total_count;
502
503 Ok(PaginatedSearchResults {
504 results,
505 total_tasks,
506 total_events,
507 has_more,
508 limit,
509 offset,
510 })
511 }
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517
518 #[test]
519 fn test_is_cjk_char() {
520 // Chinese characters
521 assert!(is_cjk_char('中'));
522 assert!(is_cjk_char('文'));
523 assert!(is_cjk_char('认'));
524 assert!(is_cjk_char('证'));
525
526 // Japanese Hiragana
527 assert!(is_cjk_char('あ'));
528 assert!(is_cjk_char('い'));
529
530 // Japanese Katakana
531 assert!(is_cjk_char('ア'));
532 assert!(is_cjk_char('イ'));
533
534 // Korean Hangul
535 assert!(is_cjk_char('가'));
536 assert!(is_cjk_char('나'));
537
538 // Non-CJK
539 assert!(!is_cjk_char('a'));
540 assert!(!is_cjk_char('A'));
541 assert!(!is_cjk_char('1'));
542 assert!(!is_cjk_char(' '));
543 assert!(!is_cjk_char('.'));
544 }
545
546 #[test]
547 fn test_needs_like_fallback() {
548 // Single CJK character - needs fallback
549 assert!(needs_like_fallback("中"));
550 assert!(needs_like_fallback("认"));
551 assert!(needs_like_fallback("あ"));
552 assert!(needs_like_fallback("가"));
553
554 // Two CJK characters - needs fallback
555 assert!(needs_like_fallback("中文"));
556 assert!(needs_like_fallback("认证"));
557 assert!(needs_like_fallback("用户"));
558
559 // Three+ CJK characters - can use FTS5
560 assert!(!needs_like_fallback("用户认"));
561 assert!(!needs_like_fallback("用户认证"));
562
563 // English - can use FTS5
564 assert!(!needs_like_fallback("JWT"));
565 assert!(!needs_like_fallback("auth"));
566 assert!(!needs_like_fallback("a")); // Single ASCII char, not CJK
567
568 // Mixed - can use FTS5
569 assert!(!needs_like_fallback("JWT认证"));
570 assert!(!needs_like_fallback("API接口"));
571 }
572
573 #[test]
574 fn test_needs_like_fallback_mixed_cjk_ascii() {
575 // Two characters: one CJK + one ASCII - should NOT need fallback
576 // because not all chars are CJK
577 assert!(!needs_like_fallback("中a"));
578 assert!(!needs_like_fallback("a中"));
579 assert!(!needs_like_fallback("認1"));
580
581 // Three+ characters with mixed CJK/ASCII - can use FTS5
582 assert!(!needs_like_fallback("中文API"));
583 assert!(!needs_like_fallback("JWT认证系统"));
584 assert!(!needs_like_fallback("API中文文档"));
585 }
586
587 #[test]
588 fn test_needs_like_fallback_edge_cases() {
589 // Empty string - no fallback needed
590 assert!(!needs_like_fallback(""));
591
592 // Whitespace only - no fallback
593 assert!(!needs_like_fallback(" "));
594 assert!(!needs_like_fallback(" "));
595
596 // Single non-CJK - no fallback
597 assert!(!needs_like_fallback("1"));
598 assert!(!needs_like_fallback("@"));
599 assert!(!needs_like_fallback(" "));
600
601 // Two non-CJK - no fallback
602 assert!(!needs_like_fallback("ab"));
603 assert!(!needs_like_fallback("12"));
604 }
605
606 #[test]
607 fn test_is_cjk_char_extension_ranges() {
608 // CJK Extension A (U+3400..U+4DBF)
609 assert!(is_cjk_char('\u{3400}')); // First char of Extension A
610 assert!(is_cjk_char('\u{4DBF}')); // Last char of Extension A
611
612 // CJK Unified Ideographs (U+4E00..U+9FFF) - common range
613 assert!(is_cjk_char('\u{4E00}')); // First common CJK
614 assert!(is_cjk_char('\u{9FFF}')); // Last common CJK
615
616 // Characters just outside ranges - should NOT be CJK
617 assert!(!is_cjk_char('\u{33FF}')); // Just before Extension A
618 assert!(!is_cjk_char('\u{4DC0}')); // Just after Extension A
619 assert!(!is_cjk_char('\u{4DFF}')); // Just before Unified Ideographs
620 assert!(!is_cjk_char('\u{A000}')); // Just after Unified Ideographs
621 }
622
623 #[test]
624 fn test_is_cjk_char_japanese() {
625 // Hiragana range (U+3040..U+309F)
626 assert!(is_cjk_char('\u{3040}')); // First Hiragana
627 assert!(is_cjk_char('ひ')); // Middle Hiragana
628 assert!(is_cjk_char('\u{309F}')); // Last Hiragana
629
630 // Katakana range (U+30A0..U+30FF)
631 assert!(is_cjk_char('\u{30A0}')); // First Katakana
632 assert!(is_cjk_char('カ')); // Middle Katakana
633 assert!(is_cjk_char('\u{30FF}')); // Last Katakana
634
635 // Just outside Japanese ranges
636 assert!(!is_cjk_char('\u{303F}')); // Before Hiragana
637 assert!(!is_cjk_char('\u{3100}')); // After Katakana (Bopomofo, not CJK by our definition)
638 }
639
640 #[test]
641 fn test_is_cjk_char_korean() {
642 // Hangul Syllables (U+AC00..U+D7AF)
643 assert!(is_cjk_char('\u{AC00}')); // First Hangul syllable (가)
644 assert!(is_cjk_char('한')); // Middle Hangul
645 assert!(is_cjk_char('\u{D7AF}')); // Last Hangul syllable
646
647 // Just outside Korean range
648 assert!(!is_cjk_char('\u{ABFF}')); // Before Hangul
649 assert!(!is_cjk_char('\u{D7B0}')); // After Hangul
650 }
651
652 #[test]
653 fn test_escape_fts5_basic() {
654 // No quotes - wrapped in quotes for literal search
655 assert_eq!(escape_fts5("hello world"), "\"hello world\"");
656 assert_eq!(escape_fts5("JWT authentication"), "\"JWT authentication\"");
657
658 // Single quote (not escaped by this function, only double quotes)
659 assert_eq!(escape_fts5("user's task"), "\"user's task\"");
660 }
661
662 #[test]
663 fn test_escape_fts5_double_quotes() {
664 // Single double quote - escaped and wrapped
665 assert_eq!(escape_fts5("\"admin\""), "\"\"\"admin\"\"\"");
666
667 // Multiple double quotes
668 assert_eq!(
669 escape_fts5("\"user\" and \"admin\""),
670 "\"\"\"user\"\" and \"\"admin\"\"\""
671 );
672
673 // Double quotes at different positions
674 assert_eq!(
675 escape_fts5("start \"middle\" end"),
676 "\"start \"\"middle\"\" end\""
677 );
678 assert_eq!(escape_fts5("\"start"), "\"\"\"start\"");
679 assert_eq!(escape_fts5("end\""), "\"end\"\"\"");
680 }
681
682 #[test]
683 fn test_escape_fts5_complex_queries() {
684 // Mixed quotes and special characters
685 assert_eq!(
686 escape_fts5("search for \"exact phrase\" here"),
687 "\"search for \"\"exact phrase\"\" here\""
688 );
689
690 // Empty string - still wrapped
691 assert_eq!(escape_fts5(""), "\"\"");
692
693 // Only quotes
694 assert_eq!(escape_fts5("\""), "\"\"\"\"");
695 assert_eq!(escape_fts5("\"\""), "\"\"\"\"\"\"");
696 assert_eq!(escape_fts5("\"\"\""), "\"\"\"\"\"\"\"\"");
697 }
698
699 #[test]
700 fn test_escape_fts5_special_chars() {
701 // FTS5 special characters should be wrapped and treated as literals
702 assert_eq!(escape_fts5("#123"), "\"#123\"");
703 assert_eq!(escape_fts5("task*"), "\"task*\"");
704 assert_eq!(escape_fts5("+keyword"), "\"+keyword\"");
705 assert_eq!(escape_fts5("-exclude"), "\"-exclude\"");
706 assert_eq!(escape_fts5("a AND b"), "\"a AND b\"");
707 assert_eq!(escape_fts5("a OR b"), "\"a OR b\"");
708 }
709
710 #[test]
711 fn test_escape_fts5_cjk_with_quotes() {
712 // CJK text with quotes - escaped and wrapped
713 assert_eq!(
714 escape_fts5("用户\"管理员\"权限"),
715 "\"用户\"\"管理员\"\"权限\""
716 );
717 assert_eq!(escape_fts5("\"認証\"システム"), "\"\"\"認証\"\"システム\"");
718
719 // Mixed CJK and English with quotes
720 assert_eq!(
721 escape_fts5("API\"接口\"documentation"),
722 "\"API\"\"接口\"\"documentation\""
723 );
724 }
725
726 #[test]
727 fn test_needs_like_fallback_unicode_normalization() {
728 // Test with different Unicode representations
729 // Most CJK characters don't have composition, but test general behavior
730
731 // Standard CJK characters
732 assert!(needs_like_fallback("中"));
733 assert!(needs_like_fallback("日"));
734
735 // Two CJK characters
736 assert!(needs_like_fallback("中日"));
737 assert!(needs_like_fallback("認證"));
738 }
739}