Skip to main content

locus_sdk/application/
memory_transform.rs

1use std::sync::Arc;
2
3use anyhow::Result;
4use chrono::Utc;
5use locus_core_rs::domain::contracts::NodeStore;
6use locus_core_rs::domain::models::{NodeQuery, NodeUpsertStatus};
7
8use crate::application::ai_router::route_embedding;
9use crate::application::memory_filters::{build_session_filter, node_matches_common_filters};
10use crate::domain::ai::{AiProviderRegistry, AiTask, EmbedRequest, ProviderPolicy};
11use crate::domain::memory::{
12    MemoryTransformOperation, MemoryTransformRequest, MemoryTransformResult, clamp_batch_size,
13    clamp_nodes,
14};
15
16pub struct MemoryTransformService {
17    store: Arc<dyn NodeStore>,
18    providers: Arc<dyn AiProviderRegistry>,
19}
20
21impl MemoryTransformService {
22    /// Create a transform service with storage and provider registry dependencies.
23    pub fn new(store: Arc<dyn NodeStore>, providers: Arc<dyn AiProviderRegistry>) -> Self {
24        Self { store, providers }
25    }
26
27    /// Execute a bulk memory transform operation.
28    ///
29    /// The current implementation supports embedding backfill with optional
30    /// dry-run behavior, batch control, and bounded failure reporting.
31    pub async fn execute(&self, request: &MemoryTransformRequest) -> Result<MemoryTransformResult> {
32        let started_at = Utc::now();
33        let max_nodes = clamp_nodes(if request.max_nodes == 0 {
34            5000
35        } else {
36            request.max_nodes
37        });
38        let batch_size = clamp_batch_size(if request.batch_size == 0 {
39            100
40        } else {
41            request.batch_size
42        });
43
44        let single_session = request
45            .scope
46            .session_ids
47            .as_deref()
48            .filter(|sessions| sessions.len() == 1)
49            .and_then(|sessions| sessions.first().cloned());
50
51        let nodes = self
52            .store
53            .query_nodes_async(NodeQuery {
54                limit: max_nodes,
55                session_id: single_session,
56                from_utc: request.scope.from_utc,
57                to_utc: request.scope.to_utc,
58                tiers: request.scope.tiers.clone(),
59            })
60            .await?;
61
62        let session_filter = build_session_filter(&request.scope);
63
64        let mut selected = nodes
65            .into_iter()
66            .filter(|node| {
67                node_matches_common_filters(node, &request.scope, &request.filter, session_filter.as_ref())
68            })
69            .collect::<Vec<_>>();
70
71        if request.operation == MemoryTransformOperation::EmbedBackfill {
72            selected.retain(|node| node.embedding.as_ref().is_none_or(|values| values.is_empty()));
73        }
74
75        let mut result = MemoryTransformResult {
76            scanned: selected.len(),
77            selected: selected.len(),
78            started_at,
79            completed_at: started_at,
80            ..Default::default()
81        };
82
83        if request.dry_run {
84            result.updated = result.selected;
85            result.completed_at = Utc::now();
86            return Ok(result);
87        }
88
89        for chunk in selected.chunks(batch_size) {
90            for mut node in chunk.iter().cloned() {
91                let Some(embedding_input) = build_embedding_input(node.context_summary.as_deref(), &node.session_id)
92                else {
93                    result.skipped += 1;
94                    continue;
95                };
96
97                let embed_request = EmbedRequest {
98                    text: embedding_input,
99                    task: AiTask::SemanticEmbedding,
100                    provider_id: request.provider_id.clone(),
101                    model: request.model.clone(),
102                    policy: if request.provider_id.is_some() {
103                        ProviderPolicy::Required
104                    } else {
105                        ProviderPolicy::Auto
106                    },
107                };
108
109                let vector = match route_embedding(self.providers.as_ref(), &embed_request).await {
110                    Ok(values) if !values.is_empty() => values,
111                    Ok(_) => {
112                        result.failed += 1;
113                        push_failure(
114                            &mut result.failures,
115                            format!("{}: embedding provider returned empty vector", node.sync_key),
116                        );
117                        continue;
118                    }
119                    Err(err) => {
120                        result.failed += 1;
121                        push_failure(
122                            &mut result.failures,
123                            format!("{}: embedding failed: {err}", node.sync_key),
124                        );
125                        continue;
126                    }
127                };
128
129                node.embedding_dimensions = Some(vector.len());
130                node.embedding_model = request
131                    .model
132                    .clone()
133                    .or_else(|| request.provider_id.clone())
134                    .or_else(|| Some("sdk-memory-transform".to_string()));
135                node.embedding = Some(vector);
136                node.embedded_at = Some(Utc::now());
137                node.updated_at = Utc::now();
138
139                match self.store.upsert_node_async(node).await {
140                    Ok(status) => match status.status {
141                        NodeUpsertStatus::Created | NodeUpsertStatus::Updated => result.updated += 1,
142                        NodeUpsertStatus::Duplicate => result.duplicate += 1,
143                        NodeUpsertStatus::Skipped => result.skipped += 1,
144                    },
145                    Err(err) => {
146                        result.failed += 1;
147                        push_failure(&mut result.failures, format!("store upsert failed: {err}"));
148                    }
149                }
150            }
151        }
152
153        result.completed_at = Utc::now();
154        Ok(result)
155    }
156}
157
158fn build_embedding_input(context_summary: Option<&str>, session_id: &str) -> Option<String> {
159    let summary = context_summary.and_then(|value| {
160        let trimmed = value.trim();
161        if trimmed.is_empty() {
162            None
163        } else {
164            Some(trimmed)
165        }
166    });
167
168    let session = {
169        let trimmed = session_id.trim();
170        if trimmed.is_empty() {
171            None
172        } else {
173            Some(trimmed)
174        }
175    };
176
177    match (summary, session) {
178        (Some(summary), Some(session)) => Some(format!("{summary}\nsession_id:{session}")),
179        (Some(summary), None) => Some(summary.to_string()),
180        (None, Some(session)) => Some(format!("session_id:{session}")),
181        (None, None) => None,
182    }
183}
184
185fn push_failure(failures: &mut Vec<String>, reason: String) {
186    if failures.len() < 100 {
187        failures.push(reason);
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use std::sync::Arc;
194
195    use anyhow::Result;
196    use async_trait::async_trait;
197    use chrono::Utc;
198    use locus_core_rs::{InMemoryNodeStore, NodeStore};
199    use locus_core_rs::domain::models::{AvecState, SttpNode};
200
201    use super::MemoryTransformService;
202    use crate::domain::ai::{
203        AiCapability, AiProvider, EmbedRequest, ScoreAvecRequest,
204    };
205    use crate::domain::memory::{MemoryTransformOperation, MemoryTransformRequest};
206    use crate::infrastructure::registry::InMemoryAiProviderRegistry;
207
208    struct MockEmbeddingProvider;
209
210    #[async_trait]
211    impl AiProvider for MockEmbeddingProvider {
212        fn provider_id(&self) -> &str {
213            "mock"
214        }
215
216        fn capabilities(&self) -> &'static [AiCapability] {
217            &[AiCapability::SemanticEmbedding]
218        }
219
220        async fn embed_semantic(&self, _request: &EmbedRequest) -> Result<Vec<f32>> {
221            Ok(vec![0.2, 0.3, 0.4])
222        }
223
224        async fn embed_avec(&self, _request: &EmbedRequest) -> Result<Vec<f32>> {
225            Ok(vec![0.2, 0.3, 0.4])
226        }
227
228        async fn score_avec(&self, _request: &ScoreAvecRequest) -> Result<AvecState> {
229            Ok(AvecState::zero())
230        }
231    }
232
233    #[tokio::test]
234    async fn dry_run_reports_selected_without_writes() {
235        let store: Arc<dyn NodeStore> = Arc::new(InMemoryNodeStore::new());
236        let node = test_node("dry-run", None);
237        store
238            .upsert_node_async(node)
239            .await
240            .expect("upsert should succeed");
241
242        let mut providers = InMemoryAiProviderRegistry::new();
243        providers.register(MockEmbeddingProvider);
244
245        let service = MemoryTransformService::new(store, Arc::new(providers));
246
247        let request = MemoryTransformRequest {
248            operation: MemoryTransformOperation::EmbedBackfill,
249            dry_run: true,
250            max_nodes: 100,
251            batch_size: 10,
252            ..Default::default()
253        };
254
255        let result = service.execute(&request).await.expect("transform should succeed");
256
257        assert_eq!(result.selected, 1);
258        assert_eq!(result.updated, 1);
259        assert_eq!(result.failed, 0);
260    }
261
262    #[tokio::test]
263    async fn embed_backfill_updates_missing_embedding_nodes() {
264        let store: Arc<dyn NodeStore> = Arc::new(InMemoryNodeStore::new());
265        let node = test_node("backfill", None);
266        store
267            .upsert_node_async(node)
268            .await
269            .expect("upsert should succeed");
270
271        let mut providers = InMemoryAiProviderRegistry::new();
272        providers.register(MockEmbeddingProvider);
273
274        let service = MemoryTransformService::new(store.clone(), Arc::new(providers));
275
276        let request = MemoryTransformRequest {
277            operation: MemoryTransformOperation::EmbedBackfill,
278            dry_run: false,
279            max_nodes: 100,
280            batch_size: 10,
281            ..Default::default()
282        };
283
284        let result = service.execute(&request).await.expect("transform should succeed");
285
286        assert_eq!(result.updated, 1);
287        assert_eq!(result.failed, 0);
288
289        let nodes = store
290            .query_nodes_async(locus_core_rs::domain::models::NodeQuery {
291                limit: 10,
292                session_id: Some("backfill".to_string()),
293                ..Default::default()
294            })
295            .await
296            .expect("query should succeed");
297
298        assert_eq!(nodes.len(), 1);
299        assert!(nodes[0].embedding.as_ref().is_some_and(|v| !v.is_empty()));
300    }
301
302    fn test_node(session_id: &str, embedding: Option<Vec<f32>>) -> SttpNode {
303        let now = Utc::now();
304        let user = AvecState {
305            stability: 0.6,
306            friction: 0.4,
307            logic: 0.8,
308            autonomy: 0.7,
309        };
310        let model = AvecState {
311            stability: 0.5,
312            friction: 0.3,
313            logic: 0.9,
314            autonomy: 0.6,
315        };
316
317        SttpNode {
318            raw: format!("raw:{session_id}"),
319            session_id: session_id.to_string(),
320            tier: "raw".to_string(),
321            timestamp: now,
322            compression_depth: 1,
323            parent_node_id: None,
324            sync_key: format!("{}:{}", session_id, now.timestamp_nanos_opt().unwrap_or_default()),
325            updated_at: now,
326            source_metadata: None,
327            context_summary: Some("summary".to_string()),
328            embedding_dimensions: embedding.as_ref().map(|v| v.len()),
329            embedding_model: embedding.as_ref().map(|_| "existing".to_string()),
330            embedding,
331            embedded_at: None,
332            user_avec: user,
333            model_avec: model,
334            compression_avec: Some(model),
335            rho: 0.9,
336            kappa: 0.8,
337            psi: 2.5,
338        }
339    }
340}