use anyhow::{bail, Result};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
#[serde(rename_all = "snake_case")]
pub(crate) enum RecallScope {
Profile,
Repo,
Pod,
Branch,
Clone,
}
impl RecallScope {
pub(crate) fn parse(raw: &str) -> Option<Self> {
match raw {
"profile" => Some(Self::Profile),
"repo" | "locality" => Some(Self::Repo),
"pod" => Some(Self::Pod),
"branch" => Some(Self::Branch),
"clone" => Some(Self::Clone),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct ProviderDescriptor {
pub(crate) name: String,
pub(crate) kind: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct RecallProviderCapabilities {
pub(crate) search: bool,
pub(crate) describe: bool,
pub(crate) expand: bool,
pub(crate) health: bool,
}
impl RecallProviderCapabilities {
#[cfg_attr(not(test), allow(dead_code))]
pub(crate) fn all() -> Self {
Self {
search: true,
describe: true,
expand: true,
health: true,
}
}
pub(crate) fn from_declared(names: &[String]) -> Self {
let mut caps = Self::default();
for name in names {
match normalize_capability_name(name).as_str() {
"search" => caps.search = true,
"describe" => caps.describe = true,
"expand" => caps.expand = true,
"health" => caps.health = true,
_ => {}
}
}
caps
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct ProviderReference {
pub(crate) provider_name: String,
pub(crate) provider_kind: String,
pub(crate) provider_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) provider_scope: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct RecallProvenance {
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) source_ref: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) provider_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) surface_path: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct RecallTemporalHints {
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) observed_at: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) valid_at: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) updated_at: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct RecallResult {
pub(crate) provider_ref: ProviderReference,
pub(crate) ccd_scope: RecallScope,
pub(crate) title: String,
pub(crate) summary: String,
pub(crate) rank: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) preview: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) detail: Option<String>,
pub(crate) can_expand: bool,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub(crate) memory_class_hints: Vec<String>,
#[serde(default)]
pub(crate) provenance: RecallProvenance,
#[serde(default)]
pub(crate) temporal: RecallTemporalHints,
}
impl RecallResult {
pub(crate) fn validate(&self) -> Result<()> {
if self.provider_ref.provider_name.trim().is_empty() {
bail!("provider_ref.provider_name must not be empty");
}
if self.provider_ref.provider_kind.trim().is_empty() {
bail!("provider_ref.provider_kind must not be empty");
}
if self.provider_ref.provider_id.trim().is_empty() {
bail!("provider_ref.provider_id must not be empty");
}
if self.title.trim().is_empty() {
bail!("title must not be empty");
}
if self.summary.trim().is_empty() {
bail!("summary must not be empty");
}
if self.rank == 0 {
bail!("rank must be greater than zero");
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct RecallHealthView {
pub(crate) status: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) message: Option<String>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct StartRecallBudget {
pub(crate) search_limit: usize,
pub(crate) describe_limit: usize,
pub(crate) expand_limit: usize,
}
impl Default for StartRecallBudget {
fn default() -> Self {
Self {
search_limit: 4,
describe_limit: 2,
expand_limit: 1,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct StartRecallView {
pub(crate) status: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) configured_provider: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) configured_provider_kind: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) used_provider: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) used_provider_kind: Option<String>,
pub(crate) fallback_used: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) query: Option<String>,
pub(crate) budget: StartRecallBudget,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub(crate) search_results: Vec<RecallResult>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub(crate) described_results: Vec<RecallResult>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub(crate) expanded_results: Vec<RecallResult>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) error: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub(crate) warnings: Vec<String>,
}
impl StartRecallView {
pub(crate) fn missing() -> Self {
Self {
status: "missing",
configured_provider: None,
configured_provider_kind: None,
used_provider: None,
used_provider_kind: None,
fallback_used: false,
query: None,
budget: StartRecallBudget::default(),
search_results: Vec::new(),
described_results: Vec::new(),
expanded_results: Vec::new(),
error: None,
warnings: Vec::new(),
}
}
pub(crate) fn skipped() -> Self {
Self {
status: "skipped",
..Self::missing()
}
}
pub(crate) fn disabled(
descriptor: Option<ProviderDescriptor>,
budget: StartRecallBudget,
reason: Option<String>,
) -> Self {
Self {
status: "disabled",
configured_provider: descriptor.as_ref().map(|value| value.name.clone()),
configured_provider_kind: descriptor.as_ref().map(|value| value.kind.clone()),
used_provider: None,
used_provider_kind: None,
fallback_used: false,
query: None,
budget,
search_results: Vec::new(),
described_results: Vec::new(),
expanded_results: Vec::new(),
error: reason,
warnings: Vec::new(),
}
}
}
pub(crate) trait RecallProvider {
fn descriptor(&self) -> ProviderDescriptor;
fn capabilities(&self) -> RecallProviderCapabilities;
#[cfg_attr(not(test), allow(dead_code))]
fn health(&mut self) -> Result<RecallHealthView>;
fn search(&mut self, query: &str) -> Result<Vec<RecallResult>>;
fn describe(&mut self, reference: &ProviderReference) -> Result<Option<RecallResult>>;
fn expand(&mut self, reference: &ProviderReference) -> Result<Vec<RecallResult>>;
}
pub(crate) fn build_start_query(title: &str, immediate_actions: &[String]) -> Option<String> {
let mut parts = Vec::new();
if !title.trim().is_empty() {
parts.push(title.trim().to_owned());
}
for item in immediate_actions {
let item = item.trim();
if !item.is_empty() {
parts.push(item.to_owned());
}
}
if parts.is_empty() {
None
} else {
Some(parts.join(" | "))
}
}
pub(crate) fn run_start_recall(
provider: &mut dyn RecallProvider,
fallback: Option<&mut dyn RecallProvider>,
query: Option<String>,
budget: StartRecallBudget,
) -> StartRecallView {
let descriptor = provider.descriptor();
let Some(query) = query.filter(|value| !value.trim().is_empty()) else {
return StartRecallView::disabled(Some(descriptor), budget, None);
};
match collect_start_recall(provider, &query, budget) {
Ok((used_descriptor, search_results, described_results, expanded_results)) => {
StartRecallView {
status: "loaded",
configured_provider: Some(descriptor.name),
configured_provider_kind: Some(descriptor.kind),
used_provider: Some(used_descriptor.name),
used_provider_kind: Some(used_descriptor.kind),
fallback_used: false,
query: Some(query),
budget,
search_results,
described_results,
expanded_results,
error: None,
warnings: Vec::new(),
}
}
Err(error) => {
let error_message = error.to_string();
let Some(fallback_provider) = fallback else {
return StartRecallView {
status: "error",
configured_provider: Some(descriptor.name),
configured_provider_kind: Some(descriptor.kind),
used_provider: None,
used_provider_kind: None,
fallback_used: false,
query: Some(query),
budget,
search_results: Vec::new(),
described_results: Vec::new(),
expanded_results: Vec::new(),
error: Some(error_message),
warnings: Vec::new(),
};
};
match collect_start_recall(fallback_provider, &query, budget) {
Ok((fallback_descriptor, search_results, described_results, expanded_results)) => {
StartRecallView {
status: "fallback",
configured_provider: Some(descriptor.name),
configured_provider_kind: Some(descriptor.kind),
used_provider: Some(fallback_descriptor.name),
used_provider_kind: Some(fallback_descriptor.kind),
fallback_used: true,
query: Some(query),
budget,
search_results,
described_results,
expanded_results,
error: None,
warnings: vec![error_message],
}
}
Err(fallback_error) => StartRecallView {
status: "error",
configured_provider: Some(descriptor.name),
configured_provider_kind: Some(descriptor.kind),
used_provider: None,
used_provider_kind: None,
fallback_used: true,
query: Some(query),
budget,
search_results: Vec::new(),
described_results: Vec::new(),
expanded_results: Vec::new(),
error: Some(format!(
"{error_message}; fallback provider also failed: {fallback_error}"
)),
warnings: Vec::new(),
},
}
}
}
}
type StartRecallArtifacts = (
ProviderDescriptor,
Vec<RecallResult>,
Vec<RecallResult>,
Vec<RecallResult>,
);
fn collect_start_recall(
provider: &mut dyn RecallProvider,
query: &str,
budget: StartRecallBudget,
) -> Result<StartRecallArtifacts> {
let descriptor = provider.descriptor();
let capabilities = provider.capabilities();
if !capabilities.search {
bail!(
"configured recall provider `{}` does not declare `search` capability",
descriptor.name
);
}
let mut search_results = provider.search(query)?;
for result in &search_results {
result.validate()?;
}
search_results.truncate(budget.search_limit);
let mut described_results = Vec::new();
if capabilities.describe && budget.describe_limit > 0 {
for result in search_results.iter().take(budget.describe_limit) {
if let Some(described) = provider.describe(&result.provider_ref)? {
described.validate()?;
described_results.push(described);
}
}
}
let mut expanded_results = Vec::new();
if capabilities.expand && budget.expand_limit > 0 {
for result in described_results.iter().take(budget.expand_limit) {
let mut expanded = provider.expand(&result.provider_ref)?;
for item in &expanded {
item.validate()?;
}
expanded_results.append(&mut expanded);
}
}
Ok((
descriptor,
search_results,
described_results,
expanded_results,
))
}
fn normalize_capability_name(name: &str) -> String {
name.trim().to_ascii_lowercase().replace('_', "-")
}
#[cfg(test)]
pub(crate) mod harness {
use super::{RecallProvider, Result};
pub(crate) fn assert_provider_conforms(
provider: &mut dyn RecallProvider,
query: &str,
) -> Result<()> {
let health = provider.health()?;
assert_eq!(health.status, "ok", "provider health must be ok");
let search_results = provider.search(query)?;
for result in &search_results {
result.validate()?;
}
if let Some(first) = search_results.first() {
let described = provider
.describe(&first.provider_ref)?
.expect("describe must return the searched result");
described.validate()?;
assert_eq!(
described.provider_ref.provider_id, first.provider_ref.provider_id,
"describe must return the referenced item"
);
if first.can_expand {
let expanded = provider.expand(&first.provider_ref)?;
for result in &expanded {
result.validate()?;
}
}
}
Ok(())
}
pub(crate) fn assert_provider_returns_empty_search(
provider: &mut dyn RecallProvider,
query: &str,
) -> Result<()> {
let search_results = provider.search(query)?;
assert!(
search_results.is_empty(),
"provider should return no search results for query `{query}`"
);
Ok(())
}
}
#[cfg(test)]
mod tests {
use anyhow::anyhow;
use super::*;
#[derive(Clone)]
struct StubProvider {
search_results: Vec<RecallResult>,
}
struct FailingProvider {
name: &'static str,
}
impl RecallProvider for StubProvider {
fn descriptor(&self) -> ProviderDescriptor {
ProviderDescriptor {
name: "stub".to_owned(),
kind: "test".to_owned(),
}
}
fn capabilities(&self) -> RecallProviderCapabilities {
RecallProviderCapabilities::all()
}
fn health(&mut self) -> Result<RecallHealthView> {
Ok(RecallHealthView {
status: "ok".to_owned(),
message: None,
})
}
fn search(&mut self, _query: &str) -> Result<Vec<RecallResult>> {
Ok(self.search_results.clone())
}
fn describe(&mut self, reference: &ProviderReference) -> Result<Option<RecallResult>> {
Ok(self
.search_results
.iter()
.find(|result| result.provider_ref.provider_id == reference.provider_id)
.cloned())
}
fn expand(&mut self, _reference: &ProviderReference) -> Result<Vec<RecallResult>> {
Ok(Vec::new())
}
}
impl RecallProvider for FailingProvider {
fn descriptor(&self) -> ProviderDescriptor {
ProviderDescriptor {
name: self.name.to_owned(),
kind: "test".to_owned(),
}
}
fn capabilities(&self) -> RecallProviderCapabilities {
RecallProviderCapabilities::all()
}
fn health(&mut self) -> Result<RecallHealthView> {
Ok(RecallHealthView {
status: "ok".to_owned(),
message: None,
})
}
fn search(&mut self, _query: &str) -> Result<Vec<RecallResult>> {
Err(anyhow!("{} search failed", self.name))
}
fn describe(&mut self, _reference: &ProviderReference) -> Result<Option<RecallResult>> {
unreachable!("describe is never called when search fails")
}
fn expand(&mut self, _reference: &ProviderReference) -> Result<Vec<RecallResult>> {
unreachable!("expand is never called when search fails")
}
}
fn stub_result(id: &str, rank: u32) -> RecallResult {
RecallResult {
provider_ref: ProviderReference {
provider_name: "stub".to_owned(),
provider_kind: "test".to_owned(),
provider_id: id.to_owned(),
provider_scope: Some("clone".to_owned()),
},
ccd_scope: RecallScope::Clone,
title: format!("Title {id}"),
summary: format!("Summary {id}"),
rank,
preview: Some(format!("Preview {id}")),
detail: None,
can_expand: false,
memory_class_hints: Vec::new(),
provenance: RecallProvenance {
source_ref: Some("memory.md".to_owned()),
..RecallProvenance::default()
},
temporal: RecallTemporalHints::default(),
}
}
#[test]
fn start_recall_budget_truncates_search_and_describe() {
let mut provider = StubProvider {
search_results: vec![
stub_result("1", 1),
stub_result("2", 2),
stub_result("3", 3),
],
};
let view = run_start_recall(
&mut provider,
None,
Some("test query".to_owned()),
StartRecallBudget {
search_limit: 2,
describe_limit: 1,
expand_limit: 0,
},
);
assert_eq!(view.status, "loaded");
assert_eq!(view.search_results.len(), 2);
assert_eq!(view.described_results.len(), 1);
assert!(view.expanded_results.is_empty());
}
#[test]
fn start_recall_marks_attempted_fallback_when_both_providers_fail() {
let mut primary = FailingProvider { name: "primary" };
let mut fallback = FailingProvider { name: "fallback" };
let view = run_start_recall(
&mut primary,
Some(&mut fallback),
Some("test query".to_owned()),
StartRecallBudget::default(),
);
assert_eq!(view.status, "error");
assert!(view.fallback_used);
assert!(view
.error
.as_deref()
.expect("error")
.contains("fallback provider also failed"));
}
}