1use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use uuid::Uuid;
11
12use super::rbac::{AccessControl, Action, Permission, Resource};
13use super::workspace::{MemberId, SessionVisibility, TeamId};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct TeamSearchQuery {
22 pub text: String,
24 pub team_id: TeamId,
26 pub providers: Option<Vec<String>>,
28 pub members: Option<Vec<MemberId>>,
30 pub from_date: Option<DateTime<Utc>>,
32 pub to_date: Option<DateTime<Utc>>,
34 pub tags: Option<Vec<String>>,
36 pub include_archived: bool,
38 pub search_content: bool,
40 pub limit: usize,
42 pub offset: usize,
44 pub sort_by: SearchSortField,
46 pub sort_order: SortOrder,
48}
49
50impl Default for TeamSearchQuery {
51 fn default() -> Self {
52 Self {
53 text: String::new(),
54 team_id: Uuid::nil(),
55 providers: None,
56 members: None,
57 from_date: None,
58 to_date: None,
59 tags: None,
60 include_archived: false,
61 search_content: true,
62 limit: 20,
63 offset: 0,
64 sort_by: SearchSortField::Relevance,
65 sort_order: SortOrder::Descending,
66 }
67 }
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
72#[serde(rename_all = "snake_case")]
73pub enum SearchSortField {
74 Relevance,
75 CreatedAt,
76 UpdatedAt,
77 MessageCount,
78 Title,
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
83#[serde(rename_all = "snake_case")]
84pub enum SortOrder {
85 Ascending,
86 Descending,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct TeamSearchResult {
92 pub sessions: Vec<TeamSessionResult>,
94 pub total_count: usize,
96 pub took_ms: u64,
98 pub facets: SearchFacets,
100 pub suggestions: Vec<String>,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct TeamSessionResult {
107 pub session_id: String,
109 pub title: String,
111 pub owner_id: MemberId,
113 pub owner_name: String,
115 pub provider: String,
117 pub model: Option<String>,
119 pub message_count: u32,
121 pub created_at: DateTime<Utc>,
123 pub updated_at: DateTime<Utc>,
125 pub tags: Vec<String>,
127 pub archived: bool,
129 pub score: f32,
131 pub highlights: Vec<SearchHighlight>,
133 pub visibility: SessionVisibility,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct SearchHighlight {
140 pub field: String,
142 pub snippet: String,
144 pub message_index: Option<usize>,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct SearchFacets {
151 pub providers: HashMap<String, usize>,
153 pub members: HashMap<String, MemberFacet>,
155 pub tags: HashMap<String, usize>,
157 pub date_histogram: Vec<DateBucket>,
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct MemberFacet {
164 pub member_id: MemberId,
165 pub display_name: String,
166 pub count: usize,
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct DateBucket {
172 pub date: DateTime<Utc>,
173 pub count: usize,
174}
175
176pub struct TeamSearchEngine {
182 access_control: AccessControl,
184}
185
186impl TeamSearchEngine {
187 pub fn new(access_control: AccessControl) -> Self {
189 Self { access_control }
190 }
191
192 pub async fn search(
194 &self,
195 query: TeamSearchQuery,
196 searcher_id: MemberId,
197 sessions: &[SessionData],
198 ) -> TeamSearchResult {
199 let start = std::time::Instant::now();
200
201 let accessible_sessions: Vec<&SessionData> = sessions
203 .iter()
204 .filter(|s| self.can_view_session(searcher_id, query.team_id, s))
205 .collect();
206
207 let mut matched_sessions: Vec<TeamSessionResult> = accessible_sessions
209 .iter()
210 .filter_map(|s| self.match_session(s, &query))
211 .collect();
212
213 let total_count = matched_sessions.len();
215
216 self.sort_results(&mut matched_sessions, query.sort_by, query.sort_order);
218
219 let paginated: Vec<TeamSessionResult> = matched_sessions
221 .into_iter()
222 .skip(query.offset)
223 .take(query.limit)
224 .collect();
225
226 let facets = self.calculate_facets(&accessible_sessions, &query);
228
229 let suggestions = self.generate_suggestions(&query.text);
231
232 TeamSearchResult {
233 sessions: paginated,
234 total_count,
235 took_ms: start.elapsed().as_millis() as u64,
236 facets,
237 suggestions,
238 }
239 }
240
241 fn can_view_session(&self, user_id: MemberId, team_id: TeamId, session: &SessionData) -> bool {
243 if session.owner_id == user_id {
245 return true;
246 }
247
248 match session.visibility {
250 SessionVisibility::Private => false,
251 SessionVisibility::TeamOnly | SessionVisibility::Public => {
252 let resource = Resource::Session {
254 team_id,
255 session_id: session.session_id.clone(),
256 owner_id: session.owner_id,
257 };
258 matches!(
259 self.access_control.check(user_id, &resource, Action::View),
260 super::rbac::AccessDecision::Allow
261 )
262 }
263 }
264 }
265
266 fn match_session(&self, session: &SessionData, query: &TeamSearchQuery) -> Option<TeamSessionResult> {
268 if let Some(providers) = &query.providers {
270 if !providers.contains(&session.provider) {
271 return None;
272 }
273 }
274
275 if let Some(members) = &query.members {
277 if !members.contains(&session.owner_id) {
278 return None;
279 }
280 }
281
282 if let Some(from) = query.from_date {
284 if session.created_at < from {
285 return None;
286 }
287 }
288 if let Some(to) = query.to_date {
289 if session.created_at > to {
290 return None;
291 }
292 }
293
294 if let Some(tags) = &query.tags {
296 if !tags.iter().any(|t| session.tags.contains(t)) {
297 return None;
298 }
299 }
300
301 if !query.include_archived && session.archived {
303 return None;
304 }
305
306 let (score, highlights) = self.calculate_relevance(session, &query.text, query.search_content);
308
309 if !query.text.is_empty() && score < 0.1 {
311 return None;
312 }
313
314 Some(TeamSessionResult {
315 session_id: session.session_id.clone(),
316 title: session.title.clone(),
317 owner_id: session.owner_id,
318 owner_name: session.owner_name.clone(),
319 provider: session.provider.clone(),
320 model: session.model.clone(),
321 message_count: session.message_count,
322 created_at: session.created_at,
323 updated_at: session.updated_at,
324 tags: session.tags.clone(),
325 archived: session.archived,
326 score,
327 highlights,
328 visibility: session.visibility,
329 })
330 }
331
332 fn calculate_relevance(
334 &self,
335 session: &SessionData,
336 query_text: &str,
337 search_content: bool,
338 ) -> (f32, Vec<SearchHighlight>) {
339 if query_text.is_empty() {
340 return (1.0, vec![]);
341 }
342
343 let query_lower = query_text.to_lowercase();
344 let query_terms: Vec<&str> = query_lower.split_whitespace().collect();
345 let mut score = 0.0;
346 let mut highlights = vec![];
347
348 let title_lower = session.title.to_lowercase();
350 for term in &query_terms {
351 if title_lower.contains(term) {
352 score += 3.0;
353 highlights.push(SearchHighlight {
354 field: "title".to_string(),
355 snippet: self.highlight_text(&session.title, term),
356 message_index: None,
357 });
358 }
359 }
360
361 for tag in &session.tags {
363 let tag_lower = tag.to_lowercase();
364 for term in &query_terms {
365 if tag_lower.contains(term) {
366 score += 2.0;
367 highlights.push(SearchHighlight {
368 field: "tags".to_string(),
369 snippet: tag.clone(),
370 message_index: None,
371 });
372 }
373 }
374 }
375
376 if query_terms.iter().any(|t| session.provider.to_lowercase().contains(t)) {
378 score += 1.0;
379 }
380
381 if search_content {
383 for (idx, message) in session.messages.iter().enumerate() {
384 let content_lower = message.content.to_lowercase();
385 for term in &query_terms {
386 if content_lower.contains(term) {
387 score += 0.5;
388 if highlights.len() < 5 {
390 highlights.push(SearchHighlight {
391 field: "content".to_string(),
392 snippet: self.extract_snippet(&message.content, term),
393 message_index: Some(idx),
394 });
395 }
396 }
397 }
398 }
399 }
400
401 let max_possible = (query_terms.len() as f32) * 5.0;
403 let normalized_score = (score / max_possible).min(1.0);
404
405 (normalized_score, highlights)
406 }
407
408 fn highlight_text(&self, text: &str, term: &str) -> String {
410 let lower = text.to_lowercase();
411 if let Some(pos) = lower.find(term) {
412 let before = &text[..pos];
413 let matched = &text[pos..pos + term.len()];
414 let after = &text[pos + term.len()..];
415 format!("{}**{}**{}", before, matched, after)
416 } else {
417 text.to_string()
418 }
419 }
420
421 fn extract_snippet(&self, content: &str, term: &str) -> String {
423 let lower = content.to_lowercase();
424 if let Some(pos) = lower.find(term) {
425 let start = pos.saturating_sub(50);
426 let end = (pos + term.len() + 50).min(content.len());
427
428 let mut snippet = String::new();
429 if start > 0 {
430 snippet.push_str("...");
431 }
432 snippet.push_str(&content[start..end]);
433 if end < content.len() {
434 snippet.push_str("...");
435 }
436 snippet
437 } else {
438 content.chars().take(100).collect()
439 }
440 }
441
442 fn sort_results(
444 &self,
445 results: &mut [TeamSessionResult],
446 sort_by: SearchSortField,
447 order: SortOrder,
448 ) {
449 results.sort_by(|a, b| {
450 let cmp = match sort_by {
451 SearchSortField::Relevance => a.score.partial_cmp(&b.score).unwrap(),
452 SearchSortField::CreatedAt => a.created_at.cmp(&b.created_at),
453 SearchSortField::UpdatedAt => a.updated_at.cmp(&b.updated_at),
454 SearchSortField::MessageCount => a.message_count.cmp(&b.message_count),
455 SearchSortField::Title => a.title.cmp(&b.title),
456 };
457
458 match order {
459 SortOrder::Ascending => cmp,
460 SortOrder::Descending => cmp.reverse(),
461 }
462 });
463 }
464
465 fn calculate_facets(&self, sessions: &[&SessionData], _query: &TeamSearchQuery) -> SearchFacets {
467 let mut providers: HashMap<String, usize> = HashMap::new();
468 let mut members: HashMap<String, MemberFacet> = HashMap::new();
469 let mut tags: HashMap<String, usize> = HashMap::new();
470 let mut date_counts: HashMap<String, usize> = HashMap::new();
471
472 for session in sessions {
473 *providers.entry(session.provider.clone()).or_insert(0) += 1;
475
476 let member_key = session.owner_id.to_string();
478 members
479 .entry(member_key.clone())
480 .and_modify(|f| f.count += 1)
481 .or_insert(MemberFacet {
482 member_id: session.owner_id,
483 display_name: session.owner_name.clone(),
484 count: 1,
485 });
486
487 for tag in &session.tags {
489 *tags.entry(tag.clone()).or_insert(0) += 1;
490 }
491
492 let month_key = session.created_at.format("%Y-%m").to_string();
494 *date_counts.entry(month_key).or_insert(0) += 1;
495 }
496
497 let mut date_histogram: Vec<DateBucket> = date_counts
499 .into_iter()
500 .filter_map(|(date_str, count)| {
501 let date = chrono::NaiveDate::parse_from_str(&format!("{}-01", date_str), "%Y-%m-%d")
502 .ok()?;
503 Some(DateBucket {
504 date: DateTime::from_naive_utc_and_offset(
505 date.and_hms_opt(0, 0, 0)?,
506 Utc,
507 ),
508 count,
509 })
510 })
511 .collect();
512 date_histogram.sort_by(|a, b| a.date.cmp(&b.date));
513
514 SearchFacets {
515 providers,
516 members,
517 tags,
518 date_histogram,
519 }
520 }
521
522 fn generate_suggestions(&self, query: &str) -> Vec<String> {
524 let mut suggestions = vec![];
526
527 if !query.is_empty() {
528 suggestions.push(format!("{} provider:copilot", query));
530 suggestions.push(format!("{} provider:cursor", query));
531
532 suggestions.push(format!("{} from:last-week", query));
534 }
535
536 suggestions
537 }
538}
539
540#[derive(Debug, Clone, Serialize, Deserialize)]
542pub struct SessionData {
543 pub session_id: String,
544 pub title: String,
545 pub owner_id: MemberId,
546 pub owner_name: String,
547 pub provider: String,
548 pub model: Option<String>,
549 pub message_count: u32,
550 pub created_at: DateTime<Utc>,
551 pub updated_at: DateTime<Utc>,
552 pub tags: Vec<String>,
553 pub archived: bool,
554 pub visibility: SessionVisibility,
555 pub messages: Vec<MessageData>,
556}
557
558#[derive(Debug, Clone, Serialize, Deserialize)]
560pub struct MessageData {
561 pub role: String,
562 pub content: String,
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568
569 fn create_test_session(
570 id: &str,
571 title: &str,
572 owner_id: MemberId,
573 provider: &str,
574 ) -> SessionData {
575 SessionData {
576 session_id: id.to_string(),
577 title: title.to_string(),
578 owner_id,
579 owner_name: "Test User".to_string(),
580 provider: provider.to_string(),
581 model: Some("gpt-4".to_string()),
582 message_count: 10,
583 created_at: Utc::now(),
584 updated_at: Utc::now(),
585 tags: vec!["rust".to_string()],
586 archived: false,
587 visibility: SessionVisibility::TeamOnly,
588 messages: vec![MessageData {
589 role: "user".to_string(),
590 content: "Hello, how do I write a Rust function?".to_string(),
591 }],
592 }
593 }
594
595 #[tokio::test]
596 async fn test_team_search() {
597 let access_control = AccessControl::new();
598 let engine = TeamSearchEngine::new(access_control);
599
600 let owner_id = Uuid::new_v4();
601 let sessions = vec![
602 create_test_session("1", "Rust Programming Help", owner_id, "copilot"),
603 create_test_session("2", "Python Tutorial", owner_id, "cursor"),
604 ];
605
606 let query = TeamSearchQuery {
607 text: "rust".to_string(),
608 team_id: Uuid::new_v4(),
609 limit: 10,
610 ..Default::default()
611 };
612
613 let result = engine.search(query, owner_id, &sessions).await;
614 assert!(!result.sessions.is_empty());
615 }
616}