Skip to main content

ainl_agent_snapshot/
lib.rs

1//! Bounded agent graph snapshot + deterministic plan types for the planner protocol.
2//!
3//! Shared between ArmaraOS and `ainl-inference-server` via path or published crate dependency.
4
5mod builder;
6
7pub use builder::{build_snapshot, SnapshotError};
8
9use ainl_contracts::{ProcedureArtifact, ProcedureExecutionPlan, ProcedureStepKind};
10use ainl_memory::AinlMemoryNode;
11use serde::{Deserialize, Serialize};
12
13/// Schema version for [`AgentSnapshot::snapshot_version`]; server rejects unknown versions.
14pub const SNAPSHOT_SCHEMA_VERSION: u32 = 1;
15
16/// Default total plan wall-clock cap (ms).
17pub const DEFAULT_MAX_WALL_MS: u64 = 60_000;
18/// Default max `LocalPatch` replans per plan execution.
19pub const DEFAULT_MAX_REPLAN_CALLS: u32 = 3;
20
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
22pub struct AgentSnapshot {
23    pub agent_id: String,
24    pub snapshot_version: u32,
25    #[serde(default)]
26    pub persona: Vec<AinlMemoryNode>,
27    #[serde(default)]
28    pub episodic: Vec<AinlMemoryNode>,
29    #[serde(default)]
30    pub semantic: Vec<AinlMemoryNode>,
31    #[serde(default)]
32    pub procedural: Vec<AinlMemoryNode>,
33    #[serde(default)]
34    pub tool_allowlist: Vec<String>,
35    #[serde(default)]
36    pub policy_caps: PolicyCaps,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
40pub struct PolicyCaps {
41    #[serde(default = "default_max_steps")]
42    pub max_steps: u32,
43    #[serde(default = "default_max_depth")]
44    pub max_depth: u32,
45    #[serde(default = "default_max_wall_ms")]
46    pub max_wall_ms: u64,
47    #[serde(default = "default_max_replan_calls")]
48    pub max_replan_calls: u32,
49    #[serde(default)]
50    pub deny_tools: Vec<String>,
51}
52
53fn default_max_steps() -> u32 {
54    32
55}
56fn default_max_depth() -> u32 {
57    8
58}
59fn default_max_wall_ms() -> u64 {
60    DEFAULT_MAX_WALL_MS
61}
62fn default_max_replan_calls() -> u32 {
63    DEFAULT_MAX_REPLAN_CALLS
64}
65
66impl Default for PolicyCaps {
67    fn default() -> Self {
68        Self {
69            max_steps: default_max_steps(),
70            max_depth: default_max_depth(),
71            max_wall_ms: default_max_wall_ms(),
72            max_replan_calls: default_max_replan_calls(),
73            deny_tools: Vec::new(),
74        }
75    }
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
79pub struct RepairContext {
80    pub failed_step_id: String,
81    pub failed_step_tool: String,
82    pub error_msg: String,
83    pub prior_outputs: serde_json::Value,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
87pub struct DeterministicPlan {
88    #[serde(default)]
89    pub steps: Vec<PlanStep>,
90    #[serde(default)]
91    pub graph_writes: Vec<GraphWrite>,
92    #[serde(default)]
93    pub confidence: f32,
94    #[serde(default)]
95    pub reasoning_required_at: Vec<String>,
96}
97
98/// Project a portable procedure into the deterministic planner protocol.
99#[must_use]
100pub fn deterministic_plan_from_procedure(artifact: &ProcedureArtifact) -> DeterministicPlan {
101    let execution = procedure_execution_plan_from_artifact(artifact);
102    let steps = execution
103        .steps
104        .iter()
105        .map(|step| PlanStep {
106            id: step.step_id.clone(),
107            tool: if step.executor == "tool" {
108                step.operation.clone()
109            } else {
110                format!("procedure_{}", step.executor)
111            },
112            args: if step.args_schema.is_null() {
113                serde_json::json!({
114                    "procedure_id": execution.procedure_id,
115                    "operation": step.operation,
116                    "title": step.title,
117                })
118            } else {
119                serde_json::json!({
120                    "procedure_id": execution.procedure_id,
121                    "operation": step.operation,
122                    "args_schema": step.args_schema,
123                })
124            },
125            depends_on: step.depends_on.clone(),
126            on_error: OnErrorPolicy::LocalPatch,
127            idempotency_key: Some(format!("{}:{}", execution.procedure_id, step.step_id)),
128            optional: false,
129            expected_output_schema: None,
130        })
131        .collect::<Vec<_>>();
132    DeterministicPlan {
133        steps,
134        graph_writes: vec![GraphWrite {
135            node_type: "semantic".into(),
136            label: format!("procedure_used:{}", artifact.id),
137            payload: serde_json::json!({
138                "fact": format!("Procedure {} was projected into a deterministic plan", artifact.id),
139                "procedure_id": artifact.id,
140            }),
141            fitness_delta: None,
142        }],
143        confidence: artifact.fitness.clamp(0.0, 1.0),
144        reasoning_required_at: Vec::new(),
145    }
146}
147
148#[must_use]
149pub fn procedure_execution_plan_from_artifact(
150    artifact: &ProcedureArtifact,
151) -> ProcedureExecutionPlan {
152    let mut prior_step: Option<String> = None;
153    let steps = artifact
154        .steps
155        .iter()
156        .map(|step| {
157            let (executor, operation, args_schema) = match &step.kind {
158                ProcedureStepKind::ToolCall { tool, args_schema } => {
159                    ("tool".to_string(), tool.clone(), args_schema.clone())
160                }
161                ProcedureStepKind::AdapterCall { adapter, op } => (
162                    "adapter".to_string(),
163                    format!("{adapter}.{op}"),
164                    serde_json::Value::Null,
165                ),
166                ProcedureStepKind::Validate { target } => (
167                    "validate".to_string(),
168                    target.clone(),
169                    serde_json::Value::Null,
170                ),
171                ProcedureStepKind::Branch { condition } => (
172                    "branch".to_string(),
173                    condition.clone(),
174                    serde_json::Value::Null,
175                ),
176                ProcedureStepKind::HumanReview { reason } => (
177                    "human_review".to_string(),
178                    reason.clone(),
179                    serde_json::Value::Null,
180                ),
181                ProcedureStepKind::Instruction { text } => (
182                    "instruction".to_string(),
183                    text.clone(),
184                    serde_json::Value::Null,
185                ),
186            };
187            let depends_on = prior_step.iter().cloned().collect::<Vec<_>>();
188            prior_step = Some(step.step_id.clone());
189            ainl_contracts::ProcedureExecutionStep {
190                step_id: step.step_id.clone(),
191                title: step.title.clone(),
192                executor,
193                operation,
194                args_schema,
195                depends_on,
196                on_error: "local_patch".into(),
197            }
198        })
199        .collect();
200    ProcedureExecutionPlan {
201        procedure_id: artifact.id.clone(),
202        schema_version: ainl_contracts::LEARNER_SCHEMA_VERSION,
203        steps,
204        verification: artifact.verification.clone(),
205    }
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
209pub struct PlanStep {
210    pub id: String,
211    pub tool: String,
212    #[serde(default)]
213    pub args: serde_json::Value,
214    #[serde(default)]
215    pub depends_on: Vec<String>,
216    #[serde(default)]
217    pub on_error: OnErrorPolicy,
218    #[serde(default)]
219    pub idempotency_key: Option<String>,
220    #[serde(default)]
221    pub optional: bool,
222    #[serde(default)]
223    pub expected_output_schema: Option<serde_json::Value>,
224}
225
226#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
227#[serde(rename_all = "snake_case")]
228pub enum OnErrorPolicy {
229    RetryOnce,
230    LocalPatch,
231    #[default]
232    Abort,
233}
234
235#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
236pub struct GraphWrite {
237    pub node_type: String,
238    pub label: String,
239    #[serde(default)]
240    pub payload: serde_json::Value,
241    #[serde(default)]
242    pub fitness_delta: Option<f32>,
243}
244
245/// Typed tool-step failure for escalation without string parsing.
246#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, thiserror::Error)]
247pub enum PlanStepError {
248    #[error("tool not found: {0}")]
249    ToolNotFound(String),
250    #[error("policy blocked: {reason}")]
251    PolicyBlocked { reason: String },
252    #[error("transient: {0}")]
253    Transient(String),
254    #[error("deterministic: {0}")]
255    Deterministic(String),
256    #[error("timeout")]
257    Timeout,
258}
259
260impl PlanStepError {
261    pub fn to_message(&self) -> String {
262        self.to_string()
263    }
264}
265
266/// Lookup window (seconds) for non-episodic snapshot types (semantic, procedural, persona).
267/// 30 days is the default; operators can override via `[runtime_limits] snapshot_non_episodic_window_secs`.
268pub const DEFAULT_NON_EPISODIC_WINDOW_SECS: i64 = 86_400 * 30;
269
270#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
271pub struct SnapshotPolicy {
272    pub episodic_window_secs: i64,
273    pub episodic_max: usize,
274    pub semantic_top_n: usize,
275    pub procedural_top_n: usize,
276    pub persona_top_n: usize,
277    /// Lookup window (seconds) for semantic, procedural, and persona node types.
278    /// Defaults to [`DEFAULT_NON_EPISODIC_WINDOW_SECS`] (30 days).
279    pub non_episodic_window_secs: i64,
280}
281
282impl Default for SnapshotPolicy {
283    fn default() -> Self {
284        Self {
285            episodic_window_secs: 1800,
286            episodic_max: 10,
287            semantic_top_n: 20,
288            procedural_top_n: 10,
289            persona_top_n: 5,
290            non_episodic_window_secs: DEFAULT_NON_EPISODIC_WINDOW_SECS,
291        }
292    }
293}
294
295#[derive(Debug, thiserror::Error)]
296pub enum GraphWriteError {
297    #[error("invalid node_type for graph write: {0}")]
298    InvalidNodeType(String),
299    #[error("episodic and patch writes are not allowed via apply_graph_writes")]
300    DisallowedKind,
301    #[error("failed to build node: {0}")]
302    Build(String),
303}
304
305/// Map planner graph writes to concrete memory nodes for `GraphMemory::write_node`.
306///
307/// Rejects `episode` / `episodic` / `patch` — those paths are owned by the executor / dispatch_patches.
308pub fn apply_graph_writes(
309    writes: &[GraphWrite],
310    agent_id: &str,
311    now_ms: i64,
312) -> Result<Vec<AinlMemoryNode>, GraphWriteError> {
313    use ainl_memory::AinlMemoryNode;
314    use uuid::Uuid;
315
316    let mut out = Vec::with_capacity(writes.len());
317    for w in writes {
318        let nt = w.node_type.to_lowercase();
319        match nt.as_str() {
320            "episode" | "episodic" | "patch" => return Err(GraphWriteError::DisallowedKind),
321            "semantic" => {
322                let fact = w
323                    .payload
324                    .get("fact")
325                    .and_then(|v| v.as_str())
326                    .unwrap_or(&w.label)
327                    .to_string();
328                let confidence = w
329                    .payload
330                    .get("confidence")
331                    .and_then(|v| v.as_f64())
332                    .map(|f| f as f32)
333                    .unwrap_or(0.8);
334                let source_turn_id = w
335                    .payload
336                    .get("source_turn_id")
337                    .and_then(|v| v.as_str())
338                    .and_then(|s| Uuid::parse_str(s).ok())
339                    .unwrap_or_else(Uuid::new_v4);
340                let mut node = AinlMemoryNode::new_fact(fact, confidence, source_turn_id);
341                node.id = Uuid::new_v4();
342                node.agent_id = agent_id.to_string();
343                out.push(node);
344            }
345            "persona" => {
346                let trait_name = w
347                    .payload
348                    .get("trait_name")
349                    .and_then(|v| v.as_str())
350                    .unwrap_or(&w.label)
351                    .to_string();
352                let strength = w
353                    .payload
354                    .get("strength")
355                    .and_then(|v| v.as_f64())
356                    .map(|f| f as f32)
357                    .unwrap_or(0.7);
358                let learned_from = w
359                    .payload
360                    .get("learned_from")
361                    .and_then(|v| v.as_array())
362                    .map(|arr| {
363                        arr.iter()
364                            .filter_map(|x| x.as_str().and_then(|s| Uuid::parse_str(s).ok()))
365                            .collect()
366                    })
367                    .unwrap_or_default();
368                let mut node = AinlMemoryNode::new_persona(trait_name, strength, learned_from);
369                node.id = Uuid::new_v4();
370                node.agent_id = agent_id.to_string();
371                out.push(node);
372            }
373            "procedural" => {
374                let pattern_name = w
375                    .payload
376                    .get("pattern_name")
377                    .and_then(|v| v.as_str())
378                    .unwrap_or(&w.label)
379                    .to_string();
380                let tool_sequence: Vec<String> = w
381                    .payload
382                    .get("tool_sequence")
383                    .and_then(|v| v.as_array())
384                    .map(|arr| {
385                        arr.iter()
386                            .filter_map(|x| x.as_str().map(String::from))
387                            .collect()
388                    })
389                    .unwrap_or_default();
390                let confidence = w
391                    .payload
392                    .get("confidence")
393                    .and_then(|v| v.as_f64())
394                    .map(|f| f as f32)
395                    .unwrap_or(0.75);
396                let mut node = if tool_sequence.is_empty() {
397                    let compiled = w
398                        .payload
399                        .get("compiled_graph")
400                        .and_then(|v| v.as_array())
401                        .map(|a| {
402                            a.iter()
403                                .filter_map(|x| x.as_u64().map(|u| u as u8))
404                                .collect()
405                        })
406                        .unwrap_or_default();
407                    AinlMemoryNode::new_pattern(pattern_name, compiled)
408                } else {
409                    AinlMemoryNode::new_procedural_tools(pattern_name, tool_sequence, confidence)
410                };
411                node.id = Uuid::new_v4();
412                node.agent_id = agent_id.to_string();
413                if let Some(fd) = w.fitness_delta {
414                    if let ainl_memory::AinlNodeType::Procedural { ref mut procedural } =
415                        node.node_type
416                    {
417                        procedural.fitness = Some(procedural.fitness.unwrap_or(0.5) + fd);
418                    }
419                }
420                let _ = now_ms;
421                out.push(node);
422            }
423            other => return Err(GraphWriteError::InvalidNodeType(other.to_string())),
424        }
425    }
426    Ok(out)
427}
428
429/// JSON discriminator for structured planner output (`InferOutput.structured`).
430pub const STRUCTURED_KIND_DETERMINISTIC_PLAN: &str = "deterministic_plan";
431/// Structured response when server-side plan validation fails after repair attempt.
432pub const STRUCTURED_KIND_PLANNER_INVALID_PLAN: &str = "planner_invalid_plan";
433
434#[cfg(test)]
435mod procedure_tests {
436    use super::*;
437    use ainl_contracts::{
438        ProcedureArtifact, ProcedureArtifactFormat, ProcedureLifecycle, ProcedureStep,
439        ProcedureStepKind, ProcedureVerification, LEARNER_SCHEMA_VERSION,
440    };
441
442    fn artifact() -> ProcedureArtifact {
443        ProcedureArtifact {
444            schema_version: LEARNER_SCHEMA_VERSION,
445            id: "proc:test".into(),
446            title: "Test Procedure".into(),
447            intent: "test".into(),
448            summary: "summary".into(),
449            required_tools: vec!["file_read".into()],
450            required_adapters: vec![],
451            inputs: vec![],
452            outputs: vec![],
453            preconditions: vec![],
454            steps: vec![ProcedureStep {
455                step_id: "s1".into(),
456                title: "Read".into(),
457                kind: ProcedureStepKind::ToolCall {
458                    tool: "file_read".into(),
459                    args_schema: serde_json::json!({"type":"object"}),
460                },
461                rationale: None,
462            }],
463            verification: ProcedureVerification::default(),
464            known_failures: vec![],
465            recovery: vec![],
466            source_trajectory_ids: vec![],
467            source_failure_ids: vec![],
468            fitness: 0.9,
469            observation_count: 3,
470            lifecycle: ProcedureLifecycle::Promoted,
471            render_targets: vec![ProcedureArtifactFormat::PromptOnly],
472        }
473    }
474
475    #[test]
476    fn projects_procedure_to_deterministic_plan() {
477        let plan = deterministic_plan_from_procedure(&artifact());
478        assert_eq!(plan.steps.len(), 1);
479        assert_eq!(plan.steps[0].tool, "file_read");
480        assert_eq!(plan.steps[0].on_error, OnErrorPolicy::LocalPatch);
481        assert_eq!(plan.confidence, 0.9);
482    }
483}