Skip to main content

nexus_memory_agent/
util.rs

1//! Shared utility functions for agent services
2
3use nexus_core::traits::EmbeddingService;
4use nexus_core::{CognitiveLevel, CognitiveMetadata, Memory, PerspectiveKey};
5use nexus_llm::{GenerateResponse, TokenUsage};
6use nexus_storage::{MemoryRepository, MetricSample};
7use tracing::warn;
8
9#[derive(Debug, Clone)]
10pub struct CognitionSnapshot {
11    pub level: CognitiveLevel,
12    pub confidence: Option<f32>,
13    pub perspective: Option<PerspectiveKey>,
14    pub generated_by: Option<String>,
15    pub times_reinforced: i64,
16    pub raw_activity: bool,
17}
18
19impl CognitionSnapshot {
20    /// Build a cognition snapshot from a memory's metadata.
21    ///
22    /// Parses `CognitiveMetadata` exactly once and reads all fields from the parsed
23    /// struct.  Previous versions called `cognitive_level_from_metadata` and
24    /// `perspective_from_metadata` as fallbacks, each of which re-parsed the same
25    /// JSON — up to three redundant deserializations per call.
26    pub fn from_memory(memory: &Memory) -> Self {
27        let cognitive = CognitiveMetadata::from_metadata(&memory.metadata);
28        let level = cognitive
29            .as_ref()
30            .map_or(CognitiveLevel::Raw, |value| value.level);
31        let perspective = cognitive.as_ref().map(|value| PerspectiveKey {
32            observer: value.observer.clone(),
33            subject: value.subject.clone(),
34            session_key: value.session_key.clone(),
35        });
36        let raw_activity = memory.labels.iter().any(|label| label == "raw-activity")
37            || memory
38                .metadata
39                .get("raw_activity")
40                .and_then(serde_json::Value::as_bool)
41                .unwrap_or(false);
42
43        Self {
44            level,
45            confidence: cognitive.as_ref().and_then(|value| value.confidence),
46            perspective,
47            generated_by: cognitive
48                .as_ref()
49                .and_then(|value| value.generated_by.clone()),
50            times_reinforced: cognitive
51                .as_ref()
52                .map(|value| value.times_reinforced)
53                .unwrap_or(0),
54            raw_activity,
55        }
56    }
57
58    pub fn confidence_meets_threshold(&self) -> bool {
59        let confidence = self.confidence.unwrap_or(1.0);
60        match self.level {
61            CognitiveLevel::Explicit => confidence >= 0.70,
62            CognitiveLevel::Derived => confidence >= 0.75,
63            CognitiveLevel::Contradiction => confidence >= 0.80,
64            CognitiveLevel::SummaryShort | CognitiveLevel::SummaryLong | CognitiveLevel::Raw => {
65                true
66            }
67        }
68    }
69}
70
71/// Extract the `agent.summary` field from JSON metadata, falling back to a
72/// truncated content excerpt.
73pub fn extract_agent_summary(metadata: &str, content: &str, fallback_chars: usize) -> String {
74    #[derive(serde::Deserialize)]
75    struct AgentMeta {
76        summary: Option<String>,
77    }
78
79    #[derive(serde::Deserialize)]
80    struct Metadata {
81        agent: Option<AgentMeta>,
82    }
83
84    serde_json::from_str::<Metadata>(metadata)
85        .ok()
86        .and_then(|md| md.agent)
87        .and_then(|a| a.summary)
88        .unwrap_or_else(|| content.chars().take(fallback_chars).collect())
89}
90
91/// Persist a best-effort stage timing metric without impacting cognition flow.
92pub async fn record_stage_metric(
93    repo: &MemoryRepository,
94    namespace_id: i64,
95    metric_name: &str,
96    metric_value_ms: f64,
97    stage: &str,
98) {
99    let labels = serde_json::json!({
100        "namespace_id": namespace_id,
101        "stage": stage,
102        "unit": "ms",
103    });
104
105    if let Err(error) = repo
106        .record_metric(metric_name, metric_value_ms, &labels)
107        .await
108    {
109        warn!(
110            %error,
111            namespace_id,
112            metric_name,
113            stage,
114            "Failed to persist cognition stage metric"
115        );
116    }
117}
118
119pub fn stage_metric_sample(
120    namespace_id: i64,
121    metric_name: &str,
122    metric_value_ms: f64,
123    stage: &str,
124) -> MetricSample {
125    MetricSample {
126        metric_name: metric_name.to_string(),
127        metric_value: metric_value_ms,
128        labels: serde_json::json!({
129            "namespace_id": namespace_id,
130            "stage": stage,
131            "unit": "ms",
132        }),
133    }
134}
135
136/// Persist best-effort token usage metrics for a cognition stage.
137pub async fn record_token_usage_metrics(
138    repo: &MemoryRepository,
139    namespace_id: i64,
140    metric_prefix: &str,
141    stage: &str,
142    usage: Option<&TokenUsage>,
143) {
144    let Some(usage) = usage else {
145        return;
146    };
147
148    for (suffix, value) in [
149        ("prompt_tokens", usage.prompt_tokens as f64),
150        ("completion_tokens", usage.completion_tokens as f64),
151        ("total_tokens", usage.total_tokens as f64),
152    ] {
153        let metric_name = format!("{metric_prefix}.{suffix}");
154        let labels = serde_json::json!({
155            "namespace_id": namespace_id,
156            "stage": stage,
157            "unit": "tokens",
158        });
159
160        if let Err(error) = repo.record_metric(&metric_name, value, &labels).await {
161            warn!(
162                %error,
163                namespace_id,
164                metric_name,
165                stage,
166                "Failed to persist cognition token usage metric"
167            );
168        }
169    }
170}
171
172pub fn token_usage_metric_samples(
173    namespace_id: i64,
174    metric_prefix: &str,
175    stage: &str,
176    usage: Option<&TokenUsage>,
177) -> Vec<MetricSample> {
178    let Some(usage) = usage else {
179        return Vec::new();
180    };
181
182    [
183        ("prompt_tokens", usage.prompt_tokens as f64),
184        ("completion_tokens", usage.completion_tokens as f64),
185        ("total_tokens", usage.total_tokens as f64),
186    ]
187    .into_iter()
188    .map(|(suffix, value)| MetricSample {
189        metric_name: format!("{metric_prefix}.{suffix}"),
190        metric_value: value,
191        labels: serde_json::json!({
192            "namespace_id": namespace_id,
193            "stage": stage,
194            "unit": "tokens",
195        }),
196    })
197    .collect()
198}
199
200pub async fn flush_metric_samples(repo: &MemoryRepository, samples: &[MetricSample]) {
201    if samples.is_empty() {
202        return;
203    }
204
205    if let Err(error) = repo.record_metrics_batch(samples).await {
206        warn!(%error, count = samples.len(), "Failed to persist cognition metric batch");
207    }
208}
209
210/// Parse a JSON response using the same fenced-block tolerance as `LlmClientJson`.
211pub fn parse_json_response<T: serde::de::DeserializeOwned>(
212    response: &GenerateResponse,
213) -> Result<T, serde_json::Error> {
214    let content = response.content.trim();
215    let json_str = if content.starts_with("```") {
216        let start = content.find('\n').map(|i| i + 1).unwrap_or(3);
217        let end = content[start..]
218            .rfind("```")
219            .map(|i| start + i)
220            .unwrap_or(content.len());
221        if start >= end {
222            content
223        } else {
224            &content[start..end]
225        }
226    } else {
227        content
228    };
229
230    serde_json::from_str(json_str.trim())
231}
232
233/// Attempt to generate an embedding for `content` using the optional service.
234///
235/// Returns `(Some(vector), Some(model_name))` on success, `(None, None)` when
236/// the service is absent or the call fails (graceful degradation).
237pub async fn maybe_embed(
238    service: Option<&dyn EmbeddingService>,
239    content: &str,
240) -> (Option<Vec<f32>>, Option<String>) {
241    let Some(svc) = service else {
242        return (None, None);
243    };
244    match svc.embed(content).await {
245        Ok(vec) => (Some(vec), Some(svc.model_name().to_string())),
246        Err(error) => {
247            warn!(%error, "Embedding generation failed, storing without embedding");
248            (None, None)
249        }
250    }
251}