1fn map_error(error: crate::Error) -> crate::wire::ErrorEnvelope {
2 error.into()
3}
4
5#[derive(Debug, Clone)]
10pub struct NamespaceIdent(pub Vec<String>);
11
12impl NamespaceIdent {
13 pub fn root() -> Self {
14 Self(vec![])
15 }
16 pub fn as_table_id(&self, table_name: &str) -> Vec<String> {
17 let mut id = self.0.clone();
18 id.push(table_name.to_string());
19 id
20 }
21}
22
23pub fn resolve_namespace(
27 namespace: Option<&str>,
28) -> Result<NamespaceIdent, crate::wire::ErrorEnvelope> {
29 match namespace {
30 None | Some(crate::wire::DEFAULT_NAMESPACE) => Ok(NamespaceIdent::root()),
31 Some(other) => Err(map_error(crate::Error::namespace_unknown(other))),
32 }
33}
34
35fn map_storage(error: anyhow::Error) -> crate::wire::ErrorEnvelope {
36 if let Some(conflict) = error.downcast_ref::<crate::substrate::ConflictExhausted>() {
39 return map_error(crate::Error::Conflict {
40 attempts: conflict.attempts,
41 });
42 }
43 map_error(crate::Error::Storage(error))
44}
45
46mod ingest_handler {
47 use anyhow::Result;
48 use tokio_stream::StreamExt;
49
50 use crate::{
51 adapter::{Adapter, AdapterYield, SkipOracle, SkipReason},
52 sessions::{IngestEvent, IngestSummary, IngestValidator, OutcomeStatus, RowOutcome, Store},
53 wire::{
54 ErrorBody, ErrorCode, IngestEnvelope, IngestRequest, IngestResponse, IngestResult,
55 IngestStatus, validate_protocol,
56 },
57 };
58
59 use super::{map_error, map_storage};
60
61 pub const MAX_INGEST_EVENTS: usize = 1000;
63
64 #[derive(Debug, Clone)]
71 pub enum SyncEvent {
72 Discovered { total: Option<usize> },
76 SessionDone(SessionOutcome),
79 }
80
81 #[derive(Debug, Clone)]
83 pub struct SessionOutcome {
84 pub project: Option<String>,
86 pub session_id: Option<String>,
89 pub messages: usize,
92 pub status: SyncStatus,
93 }
94
95 #[derive(Debug, Clone)]
108 pub enum SyncStatus {
109 Ok,
110 Partial {
111 dropped_events: usize,
112 first_drop_reason: Option<String>,
115 },
116 Skipped {
117 reason: String,
118 },
119 Rejected {
120 reason: String,
121 },
122 Fresh,
125 Empty,
129 }
130
131 #[derive(Debug, Default)]
132 struct InFlight {
133 project: Option<String>,
134 session_id: String,
135 messages: usize,
136 dropped_events: usize,
141 first_drop_reason: Option<String>,
142 session_index: usize,
146 }
147
148 #[derive(Debug)]
153 struct PendingDone {
154 project: Option<String>,
155 session_id: String,
156 messages: usize,
157 dropped_events: usize,
158 first_drop_reason: Option<String>,
159 session_index: usize,
160 }
161
162 const ADAPTER_FLUSH_BATCH: usize = 100;
169
170 pub async fn ingest_adapter<F>(
181 store: &Store,
182 adapter: &dyn Adapter,
183 oracle: &dyn SkipOracle,
184 mut on_event: F,
185 ) -> Result<IngestSummary>
186 where
187 F: FnMut(SyncEvent),
188 {
189 let mut summary = IngestSummary::default();
190 let truncations_before = crate::adapter::extract::truncated_values_count();
191 let total = adapter
195 .discover()
196 .await
197 .map_err(|error| tracing::debug!(%error, "adapter discover failed"))
198 .ok();
199 on_event(SyncEvent::Discovered { total });
200
201 let mut events = adapter.events_with(oracle);
202 let mut validator = IngestValidator::default();
203 let mut index = 0usize;
207 let mut in_flight: Option<InFlight> = None;
208 let mut pending_dones: std::collections::VecDeque<PendingDone> =
212 std::collections::VecDeque::new();
213 let mut decode_total = std::time::Duration::ZERO;
218 let mut decode_count = 0u64;
219 let mut validator_total = std::time::Duration::ZERO;
220 let mut validator_count = 0u64;
221 let run_started = std::time::Instant::now();
222
223 loop {
224 let decode_start = std::time::Instant::now();
225 let next = events.next().await;
226 decode_total += decode_start.elapsed();
227 decode_count += 1;
228 let event = match next {
229 Some(event) => event,
230 None => break,
231 };
232 match event {
233 Ok(AdapterYield::Skipped {
234 session_id,
235 project,
236 reason,
237 }) => {
238 let status = match reason {
239 SkipReason::Fresh => {
240 summary.skipped_fresh += 1;
241 SyncStatus::Fresh
242 }
243 SkipReason::Empty => {
244 summary.skipped_empty += 1;
245 SyncStatus::Empty
246 }
247 SkipReason::Unsupported(reason) => {
248 summary.skipped_files += 1;
249 SyncStatus::Skipped { reason }
250 }
251 };
252 on_event(SyncEvent::SessionDone(SessionOutcome {
253 project,
254 session_id,
255 messages: 0,
256 status,
257 }));
258 }
259 Ok(AdapterYield::Event(event)) => {
260 if matches!(&event, IngestEvent::Session(_))
265 && let Some(prev) = in_flight.take()
266 {
267 pending_dones.push_back(PendingDone {
268 project: prev.project,
269 session_id: prev.session_id,
270 messages: prev.messages,
271 dropped_events: prev.dropped_events,
272 first_drop_reason: prev.first_drop_reason,
273 session_index: prev.session_index,
274 });
275 }
276 let event_index = index;
277 match &event {
278 IngestEvent::Session(session) => {
279 in_flight = Some(InFlight {
280 project: Some((*session.project).clone()),
281 session_id: session.id.clone(),
282 messages: 0,
283 dropped_events: 0,
284 first_drop_reason: None,
285 session_index: event_index,
286 });
287 }
288 IngestEvent::Message(_) => {
289 if let Some(slot) = in_flight.as_mut() {
290 slot.messages += 1;
291 }
292 }
293 IngestEvent::Part(_) => {}
294 }
295
296 let validator_start = std::time::Instant::now();
297 let push_outcomes = validator.push(store, index, event).await?;
298 validator_total += validator_start.elapsed();
299 validator_count += 1;
300 for outcome in &push_outcomes {
307 if matches!(outcome.status, OutcomeStatus::Error)
308 && outcome.kind != "session"
309 && let Some(slot) = in_flight.as_mut()
310 {
311 slot.dropped_events += 1;
312 if slot.first_drop_reason.is_none() {
313 slot.first_drop_reason =
314 outcome.error.as_ref().map(|err| err.message.clone());
315 }
316 }
317 }
318 summary.add_outcomes(&push_outcomes);
319 index += 1;
320
321 if validator.pending_substreams() >= ADAPTER_FLUSH_BATCH {
326 let flush_start = std::time::Instant::now();
327 let (flush_outcomes, flush_counts) = validator.flush(store).await?;
328 validator_total += flush_start.elapsed();
329 validator_count += 1;
330 summary.add_outcomes_errors_only(&flush_outcomes);
334 summary.add_batch(&flush_counts);
335 drain_pending_dones(&mut pending_dones, &flush_outcomes, &mut on_event);
336 }
337 }
338 Err(error) => {
339 tracing::debug!(
345 %error,
346 "adapter event error (per-line drop by design)"
347 );
348 match in_flight.as_mut() {
349 Some(slot) => {
350 slot.dropped_events += 1;
354 if slot.first_drop_reason.is_none() {
355 slot.first_drop_reason = Some(error.to_string());
356 }
357 summary.dropped_events += 1;
358 }
359 None => {
360 summary.skipped_files += 1;
365 on_event(SyncEvent::SessionDone(SessionOutcome {
366 project: None,
367 session_id: None,
368 messages: 0,
369 status: SyncStatus::Skipped {
370 reason: error.to_string(),
371 },
372 }));
373 }
374 }
375 }
376 }
377 }
378
379 if let Some(prev) = in_flight.take() {
380 pending_dones.push_back(PendingDone {
381 project: prev.project,
382 session_id: prev.session_id,
383 messages: prev.messages,
384 dropped_events: prev.dropped_events,
385 first_drop_reason: prev.first_drop_reason,
386 session_index: prev.session_index,
387 });
388 }
389 let validator_start = std::time::Instant::now();
390 let (final_outcomes, final_counts) = validator.finish(store).await?;
391 validator_total += validator_start.elapsed();
392 validator_count += 1;
393 summary.add_outcomes_errors_only(&final_outcomes);
394 summary.add_batch(&final_counts);
395 drain_pending_dones(&mut pending_dones, &final_outcomes, &mut on_event);
396
397 summary.truncated_values = crate::adapter::extract::truncated_values_count()
398 .saturating_sub(truncations_before) as usize;
399
400 let total = run_started.elapsed();
401 let other = total
402 .saturating_sub(decode_total)
403 .saturating_sub(validator_total);
404 tracing::info!(
405 target: "pond::perf",
406 total_ms = total.as_millis() as u64,
407 decode_ms = decode_total.as_millis() as u64,
408 validator_ms = validator_total.as_millis() as u64,
409 other_ms = other.as_millis() as u64,
410 decode_calls = decode_count,
411 validator_calls = validator_count,
412 rows_inserted = summary.inserted as u64,
413 rows_matched = summary.matched as u64,
414 dropped_events = summary.dropped_events as u64,
415 dropped_sessions = summary.dropped_sessions as u64,
416 skipped_files = summary.skipped_files as u64,
417 skipped_fresh = summary.skipped_fresh as u64,
418 truncated_values = summary.truncated_values as u64,
419 "ingest_adapter complete"
420 );
421 Ok(summary)
422 }
423
424 fn drain_pending_dones<F>(
431 queue: &mut std::collections::VecDeque<PendingDone>,
432 outcomes: &[RowOutcome],
433 on_event: &mut F,
434 ) where
435 F: FnMut(SyncEvent),
436 {
437 let mut session_outcome_by_index: std::collections::HashMap<usize, &RowOutcome> =
440 std::collections::HashMap::new();
441 for outcome in outcomes {
442 if outcome.kind == "session" {
443 session_outcome_by_index.insert(outcome.index, outcome);
444 }
445 }
446
447 while let Some(done) = queue.pop_front() {
448 let session_outcome = session_outcome_by_index.get(&done.session_index).copied();
449 let rejection_reason = session_outcome.and_then(|outcome| {
450 if matches!(outcome.status, OutcomeStatus::Error) {
451 Some(
452 outcome
453 .error
454 .as_ref()
455 .map(|err| err.message.clone())
456 .unwrap_or_else(|| "session-level rejection".to_owned()),
457 )
458 } else {
459 None
460 }
461 });
462 let status = if let Some(reason) = rejection_reason {
463 SyncStatus::Rejected { reason }
464 } else if done.dropped_events > 0 {
465 SyncStatus::Partial {
466 dropped_events: done.dropped_events,
467 first_drop_reason: done.first_drop_reason,
468 }
469 } else {
470 SyncStatus::Ok
471 };
472 on_event(SyncEvent::SessionDone(SessionOutcome {
473 project: done.project,
474 session_id: Some(done.session_id),
475 messages: done.messages,
476 status,
477 }));
478 }
479 }
480
481 pub async fn pond_ingest(store: &Store, request: IngestRequest) -> IngestEnvelope {
487 if let Err(envelope) = validate_protocol(request.protocol_version) {
488 return IngestEnvelope::Error(envelope);
489 }
490 if let Err(envelope) = super::resolve_namespace(request.namespace.as_deref()) {
491 return IngestEnvelope::Error(envelope);
492 }
493 if request.events.is_empty() {
494 return IngestEnvelope::Error(map_error(crate::Error::validation_field(
495 "events must be a non-empty array",
496 "events",
497 Some(serde_json::json!([])),
498 Some("non-empty array".to_owned()),
499 )));
500 }
501 if request.events.len() > MAX_INGEST_EVENTS {
502 return IngestEnvelope::Error(map_error(crate::Error::validation_field(
503 format!("ingest batch exceeds the event cap: at most {MAX_INGEST_EVENTS} events"),
504 "events",
505 Some(serde_json::json!(request.events.len())),
506 Some(format!("at most {MAX_INGEST_EVENTS} events")),
507 )));
508 }
509
510 match ingest_events(store, request.events).await {
511 Ok(outcomes) => {
512 let mut accepted = 0;
513 let mut rejected = 0;
514 for outcome in &outcomes {
515 match outcome.status {
516 OutcomeStatus::Inserted | OutcomeStatus::Matched => accepted += 1,
517 OutcomeStatus::Error => rejected += 1,
518 }
519 }
520 let results = outcomes
521 .into_iter()
522 .map(outcome_to_result)
523 .collect::<Vec<_>>();
524 IngestEnvelope::Success(IngestResponse {
525 accepted,
526 rejected,
527 results,
528 })
529 }
530 Err(failure) => IngestEnvelope::Error(map_storage(failure)),
531 }
532 }
533
534 pub async fn ingest_events(store: &Store, events: Vec<IngestEvent>) -> Result<Vec<RowOutcome>> {
540 let mut validator = IngestValidator::default();
541 let mut outcomes = Vec::with_capacity(events.len());
542 for (index, event) in events.into_iter().enumerate() {
543 let mut chunk = validator.push(store, index, event).await?;
544 outcomes.append(&mut chunk);
545 }
546 let (mut tail, _counts) = validator.finish(store).await?;
549 outcomes.append(&mut tail);
550 outcomes.sort_by_key(|outcome| outcome.index);
551 Ok(outcomes)
552 }
553
554 fn outcome_to_result(outcome: RowOutcome) -> IngestResult {
555 let (status, error) = match (outcome.status, outcome.error) {
556 (OutcomeStatus::Inserted, _) => (IngestStatus::Inserted, None),
557 (OutcomeStatus::Matched, _) => (IngestStatus::Matched, None),
558 (OutcomeStatus::Error, error) => {
559 let body = error
560 .map(|err| {
561 let mut details = serde_json::Map::new();
562 if let Some(field) = err.field {
563 details.insert("field".to_owned(), serde_json::json!(field));
564 }
565 if let Some(reason) = err.reason {
566 details.insert("reason".to_owned(), serde_json::json!(reason));
567 }
568 ErrorBody {
569 code: ErrorCode::ValidationFailed,
570 message: err.message,
571 details: serde_json::Value::Object(details),
572 }
573 })
574 .unwrap_or_else(|| ErrorBody {
575 code: ErrorCode::ValidationFailed,
576 message: "ingest failed".to_owned(),
577 details: serde_json::json!({}),
578 });
579 (IngestStatus::Error, Some(body))
580 }
581 };
582 IngestResult {
583 index: outcome.index,
584 kind: outcome.kind.to_owned(),
585 pk: outcome.pk,
586 status,
587 error,
588 }
589 }
590}
591
592pub use crate::sessions::{IngestEvent, IngestSummary, IngestValidator, search_text};
593pub use ingest_handler::{
594 MAX_INGEST_EVENTS, SessionOutcome, SyncEvent, SyncStatus, ingest_adapter, ingest_events,
595 pond_ingest,
596};
597
598mod export_handler {
599 use anyhow::{Context, Result};
610 use tokio::io::{AsyncWrite, AsyncWriteExt};
611
612 use crate::sessions::{IngestEvent, Store};
613
614 #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
615 pub struct ExportSummary {
616 pub sessions: usize,
617 pub messages: usize,
618 pub parts: usize,
619 }
620
621 pub async fn pond_export<W>(
622 store: &Store,
623 session_filter: Option<&str>,
624 writer: &mut W,
625 ) -> Result<ExportSummary>
626 where
627 W: AsyncWrite + Unpin,
628 {
629 let mut session_ids = match session_filter {
630 Some(id) => vec![id.to_owned()],
631 None => store.session_ids().await?,
632 };
633 session_ids.sort();
634
635 let mut summary = ExportSummary::default();
636 for session_id in session_ids {
637 let Some(stored) = store
638 .get_session(&session_id)
639 .await
640 .with_context(|| format!("export: failed to load session {session_id}"))?
641 else {
642 if session_filter.is_some() {
643 anyhow::bail!("export: session not found: {session_id}");
644 }
645 continue;
646 };
647 write_event(writer, &IngestEvent::Session(stored.session)).await?;
648 summary.sessions += 1;
649 for message_with_parts in stored.messages {
650 write_event(writer, &IngestEvent::Message(message_with_parts.message)).await?;
651 summary.messages += 1;
652 for part in message_with_parts.parts {
653 write_event(writer, &IngestEvent::Part(part)).await?;
654 summary.parts += 1;
655 }
656 }
657 }
658 writer.flush().await.context("export: flush failed")?;
659 Ok(summary)
660 }
661
662 async fn write_event<W>(writer: &mut W, event: &IngestEvent) -> Result<()>
663 where
664 W: AsyncWrite + Unpin,
665 {
666 let line = serde_json::to_string(event).context("export: serialize event")?;
667 writer
668 .write_all(line.as_bytes())
669 .await
670 .context("export: write event")?;
671 writer
672 .write_all(b"\n")
673 .await
674 .context("export: write newline")?;
675 Ok(())
676 }
677}
678
679pub use export_handler::{ExportSummary, pond_export};
680
681mod restore_handler {
682 use anyhow::{Context, Result, bail};
689
690 use crate::sessions::{SessionWithMessages, Store};
691
692 pub async fn restore_lineage(
693 store: &Store,
694 session_id: &str,
695 ) -> Result<Vec<SessionWithMessages>> {
696 let Some(parent) = store.get_session(session_id).await? else {
697 bail!("export: session not found: {session_id}");
698 };
699 let mut sessions = vec![parent];
700 for child in store.child_sessions(session_id).await? {
701 if !store.child_sessions(&child.id).await?.is_empty() {
702 bail!(
703 "adapter-lineage-complete-restore supports one subagent level; session {} has child sessions",
704 child.id
705 );
706 }
707 let child_id = child.id;
708 let stored = store
709 .get_session(&child_id)
710 .await?
711 .with_context(|| format!("export: child session disappeared: {child_id}"))?;
712 sessions.push(stored);
713 }
714 Ok(sessions)
715 }
716}
717
718pub use restore_handler::restore_lineage;
719
720mod get_handler {
721 use crate::{
722 sessions::{GetLookup, MessageViewParams, RetrievedMessage, SessionViewParams, Store},
723 wire::{
724 GetEnvelope, GetRequest, GetResponse, GetResult, GetSession, MessageView, PartSummary,
725 ResponseMode, ResponsePart, validate_protocol,
726 },
727 };
728
729 use super::{map_error, map_storage};
730
731 fn to_message_view(message: RetrievedMessage, verbatim: bool) -> MessageView {
736 if verbatim {
737 return MessageView {
738 id: message.id,
739 role: message.role,
740 timestamp: message.timestamp,
741 text: None,
742 content: None,
743 parts_summary: Vec::new(),
744 parts: Some(
745 message
746 .parts
747 .into_iter()
748 .map(ResponsePart::from_part)
749 .collect(),
750 ),
751 };
752 }
753 let parts_summary = message
754 .parts
755 .iter()
756 .filter_map(|part| PartSummary::for_kind(&part.kind))
757 .collect();
758 MessageView {
759 id: message.id,
760 role: message.role,
761 timestamp: message.timestamp,
762 text: message.text,
763 content: message.content,
764 parts_summary,
765 parts: None,
766 }
767 }
768
769 const BUDGET_BYTES: usize = 200_000;
774
775 pub async fn pond_get(store: &Store, request: GetRequest) -> GetEnvelope {
776 if let Err(error) = validate_protocol(request.protocol_version) {
777 return GetEnvelope::Error(error);
778 }
779 if let Err(envelope) = super::resolve_namespace(request.namespace.as_deref()) {
780 return GetEnvelope::Error(envelope);
781 }
782
783 let response = match (&request.session_id, &request.message_id) {
784 (Some(session_id), None) => session_result(store, session_id, &request).await,
785 (None, Some(message_id)) => message_result(store, message_id, &request).await,
786 (Some(_), Some(_)) => Err(map_error(crate::Error::validation_field(
787 "session_id and message_id are mutually exclusive",
788 "message_id",
789 request.message_id.clone().map(serde_json::Value::String),
790 Some("omit when session_id is present".to_owned()),
791 ))),
792 (None, None) => Err(map_error(crate::Error::validation(
793 "one of session_id or message_id is required",
794 ))),
795 };
796
797 match response {
798 Ok(response) => GetEnvelope::Success(response),
799 Err(error) => GetEnvelope::Error(error),
800 }
801 }
802
803 fn unknown_after_id(request: &GetRequest, anchor_of: &str) -> crate::wire::ErrorEnvelope {
806 map_error(crate::Error::validation_field(
807 "after_id not found (stale or mistyped continuation anchor)",
808 "after_id",
809 request.after_id.clone().map(serde_json::Value::String),
810 Some(format!("a {anchor_of} from a prior page of this read")),
811 ))
812 }
813
814 async fn session_result(
815 store: &Store,
816 session_id: &str,
817 request: &GetRequest,
818 ) -> Result<GetResponse, crate::wire::ErrorEnvelope> {
819 let params = SessionViewParams {
820 mode: request.response_mode,
821 after_id: request.after_id.as_deref(),
822 limit: request.limit,
823 budget_bytes: BUDGET_BYTES,
824 session_from: request.session_from,
825 };
826 let view = match store
827 .session_view(session_id, params)
828 .await
829 .map_err(map_storage)?
830 {
831 GetLookup::NotFound => {
832 return Err(map_error(crate::Error::not_found(
833 "session",
834 serde_json::json!(session_id),
835 format!("session not found: {session_id}"),
836 )));
837 }
838 GetLookup::UnknownAfterId => return Err(unknown_after_id(request, "message id")),
839 GetLookup::Found(view) => view,
840 };
841 let verbatim = matches!(request.response_mode, ResponseMode::Verbatim);
842 Ok(GetResponse {
843 session: GetSession::from_session(&view.session),
844 result: GetResult::Session {
845 messages: view
846 .messages
847 .into_iter()
848 .map(|message| to_message_view(message, verbatim))
849 .collect(),
850 messages_remaining: view.messages_remaining,
851 },
852 })
853 }
854
855 async fn message_result(
856 store: &Store,
857 message_id: &str,
858 request: &GetRequest,
859 ) -> Result<GetResponse, crate::wire::ErrorEnvelope> {
860 let params = MessageViewParams {
861 context_depth: request.context_depth,
862 after_id: request.after_id.as_deref(),
863 limit: request.limit,
864 budget_bytes: BUDGET_BYTES,
865 };
866 let view = match store
867 .message_view(message_id, params)
868 .await
869 .map_err(map_storage)?
870 {
871 GetLookup::NotFound => {
872 return Err(map_error(crate::Error::not_found(
873 "message",
874 serde_json::json!(message_id),
875 format!("message not found: {message_id}"),
876 )));
877 }
878 GetLookup::UnknownAfterId => return Err(unknown_after_id(request, "part id")),
879 GetLookup::Found(view) => view,
880 };
881 let target = MessageView {
884 id: view.target.id,
885 role: view.target.role,
886 timestamp: view.target.timestamp,
887 text: None,
888 content: None,
889 parts_summary: Vec::new(),
890 parts: None,
891 };
892 Ok(GetResponse {
893 session: GetSession::from_session(&view.session),
894 result: GetResult::Message {
895 target,
896 target_parts: view
897 .target_parts
898 .into_iter()
899 .map(ResponsePart::from_part)
900 .collect(),
901 target_parts_remaining: view.target_parts_remaining,
902 siblings: view
903 .siblings
904 .into_iter()
905 .map(|sibling| to_message_view(sibling, false))
906 .collect(),
907 },
908 })
909 }
910}
911
912pub use get_handler::pond_get;
913
914mod search_handler {
915 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
920
921 use crate::{
922 Clock, SystemClock,
923 embed::{Embedder, LazyEmbedder, format_query},
924 sessions::{MessageKey, MessageMeta, Store},
925 substrate::{Predicate, ScalarValue},
926 wire::{
927 ErrorEnvelope, PartSummary, ProjectFilter, Role, SearchCursor, SearchEnvelope,
928 SearchFilters, SearchRequest, SearchResponse, SearchResult, SearchSession,
929 validate_protocol,
930 },
931 };
932 use chrono::NaiveDate;
933 use std::collections::HashMap;
934
935 use super::{map_error, map_storage};
936
937 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
942 pub enum SearchMode {
943 Hybrid,
944 Fts,
945 Vector,
946 }
947
948 #[derive(Debug, Clone, PartialEq)]
949 pub struct SearchPlan {
950 pub mode: SearchMode,
951 pub query: String,
952 pub similar_to: Option<String>,
956 pub filter: Predicate,
957 pub filters: SearchFilters,
958 pub pool: usize,
959 pub vector_pool: usize,
960 pub limit: usize,
961 pub offset: usize,
962 pub min_score: f64,
963 }
964
965 const LIMIT_CAP: usize = 200;
966 const MAX_MATCHES_PER_SESSION: usize = 3;
967 const SEARCH_BUDGET_BYTES: usize = 60_000;
968 const HIT_SNIPPET_CHARS: usize = 600;
972 const SCORE_DENOMINATOR: f64 = FTS_FUSION_WEIGHT + VECTOR_FUSION_WEIGHT;
973
974 const FTS_FUSION_WEIGHT: f64 = 0.135;
983 const VECTOR_FUSION_WEIGHT: f64 = 1.0;
984
985 fn encode_search_cursor(cursor: &SearchCursor) -> String {
986 #[allow(clippy::expect_used)]
987 let bytes = serde_json::to_vec(cursor).expect("search cursor encodes as JSON");
988 URL_SAFE_NO_PAD.encode(bytes)
989 }
990
991 fn decode_search_cursor(raw: &str) -> Result<SearchCursor, ErrorEnvelope> {
992 let bytes = URL_SAFE_NO_PAD.decode(raw).map_err(|_| {
993 map_error(crate::Error::validation_field(
994 "cursor is malformed (expected opaque value from a prior response)",
995 "cursor",
996 Some(serde_json::json!(raw)),
997 Some("opaque base64url".to_owned()),
998 ))
999 })?;
1000 serde_json::from_slice(&bytes).map_err(|_| {
1001 map_error(crate::Error::validation_field(
1002 "cursor is malformed (decode failed)",
1003 "cursor",
1004 Some(serde_json::json!(raw)),
1005 Some("opaque cursor from a prior response".to_owned()),
1006 ))
1007 })
1008 }
1009
1010 pub async fn pond_search(
1019 store: &Store,
1020 embedder: &LazyEmbedder,
1021 request: SearchRequest,
1022 search: &crate::config::SearchConfig,
1023 ) -> SearchEnvelope {
1024 match run_search(store, embedder, request, search, &SystemClock).await {
1025 Ok(response) => SearchEnvelope::Success(response),
1026 Err(envelope) => SearchEnvelope::Error(envelope),
1027 }
1028 }
1029
1030 pub async fn explain_search_plan(
1031 store: &Store,
1032 embedder: &LazyEmbedder,
1033 request: SearchRequest,
1034 search: &crate::config::SearchConfig,
1035 ) -> Result<String, ErrorEnvelope> {
1036 let override_mode = request.mode_override.map(wire_mode_to_internal);
1037 let mut plan = plan_search(request, SearchMode::Fts)?;
1038 plan.mode = resolve_effective_mode(store, override_mode).await?;
1039 let mut out = String::new();
1040 if matches!(plan.mode, SearchMode::Fts | SearchMode::Hybrid) {
1041 let fts = store
1042 .explain_fts_plan(&plan.query, plan.pool, &plan.filter)
1043 .await
1044 .map_err(map_storage)?;
1045 out.push_str("fts:\n");
1046 out.push_str(&fts);
1047 out.push('\n');
1048 }
1049 if matches!(plan.mode, SearchMode::Vector | SearchMode::Hybrid) {
1050 let backend = load_embedder(embedder).await?;
1051 let vector = embed_query(backend.as_ref(), &plan.query)?;
1052 let vector_plan = store
1053 .explain_vector_plan(&vector, plan.vector_pool, &plan.filter, Some(search))
1054 .await
1055 .map_err(map_storage)?;
1056 out.push_str("vector:\n");
1057 out.push_str(&vector_plan);
1058 out.push('\n');
1059 }
1060 Ok(out)
1061 }
1062
1063 async fn run_search(
1064 store: &Store,
1065 embedder: &LazyEmbedder,
1066 request: SearchRequest,
1067 search: &crate::config::SearchConfig,
1068 _clock: &dyn Clock,
1069 ) -> Result<SearchResponse, ErrorEnvelope> {
1070 let override_mode = request.mode_override.map(wire_mode_to_internal);
1071 let mut plan = plan_search(request, SearchMode::Fts)?;
1072
1073 if plan.similar_to.is_some() {
1077 plan.mode = SearchMode::Vector;
1078 } else {
1079 plan.mode = resolve_effective_mode(store, override_mode).await?;
1083 }
1084
1085 let candidates = match plan.mode {
1086 SearchMode::Fts => {
1087 let hits = store
1088 .fts_search(&plan.query, plan.pool, &plan.filter)
1089 .await
1090 .map_err(map_storage)?;
1091 normalize_fts(hits)
1092 }
1093 SearchMode::Hybrid => {
1094 let backend = load_embedder(embedder).await?;
1095 let vector = embed_query(backend.as_ref(), &plan.query)?;
1096 let fts_fut = async {
1099 store
1100 .fts_search(&plan.query, plan.pool, &plan.filter)
1101 .await
1102 .map_err(map_storage)
1103 };
1104 let vector_fut = async {
1105 store
1106 .vector_search(&vector, plan.vector_pool, &plan.filter, Some(search))
1107 .await
1108 .map_err(map_storage)
1109 };
1110 let (fts, vector_raw) = tokio::try_join!(fts_fut, vector_fut)?;
1111 let fts_max = fts.iter().map(|(_, s)| *s).fold(0.0_f32, f32::max);
1123 let fts_entries: Vec<(MessageKey, f64)> = fts
1124 .into_iter()
1125 .map(|(key, score)| {
1126 let normed = if fts_max > 0.0 {
1127 f64::from(score / fts_max)
1128 } else {
1129 0.0
1130 };
1131 (key, normed)
1132 })
1133 .collect();
1134 let vec_n = vector_raw.len() as f64;
1135 let vector_entries: Vec<(MessageKey, f64)> = vector_raw
1136 .into_iter()
1137 .enumerate()
1138 .map(|(idx, (key, _))| {
1139 let normed = if vec_n > 0.0 {
1140 1.0 - (idx as f64 / vec_n)
1141 } else {
1142 0.0
1143 };
1144 (key, normed)
1145 })
1146 .collect();
1147 let lists = [
1152 RankedList {
1153 retriever: RetrieverKind::Fts,
1154 entries: fts_entries,
1155 weight: FTS_FUSION_WEIGHT,
1156 },
1157 RankedList {
1158 retriever: RetrieverKind::Vector,
1159 entries: vector_entries,
1160 weight: VECTOR_FUSION_WEIGHT,
1161 },
1162 ];
1163 fuse_arms(&lists)
1164 .into_iter()
1165 .map(|hit| Candidate {
1166 session_id: hit.key.session_id,
1167 message_id: hit.key.message_id,
1168 base_score: hit.score,
1169 })
1170 .collect()
1171 }
1172 SearchMode::Vector => {
1179 let vector = if let Some(similar_id) = &plan.similar_to {
1180 let stored = store
1181 .message_vector_by_id(similar_id)
1182 .await
1183 .map_err(map_storage)?;
1184 let Some(vector) = stored else {
1185 return Err(map_error(crate::Error::not_found(
1186 "message",
1187 serde_json::json!(similar_id),
1188 format!(
1189 "no embedded message with id {similar_id} (the message may not \
1190 exist, or it exists but is not yet embedded - run `pond embed`)"
1191 ),
1192 )));
1193 };
1194 vector
1195 } else {
1196 let backend = load_embedder(embedder).await?;
1197 embed_query(backend.as_ref(), &plan.query)?
1198 };
1199 let vector_raw = store
1200 .vector_search(&vector, plan.vector_pool, &plan.filter, Some(search))
1201 .await
1202 .map_err(map_storage)?;
1203 normalize_vector(vector_raw)
1204 }
1205 };
1206
1207 if candidates.is_empty() {
1208 return Ok(empty_response());
1209 }
1210
1211 let keys = candidates
1214 .iter()
1215 .map(|candidate| MessageKey {
1216 session_id: candidate.session_id.clone(),
1217 message_id: candidate.message_id.clone(),
1218 })
1219 .collect::<Vec<_>>();
1220 let metas = store
1221 .message_metas_by_keys(&keys)
1222 .await
1223 .map_err(map_storage)?;
1224 let meta_index = metas
1225 .iter()
1226 .map(|meta| ((meta.session_id.as_str(), meta.message_id.as_str()), meta))
1227 .collect::<std::collections::HashMap<_, _>>();
1228
1229 let mut scored = Vec::with_capacity(candidates.len());
1230 for candidate in candidates {
1231 let Some(meta) =
1232 meta_index.get(&(candidate.session_id.as_str(), candidate.message_id.as_str()))
1233 else {
1234 continue;
1235 };
1236 let score = candidate.base_score;
1237 if score < plan.min_score {
1238 continue;
1239 }
1240 scored.push(ScoredHit {
1241 meta: (*meta).clone(),
1242 score,
1243 });
1244 }
1245 scored.sort_by(|left, right| {
1246 right
1247 .score
1248 .partial_cmp(&left.score)
1249 .unwrap_or(std::cmp::Ordering::Equal)
1250 .then_with(|| left.meta.session_id.cmp(&right.meta.session_id))
1251 .then_with(|| left.meta.message_id.cmp(&right.meta.message_id))
1252 });
1253
1254 let matched_total = scored.len();
1255 let sessions = build_sessions(store, &scored, &plan.query).await?;
1256 page_sessions(sessions, matched_total, &plan)
1257 }
1258
1259 async fn resolve_effective_mode(
1264 store: &Store,
1265 override_mode: Option<SearchMode>,
1266 ) -> Result<SearchMode, ErrorEnvelope> {
1267 if let Some(mode) = override_mode {
1268 return Ok(mode);
1269 }
1270 let has = store.has_embeddings().await.map_err(map_storage)?;
1271 Ok(if has {
1272 SearchMode::Hybrid
1273 } else {
1274 SearchMode::Fts
1275 })
1276 }
1277
1278 async fn load_embedder(
1282 embedder: &LazyEmbedder,
1283 ) -> Result<std::sync::Arc<dyn Embedder>, ErrorEnvelope> {
1284 embedder.get().await.map_err(|error| {
1285 map_error(crate::Error::internal(format!(
1286 "embedder load failed: {error}"
1287 )))
1288 })
1289 }
1290
1291 pub fn plan_search(
1292 request: SearchRequest,
1293 mode: SearchMode,
1294 ) -> Result<SearchPlan, ErrorEnvelope> {
1295 validate_protocol(request.protocol_version)?;
1296
1297 let _ns = super::resolve_namespace(request.namespace.as_deref())?;
1298
1299 let cursor = match request.cursor.as_deref() {
1300 Some(raw) => Some(decode_search_cursor(raw)?),
1301 None => None,
1302 };
1303 let (query_raw, similar_raw, filters, offset) = match cursor {
1304 Some(cursor) => (
1305 cursor.query,
1306 cursor.similar_to,
1307 cursor.filters,
1308 cursor.offset,
1309 ),
1310 None => (request.query, request.similar_to, request.filters, 0),
1311 };
1312 let query = query_raw.trim().to_owned();
1313 let similar_to = similar_raw
1314 .as_ref()
1315 .map(|id| id.trim().to_owned())
1316 .filter(|id| !id.is_empty());
1317 if similar_to.is_none() && query.is_empty() {
1318 return Err(map_error(crate::Error::validation_field(
1319 "query must be non-empty after trim",
1320 "query",
1321 Some(serde_json::json!(query_raw)),
1322 Some("non-empty string after trim, or pass `similar_to`".to_owned()),
1323 )));
1324 }
1325 if request.limit == 0 {
1326 return Err(map_error(crate::Error::validation_field(
1327 "limit must be at least 1",
1328 "limit",
1329 Some(serde_json::json!(request.limit)),
1330 Some("integer >= 1".to_owned()),
1331 )));
1332 }
1333 let limit = request.limit.min(LIMIT_CAP);
1334 let min_score = filters.min_score;
1335 let filter = build_filter(&filters)?;
1336 let pool = limit.saturating_mul(5).max(50);
1339 Ok(SearchPlan {
1340 mode,
1341 query,
1342 similar_to,
1343 filter,
1344 filters,
1345 pool,
1346 vector_pool: pool.saturating_mul(2),
1347 limit,
1348 offset,
1349 min_score,
1350 })
1351 }
1352
1353 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
1354 pub enum RetrieverKind {
1355 Vector,
1356 Fts,
1357 }
1358
1359 impl RetrieverKind {
1360 fn as_wire(self) -> &'static str {
1361 match self {
1362 Self::Vector => "vector",
1363 Self::Fts => "fts",
1364 }
1365 }
1366 }
1367
1368 pub struct RankedList {
1375 pub retriever: RetrieverKind,
1376 pub entries: Vec<(MessageKey, f64)>,
1377 pub weight: f64,
1378 }
1379
1380 fn wire_mode_to_internal(wire: crate::wire::SearchModeWire) -> SearchMode {
1383 match wire {
1384 crate::wire::SearchModeWire::Fts => SearchMode::Fts,
1385 crate::wire::SearchModeWire::Vector => SearchMode::Vector,
1386 crate::wire::SearchModeWire::Hybrid => SearchMode::Hybrid,
1387 }
1388 }
1389
1390 #[derive(Debug, Clone, PartialEq)]
1392 pub struct FusedHit {
1393 pub key: MessageKey,
1394 pub score: f64,
1395 pub matched_via: Vec<String>,
1396 }
1397
1398 fn session_root(session_id: &str) -> &str {
1403 match session_id.find('/') {
1404 Some(idx) => &session_id[..idx],
1405 None => session_id,
1406 }
1407 }
1408
1409 pub fn fuse_arms(lists: &[RankedList]) -> Vec<FusedHit> {
1432 let mut merged: std::collections::HashMap<String, (f64, Vec<String>, MessageKey)> =
1433 std::collections::HashMap::new();
1434 for list in lists {
1435 if list.entries.is_empty() {
1436 continue;
1437 }
1438 let mut lo = f64::INFINITY;
1447 let mut hi = f64::NEG_INFINITY;
1448 for (_, raw) in &list.entries {
1449 if *raw < lo {
1450 lo = *raw;
1451 }
1452 if *raw > hi {
1453 hi = *raw;
1454 }
1455 }
1456 let range = hi - lo;
1457 let mut seen_in_arm: std::collections::HashSet<String> =
1462 std::collections::HashSet::new();
1463 for (key, raw) in &list.entries {
1464 let root = session_root(&key.session_id).to_owned();
1465 if !seen_in_arm.insert(root.clone()) {
1466 continue;
1467 }
1468 let norm = if range > 0.0 { (raw - lo) / range } else { 0.0 };
1469 let contribution = list.weight * norm;
1470 let entry = merged
1471 .entry(root)
1472 .or_insert_with(|| (0.0, Vec::new(), key.clone()));
1473 entry.0 += contribution;
1474 entry.1.push(list.retriever.as_wire().to_owned());
1475 }
1476 }
1477 let mut hits = merged
1478 .into_values()
1479 .map(|(score, matched_via, key)| FusedHit {
1480 key,
1481 score,
1482 matched_via,
1483 })
1484 .collect::<Vec<_>>();
1485 hits.sort_by(|left, right| {
1486 right
1487 .score
1488 .partial_cmp(&left.score)
1489 .unwrap_or(std::cmp::Ordering::Equal)
1490 .then_with(|| left.key.cmp(&right.key))
1491 });
1492 hits
1493 }
1494
1495 const ANCHOR_MIN_TERM_CHARS: usize = 4;
1500
1501 pub fn hit_payload(text: &str, query: &str) -> String {
1506 let chars_len = text.chars().count();
1507 if chars_len <= HIT_SNIPPET_CHARS {
1508 return text.to_owned();
1509 }
1510 query_snippet(text, query)
1511 }
1512
1513 fn query_snippet(text: &str, query: &str) -> String {
1529 let lower_text = text.to_lowercase();
1530 let terms: Vec<String> = query
1531 .split_whitespace()
1532 .filter(|term| !term.is_empty())
1533 .map(str::to_lowercase)
1534 .collect();
1535 let any_informative = terms
1536 .iter()
1537 .any(|term| term.chars().count() >= ANCHOR_MIN_TERM_CHARS);
1538 let hit = terms
1539 .iter()
1540 .filter(|term| !any_informative || term.chars().count() >= ANCHOR_MIN_TERM_CHARS)
1541 .filter_map(|term| lower_text.find(term.as_str()))
1542 .min();
1543 let chars: Vec<char> = text.chars().collect();
1544 let center = hit
1548 .map(|byte| lower_text[..byte].chars().count())
1549 .unwrap_or(0);
1550 let half = HIT_SNIPPET_CHARS / 2;
1551 let start = center.saturating_sub(half);
1552 let end = (start + HIT_SNIPPET_CHARS).min(chars.len());
1553 let start = end.saturating_sub(HIT_SNIPPET_CHARS);
1554 let mut snippet = String::new();
1558 if start > 0 {
1559 snippet.push_str(&format!("[{start} chars before] "));
1560 }
1561 snippet.extend(&chars[start..end]);
1562 if end < chars.len() {
1563 snippet.push_str(&format!(
1564 " [+{} more chars; pond_get for full]",
1565 chars.len() - end
1566 ));
1567 }
1568 snippet
1569 }
1570
1571 struct Candidate {
1572 session_id: String,
1573 message_id: String,
1574 base_score: f64,
1575 }
1576
1577 struct ScoredHit {
1578 meta: MessageMeta,
1579 score: f64,
1580 }
1581
1582 impl ScoredHit {
1583 fn to_search_result(
1584 &self,
1585 query: &str,
1586 summaries: &HashMap<(String, String), Vec<PartSummary>>,
1587 ) -> Result<SearchResult, ErrorEnvelope> {
1588 let text = hit_payload(&self.meta.search_text, query);
1589 let role = match self.meta.role.as_str() {
1590 "system" => Role::System,
1591 "user" => Role::User,
1592 "assistant" => Role::Assistant,
1593 "tool" => Role::Tool,
1594 other => {
1595 return Err(map_error(crate::Error::internal(format!(
1596 "stored message has unknown role: {other}"
1597 ))));
1598 }
1599 };
1600 let parts_summary = if matches!(role, Role::User) {
1603 summaries
1604 .get(&(self.meta.session_id.clone(), self.meta.message_id.clone()))
1605 .cloned()
1606 .unwrap_or_default()
1607 } else {
1608 Vec::new()
1609 };
1610 Ok(SearchResult {
1611 message_id: self.meta.message_id.clone(),
1612 role,
1613 timestamp: self.meta.timestamp,
1614 text,
1615 score: normalize_score(self.score),
1616 parts_summary,
1617 })
1618 }
1619 }
1620
1621 fn normalize_score(score: f64) -> f64 {
1622 (score / SCORE_DENOMINATOR).clamp(0.0, 1.0)
1623 }
1624
1625 fn normalize_fts(hits: Vec<(MessageKey, f32)>) -> Vec<Candidate> {
1626 let max = hits.iter().map(|(_, score)| *score).fold(0.0_f32, f32::max);
1627 hits.into_iter()
1628 .map(|(key, score)| Candidate {
1629 session_id: key.session_id,
1630 message_id: key.message_id,
1631 base_score: if max > 0.0 {
1632 f64::from(score / max)
1633 } else {
1634 0.0
1635 },
1636 })
1637 .collect()
1638 }
1639
1640 fn normalize_vector(hits: Vec<(MessageKey, f32)>) -> Vec<Candidate> {
1641 let n = hits.len() as f64;
1642 hits.into_iter()
1643 .enumerate()
1644 .map(|(idx, (key, _))| Candidate {
1645 session_id: key.session_id,
1646 message_id: key.message_id,
1647 base_score: if n > 0.0 { 1.0 - (idx as f64 / n) } else { 0.0 },
1648 })
1649 .collect()
1650 }
1651
1652 fn embed_query(embedder: &dyn Embedder, query: &str) -> Result<Vec<f32>, ErrorEnvelope> {
1653 let prompt = format_query(query);
1654 let vectors =
1658 tokio::task::block_in_place(|| embedder.embed(&[prompt])).map_err(|error_value| {
1659 map_error(crate::Error::internal(format!(
1660 "failed to embed query: {error_value}"
1661 )))
1662 })?;
1663 vectors.into_iter().next().ok_or_else(|| {
1664 map_error(crate::Error::internal(
1665 "embedder returned no vector for query",
1666 ))
1667 })
1668 }
1669
1670 async fn build_sessions(
1671 store: &Store,
1672 scored: &[ScoredHit],
1673 query: &str,
1674 ) -> Result<Vec<SearchSession>, ErrorEnvelope> {
1675 use std::collections::BTreeMap;
1676
1677 struct Acc {
1678 project: String,
1679 source_agent: String,
1680 matched_count: usize,
1681 matches: Vec<SearchResult>,
1682 }
1683 let mut user_ids_by_session: BTreeMap<String, Vec<String>> = BTreeMap::new();
1687 for hit in scored {
1688 if hit.meta.role == "user" {
1689 user_ids_by_session
1690 .entry(hit.meta.session_id.clone())
1691 .or_default()
1692 .push(hit.meta.message_id.clone());
1693 }
1694 }
1695 let mut summaries: HashMap<(String, String), Vec<PartSummary>> = HashMap::new();
1696 for (session_id, message_ids) in &user_ids_by_session {
1697 for (key, parts) in store
1698 .summary_parts_for_messages(session_id, message_ids)
1699 .await
1700 .map_err(map_storage)?
1701 {
1702 summaries.insert(
1703 key,
1704 parts
1705 .iter()
1706 .filter_map(|part| PartSummary::for_kind(&part.kind))
1707 .collect(),
1708 );
1709 }
1710 }
1711
1712 let mut groups: BTreeMap<String, Acc> = BTreeMap::new();
1713 for hit in scored {
1714 let root = session_root(&hit.meta.session_id).to_owned();
1715 let entry = groups.entry(root).or_insert_with(|| Acc {
1716 project: hit.meta.project.clone(),
1717 source_agent: hit.meta.source_agent.clone(),
1718 matched_count: 0,
1719 matches: Vec::new(),
1720 });
1721 entry.matched_count += 1;
1722 entry.matches.push(hit.to_search_result(query, &summaries)?);
1723 }
1724
1725 let session_ids = groups.keys().cloned().collect::<Vec<_>>();
1726 let counts = store
1727 .session_message_counts(&session_ids)
1728 .await
1729 .map_err(map_storage)?;
1730
1731 let mut result = groups
1732 .into_iter()
1733 .map(|(session_id, mut acc)| {
1734 acc.matches.sort_by(|left, right| {
1735 right
1736 .score
1737 .partial_cmp(&left.score)
1738 .unwrap_or(std::cmp::Ordering::Equal)
1739 .then_with(|| left.message_id.cmp(&right.message_id))
1740 });
1741 acc.matches.truncate(MAX_MATCHES_PER_SESSION);
1742 SearchSession {
1743 session_messages_count: counts.get(&session_id).copied().unwrap_or_default(),
1744 session_id,
1745 project: acc.project,
1746 source_agent: acc.source_agent,
1747 matched_message_count: acc.matched_count,
1748 matches: acc.matches,
1749 }
1750 })
1751 .collect::<Vec<_>>();
1752 result.sort_by(|left, right| {
1753 let left_score = left
1754 .matches
1755 .first()
1756 .map(|hit| hit.score)
1757 .unwrap_or_default();
1758 let right_score = right
1759 .matches
1760 .first()
1761 .map(|hit| hit.score)
1762 .unwrap_or_default();
1763 right_score
1764 .partial_cmp(&left_score)
1765 .unwrap_or(std::cmp::Ordering::Equal)
1766 .then_with(|| left.session_id.cmp(&right.session_id))
1767 });
1768 Ok(result)
1769 }
1770
1771 fn page_sessions(
1772 sessions: Vec<SearchSession>,
1773 matched_total: usize,
1774 plan: &SearchPlan,
1775 ) -> Result<SearchResponse, ErrorEnvelope> {
1776 if plan.offset >= sessions.len() {
1777 return Ok(SearchResponse {
1778 sessions: Vec::new(),
1779 matched_total,
1780 has_more: false,
1781 next_cursor: None,
1782 });
1783 }
1784
1785 let mut emitted = Vec::new();
1786 let mut used_bytes = 0usize;
1787 for session in sessions.iter().skip(plan.offset) {
1788 if emitted.len() >= plan.limit {
1789 break;
1790 }
1791 let bytes = serde_json::to_string(session)
1792 .map_err(|error| {
1793 map_error(crate::Error::internal(format!(
1794 "failed to size search response session: {error}"
1795 )))
1796 })?
1797 .len();
1798 if !emitted.is_empty() && used_bytes.saturating_add(bytes) > SEARCH_BUDGET_BYTES {
1799 break;
1800 }
1801 used_bytes = used_bytes.saturating_add(bytes);
1802 emitted.push(session.clone());
1803 }
1804
1805 let next_offset = plan.offset + emitted.len();
1806 let has_more = next_offset < sessions.len();
1807 let next_cursor = has_more.then(|| {
1808 encode_search_cursor(&SearchCursor {
1809 query: plan.query.clone(),
1810 similar_to: plan.similar_to.clone(),
1811 filters: plan.filters.clone(),
1812 offset: next_offset,
1813 })
1814 });
1815
1816 Ok(SearchResponse {
1817 sessions: emitted,
1818 matched_total,
1819 has_more,
1820 next_cursor,
1821 })
1822 }
1823
1824 pub fn build_filter(filters: &SearchFilters) -> Result<Predicate, ErrorEnvelope> {
1828 let mut clauses = Vec::new();
1829
1830 match &filters.project {
1831 None => {}
1832 Some(ProjectFilter::Contains(value)) => {
1833 clauses.push(Predicate::LikeContains("project", value.clone()));
1834 }
1835 Some(ProjectFilter::Regex(pattern)) => {
1836 clauses.push(Predicate::Regex("project", pattern.clone()));
1837 }
1838 }
1839
1840 if let Some(session_id) = &filters.session_id {
1841 clauses.push(Predicate::Eq("session_id", session_id.clone().into()));
1842 }
1843 if let Some(source_agent) = &filters.source_agent {
1844 clauses.push(Predicate::Eq("source_agent", source_agent.clone().into()));
1845 }
1846 if let Some(role) = &filters.role {
1847 if !matches!(role.as_str(), "user" | "assistant" | "system" | "tool") {
1848 return Err(map_error(crate::Error::validation_field(
1849 format!(
1850 "filters.role must be one of: user, assistant, system, tool; got {role}"
1851 ),
1852 "filters.role",
1853 Some(serde_json::json!(role)),
1854 Some("one of: user, assistant, system, tool".to_owned()),
1855 )));
1856 }
1857 clauses.push(Predicate::Eq("role", role.clone().into()));
1858 }
1859 if let Some(from_date) = &filters.from_date {
1860 clauses.push(Predicate::Gte(
1861 "timestamp",
1862 ScalarValue::Raw(date_bound(from_date, "filters.from_date", false)?),
1863 ));
1864 }
1865 if let Some(to_date) = &filters.to_date {
1866 clauses.push(Predicate::Lte(
1867 "timestamp",
1868 ScalarValue::Raw(date_bound(to_date, "filters.to_date", true)?),
1869 ));
1870 }
1871
1872 if !filters.include_subagents
1876 && filters.session_id.is_none()
1877 && filters.source_agent.is_none()
1878 {
1879 clauses.push(Predicate::Not(Box::new(Predicate::LikeContains(
1880 "source_agent",
1881 "/".to_owned(),
1882 ))));
1883 }
1884
1885 Ok(Predicate::And(clauses))
1886 }
1887
1888 fn date_bound(date: &str, field: &str, end_of_day: bool) -> Result<String, ErrorEnvelope> {
1891 NaiveDate::parse_from_str(date, "%Y-%m-%d").map_err(|_| {
1892 map_error(crate::Error::validation_field(
1893 format!("{field} must be in YYYY-MM-DD format; got {date}"),
1894 field,
1895 Some(serde_json::json!(date)),
1896 Some("YYYY-MM-DD".to_owned()),
1897 ))
1898 })?;
1899 let time = if end_of_day { "23:59:59" } else { "00:00:00" };
1900 Ok(format!("timestamp '{date} {time}'"))
1901 }
1902
1903 fn empty_response() -> SearchResponse {
1904 SearchResponse {
1905 sessions: Vec::new(),
1906 matched_total: 0,
1907 has_more: false,
1908 next_cursor: None,
1909 }
1910 }
1911
1912 #[cfg(test)]
1913 mod fusion_helpers_tests {
1914 #![allow(clippy::expect_used, clippy::unwrap_used)]
1915
1916 use super::*;
1917
1918 #[test]
1919 fn session_root_strips_agent_suffix_for_claude_code_subagents() {
1920 assert_eq!(
1921 session_root("94a50f23-1234-5678-9abc-def012345678"),
1922 "94a50f23-1234-5678-9abc-def012345678",
1923 );
1924 assert_eq!(
1925 session_root("94a50f23-1234-5678-9abc-def012345678/agent-abc123"),
1926 "94a50f23-1234-5678-9abc-def012345678",
1927 );
1928 assert_eq!(session_root("root/a/b"), "root");
1930 }
1931
1932 #[test]
1933 fn fuse_arms_dedupes_intra_arm_by_session_root_and_credits_cross_arm() {
1934 let mk = |sid: &str, mid: &str| crate::sessions::MessageKey {
1935 session_id: sid.to_owned(),
1936 message_id: mid.to_owned(),
1937 };
1938 let fts = RankedList {
1944 retriever: RetrieverKind::Fts,
1945 entries: vec![
1946 (mk("session-A", "msg-1"), 10.0),
1947 (mk("session-A", "msg-2"), 9.0),
1948 (mk("session-B", "msg-3"), 6.0),
1949 (mk("session-A/agent-x", "msg-4"), 5.0),
1950 ],
1951 weight: 0.135,
1952 };
1953 let vec_arm = RankedList {
1954 retriever: RetrieverKind::Vector,
1955 entries: vec![
1956 (mk("session-B", "msg-7"), 0.9),
1957 (mk("session-A", "msg-9"), 0.6),
1958 ],
1959 weight: 1.0,
1960 };
1961 let merged = fuse_arms(&[fts, vec_arm]);
1962 assert_eq!(merged.len(), 2);
1964 assert_eq!(merged[0].key.session_id, "session-B");
1971 assert_eq!(merged[0].key.message_id, "msg-3");
1974 assert_eq!(merged[0].matched_via, vec!["fts", "vector"]);
1975 assert_eq!(merged[1].key.session_id, "session-A");
1976 assert_eq!(merged[1].key.message_id, "msg-1");
1977 assert_eq!(merged[1].matched_via, vec!["fts", "vector"]);
1978 }
1979
1980 #[test]
1981 fn fuse_arms_collapses_degenerate_tied_arm_to_zero_contribution() {
1982 let mk = |sid: &str, mid: &str| crate::sessions::MessageKey {
1988 session_id: sid.to_owned(),
1989 message_id: mid.to_owned(),
1990 };
1991 let fts = RankedList {
1992 retriever: RetrieverKind::Fts,
1993 entries: vec![(mk("session-A", "a"), 1.0), (mk("session-B", "b"), 1.0)],
1994 weight: 0.135,
1995 };
1996 let vec_arm = RankedList {
1997 retriever: RetrieverKind::Vector,
1998 entries: vec![(mk("session-A", "a"), 0.9), (mk("session-B", "b"), 0.3)],
1999 weight: 1.0,
2000 };
2001 let merged = fuse_arms(&[fts, vec_arm]);
2002 assert_eq!(merged[0].key.session_id, "session-A");
2004 assert!((merged[0].score - 1.0).abs() < 1e-9);
2005 assert!(merged[1].score.abs() < 1e-9);
2006 }
2007 }
2008}
2009
2010pub use search_handler::{
2011 FusedHit, RankedList, RetrieverKind, SearchMode, SearchPlan, build_filter, explain_search_plan,
2012 fuse_arms, hit_payload, plan_search, pond_search,
2013};
2014
2015#[cfg(test)]
2016mod tests {
2017 #![allow(clippy::expect_used, clippy::unwrap_used)]
2018
2019 use super::*;
2020 use crate::wire::{ProjectFilter, SearchFilters, SearchRequest};
2021 use chrono::Utc;
2022
2023 fn search_request(query: &str) -> SearchRequest {
2024 SearchRequest {
2025 protocol_version: crate::PROTOCOL_VERSION,
2026 namespace: Some("local".to_owned()),
2027 query: query.to_owned(),
2028 mode_override: None,
2029 similar_to: None,
2030 filters: SearchFilters::default(),
2031 limit: 20,
2032 cursor: None,
2033 }
2034 }
2035
2036 fn key(session: &str, id: &str) -> crate::sessions::MessageKey {
2037 crate::sessions::MessageKey {
2038 session_id: session.to_owned(),
2039 message_id: id.to_owned(),
2040 }
2041 }
2042
2043 #[test]
2044 fn fuse_arms_fuses_retrievers_and_reports_provenance() {
2045 let lists = [
2050 RankedList {
2051 retriever: RetrieverKind::Vector,
2052 entries: vec![
2053 (key("session-a", "a"), 0.9),
2054 (key("session-b", "b"), 0.7),
2055 (key("session-c", "c"), 0.5),
2056 ],
2057 weight: 1.0,
2058 },
2059 RankedList {
2060 retriever: RetrieverKind::Fts,
2061 entries: vec![
2062 (key("session-b", "b"), 10.0),
2063 (key("session-a", "a"), 8.0),
2064 (key("session-d", "d"), 4.0),
2065 ],
2066 weight: 0.135,
2067 },
2068 ];
2069 let merged = fuse_arms(&lists);
2070
2071 assert_eq!(merged[0].key.session_id, "session-a");
2078 assert_eq!(merged[1].key.session_id, "session-b");
2079 assert_eq!(merged[0].matched_via, vec!["vector", "fts"]);
2080 assert!(merged[0].score > merged[1].score);
2081
2082 let c = merged
2083 .iter()
2084 .find(|hit| hit.key.session_id == "session-c")
2085 .unwrap();
2086 assert_eq!(c.matched_via, vec!["vector"]);
2087 let d = merged
2088 .iter()
2089 .find(|hit| hit.key.session_id == "session-d")
2090 .unwrap();
2091 assert_eq!(d.matched_via, vec!["fts"]);
2092 }
2093
2094 #[test]
2095 fn hit_payload_returns_short_text_in_full() {
2096 let short = "a short message body";
2097 let text = hit_payload(short, "message");
2098 assert_eq!(text, short, "small text is returned as-is");
2099 }
2100
2101 #[test]
2102 fn hit_payload_windows_long_text_around_the_query_term() {
2103 let body = format!("{}NEEDLE{}", "a".repeat(2000), "b".repeat(394));
2105 let text = hit_payload(&body, "needle");
2106 assert!(
2107 text.contains("NEEDLE"),
2108 "text is the match-windowed snippet: {text}"
2109 );
2110 assert!(
2113 text.chars().count() <= 600 + 64,
2114 "snippet window is bounded by HIT_SNIPPET_CHARS plus markers: {}",
2115 text.chars().count()
2116 );
2117 }
2118
2119 #[test]
2120 fn hit_payload_snippet_survives_case_folding_that_changes_byte_length() {
2121 let body = format!("İÉÉÉ{}", "a".repeat(2100));
2125 let text = hit_payload(&body, "ééé");
2126 assert!(
2127 text.contains("ÉÉÉ"),
2128 "snippet windows on the matched term: {text}"
2129 );
2130 }
2131
2132 #[tokio::test]
2133 async fn restore_lineage_rejects_a_graph_nesting_deeper_than_one_level() {
2134 use crate::adapter::Extracted;
2135 use crate::sessions::Store;
2136 use crate::wire::{ProviderOptions, Session};
2137 use tempfile::TempDir;
2138
2139 let session = |id: &str, parent: Option<&str>| Session {
2140 id: id.to_owned(),
2141 parent_session_id: parent.map(str::to_owned),
2142 parent_message_id: None,
2143 source_agent: "claude-code".to_owned(),
2144 created_at: Utc::now(),
2145 project: Extracted::from_test_value("/tmp/pond".to_owned()),
2146 options: ProviderOptions::new(),
2147 };
2148
2149 let dir = TempDir::new().unwrap();
2150 let store = Store::open_local(dir.path()).await.unwrap();
2151 store
2153 .upsert_sessions(&[
2154 session("a", None),
2155 session("b", Some("a")),
2156 session("c", Some("b")),
2157 ])
2158 .await
2159 .unwrap();
2160
2161 let err = restore_lineage(&store, "a").await.unwrap_err();
2163 assert!(
2164 err.to_string().contains("one subagent level"),
2165 "expected the deeper-graph error, got: {err}"
2166 );
2167
2168 let lineage = restore_lineage(&store, "b").await.unwrap();
2170 let ids: Vec<&str> = lineage.iter().map(|s| s.session.id.as_str()).collect();
2171 assert_eq!(ids, ["b", "c"]);
2172 }
2173
2174 #[test]
2175 fn build_filter_pushes_down_each_predicate_and_handles_empty() {
2176 let filters = SearchFilters {
2177 project: Some(ProjectFilter::Contains("/Users/me/pond".to_owned())),
2178 session_id: Some("01HXY".to_owned()),
2179 source_agent: Some("claude-code".to_owned()),
2180 role: Some("assistant".to_owned()),
2181 from_date: Some("2026-01-01".to_owned()),
2182 to_date: Some("2026-05-01".to_owned()),
2183 min_score: 0.0,
2184 include_subagents: false,
2185 };
2186 let sql = build_filter(&filters).unwrap().to_lance();
2187 assert!(sql.contains("project LIKE '%/Users/me/pond%'"));
2188 assert!(sql.contains("session_id = '01HXY'"));
2189 assert!(sql.contains("source_agent = 'claude-code'"));
2190 assert!(sql.contains("role = 'assistant'"));
2191 assert!(sql.contains("timestamp >="));
2192 assert!(sql.contains("timestamp <="));
2193 assert!(!sql.contains("NOT ("));
2195
2196 assert_eq!(
2197 build_filter(&SearchFilters::default()).unwrap().to_lance(),
2198 "NOT (source_agent LIKE '%/%' ESCAPE '\\')",
2199 );
2200 assert_eq!(
2201 build_filter(&SearchFilters {
2202 include_subagents: true,
2203 ..SearchFilters::default()
2204 })
2205 .unwrap()
2206 .to_lance(),
2207 "",
2208 );
2209 }
2210
2211 #[test]
2212 fn build_filter_rejects_bad_role_and_date() {
2213 let bad_role = SearchFilters {
2214 role: Some("wizard".to_owned()),
2215 ..SearchFilters::default()
2216 };
2217 assert!(build_filter(&bad_role).is_err());
2218
2219 let bad_date = SearchFilters {
2220 from_date: Some("01-01-2026".to_owned()),
2221 ..SearchFilters::default()
2222 };
2223 assert!(build_filter(&bad_date).is_err());
2224 }
2225
2226 #[test]
2227 fn build_filter_contains_escapes_like_wildcards() {
2228 let filters = SearchFilters {
2229 project: Some(ProjectFilter::Contains("/Users/me/my_project".to_owned())),
2230 ..SearchFilters::default()
2231 };
2232 let sql = build_filter(&filters).unwrap().to_lance();
2233 assert!(
2236 sql.contains(r"my\_project"),
2237 "underscore must be escaped: {sql}"
2238 );
2239 assert!(
2240 sql.contains(r"ESCAPE '\'"),
2241 "predicate must declare the escape char: {sql}"
2242 );
2243 }
2244
2245 #[test]
2246 fn plan_search_shapes_request_for_each_planning_input() {
2247 let mut request = search_request(" vector memory ");
2248 request.limit = 500;
2249 request.filters.min_score = 0.42;
2250 let plan = plan_search(request, SearchMode::Hybrid).unwrap();
2251 assert_eq!(plan.mode, SearchMode::Hybrid);
2252 assert_eq!(plan.query, "vector memory");
2253 assert_eq!(plan.limit, 200);
2254 assert_eq!(plan.pool, 1000);
2255 assert_eq!(plan.vector_pool, 2000);
2256 assert_eq!(plan.min_score, 0.42);
2257
2258 let mut request = search_request("tiny pool");
2260 request.limit = 1;
2261 let plan = plan_search(request, SearchMode::Fts).unwrap();
2262 assert_eq!(plan.mode, SearchMode::Fts);
2263 assert_eq!(plan.limit, 1);
2264 assert_eq!(plan.pool, 50);
2265 assert_eq!(plan.vector_pool, 100);
2266
2267 let mut request = search_request("filtered");
2269 request.filters.project = Some(ProjectFilter::Contains("/Users/me/pond".to_owned()));
2270 request.filters.role = Some("assistant".to_owned());
2271 let plan = plan_search(request, SearchMode::Fts).unwrap();
2272 let sql = plan.filter.to_lance();
2273 assert!(sql.contains("project LIKE"));
2274 assert!(sql.contains("role = 'assistant'"));
2275 }
2276
2277 #[test]
2278 fn plan_search_rejects_invalid_composition_before_execution() {
2279 let mut blank = search_request(" ");
2280 let error = plan_search(blank.clone(), SearchMode::Fts)
2281 .unwrap_err()
2282 .error;
2283 assert_eq!(error.code, crate::wire::ErrorCode::ValidationFailed);
2284 assert_eq!(error.details["field"], "query");
2285
2286 blank.query = "valid".to_owned();
2287 blank.limit = 0;
2288 let error = plan_search(blank.clone(), SearchMode::Fts)
2289 .unwrap_err()
2290 .error;
2291 assert_eq!(error.details["field"], "limit");
2292
2293 blank.limit = 1;
2294 blank.namespace = Some("remote".to_owned());
2295 let error = plan_search(blank, SearchMode::Fts).unwrap_err().error;
2296 assert_eq!(error.code, crate::wire::ErrorCode::NamespaceUnknown);
2297 assert_eq!(error.details["namespace"], "remote");
2298 }
2299}