Skip to main content

nexus_memory_agent/
derive.rs

1//! Derivation service - converts raw memories into explicit observations.
2
3use std::collections::HashSet;
4use std::sync::Arc;
5
6use chrono::Utc;
7use nexus_core::config::AgentConfig;
8use nexus_core::traits::EmbeddingService;
9use nexus_core::{
10    cognitive_level_from_metadata, infer_perspective, perspective_from_metadata, CognitiveLevel,
11    CognitiveMetadata, Memory, MemoryCategory, MemoryLaneType, PerspectiveKey, PerspectiveSource,
12};
13use nexus_llm::{ChatMessage, GenerateParams, LlmClient, LlmClientJson};
14use nexus_storage::models::EnqueueJobParams;
15use nexus_storage::repository::{
16    MemoryRepository, StoreMemoryParams, StoreMemoryWithLineageParams,
17};
18use serde_json::json;
19use tracing::{debug, info, warn};
20
21use crate::error::AgentError;
22use crate::prompts::{derive_user_prompt, DERIVE_SYSTEM_PROMPT};
23
24const DERIVE_MAX_TOKENS: u32 = 4096;
25const REFLECT_PERSPECTIVE_JOB: &str = "reflect_perspective";
26const DIGEST_SESSION_JOB: &str = "digest_session";
27const DERIVE_GENERATED_BY: &str = "derive_service";
28const DERIVED_FROM_ROLE: &str = "derived_from";
29const RAW_ACTIVITY_LABEL: &str = "raw-activity";
30const LOW_SIGNAL_LABEL: &str = "low-signal";
31
32pub struct DeriveService {
33    config: AgentConfig,
34    llm: Arc<dyn LlmClient>,
35    embeddings: Option<Arc<dyn EmbeddingService>>,
36}
37
38#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
39pub struct DerivedObservation {
40    pub content: String,
41    pub category: String,
42    pub memory_lane_type: Option<String>,
43    pub labels: Vec<String>,
44    pub confidence: f32,
45}
46
47#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
48struct DerivedObservationEnvelope {
49    observations: Vec<DerivedObservation>,
50}
51
52impl DeriveService {
53    pub fn new(
54        config: AgentConfig,
55        llm: Arc<dyn LlmClient>,
56        embeddings: Option<Arc<dyn EmbeddingService>>,
57    ) -> Self {
58        Self {
59            config,
60            llm,
61            embeddings,
62        }
63    }
64
65    pub async fn derive_memory(
66        &self,
67        memory: &Memory,
68        repo: &MemoryRepository,
69    ) -> Result<Vec<i64>, AgentError> {
70        self.derive_memory_with_perspective(memory, None, repo)
71            .await
72    }
73
74    pub async fn derive_memory_with_perspective(
75        &self,
76        memory: &Memory,
77        queued_perspective: Option<&PerspectiveKey>,
78        repo: &MemoryRepository,
79    ) -> Result<Vec<i64>, AgentError> {
80        if !is_derivable_source(memory) {
81            return Ok(Vec::new());
82        }
83
84        let existing_ids = existing_derived_ids(repo, memory.id).await?;
85        if !existing_ids.is_empty() {
86            debug!(
87                memory_id = memory.id,
88                derived_count = existing_ids.len(),
89                "Reusing existing derived observations"
90            );
91            return Ok(existing_ids);
92        }
93
94        let perspective = queued_perspective
95            .cloned()
96            .or_else(|| perspective_from_metadata(&memory.metadata))
97            .unwrap_or_else(|| {
98                infer_perspective(
99                    PerspectiveSource::HookIngest,
100                    self.config.namespace.clone(),
101                    None,
102                    None,
103                )
104            });
105
106        let observations = match self.derive_with_llm(memory).await {
107            Ok(observations) => observations,
108            Err(error) => {
109                warn!(memory_id = memory.id, %error, "LLM derivation failed, using fallback");
110                fallback_observations(memory)
111            }
112        };
113
114        let mut observations = normalize_observations(observations);
115        if observations.is_empty() {
116            debug!(memory_id = memory.id, "No explicit observations derived");
117            return Ok(Vec::new());
118        }
119
120        // Propagate low-signal marker from source to all derived observations.
121        // Raw-activity is NOT propagated — it's only for the raw memories themselves.
122        // Low-signal derived memories should be excluded from reflection but still eligible for digest.
123        let has_low_signal = memory.labels.iter().any(|l| l == LOW_SIGNAL_LABEL);
124        for obs in &mut observations {
125            if has_low_signal && !obs.labels.iter().any(|l| l == LOW_SIGNAL_LABEL) {
126                obs.labels.push(LOW_SIGNAL_LABEL.to_string());
127            }
128        }
129
130        // --- PHASE 10 OPTIMIZATION: BATCH EMBEDDING ---
131        let mut derived_ids = Vec::with_capacity(observations.len());
132
133        // Prepare batch for embedding
134        let contents: Vec<String> = observations.iter().map(|o| o.content.clone()).collect();
135        let mut embeddings_map: std::collections::HashMap<usize, (Vec<f32>, String)> =
136            std::collections::HashMap::new();
137
138        if let Some(service) = self.embeddings.as_deref() {
139            match service.embed_batch(&contents).await {
140                Ok(results) if results.len() == contents.len() => {
141                    for (i, vec) in results.into_iter().enumerate() {
142                        embeddings_map.insert(i, (vec, service.model_name().to_string()));
143                    }
144                }
145                Ok(results) => {
146                    warn!(
147                        "embed_batch returned {} results for {} inputs in derive pipeline",
148                        results.len(),
149                        contents.len()
150                    );
151                }
152                Err(error) => {
153                    warn!(%error, "Batch embedding failed, falling back to individual or no embeddings");
154                }
155            }
156        }
157
158        for (i, observation) in observations.into_iter().enumerate() {
159            let category = MemoryCategory::parse(&observation.category).unwrap_or(memory.category);
160            let memory_lane_type = observation
161                .memory_lane_type
162                .as_deref()
163                .and_then(MemoryLaneType::parse);
164            let metadata = derive_metadata(memory, &perspective, observation.confidence);
165
166            // Check batch results
167            let (embedding, embedding_model) = if let Some((vec, model)) = embeddings_map.get(&i) {
168                (Some(vec.clone()), Some(model.clone()))
169            } else {
170                (None, None)
171            };
172
173            let derived = repo
174                .store_with_lineage(StoreMemoryWithLineageParams {
175                    store: StoreMemoryParams {
176                        namespace_id: memory.namespace_id,
177                        content: &observation.content,
178                        category: &category,
179                        memory_lane_type: memory_lane_type.as_ref(),
180                        labels: &observation.labels,
181                        metadata: &metadata,
182                        embedding: embedding.as_deref(),
183                        embedding_model: embedding_model.as_deref(),
184                    },
185                    source_memory_ids: &[memory.id],
186                    evidence_role: DERIVED_FROM_ROLE,
187                })
188                .await
189                .map_err(|error| AgentError::Storage(error.to_string()))?;
190            derived_ids.push(derived.id);
191        }
192
193        enqueue_follow_up_jobs(repo, memory, &perspective, &derived_ids, &self.config).await?;
194
195        info!(
196            memory_id = memory.id,
197            derived_count = derived_ids.len(),
198            "Derived explicit observations from raw memory"
199        );
200        Ok(derived_ids)
201    }
202
203    async fn derive_with_llm(
204        &self,
205        memory: &Memory,
206    ) -> Result<Vec<DerivedObservation>, AgentError> {
207        let params = GenerateParams {
208            messages: vec![
209                ChatMessage::system(DERIVE_SYSTEM_PROMPT),
210                ChatMessage::user(derive_user_prompt(memory)),
211            ],
212            max_tokens: DERIVE_MAX_TOKENS,
213            temperature: 0.1,
214            json_mode: true,
215        };
216
217        let envelope: DerivedObservationEnvelope = self
218            .llm
219            .generate_json(params)
220            .await
221            .map_err(|error| AgentError::Llm(error.to_string()))?;
222
223        Ok(envelope.observations)
224    }
225}
226
227fn derive_metadata(
228    source: &Memory,
229    perspective: &PerspectiveKey,
230    confidence: f32,
231) -> serde_json::Value {
232    let mut cognitive = CognitiveMetadata::new(
233        CognitiveLevel::Explicit,
234        perspective.observer.clone(),
235        perspective.subject.clone(),
236        perspective.session_key.clone(),
237        DERIVE_GENERATED_BY,
238    );
239    cognitive.source_memory_ids = vec![source.id];
240    cognitive.confidence = Some(confidence.max(0.70));
241    cognitive.times_reinforced = 0;
242    cognitive.times_contradicted = 0;
243    cognitive.derived_at = Some(Utc::now());
244    cognitive.generated_by = Some(DERIVE_GENERATED_BY.to_string());
245    cognitive.merge_into(&sanitized_source_metadata(source))
246}
247
248fn fallback_observations(memory: &Memory) -> Vec<DerivedObservation> {
249    let summary = memory
250        .metadata
251        .get("agent")
252        .and_then(|agent| agent.get("summary"))
253        .and_then(serde_json::Value::as_str)
254        .map(str::trim)
255        .filter(|summary| summary.len() >= 16 && !looks_like_noise(summary))
256        .map(ToString::to_string);
257
258    let content = memory
259        .content
260        .split_whitespace()
261        .collect::<Vec<_>>()
262        .join(" ");
263
264    let candidate = summary.unwrap_or(content);
265    if candidate.is_empty() || looks_like_noise(&candidate) {
266        return Vec::new();
267    }
268
269    vec![DerivedObservation {
270        content: candidate,
271        category: memory.category.to_string(),
272        memory_lane_type: memory.memory_lane_type.as_ref().map(ToString::to_string),
273        labels: explicit_labels_from_source(memory),
274        confidence: 0.70,
275    }]
276}
277
278fn sanitized_source_metadata(source: &Memory) -> serde_json::Value {
279    let mut sanitized = serde_json::Map::new();
280
281    if let Some(agent) = source
282        .metadata
283        .get("agent")
284        .and_then(serde_json::Value::as_object)
285    {
286        let mut agent_sanitized = serde_json::Map::new();
287        for key in [
288            "summary",
289            "entities",
290            "topics",
291            "importance_score",
292            "source",
293            "generated_by",
294        ] {
295            if let Some(value) = agent.get(key) {
296                agent_sanitized.insert(key.to_string(), value.clone());
297            }
298        }
299        if !agent_sanitized.is_empty() {
300            sanitized.insert(
301                "agent".to_string(),
302                serde_json::Value::Object(agent_sanitized),
303            );
304        }
305    }
306
307    serde_json::Value::Object(sanitized)
308}
309
310fn explicit_labels_from_source(source: &Memory) -> Vec<String> {
311    let mut labels = source.labels.clone();
312    dedupe_labels(&mut labels);
313    labels
314}
315
316fn normalize_observations(observations: Vec<DerivedObservation>) -> Vec<DerivedObservation> {
317    let mut seen = HashSet::new();
318    let mut normalized = Vec::new();
319
320    for mut observation in observations {
321        observation.content = observation.content.trim().to_string();
322        if observation.content.is_empty() || observation.confidence < 0.70 {
323            continue;
324        }
325
326        let fingerprint = observation.content.to_lowercase();
327        if !seen.insert(fingerprint) {
328            continue;
329        }
330
331        dedupe_labels(&mut observation.labels);
332        normalized.push(observation);
333    }
334
335    normalized
336}
337
338fn dedupe_labels(labels: &mut Vec<String>) {
339    let mut seen = HashSet::new();
340    labels.retain(|label| seen.insert(label.to_lowercase()));
341}
342
343fn looks_like_noise(text: &str) -> bool {
344    text.chars()
345        .all(|c| c.is_ascii_punctuation() || c.is_whitespace())
346}
347
348fn is_derivable_source(memory: &Memory) -> bool {
349    memory.category == MemoryCategory::Session
350        && memory.labels.iter().any(|l| l == RAW_ACTIVITY_LABEL)
351        && memory.metadata.get("cognitive").is_some()
352        && cognitive_level_from_metadata(&memory.metadata) == CognitiveLevel::Raw
353}
354
355async fn existing_derived_ids(
356    repo: &MemoryRepository,
357    source_id: i64,
358) -> Result<Vec<i64>, AgentError> {
359    let lineage = repo
360        .load_lineage(source_id)
361        .await
362        .map_err(|error| AgentError::Storage(error.to_string()))?;
363
364    Ok(lineage
365        .into_iter()
366        .filter(|entry| {
367            entry.source_memory_id == source_id && entry.evidence_role == DERIVED_FROM_ROLE
368        })
369        .map(|entry| entry.derived_memory_id)
370        .collect())
371}
372
373async fn enqueue_follow_up_jobs(
374    repo: &MemoryRepository,
375    source: &Memory,
376    perspective: &PerspectiveKey,
377    derived_ids: &[i64],
378    _config: &AgentConfig,
379) -> Result<(), AgentError> {
380    let perspective_json = serde_json::to_value(perspective).ok();
381
382    // Enqueue one reflect job per perspective (not per derived ID).
383    // reflect_perspective_cycle processes all candidates in the perspective group,
384    // so duplicate jobs are redundant work.
385    if !derived_ids.is_empty() {
386        let payload = json!({
387            "derived_count": derived_ids.len(),
388            "source": "derive_follow_up",
389            "reason": "new_explicit_observations",
390        });
391        repo.enqueue_job(EnqueueJobParams {
392            namespace_id: source.namespace_id,
393            job_type: REFLECT_PERSPECTIVE_JOB,
394            priority: 110,
395            perspective: perspective_json.as_ref(),
396            payload: &payload,
397        })
398        .await
399        .map_err(|error| AgentError::Storage(error.to_string()))?;
400    }
401
402    // Only enqueue session digest when session_key is present
403    if let Some(ref session_key) = perspective.session_key {
404        let payload = json!({
405            "session_key": session_key,
406            "reason": "post_derivation_rollover",
407        });
408        repo.enqueue_job(EnqueueJobParams {
409            namespace_id: source.namespace_id,
410            job_type: DIGEST_SESSION_JOB,
411            priority: 120,
412            perspective: perspective_json.as_ref(),
413            payload: &payload,
414        })
415        .await
416        .map_err(|error| AgentError::Storage(error.to_string()))?;
417    }
418
419    Ok(())
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425
426    use std::collections::VecDeque;
427    use std::sync::Mutex;
428
429    use async_trait::async_trait;
430    use nexus_llm::GenerateResponse;
431    use nexus_storage::repository::NamespaceRepository;
432    use sqlx::sqlite::SqlitePoolOptions;
433
434    use nexus_core::cognitive_level_from_metadata;
435
436    struct MockLlmClient {
437        responses: Mutex<VecDeque<nexus_llm::Result<GenerateResponse>>>,
438    }
439
440    impl MockLlmClient {
441        fn new(responses: Vec<nexus_llm::Result<GenerateResponse>>) -> Self {
442            Self {
443                responses: Mutex::new(VecDeque::from(responses)),
444            }
445        }
446    }
447
448    #[async_trait]
449    impl LlmClient for MockLlmClient {
450        async fn generate(&self, _params: GenerateParams) -> nexus_llm::Result<GenerateResponse> {
451            self.responses
452                .lock()
453                .expect("mock responses poisoned")
454                .pop_front()
455                .expect("mock response missing")
456        }
457
458        fn provider_name(&self) -> String {
459            "mock".to_string()
460        }
461
462        fn model_name(&self) -> String {
463            "mock-model".to_string()
464        }
465    }
466
467    async fn setup_repo() -> (sqlx::SqlitePool, MemoryRepository, i64) {
468        let pool = SqlitePoolOptions::new()
469            .max_connections(1)
470            .connect("sqlite::memory:")
471            .await
472            .unwrap();
473        nexus_storage::migrations::run_migrations(&pool)
474            .await
475            .unwrap();
476        let namespace_repo = NamespaceRepository::new(pool.clone());
477        let namespace = namespace_repo
478            .get_or_create("derive-test", "derive-test")
479            .await
480            .unwrap();
481        (pool.clone(), MemoryRepository::new(pool), namespace.id)
482    }
483
484    fn derive_response() -> GenerateResponse {
485        let envelope = DerivedObservationEnvelope {
486            observations: vec![DerivedObservation {
487                content: "Explicit observation from derivation.".to_string(),
488                category: "session".to_string(),
489                memory_lane_type: Some("process".to_string()),
490                labels: vec!["derived".to_string()],
491                confidence: 0.95,
492            }],
493        };
494        GenerateResponse {
495            content: serde_json::to_string(&envelope).unwrap(),
496            model: "mock-model".to_string(),
497            usage: None,
498        }
499    }
500
501    #[tokio::test]
502    async fn test_derive_memory_persists_explicit_observations_and_jobs() {
503        let (_pool, repo, namespace_id) = setup_repo().await;
504        let service = DeriveService::new(
505            AgentConfig::default(),
506            Arc::new(MockLlmClient::new(vec![Ok(derive_response())])),
507            None,
508        );
509
510        let metadata = json!({
511            "cognitive": {
512                "level": "raw",
513                "observer": "claude",
514                "subject": "claude",
515                "session_key": "sess-1",
516            }
517        });
518        let raw = repo
519            .store(StoreMemoryParams {
520                namespace_id,
521                content: "Raw implementation log.",
522                category: &MemoryCategory::Session,
523                memory_lane_type: None,
524                labels: &["raw-activity".to_string()],
525                metadata: &metadata,
526                embedding: None,
527                embedding_model: None,
528            })
529            .await
530            .unwrap();
531
532        let derived_ids = service.derive_memory(&raw, &repo).await.unwrap();
533        assert_eq!(derived_ids.len(), 1);
534
535        let derived = repo.get_by_id(derived_ids[0]).await.unwrap().unwrap();
536        assert_eq!(
537            cognitive_level_from_metadata(&derived.metadata),
538            CognitiveLevel::Explicit
539        );
540
541        let jobs = repo
542            .list_jobs(namespace_id, None, None, 10, 0)
543            .await
544            .unwrap();
545        assert!(jobs.iter().any(|j| j.job_type == REFLECT_PERSPECTIVE_JOB));
546        assert!(jobs.iter().any(|j| j.job_type == DIGEST_SESSION_JOB));
547    }
548}