intent_engine/search.rs
1//! Search utilities for intent-engine
2//!
3//! This module provides:
4//! 1. CJK (Chinese, Japanese, Korean) search utilities for detecting when to use
5//! LIKE fallback vs FTS5 trigram search
6//! 2. Unified search across tasks and events
7//!
8//! **Background**: SQLite FTS5 with trigram tokenizer requires at least 3 consecutive
9//! characters to match. This is problematic for CJK languages where single-character
10//! or two-character searches are common (e.g., "用户", "认证").
11//!
12//! **Solution**: For short CJK queries, we fallback to LIKE search which supports
13//! any length substring matching, albeit slower.
14
15/// Check if a character is a CJK character
16pub fn is_cjk_char(c: char) -> bool {
17 let code = c as u32;
18 matches!(code,
19 // CJK Unified Ideographs (most common Chinese characters)
20 0x4E00..=0x9FFF |
21 // CJK Extension A
22 0x3400..=0x4DBF |
23 // CJK Extension B-F (less common, but included for completeness)
24 0x20000..=0x2A6DF |
25 0x2A700..=0x2B73F |
26 0x2B740..=0x2B81F |
27 0x2B820..=0x2CEAF |
28 0x2CEB0..=0x2EBEF |
29 // Hiragana (Japanese)
30 0x3040..=0x309F |
31 // Katakana (Japanese)
32 0x30A0..=0x30FF |
33 // Hangul Syllables (Korean)
34 0xAC00..=0xD7AF
35 )
36}
37
38/// Determine if a query should use LIKE fallback instead of FTS5 trigram
39///
40/// Returns `true` if:
41/// - Query is a single CJK character, OR
42/// - Query is two CJK characters
43///
44/// Trigram tokenizer requires 3+ characters for matching, so we use LIKE
45/// for shorter CJK queries to ensure they work.
46pub fn needs_like_fallback(query: &str) -> bool {
47 let chars: Vec<char> = query.chars().collect();
48
49 // Single-character CJK
50 if chars.len() == 1 && is_cjk_char(chars[0]) {
51 return true;
52 }
53
54 // Two-character all-CJK
55 // This is optional - could also let trigram handle it, but trigram
56 // needs minimum 3 chars so two-char CJK won't work well
57 if chars.len() == 2 && chars.iter().all(|c| is_cjk_char(*c)) {
58 return true;
59 }
60
61 false
62}
63
64/// Escape FTS5 special characters in a query string
65///
66/// FTS5 queries support advanced syntax (AND, OR, NOT, *, #, "phrase search", etc.).
67/// To safely search user input as literal text, we:
68/// 1. Escape any double quotes by doubling them (`"` -> `""`)
69/// 2. Wrap the entire query in double quotes for literal phrase search
70///
71/// This prevents special characters like `#`, `*`, `+`, `-`, etc. from being
72/// interpreted as FTS5 syntax operators.
73///
74/// # Arguments
75/// * `query` - The query string to escape
76///
77/// # Returns
78/// The query string wrapped in double quotes with internal quotes escaped
79///
80/// # Example
81/// ```ignore
82/// use crate::search::escape_fts5;
83///
84/// let escaped = escape_fts5("user \"admin\" role");
85/// assert_eq!(escaped, "\"user \"\"admin\"\" role\"");
86///
87/// let escaped = escape_fts5("#123");
88/// assert_eq!(escaped, "\"#123\"");
89/// ```
90pub fn escape_fts5(query: &str) -> String {
91 // Escape internal double quotes and wrap in quotes for literal search
92 format!("\"{}\"", query.replace('"', "\"\""))
93}
94
95// ============================================================================
96// Unified Search
97// ============================================================================
98
99use crate::db::models::{Event, PaginatedSearchResults, SearchResult, Task};
100use crate::error::Result;
101use crate::tasks::TaskManager;
102use sqlx::{Row, SqlitePool};
103
104pub struct SearchManager<'a> {
105 pool: &'a SqlitePool,
106}
107
108impl<'a> SearchManager<'a> {
109 pub fn new(pool: &'a SqlitePool) -> Self {
110 Self { pool }
111 }
112
113 /// Unified search across tasks and events with pagination support
114 ///
115 /// This is the new unified search method that replaces unified_search().
116 /// Key improvements:
117 /// - Pagination support (limit, offset)
118 /// - Flexible result inclusion (tasks, events)
119 /// - Optional priority-based secondary sorting
120 /// - Returns PaginatedSearchResults with metadata
121 ///
122 /// # Parameters
123 /// - `query`: FTS5 search query string
124 /// - `include_tasks`: Whether to search in tasks (default: true)
125 /// - `include_events`: Whether to search in events (default: true)
126 /// - `limit`: Maximum number of results per source (default: 20)
127 /// - `offset`: Number of results to skip (default: 0)
128 /// - `sort_by_priority`: Enable priority-based secondary sorting (default: false)
129 ///
130 /// # Returns
131 /// PaginatedSearchResults with mixed task and event results, ordered by relevance (FTS5 rank)
132 pub async fn search(
133 &self,
134 query: &str,
135 include_tasks: bool,
136 include_events: bool,
137 limit: Option<i64>,
138 offset: Option<i64>,
139 sort_by_priority: bool,
140 ) -> Result<PaginatedSearchResults> {
141 let limit = limit.unwrap_or(20);
142 let offset = offset.unwrap_or(0);
143
144 // Handle empty or whitespace-only queries
145 if query.trim().is_empty() {
146 return Ok(PaginatedSearchResults {
147 results: Vec::new(),
148 total_tasks: 0,
149 total_events: 0,
150 has_more: false,
151 limit,
152 offset,
153 });
154 }
155
156 // Handle queries with no searchable content (only special characters)
157 let has_searchable = query.chars().any(|c| c.is_alphanumeric() || is_cjk_char(c));
158 if !has_searchable {
159 return Ok(PaginatedSearchResults {
160 results: Vec::new(),
161 total_tasks: 0,
162 total_events: 0,
163 has_more: false,
164 limit,
165 offset,
166 });
167 }
168
169 // Escape FTS5 special characters
170 let escaped_query = escape_fts5(query);
171
172 let mut total_tasks: i64 = 0;
173 let mut total_events: i64 = 0;
174 let mut all_results: Vec<(SearchResult, f64)> = Vec::new();
175
176 // Check if we need LIKE fallback for short CJK queries
177 let use_like_fallback = needs_like_fallback(query);
178
179 if use_like_fallback {
180 // LIKE fallback path for short CJK queries (1-2 chars)
181 let like_pattern = format!("%{}%", query);
182
183 // Search tasks if enabled
184 if include_tasks {
185 // Get total count
186 let count_result = sqlx::query_scalar::<_, i64>(
187 "SELECT COUNT(*) FROM tasks WHERE name LIKE ? OR spec LIKE ?",
188 )
189 .bind(&like_pattern)
190 .bind(&like_pattern)
191 .fetch_one(self.pool)
192 .await?;
193 total_tasks = count_result;
194
195 // Build ORDER BY clause
196 let order_by = if sort_by_priority {
197 "ORDER BY COALESCE(priority, 0) ASC, id ASC"
198 } else {
199 "ORDER BY id ASC"
200 };
201
202 // Query tasks with pagination
203 let task_query = format!(
204 r#"
205 SELECT
206 id,
207 parent_id,
208 name,
209 spec,
210 status,
211 complexity,
212 priority,
213 first_todo_at,
214 first_doing_at,
215 first_done_at,
216 active_form,
217 owner,
218 metadata
219 FROM tasks
220 WHERE name LIKE ? OR spec LIKE ?
221 {}
222 LIMIT ? OFFSET ?
223 "#,
224 order_by
225 );
226
227 let rows = sqlx::query(&task_query)
228 .bind(&like_pattern)
229 .bind(&like_pattern)
230 .bind(limit)
231 .bind(offset)
232 .fetch_all(self.pool)
233 .await?;
234
235 for row in rows {
236 let task = Task {
237 id: row.get("id"),
238 parent_id: row.get("parent_id"),
239 name: row.get("name"),
240 spec: row.get("spec"),
241 status: row.get("status"),
242 complexity: row.get("complexity"),
243 priority: row.get("priority"),
244 first_todo_at: row.get("first_todo_at"),
245 first_doing_at: row.get("first_doing_at"),
246 first_done_at: row.get("first_done_at"),
247 active_form: row.get("active_form"),
248 owner: row.get("owner"),
249 metadata: row.get("metadata"),
250 };
251
252 // Determine match field and create snippet
253 let (match_field, match_snippet) = if task.name.contains(query) {
254 ("name".to_string(), task.name.clone())
255 } else if let Some(ref spec) = task.spec {
256 if spec.contains(query) {
257 ("spec".to_string(), spec.clone())
258 } else {
259 ("name".to_string(), task.name.clone())
260 }
261 } else {
262 ("name".to_string(), task.name.clone())
263 };
264
265 all_results.push((
266 SearchResult::Task {
267 task,
268 match_snippet,
269 match_field,
270 },
271 1.0, // Constant rank for LIKE results
272 ));
273 }
274 }
275
276 // Search events if enabled
277 if include_events {
278 // Get total count
279 let count_result = sqlx::query_scalar::<_, i64>(
280 "SELECT COUNT(*) FROM events WHERE discussion_data LIKE ?",
281 )
282 .bind(&like_pattern)
283 .fetch_one(self.pool)
284 .await?;
285 total_events = count_result;
286
287 // Query events with pagination
288 let rows = sqlx::query(
289 r#"
290 SELECT
291 id,
292 task_id,
293 timestamp,
294 log_type,
295 discussion_data
296 FROM events
297 WHERE discussion_data LIKE ?
298 ORDER BY id ASC
299 LIMIT ? OFFSET ?
300 "#,
301 )
302 .bind(&like_pattern)
303 .bind(limit)
304 .bind(offset)
305 .fetch_all(self.pool)
306 .await?;
307
308 let task_mgr = TaskManager::new(self.pool);
309 for row in rows {
310 let event = Event {
311 id: row.get("id"),
312 task_id: row.get("task_id"),
313 timestamp: row.get("timestamp"),
314 log_type: row.get("log_type"),
315 discussion_data: row.get("discussion_data"),
316 };
317
318 // Create match snippet
319 let match_snippet = event.discussion_data.clone();
320
321 // Get task ancestry chain for this event
322 let task_chain = task_mgr.get_task_ancestry(event.task_id).await?;
323
324 all_results.push((
325 SearchResult::Event {
326 event,
327 task_chain,
328 match_snippet,
329 },
330 1.0, // Constant rank for LIKE results
331 ));
332 }
333 }
334 } else {
335 // FTS5 path for longer queries (3+ chars)
336 // Search tasks if enabled
337 if include_tasks {
338 // Get total count
339 let count_result = sqlx::query_scalar::<_, i64>(
340 "SELECT COUNT(*) FROM tasks_fts WHERE tasks_fts MATCH ?",
341 )
342 .bind(&escaped_query)
343 .fetch_one(self.pool)
344 .await?;
345 total_tasks = count_result;
346
347 // Build ORDER BY clause
348 let order_by = if sort_by_priority {
349 "ORDER BY rank ASC, COALESCE(t.priority, 0) ASC, t.id ASC"
350 } else {
351 "ORDER BY rank ASC, t.id ASC"
352 };
353
354 // Query tasks with pagination
355 let task_query = format!(
356 r#"
357 SELECT
358 t.id,
359 t.parent_id,
360 t.name,
361 t.spec,
362 t.status,
363 t.complexity,
364 t.priority,
365 t.first_todo_at,
366 t.first_doing_at,
367 t.first_done_at,
368 t.active_form,
369 t.owner,
370 t.metadata,
371 COALESCE(
372 snippet(tasks_fts, 1, '**', '**', '...', 15),
373 snippet(tasks_fts, 0, '**', '**', '...', 15)
374 ) as match_snippet,
375 rank
376 FROM tasks_fts
377 INNER JOIN tasks t ON tasks_fts.rowid = t.id
378 WHERE tasks_fts MATCH ? AND t.deleted_at IS NULL
379 {}
380 LIMIT ? OFFSET ?
381 "#,
382 order_by
383 );
384
385 let rows = sqlx::query(&task_query)
386 .bind(&escaped_query)
387 .bind(limit)
388 .bind(offset)
389 .fetch_all(self.pool)
390 .await?;
391
392 for row in rows {
393 let task = Task {
394 id: row.get("id"),
395 parent_id: row.get("parent_id"),
396 name: row.get("name"),
397 spec: row.get("spec"),
398 status: row.get("status"),
399 complexity: row.get("complexity"),
400 priority: row.get("priority"),
401 first_todo_at: row.get("first_todo_at"),
402 first_doing_at: row.get("first_doing_at"),
403 first_done_at: row.get("first_done_at"),
404 active_form: row.get("active_form"),
405 owner: row.get("owner"),
406 metadata: row.get("metadata"),
407 };
408 let match_snippet: String = row.get("match_snippet");
409 let rank: f64 = row.get("rank");
410
411 // Determine match field based on snippet content
412 let match_field = if task
413 .spec
414 .as_ref()
415 .map(|s| match_snippet.to_lowercase().contains(&s.to_lowercase()))
416 .unwrap_or(false)
417 {
418 "spec".to_string()
419 } else {
420 "name".to_string()
421 };
422
423 all_results.push((
424 SearchResult::Task {
425 task,
426 match_snippet,
427 match_field,
428 },
429 rank,
430 ));
431 }
432 }
433
434 // Search events if enabled
435 if include_events {
436 // Get total count
437 let count_result = sqlx::query_scalar::<_, i64>(
438 "SELECT COUNT(*) FROM events_fts WHERE events_fts MATCH ?",
439 )
440 .bind(&escaped_query)
441 .fetch_one(self.pool)
442 .await?;
443 total_events = count_result;
444
445 // Query events with pagination
446 let rows = sqlx::query(
447 r#"
448 SELECT
449 e.id,
450 e.task_id,
451 e.timestamp,
452 e.log_type,
453 e.discussion_data,
454 snippet(events_fts, 0, '**', '**', '...', 15) as match_snippet,
455 rank
456 FROM events_fts
457 INNER JOIN events e ON events_fts.rowid = e.id
458 WHERE events_fts MATCH ?
459 ORDER BY rank ASC, e.id ASC
460 LIMIT ? OFFSET ?
461 "#,
462 )
463 .bind(&escaped_query)
464 .bind(limit)
465 .bind(offset)
466 .fetch_all(self.pool)
467 .await?;
468
469 let task_mgr = TaskManager::new(self.pool);
470 for row in rows {
471 let event = Event {
472 id: row.get("id"),
473 task_id: row.get("task_id"),
474 timestamp: row.get("timestamp"),
475 log_type: row.get("log_type"),
476 discussion_data: row.get("discussion_data"),
477 };
478 let match_snippet: String = row.get("match_snippet");
479 let rank: f64 = row.get("rank");
480
481 // Get task ancestry chain for this event
482 let task_chain = task_mgr.get_task_ancestry(event.task_id).await?;
483
484 all_results.push((
485 SearchResult::Event {
486 event,
487 task_chain,
488 match_snippet,
489 },
490 rank,
491 ));
492 }
493 }
494 } // End of else block (FTS5 path)
495
496 // Sort all results by rank (relevance)
497 all_results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
498
499 // Extract results without rank
500 let results: Vec<SearchResult> =
501 all_results.into_iter().map(|(result, _)| result).collect();
502
503 // Calculate has_more
504 let total_count = total_tasks + total_events;
505 let has_more = offset + (results.len() as i64) < total_count;
506
507 Ok(PaginatedSearchResults {
508 results,
509 total_tasks,
510 total_events,
511 has_more,
512 limit,
513 offset,
514 })
515 }
516}
517
518#[cfg(test)]
519mod tests {
520 use super::*;
521
522 #[test]
523 fn test_is_cjk_char() {
524 // Chinese characters
525 assert!(is_cjk_char('中'));
526 assert!(is_cjk_char('文'));
527 assert!(is_cjk_char('认'));
528 assert!(is_cjk_char('证'));
529
530 // Japanese Hiragana
531 assert!(is_cjk_char('あ'));
532 assert!(is_cjk_char('い'));
533
534 // Japanese Katakana
535 assert!(is_cjk_char('ア'));
536 assert!(is_cjk_char('イ'));
537
538 // Korean Hangul
539 assert!(is_cjk_char('가'));
540 assert!(is_cjk_char('나'));
541
542 // Non-CJK
543 assert!(!is_cjk_char('a'));
544 assert!(!is_cjk_char('A'));
545 assert!(!is_cjk_char('1'));
546 assert!(!is_cjk_char(' '));
547 assert!(!is_cjk_char('.'));
548 }
549
550 #[test]
551 fn test_needs_like_fallback() {
552 // Single CJK character - needs fallback
553 assert!(needs_like_fallback("中"));
554 assert!(needs_like_fallback("认"));
555 assert!(needs_like_fallback("あ"));
556 assert!(needs_like_fallback("가"));
557
558 // Two CJK characters - needs fallback
559 assert!(needs_like_fallback("中文"));
560 assert!(needs_like_fallback("认证"));
561 assert!(needs_like_fallback("用户"));
562
563 // Three+ CJK characters - can use FTS5
564 assert!(!needs_like_fallback("用户认"));
565 assert!(!needs_like_fallback("用户认证"));
566
567 // English - can use FTS5
568 assert!(!needs_like_fallback("JWT"));
569 assert!(!needs_like_fallback("auth"));
570 assert!(!needs_like_fallback("a")); // Single ASCII char, not CJK
571
572 // Mixed - can use FTS5
573 assert!(!needs_like_fallback("JWT认证"));
574 assert!(!needs_like_fallback("API接口"));
575 }
576
577 #[test]
578 fn test_needs_like_fallback_mixed_cjk_ascii() {
579 // Two characters: one CJK + one ASCII - should NOT need fallback
580 // because not all chars are CJK
581 assert!(!needs_like_fallback("中a"));
582 assert!(!needs_like_fallback("a中"));
583 assert!(!needs_like_fallback("認1"));
584
585 // Three+ characters with mixed CJK/ASCII - can use FTS5
586 assert!(!needs_like_fallback("中文API"));
587 assert!(!needs_like_fallback("JWT认证系统"));
588 assert!(!needs_like_fallback("API中文文档"));
589 }
590
591 #[test]
592 fn test_needs_like_fallback_edge_cases() {
593 // Empty string - no fallback needed
594 assert!(!needs_like_fallback(""));
595
596 // Whitespace only - no fallback
597 assert!(!needs_like_fallback(" "));
598 assert!(!needs_like_fallback(" "));
599
600 // Single non-CJK - no fallback
601 assert!(!needs_like_fallback("1"));
602 assert!(!needs_like_fallback("@"));
603 assert!(!needs_like_fallback(" "));
604
605 // Two non-CJK - no fallback
606 assert!(!needs_like_fallback("ab"));
607 assert!(!needs_like_fallback("12"));
608 }
609
610 #[test]
611 fn test_is_cjk_char_extension_ranges() {
612 // CJK Extension A (U+3400..U+4DBF)
613 assert!(is_cjk_char('\u{3400}')); // First char of Extension A
614 assert!(is_cjk_char('\u{4DBF}')); // Last char of Extension A
615
616 // CJK Unified Ideographs (U+4E00..U+9FFF) - common range
617 assert!(is_cjk_char('\u{4E00}')); // First common CJK
618 assert!(is_cjk_char('\u{9FFF}')); // Last common CJK
619
620 // Characters just outside ranges - should NOT be CJK
621 assert!(!is_cjk_char('\u{33FF}')); // Just before Extension A
622 assert!(!is_cjk_char('\u{4DC0}')); // Just after Extension A
623 assert!(!is_cjk_char('\u{4DFF}')); // Just before Unified Ideographs
624 assert!(!is_cjk_char('\u{A000}')); // Just after Unified Ideographs
625 }
626
627 #[test]
628 fn test_is_cjk_char_japanese() {
629 // Hiragana range (U+3040..U+309F)
630 assert!(is_cjk_char('\u{3040}')); // First Hiragana
631 assert!(is_cjk_char('ひ')); // Middle Hiragana
632 assert!(is_cjk_char('\u{309F}')); // Last Hiragana
633
634 // Katakana range (U+30A0..U+30FF)
635 assert!(is_cjk_char('\u{30A0}')); // First Katakana
636 assert!(is_cjk_char('カ')); // Middle Katakana
637 assert!(is_cjk_char('\u{30FF}')); // Last Katakana
638
639 // Just outside Japanese ranges
640 assert!(!is_cjk_char('\u{303F}')); // Before Hiragana
641 assert!(!is_cjk_char('\u{3100}')); // After Katakana (Bopomofo, not CJK by our definition)
642 }
643
644 #[test]
645 fn test_is_cjk_char_korean() {
646 // Hangul Syllables (U+AC00..U+D7AF)
647 assert!(is_cjk_char('\u{AC00}')); // First Hangul syllable (가)
648 assert!(is_cjk_char('한')); // Middle Hangul
649 assert!(is_cjk_char('\u{D7AF}')); // Last Hangul syllable
650
651 // Just outside Korean range
652 assert!(!is_cjk_char('\u{ABFF}')); // Before Hangul
653 assert!(!is_cjk_char('\u{D7B0}')); // After Hangul
654 }
655
656 #[test]
657 fn test_escape_fts5_basic() {
658 // No quotes - wrapped in quotes for literal search
659 assert_eq!(escape_fts5("hello world"), "\"hello world\"");
660 assert_eq!(escape_fts5("JWT authentication"), "\"JWT authentication\"");
661
662 // Single quote (not escaped by this function, only double quotes)
663 assert_eq!(escape_fts5("user's task"), "\"user's task\"");
664 }
665
666 #[test]
667 fn test_escape_fts5_double_quotes() {
668 // Single double quote - escaped and wrapped
669 assert_eq!(escape_fts5("\"admin\""), "\"\"\"admin\"\"\"");
670
671 // Multiple double quotes
672 assert_eq!(
673 escape_fts5("\"user\" and \"admin\""),
674 "\"\"\"user\"\" and \"\"admin\"\"\""
675 );
676
677 // Double quotes at different positions
678 assert_eq!(
679 escape_fts5("start \"middle\" end"),
680 "\"start \"\"middle\"\" end\""
681 );
682 assert_eq!(escape_fts5("\"start"), "\"\"\"start\"");
683 assert_eq!(escape_fts5("end\""), "\"end\"\"\"");
684 }
685
686 #[test]
687 fn test_escape_fts5_complex_queries() {
688 // Mixed quotes and special characters
689 assert_eq!(
690 escape_fts5("search for \"exact phrase\" here"),
691 "\"search for \"\"exact phrase\"\" here\""
692 );
693
694 // Empty string - still wrapped
695 assert_eq!(escape_fts5(""), "\"\"");
696
697 // Only quotes
698 assert_eq!(escape_fts5("\""), "\"\"\"\"");
699 assert_eq!(escape_fts5("\"\""), "\"\"\"\"\"\"");
700 assert_eq!(escape_fts5("\"\"\""), "\"\"\"\"\"\"\"\"");
701 }
702
703 #[test]
704 fn test_escape_fts5_special_chars() {
705 // FTS5 special characters should be wrapped and treated as literals
706 assert_eq!(escape_fts5("#123"), "\"#123\"");
707 assert_eq!(escape_fts5("task*"), "\"task*\"");
708 assert_eq!(escape_fts5("+keyword"), "\"+keyword\"");
709 assert_eq!(escape_fts5("-exclude"), "\"-exclude\"");
710 assert_eq!(escape_fts5("a AND b"), "\"a AND b\"");
711 assert_eq!(escape_fts5("a OR b"), "\"a OR b\"");
712 }
713
714 #[test]
715 fn test_escape_fts5_cjk_with_quotes() {
716 // CJK text with quotes - escaped and wrapped
717 assert_eq!(
718 escape_fts5("用户\"管理员\"权限"),
719 "\"用户\"\"管理员\"\"权限\""
720 );
721 assert_eq!(escape_fts5("\"認証\"システム"), "\"\"\"認証\"\"システム\"");
722
723 // Mixed CJK and English with quotes
724 assert_eq!(
725 escape_fts5("API\"接口\"documentation"),
726 "\"API\"\"接口\"\"documentation\""
727 );
728 }
729
730 #[test]
731 fn test_needs_like_fallback_unicode_normalization() {
732 // Test with different Unicode representations
733 // Most CJK characters don't have composition, but test general behavior
734
735 // Standard CJK characters
736 assert!(needs_like_fallback("中"));
737 assert!(needs_like_fallback("日"));
738
739 // Two CJK characters
740 assert!(needs_like_fallback("中日"));
741 assert!(needs_like_fallback("認證"));
742 }
743}