Skip to main content

nexus_memory_agent/
derive.rs

1//! Derivation service - converts raw memories into explicit observations.
2
3use std::collections::HashSet;
4use std::sync::Arc;
5
6use nexus_core::config::AgentConfig;
7use nexus_core::traits::EmbeddingService;
8use nexus_core::{
9    cognitive_level_from_metadata, infer_perspective, perspective_from_metadata, CognitiveLevel,
10    CognitiveMetadata, Memory, MemoryCategory, MemoryLaneType, PerspectiveKey, PerspectiveSource,
11};
12use nexus_llm::{ChatMessage, GenerateParams, LlmClient, LlmClientJson};
13use nexus_storage::models::EnqueueJobParams;
14use nexus_storage::repository::{
15    MemoryRepository, StoreMemoryParams, StoreMemoryWithLineageParams,
16};
17use serde_json::json;
18use tracing::{debug, info, warn};
19
20use crate::error::AgentError;
21use crate::prompts::{derive_user_prompt, DERIVE_SYSTEM_PROMPT};
22
23const DERIVE_MAX_TOKENS: u32 = 4096;
24const REFLECT_PERSPECTIVE_JOB: &str = "reflect_perspective";
25const DIGEST_SESSION_JOB: &str = "digest_session";
26const DERIVE_GENERATED_BY: &str = "derive_service";
27const DERIVED_FROM_ROLE: &str = "derived_from";
28const RAW_ACTIVITY_LABEL: &str = "raw-activity";
29const LOW_SIGNAL_LABEL: &str = "low-signal";
30
31pub struct DeriveService {
32    config: AgentConfig,
33    llm: Arc<dyn LlmClient>,
34    embeddings: Option<Arc<dyn EmbeddingService>>,
35}
36
37#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
38pub struct DerivedObservation {
39    pub content: String,
40    pub category: String,
41    pub memory_lane_type: Option<String>,
42    pub labels: Vec<String>,
43    pub confidence: f32,
44}
45
46#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
47struct DerivedObservationEnvelope {
48    observations: Vec<DerivedObservation>,
49}
50
51impl DeriveService {
52    pub fn new(
53        config: AgentConfig,
54        llm: Arc<dyn LlmClient>,
55        embeddings: Option<Arc<dyn EmbeddingService>>,
56    ) -> Self {
57        Self {
58            config,
59            llm,
60            embeddings,
61        }
62    }
63
64    pub async fn derive_memory(
65        &self,
66        memory: &Memory,
67        repo: &MemoryRepository,
68    ) -> Result<Vec<i64>, AgentError> {
69        self.derive_memory_with_perspective(memory, None, repo)
70            .await
71    }
72
73    pub async fn derive_memory_with_perspective(
74        &self,
75        memory: &Memory,
76        queued_perspective: Option<&PerspectiveKey>,
77        repo: &MemoryRepository,
78    ) -> Result<Vec<i64>, AgentError> {
79        if !is_derivable_source(memory) {
80            return Ok(Vec::new());
81        }
82
83        let existing_ids = existing_derived_ids(repo, memory.id).await?;
84        if !existing_ids.is_empty() {
85            debug!(
86                memory_id = memory.id,
87                derived_count = existing_ids.len(),
88                "Reusing existing derived observations"
89            );
90            return Ok(existing_ids);
91        }
92
93        let perspective = queued_perspective
94            .cloned()
95            .or_else(|| perspective_from_metadata(&memory.metadata))
96            .unwrap_or_else(|| {
97                infer_perspective(
98                    PerspectiveSource::HookIngest,
99                    self.config.namespace.clone(),
100                    None,
101                    None,
102                )
103            });
104
105        let observations = match self.derive_with_llm(memory).await {
106            Ok(observations) => observations,
107            Err(error) => {
108                warn!(memory_id = memory.id, %error, "LLM derivation failed, using fallback");
109                fallback_observations(memory)
110            }
111        };
112
113        let observations = normalize_observations(observations);
114        if observations.is_empty() {
115            debug!(memory_id = memory.id, "No explicit observations derived");
116            return Ok(Vec::new());
117        }
118
119        // --- PHASE 10 OPTIMIZATION: BATCH EMBEDDING ---
120        let mut derived_ids = Vec::with_capacity(observations.len());
121
122        // Prepare batch for embedding
123        let contents: Vec<String> = observations.iter().map(|o| o.content.clone()).collect();
124        let mut embeddings_map: std::collections::HashMap<usize, (Vec<f32>, String)> =
125            std::collections::HashMap::new();
126
127        if let Some(service) = self.embeddings.as_deref() {
128            match service.embed_batch(&contents).await {
129                Ok(results) if results.len() == contents.len() => {
130                    for (i, vec) in results.into_iter().enumerate() {
131                        embeddings_map.insert(i, (vec, service.model_name().to_string()));
132                    }
133                }
134                Ok(results) => {
135                    warn!(
136                        "embed_batch returned {} results for {} inputs in derive pipeline",
137                        results.len(),
138                        contents.len()
139                    );
140                }
141                Err(error) => {
142                    warn!(%error, "Batch embedding failed, falling back to individual or no embeddings");
143                }
144            }
145        }
146
147        for (i, observation) in observations.into_iter().enumerate() {
148            let category = MemoryCategory::parse(&observation.category).unwrap_or(memory.category);
149            let memory_lane_type = observation
150                .memory_lane_type
151                .as_deref()
152                .and_then(MemoryLaneType::parse);
153            let metadata = derive_metadata(memory, &perspective, observation.confidence);
154
155            // Check batch results
156            let (embedding, embedding_model) = if let Some((vec, model)) = embeddings_map.get(&i) {
157                (Some(vec.clone()), Some(model.clone()))
158            } else {
159                (None, None)
160            };
161
162            let derived = repo
163                .store_with_lineage(StoreMemoryWithLineageParams {
164                    store: StoreMemoryParams {
165                        namespace_id: memory.namespace_id,
166                        content: &observation.content,
167                        category: &category,
168                        memory_lane_type: memory_lane_type.as_ref(),
169                        labels: &observation.labels,
170                        metadata: &metadata,
171                        embedding: embedding.as_deref(),
172                        embedding_model: embedding_model.as_deref(),
173                    },
174                    source_memory_ids: &[memory.id],
175                    evidence_role: DERIVED_FROM_ROLE,
176                })
177                .await
178                .map_err(|error| AgentError::Storage(error.to_string()))?;
179            derived_ids.push(derived.id);
180        }
181
182        enqueue_follow_up_jobs(repo, memory, &perspective, &derived_ids, &self.config).await?;
183
184        info!(
185            memory_id = memory.id,
186            derived_count = derived_ids.len(),
187            "Derived explicit observations from raw memory"
188        );
189        Ok(derived_ids)
190    }
191
192    async fn derive_with_llm(
193        &self,
194        memory: &Memory,
195    ) -> Result<Vec<DerivedObservation>, AgentError> {
196        let params = GenerateParams {
197            messages: vec![
198                ChatMessage::system(DERIVE_SYSTEM_PROMPT),
199                ChatMessage::user(derive_user_prompt(memory)),
200            ],
201            max_tokens: DERIVE_MAX_TOKENS,
202            temperature: 0.1,
203            json_mode: true,
204        };
205
206        let envelope: DerivedObservationEnvelope = self
207            .llm
208            .generate_json(params)
209            .await
210            .map_err(|error| AgentError::Llm(error.to_string()))?;
211
212        Ok(envelope.observations)
213    }
214}
215
216fn derive_metadata(
217    source: &Memory,
218    perspective: &PerspectiveKey,
219    confidence: f32,
220) -> serde_json::Value {
221    let mut cognitive = CognitiveMetadata::new(
222        CognitiveLevel::Explicit,
223        perspective.observer.clone(),
224        perspective.subject.clone(),
225        perspective.session_key.clone(),
226        DERIVE_GENERATED_BY,
227    );
228    cognitive.source_memory_ids = vec![source.id];
229    cognitive.confidence = Some(confidence.max(0.70));
230    cognitive.merge_into(&sanitized_source_metadata(source))
231}
232
233fn fallback_observations(memory: &Memory) -> Vec<DerivedObservation> {
234    let summary = memory
235        .metadata
236        .get("agent")
237        .and_then(|agent| agent.get("summary"))
238        .and_then(serde_json::Value::as_str)
239        .map(str::trim)
240        .filter(|summary| summary.len() >= 16 && !looks_like_noise(summary))
241        .map(ToString::to_string);
242
243    let content = memory
244        .content
245        .split_whitespace()
246        .collect::<Vec<_>>()
247        .join(" ");
248
249    let candidate = summary.unwrap_or(content);
250    if candidate.is_empty() || looks_like_noise(&candidate) {
251        return Vec::new();
252    }
253
254    vec![DerivedObservation {
255        content: candidate,
256        category: memory.category.to_string(),
257        memory_lane_type: memory.memory_lane_type.as_ref().map(ToString::to_string),
258        labels: explicit_labels_from_source(memory),
259        confidence: 0.70,
260    }]
261}
262
263fn sanitized_source_metadata(source: &Memory) -> serde_json::Value {
264    let mut sanitized = serde_json::Map::new();
265
266    if let Some(agent) = source
267        .metadata
268        .get("agent")
269        .and_then(serde_json::Value::as_object)
270    {
271        let mut agent_sanitized = serde_json::Map::new();
272        for key in [
273            "summary",
274            "entities",
275            "topics",
276            "importance_score",
277            "source",
278            "generated_by",
279        ] {
280            if let Some(value) = agent.get(key) {
281                agent_sanitized.insert(key.to_string(), value.clone());
282            }
283        }
284        if !agent_sanitized.is_empty() {
285            sanitized.insert(
286                "agent".to_string(),
287                serde_json::Value::Object(agent_sanitized),
288            );
289        }
290    }
291
292    serde_json::Value::Object(sanitized)
293}
294
295fn explicit_labels_from_source(source: &Memory) -> Vec<String> {
296    let mut labels: Vec<String> = source
297        .labels
298        .iter()
299        .filter(|label| {
300            !label.eq_ignore_ascii_case(RAW_ACTIVITY_LABEL)
301                && !label.eq_ignore_ascii_case(LOW_SIGNAL_LABEL)
302        })
303        .cloned()
304        .collect();
305    dedupe_labels(&mut labels);
306    labels
307}
308
309fn normalize_observations(observations: Vec<DerivedObservation>) -> Vec<DerivedObservation> {
310    let mut seen = HashSet::new();
311    let mut normalized = Vec::new();
312
313    for mut observation in observations {
314        observation.content = observation.content.trim().to_string();
315        if observation.content.is_empty() || observation.confidence < 0.70 {
316            continue;
317        }
318
319        let fingerprint = observation.content.to_lowercase();
320        if !seen.insert(fingerprint) {
321            continue;
322        }
323
324        observation.labels.retain(|label| {
325            !label.eq_ignore_ascii_case(RAW_ACTIVITY_LABEL)
326                && !label.eq_ignore_ascii_case(LOW_SIGNAL_LABEL)
327        });
328        dedupe_labels(&mut observation.labels);
329        normalized.push(observation);
330    }
331
332    normalized
333}
334
335fn dedupe_labels(labels: &mut Vec<String>) {
336    let mut seen = HashSet::new();
337    labels.retain(|label| seen.insert(label.to_lowercase()));
338}
339
340fn looks_like_noise(text: &str) -> bool {
341    text.chars()
342        .all(|c| c.is_ascii_punctuation() || c.is_whitespace())
343}
344
345fn is_derivable_source(memory: &Memory) -> bool {
346    memory.category == MemoryCategory::Session
347        && memory.labels.iter().any(|l| l == RAW_ACTIVITY_LABEL)
348        && memory.metadata.get("cognitive").is_some()
349        && cognitive_level_from_metadata(&memory.metadata) == CognitiveLevel::Raw
350}
351
352async fn existing_derived_ids(
353    repo: &MemoryRepository,
354    source_id: i64,
355) -> Result<Vec<i64>, AgentError> {
356    let lineage = repo
357        .load_lineage(source_id)
358        .await
359        .map_err(|error| AgentError::Storage(error.to_string()))?;
360
361    Ok(lineage
362        .into_iter()
363        .filter(|entry| {
364            entry.source_memory_id == source_id && entry.evidence_role == DERIVED_FROM_ROLE
365        })
366        .map(|entry| entry.derived_memory_id)
367        .collect())
368}
369
370async fn enqueue_follow_up_jobs(
371    repo: &MemoryRepository,
372    source: &Memory,
373    perspective: &PerspectiveKey,
374    derived_ids: &[i64],
375    _config: &AgentConfig,
376) -> Result<(), AgentError> {
377    let perspective_json = serde_json::to_value(perspective).ok();
378
379    // Enqueue one reflect job per perspective (not per derived ID).
380    // reflect_perspective_cycle processes all candidates in the perspective group,
381    // so duplicate jobs are redundant work.
382    if !derived_ids.is_empty() {
383        let payload = json!({
384            "derived_count": derived_ids.len(),
385            "source": "derive_follow_up",
386            "reason": "new_explicit_observations",
387        });
388        repo.enqueue_job(EnqueueJobParams {
389            namespace_id: source.namespace_id,
390            job_type: REFLECT_PERSPECTIVE_JOB,
391            priority: 110,
392            perspective: perspective_json.as_ref(),
393            payload: &payload,
394        })
395        .await
396        .map_err(|error| AgentError::Storage(error.to_string()))?;
397    }
398
399    // Only enqueue session digest when session_key is present
400    if let Some(ref session_key) = perspective.session_key {
401        let payload = json!({
402            "session_key": session_key,
403            "reason": "post_derivation_rollover",
404        });
405        repo.enqueue_job(EnqueueJobParams {
406            namespace_id: source.namespace_id,
407            job_type: DIGEST_SESSION_JOB,
408            priority: 120,
409            perspective: perspective_json.as_ref(),
410            payload: &payload,
411        })
412        .await
413        .map_err(|error| AgentError::Storage(error.to_string()))?;
414    }
415
416    Ok(())
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422
423    use std::collections::VecDeque;
424    use std::sync::Mutex;
425
426    use async_trait::async_trait;
427    use nexus_llm::GenerateResponse;
428    use nexus_storage::repository::NamespaceRepository;
429    use sqlx::sqlite::SqlitePoolOptions;
430
431    use nexus_core::cognitive_level_from_metadata;
432
433    struct MockLlmClient {
434        responses: Mutex<VecDeque<nexus_llm::Result<GenerateResponse>>>,
435    }
436
437    impl MockLlmClient {
438        fn new(responses: Vec<nexus_llm::Result<GenerateResponse>>) -> Self {
439            Self {
440                responses: Mutex::new(VecDeque::from(responses)),
441            }
442        }
443    }
444
445    #[async_trait]
446    impl LlmClient for MockLlmClient {
447        async fn generate(&self, _params: GenerateParams) -> nexus_llm::Result<GenerateResponse> {
448            self.responses
449                .lock()
450                .expect("mock responses poisoned")
451                .pop_front()
452                .expect("mock response missing")
453        }
454
455        fn provider_name(&self) -> String {
456            "mock".to_string()
457        }
458
459        fn model_name(&self) -> String {
460            "mock-model".to_string()
461        }
462    }
463
464    async fn setup_repo() -> (sqlx::SqlitePool, MemoryRepository, i64) {
465        let pool = SqlitePoolOptions::new()
466            .max_connections(1)
467            .connect("sqlite::memory:")
468            .await
469            .unwrap();
470        nexus_storage::migrations::run_migrations(&pool)
471            .await
472            .unwrap();
473        let namespace_repo = NamespaceRepository::new(pool.clone());
474        let namespace = namespace_repo
475            .get_or_create("derive-test", "derive-test")
476            .await
477            .unwrap();
478        (pool.clone(), MemoryRepository::new(pool), namespace.id)
479    }
480
481    fn derive_response() -> GenerateResponse {
482        let envelope = DerivedObservationEnvelope {
483            observations: vec![DerivedObservation {
484                content: "Explicit observation from derivation.".to_string(),
485                category: "session".to_string(),
486                memory_lane_type: Some("process".to_string()),
487                labels: vec!["derived".to_string()],
488                confidence: 0.95,
489            }],
490        };
491        GenerateResponse {
492            content: serde_json::to_string(&envelope).unwrap(),
493            model: "mock-model".to_string(),
494            usage: None,
495        }
496    }
497
498    #[tokio::test]
499    async fn test_derive_memory_persists_explicit_observations_and_jobs() {
500        let (_pool, repo, namespace_id) = setup_repo().await;
501        let service = DeriveService::new(
502            AgentConfig::default(),
503            Arc::new(MockLlmClient::new(vec![Ok(derive_response())])),
504            None,
505        );
506
507        let metadata = json!({
508            "cognitive": {
509                "level": "raw",
510                "observer": "claude",
511                "subject": "claude",
512                "session_key": "sess-1",
513            }
514        });
515        let raw = repo
516            .store(StoreMemoryParams {
517                namespace_id,
518                content: "Raw implementation log.",
519                category: &MemoryCategory::Session,
520                memory_lane_type: None,
521                labels: &["raw-activity".to_string()],
522                metadata: &metadata,
523                embedding: None,
524                embedding_model: None,
525            })
526            .await
527            .unwrap();
528
529        let derived_ids = service.derive_memory(&raw, &repo).await.unwrap();
530        assert_eq!(derived_ids.len(), 1);
531
532        let derived = repo.get_by_id(derived_ids[0]).await.unwrap().unwrap();
533        assert_eq!(
534            cognitive_level_from_metadata(&derived.metadata),
535            CognitiveLevel::Explicit
536        );
537
538        let jobs = repo
539            .list_jobs(namespace_id, None, None, 10, 0)
540            .await
541            .unwrap();
542        assert!(jobs.iter().any(|j| j.job_type == REFLECT_PERSPECTIVE_JOB));
543        assert!(jobs.iter().any(|j| j.job_type == DIGEST_SESSION_JOB));
544    }
545}