Skip to main content

ainl_procedure_learning/
lib.rs

1//! Procedure learning for AINL hosts.
2//!
3//! This crate is the reusable “experience → procedure artifact → reuse → patch” core. Hosts such
4//! as ArmaraOS provide storage, validation, and execution; this crate provides deterministic,
5//! serializable learning decisions without depending on `openfang-*`.
6
7use std::collections::{BTreeSet, HashSet};
8
9use ainl_contracts::{
10    ExperienceBundle, ProcedureArtifact, ProcedureArtifactFormat, ProcedureExecutionPlan,
11    ProcedureExecutionStep, ProcedureLifecycle, ProcedurePatch, ProcedureStep, ProcedureStepKind,
12    ProcedureVerification, TrajectoryOutcome, LEARNER_SCHEMA_VERSION,
13};
14use sha2::{Digest, Sha256};
15
16pub mod proposal_kind {
17    pub const PROCEDURE_MINT: &str = "procedure_mint";
18    pub const PROCEDURE_PATCH: &str = "procedure_patch";
19    pub const PROCEDURE_PROMOTE: &str = "procedure_promote";
20    pub const PROCEDURE_DEPRECATE: &str = "procedure_deprecate";
21    pub const GRAPH_PATCH_FROM_PROCEDURE: &str = "graph_patch_from_procedure";
22}
23
24#[derive(Debug, Clone)]
25pub struct DistillPolicy {
26    pub min_observations: u32,
27    pub min_fitness: f32,
28    pub require_success: bool,
29}
30
31impl Default for DistillPolicy {
32    fn default() -> Self {
33        Self {
34            min_observations: 3,
35            min_fitness: 0.70,
36            require_success: true,
37        }
38    }
39}
40
41#[derive(Debug, Clone, PartialEq)]
42pub struct ReuseScore {
43    pub procedure_id: String,
44    pub score: f32,
45    pub reasons: Vec<String>,
46}
47
48#[derive(Debug, thiserror::Error, PartialEq)]
49pub enum ProcedureLearningError {
50    #[error("insufficient observations: {observed} < {required}")]
51    InsufficientObservations { observed: u32, required: u32 },
52    #[error("fitness below threshold: {fitness:.3} < {required:.3}")]
53    FitnessBelowThreshold { fitness: f32, required: f32 },
54    #[error("experience outcome is not successful")]
55    NonSuccessfulOutcome,
56    #[error("experience has no events")]
57    EmptyExperience,
58}
59
60#[must_use]
61pub fn sha256_hex_lower(s: &str) -> String {
62    hex::encode(Sha256::digest(s.as_bytes()))
63}
64
65#[must_use]
66pub fn procedure_fingerprint(bundle: &ExperienceBundle) -> String {
67    let mut canonical = String::new();
68    canonical.push_str(bundle.intent.trim());
69    canonical.push('\n');
70    for event in &bundle.events {
71        canonical.push_str(&event.tool_or_adapter);
72        canonical.push(':');
73        canonical.push_str(&event.operation);
74        canonical.push(':');
75        canonical.push_str(if event.success { "ok" } else { "err" });
76        canonical.push('\n');
77    }
78    sha256_hex_lower(&canonical)
79}
80
81pub fn distill_procedure(
82    bundle: &ExperienceBundle,
83    policy: &DistillPolicy,
84) -> Result<ProcedureArtifact, ProcedureLearningError> {
85    if bundle.events.is_empty() {
86        return Err(ProcedureLearningError::EmptyExperience);
87    }
88    if bundle.observation_count < policy.min_observations {
89        return Err(ProcedureLearningError::InsufficientObservations {
90            observed: bundle.observation_count,
91            required: policy.min_observations,
92        });
93    }
94    if bundle.fitness < policy.min_fitness {
95        return Err(ProcedureLearningError::FitnessBelowThreshold {
96            fitness: bundle.fitness,
97            required: policy.min_fitness,
98        });
99    }
100    if policy.require_success && bundle.outcome != TrajectoryOutcome::Success {
101        return Err(ProcedureLearningError::NonSuccessfulOutcome);
102    }
103
104    let fingerprint = procedure_fingerprint(bundle);
105    let mut required_tools = BTreeSet::new();
106    let mut required_adapters = BTreeSet::new();
107    let mut known_failures = Vec::new();
108    let steps = bundle
109        .events
110        .iter()
111        .enumerate()
112        .map(|(idx, event)| {
113            if event.success {
114                required_tools.insert(event.operation.clone());
115            } else if let Some(err) = &event.error {
116                known_failures.push(format!("{}: {}", event.operation, err));
117            }
118            if event.tool_or_adapter != "tool" {
119                required_adapters.insert(event.tool_or_adapter.clone());
120            }
121            ProcedureStep {
122                step_id: format!("step-{:02}", idx + 1),
123                title: event.operation.clone(),
124                kind: ProcedureStepKind::ToolCall {
125                    tool: event.operation.clone(),
126                    args_schema: serde_json::json!({"type":"object"}),
127                },
128                rationale: event.output_preview.clone(),
129            }
130        })
131        .collect::<Vec<_>>();
132
133    Ok(ProcedureArtifact {
134        schema_version: LEARNER_SCHEMA_VERSION,
135        id: format!("proc:{fingerprint}"),
136        title: title_from_intent(&bundle.intent),
137        intent: bundle.intent.clone(),
138        summary: format!(
139            "Learned from {} observations with fitness {:.2}.",
140            bundle.observation_count, bundle.fitness
141        ),
142        required_tools: required_tools.into_iter().collect(),
143        required_adapters: required_adapters.into_iter().collect(),
144        inputs: Vec::new(),
145        outputs: Vec::new(),
146        preconditions: vec![
147            "Use this procedure only when the user task matches the intent.".into(),
148        ],
149        steps,
150        verification: ProcedureVerification {
151            checks: vec![
152                "Confirm all required tool calls completed successfully.".into(),
153                "Summarize any errors instead of claiming success.".into(),
154            ],
155            success_criteria: vec![
156                "The requested workflow is completed or a safe failure is reported.".into(),
157            ],
158        },
159        known_failures,
160        recovery: vec!["If any step fails, stop and inspect the failure before retrying.".into()],
161        source_trajectory_ids: bundle.source_trajectory_ids.clone(),
162        source_failure_ids: bundle.source_failure_ids.clone(),
163        fitness: bundle.fitness,
164        observation_count: bundle.observation_count,
165        lifecycle: ProcedureLifecycle::Candidate,
166        render_targets: vec![
167            ProcedureArtifactFormat::MarkdownSkill,
168            ProcedureArtifactFormat::AinlGraph,
169            ProcedureArtifactFormat::PromptOnly,
170        ],
171    })
172}
173
174#[must_use]
175pub fn score_reuse(
176    artifact: &ProcedureArtifact,
177    user_intent: &str,
178    available_tools: &[String],
179) -> ReuseScore {
180    let intent_l = user_intent.to_ascii_lowercase();
181    let mut score = 0.0_f32;
182    let mut reasons = Vec::new();
183    for token in artifact
184        .intent
185        .split(|c: char| !c.is_ascii_alphanumeric())
186        .filter(|t| t.len() >= 4)
187    {
188        if intent_l.contains(&token.to_ascii_lowercase()) {
189            score += 0.15;
190        }
191    }
192    let available: HashSet<&str> = available_tools.iter().map(String::as_str).collect();
193    let required = artifact.required_tools.len().max(1) as f32;
194    let have = artifact
195        .required_tools
196        .iter()
197        .filter(|t| available.contains(t.as_str()))
198        .count() as f32;
199    let tool_score = have / required;
200    score += tool_score * 0.45;
201    if tool_score >= 1.0 {
202        reasons.push("all_required_tools_available".into());
203    } else {
204        reasons.push("some_required_tools_missing".into());
205    }
206    score += artifact.fitness.clamp(0.0, 1.0) * 0.30;
207    score += (artifact.observation_count.min(10) as f32 / 10.0) * 0.10;
208    ReuseScore {
209        procedure_id: artifact.id.clone(),
210        score: score.clamp(0.0, 1.0),
211        reasons,
212    }
213}
214
215#[must_use]
216pub fn patch_from_failure(
217    artifact: &ProcedureArtifact,
218    failure_id: impl Into<String>,
219    failure_message: impl Into<String>,
220) -> ProcedurePatch {
221    let failure_id = failure_id.into();
222    let failure_message = failure_message.into();
223    let patch_hash = sha256_hex_lower(&format!("{}:{failure_id}:{failure_message}", artifact.id));
224    ProcedurePatch {
225        schema_version: LEARNER_SCHEMA_VERSION,
226        patch_id: format!("patch:{patch_hash}"),
227        procedure_id: artifact.id.clone(),
228        rationale: format!("Patch learned from failed reuse: {failure_message}"),
229        add_steps: vec![ProcedureStep {
230            step_id: "recovery-check".into(),
231            title: "Check prior failure before retry".into(),
232            kind: ProcedureStepKind::Validate {
233                target: "previous failure is addressed".into(),
234            },
235            rationale: Some(failure_message.clone()),
236        }],
237        add_known_failures: vec![failure_message],
238        add_recovery: vec!["Do not retry unchanged inputs after this failure.".into()],
239        source_failure_ids: vec![failure_id],
240    }
241}
242
243#[must_use]
244pub fn apply_patch(artifact: &ProcedureArtifact, patch: &ProcedurePatch) -> ProcedureArtifact {
245    let mut next = artifact.clone();
246    next.steps.extend(patch.add_steps.clone());
247    next.known_failures.extend(patch.add_known_failures.clone());
248    next.recovery.extend(patch.add_recovery.clone());
249    next.source_failure_ids
250        .extend(patch.source_failure_ids.clone());
251    next.lifecycle = ProcedureLifecycle::Candidate;
252    next
253}
254
255#[must_use]
256pub fn render_execution_plan(artifact: &ProcedureArtifact) -> ProcedureExecutionPlan {
257    let mut prior_step: Option<String> = None;
258    let steps = artifact
259        .steps
260        .iter()
261        .map(|step| {
262            let (executor, operation, args_schema) = match &step.kind {
263                ProcedureStepKind::ToolCall { tool, args_schema } => {
264                    ("tool".to_string(), tool.clone(), args_schema.clone())
265                }
266                ProcedureStepKind::AdapterCall { adapter, op } => (
267                    "adapter".to_string(),
268                    format!("{adapter}.{op}"),
269                    serde_json::Value::Null,
270                ),
271                ProcedureStepKind::Validate { target } => (
272                    "validate".to_string(),
273                    target.clone(),
274                    serde_json::Value::Null,
275                ),
276                ProcedureStepKind::Branch { condition } => (
277                    "branch".to_string(),
278                    condition.clone(),
279                    serde_json::Value::Null,
280                ),
281                ProcedureStepKind::HumanReview { reason } => (
282                    "human_review".to_string(),
283                    reason.clone(),
284                    serde_json::Value::Null,
285                ),
286                ProcedureStepKind::Instruction { text } => (
287                    "instruction".to_string(),
288                    text.clone(),
289                    serde_json::Value::Null,
290                ),
291            };
292            let depends_on = prior_step.iter().cloned().collect::<Vec<_>>();
293            prior_step = Some(step.step_id.clone());
294            ProcedureExecutionStep {
295                step_id: step.step_id.clone(),
296                title: step.title.clone(),
297                executor,
298                operation,
299                args_schema,
300                depends_on,
301                on_error: "abort_and_patch".into(),
302            }
303        })
304        .collect();
305    ProcedureExecutionPlan {
306        procedure_id: artifact.id.clone(),
307        schema_version: LEARNER_SCHEMA_VERSION,
308        steps,
309        verification: artifact.verification.clone(),
310    }
311}
312
313#[must_use]
314pub fn render_markdown_skill(artifact: &ProcedureArtifact) -> String {
315    let mut out = String::new();
316    out.push_str("# ");
317    out.push_str(&artifact.title);
318    out.push_str("\n\n## Intent\n\n");
319    out.push_str(&artifact.intent);
320    out.push_str("\n\n## Summary\n\n");
321    out.push_str(&artifact.summary);
322    if !artifact.required_tools.is_empty() {
323        out.push_str("\n\n## Required Tools\n\n");
324        for tool in &artifact.required_tools {
325            out.push_str("- `");
326            out.push_str(tool);
327            out.push_str("`\n");
328        }
329    }
330    out.push_str("\n## Procedure\n\n");
331    for step in &artifact.steps {
332        out.push_str("- ");
333        out.push_str(&step.title);
334        if let Some(r) = &step.rationale {
335            out.push_str(" — ");
336            out.push_str(r);
337        }
338        out.push('\n');
339    }
340    if !artifact.known_failures.is_empty() {
341        out.push_str("\n## Known Failures\n\n");
342        for failure in &artifact.known_failures {
343            out.push_str("- ");
344            out.push_str(failure);
345            out.push('\n');
346        }
347    }
348    if !artifact.verification.checks.is_empty() {
349        out.push_str("\n## Verification\n\n");
350        for check in &artifact.verification.checks {
351            out.push_str("- ");
352            out.push_str(check);
353            out.push('\n');
354        }
355    }
356    out
357}
358
359#[must_use]
360pub fn render_ainl_compact_skeleton(artifact: &ProcedureArtifact, graph_name: &str) -> String {
361    let graph = sanitize_graph_name(graph_name);
362    let mut out = format!("# generated from {}\n{}:\n", artifact.id, graph);
363    out.push_str("  in: task\n");
364    out.push_str("  # Procedure intent: ");
365    out.push_str(&artifact.intent.replace('\n', " "));
366    out.push('\n');
367    for step in &artifact.steps {
368        out.push_str("  # ");
369        out.push_str(&step.title.replace('\n', " "));
370        out.push('\n');
371    }
372    out.push_str("  out \"procedure_skeleton:");
373    out.push_str(&artifact.id.replace('"', ""));
374    out.push_str("\"\n");
375    out
376}
377
378#[must_use]
379pub fn render_openfang_skill_toml(artifact: &ProcedureArtifact) -> String {
380    let mut out = String::new();
381    out.push_str("[skill]\n");
382    out.push_str("id = \"");
383    out.push_str(&toml_escape(&artifact.id));
384    out.push_str("\"\nname = \"");
385    out.push_str(&toml_escape(&artifact.title));
386    out.push_str("\"\ndescription = \"");
387    out.push_str(&toml_escape(&artifact.summary));
388    out.push_str("\"\nlifecycle = \"");
389    out.push_str(match artifact.lifecycle {
390        ProcedureLifecycle::Draft => "draft",
391        ProcedureLifecycle::Candidate => "candidate",
392        ProcedureLifecycle::Validated => "validated",
393        ProcedureLifecycle::Promoted => "promoted",
394        ProcedureLifecycle::Deprecated => "deprecated",
395    });
396    out.push_str("\"\n\n[procedure]\nintent = \"");
397    out.push_str(&toml_escape(&artifact.intent));
398    out.push_str("\"\nrequired_tools = [");
399    out.push_str(&quoted_toml_list(&artifact.required_tools));
400    out.push_str("]\nobservation_count = ");
401    out.push_str(&artifact.observation_count.to_string());
402    out.push_str("\nfitness = ");
403    out.push_str(&format!("{:.3}", artifact.fitness));
404    out.push('\n');
405    out
406}
407
408#[must_use]
409pub fn render_hand_metadata_toml(artifact: &ProcedureArtifact) -> String {
410    let mut out = String::new();
411    out.push_str("[hand]\n");
412    out.push_str("schema_version = \"1\"\nname = \"");
413    out.push_str(&toml_escape(&artifact.title));
414    out.push_str("\"\ndescription = \"");
415    out.push_str(&toml_escape(&artifact.summary));
416    out.push_str("\"\n\n[ainl_procedure]\nid = \"");
417    out.push_str(&toml_escape(&artifact.id));
418    out.push_str("\"\nintent = \"");
419    out.push_str(&toml_escape(&artifact.intent));
420    out.push_str("\"\nrendered_from = \"procedure_artifact\"\n");
421    out
422}
423
424fn title_from_intent(intent: &str) -> String {
425    let t = intent.trim();
426    if t.is_empty() {
427        "Learned Procedure".into()
428    } else {
429        let first = t.lines().next().unwrap_or(t);
430        let mut s = first.chars().take(80).collect::<String>();
431        if s.len() < first.len() {
432            s.push_str("...");
433        }
434        s
435    }
436}
437
438fn quoted_toml_list(values: &[String]) -> String {
439    values
440        .iter()
441        .map(|v| format!("\"{}\"", toml_escape(v)))
442        .collect::<Vec<_>>()
443        .join(", ")
444}
445
446fn toml_escape(value: &str) -> String {
447    value
448        .replace('\\', "\\\\")
449        .replace('"', "\\\"")
450        .replace('\n', " ")
451}
452
453fn sanitize_graph_name(name: &str) -> String {
454    let mut out = name
455        .chars()
456        .map(|c| {
457            if c.is_ascii_alphanumeric() || c == '_' {
458                c
459            } else {
460                '_'
461            }
462        })
463        .collect::<String>();
464    if out.is_empty() {
465        out.push_str("learned_procedure");
466    }
467    if out.chars().next().is_some_and(|c| c.is_ascii_digit()) {
468        out.insert(0, '_');
469    }
470    out
471}
472
473#[cfg(test)]
474mod tests {
475    use super::*;
476    use ainl_contracts::{ContextFreshness, ExperienceEvent, ImpactDecision};
477
478    fn sample_bundle() -> ExperienceBundle {
479        ExperienceBundle {
480            schema_version: LEARNER_SCHEMA_VERSION,
481            bundle_id: "bundle-1".into(),
482            agent_id: "agent-1".into(),
483            intent: "Review a pull request".into(),
484            outcome: TrajectoryOutcome::Success,
485            host_outcome: None,
486            observation_count: 3,
487            fitness: 0.8,
488            events: vec![
489                ExperienceEvent {
490                    event_id: "e1".into(),
491                    timestamp_ms: 1,
492                    tool_or_adapter: "tool".into(),
493                    operation: "file_read".into(),
494                    success: true,
495                    input_preview: None,
496                    output_preview: Some("read diff".into()),
497                    error: None,
498                    duration_ms: 10,
499                    vitals: None,
500                    freshness_at_step: None,
501                },
502                ExperienceEvent {
503                    event_id: "e2".into(),
504                    timestamp_ms: 2,
505                    tool_or_adapter: "tool".into(),
506                    operation: "shell_exec".into(),
507                    success: true,
508                    input_preview: None,
509                    output_preview: Some("tests pass".into()),
510                    error: None,
511                    duration_ms: 20,
512                    vitals: None,
513                    freshness_at_step: None,
514                },
515            ],
516            source_trajectory_ids: vec!["traj-1".into()],
517            source_failure_ids: vec![],
518            freshness: ContextFreshness::Fresh,
519            impact_decision: ImpactDecision::AllowExecute,
520        }
521    }
522
523    #[test]
524    fn distills_successful_recurrent_bundle() {
525        let artifact = distill_procedure(&sample_bundle(), &DistillPolicy::default()).unwrap();
526        assert_eq!(artifact.lifecycle, ProcedureLifecycle::Candidate);
527        assert!(artifact.required_tools.contains(&"file_read".to_string()));
528        assert_eq!(artifact.steps.len(), 2);
529    }
530
531    #[test]
532    fn rejects_low_observation_bundle() {
533        let mut b = sample_bundle();
534        b.observation_count = 1;
535        let err = distill_procedure(&b, &DistillPolicy::default()).unwrap_err();
536        assert!(matches!(
537            err,
538            ProcedureLearningError::InsufficientObservations { .. }
539        ));
540    }
541
542    #[test]
543    fn scores_reuse_from_intent_tools_and_fitness() {
544        let artifact = distill_procedure(&sample_bundle(), &DistillPolicy::default()).unwrap();
545        let score = score_reuse(
546            &artifact,
547            "Please review this pull request",
548            &["file_read".into(), "shell_exec".into()],
549        );
550        assert!(score.score > 0.7, "{score:?}");
551    }
552
553    #[test]
554    fn failure_patch_applies_to_artifact() {
555        let artifact = distill_procedure(&sample_bundle(), &DistillPolicy::default()).unwrap();
556        let patch = patch_from_failure(&artifact, "f1", "shell timed out");
557        let next = apply_patch(&artifact, &patch);
558        assert!(next.known_failures.iter().any(|f| f.contains("timed out")));
559        assert!(next.steps.len() > artifact.steps.len());
560    }
561
562    #[test]
563    fn renders_markdown_and_ainl_skeleton() {
564        let artifact = distill_procedure(&sample_bundle(), &DistillPolicy::default()).unwrap();
565        let md = render_markdown_skill(&artifact);
566        assert!(md.contains("## Procedure"));
567        let ainl = render_ainl_compact_skeleton(&artifact, "review-pr");
568        assert!(ainl.contains("review_pr:"));
569        assert!(ainl.contains("procedure_skeleton"));
570        assert!(!ainl.contains("out {"));
571        let skill_toml = render_openfang_skill_toml(&artifact);
572        assert!(skill_toml.contains("[skill]"));
573        assert!(skill_toml.contains("required_tools"));
574        let hand_toml = render_hand_metadata_toml(&artifact);
575        assert!(hand_toml.contains("[hand]"));
576        assert!(hand_toml.contains("[ainl_procedure]"));
577    }
578
579    #[test]
580    fn crate_manifest_has_no_openfang_dependency() {
581        let manifest = include_str!("../Cargo.toml");
582        assert!(!manifest.contains("openfang-"));
583    }
584
585    #[test]
586    fn renders_portable_execution_plan() {
587        let artifact = distill_procedure(&sample_bundle(), &DistillPolicy::default()).unwrap();
588        let plan = render_execution_plan(&artifact);
589        assert_eq!(plan.procedure_id, artifact.id);
590        assert_eq!(plan.steps.len(), artifact.steps.len());
591        assert_eq!(plan.steps[0].executor, "tool");
592    }
593}