1fn map_error(error: crate::Error) -> crate::wire::ErrorEnvelope {
5 error.into()
6}
7
8#[derive(Debug, Clone)]
13pub struct NamespaceIdent(pub Vec<String>);
14
15impl NamespaceIdent {
16 pub fn root() -> Self {
17 Self(vec![])
18 }
19 pub fn as_table_id(&self, table_name: &str) -> Vec<String> {
20 let mut id = self.0.clone();
21 id.push(table_name.to_string());
22 id
23 }
24}
25
26pub fn resolve_namespace(
30 namespace: Option<&str>,
31) -> Result<NamespaceIdent, crate::wire::ErrorEnvelope> {
32 match namespace {
33 None | Some(crate::wire::DEFAULT_NAMESPACE) => Ok(NamespaceIdent::root()),
34 Some(other) => Err(map_error(crate::Error::namespace_unknown(other))),
35 }
36}
37
38fn map_storage(error: anyhow::Error) -> crate::wire::ErrorEnvelope {
39 if let Some(conflict) = error.downcast_ref::<crate::substrate::ConflictExhausted>() {
42 return map_error(crate::Error::Conflict {
43 attempts: conflict.attempts,
44 });
45 }
46 map_error(crate::Error::Storage(error))
47}
48
49mod ingest_handler {
50 use anyhow::Result;
51 use tokio_stream::StreamExt;
52
53 use crate::{
54 adapter::{Adapter, AdapterYield, SkipOracle, SkipReason},
55 sessions::{IngestEvent, IngestSummary, IngestValidator, OutcomeStatus, RowOutcome, Store},
56 wire::{
57 ErrorBody, ErrorCode, IngestEnvelope, IngestRequest, IngestResponse, IngestResult,
58 IngestStatus, validate_protocol,
59 },
60 };
61
62 use super::{map_error, map_storage};
63
64 pub const MAX_INGEST_EVENTS: usize = 1000;
66
67 #[derive(Debug, Clone)]
74 pub enum SyncEvent {
75 Discovered { total: Option<usize> },
79 SessionDone(SessionOutcome),
82 SkippedBulk { status: SyncStatus, count: usize },
84 }
85
86 #[derive(Debug, Clone)]
88 pub struct SessionOutcome {
89 pub project: Option<String>,
91 pub session_id: Option<String>,
94 pub messages: usize,
97 pub status: SyncStatus,
98 }
99
100 #[derive(Debug, Clone)]
113 pub enum SyncStatus {
114 Ok,
115 Partial {
116 dropped_events: usize,
117 first_drop_reason: Option<String>,
120 },
121 Skipped {
122 reason: String,
123 },
124 Rejected {
125 reason: String,
126 },
127 Fresh,
130 Empty,
134 }
135
136 #[derive(Debug, Default)]
137 struct InFlight {
138 project: Option<String>,
139 session_id: String,
140 messages: usize,
141 dropped_events: usize,
146 first_drop_reason: Option<String>,
147 session_index: usize,
151 }
152
153 #[derive(Debug)]
158 struct PendingDone {
159 project: Option<String>,
160 session_id: String,
161 messages: usize,
162 dropped_events: usize,
163 first_drop_reason: Option<String>,
164 session_index: usize,
165 }
166
167 const ADAPTER_FLUSH_BATCH: usize = 100;
174
175 pub async fn ingest_adapter<F>(
186 store: &Store,
187 adapter: &dyn Adapter,
188 oracle: &dyn SkipOracle,
189 mut on_event: F,
190 ) -> Result<IngestSummary>
191 where
192 F: FnMut(SyncEvent),
193 {
194 let mut summary = IngestSummary::default();
195 let truncations_before = crate::adapter::extract::truncated_values_count();
196 let total = adapter
200 .discover()
201 .await
202 .map_err(|error| tracing::debug!(%error, "adapter discover failed"))
203 .ok();
204 on_event(SyncEvent::Discovered { total });
205
206 let mut events = adapter.events_with(oracle);
207 let mut validator = IngestValidator::default();
208 let mut index = 0usize;
212 let mut in_flight: Option<InFlight> = None;
213 let mut pending_dones: std::collections::VecDeque<PendingDone> =
217 std::collections::VecDeque::new();
218 let mut decode_total = std::time::Duration::ZERO;
223 let mut decode_count = 0u64;
224 let mut validator_total = std::time::Duration::ZERO;
225 let mut validator_count = 0u64;
226 let run_started = std::time::Instant::now();
227
228 loop {
229 let decode_start = std::time::Instant::now();
230 let next = events.next().await;
231 decode_total += decode_start.elapsed();
232 decode_count += 1;
233 let event = match next {
234 Some(event) => event,
235 None => break,
236 };
237 match event {
238 Ok(AdapterYield::Skipped {
239 session_id,
240 project,
241 reason,
242 }) => {
243 let status = match reason {
244 SkipReason::Fresh => {
245 summary.skipped_fresh += 1;
246 SyncStatus::Fresh
247 }
248 SkipReason::Empty => {
249 summary.skipped_empty += 1;
250 SyncStatus::Empty
251 }
252 SkipReason::Unsupported(reason) => {
253 summary.skipped_files += 1;
254 SyncStatus::Skipped { reason }
255 }
256 };
257 on_event(SyncEvent::SessionDone(SessionOutcome {
258 project,
259 session_id,
260 messages: 0,
261 status,
262 }));
263 }
264 Ok(AdapterYield::SkippedBatch { reason, count }) => {
265 let status = match reason {
266 SkipReason::Fresh => {
267 summary.skipped_fresh += count;
268 SyncStatus::Fresh
269 }
270 SkipReason::Empty => {
271 summary.skipped_empty += count;
272 SyncStatus::Empty
273 }
274 SkipReason::Unsupported(reason) => {
275 summary.skipped_files += count;
276 SyncStatus::Skipped { reason }
277 }
278 };
279 on_event(SyncEvent::SkippedBulk { status, count });
280 }
281 Ok(AdapterYield::Event(event)) => {
282 if matches!(&event, IngestEvent::Session(_))
287 && let Some(prev) = in_flight.take()
288 {
289 pending_dones.push_back(PendingDone {
290 project: prev.project,
291 session_id: prev.session_id,
292 messages: prev.messages,
293 dropped_events: prev.dropped_events,
294 first_drop_reason: prev.first_drop_reason,
295 session_index: prev.session_index,
296 });
297 }
298 let event_index = index;
299 match &event {
300 IngestEvent::Session(session) => {
301 in_flight = Some(InFlight {
302 project: Some((*session.project).clone()),
303 session_id: session.id.clone(),
304 messages: 0,
305 dropped_events: 0,
306 first_drop_reason: None,
307 session_index: event_index,
308 });
309 }
310 IngestEvent::Message(_) => {
311 if let Some(slot) = in_flight.as_mut() {
312 slot.messages += 1;
313 }
314 }
315 IngestEvent::Part(_) => {}
316 }
317
318 let validator_start = std::time::Instant::now();
319 let push_outcomes = validator.push(store, index, event).await?;
320 validator_total += validator_start.elapsed();
321 validator_count += 1;
322 for outcome in &push_outcomes {
329 if matches!(outcome.status, OutcomeStatus::Error)
330 && outcome.kind != "session"
331 && let Some(slot) = in_flight.as_mut()
332 {
333 slot.dropped_events += 1;
334 if slot.first_drop_reason.is_none() {
335 slot.first_drop_reason =
336 outcome.error.as_ref().map(|err| err.message.clone());
337 }
338 }
339 }
340 summary.add_outcomes(&push_outcomes);
341 index += 1;
342
343 if validator.pending_substreams() >= ADAPTER_FLUSH_BATCH {
348 let flush_start = std::time::Instant::now();
349 let (flush_outcomes, flush_counts) = validator.flush(store).await?;
350 validator_total += flush_start.elapsed();
351 validator_count += 1;
352 summary.add_outcomes_errors_only(&flush_outcomes);
356 summary.add_batch(&flush_counts);
357 drain_pending_dones(&mut pending_dones, &flush_outcomes, &mut on_event);
358 }
359 }
360 Err(error) => {
361 tracing::debug!(
367 %error,
368 "adapter event error (per-line drop by design)"
369 );
370 match in_flight.as_mut() {
371 Some(slot) => {
372 slot.dropped_events += 1;
376 if slot.first_drop_reason.is_none() {
377 slot.first_drop_reason = Some(error.to_string());
378 }
379 summary.dropped_events += 1;
380 }
381 None => {
382 summary.skipped_files += 1;
387 on_event(SyncEvent::SessionDone(SessionOutcome {
388 project: None,
389 session_id: None,
390 messages: 0,
391 status: SyncStatus::Skipped {
392 reason: error.to_string(),
393 },
394 }));
395 }
396 }
397 }
398 }
399 }
400
401 if let Some(prev) = in_flight.take() {
402 pending_dones.push_back(PendingDone {
403 project: prev.project,
404 session_id: prev.session_id,
405 messages: prev.messages,
406 dropped_events: prev.dropped_events,
407 first_drop_reason: prev.first_drop_reason,
408 session_index: prev.session_index,
409 });
410 }
411 let validator_start = std::time::Instant::now();
412 let (final_outcomes, final_counts) = validator.finish(store).await?;
413 validator_total += validator_start.elapsed();
414 validator_count += 1;
415 summary.add_outcomes_errors_only(&final_outcomes);
416 summary.add_batch(&final_counts);
417 drain_pending_dones(&mut pending_dones, &final_outcomes, &mut on_event);
418
419 summary.truncated_values = crate::adapter::extract::truncated_values_count()
420 .saturating_sub(truncations_before) as usize;
421
422 let total = run_started.elapsed();
423 let other = total
424 .saturating_sub(decode_total)
425 .saturating_sub(validator_total);
426 tracing::info!(
427 target: "pond::perf",
428 total_ms = total.as_millis() as u64,
429 decode_ms = decode_total.as_millis() as u64,
430 validator_ms = validator_total.as_millis() as u64,
431 other_ms = other.as_millis() as u64,
432 decode_calls = decode_count,
433 validator_calls = validator_count,
434 rows_inserted = summary.inserted as u64,
435 rows_matched = summary.matched as u64,
436 dropped_events = summary.dropped_events as u64,
437 dropped_sessions = summary.dropped_sessions as u64,
438 skipped_files = summary.skipped_files as u64,
439 skipped_fresh = summary.skipped_fresh as u64,
440 truncated_values = summary.truncated_values as u64,
441 "ingest_adapter complete"
442 );
443 Ok(summary)
444 }
445
446 fn drain_pending_dones<F>(
453 queue: &mut std::collections::VecDeque<PendingDone>,
454 outcomes: &[RowOutcome],
455 on_event: &mut F,
456 ) where
457 F: FnMut(SyncEvent),
458 {
459 let mut session_outcome_by_index: std::collections::HashMap<usize, &RowOutcome> =
462 std::collections::HashMap::new();
463 for outcome in outcomes {
464 if outcome.kind == "session" {
465 session_outcome_by_index.insert(outcome.index, outcome);
466 }
467 }
468
469 while let Some(done) = queue.pop_front() {
470 let session_outcome = session_outcome_by_index.get(&done.session_index).copied();
471 let rejection_reason = session_outcome.and_then(|outcome| {
472 if matches!(outcome.status, OutcomeStatus::Error) {
473 Some(
474 outcome
475 .error
476 .as_ref()
477 .map(|err| err.message.clone())
478 .unwrap_or_else(|| "session-level rejection".to_owned()),
479 )
480 } else {
481 None
482 }
483 });
484 let status = if let Some(reason) = rejection_reason {
485 SyncStatus::Rejected { reason }
486 } else if done.dropped_events > 0 {
487 SyncStatus::Partial {
488 dropped_events: done.dropped_events,
489 first_drop_reason: done.first_drop_reason,
490 }
491 } else {
492 SyncStatus::Ok
493 };
494 on_event(SyncEvent::SessionDone(SessionOutcome {
495 project: done.project,
496 session_id: Some(done.session_id),
497 messages: done.messages,
498 status,
499 }));
500 }
501 }
502
503 pub async fn pond_ingest(store: &Store, request: IngestRequest) -> IngestEnvelope {
509 if let Err(envelope) = validate_protocol(request.protocol_version) {
510 return IngestEnvelope::Error(envelope);
511 }
512 if let Err(envelope) = super::resolve_namespace(request.namespace.as_deref()) {
513 return IngestEnvelope::Error(envelope);
514 }
515 if request.events.is_empty() {
516 return IngestEnvelope::Error(map_error(crate::Error::validation_field(
517 "events must be a non-empty array",
518 "events",
519 Some(serde_json::json!([])),
520 Some("non-empty array".to_owned()),
521 )));
522 }
523 if request.events.len() > MAX_INGEST_EVENTS {
524 return IngestEnvelope::Error(map_error(crate::Error::validation_field(
525 format!("ingest batch exceeds the event cap: at most {MAX_INGEST_EVENTS} events"),
526 "events",
527 Some(serde_json::json!(request.events.len())),
528 Some(format!("at most {MAX_INGEST_EVENTS} events")),
529 )));
530 }
531
532 match ingest_events(store, request.events).await {
533 Ok(outcomes) => {
534 let mut accepted = 0;
535 let mut rejected = 0;
536 for outcome in &outcomes {
537 match outcome.status {
538 OutcomeStatus::Inserted | OutcomeStatus::Matched => accepted += 1,
539 OutcomeStatus::Error => rejected += 1,
540 }
541 }
542 let results = outcomes
543 .into_iter()
544 .map(outcome_to_result)
545 .collect::<Vec<_>>();
546 IngestEnvelope::Success(IngestResponse {
547 accepted,
548 rejected,
549 results,
550 })
551 }
552 Err(failure) => IngestEnvelope::Error(map_storage(failure)),
553 }
554 }
555
556 pub async fn ingest_events(store: &Store, events: Vec<IngestEvent>) -> Result<Vec<RowOutcome>> {
562 let mut validator = IngestValidator::default();
563 let mut outcomes = Vec::with_capacity(events.len());
564 for (index, event) in events.into_iter().enumerate() {
565 let mut chunk = validator.push(store, index, event).await?;
566 outcomes.append(&mut chunk);
567 }
568 let (mut tail, _counts) = validator.finish(store).await?;
571 outcomes.append(&mut tail);
572 outcomes.sort_by_key(|outcome| outcome.index);
573 Ok(outcomes)
574 }
575
576 fn outcome_to_result(outcome: RowOutcome) -> IngestResult {
577 let (status, error) = match (outcome.status, outcome.error) {
578 (OutcomeStatus::Inserted, _) => (IngestStatus::Inserted, None),
579 (OutcomeStatus::Matched, _) => (IngestStatus::Matched, None),
580 (OutcomeStatus::Error, error) => {
581 let body = error
582 .map(|err| {
583 let mut details = serde_json::Map::new();
584 if let Some(field) = err.field {
585 details.insert("field".to_owned(), serde_json::json!(field));
586 }
587 if let Some(reason) = err.reason {
588 details.insert("reason".to_owned(), serde_json::json!(reason));
589 }
590 ErrorBody {
591 code: ErrorCode::ValidationFailed,
592 message: err.message,
593 details: serde_json::Value::Object(details),
594 }
595 })
596 .unwrap_or_else(|| ErrorBody {
597 code: ErrorCode::ValidationFailed,
598 message: "ingest failed".to_owned(),
599 details: serde_json::json!({}),
600 });
601 (IngestStatus::Error, Some(body))
602 }
603 };
604 IngestResult {
605 index: outcome.index,
606 kind: outcome.kind.to_owned(),
607 pk: outcome.pk,
608 status,
609 error,
610 }
611 }
612}
613
614pub use crate::sessions::{IngestEvent, IngestSummary, IngestValidator, search_text};
615pub use ingest_handler::{
616 MAX_INGEST_EVENTS, SessionOutcome, SyncEvent, SyncStatus, ingest_adapter, ingest_events,
617 pond_ingest,
618};
619
620mod export_handler {
621 use anyhow::{Context, Result};
632 use tokio::io::{AsyncWrite, AsyncWriteExt};
633
634 use crate::sessions::{IngestEvent, Store};
635
636 #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
637 pub struct ExportSummary {
638 pub sessions: usize,
639 pub messages: usize,
640 pub parts: usize,
641 }
642
643 pub async fn pond_export<W>(
644 store: &Store,
645 session_filter: Option<&str>,
646 writer: &mut W,
647 ) -> Result<ExportSummary>
648 where
649 W: AsyncWrite + Unpin,
650 {
651 let mut session_ids = match session_filter {
652 Some(id) => vec![id.to_owned()],
653 None => store.session_ids().await?,
654 };
655 session_ids.sort();
656
657 let mut summary = ExportSummary::default();
658 for session_id in session_ids {
659 let Some(stored) = store
660 .get_session(&session_id)
661 .await
662 .with_context(|| format!("export: failed to load session {session_id}"))?
663 else {
664 if session_filter.is_some() {
665 anyhow::bail!("export: session not found: {session_id}");
666 }
667 continue;
668 };
669 write_event(writer, &IngestEvent::Session(stored.session)).await?;
670 summary.sessions += 1;
671 for message_with_parts in stored.messages {
672 write_event(writer, &IngestEvent::Message(message_with_parts.message)).await?;
673 summary.messages += 1;
674 for part in message_with_parts.parts {
675 write_event(writer, &IngestEvent::Part(part)).await?;
676 summary.parts += 1;
677 }
678 }
679 }
680 writer.flush().await.context("export: flush failed")?;
681 Ok(summary)
682 }
683
684 async fn write_event<W>(writer: &mut W, event: &IngestEvent) -> Result<()>
685 where
686 W: AsyncWrite + Unpin,
687 {
688 let line = serde_json::to_string(event).context("export: serialize event")?;
689 writer
690 .write_all(line.as_bytes())
691 .await
692 .context("export: write event")?;
693 writer
694 .write_all(b"\n")
695 .await
696 .context("export: write newline")?;
697 Ok(())
698 }
699}
700
701pub use export_handler::{ExportSummary, pond_export};
702
703mod restore_handler {
704 use anyhow::{Context, Result, bail};
711
712 use crate::sessions::{SessionWithMessages, Store};
713
714 pub async fn restore_lineage(
715 store: &Store,
716 session_id: &str,
717 ) -> Result<Vec<SessionWithMessages>> {
718 let Some(parent) = store.get_session(session_id).await? else {
719 bail!("export: session not found: {session_id}");
720 };
721 let mut sessions = vec![parent];
722 for child in store.child_sessions(session_id).await? {
723 if !store.child_sessions(&child.id).await?.is_empty() {
724 bail!(
725 "adapter-lineage-complete-restore supports one subagent level; session {} has child sessions",
726 child.id
727 );
728 }
729 let child_id = child.id;
730 let stored = store
731 .get_session(&child_id)
732 .await?
733 .with_context(|| format!("export: child session disappeared: {child_id}"))?;
734 sessions.push(stored);
735 }
736 Ok(sessions)
737 }
738}
739
740pub use restore_handler::restore_lineage;
741
742mod get_handler {
743 use crate::{
744 sessions::{GetLookup, MessageViewParams, RetrievedMessage, SessionViewParams, Store},
745 wire::{
746 GetEnvelope, GetRequest, GetResponse, GetResult, GetSession, MessageView, PartSummary,
747 ResponsePart, validate_protocol,
748 },
749 };
750
751 use super::{map_error, map_storage};
752
753 fn to_message_view(message: RetrievedMessage) -> MessageView {
758 let parts_summary = message
759 .parts
760 .iter()
761 .filter_map(|part| PartSummary::for_kind(&part.kind))
762 .collect();
763 MessageView {
764 id: message.id,
765 role: message.role,
766 timestamp: message.timestamp,
767 text: message.text,
768 content: message.content,
769 parts_summary,
770 }
771 }
772
773 const BUDGET_BYTES: usize = 200_000;
779
780 pub async fn pond_get(store: &Store, request: GetRequest) -> GetEnvelope {
781 if let Err(error) = validate_protocol(request.protocol_version) {
782 return GetEnvelope::Error(error);
783 }
784 if let Err(envelope) = super::resolve_namespace(request.namespace.as_deref()) {
785 return GetEnvelope::Error(envelope);
786 }
787
788 let response = match (&request.session_id, &request.message_id) {
789 (Some(session_id), None) => session_result(store, session_id, &request).await,
790 (None, Some(message_id)) => message_result(store, message_id, &request).await,
791 (Some(_), Some(_)) => Err(map_error(crate::Error::validation_field(
792 "session_id and message_id are mutually exclusive",
793 "message_id",
794 request.message_id.clone().map(serde_json::Value::String),
795 Some("omit when session_id is present".to_owned()),
796 ))),
797 (None, None) => Err(map_error(crate::Error::validation(
798 "one of session_id or message_id is required",
799 ))),
800 };
801
802 match response {
803 Ok(response) => GetEnvelope::Success(response),
804 Err(error) => GetEnvelope::Error(error),
805 }
806 }
807
808 fn unknown_anchor(field: &str, value: Option<&str>) -> crate::wire::ErrorEnvelope {
811 map_error(crate::Error::validation_field(
812 format!("{field} not found (stale or mistyped pagination anchor)"),
813 field,
814 value.map(|v| serde_json::Value::String(v.to_owned())),
815 Some("a message id from a prior page of this read".to_owned()),
816 ))
817 }
818
819 async fn session_result(
820 store: &Store,
821 session_id: &str,
822 request: &GetRequest,
823 ) -> Result<GetResponse, crate::wire::ErrorEnvelope> {
824 if request.session_after_message_id.is_some() && request.session_before_message_id.is_some()
825 {
826 return Err(map_error(crate::Error::validation_field(
827 "session_after_message_id and session_before_message_id are mutually exclusive",
828 "session_before_message_id",
829 request
830 .session_before_message_id
831 .clone()
832 .map(serde_json::Value::String),
833 Some("set only one pagination anchor".to_owned()),
834 )));
835 }
836 let params = SessionViewParams {
837 after_message_id: request.session_after_message_id.as_deref(),
838 before_message_id: request.session_before_message_id.as_deref(),
839 limit: request.session_limit,
840 budget_bytes: BUDGET_BYTES,
841 session_from: request.session_from,
842 };
843 let view = match store
844 .session_view(session_id, params)
845 .await
846 .map_err(map_storage)?
847 {
848 GetLookup::NotFound => {
849 return Err(map_error(crate::Error::not_found(
850 "session",
851 serde_json::json!(session_id),
852 format!("session not found: {session_id}"),
853 )));
854 }
855 GetLookup::UnknownAnchor => {
856 let (field, value) = match &request.session_after_message_id {
857 Some(value) => ("session_after_message_id", Some(value.as_str())),
858 None => (
859 "session_before_message_id",
860 request.session_before_message_id.as_deref(),
861 ),
862 };
863 return Err(unknown_anchor(field, value));
864 }
865 GetLookup::Found(view) => view,
866 };
867 Ok(GetResponse {
868 session: GetSession::from_session(&view.session),
869 result: GetResult::Session {
870 messages: view.messages.into_iter().map(to_message_view).collect(),
871 before_remaining: view.before_remaining,
872 after_remaining: view.after_remaining,
873 },
874 })
875 }
876
877 async fn message_result(
878 store: &Store,
879 message_id: &str,
880 request: &GetRequest,
881 ) -> Result<GetResponse, crate::wire::ErrorEnvelope> {
882 let params = MessageViewParams {
883 context_before: request.message_context_before,
884 context_after: request.message_context_after,
885 budget_bytes: BUDGET_BYTES,
886 };
887 let view = match store
888 .message_view(message_id, params)
889 .await
890 .map_err(map_storage)?
891 {
892 GetLookup::NotFound => {
893 return Err(map_error(crate::Error::not_found(
894 "message",
895 serde_json::json!(message_id),
896 format!("message not found: {message_id}"),
897 )));
898 }
899 GetLookup::UnknownAnchor => {
901 return Err(map_error(crate::Error::internal(
902 "message_view returned UnknownAnchor for an anchorless lookup",
903 )));
904 }
905 GetLookup::Found(view) => view,
906 };
907 let target = MessageView {
910 id: view.target.id,
911 role: view.target.role,
912 timestamp: view.target.timestamp,
913 text: None,
914 content: None,
915 parts_summary: Vec::new(),
916 };
917 Ok(GetResponse {
918 session: GetSession::from_session(&view.session),
919 result: GetResult::Message {
920 target,
921 target_parts: view
922 .target_parts
923 .into_iter()
924 .map(ResponsePart::from_part)
925 .collect(),
926 target_parts_remaining: view.target_parts_remaining,
927 siblings: view.siblings.into_iter().map(to_message_view).collect(),
928 },
929 })
930 }
931}
932
933pub use get_handler::pond_get;
934
935mod search_handler {
936 use crate::{
941 Clock, SystemClock,
942 embed::{Embedder, LazyEmbedder, format_query},
943 sessions::{MessageKey, MessageMeta, SearchHit, Store},
944 substrate::{Predicate, ScalarValue},
945 wire::{
946 ErrorEnvelope, PartSummary, ProjectFilter, Role, SearchEnvelope, SearchFilters,
947 SearchRequest, SearchResponse, SearchResult, SearchSession, SortBy, validate_protocol,
948 },
949 };
950 use chrono::{DateTime, NaiveDate, Utc};
951 use std::collections::HashMap;
952
953 use super::{map_error, map_storage};
954
955 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
960 pub enum SearchMode {
961 Fts,
962 Vector,
963 }
964
965 #[derive(Debug, Clone, PartialEq)]
966 pub struct SearchPlan {
967 pub mode: SearchMode,
968 pub query: String,
969 pub filter: Predicate,
975 pub filters: SearchFilters,
976 pub sort_by: SortBy,
977 pub pool: usize,
978 pub vector_pool: usize,
979 pub limit: usize,
980 pub min_score: f64,
981 pub exclude_subagents: bool,
985 }
986
987 const LIMIT_CAP: usize = 200;
988 const HIT_SNIPPET_CHARS: usize = 600;
992
993 const RECENCY_BOOST_MAGNITUDE: f64 = 0.02;
1007 const RECENCY_BOOST_SCALE_DAYS: f64 = 30.0;
1008
1009 pub async fn pond_search(
1018 store: &Store,
1019 embedder: &LazyEmbedder,
1020 request: SearchRequest,
1021 search: &crate::config::SearchConfig,
1022 ) -> SearchEnvelope {
1023 match run_search(store, embedder, request, search, &SystemClock).await {
1024 Ok(response) => SearchEnvelope::Success(response),
1025 Err(envelope) => SearchEnvelope::Error(envelope),
1026 }
1027 }
1028
1029 pub async fn explain_search_plan(
1030 store: &Store,
1031 embedder: &LazyEmbedder,
1032 request: SearchRequest,
1033 search: &crate::config::SearchConfig,
1034 ) -> Result<String, ErrorEnvelope> {
1035 let mut plan = plan_search(request)?;
1036 plan.mode = resolve_effective_mode(store, plan.mode).await?;
1037 let mut out = String::new();
1038 match plan.mode {
1039 SearchMode::Fts => {
1040 let fts = store
1041 .explain_fts_plan(&plan.query, plan.pool, &plan.filter)
1042 .await
1043 .map_err(map_storage)?;
1044 out.push_str("fts:\n");
1045 out.push_str(&fts);
1046 out.push('\n');
1047 }
1048 SearchMode::Vector => {
1049 let backend = load_embedder(embedder).await?;
1050 let vector = embed_query(backend.as_ref(), &plan.query)?;
1051 let vector_plan = store
1052 .explain_vector_plan(&vector, plan.vector_pool, &plan.filter, Some(search))
1053 .await
1054 .map_err(map_storage)?;
1055 out.push_str("vector:\n");
1056 out.push_str(&vector_plan);
1057 out.push('\n');
1058 }
1059 }
1060 Ok(out)
1061 }
1062
1063 async fn run_search(
1064 store: &Store,
1065 embedder: &LazyEmbedder,
1066 request: SearchRequest,
1067 search: &crate::config::SearchConfig,
1068 clock: &dyn Clock,
1069 ) -> Result<SearchResponse, ErrorEnvelope> {
1070 let mut plan = plan_search(request)?;
1071
1072 plan.mode = resolve_effective_mode(store, plan.mode).await?;
1075
1076 if matches!(plan.mode, SearchMode::Fts) && plan.min_score > 0.0 {
1080 return Err(map_error(crate::Error::validation_field(
1081 "min_score is not supported in fts mode (BM25 scores are unbounded \
1082 and not comparable across queries); use vector mode or drop min_score",
1083 "min_score",
1084 Some(serde_json::json!(plan.min_score)),
1085 Some("0 in fts mode".to_owned()),
1086 )));
1087 }
1088
1089 let candidates_fut = async {
1096 match plan.mode {
1097 SearchMode::Fts => {
1098 let mut hits = store
1099 .fts_search(&plan.query, plan.pool, &plan.filter)
1100 .await
1101 .map_err(map_storage)?;
1102 retain_non_subagents(&mut hits, plan.exclude_subagents);
1103 Ok(normalize_fts(hits))
1104 }
1105 SearchMode::Vector => {
1109 let backend = load_embedder(embedder).await?;
1110 let vector = embed_query(backend.as_ref(), &plan.query)?;
1111 let mut vector_raw = store
1112 .vector_search(&vector, plan.vector_pool, &plan.filter, Some(search))
1113 .await
1114 .map_err(map_storage)?;
1115 retain_non_subagents(&mut vector_raw, plan.exclude_subagents);
1116 Ok(normalize_vector(vector_raw))
1117 }
1118 }
1119 };
1120 let scope_fut = async {
1121 store
1122 .searchable_in_scope(&plan.filter)
1123 .await
1124 .map_err(map_storage)
1125 };
1126 let (candidates, searchable_in_scope) = tokio::try_join!(candidates_fut, scope_fut)?;
1127
1128 if candidates.is_empty() {
1129 return Ok(empty_response(searchable_in_scope));
1130 }
1131
1132 let (selected, total_sessions, matched_total) =
1138 select_top_hits(candidates, plan.min_score, plan.limit);
1139 if selected.is_empty() {
1140 return Ok(empty_response(searchable_in_scope));
1141 }
1142
1143 let rowids: Option<Vec<u64>> = selected.iter().map(|candidate| candidate.rowid).collect();
1149 let metas = match &rowids {
1150 Some(rowids) => store
1151 .message_metas_by_rowids(rowids)
1152 .await
1153 .map_err(map_storage)?,
1154 None => {
1155 let keys = selected
1156 .iter()
1157 .map(|candidate| MessageKey {
1158 session_id: candidate.session_id.clone(),
1159 message_id: candidate.message_id.clone(),
1160 })
1161 .collect::<Vec<_>>();
1162 store
1163 .message_metas_by_keys(&keys)
1164 .await
1165 .map_err(map_storage)?
1166 }
1167 };
1168 let meta_index = metas
1169 .iter()
1170 .map(|meta| ((meta.session_id.as_str(), meta.message_id.as_str()), meta))
1171 .collect::<std::collections::HashMap<_, _>>();
1172
1173 let now = clock.now();
1179 let mut scored = Vec::with_capacity(selected.len());
1180 for candidate in selected {
1181 let Some(meta) =
1182 meta_index.get(&(candidate.session_id.as_str(), candidate.message_id.as_str()))
1183 else {
1184 continue;
1185 };
1186 let order_score = match (plan.sort_by, plan.mode) {
1187 (SortBy::Recency, _) => recency_rank(meta.timestamp),
1188 (SortBy::Relevance, SearchMode::Vector) => {
1189 candidate.base_score + recency_boost(meta.timestamp, now)
1190 }
1191 (SortBy::Relevance, SearchMode::Fts) => candidate.base_score,
1192 };
1193 scored.push(ScoredHit {
1194 meta: (*meta).clone(),
1195 display_score: candidate.base_score,
1196 order_score,
1197 });
1198 }
1199 scored.sort_by(|left, right| {
1200 right
1201 .order_score
1202 .partial_cmp(&left.order_score)
1203 .unwrap_or(std::cmp::Ordering::Equal)
1204 .then_with(|| left.meta.session_id.cmp(&right.meta.session_id))
1205 .then_with(|| left.meta.message_id.cmp(&right.meta.message_id))
1206 });
1207
1208 let sessions = build_sessions(store, &scored, &plan.query).await?;
1209 page_sessions(
1210 sessions,
1211 matched_total,
1212 total_sessions,
1213 searchable_in_scope,
1214 &plan,
1215 )
1216 }
1217
1218 fn recency_boost(ts: DateTime<Utc>, now: DateTime<Utc>) -> f64 {
1223 let age_days = (now - ts).num_seconds().max(0) as f64 / 86_400.0;
1224 RECENCY_BOOST_MAGNITUDE * (-age_days / RECENCY_BOOST_SCALE_DAYS).exp()
1225 }
1226
1227 fn recency_rank(ts: DateTime<Utc>) -> f64 {
1230 ts.timestamp() as f64
1231 }
1232
1233 async fn resolve_effective_mode(
1237 store: &Store,
1238 requested: SearchMode,
1239 ) -> Result<SearchMode, ErrorEnvelope> {
1240 if matches!(requested, SearchMode::Fts) {
1241 return Ok(SearchMode::Fts);
1242 }
1243 let has = store.has_embeddings().await.map_err(map_storage)?;
1244 Ok(if has {
1245 SearchMode::Vector
1246 } else {
1247 SearchMode::Fts
1248 })
1249 }
1250
1251 async fn load_embedder(
1255 embedder: &LazyEmbedder,
1256 ) -> Result<std::sync::Arc<dyn Embedder>, ErrorEnvelope> {
1257 embedder.get().await.map_err(|error| {
1258 map_error(crate::Error::internal(format!(
1259 "embedder load failed: {error}"
1260 )))
1261 })
1262 }
1263
1264 pub fn plan_search(request: SearchRequest) -> Result<SearchPlan, ErrorEnvelope> {
1265 validate_protocol(request.protocol_version)?;
1266
1267 let _ns = super::resolve_namespace(request.namespace.as_deref())?;
1268
1269 let mode = match request.mode {
1270 crate::wire::SearchModeWire::Fts => SearchMode::Fts,
1271 crate::wire::SearchModeWire::Vector => SearchMode::Vector,
1272 };
1273 let sort_by = request.sort_by;
1274 let filters = request.filters;
1275 let query = request.query.trim().to_owned();
1276 if query.is_empty() {
1277 return Err(map_error(crate::Error::validation_field(
1278 "query must be non-empty after trim",
1279 "query",
1280 Some(serde_json::json!(request.query)),
1281 Some("non-empty string after trim".to_owned()),
1282 )));
1283 }
1284 if request.limit == 0 {
1285 return Err(map_error(crate::Error::validation_field(
1286 "limit must be at least 1",
1287 "limit",
1288 Some(serde_json::json!(request.limit)),
1289 Some("integer >= 1".to_owned()),
1290 )));
1291 }
1292 let limit = request.limit.min(LIMIT_CAP);
1293 let min_score = filters.min_score;
1294 let filter = build_scope_filter(&filters)?;
1295 let exclude_subagents = default_excludes_subagents(&filters);
1296 let mut pool = limit.saturating_mul(5).max(50);
1301 let mut vector_pool = pool.saturating_mul(2);
1302 if exclude_subagents {
1303 pool = pool.saturating_mul(3) / 2;
1304 vector_pool = vector_pool.saturating_mul(3) / 2;
1305 }
1306 Ok(SearchPlan {
1307 mode,
1308 query,
1309 filter,
1310 filters,
1311 sort_by,
1312 pool,
1313 vector_pool,
1314 limit,
1315 min_score,
1316 exclude_subagents,
1317 })
1318 }
1319
1320 fn session_root(session_id: &str) -> &str {
1325 match session_id.find('/') {
1326 Some(idx) => &session_id[..idx],
1327 None => session_id,
1328 }
1329 }
1330
1331 fn retain_non_subagents(hits: &mut Vec<SearchHit>, exclude: bool) {
1336 if exclude {
1337 hits.retain(|hit| !hit.key.session_id.contains('/'));
1338 }
1339 }
1340
1341 const ANCHOR_MIN_TERM_CHARS: usize = 4;
1346
1347 pub fn hit_payload(text: &str, query: &str) -> String {
1352 let chars_len = text.chars().count();
1353 if chars_len <= HIT_SNIPPET_CHARS {
1354 return text.to_owned();
1355 }
1356 query_snippet(text, query)
1357 }
1358
1359 fn query_snippet(text: &str, query: &str) -> String {
1375 let lower_text = text.to_lowercase();
1376 let terms: Vec<String> = query
1377 .split_whitespace()
1378 .filter(|term| !term.is_empty())
1379 .map(str::to_lowercase)
1380 .collect();
1381 let any_informative = terms
1382 .iter()
1383 .any(|term| term.chars().count() >= ANCHOR_MIN_TERM_CHARS);
1384 let hit = terms
1385 .iter()
1386 .filter(|term| !any_informative || term.chars().count() >= ANCHOR_MIN_TERM_CHARS)
1387 .filter_map(|term| lower_text.find(term.as_str()))
1388 .min();
1389 let chars: Vec<char> = text.chars().collect();
1390 let center = hit
1394 .map(|byte| lower_text[..byte].chars().count())
1395 .unwrap_or(0);
1396 let half = HIT_SNIPPET_CHARS / 2;
1397 let start = center.saturating_sub(half);
1398 let end = (start + HIT_SNIPPET_CHARS).min(chars.len());
1399 let start = end.saturating_sub(HIT_SNIPPET_CHARS);
1400 let mut snippet = String::new();
1404 if start > 0 {
1405 snippet.push_str(&format!("[{start} chars before] "));
1406 }
1407 snippet.extend(&chars[start..end]);
1408 if end < chars.len() {
1409 snippet.push_str(&format!(
1410 " [+{} more chars; pond_get for full]",
1411 chars.len() - end
1412 ));
1413 }
1414 snippet
1415 }
1416
1417 struct Candidate {
1418 rowid: Option<u64>,
1419 session_id: String,
1420 message_id: String,
1421 base_score: f64,
1422 }
1423
1424 struct ScoredHit {
1425 meta: MessageMeta,
1426 display_score: f64,
1429 order_score: f64,
1433 }
1434
1435 impl ScoredHit {
1436 fn to_search_result(
1437 &self,
1438 query: &str,
1439 summaries: &HashMap<(String, String), Vec<PartSummary>>,
1440 ) -> Result<SearchResult, ErrorEnvelope> {
1441 let text = hit_payload(&self.meta.search_text, query);
1442 let role = match self.meta.role.as_str() {
1443 "system" => Role::System,
1444 "user" => Role::User,
1445 "assistant" => Role::Assistant,
1446 "tool" => Role::Tool,
1447 other => {
1448 return Err(map_error(crate::Error::internal(format!(
1449 "stored message has unknown role: {other}"
1450 ))));
1451 }
1452 };
1453 let parts_summary = if matches!(role, Role::User) {
1456 summaries
1457 .get(&(self.meta.session_id.clone(), self.meta.message_id.clone()))
1458 .cloned()
1459 .unwrap_or_default()
1460 } else {
1461 Vec::new()
1462 };
1463 Ok(SearchResult {
1464 message_id: self.meta.message_id.clone(),
1465 role,
1466 timestamp: self.meta.timestamp,
1467 text,
1468 score: self.display_score.clamp(0.0, 1.0),
1469 parts_summary,
1470 })
1471 }
1472 }
1473
1474 fn normalize_fts(hits: Vec<SearchHit>) -> Vec<Candidate> {
1475 let max = hits.iter().map(|hit| hit.score).fold(0.0_f32, f32::max);
1476 hits.into_iter()
1477 .map(|hit| Candidate {
1478 rowid: hit.rowid,
1479 session_id: hit.key.session_id,
1480 message_id: hit.key.message_id,
1481 base_score: if max > 0.0 {
1482 f64::from(hit.score / max)
1483 } else {
1484 0.0
1485 },
1486 })
1487 .collect()
1488 }
1489
1490 fn normalize_vector(hits: Vec<SearchHit>) -> Vec<Candidate> {
1494 hits.into_iter()
1495 .map(|hit| Candidate {
1496 rowid: hit.rowid,
1497 session_id: hit.key.session_id,
1498 message_id: hit.key.message_id,
1499 base_score: 1.0 - f64::from(hit.score),
1500 })
1501 .collect()
1502 }
1503
1504 fn embed_query(embedder: &dyn Embedder, query: &str) -> Result<Vec<f32>, ErrorEnvelope> {
1505 let prompt = format_query(query);
1506 let vectors =
1510 tokio::task::block_in_place(|| embedder.embed(&[prompt])).map_err(|error_value| {
1511 map_error(crate::Error::internal(format!(
1512 "failed to embed query: {error_value}"
1513 )))
1514 })?;
1515 vectors.into_iter().next().ok_or_else(|| {
1516 map_error(crate::Error::internal(
1517 "embedder returned no vector for query",
1518 ))
1519 })
1520 }
1521
1522 fn select_top_hits(
1532 mut candidates: Vec<Candidate>,
1533 min_score: f64,
1534 limit: usize,
1535 ) -> (Vec<Candidate>, usize, usize) {
1536 if min_score > 0.0 {
1542 candidates.retain(|candidate| candidate.base_score >= min_score);
1543 }
1544 let matched_total = candidates.len();
1545 candidates.sort_by(|left, right| {
1546 right
1547 .base_score
1548 .partial_cmp(&left.base_score)
1549 .unwrap_or(std::cmp::Ordering::Equal)
1550 .then_with(|| left.session_id.cmp(&right.session_id))
1551 .then_with(|| left.message_id.cmp(&right.message_id))
1552 });
1553 let (total_sessions, keep) = {
1556 let mut order: Vec<&str> = Vec::new();
1557 let mut seen: std::collections::HashSet<&str> = std::collections::HashSet::new();
1558 for candidate in &candidates {
1559 let root = session_root(&candidate.session_id);
1560 if seen.insert(root) {
1561 order.push(root);
1562 }
1563 }
1564 let total = order.len();
1565 let keep: std::collections::HashSet<String> =
1566 order.into_iter().take(limit).map(str::to_owned).collect();
1567 (total, keep)
1568 };
1569 let selected = candidates
1570 .into_iter()
1571 .filter(|candidate| keep.contains(session_root(&candidate.session_id)))
1572 .collect();
1573 (selected, total_sessions, matched_total)
1574 }
1575
1576 async fn build_sessions(
1577 store: &Store,
1578 scored: &[ScoredHit],
1579 query: &str,
1580 ) -> Result<Vec<SearchSession>, ErrorEnvelope> {
1581 use std::collections::BTreeMap;
1582
1583 struct Acc {
1584 project: String,
1585 source_agent: String,
1586 matched_count: usize,
1587 rank: f64,
1590 matches: Vec<(DateTime<Utc>, SearchResult)>,
1591 }
1592 let mut user_ids_by_session: BTreeMap<String, Vec<String>> = BTreeMap::new();
1596 for hit in scored {
1597 if hit.meta.role == "user" {
1598 user_ids_by_session
1599 .entry(hit.meta.session_id.clone())
1600 .or_default()
1601 .push(hit.meta.message_id.clone());
1602 }
1603 }
1604 let mut summaries: HashMap<(String, String), Vec<PartSummary>> = HashMap::new();
1605 for (session_id, message_ids) in &user_ids_by_session {
1606 for (key, parts) in store
1607 .summary_parts_for_messages(session_id, message_ids)
1608 .await
1609 .map_err(map_storage)?
1610 {
1611 summaries.insert(
1612 key,
1613 parts
1614 .iter()
1615 .filter_map(|part| PartSummary::for_kind(&part.kind))
1616 .collect(),
1617 );
1618 }
1619 }
1620
1621 let mut groups: BTreeMap<String, Acc> = BTreeMap::new();
1622 for hit in scored {
1623 let root = session_root(&hit.meta.session_id).to_owned();
1624 let entry = groups.entry(root).or_insert_with(|| Acc {
1625 project: hit.meta.project.clone(),
1626 source_agent: hit.meta.source_agent.clone(),
1627 matched_count: 0,
1628 rank: f64::NEG_INFINITY,
1629 matches: Vec::new(),
1630 });
1631 entry.matched_count += 1;
1632 entry.rank = entry.rank.max(hit.order_score);
1633 entry
1634 .matches
1635 .push((hit.meta.timestamp, hit.to_search_result(query, &summaries)?));
1636 }
1637
1638 let session_ids = groups.keys().cloned().collect::<Vec<_>>();
1639 let counts = store
1640 .session_message_counts(&session_ids)
1641 .await
1642 .map_err(map_storage)?;
1643
1644 let mut result = groups
1649 .into_iter()
1650 .map(|(session_id, mut acc)| {
1651 acc.matches.sort_by(|left, right| {
1652 right
1653 .0
1654 .cmp(&left.0)
1655 .then_with(|| left.1.message_id.cmp(&right.1.message_id))
1656 });
1657 let matches = acc.matches.into_iter().map(|(_, result)| result).collect();
1658 (
1659 acc.rank,
1660 SearchSession {
1661 session_messages_count: counts
1662 .get(&session_id)
1663 .copied()
1664 .unwrap_or_default(),
1665 session_id,
1666 project: acc.project,
1667 source_agent: acc.source_agent,
1668 matched_message_count: acc.matched_count,
1669 matches,
1670 },
1671 )
1672 })
1673 .collect::<Vec<_>>();
1674 result.sort_by(|left, right| {
1675 right
1676 .0
1677 .partial_cmp(&left.0)
1678 .unwrap_or(std::cmp::Ordering::Equal)
1679 .then_with(|| left.1.session_id.cmp(&right.1.session_id))
1680 });
1681 Ok(result.into_iter().map(|(_, session)| session).collect())
1682 }
1683
1684 fn page_sessions(
1685 sessions: Vec<SearchSession>,
1686 matched_total: usize,
1687 total_sessions: usize,
1688 searchable_in_scope: usize,
1689 plan: &SearchPlan,
1690 ) -> Result<SearchResponse, ErrorEnvelope> {
1691 let emitted: Vec<SearchSession> = sessions.into_iter().take(plan.limit).collect();
1698 let has_more = total_sessions > emitted.len();
1699
1700 Ok(SearchResponse {
1701 sessions: emitted,
1702 matched_total,
1703 searchable_in_scope,
1704 has_more,
1705 })
1706 }
1707
1708 fn build_scope_clauses(filters: &SearchFilters) -> Result<Vec<Predicate>, ErrorEnvelope> {
1712 let mut clauses = Vec::new();
1713
1714 match &filters.project {
1715 None => {}
1716 Some(ProjectFilter::Contains(value)) => {
1717 clauses.push(Predicate::LikeContains("project", value.clone()));
1718 }
1719 Some(ProjectFilter::Regex(pattern)) => {
1720 clauses.push(Predicate::Regex("project", pattern.clone()));
1721 }
1722 }
1723
1724 if let Some(session_id) = &filters.session_id {
1725 clauses.push(Predicate::Eq("session_id", session_id.clone().into()));
1726 }
1727 if let Some(from_date) = &filters.from_date {
1728 clauses.push(Predicate::Gte(
1729 "timestamp",
1730 ScalarValue::Raw(date_bound(from_date, "filters.from_date", false)?),
1731 ));
1732 }
1733 if let Some(to_date) = &filters.to_date {
1734 clauses.push(Predicate::Lte(
1735 "timestamp",
1736 ScalarValue::Raw(date_bound(to_date, "filters.to_date", true)?),
1737 ));
1738 }
1739
1740 Ok(clauses)
1741 }
1742
1743 pub fn build_scope_filter(filters: &SearchFilters) -> Result<Predicate, ErrorEnvelope> {
1747 Ok(Predicate::And(build_scope_clauses(filters)?))
1748 }
1749
1750 pub fn default_excludes_subagents(filters: &SearchFilters) -> bool {
1755 filters.session_id.is_none()
1756 }
1757
1758 fn date_bound(date: &str, field: &str, end_of_day: bool) -> Result<String, ErrorEnvelope> {
1761 NaiveDate::parse_from_str(date, "%Y-%m-%d").map_err(|_| {
1762 map_error(crate::Error::validation_field(
1763 format!("{field} must be in YYYY-MM-DD format; got {date}"),
1764 field,
1765 Some(serde_json::json!(date)),
1766 Some("YYYY-MM-DD".to_owned()),
1767 ))
1768 })?;
1769 let time = if end_of_day { "23:59:59" } else { "00:00:00" };
1770 Ok(format!("timestamp '{date} {time}'"))
1771 }
1772
1773 fn empty_response(searchable_in_scope: usize) -> SearchResponse {
1774 SearchResponse {
1775 sessions: Vec::new(),
1776 matched_total: 0,
1777 searchable_in_scope,
1778 has_more: false,
1779 }
1780 }
1781
1782 #[cfg(test)]
1783 mod grouping_helpers_tests {
1784 #![allow(clippy::expect_used, clippy::unwrap_used)]
1785
1786 use super::*;
1787
1788 #[test]
1789 fn session_root_strips_agent_suffix_for_claude_code_subagents() {
1790 assert_eq!(
1791 session_root("94a50f23-1234-5678-9abc-def012345678"),
1792 "94a50f23-1234-5678-9abc-def012345678",
1793 );
1794 assert_eq!(
1795 session_root("94a50f23-1234-5678-9abc-def012345678/agent-abc123"),
1796 "94a50f23-1234-5678-9abc-def012345678",
1797 );
1798 assert_eq!(session_root("root/a/b"), "root");
1800 }
1801
1802 #[test]
1803 fn retain_non_subagents_drops_slash_ids_only_when_excluding() {
1804 let hit = |sid: &str| SearchHit {
1805 rowid: None,
1806 key: crate::sessions::MessageKey {
1807 session_id: sid.to_owned(),
1808 message_id: "m1".to_owned(),
1809 },
1810 score: 1.0_f32,
1811 };
1812 let base = vec![hit("root-a"), hit("root-b/agent-x"), hit("root-c")];
1813
1814 let mut excluded = base.clone();
1815 retain_non_subagents(&mut excluded, true);
1816 let ids: Vec<&str> = excluded
1817 .iter()
1818 .map(|hit| hit.key.session_id.as_str())
1819 .collect();
1820 assert_eq!(ids, ["root-a", "root-c"]);
1821
1822 let mut kept = base;
1823 retain_non_subagents(&mut kept, false);
1824 assert_eq!(kept.len(), 3);
1825 }
1826 }
1827}
1828
1829pub use search_handler::{
1830 SearchMode, SearchPlan, build_scope_filter, default_excludes_subagents, explain_search_plan,
1831 hit_payload, plan_search, pond_search,
1832};
1833
1834#[cfg(test)]
1835mod tests {
1836 #![allow(clippy::expect_used, clippy::unwrap_used)]
1837
1838 use super::*;
1839 use crate::wire::{ProjectFilter, SearchFilters, SearchRequest};
1840 use chrono::Utc;
1841
1842 fn search_request(query: &str) -> SearchRequest {
1843 SearchRequest {
1844 protocol_version: crate::PROTOCOL_VERSION,
1845 namespace: Some("local".to_owned()),
1846 query: query.to_owned(),
1847 mode: crate::wire::SearchModeWire::Vector,
1848 sort_by: crate::wire::SortBy::Relevance,
1849 filters: SearchFilters::default(),
1850 limit: 20,
1851 }
1852 }
1853
1854 #[test]
1855 fn hit_payload_returns_short_text_in_full() {
1856 let short = "a short message body";
1857 let text = hit_payload(short, "message");
1858 assert_eq!(text, short, "small text is returned as-is");
1859 }
1860
1861 #[test]
1862 fn hit_payload_windows_long_text_around_the_query_term() {
1863 let body = format!("{}NEEDLE{}", "a".repeat(2000), "b".repeat(394));
1865 let text = hit_payload(&body, "needle");
1866 assert!(
1867 text.contains("NEEDLE"),
1868 "text is the match-windowed snippet: {text}"
1869 );
1870 assert!(
1873 text.chars().count() <= 600 + 64,
1874 "snippet window is bounded by HIT_SNIPPET_CHARS plus markers: {}",
1875 text.chars().count()
1876 );
1877 }
1878
1879 #[test]
1880 fn hit_payload_snippet_survives_case_folding_that_changes_byte_length() {
1881 let body = format!("İÉÉÉ{}", "a".repeat(2100));
1885 let text = hit_payload(&body, "ééé");
1886 assert!(
1887 text.contains("ÉÉÉ"),
1888 "snippet windows on the matched term: {text}"
1889 );
1890 }
1891
1892 #[tokio::test]
1893 async fn restore_lineage_rejects_a_graph_nesting_deeper_than_one_level() {
1894 use crate::adapter::Extracted;
1895 use crate::sessions::Store;
1896 use crate::wire::{ProviderOptions, Session};
1897 use tempfile::TempDir;
1898
1899 let session = |id: &str, parent: Option<&str>| Session {
1900 id: id.to_owned(),
1901 parent_session_id: parent.map(str::to_owned),
1902 parent_message_id: None,
1903 source_agent: "claude-code".to_owned(),
1904 created_at: Utc::now(),
1905 project: Extracted::from_test_value("/tmp/pond".to_owned()),
1906 options: ProviderOptions::new(),
1907 };
1908
1909 let dir = TempDir::new().unwrap();
1910 let store = Store::open_local(dir.path()).await.unwrap();
1911 store
1913 .upsert_sessions(&[
1914 session("a", None),
1915 session("b", Some("a")),
1916 session("c", Some("b")),
1917 ])
1918 .await
1919 .unwrap();
1920
1921 let err = restore_lineage(&store, "a").await.unwrap_err();
1923 assert!(
1924 err.to_string().contains("one subagent level"),
1925 "expected the deeper-graph error, got: {err}"
1926 );
1927
1928 let lineage = restore_lineage(&store, "b").await.unwrap();
1930 let ids: Vec<&str> = lineage.iter().map(|s| s.session.id.as_str()).collect();
1931 assert_eq!(ids, ["b", "c"]);
1932 }
1933
1934 #[test]
1935 fn build_scope_filter_pushes_down_each_predicate_and_handles_empty() {
1936 let filters = SearchFilters {
1937 project: Some(ProjectFilter::Contains("/Users/me/pond".to_owned())),
1938 session_id: Some("01HXY".to_owned()),
1939 from_date: Some("2026-01-01".to_owned()),
1940 to_date: Some("2026-05-01".to_owned()),
1941 min_score: 0.0,
1942 };
1943 let sql = build_scope_filter(&filters).unwrap().to_lance();
1944 assert!(sql.contains("project LIKE '%/Users/me/pond%'"));
1945 assert!(sql.contains("session_id = '01HXY'"));
1946 assert!(sql.contains("timestamp >="));
1947 assert!(sql.contains("timestamp <="));
1948 assert!(!sql.contains("source_agent"));
1950
1951 assert_eq!(
1954 build_scope_filter(&SearchFilters::default())
1955 .unwrap()
1956 .to_lance(),
1957 "",
1958 );
1959 }
1960
1961 #[test]
1962 fn build_scope_filter_rejects_bad_date() {
1963 let bad_date = SearchFilters {
1964 from_date: Some("01-01-2026".to_owned()),
1965 ..SearchFilters::default()
1966 };
1967 assert!(build_scope_filter(&bad_date).is_err());
1968 }
1969
1970 #[test]
1971 fn build_scope_filter_escapes_like_wildcards() {
1972 let filters = SearchFilters {
1973 project: Some(ProjectFilter::Contains("/Users/me/my_project".to_owned())),
1974 ..SearchFilters::default()
1975 };
1976 let sql = build_scope_filter(&filters).unwrap().to_lance();
1977 assert!(
1980 sql.contains(r"my\_project"),
1981 "underscore must be escaped: {sql}"
1982 );
1983 assert!(
1984 sql.contains(r"ESCAPE '\'"),
1985 "predicate must declare the escape char: {sql}"
1986 );
1987 }
1988
1989 #[test]
1990 fn plan_search_shapes_request_for_each_planning_input() {
1991 let mut request = search_request(" vector memory ");
1992 request.limit = 500;
1993 request.filters.min_score = 0.42;
1994 let plan = plan_search(request).unwrap();
1996 assert_eq!(plan.mode, SearchMode::Vector);
1997 assert_eq!(plan.query, "vector memory");
1998 assert_eq!(plan.limit, 200);
1999 assert!(plan.exclude_subagents);
2002 assert_eq!(plan.pool, 1500);
2003 assert_eq!(plan.vector_pool, 3000);
2004 assert_eq!(plan.min_score, 0.42);
2005
2006 let mut request = search_request("tiny pool");
2009 request.mode = crate::wire::SearchModeWire::Fts;
2010 request.limit = 1;
2011 let plan = plan_search(request).unwrap();
2012 assert_eq!(plan.mode, SearchMode::Fts);
2013 assert_eq!(plan.limit, 1);
2014 assert_eq!(plan.pool, 75);
2015 assert_eq!(plan.vector_pool, 150);
2016
2017 let mut request = search_request("filtered");
2021 request.filters.project = Some(ProjectFilter::Contains("/Users/me/pond".to_owned()));
2022 request.filters.session_id = Some("01HXY".to_owned());
2023 let plan = plan_search(request).unwrap();
2024 assert!(!plan.exclude_subagents);
2025 assert_eq!(plan.pool, 100);
2026 assert_eq!(plan.vector_pool, 200);
2027 let sql = plan.filter.to_lance();
2028 assert!(sql.contains("project LIKE"));
2029 assert!(sql.contains("session_id = '01HXY'"));
2030 }
2031
2032 #[test]
2033 fn plan_search_rejects_invalid_composition_before_execution() {
2034 let mut blank = search_request(" ");
2035 let error = plan_search(blank.clone()).unwrap_err().error;
2036 assert_eq!(error.code, crate::wire::ErrorCode::ValidationFailed);
2037 assert_eq!(error.details["field"], "query");
2038
2039 blank.query = "valid".to_owned();
2040 blank.limit = 0;
2041 let error = plan_search(blank.clone()).unwrap_err().error;
2042 assert_eq!(error.details["field"], "limit");
2043
2044 blank.limit = 1;
2045 blank.namespace = Some("remote".to_owned());
2046 let error = plan_search(blank).unwrap_err().error;
2047 assert_eq!(error.code, crate::wire::ErrorCode::NamespaceUnknown);
2048 assert_eq!(error.details["namespace"], "remote");
2049 }
2050}
2051
2052#[cfg(test)]
2053mod get_tests {
2054 #![allow(clippy::expect_used, clippy::unwrap_used)]
2055
2056 use crate::sessions::Store;
2057 use crate::wire::{
2058 GetEnvelope, GetRequest, GetResult, IngestEnvelope, IngestRequest, Message, Part, PartKind,
2059 Provenance, ProviderOptions, Session, SessionFrom,
2060 };
2061 use chrono::{TimeZone, Utc};
2062 use tempfile::TempDir;
2063
2064 fn text_part(session_id: &str, message_id: &str, part_id: &str, body: &str) -> Part {
2065 Part {
2066 session_id: session_id.to_owned(),
2067 id: part_id.to_owned(),
2068 message_id: message_id.to_owned(),
2069 ordinal: 0,
2070 provenance: Provenance::Conversational,
2071 options: ProviderOptions::new(),
2072 kind: PartKind::Text {
2073 text: crate::adapter::extract_str(&serde_json::json!({ "x": body }), "x"),
2074 },
2075 }
2076 }
2077
2078 async fn ingest(store: &Store, events: Vec<super::IngestEvent>) {
2079 let envelope = super::pond_ingest(
2080 store,
2081 IngestRequest {
2082 protocol_version: crate::PROTOCOL_VERSION,
2083 namespace: Some("local".to_owned()),
2084 events,
2085 },
2086 )
2087 .await;
2088 assert!(
2089 matches!(envelope, IngestEnvelope::Success(_)),
2090 "ingest should succeed: {envelope:?}"
2091 );
2092 }
2093
2094 fn session(id: &str, project_marker: &str) -> Session {
2095 Session {
2096 id: id.to_owned(),
2097 parent_session_id: None,
2098 parent_message_id: None,
2099 source_agent: "claude-code".to_owned(),
2100 created_at: Utc.with_ymd_and_hms(2026, 1, 1, 0, 0, 0).unwrap(),
2101 project: crate::adapter::extract_str(&serde_json::json!({ "x": project_marker }), "x")
2102 .unwrap(),
2103 options: ProviderOptions::new(),
2104 }
2105 }
2106
2107 #[tokio::test(flavor = "multi_thread")]
2112 async fn pond_get_paginates_session_via_after_message_id() -> anyhow::Result<()> {
2113 let temp = TempDir::new()?;
2114 let store = Store::open_local(temp.path()).await?;
2115 let session_id = "paginate-session";
2116
2117 let huge_text = "abc def ghi jkl ".repeat(5000);
2120 let mut events = vec![super::IngestEvent::Session(session(
2121 session_id,
2122 "pond-paginate",
2123 ))];
2124 for index in 0..3 {
2125 let message_id = format!("paginate-msg-{index}");
2126 events.push(super::IngestEvent::Message(Message::User {
2127 id: message_id.clone(),
2128 session_id: session_id.to_owned(),
2129 timestamp: Utc
2130 .with_ymd_and_hms(2026, 1, 1, 0, index as u32 + 1, 0)
2131 .unwrap(),
2132 options: ProviderOptions::new(),
2133 }));
2134 events.push(super::IngestEvent::Part(text_part(
2135 session_id,
2136 &message_id,
2137 &format!("paginate-part-{index}"),
2138 &huge_text,
2139 )));
2140 }
2141 ingest(&store, events).await;
2142
2143 let page_request = |after: Option<String>| GetRequest {
2144 protocol_version: crate::PROTOCOL_VERSION,
2145 namespace: Some("local".to_owned()),
2146 session_id: Some(session_id.to_owned()),
2147 message_id: None,
2148 session_limit: 1000,
2149 session_from: SessionFrom::Start,
2150 session_after_message_id: after,
2151 session_before_message_id: None,
2152 message_context_before: 3,
2153 message_context_after: 3,
2154 };
2155
2156 let GetEnvelope::Success(first) = super::pond_get(&store, page_request(None)).await else {
2157 panic!("first page must succeed");
2158 };
2159 let GetResult::Session {
2160 messages: first_messages,
2161 after_remaining,
2162 ..
2163 } = first.result
2164 else {
2165 panic!("first page is session-scope");
2166 };
2167 assert!(after_remaining > 0, "long corpus must trip the page budget");
2168 let after = first_messages.last().expect("non-empty page").id.clone();
2169
2170 let GetEnvelope::Success(second) = super::pond_get(&store, page_request(Some(after))).await
2171 else {
2172 panic!("continuation page must succeed");
2173 };
2174 let GetResult::Session {
2175 messages: second_messages,
2176 ..
2177 } = second.result
2178 else {
2179 panic!("continuation is session-scope");
2180 };
2181 assert!(
2182 !second_messages.is_empty(),
2183 "continuation surfaces the rest"
2184 );
2185 let first_ids: std::collections::HashSet<&str> =
2186 first_messages.iter().map(|m| m.id.as_str()).collect();
2187 assert!(
2188 second_messages
2189 .iter()
2190 .all(|m| !first_ids.contains(m.id.as_str())),
2191 "session_after_message_id pages must be disjoint"
2192 );
2193 Ok(())
2194 }
2195
2196 #[tokio::test(flavor = "multi_thread")]
2201 async fn pond_get_session_from_end_returns_the_recent_tail() -> anyhow::Result<()> {
2202 let temp = TempDir::new()?;
2203 let store = Store::open_local(temp.path()).await?;
2204 let session_id = "tail-session";
2205
2206 let mut events = vec![super::IngestEvent::Session(session(
2207 session_id,
2208 "pond-tail",
2209 ))];
2210 for index in 0..5u32 {
2211 let message_id = format!("tail-msg-{index}");
2212 events.push(super::IngestEvent::Message(Message::User {
2213 id: message_id.clone(),
2214 session_id: session_id.to_owned(),
2215 timestamp: Utc.with_ymd_and_hms(2026, 1, 1, 0, index + 1, 0).unwrap(),
2216 options: ProviderOptions::new(),
2217 }));
2218 events.push(super::IngestEvent::Part(text_part(
2219 session_id,
2220 &message_id,
2221 &format!("tail-part-{index}"),
2222 &format!("message {index}"),
2223 )));
2224 }
2225 ingest(&store, events).await;
2226
2227 let request = |from: SessionFrom| GetRequest {
2228 protocol_version: crate::PROTOCOL_VERSION,
2229 namespace: Some("local".to_owned()),
2230 session_id: Some(session_id.to_owned()),
2231 message_id: None,
2232 session_limit: 2,
2233 session_from: from,
2234 session_after_message_id: None,
2235 session_before_message_id: None,
2236 message_context_before: 3,
2237 message_context_after: 3,
2238 };
2239 let page = |envelope: GetEnvelope| -> (Vec<String>, usize, usize) {
2240 let GetEnvelope::Success(response) = envelope else {
2241 panic!("get must succeed");
2242 };
2243 let GetResult::Session {
2244 messages,
2245 before_remaining,
2246 after_remaining,
2247 } = response.result
2248 else {
2249 panic!("session-scope result expected");
2250 };
2251 (
2252 messages.into_iter().map(|m| m.id).collect(),
2253 before_remaining,
2254 after_remaining,
2255 )
2256 };
2257
2258 let (end_ids, end_before, _) =
2259 page(super::pond_get(&store, request(SessionFrom::End)).await);
2260 assert_eq!(
2261 end_ids,
2262 ["tail-msg-3", "tail-msg-4"],
2263 "end returns the newest two, chronologically"
2264 );
2265 assert_eq!(end_before, 3, "three older messages precede the tail");
2266
2267 let (start_ids, _, start_after) =
2268 page(super::pond_get(&store, request(SessionFrom::Start)).await);
2269 assert_eq!(
2270 start_ids,
2271 ["tail-msg-0", "tail-msg-1"],
2272 "start returns the oldest two"
2273 );
2274 assert_eq!(start_after, 3, "three newer messages follow the head");
2275 Ok(())
2276 }
2277}