use std::sync::{Arc, Mutex};
use chrono::{DateTime, Utc};
use cortex_core::{DecayJobId, EpisodeId, MemoryId, PrincipleId};
use cortex_memory::decay::{
DecayJob, DecayJobKind, DecayJobState, SummaryMethod, SUMMARY_METHOD_NONE_WIRE,
};
use cortex_store::repo::{DecayJobRecord, DecayJobRepo};
use cortex_store::Pool;
use serde_json::{json, Value};
use tracing::info;
use crate::tool_handler::{GateId, ToolError, ToolHandler};
const VALID_KINDS: &[&str] = &[
"episode_compression",
"candidate_compression",
"expired_principle_review",
];
const VALID_SUMMARY_METHODS: &[&str] = &["deterministic_concatenate", "llm_summary"];
#[derive(Debug)]
pub struct CortexDecayScheduleTool {
pool: Arc<Mutex<Pool>>,
}
impl CortexDecayScheduleTool {
#[must_use]
pub fn new(pool: Arc<Mutex<Pool>>) -> Self {
Self { pool }
}
}
impl ToolHandler for CortexDecayScheduleTool {
fn name(&self) -> &'static str {
"cortex_decay_schedule"
}
fn gate_set(&self) -> &'static [GateId] {
&[GateId::SessionWrite]
}
fn call(&self, params: Value) -> Result<Value, ToolError> {
let kind_str = params["kind"]
.as_str()
.filter(|s| !s.is_empty())
.ok_or_else(|| ToolError::InvalidParams("kind is required".into()))?;
if !VALID_KINDS.contains(&kind_str) {
return Err(ToolError::InvalidParams(format!(
"kind must be one of {VALID_KINDS:?}, got `{kind_str}`"
)));
}
let summary_method_str = params["summary_method"]
.as_str()
.filter(|s| !s.is_empty())
.ok_or_else(|| ToolError::InvalidParams("summary_method is required".into()))?;
if !VALID_SUMMARY_METHODS.contains(&summary_method_str) {
return Err(ToolError::InvalidParams(format!(
"summary_method must be one of {VALID_SUMMARY_METHODS:?}, got `{summary_method_str}`"
)));
}
if summary_method_str == "llm_summary" {
return Err(ToolError::InvalidParams(
"summary_method `llm_summary` is not supported via MCP; \
use `cortex decay schedule --summary-method llm` via the CLI \
to supply an operator attestation"
.into(),
));
}
let source_ids_arr = params["source_ids"]
.as_array()
.ok_or_else(|| ToolError::InvalidParams("source_ids must be a JSON array".into()))?;
let now = Utc::now();
let scheduled_for = match params["scheduled_for"].as_str() {
None | Some("") => now,
Some(raw) => DateTime::parse_from_rfc3339(raw)
.map(|ts| ts.with_timezone(&Utc))
.map_err(|err| {
ToolError::InvalidParams(format!(
"scheduled_for `{raw}` is not a valid RFC3339 timestamp: {err}"
))
})?,
};
let summary_method = SummaryMethod::DeterministicConcatenate;
let kind = match kind_str {
"episode_compression" => {
let ids = parse_source_ids::<EpisodeId>(source_ids_arr, "source_ids")?;
DecayJobKind::EpisodeCompression {
source_episode_ids: ids,
summary_method,
}
}
"candidate_compression" => {
let ids = parse_source_ids::<MemoryId>(source_ids_arr, "source_ids")?;
DecayJobKind::CandidateCompression {
source_memory_ids: ids,
summary_method,
}
}
"expired_principle_review" => {
if source_ids_arr.len() != 1 {
return Err(ToolError::InvalidParams(format!(
"expired_principle_review requires exactly one principle id in source_ids, \
got {}",
source_ids_arr.len()
)));
}
let raw = source_ids_arr[0].as_str().ok_or_else(|| {
ToolError::InvalidParams("source_ids entries must be strings".into())
})?;
let principle_id = raw.parse::<PrincipleId>().map_err(|err| {
ToolError::InvalidParams(format!(
"source_ids[0] `{raw}` is not a valid principle id: {err}"
))
})?;
DecayJobKind::ExpiredPrincipleReview { principle_id }
}
_ => unreachable!("kind validated against VALID_KINDS above"),
};
let job_id = DecayJobId::new();
let job = DecayJob {
id: job_id,
kind,
state: DecayJobState::Pending,
scheduled_for,
created_at: now,
created_by: "operator:mcp".to_owned(),
updated_at: now,
};
let record: DecayJobRecord = job.into();
info!(
job_id = %record.id,
kind = %record.kind_wire,
summary_method = %record.summary_method_wire,
scheduled_for = %record.scheduled_for.to_rfc3339(),
"cortex_decay_schedule via MCP"
);
let pool = self
.pool
.lock()
.map_err(|err| ToolError::Internal(format!("pool lock poisoned: {err}")))?;
DecayJobRepo::new(&pool)
.insert(&record)
.map_err(|err| ToolError::Internal(format!("failed to insert decay job: {err}")))?;
Ok(json!({
"job_id": record.id.to_string(),
"kind": record.kind_wire,
"state": "pending",
"scheduled_for": record.scheduled_for.to_rfc3339(),
}))
}
}
fn parse_source_ids<T>(arr: &[Value], field: &str) -> Result<Vec<T>, ToolError>
where
T: std::str::FromStr,
T::Err: std::fmt::Display,
{
if arr.is_empty() {
return Err(ToolError::InvalidParams(format!(
"{field} must contain at least one id"
)));
}
arr.iter()
.enumerate()
.map(|(i, v)| {
let raw = v.as_str().ok_or_else(|| {
ToolError::InvalidParams(format!("{field}[{i}] must be a string"))
})?;
raw.parse::<T>().map_err(|err| {
ToolError::InvalidParams(format!("{field}[{i}] `{raw}` is not a valid id: {err}"))
})
})
.collect()
}
#[allow(dead_code)]
const _NONE_WIRE_USED_BY_FROM_IMPL: &str = SUMMARY_METHOD_NONE_WIRE;