1fn map_error(error: crate::Error) -> crate::wire::ErrorEnvelope {
5 error.into()
6}
7
8#[derive(Debug, Clone)]
13pub struct NamespaceIdent(pub Vec<String>);
14
15impl NamespaceIdent {
16 pub fn root() -> Self {
17 Self(vec![])
18 }
19 pub fn as_table_id(&self, table_name: &str) -> Vec<String> {
20 let mut id = self.0.clone();
21 id.push(table_name.to_string());
22 id
23 }
24}
25
26pub fn resolve_namespace(
30 namespace: Option<&str>,
31) -> Result<NamespaceIdent, crate::wire::ErrorEnvelope> {
32 match namespace {
33 None | Some(crate::wire::DEFAULT_NAMESPACE) => Ok(NamespaceIdent::root()),
34 Some(other) => Err(map_error(crate::Error::namespace_unknown(other))),
35 }
36}
37
38fn map_storage(error: anyhow::Error) -> crate::wire::ErrorEnvelope {
39 if let Some(conflict) = error.downcast_ref::<crate::substrate::ConflictExhausted>() {
42 return map_error(crate::Error::Conflict {
43 attempts: conflict.attempts,
44 });
45 }
46 map_error(crate::Error::Storage(error))
47}
48
49mod ingest_handler {
50 use anyhow::Result;
51 use tokio_stream::StreamExt;
52
53 use crate::{
54 adapter::{Adapter, AdapterYield, SkipOracle, SkipReason},
55 sessions::{IngestEvent, IngestSummary, IngestValidator, OutcomeStatus, RowOutcome, Store},
56 wire::{
57 ErrorBody, ErrorCode, IngestEnvelope, IngestRequest, IngestResponse, IngestResult,
58 IngestStatus, validate_protocol,
59 },
60 };
61
62 use super::{map_error, map_storage};
63
64 pub const MAX_INGEST_EVENTS: usize = 1000;
66
67 #[derive(Debug, Clone)]
74 pub enum SyncEvent {
75 Discovered { total: Option<usize> },
79 SessionDone(SessionOutcome),
82 }
83
84 #[derive(Debug, Clone)]
86 pub struct SessionOutcome {
87 pub project: Option<String>,
89 pub session_id: Option<String>,
92 pub messages: usize,
95 pub status: SyncStatus,
96 }
97
98 #[derive(Debug, Clone)]
111 pub enum SyncStatus {
112 Ok,
113 Partial {
114 dropped_events: usize,
115 first_drop_reason: Option<String>,
118 },
119 Skipped {
120 reason: String,
121 },
122 Rejected {
123 reason: String,
124 },
125 Fresh,
128 Empty,
132 }
133
134 #[derive(Debug, Default)]
135 struct InFlight {
136 project: Option<String>,
137 session_id: String,
138 messages: usize,
139 dropped_events: usize,
144 first_drop_reason: Option<String>,
145 session_index: usize,
149 }
150
151 #[derive(Debug)]
156 struct PendingDone {
157 project: Option<String>,
158 session_id: String,
159 messages: usize,
160 dropped_events: usize,
161 first_drop_reason: Option<String>,
162 session_index: usize,
163 }
164
165 const ADAPTER_FLUSH_BATCH: usize = 100;
172
173 pub async fn ingest_adapter<F>(
184 store: &Store,
185 adapter: &dyn Adapter,
186 oracle: &dyn SkipOracle,
187 mut on_event: F,
188 ) -> Result<IngestSummary>
189 where
190 F: FnMut(SyncEvent),
191 {
192 let mut summary = IngestSummary::default();
193 let truncations_before = crate::adapter::extract::truncated_values_count();
194 let total = adapter
198 .discover()
199 .await
200 .map_err(|error| tracing::debug!(%error, "adapter discover failed"))
201 .ok();
202 on_event(SyncEvent::Discovered { total });
203
204 let mut events = adapter.events_with(oracle);
205 let mut validator = IngestValidator::default();
206 let mut index = 0usize;
210 let mut in_flight: Option<InFlight> = None;
211 let mut pending_dones: std::collections::VecDeque<PendingDone> =
215 std::collections::VecDeque::new();
216 let mut decode_total = std::time::Duration::ZERO;
221 let mut decode_count = 0u64;
222 let mut validator_total = std::time::Duration::ZERO;
223 let mut validator_count = 0u64;
224 let run_started = std::time::Instant::now();
225
226 loop {
227 let decode_start = std::time::Instant::now();
228 let next = events.next().await;
229 decode_total += decode_start.elapsed();
230 decode_count += 1;
231 let event = match next {
232 Some(event) => event,
233 None => break,
234 };
235 match event {
236 Ok(AdapterYield::Skipped {
237 session_id,
238 project,
239 reason,
240 }) => {
241 let status = match reason {
242 SkipReason::Fresh => {
243 summary.skipped_fresh += 1;
244 SyncStatus::Fresh
245 }
246 SkipReason::Empty => {
247 summary.skipped_empty += 1;
248 SyncStatus::Empty
249 }
250 SkipReason::Unsupported(reason) => {
251 summary.skipped_files += 1;
252 SyncStatus::Skipped { reason }
253 }
254 };
255 on_event(SyncEvent::SessionDone(SessionOutcome {
256 project,
257 session_id,
258 messages: 0,
259 status,
260 }));
261 }
262 Ok(AdapterYield::Event(event)) => {
263 if matches!(&event, IngestEvent::Session(_))
268 && let Some(prev) = in_flight.take()
269 {
270 pending_dones.push_back(PendingDone {
271 project: prev.project,
272 session_id: prev.session_id,
273 messages: prev.messages,
274 dropped_events: prev.dropped_events,
275 first_drop_reason: prev.first_drop_reason,
276 session_index: prev.session_index,
277 });
278 }
279 let event_index = index;
280 match &event {
281 IngestEvent::Session(session) => {
282 in_flight = Some(InFlight {
283 project: Some((*session.project).clone()),
284 session_id: session.id.clone(),
285 messages: 0,
286 dropped_events: 0,
287 first_drop_reason: None,
288 session_index: event_index,
289 });
290 }
291 IngestEvent::Message(_) => {
292 if let Some(slot) = in_flight.as_mut() {
293 slot.messages += 1;
294 }
295 }
296 IngestEvent::Part(_) => {}
297 }
298
299 let validator_start = std::time::Instant::now();
300 let push_outcomes = validator.push(store, index, event).await?;
301 validator_total += validator_start.elapsed();
302 validator_count += 1;
303 for outcome in &push_outcomes {
310 if matches!(outcome.status, OutcomeStatus::Error)
311 && outcome.kind != "session"
312 && let Some(slot) = in_flight.as_mut()
313 {
314 slot.dropped_events += 1;
315 if slot.first_drop_reason.is_none() {
316 slot.first_drop_reason =
317 outcome.error.as_ref().map(|err| err.message.clone());
318 }
319 }
320 }
321 summary.add_outcomes(&push_outcomes);
322 index += 1;
323
324 if validator.pending_substreams() >= ADAPTER_FLUSH_BATCH {
329 let flush_start = std::time::Instant::now();
330 let (flush_outcomes, flush_counts) = validator.flush(store).await?;
331 validator_total += flush_start.elapsed();
332 validator_count += 1;
333 summary.add_outcomes_errors_only(&flush_outcomes);
337 summary.add_batch(&flush_counts);
338 drain_pending_dones(&mut pending_dones, &flush_outcomes, &mut on_event);
339 }
340 }
341 Err(error) => {
342 tracing::debug!(
348 %error,
349 "adapter event error (per-line drop by design)"
350 );
351 match in_flight.as_mut() {
352 Some(slot) => {
353 slot.dropped_events += 1;
357 if slot.first_drop_reason.is_none() {
358 slot.first_drop_reason = Some(error.to_string());
359 }
360 summary.dropped_events += 1;
361 }
362 None => {
363 summary.skipped_files += 1;
368 on_event(SyncEvent::SessionDone(SessionOutcome {
369 project: None,
370 session_id: None,
371 messages: 0,
372 status: SyncStatus::Skipped {
373 reason: error.to_string(),
374 },
375 }));
376 }
377 }
378 }
379 }
380 }
381
382 if let Some(prev) = in_flight.take() {
383 pending_dones.push_back(PendingDone {
384 project: prev.project,
385 session_id: prev.session_id,
386 messages: prev.messages,
387 dropped_events: prev.dropped_events,
388 first_drop_reason: prev.first_drop_reason,
389 session_index: prev.session_index,
390 });
391 }
392 let validator_start = std::time::Instant::now();
393 let (final_outcomes, final_counts) = validator.finish(store).await?;
394 validator_total += validator_start.elapsed();
395 validator_count += 1;
396 summary.add_outcomes_errors_only(&final_outcomes);
397 summary.add_batch(&final_counts);
398 drain_pending_dones(&mut pending_dones, &final_outcomes, &mut on_event);
399
400 summary.truncated_values = crate::adapter::extract::truncated_values_count()
401 .saturating_sub(truncations_before) as usize;
402
403 let total = run_started.elapsed();
404 let other = total
405 .saturating_sub(decode_total)
406 .saturating_sub(validator_total);
407 tracing::info!(
408 target: "pond::perf",
409 total_ms = total.as_millis() as u64,
410 decode_ms = decode_total.as_millis() as u64,
411 validator_ms = validator_total.as_millis() as u64,
412 other_ms = other.as_millis() as u64,
413 decode_calls = decode_count,
414 validator_calls = validator_count,
415 rows_inserted = summary.inserted as u64,
416 rows_matched = summary.matched as u64,
417 dropped_events = summary.dropped_events as u64,
418 dropped_sessions = summary.dropped_sessions as u64,
419 skipped_files = summary.skipped_files as u64,
420 skipped_fresh = summary.skipped_fresh as u64,
421 truncated_values = summary.truncated_values as u64,
422 "ingest_adapter complete"
423 );
424 Ok(summary)
425 }
426
427 fn drain_pending_dones<F>(
434 queue: &mut std::collections::VecDeque<PendingDone>,
435 outcomes: &[RowOutcome],
436 on_event: &mut F,
437 ) where
438 F: FnMut(SyncEvent),
439 {
440 let mut session_outcome_by_index: std::collections::HashMap<usize, &RowOutcome> =
443 std::collections::HashMap::new();
444 for outcome in outcomes {
445 if outcome.kind == "session" {
446 session_outcome_by_index.insert(outcome.index, outcome);
447 }
448 }
449
450 while let Some(done) = queue.pop_front() {
451 let session_outcome = session_outcome_by_index.get(&done.session_index).copied();
452 let rejection_reason = session_outcome.and_then(|outcome| {
453 if matches!(outcome.status, OutcomeStatus::Error) {
454 Some(
455 outcome
456 .error
457 .as_ref()
458 .map(|err| err.message.clone())
459 .unwrap_or_else(|| "session-level rejection".to_owned()),
460 )
461 } else {
462 None
463 }
464 });
465 let status = if let Some(reason) = rejection_reason {
466 SyncStatus::Rejected { reason }
467 } else if done.dropped_events > 0 {
468 SyncStatus::Partial {
469 dropped_events: done.dropped_events,
470 first_drop_reason: done.first_drop_reason,
471 }
472 } else {
473 SyncStatus::Ok
474 };
475 on_event(SyncEvent::SessionDone(SessionOutcome {
476 project: done.project,
477 session_id: Some(done.session_id),
478 messages: done.messages,
479 status,
480 }));
481 }
482 }
483
484 pub async fn pond_ingest(store: &Store, request: IngestRequest) -> IngestEnvelope {
490 if let Err(envelope) = validate_protocol(request.protocol_version) {
491 return IngestEnvelope::Error(envelope);
492 }
493 if let Err(envelope) = super::resolve_namespace(request.namespace.as_deref()) {
494 return IngestEnvelope::Error(envelope);
495 }
496 if request.events.is_empty() {
497 return IngestEnvelope::Error(map_error(crate::Error::validation_field(
498 "events must be a non-empty array",
499 "events",
500 Some(serde_json::json!([])),
501 Some("non-empty array".to_owned()),
502 )));
503 }
504 if request.events.len() > MAX_INGEST_EVENTS {
505 return IngestEnvelope::Error(map_error(crate::Error::validation_field(
506 format!("ingest batch exceeds the event cap: at most {MAX_INGEST_EVENTS} events"),
507 "events",
508 Some(serde_json::json!(request.events.len())),
509 Some(format!("at most {MAX_INGEST_EVENTS} events")),
510 )));
511 }
512
513 match ingest_events(store, request.events).await {
514 Ok(outcomes) => {
515 let mut accepted = 0;
516 let mut rejected = 0;
517 for outcome in &outcomes {
518 match outcome.status {
519 OutcomeStatus::Inserted | OutcomeStatus::Matched => accepted += 1,
520 OutcomeStatus::Error => rejected += 1,
521 }
522 }
523 let results = outcomes
524 .into_iter()
525 .map(outcome_to_result)
526 .collect::<Vec<_>>();
527 IngestEnvelope::Success(IngestResponse {
528 accepted,
529 rejected,
530 results,
531 })
532 }
533 Err(failure) => IngestEnvelope::Error(map_storage(failure)),
534 }
535 }
536
537 pub async fn ingest_events(store: &Store, events: Vec<IngestEvent>) -> Result<Vec<RowOutcome>> {
543 let mut validator = IngestValidator::default();
544 let mut outcomes = Vec::with_capacity(events.len());
545 for (index, event) in events.into_iter().enumerate() {
546 let mut chunk = validator.push(store, index, event).await?;
547 outcomes.append(&mut chunk);
548 }
549 let (mut tail, _counts) = validator.finish(store).await?;
552 outcomes.append(&mut tail);
553 outcomes.sort_by_key(|outcome| outcome.index);
554 Ok(outcomes)
555 }
556
557 fn outcome_to_result(outcome: RowOutcome) -> IngestResult {
558 let (status, error) = match (outcome.status, outcome.error) {
559 (OutcomeStatus::Inserted, _) => (IngestStatus::Inserted, None),
560 (OutcomeStatus::Matched, _) => (IngestStatus::Matched, None),
561 (OutcomeStatus::Error, error) => {
562 let body = error
563 .map(|err| {
564 let mut details = serde_json::Map::new();
565 if let Some(field) = err.field {
566 details.insert("field".to_owned(), serde_json::json!(field));
567 }
568 if let Some(reason) = err.reason {
569 details.insert("reason".to_owned(), serde_json::json!(reason));
570 }
571 ErrorBody {
572 code: ErrorCode::ValidationFailed,
573 message: err.message,
574 details: serde_json::Value::Object(details),
575 }
576 })
577 .unwrap_or_else(|| ErrorBody {
578 code: ErrorCode::ValidationFailed,
579 message: "ingest failed".to_owned(),
580 details: serde_json::json!({}),
581 });
582 (IngestStatus::Error, Some(body))
583 }
584 };
585 IngestResult {
586 index: outcome.index,
587 kind: outcome.kind.to_owned(),
588 pk: outcome.pk,
589 status,
590 error,
591 }
592 }
593}
594
595pub use crate::sessions::{IngestEvent, IngestSummary, IngestValidator, search_text};
596pub use ingest_handler::{
597 MAX_INGEST_EVENTS, SessionOutcome, SyncEvent, SyncStatus, ingest_adapter, ingest_events,
598 pond_ingest,
599};
600
601mod export_handler {
602 use anyhow::{Context, Result};
613 use tokio::io::{AsyncWrite, AsyncWriteExt};
614
615 use crate::sessions::{IngestEvent, Store};
616
617 #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
618 pub struct ExportSummary {
619 pub sessions: usize,
620 pub messages: usize,
621 pub parts: usize,
622 }
623
624 pub async fn pond_export<W>(
625 store: &Store,
626 session_filter: Option<&str>,
627 writer: &mut W,
628 ) -> Result<ExportSummary>
629 where
630 W: AsyncWrite + Unpin,
631 {
632 let mut session_ids = match session_filter {
633 Some(id) => vec![id.to_owned()],
634 None => store.session_ids().await?,
635 };
636 session_ids.sort();
637
638 let mut summary = ExportSummary::default();
639 for session_id in session_ids {
640 let Some(stored) = store
641 .get_session(&session_id)
642 .await
643 .with_context(|| format!("export: failed to load session {session_id}"))?
644 else {
645 if session_filter.is_some() {
646 anyhow::bail!("export: session not found: {session_id}");
647 }
648 continue;
649 };
650 write_event(writer, &IngestEvent::Session(stored.session)).await?;
651 summary.sessions += 1;
652 for message_with_parts in stored.messages {
653 write_event(writer, &IngestEvent::Message(message_with_parts.message)).await?;
654 summary.messages += 1;
655 for part in message_with_parts.parts {
656 write_event(writer, &IngestEvent::Part(part)).await?;
657 summary.parts += 1;
658 }
659 }
660 }
661 writer.flush().await.context("export: flush failed")?;
662 Ok(summary)
663 }
664
665 async fn write_event<W>(writer: &mut W, event: &IngestEvent) -> Result<()>
666 where
667 W: AsyncWrite + Unpin,
668 {
669 let line = serde_json::to_string(event).context("export: serialize event")?;
670 writer
671 .write_all(line.as_bytes())
672 .await
673 .context("export: write event")?;
674 writer
675 .write_all(b"\n")
676 .await
677 .context("export: write newline")?;
678 Ok(())
679 }
680}
681
682pub use export_handler::{ExportSummary, pond_export};
683
684mod restore_handler {
685 use anyhow::{Context, Result, bail};
692
693 use crate::sessions::{SessionWithMessages, Store};
694
695 pub async fn restore_lineage(
696 store: &Store,
697 session_id: &str,
698 ) -> Result<Vec<SessionWithMessages>> {
699 let Some(parent) = store.get_session(session_id).await? else {
700 bail!("export: session not found: {session_id}");
701 };
702 let mut sessions = vec![parent];
703 for child in store.child_sessions(session_id).await? {
704 if !store.child_sessions(&child.id).await?.is_empty() {
705 bail!(
706 "adapter-lineage-complete-restore supports one subagent level; session {} has child sessions",
707 child.id
708 );
709 }
710 let child_id = child.id;
711 let stored = store
712 .get_session(&child_id)
713 .await?
714 .with_context(|| format!("export: child session disappeared: {child_id}"))?;
715 sessions.push(stored);
716 }
717 Ok(sessions)
718 }
719}
720
721pub use restore_handler::restore_lineage;
722
723mod get_handler {
724 use crate::{
725 sessions::{GetLookup, MessageViewParams, RetrievedMessage, SessionViewParams, Store},
726 wire::{
727 GetEnvelope, GetRequest, GetResponse, GetResult, GetSession, MessageView, PartSummary,
728 ResponseMode, ResponsePart, validate_protocol,
729 },
730 };
731
732 use super::{map_error, map_storage};
733
734 fn to_message_view(message: RetrievedMessage, verbatim: bool) -> MessageView {
739 if verbatim {
740 return MessageView {
741 id: message.id,
742 role: message.role,
743 timestamp: message.timestamp,
744 text: None,
745 content: None,
746 parts_summary: Vec::new(),
747 parts: Some(
748 message
749 .parts
750 .into_iter()
751 .map(ResponsePart::from_part)
752 .collect(),
753 ),
754 };
755 }
756 let parts_summary = message
757 .parts
758 .iter()
759 .filter_map(|part| PartSummary::for_kind(&part.kind))
760 .collect();
761 MessageView {
762 id: message.id,
763 role: message.role,
764 timestamp: message.timestamp,
765 text: message.text,
766 content: message.content,
767 parts_summary,
768 parts: None,
769 }
770 }
771
772 const BUDGET_BYTES: usize = 200_000;
777
778 pub async fn pond_get(store: &Store, request: GetRequest) -> GetEnvelope {
779 if let Err(error) = validate_protocol(request.protocol_version) {
780 return GetEnvelope::Error(error);
781 }
782 if let Err(envelope) = super::resolve_namespace(request.namespace.as_deref()) {
783 return GetEnvelope::Error(envelope);
784 }
785
786 let response = match (&request.session_id, &request.message_id) {
787 (Some(session_id), None) => session_result(store, session_id, &request).await,
788 (None, Some(message_id)) => message_result(store, message_id, &request).await,
789 (Some(_), Some(_)) => Err(map_error(crate::Error::validation_field(
790 "session_id and message_id are mutually exclusive",
791 "message_id",
792 request.message_id.clone().map(serde_json::Value::String),
793 Some("omit when session_id is present".to_owned()),
794 ))),
795 (None, None) => Err(map_error(crate::Error::validation(
796 "one of session_id or message_id is required",
797 ))),
798 };
799
800 match response {
801 Ok(response) => GetEnvelope::Success(response),
802 Err(error) => GetEnvelope::Error(error),
803 }
804 }
805
806 fn unknown_after_id(request: &GetRequest, anchor_of: &str) -> crate::wire::ErrorEnvelope {
809 map_error(crate::Error::validation_field(
810 "after_id not found (stale or mistyped continuation anchor)",
811 "after_id",
812 request.after_id.clone().map(serde_json::Value::String),
813 Some(format!("a {anchor_of} from a prior page of this read")),
814 ))
815 }
816
817 async fn session_result(
818 store: &Store,
819 session_id: &str,
820 request: &GetRequest,
821 ) -> Result<GetResponse, crate::wire::ErrorEnvelope> {
822 let params = SessionViewParams {
823 mode: request.response_mode,
824 after_id: request.after_id.as_deref(),
825 limit: request.limit,
826 budget_bytes: BUDGET_BYTES,
827 session_from: request.session_from,
828 };
829 let view = match store
830 .session_view(session_id, params)
831 .await
832 .map_err(map_storage)?
833 {
834 GetLookup::NotFound => {
835 return Err(map_error(crate::Error::not_found(
836 "session",
837 serde_json::json!(session_id),
838 format!("session not found: {session_id}"),
839 )));
840 }
841 GetLookup::UnknownAfterId => return Err(unknown_after_id(request, "message id")),
842 GetLookup::Found(view) => view,
843 };
844 let verbatim = matches!(request.response_mode, ResponseMode::Verbatim);
845 Ok(GetResponse {
846 session: GetSession::from_session(&view.session),
847 result: GetResult::Session {
848 messages: view
849 .messages
850 .into_iter()
851 .map(|message| to_message_view(message, verbatim))
852 .collect(),
853 messages_remaining: view.messages_remaining,
854 },
855 })
856 }
857
858 async fn message_result(
859 store: &Store,
860 message_id: &str,
861 request: &GetRequest,
862 ) -> Result<GetResponse, crate::wire::ErrorEnvelope> {
863 let params = MessageViewParams {
864 context_depth: request.context_depth,
865 mode: request.response_mode,
866 after_id: request.after_id.as_deref(),
867 limit: request.limit,
868 budget_bytes: BUDGET_BYTES,
869 };
870 let view = match store
871 .message_view(message_id, params)
872 .await
873 .map_err(map_storage)?
874 {
875 GetLookup::NotFound => {
876 return Err(map_error(crate::Error::not_found(
877 "message",
878 serde_json::json!(message_id),
879 format!("message not found: {message_id}"),
880 )));
881 }
882 GetLookup::UnknownAfterId => return Err(unknown_after_id(request, "part id")),
883 GetLookup::Found(view) => view,
884 };
885 let target = MessageView {
888 id: view.target.id,
889 role: view.target.role,
890 timestamp: view.target.timestamp,
891 text: None,
892 content: None,
893 parts_summary: Vec::new(),
894 parts: None,
895 };
896 Ok(GetResponse {
897 session: GetSession::from_session(&view.session),
898 result: GetResult::Message {
899 target,
900 target_parts: view
901 .target_parts
902 .into_iter()
903 .map(ResponsePart::from_part)
904 .collect(),
905 target_parts_remaining: view.target_parts_remaining,
906 siblings: view
907 .siblings
908 .into_iter()
909 .map(|sibling| to_message_view(sibling, false))
910 .collect(),
911 },
912 })
913 }
914}
915
916pub use get_handler::pond_get;
917
918mod search_handler {
919 use crate::{
924 Clock, SystemClock,
925 embed::{Embedder, LazyEmbedder, format_query},
926 sessions::{MessageKey, MessageMeta, Store},
927 substrate::{Predicate, ScalarValue},
928 wire::{
929 ErrorEnvelope, PartSummary, ProjectFilter, Role, SearchEnvelope, SearchFilters,
930 SearchRequest, SearchResponse, SearchResult, SearchSession, validate_protocol,
931 },
932 };
933 use chrono::NaiveDate;
934 use std::collections::HashMap;
935
936 use super::{map_error, map_storage};
937
938 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
943 pub enum SearchMode {
944 Hybrid,
945 Fts,
946 Vector,
947 }
948
949 #[derive(Debug, Clone, PartialEq)]
950 pub struct SearchPlan {
951 pub mode: SearchMode,
952 pub query: String,
953 pub filter: Predicate,
954 pub filters: SearchFilters,
955 pub pool: usize,
956 pub vector_pool: usize,
957 pub limit: usize,
958 pub min_score: f64,
959 }
960
961 const LIMIT_CAP: usize = 200;
962 const MAX_MATCHES_PER_SESSION: usize = 3;
963 const SEARCH_BUDGET_BYTES: usize = 60_000;
964 const HIT_SNIPPET_CHARS: usize = 600;
968 const SCORE_DENOMINATOR: f64 = FTS_FUSION_WEIGHT + VECTOR_FUSION_WEIGHT;
969
970 const FTS_FUSION_WEIGHT: f64 = 0.3;
982 const VECTOR_FUSION_WEIGHT: f64 = 1.0;
983
984 pub async fn pond_search(
993 store: &Store,
994 embedder: &LazyEmbedder,
995 request: SearchRequest,
996 search: &crate::config::SearchConfig,
997 ) -> SearchEnvelope {
998 match run_search(store, embedder, request, search, &SystemClock).await {
999 Ok(response) => SearchEnvelope::Success(response),
1000 Err(envelope) => SearchEnvelope::Error(envelope),
1001 }
1002 }
1003
1004 pub async fn explain_search_plan(
1005 store: &Store,
1006 embedder: &LazyEmbedder,
1007 request: SearchRequest,
1008 search: &crate::config::SearchConfig,
1009 ) -> Result<String, ErrorEnvelope> {
1010 let override_mode = request.mode_override.map(wire_mode_to_internal);
1011 let mut plan = plan_search(request, SearchMode::Fts)?;
1012 plan.mode = resolve_effective_mode(store, override_mode).await?;
1013 let mut out = String::new();
1014 if matches!(plan.mode, SearchMode::Fts | SearchMode::Hybrid) {
1015 let fts = store
1016 .explain_fts_plan(&plan.query, plan.pool, &plan.filter)
1017 .await
1018 .map_err(map_storage)?;
1019 out.push_str("fts:\n");
1020 out.push_str(&fts);
1021 out.push('\n');
1022 }
1023 if matches!(plan.mode, SearchMode::Vector | SearchMode::Hybrid) {
1024 let backend = load_embedder(embedder).await?;
1025 let vector = embed_query(backend.as_ref(), &plan.query)?;
1026 let vector_plan = store
1027 .explain_vector_plan(&vector, plan.vector_pool, &plan.filter, Some(search))
1028 .await
1029 .map_err(map_storage)?;
1030 out.push_str("vector:\n");
1031 out.push_str(&vector_plan);
1032 out.push('\n');
1033 }
1034 Ok(out)
1035 }
1036
1037 async fn run_search(
1038 store: &Store,
1039 embedder: &LazyEmbedder,
1040 request: SearchRequest,
1041 search: &crate::config::SearchConfig,
1042 _clock: &dyn Clock,
1043 ) -> Result<SearchResponse, ErrorEnvelope> {
1044 let override_mode = request.mode_override.map(wire_mode_to_internal);
1045 let mut plan = plan_search(request, SearchMode::Fts)?;
1046
1047 plan.mode = resolve_effective_mode(store, override_mode).await?;
1051
1052 let single_session = plan.filters.session_id.is_some();
1056 let fusion_key = if single_session {
1057 FusionKey::Message
1058 } else {
1059 FusionKey::SessionRoot
1060 };
1061
1062 let candidates_fut = async {
1069 match plan.mode {
1070 SearchMode::Fts => {
1071 let hits = store
1072 .fts_search(&plan.query, plan.pool, &plan.filter)
1073 .await
1074 .map_err(map_storage)?;
1075 Ok(normalize_fts(hits))
1076 }
1077 SearchMode::Hybrid => {
1078 let backend = load_embedder(embedder).await?;
1079 let vector = embed_query(backend.as_ref(), &plan.query)?;
1080 let fts_fut = async {
1083 store
1084 .fts_search(&plan.query, plan.pool, &plan.filter)
1085 .await
1086 .map_err(map_storage)
1087 };
1088 let vector_fut = async {
1089 store
1090 .vector_search(&vector, plan.vector_pool, &plan.filter, Some(search))
1091 .await
1092 .map_err(map_storage)
1093 };
1094 let (fts, vector_raw) = tokio::try_join!(fts_fut, vector_fut)?;
1095 let fts_entries: Vec<(MessageKey, f64)> = fts
1104 .into_iter()
1105 .map(|(key, score)| (key, f64::from(score)))
1106 .collect();
1107 let vector_entries: Vec<(MessageKey, f64)> = vector_raw
1108 .into_iter()
1109 .map(|(key, distance)| (key, 1.0 - f64::from(distance)))
1110 .collect();
1111 let lists = [
1112 RankedList {
1113 retriever: RetrieverKind::Fts,
1114 entries: fts_entries,
1115 weight: FTS_FUSION_WEIGHT,
1116 },
1117 RankedList {
1118 retriever: RetrieverKind::Vector,
1119 entries: vector_entries,
1120 weight: VECTOR_FUSION_WEIGHT,
1121 },
1122 ];
1123 Ok(fuse_arms(&lists, fusion_key)
1124 .into_iter()
1125 .map(|hit| Candidate {
1126 session_id: hit.key.session_id,
1127 message_id: hit.key.message_id,
1128 base_score: hit.score,
1129 })
1130 .collect())
1131 }
1132 SearchMode::Vector => {
1136 let backend = load_embedder(embedder).await?;
1137 let vector = embed_query(backend.as_ref(), &plan.query)?;
1138 let vector_raw = store
1139 .vector_search(&vector, plan.vector_pool, &plan.filter, Some(search))
1140 .await
1141 .map_err(map_storage)?;
1142 Ok(normalize_vector(vector_raw))
1143 }
1144 }
1145 };
1146 let scope_fut = async {
1147 store
1148 .searchable_in_scope(&plan.filter)
1149 .await
1150 .map_err(map_storage)
1151 };
1152 let (candidates, searchable_in_scope) = tokio::try_join!(candidates_fut, scope_fut)?;
1153
1154 if candidates.is_empty() {
1155 return Ok(empty_response(searchable_in_scope));
1156 }
1157
1158 let keys = candidates
1161 .iter()
1162 .map(|candidate| MessageKey {
1163 session_id: candidate.session_id.clone(),
1164 message_id: candidate.message_id.clone(),
1165 })
1166 .collect::<Vec<_>>();
1167 let metas = store
1168 .message_metas_by_keys(&keys)
1169 .await
1170 .map_err(map_storage)?;
1171 let meta_index = metas
1172 .iter()
1173 .map(|meta| ((meta.session_id.as_str(), meta.message_id.as_str()), meta))
1174 .collect::<std::collections::HashMap<_, _>>();
1175
1176 let mut scored = Vec::with_capacity(candidates.len());
1177 for candidate in candidates {
1178 let Some(meta) =
1179 meta_index.get(&(candidate.session_id.as_str(), candidate.message_id.as_str()))
1180 else {
1181 continue;
1182 };
1183 let score = candidate.base_score;
1184 if score < plan.min_score {
1185 continue;
1186 }
1187 scored.push(ScoredHit {
1188 meta: (*meta).clone(),
1189 score,
1190 });
1191 }
1192 scored.sort_by(|left, right| {
1193 right
1194 .score
1195 .partial_cmp(&left.score)
1196 .unwrap_or(std::cmp::Ordering::Equal)
1197 .then_with(|| left.meta.session_id.cmp(&right.meta.session_id))
1198 .then_with(|| left.meta.message_id.cmp(&right.meta.message_id))
1199 });
1200
1201 let matched_total = scored.len();
1202 let matches_cap = if single_session {
1206 plan.limit
1207 } else {
1208 MAX_MATCHES_PER_SESSION
1209 };
1210 let sessions = build_sessions(store, &scored, &plan.query, matches_cap).await?;
1211 page_sessions(sessions, matched_total, searchable_in_scope, &plan)
1212 }
1213
1214 async fn resolve_effective_mode(
1219 store: &Store,
1220 override_mode: Option<SearchMode>,
1221 ) -> Result<SearchMode, ErrorEnvelope> {
1222 if let Some(mode) = override_mode {
1223 return Ok(mode);
1224 }
1225 let has = store.has_embeddings().await.map_err(map_storage)?;
1226 Ok(if has {
1227 SearchMode::Hybrid
1228 } else {
1229 SearchMode::Fts
1230 })
1231 }
1232
1233 async fn load_embedder(
1237 embedder: &LazyEmbedder,
1238 ) -> Result<std::sync::Arc<dyn Embedder>, ErrorEnvelope> {
1239 embedder.get().await.map_err(|error| {
1240 map_error(crate::Error::internal(format!(
1241 "embedder load failed: {error}"
1242 )))
1243 })
1244 }
1245
1246 pub fn plan_search(
1247 request: SearchRequest,
1248 mode: SearchMode,
1249 ) -> Result<SearchPlan, ErrorEnvelope> {
1250 validate_protocol(request.protocol_version)?;
1251
1252 let _ns = super::resolve_namespace(request.namespace.as_deref())?;
1253
1254 let filters = request.filters;
1255 let query = request.query.trim().to_owned();
1256 if query.is_empty() {
1257 return Err(map_error(crate::Error::validation_field(
1258 "query must be non-empty after trim",
1259 "query",
1260 Some(serde_json::json!(request.query)),
1261 Some("non-empty string after trim".to_owned()),
1262 )));
1263 }
1264 if request.limit == 0 {
1265 return Err(map_error(crate::Error::validation_field(
1266 "limit must be at least 1",
1267 "limit",
1268 Some(serde_json::json!(request.limit)),
1269 Some("integer >= 1".to_owned()),
1270 )));
1271 }
1272 let limit = request.limit.min(LIMIT_CAP);
1273 let min_score = filters.min_score;
1274 let filter = build_filter(&filters)?;
1275 let pool = limit.saturating_mul(5).max(50);
1278 Ok(SearchPlan {
1279 mode,
1280 query,
1281 filter,
1282 filters,
1283 pool,
1284 vector_pool: pool.saturating_mul(2),
1285 limit,
1286 min_score,
1287 })
1288 }
1289
1290 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
1291 pub enum RetrieverKind {
1292 Vector,
1293 Fts,
1294 }
1295
1296 impl RetrieverKind {
1297 fn as_wire(self) -> &'static str {
1298 match self {
1299 Self::Vector => "vector",
1300 Self::Fts => "fts",
1301 }
1302 }
1303 }
1304
1305 pub struct RankedList {
1312 pub retriever: RetrieverKind,
1313 pub entries: Vec<(MessageKey, f64)>,
1314 pub weight: f64,
1315 }
1316
1317 fn wire_mode_to_internal(wire: crate::wire::SearchModeWire) -> SearchMode {
1320 match wire {
1321 crate::wire::SearchModeWire::Fts => SearchMode::Fts,
1322 crate::wire::SearchModeWire::Vector => SearchMode::Vector,
1323 crate::wire::SearchModeWire::Hybrid => SearchMode::Hybrid,
1324 }
1325 }
1326
1327 #[derive(Debug, Clone, PartialEq)]
1329 pub struct FusedHit {
1330 pub key: MessageKey,
1331 pub score: f64,
1332 pub matched_via: Vec<String>,
1333 }
1334
1335 fn session_root(session_id: &str) -> &str {
1340 match session_id.find('/') {
1341 Some(idx) => &session_id[..idx],
1342 None => session_id,
1343 }
1344 }
1345
1346 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
1352 pub enum FusionKey {
1353 SessionRoot,
1354 Message,
1355 }
1356
1357 pub fn fuse_arms(lists: &[RankedList], key_by: FusionKey) -> Vec<FusedHit> {
1379 struct Group {
1380 score: f64,
1381 matched_via: Vec<String>,
1382 rep: MessageKey,
1383 rep_contribution: f64,
1384 }
1385 let group_key = |key: &MessageKey| match key_by {
1386 FusionKey::SessionRoot => session_root(&key.session_id).to_owned(),
1387 FusionKey::Message => format!("{}\u{0}{}", key.session_id, key.message_id),
1388 };
1389 let mut merged: std::collections::HashMap<String, Group> = std::collections::HashMap::new();
1390 for list in lists {
1391 if list.entries.is_empty() {
1392 continue;
1393 }
1394 let mut lo = f64::INFINITY;
1403 let mut hi = f64::NEG_INFINITY;
1404 for (_, raw) in &list.entries {
1405 if *raw < lo {
1406 lo = *raw;
1407 }
1408 if *raw > hi {
1409 hi = *raw;
1410 }
1411 }
1412 let range = hi - lo;
1413 let mut seen_in_arm: std::collections::HashSet<String> =
1419 std::collections::HashSet::new();
1420 for (key, raw) in &list.entries {
1421 let group = group_key(key);
1422 if !seen_in_arm.insert(group.clone()) {
1423 continue;
1424 }
1425 let norm = if range > 0.0 { (raw - lo) / range } else { 0.0 };
1426 let contribution = list.weight * norm;
1427 let entry = merged.entry(group).or_insert_with(|| Group {
1428 score: 0.0,
1429 matched_via: Vec::new(),
1430 rep: key.clone(),
1431 rep_contribution: f64::NEG_INFINITY,
1432 });
1433 entry.score += contribution;
1434 entry.matched_via.push(list.retriever.as_wire().to_owned());
1435 if contribution > entry.rep_contribution {
1436 entry.rep = key.clone();
1437 entry.rep_contribution = contribution;
1438 }
1439 }
1440 }
1441 let mut hits = merged
1442 .into_values()
1443 .map(|group| FusedHit {
1444 key: group.rep,
1445 score: group.score,
1446 matched_via: group.matched_via,
1447 })
1448 .collect::<Vec<_>>();
1449 hits.sort_by(|left, right| {
1450 right
1451 .score
1452 .partial_cmp(&left.score)
1453 .unwrap_or(std::cmp::Ordering::Equal)
1454 .then_with(|| left.key.cmp(&right.key))
1455 });
1456 hits
1457 }
1458
1459 const ANCHOR_MIN_TERM_CHARS: usize = 4;
1464
1465 pub fn hit_payload(text: &str, query: &str) -> String {
1470 let chars_len = text.chars().count();
1471 if chars_len <= HIT_SNIPPET_CHARS {
1472 return text.to_owned();
1473 }
1474 query_snippet(text, query)
1475 }
1476
1477 fn query_snippet(text: &str, query: &str) -> String {
1493 let lower_text = text.to_lowercase();
1494 let terms: Vec<String> = query
1495 .split_whitespace()
1496 .filter(|term| !term.is_empty())
1497 .map(str::to_lowercase)
1498 .collect();
1499 let any_informative = terms
1500 .iter()
1501 .any(|term| term.chars().count() >= ANCHOR_MIN_TERM_CHARS);
1502 let hit = terms
1503 .iter()
1504 .filter(|term| !any_informative || term.chars().count() >= ANCHOR_MIN_TERM_CHARS)
1505 .filter_map(|term| lower_text.find(term.as_str()))
1506 .min();
1507 let chars: Vec<char> = text.chars().collect();
1508 let center = hit
1512 .map(|byte| lower_text[..byte].chars().count())
1513 .unwrap_or(0);
1514 let half = HIT_SNIPPET_CHARS / 2;
1515 let start = center.saturating_sub(half);
1516 let end = (start + HIT_SNIPPET_CHARS).min(chars.len());
1517 let start = end.saturating_sub(HIT_SNIPPET_CHARS);
1518 let mut snippet = String::new();
1522 if start > 0 {
1523 snippet.push_str(&format!("[{start} chars before] "));
1524 }
1525 snippet.extend(&chars[start..end]);
1526 if end < chars.len() {
1527 snippet.push_str(&format!(
1528 " [+{} more chars; pond_get for full]",
1529 chars.len() - end
1530 ));
1531 }
1532 snippet
1533 }
1534
1535 struct Candidate {
1536 session_id: String,
1537 message_id: String,
1538 base_score: f64,
1539 }
1540
1541 struct ScoredHit {
1542 meta: MessageMeta,
1543 score: f64,
1544 }
1545
1546 impl ScoredHit {
1547 fn to_search_result(
1548 &self,
1549 query: &str,
1550 summaries: &HashMap<(String, String), Vec<PartSummary>>,
1551 ) -> Result<SearchResult, ErrorEnvelope> {
1552 let text = hit_payload(&self.meta.search_text, query);
1553 let role = match self.meta.role.as_str() {
1554 "system" => Role::System,
1555 "user" => Role::User,
1556 "assistant" => Role::Assistant,
1557 "tool" => Role::Tool,
1558 other => {
1559 return Err(map_error(crate::Error::internal(format!(
1560 "stored message has unknown role: {other}"
1561 ))));
1562 }
1563 };
1564 let parts_summary = if matches!(role, Role::User) {
1567 summaries
1568 .get(&(self.meta.session_id.clone(), self.meta.message_id.clone()))
1569 .cloned()
1570 .unwrap_or_default()
1571 } else {
1572 Vec::new()
1573 };
1574 Ok(SearchResult {
1575 message_id: self.meta.message_id.clone(),
1576 role,
1577 timestamp: self.meta.timestamp,
1578 text,
1579 score: normalize_score(self.score),
1580 parts_summary,
1581 })
1582 }
1583 }
1584
1585 fn normalize_score(score: f64) -> f64 {
1586 (score / SCORE_DENOMINATOR).clamp(0.0, 1.0)
1587 }
1588
1589 fn normalize_fts(hits: Vec<(MessageKey, f32)>) -> Vec<Candidate> {
1590 let max = hits.iter().map(|(_, score)| *score).fold(0.0_f32, f32::max);
1591 hits.into_iter()
1592 .map(|(key, score)| Candidate {
1593 session_id: key.session_id,
1594 message_id: key.message_id,
1595 base_score: if max > 0.0 {
1596 f64::from(score / max)
1597 } else {
1598 0.0
1599 },
1600 })
1601 .collect()
1602 }
1603
1604 fn normalize_vector(hits: Vec<(MessageKey, f32)>) -> Vec<Candidate> {
1608 hits.into_iter()
1609 .map(|(key, distance)| Candidate {
1610 session_id: key.session_id,
1611 message_id: key.message_id,
1612 base_score: 1.0 - f64::from(distance),
1613 })
1614 .collect()
1615 }
1616
1617 fn embed_query(embedder: &dyn Embedder, query: &str) -> Result<Vec<f32>, ErrorEnvelope> {
1618 let prompt = format_query(query);
1619 let vectors =
1623 tokio::task::block_in_place(|| embedder.embed(&[prompt])).map_err(|error_value| {
1624 map_error(crate::Error::internal(format!(
1625 "failed to embed query: {error_value}"
1626 )))
1627 })?;
1628 vectors.into_iter().next().ok_or_else(|| {
1629 map_error(crate::Error::internal(
1630 "embedder returned no vector for query",
1631 ))
1632 })
1633 }
1634
1635 async fn build_sessions(
1636 store: &Store,
1637 scored: &[ScoredHit],
1638 query: &str,
1639 matches_cap: usize,
1640 ) -> Result<Vec<SearchSession>, ErrorEnvelope> {
1641 use std::collections::BTreeMap;
1642
1643 struct Acc {
1644 project: String,
1645 source_agent: String,
1646 matched_count: usize,
1647 matches: Vec<SearchResult>,
1648 }
1649 let mut user_ids_by_session: BTreeMap<String, Vec<String>> = BTreeMap::new();
1653 for hit in scored {
1654 if hit.meta.role == "user" {
1655 user_ids_by_session
1656 .entry(hit.meta.session_id.clone())
1657 .or_default()
1658 .push(hit.meta.message_id.clone());
1659 }
1660 }
1661 let mut summaries: HashMap<(String, String), Vec<PartSummary>> = HashMap::new();
1662 for (session_id, message_ids) in &user_ids_by_session {
1663 for (key, parts) in store
1664 .summary_parts_for_messages(session_id, message_ids)
1665 .await
1666 .map_err(map_storage)?
1667 {
1668 summaries.insert(
1669 key,
1670 parts
1671 .iter()
1672 .filter_map(|part| PartSummary::for_kind(&part.kind))
1673 .collect(),
1674 );
1675 }
1676 }
1677
1678 let mut groups: BTreeMap<String, Acc> = BTreeMap::new();
1679 for hit in scored {
1680 let root = session_root(&hit.meta.session_id).to_owned();
1681 let entry = groups.entry(root).or_insert_with(|| Acc {
1682 project: hit.meta.project.clone(),
1683 source_agent: hit.meta.source_agent.clone(),
1684 matched_count: 0,
1685 matches: Vec::new(),
1686 });
1687 entry.matched_count += 1;
1688 entry.matches.push(hit.to_search_result(query, &summaries)?);
1689 }
1690
1691 let session_ids = groups.keys().cloned().collect::<Vec<_>>();
1692 let counts = store
1693 .session_message_counts(&session_ids)
1694 .await
1695 .map_err(map_storage)?;
1696
1697 let mut result = groups
1698 .into_iter()
1699 .map(|(session_id, mut acc)| {
1700 acc.matches.sort_by(|left, right| {
1701 right
1702 .score
1703 .partial_cmp(&left.score)
1704 .unwrap_or(std::cmp::Ordering::Equal)
1705 .then_with(|| left.message_id.cmp(&right.message_id))
1706 });
1707 acc.matches.truncate(matches_cap);
1708 SearchSession {
1709 session_messages_count: counts.get(&session_id).copied().unwrap_or_default(),
1710 session_id,
1711 project: acc.project,
1712 source_agent: acc.source_agent,
1713 matched_message_count: acc.matched_count,
1714 matches: acc.matches,
1715 }
1716 })
1717 .collect::<Vec<_>>();
1718 result.sort_by(|left, right| {
1719 let left_score = left
1720 .matches
1721 .first()
1722 .map(|hit| hit.score)
1723 .unwrap_or_default();
1724 let right_score = right
1725 .matches
1726 .first()
1727 .map(|hit| hit.score)
1728 .unwrap_or_default();
1729 right_score
1730 .partial_cmp(&left_score)
1731 .unwrap_or(std::cmp::Ordering::Equal)
1732 .then_with(|| left.session_id.cmp(&right.session_id))
1733 });
1734 Ok(result)
1735 }
1736
1737 fn page_sessions(
1738 sessions: Vec<SearchSession>,
1739 matched_total: usize,
1740 searchable_in_scope: usize,
1741 plan: &SearchPlan,
1742 ) -> Result<SearchResponse, ErrorEnvelope> {
1743 let mut emitted = Vec::new();
1744 let mut used_bytes = 0usize;
1745 for session in sessions.iter() {
1746 if emitted.len() >= plan.limit {
1747 break;
1748 }
1749 let bytes = serde_json::to_string(session)
1750 .map_err(|error| {
1751 map_error(crate::Error::internal(format!(
1752 "failed to size search response session: {error}"
1753 )))
1754 })?
1755 .len();
1756 if !emitted.is_empty() && used_bytes.saturating_add(bytes) > SEARCH_BUDGET_BYTES {
1757 break;
1758 }
1759 used_bytes = used_bytes.saturating_add(bytes);
1760 emitted.push(session.clone());
1761 }
1762
1763 let has_more = emitted.len() < sessions.len();
1768
1769 Ok(SearchResponse {
1770 sessions: emitted,
1771 matched_total,
1772 searchable_in_scope,
1773 has_more,
1774 })
1775 }
1776
1777 pub fn build_filter(filters: &SearchFilters) -> Result<Predicate, ErrorEnvelope> {
1781 let mut clauses = Vec::new();
1782
1783 match &filters.project {
1784 None => {}
1785 Some(ProjectFilter::Contains(value)) => {
1786 clauses.push(Predicate::LikeContains("project", value.clone()));
1787 }
1788 Some(ProjectFilter::Regex(pattern)) => {
1789 clauses.push(Predicate::Regex("project", pattern.clone()));
1790 }
1791 }
1792
1793 if let Some(session_id) = &filters.session_id {
1794 clauses.push(Predicate::Eq("session_id", session_id.clone().into()));
1795 }
1796 if let Some(source_agent) = &filters.source_agent {
1797 clauses.push(Predicate::Eq("source_agent", source_agent.clone().into()));
1798 }
1799 if let Some(from_date) = &filters.from_date {
1800 clauses.push(Predicate::Gte(
1801 "timestamp",
1802 ScalarValue::Raw(date_bound(from_date, "filters.from_date", false)?),
1803 ));
1804 }
1805 if let Some(to_date) = &filters.to_date {
1806 clauses.push(Predicate::Lte(
1807 "timestamp",
1808 ScalarValue::Raw(date_bound(to_date, "filters.to_date", true)?),
1809 ));
1810 }
1811
1812 if !filters.include_subagents
1816 && filters.session_id.is_none()
1817 && filters.source_agent.is_none()
1818 {
1819 clauses.push(Predicate::Not(Box::new(Predicate::LikeContains(
1820 "source_agent",
1821 "/".to_owned(),
1822 ))));
1823 }
1824
1825 Ok(Predicate::And(clauses))
1826 }
1827
1828 fn date_bound(date: &str, field: &str, end_of_day: bool) -> Result<String, ErrorEnvelope> {
1831 NaiveDate::parse_from_str(date, "%Y-%m-%d").map_err(|_| {
1832 map_error(crate::Error::validation_field(
1833 format!("{field} must be in YYYY-MM-DD format; got {date}"),
1834 field,
1835 Some(serde_json::json!(date)),
1836 Some("YYYY-MM-DD".to_owned()),
1837 ))
1838 })?;
1839 let time = if end_of_day { "23:59:59" } else { "00:00:00" };
1840 Ok(format!("timestamp '{date} {time}'"))
1841 }
1842
1843 fn empty_response(searchable_in_scope: usize) -> SearchResponse {
1844 SearchResponse {
1845 sessions: Vec::new(),
1846 matched_total: 0,
1847 searchable_in_scope,
1848 has_more: false,
1849 }
1850 }
1851
1852 #[cfg(test)]
1853 mod fusion_helpers_tests {
1854 #![allow(clippy::expect_used, clippy::unwrap_used)]
1855
1856 use super::*;
1857
1858 #[test]
1859 fn session_root_strips_agent_suffix_for_claude_code_subagents() {
1860 assert_eq!(
1861 session_root("94a50f23-1234-5678-9abc-def012345678"),
1862 "94a50f23-1234-5678-9abc-def012345678",
1863 );
1864 assert_eq!(
1865 session_root("94a50f23-1234-5678-9abc-def012345678/agent-abc123"),
1866 "94a50f23-1234-5678-9abc-def012345678",
1867 );
1868 assert_eq!(session_root("root/a/b"), "root");
1870 }
1871
1872 #[test]
1873 fn fuse_arms_dedupes_intra_arm_by_session_root_and_credits_cross_arm() {
1874 let mk = |sid: &str, mid: &str| crate::sessions::MessageKey {
1875 session_id: sid.to_owned(),
1876 message_id: mid.to_owned(),
1877 };
1878 let fts = RankedList {
1884 retriever: RetrieverKind::Fts,
1885 entries: vec![
1886 (mk("session-A", "msg-1"), 10.0),
1887 (mk("session-A", "msg-2"), 9.0),
1888 (mk("session-B", "msg-3"), 6.0),
1889 (mk("session-A/agent-x", "msg-4"), 5.0),
1890 ],
1891 weight: 0.135,
1892 };
1893 let vec_arm = RankedList {
1894 retriever: RetrieverKind::Vector,
1895 entries: vec![
1896 (mk("session-B", "msg-7"), 0.9),
1897 (mk("session-A", "msg-9"), 0.6),
1898 ],
1899 weight: 1.0,
1900 };
1901 let merged = fuse_arms(&[fts, vec_arm], FusionKey::SessionRoot);
1902 assert_eq!(merged.len(), 2);
1904 assert_eq!(merged[0].key.session_id, "session-B");
1911 assert_eq!(merged[0].key.message_id, "msg-7");
1914 assert_eq!(merged[0].matched_via, vec!["fts", "vector"]);
1915 assert_eq!(merged[1].key.session_id, "session-A");
1916 assert_eq!(merged[1].key.message_id, "msg-1");
1919 assert_eq!(merged[1].matched_via, vec!["fts", "vector"]);
1920 }
1921
1922 #[test]
1923 fn fuse_arms_message_key_keeps_per_message_hits_within_one_session() {
1924 let mk = |sid: &str, mid: &str| crate::sessions::MessageKey {
1927 session_id: sid.to_owned(),
1928 message_id: mid.to_owned(),
1929 };
1930 let fts = RankedList {
1931 retriever: RetrieverKind::Fts,
1932 entries: vec![
1933 (mk("session-A", "msg-1"), 10.0),
1934 (mk("session-A", "msg-2"), 6.0),
1935 ],
1936 weight: 0.3,
1937 };
1938 let vec_arm = RankedList {
1939 retriever: RetrieverKind::Vector,
1940 entries: vec![
1941 (mk("session-A", "msg-2"), 0.9),
1942 (mk("session-A", "msg-3"), 0.6),
1943 ],
1944 weight: 1.0,
1945 };
1946 let merged = fuse_arms(&[fts, vec_arm], FusionKey::Message);
1947 assert_eq!(merged.len(), 3, "one fused hit per message, not per root");
1948 assert_eq!(merged[0].key.message_id, "msg-2");
1952 assert_eq!(merged[0].matched_via, vec!["fts", "vector"]);
1953 assert_eq!(merged[1].key.message_id, "msg-1");
1954 assert_eq!(merged[1].matched_via, vec!["fts"]);
1955 }
1956
1957 #[test]
1958 fn fuse_arms_collapses_degenerate_tied_arm_to_zero_contribution() {
1959 let mk = |sid: &str, mid: &str| crate::sessions::MessageKey {
1965 session_id: sid.to_owned(),
1966 message_id: mid.to_owned(),
1967 };
1968 let fts = RankedList {
1969 retriever: RetrieverKind::Fts,
1970 entries: vec![(mk("session-A", "a"), 1.0), (mk("session-B", "b"), 1.0)],
1971 weight: 0.135,
1972 };
1973 let vec_arm = RankedList {
1974 retriever: RetrieverKind::Vector,
1975 entries: vec![(mk("session-A", "a"), 0.9), (mk("session-B", "b"), 0.3)],
1976 weight: 1.0,
1977 };
1978 let merged = fuse_arms(&[fts, vec_arm], FusionKey::SessionRoot);
1979 assert_eq!(merged[0].key.session_id, "session-A");
1981 assert!((merged[0].score - 1.0).abs() < 1e-9);
1982 assert!(merged[1].score.abs() < 1e-9);
1983 }
1984 }
1985}
1986
1987pub use search_handler::{
1988 FusedHit, FusionKey, RankedList, RetrieverKind, SearchMode, SearchPlan, build_filter,
1989 explain_search_plan, fuse_arms, hit_payload, plan_search, pond_search,
1990};
1991
1992#[cfg(test)]
1993mod tests {
1994 #![allow(clippy::expect_used, clippy::unwrap_used)]
1995
1996 use super::*;
1997 use crate::wire::{ProjectFilter, SearchFilters, SearchRequest};
1998 use chrono::Utc;
1999
2000 fn search_request(query: &str) -> SearchRequest {
2001 SearchRequest {
2002 protocol_version: crate::PROTOCOL_VERSION,
2003 namespace: Some("local".to_owned()),
2004 query: query.to_owned(),
2005 mode_override: None,
2006 filters: SearchFilters::default(),
2007 limit: 20,
2008 }
2009 }
2010
2011 fn key(session: &str, id: &str) -> crate::sessions::MessageKey {
2012 crate::sessions::MessageKey {
2013 session_id: session.to_owned(),
2014 message_id: id.to_owned(),
2015 }
2016 }
2017
2018 #[test]
2019 fn fuse_arms_fuses_retrievers_and_reports_provenance() {
2020 let lists = [
2025 RankedList {
2026 retriever: RetrieverKind::Vector,
2027 entries: vec![
2028 (key("session-a", "a"), 0.9),
2029 (key("session-b", "b"), 0.7),
2030 (key("session-c", "c"), 0.5),
2031 ],
2032 weight: 1.0,
2033 },
2034 RankedList {
2035 retriever: RetrieverKind::Fts,
2036 entries: vec![
2037 (key("session-b", "b"), 10.0),
2038 (key("session-a", "a"), 8.0),
2039 (key("session-d", "d"), 4.0),
2040 ],
2041 weight: 0.135,
2042 },
2043 ];
2044 let merged = fuse_arms(&lists, FusionKey::SessionRoot);
2045
2046 assert_eq!(merged[0].key.session_id, "session-a");
2053 assert_eq!(merged[1].key.session_id, "session-b");
2054 assert_eq!(merged[0].matched_via, vec!["vector", "fts"]);
2055 assert!(merged[0].score > merged[1].score);
2056
2057 let c = merged
2058 .iter()
2059 .find(|hit| hit.key.session_id == "session-c")
2060 .unwrap();
2061 assert_eq!(c.matched_via, vec!["vector"]);
2062 let d = merged
2063 .iter()
2064 .find(|hit| hit.key.session_id == "session-d")
2065 .unwrap();
2066 assert_eq!(d.matched_via, vec!["fts"]);
2067 }
2068
2069 #[test]
2070 fn hit_payload_returns_short_text_in_full() {
2071 let short = "a short message body";
2072 let text = hit_payload(short, "message");
2073 assert_eq!(text, short, "small text is returned as-is");
2074 }
2075
2076 #[test]
2077 fn hit_payload_windows_long_text_around_the_query_term() {
2078 let body = format!("{}NEEDLE{}", "a".repeat(2000), "b".repeat(394));
2080 let text = hit_payload(&body, "needle");
2081 assert!(
2082 text.contains("NEEDLE"),
2083 "text is the match-windowed snippet: {text}"
2084 );
2085 assert!(
2088 text.chars().count() <= 600 + 64,
2089 "snippet window is bounded by HIT_SNIPPET_CHARS plus markers: {}",
2090 text.chars().count()
2091 );
2092 }
2093
2094 #[test]
2095 fn hit_payload_snippet_survives_case_folding_that_changes_byte_length() {
2096 let body = format!("İÉÉÉ{}", "a".repeat(2100));
2100 let text = hit_payload(&body, "ééé");
2101 assert!(
2102 text.contains("ÉÉÉ"),
2103 "snippet windows on the matched term: {text}"
2104 );
2105 }
2106
2107 #[tokio::test]
2108 async fn restore_lineage_rejects_a_graph_nesting_deeper_than_one_level() {
2109 use crate::adapter::Extracted;
2110 use crate::sessions::Store;
2111 use crate::wire::{ProviderOptions, Session};
2112 use tempfile::TempDir;
2113
2114 let session = |id: &str, parent: Option<&str>| Session {
2115 id: id.to_owned(),
2116 parent_session_id: parent.map(str::to_owned),
2117 parent_message_id: None,
2118 source_agent: "claude-code".to_owned(),
2119 created_at: Utc::now(),
2120 project: Extracted::from_test_value("/tmp/pond".to_owned()),
2121 options: ProviderOptions::new(),
2122 };
2123
2124 let dir = TempDir::new().unwrap();
2125 let store = Store::open_local(dir.path()).await.unwrap();
2126 store
2128 .upsert_sessions(&[
2129 session("a", None),
2130 session("b", Some("a")),
2131 session("c", Some("b")),
2132 ])
2133 .await
2134 .unwrap();
2135
2136 let err = restore_lineage(&store, "a").await.unwrap_err();
2138 assert!(
2139 err.to_string().contains("one subagent level"),
2140 "expected the deeper-graph error, got: {err}"
2141 );
2142
2143 let lineage = restore_lineage(&store, "b").await.unwrap();
2145 let ids: Vec<&str> = lineage.iter().map(|s| s.session.id.as_str()).collect();
2146 assert_eq!(ids, ["b", "c"]);
2147 }
2148
2149 #[test]
2150 fn build_filter_pushes_down_each_predicate_and_handles_empty() {
2151 let filters = SearchFilters {
2152 project: Some(ProjectFilter::Contains("/Users/me/pond".to_owned())),
2153 session_id: Some("01HXY".to_owned()),
2154 source_agent: Some("claude-code".to_owned()),
2155 from_date: Some("2026-01-01".to_owned()),
2156 to_date: Some("2026-05-01".to_owned()),
2157 min_score: 0.0,
2158 include_subagents: false,
2159 };
2160 let sql = build_filter(&filters).unwrap().to_lance();
2161 assert!(sql.contains("project LIKE '%/Users/me/pond%'"));
2162 assert!(sql.contains("session_id = '01HXY'"));
2163 assert!(sql.contains("source_agent = 'claude-code'"));
2164 assert!(sql.contains("timestamp >="));
2165 assert!(sql.contains("timestamp <="));
2166 assert!(!sql.contains("NOT ("));
2168
2169 assert_eq!(
2170 build_filter(&SearchFilters::default()).unwrap().to_lance(),
2171 "NOT (source_agent LIKE '%/%' ESCAPE '\\')",
2172 );
2173 assert_eq!(
2174 build_filter(&SearchFilters {
2175 include_subagents: true,
2176 ..SearchFilters::default()
2177 })
2178 .unwrap()
2179 .to_lance(),
2180 "",
2181 );
2182 }
2183
2184 #[test]
2185 fn build_filter_rejects_bad_date() {
2186 let bad_date = SearchFilters {
2187 from_date: Some("01-01-2026".to_owned()),
2188 ..SearchFilters::default()
2189 };
2190 assert!(build_filter(&bad_date).is_err());
2191 }
2192
2193 #[test]
2194 fn build_filter_contains_escapes_like_wildcards() {
2195 let filters = SearchFilters {
2196 project: Some(ProjectFilter::Contains("/Users/me/my_project".to_owned())),
2197 ..SearchFilters::default()
2198 };
2199 let sql = build_filter(&filters).unwrap().to_lance();
2200 assert!(
2203 sql.contains(r"my\_project"),
2204 "underscore must be escaped: {sql}"
2205 );
2206 assert!(
2207 sql.contains(r"ESCAPE '\'"),
2208 "predicate must declare the escape char: {sql}"
2209 );
2210 }
2211
2212 #[test]
2213 fn plan_search_shapes_request_for_each_planning_input() {
2214 let mut request = search_request(" vector memory ");
2215 request.limit = 500;
2216 request.filters.min_score = 0.42;
2217 let plan = plan_search(request, SearchMode::Hybrid).unwrap();
2218 assert_eq!(plan.mode, SearchMode::Hybrid);
2219 assert_eq!(plan.query, "vector memory");
2220 assert_eq!(plan.limit, 200);
2221 assert_eq!(plan.pool, 1000);
2222 assert_eq!(plan.vector_pool, 2000);
2223 assert_eq!(plan.min_score, 0.42);
2224
2225 let mut request = search_request("tiny pool");
2227 request.limit = 1;
2228 let plan = plan_search(request, SearchMode::Fts).unwrap();
2229 assert_eq!(plan.mode, SearchMode::Fts);
2230 assert_eq!(plan.limit, 1);
2231 assert_eq!(plan.pool, 50);
2232 assert_eq!(plan.vector_pool, 100);
2233
2234 let mut request = search_request("filtered");
2236 request.filters.project = Some(ProjectFilter::Contains("/Users/me/pond".to_owned()));
2237 request.filters.source_agent = Some("claude-code".to_owned());
2238 let plan = plan_search(request, SearchMode::Fts).unwrap();
2239 let sql = plan.filter.to_lance();
2240 assert!(sql.contains("project LIKE"));
2241 assert!(sql.contains("source_agent = 'claude-code'"));
2242 }
2243
2244 #[test]
2245 fn plan_search_rejects_invalid_composition_before_execution() {
2246 let mut blank = search_request(" ");
2247 let error = plan_search(blank.clone(), SearchMode::Fts)
2248 .unwrap_err()
2249 .error;
2250 assert_eq!(error.code, crate::wire::ErrorCode::ValidationFailed);
2251 assert_eq!(error.details["field"], "query");
2252
2253 blank.query = "valid".to_owned();
2254 blank.limit = 0;
2255 let error = plan_search(blank.clone(), SearchMode::Fts)
2256 .unwrap_err()
2257 .error;
2258 assert_eq!(error.details["field"], "limit");
2259
2260 blank.limit = 1;
2261 blank.namespace = Some("remote".to_owned());
2262 let error = plan_search(blank, SearchMode::Fts).unwrap_err().error;
2263 assert_eq!(error.code, crate::wire::ErrorCode::NamespaceUnknown);
2264 assert_eq!(error.details["namespace"], "remote");
2265 }
2266}