Skip to main content

post_cortex_mcp/
update_context.rs

1//! MCP-side adapter for context-update writes.
2//!
3//! Phase 6 of the single-entrypoint migration: every MCP-driven write
4//! flows through [`post_cortex_memory::services::MemoryServiceImpl`] —
5//! the canonical [`PostCortexService`] implementation. This module only
6//! translates the LLM-friendly wire format (HashMap content + typed
7//! `entities` / `relations` arrays) into the canonical
8//! [`UpdateContextRequest`]; validation, persistence, and metadata
9//! shaping all happen inside the service.
10
11use anyhow::{anyhow, Result};
12use post_cortex_core::core::context_update::{
13    CodeReference, EntityData, EntityRelationship, EntityType, RelationType, UpdateContent,
14    UpdateType,
15};
16use post_cortex_core::core::timeout_utils::with_mcp_timeout;
17use post_cortex_core::services::{
18    BulkUpdateContextRequest as ServiceBulkRequest, PostCortexService,
19    UpdateContextRequest as ServiceUpdateRequest,
20};
21use std::collections::HashMap;
22use tracing::{debug, error, info, instrument, warn};
23use uuid::Uuid;
24
25use crate::{
26    get_service, ContextUpdateItem, EntityItem, MCPToolResult, RelationItem,
27};
28
29/// Parse the LLM-provided `interaction_type + content` HashMap into a
30/// typed [`UpdateContent`] using the same key-resolution conventions the
31/// old MCP path used. Returns `Err` if the `interaction_type` is unknown
32/// so callers can surface a precise error.
33fn build_content(
34    interaction_type: &str,
35    content: &HashMap<String, String>,
36) -> Result<(UpdateType, UpdateContent)> {
37    let extract_extras = |exclude_keys: &[&str]| -> Vec<String> {
38        content
39            .iter()
40            .filter(|(k, _)| !exclude_keys.contains(&k.as_str()))
41            .map(|(k, v)| format!("{}: {}", k, v))
42            .collect()
43    };
44
45    let resolve_slot = |preferred: &[&str], fallback_keys: &[&str]| -> String {
46        for k in preferred.iter().chain(fallback_keys.iter()) {
47            if let Some(v) = content.get(*k) {
48                if !v.trim().is_empty() {
49                    return v.clone();
50                }
51            }
52        }
53        String::new()
54    };
55
56    let (update_type, title, description, details, implications) = match interaction_type {
57        "qa" => {
58            let title = resolve_slot(&["question"], &["title"]);
59            let description = resolve_slot(&["answer"], &["description"]);
60            let details = extract_extras(&["question", "answer", "title", "description"]);
61            (UpdateType::QuestionAnswered, title, description, details, vec![])
62        }
63        "code_change" => {
64            let title = resolve_slot(&["file_path", "file"], &["title", "description"]);
65            let description = resolve_slot(
66                &["changes", "diff", "change_type", "change"],
67                &["description"],
68            );
69            let details = extract_extras(&[
70                "file_path",
71                "file",
72                "title",
73                "description",
74                "changes",
75                "diff",
76                "change_type",
77                "change",
78            ]);
79            (
80                UpdateType::CodeChanged,
81                title,
82                description,
83                details,
84                vec!["Code functionality updated".to_string()],
85            )
86        }
87        "problem_solved" => {
88            let title = resolve_slot(&["problem"], &["title"]);
89            let description = resolve_slot(&["solution"], &["description"]);
90            let details = extract_extras(&["problem", "solution", "title", "description"]);
91            (
92                UpdateType::ProblemSolved,
93                title,
94                description,
95                details,
96                vec!["Problem resolved".to_string()],
97            )
98        }
99        "decision_made" => {
100            let title = resolve_slot(&["decision"], &["title"]);
101            let description = resolve_slot(&["rationale"], &["description"]);
102            let details = extract_extras(&["decision", "rationale", "title", "description"]);
103            (UpdateType::DecisionMade, title, description, details, vec![])
104        }
105        "requirement_added" => {
106            let title = resolve_slot(&["requirement"], &["title"]);
107            let description = resolve_slot(&["description"], &[]);
108            let details = extract_extras(&["requirement", "priority", "title", "description"]);
109            (
110                UpdateType::RequirementAdded,
111                title,
112                description,
113                details,
114                vec![],
115            )
116        }
117        "concept_defined" => {
118            let title = resolve_slot(&["concept"], &["title"]);
119            let description = resolve_slot(&["definition"], &["description"]);
120            let details = extract_extras(&["concept", "definition", "title", "description"]);
121            (UpdateType::ConceptDefined, title, description, details, vec![])
122        }
123        other => return Err(anyhow!("Unknown interaction type: {}", other)),
124    };
125
126    Ok((
127        update_type,
128        UpdateContent {
129            title,
130            description,
131            details,
132            examples: vec![],
133            implications,
134        },
135    ))
136}
137
138/// Parse an MCP `entity_type` string (lowercase) into [`EntityType`].
139/// Unknown values default to `Concept` to match the gRPC parser.
140fn parse_entity_type(s: &str) -> EntityType {
141    match s.to_lowercase().as_str() {
142        "technology" => EntityType::Technology,
143        "concept" => EntityType::Concept,
144        "problem" => EntityType::Problem,
145        "solution" => EntityType::Solution,
146        "decision" => EntityType::Decision,
147        "code_component" | "codecomponent" => EntityType::CodeComponent,
148        _ => EntityType::Concept,
149    }
150}
151
152/// Parse an MCP `relation_type` string (lowercase) into [`RelationType`].
153/// Returns `None` for unknown values — the caller surfaces this as an
154/// `InvalidArgument` error rather than silently defaulting.
155fn parse_relation_type(s: &str) -> Option<RelationType> {
156    match s.to_lowercase().as_str() {
157        "required_by" | "requiredby" => Some(RelationType::RequiredBy),
158        "leads_to" | "leadsto" => Some(RelationType::LeadsTo),
159        "related_to" | "relatedto" => Some(RelationType::RelatedTo),
160        "conflicts_with" | "conflictswith" => Some(RelationType::ConflictsWith),
161        "depends_on" | "dependson" => Some(RelationType::DependsOn),
162        "implements" => Some(RelationType::Implements),
163        "caused_by" | "causedby" => Some(RelationType::CausedBy),
164        "solves" => Some(RelationType::Solves),
165        _ => None,
166    }
167}
168
169fn entities_to_domain(items: &[EntityItem]) -> Vec<EntityData> {
170    let now = chrono::Utc::now();
171    items
172        .iter()
173        .map(|e| EntityData {
174            name: e.name.clone(),
175            entity_type: parse_entity_type(&e.entity_type),
176            first_mentioned: now,
177            last_mentioned: now,
178            mention_count: 1,
179            importance_score: 1.0,
180            description: None,
181        })
182        .collect()
183}
184
185fn relations_to_domain(items: &[RelationItem]) -> Result<Vec<EntityRelationship>> {
186    let mut out = Vec::with_capacity(items.len());
187    for (i, r) in items.iter().enumerate() {
188        let rt = parse_relation_type(&r.relation_type).ok_or_else(|| {
189            anyhow!(
190                "relation[{i}]: unknown relation_type {:?}; valid values are: \
191                 required_by, leads_to, related_to, conflicts_with, depends_on, implements, caused_by, solves",
192                r.relation_type
193            )
194        })?;
195        out.push(EntityRelationship {
196            from_entity: r.from_entity.clone(),
197            to_entity: r.to_entity.clone(),
198            relation_type: rt,
199            context: r.context.clone(),
200        });
201    }
202    Ok(out)
203}
204
205/// Build a canonical [`ServiceUpdateRequest`] from the MCP wire payload.
206fn build_request(
207    session_id: Uuid,
208    interaction_type: &str,
209    content: &HashMap<String, String>,
210    entities: &[EntityItem],
211    relations: &[RelationItem],
212    code_reference: Option<CodeReference>,
213) -> Result<ServiceUpdateRequest> {
214    let (update_type, update_content) = build_content(interaction_type, content)?;
215    Ok(ServiceUpdateRequest {
216        session_id,
217        interaction_type: update_type,
218        content: update_content,
219        entities: entities_to_domain(entities),
220        relations: relations_to_domain(relations)?,
221        code_reference,
222    })
223}
224
225/// Record a single context update via the canonical
226/// [`PostCortexService::update_context`] path.
227#[instrument(skip(content, entities, relations), fields(
228    session_id = %session_id,
229    interaction_type = %interaction_type,
230    entities_count = entities.len(),
231    relations_count = relations.len(),
232    has_code_reference = code_reference.is_some()
233))]
234pub async fn update_conversation_context(
235    interaction_type: String,
236    content: HashMap<String, String>,
237    entities: Vec<EntityItem>,
238    relations: Vec<RelationItem>,
239    code_reference: Option<CodeReference>,
240    session_id: Uuid,
241) -> Result<MCPToolResult> {
242    info!("MCP-TOOLS: update_conversation_context() called");
243    let service = get_service().await?;
244
245    let req = match build_request(
246        session_id,
247        &interaction_type,
248        &content,
249        &entities,
250        &relations,
251        code_reference,
252    ) {
253        Ok(r) => r,
254        Err(e) => {
255            error!("update_conversation_context: bad input — {}", e);
256            return Ok(MCPToolResult::error(e.to_string()));
257        }
258    };
259
260    let result = with_mcp_timeout(async {
261        match service.update_context(req).await {
262            Ok(resp) => {
263                debug!(
264                    "update_conversation_context: persisted entry {} in session {}",
265                    resp.entry_id, resp.session_id
266                );
267                Ok(MCPToolResult::success(
268                    "Context updated successfully".to_string(),
269                    None,
270                ))
271            }
272            Err(e) => {
273                warn!("update_conversation_context: service rejected — {}", e);
274                Ok(MCPToolResult::error(e.to_string()))
275            }
276        }
277    })
278    .await;
279
280    match result {
281        Ok(r) => r,
282        Err(timeout_error) => {
283            error!(
284                "TIMEOUT: update_conversation_context — session: {}, error: {}",
285                session_id, timeout_error
286            );
287            Ok(MCPToolResult::error(format!(
288                "Operation timed out: {}",
289                timeout_error
290            )))
291        }
292    }
293}
294
295/// Record multiple context updates in a single batch via the canonical
296/// service. Items that fail translation or persistence are reported in
297/// the response payload — the rest still land, matching the legacy
298/// gRPC bulk semantics.
299pub async fn bulk_update_conversation_context(
300    updates: Vec<ContextUpdateItem>,
301    session_id: Uuid,
302) -> Result<MCPToolResult> {
303    info!(
304        "MCP-TOOLS: bulk_update_conversation_context() called with {} updates for session {}",
305        updates.len(),
306        session_id
307    );
308
309    let service = get_service().await?;
310
311    let mut requests = Vec::with_capacity(updates.len());
312    let mut error_count = 0;
313    let mut errors: Vec<String> = Vec::new();
314    for (index, item) in updates.iter().enumerate() {
315        match build_request(
316            session_id,
317            &item.interaction_type,
318            &item.content,
319            &item.entities,
320            &item.relations,
321            item.code_reference.clone(),
322        ) {
323            Ok(req) => requests.push(req),
324            Err(e) => {
325                error_count += 1;
326                errors.push(format!("Update {}: {}", index, e));
327            }
328        }
329    }
330
331    // Persist via the canonical bulk method when every translation
332    // succeeded; otherwise fall back to per-item calls so partial
333    // failure semantics are preserved.
334    let success_count = if errors.is_empty() {
335        match service
336            .bulk_update_context(ServiceBulkRequest {
337                session_id,
338                updates: requests,
339            })
340            .await
341        {
342            Ok(resp) => resp.entry_ids.len(),
343            Err(e) => {
344                errors.push(format!("Bulk persist failed: {}", e));
345                error_count += 1;
346                0
347            }
348        }
349    } else {
350        // At least one item failed translation: keep the legacy
351        // "best effort" behaviour and persist the good ones one at a
352        // time so the caller still gets partial progress.
353        let mut count = 0;
354        for (offset, req) in requests.into_iter().enumerate() {
355            match service.update_context(req).await {
356                Ok(_) => count += 1,
357                Err(e) => {
358                    error_count += 1;
359                    errors.push(format!("Update (translated index {offset}): {}", e));
360                }
361            }
362        }
363        count
364    };
365
366    let message = if error_count == 0 {
367        format!(
368            "Bulk update completed successfully: {} updates added",
369            success_count
370        )
371    } else {
372        format!(
373            "Bulk update completed with errors: {} succeeded, {} failed",
374            success_count, error_count
375        )
376    };
377
378    Ok(MCPToolResult::success(
379        message,
380        Some(serde_json::json!({
381            "success_count": success_count,
382            "error_count": error_count,
383            "errors": errors,
384        })),
385    ))
386}