Skip to main content

harn_vm/corrections/
mod.rs

1use std::collections::{BTreeMap, HashSet};
2use std::sync::Arc;
3
4use serde::{Deserialize, Serialize};
5use time::OffsetDateTime;
6use uuid::Uuid;
7
8use crate::event_log::{AnyEventLog, EventLog, LogError, LogEvent, Topic};
9use crate::orchestration::CapabilityPolicy;
10
11pub const CORRECTION_SCHEMA_V0: &str = "harn-correction/v0";
12pub const CORRECTIONS_TOPIC: &str = "corrections.records";
13pub const CORRECTION_EVENT_KIND: &str = "correction_recorded";
14
15#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
16#[serde(rename_all = "snake_case")]
17pub enum CorrectionScope {
18    #[default]
19    ThisRun,
20    ThisPersona,
21    All,
22}
23
24impl CorrectionScope {
25    pub fn as_str(self) -> &'static str {
26        match self {
27            Self::ThisRun => "this_run",
28            Self::ThisPersona => "this_persona",
29            Self::All => "all",
30        }
31    }
32
33    pub fn parse(value: &str) -> Result<Self, String> {
34        match value {
35            "this_run" => Ok(Self::ThisRun),
36            "this_persona" => Ok(Self::ThisPersona),
37            "all" => Ok(Self::All),
38            other => Err(format!(
39                "unsupported correction scope '{other}', expected this_run|this_persona|all"
40            )),
41        }
42    }
43}
44
45#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
46pub struct CorrectionRecord {
47    pub schema: String,
48    pub correction_id: String,
49    pub from_decision: serde_json::Value,
50    pub to_decision: serde_json::Value,
51    pub reason: String,
52    pub applied_by: String,
53    pub scope: CorrectionScope,
54    #[serde(with = "time::serde::rfc3339")]
55    pub timestamp: OffsetDateTime,
56    #[serde(default, skip_serializing_if = "Option::is_none")]
57    pub actor_id: Option<String>,
58    #[serde(default, skip_serializing_if = "Option::is_none")]
59    pub action: Option<String>,
60    #[serde(default, skip_serializing_if = "Option::is_none")]
61    pub trace_id: Option<String>,
62    #[serde(default, skip_serializing_if = "Option::is_none")]
63    pub step: Option<String>,
64    #[serde(default, skip_serializing_if = "Vec::is_empty")]
65    pub evidence_refs: Vec<serde_json::Value>,
66    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
67    pub metadata: BTreeMap<String, serde_json::Value>,
68}
69
70impl CorrectionRecord {
71    pub fn new(
72        from_decision: serde_json::Value,
73        to_decision: serde_json::Value,
74        reason: impl Into<String>,
75        applied_by: impl Into<String>,
76        scope: CorrectionScope,
77    ) -> Self {
78        let actor_id = extract_string_anywhere(
79            &from_decision,
80            &to_decision,
81            &["actor_id", "actor", "agent", "trigger_id", "binding_id"],
82        );
83        let action =
84            extract_string_anywhere(&from_decision, &to_decision, &["action", "event_kind"]);
85        let trace_id = extract_string_anywhere(&from_decision, &to_decision, &["trace_id"]);
86        Self {
87            schema: CORRECTION_SCHEMA_V0.to_string(),
88            correction_id: Uuid::now_v7().to_string(),
89            from_decision,
90            to_decision,
91            reason: reason.into(),
92            applied_by: applied_by.into(),
93            scope,
94            timestamp: OffsetDateTime::now_utc(),
95            actor_id,
96            action,
97            trace_id,
98            step: None,
99            evidence_refs: Vec::new(),
100            metadata: BTreeMap::new(),
101        }
102    }
103}
104
105#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
106#[serde(default)]
107pub struct CorrectionQueryFilters {
108    pub actor_id: Option<String>,
109    pub action: Option<String>,
110    pub scope: Option<CorrectionScope>,
111    #[serde(with = "time::serde::rfc3339::option")]
112    pub since: Option<OffsetDateTime>,
113    #[serde(with = "time::serde::rfc3339::option")]
114    pub until: Option<OffsetDateTime>,
115    pub limit: Option<usize>,
116}
117
118pub fn corrections_topic() -> Result<Topic, LogError> {
119    Topic::new(CORRECTIONS_TOPIC)
120}
121
122pub async fn append_correction_record(
123    log: &Arc<AnyEventLog>,
124    record: &CorrectionRecord,
125) -> Result<CorrectionRecord, LogError> {
126    let payload = serde_json::to_value(record)
127        .map_err(|error| LogError::Serde(format!("correction record encode error: {error}")))?;
128    let mut headers = BTreeMap::new();
129    headers.insert("correction_id".to_string(), record.correction_id.clone());
130    headers.insert("scope".to_string(), record.scope.as_str().to_string());
131    if let Some(actor_id) = record.actor_id.as_ref() {
132        headers.insert("actor_id".to_string(), actor_id.clone());
133    }
134    if let Some(action) = record.action.as_ref() {
135        headers.insert("action".to_string(), action.clone());
136    }
137    if let Some(trace_id) = record.trace_id.as_ref() {
138        headers.insert("trace_id".to_string(), trace_id.clone());
139    }
140    log.append(
141        &corrections_topic()?,
142        LogEvent::new(CORRECTION_EVENT_KIND, payload).with_headers(headers),
143    )
144    .await?;
145    Ok(record.clone())
146}
147
148pub async fn query_correction_records(
149    log: &Arc<AnyEventLog>,
150    filters: &CorrectionQueryFilters,
151) -> Result<Vec<CorrectionRecord>, LogError> {
152    let mut records = Vec::new();
153    let mut seen = HashSet::new();
154    for (_, event) in log
155        .read_range(&corrections_topic()?, None, usize::MAX)
156        .await?
157    {
158        if event.kind != CORRECTION_EVENT_KIND {
159            continue;
160        }
161        let Ok(record) = serde_json::from_value::<CorrectionRecord>(event.payload) else {
162            continue;
163        };
164        if !matches_filters(&record, filters) {
165            continue;
166        }
167        if seen.insert(record.correction_id.clone()) {
168            records.push(record);
169        }
170    }
171    records.sort_by(|left, right| {
172        left.timestamp
173            .cmp(&right.timestamp)
174            .then(left.correction_id.cmp(&right.correction_id))
175    });
176    apply_limit(&mut records, filters.limit);
177    Ok(records)
178}
179
180pub async fn apply_corrections_to_policy(
181    log: &Arc<AnyEventLog>,
182    actor_id: &str,
183    policy: CapabilityPolicy,
184) -> Result<CapabilityPolicy, LogError> {
185    let corrections = query_correction_records(
186        log,
187        &CorrectionQueryFilters {
188            actor_id: Some(actor_id.to_string()),
189            ..CorrectionQueryFilters::default()
190        },
191    )
192    .await?;
193    Ok(policy_with_corrections(actor_id, policy, &corrections))
194}
195
196pub fn policy_with_corrections(
197    actor_id: &str,
198    mut policy: CapabilityPolicy,
199    corrections: &[CorrectionRecord],
200) -> CapabilityPolicy {
201    let should_tighten = corrections.iter().any(|record| {
202        correction_applies_to_actor(record, actor_id)
203            && matches!(
204                record.scope,
205                CorrectionScope::ThisPersona | CorrectionScope::All
206            )
207    });
208    if should_tighten {
209        cap_side_effect_level(&mut policy, "read_only");
210    }
211    policy
212}
213
214pub fn correction_record_from_json(value: serde_json::Value) -> Result<CorrectionRecord, String> {
215    let object = value
216        .as_object()
217        .ok_or_else(|| "correction record must be a dict".to_string())?;
218    let from_decision = required_json(object, "from_decision")?;
219    let to_decision = required_json(object, "to_decision")?;
220    let reason = required_string(object, "reason")?;
221    let applied_by = required_string(object, "applied_by")?;
222    let scope = object
223        .get("scope")
224        .and_then(serde_json::Value::as_str)
225        .map(CorrectionScope::parse)
226        .transpose()?
227        .unwrap_or_default();
228
229    let mut record = CorrectionRecord::new(
230        from_decision.clone(),
231        to_decision.clone(),
232        reason,
233        applied_by,
234        scope,
235    );
236    if let Some(actor_id) = optional_string(object, "actor_id")
237        .or_else(|| optional_string(object, "actor"))
238        .or_else(|| optional_string(object, "agent"))
239    {
240        record.actor_id = Some(actor_id);
241    }
242    if let Some(action) = optional_string(object, "action") {
243        record.action = Some(action);
244    }
245    if let Some(trace_id) = optional_string(object, "trace_id") {
246        record.trace_id = Some(trace_id);
247    }
248    record.step = optional_string(object, "step");
249    record.evidence_refs = match object.get("evidence_refs") {
250        Some(value) => value
251            .as_array()
252            .cloned()
253            .ok_or_else(|| "correction record field `evidence_refs` must be a list".to_string())?,
254        None => Vec::new(),
255    };
256    if let Some(metadata) = object.get("metadata") {
257        let Some(metadata_object) = metadata.as_object() else {
258            return Err("correction record field `metadata` must be a dict".to_string());
259        };
260        record.metadata.extend(
261            metadata_object
262                .iter()
263                .map(|(key, value)| (key.clone(), value.clone())),
264        );
265    }
266    Ok(record)
267}
268
269pub fn correction_query_filters_from_json(
270    value: serde_json::Value,
271) -> Result<CorrectionQueryFilters, String> {
272    let object = value
273        .as_object()
274        .ok_or_else(|| "correction query filters must be a dict".to_string())?;
275    let scope = object
276        .get("scope")
277        .and_then(serde_json::Value::as_str)
278        .map(CorrectionScope::parse)
279        .transpose()?;
280    Ok(CorrectionQueryFilters {
281        actor_id: optional_string(object, "actor_id")
282            .or_else(|| optional_string(object, "actor"))
283            .or_else(|| optional_string(object, "agent")),
284        action: optional_string(object, "action"),
285        scope,
286        since: object
287            .get("since")
288            .and_then(serde_json::Value::as_str)
289            .map(parse_rfc3339)
290            .transpose()?,
291        until: object
292            .get("until")
293            .and_then(serde_json::Value::as_str)
294            .map(parse_rfc3339)
295            .transpose()?,
296        limit: optional_limit(object, "limit")?,
297    })
298}
299
300fn matches_filters(record: &CorrectionRecord, filters: &CorrectionQueryFilters) -> bool {
301    if let Some(actor_id) = filters.actor_id.as_deref() {
302        if !correction_applies_to_actor(record, actor_id) {
303            return false;
304        }
305    }
306    if let Some(action) = filters.action.as_deref() {
307        if record.action.as_deref() != Some(action) {
308            return false;
309        }
310    }
311    if let Some(scope) = filters.scope {
312        if record.scope != scope {
313            return false;
314        }
315    }
316    if let Some(since) = filters.since {
317        if record.timestamp < since {
318            return false;
319        }
320    }
321    if let Some(until) = filters.until {
322        if record.timestamp > until {
323            return false;
324        }
325    }
326    true
327}
328
329fn correction_applies_to_actor(record: &CorrectionRecord, actor_id: &str) -> bool {
330    record.scope == CorrectionScope::All || record.actor_id.as_deref() == Some(actor_id)
331}
332
333fn cap_side_effect_level(policy: &mut CapabilityPolicy, ceiling: &str) {
334    let current = policy.side_effect_level.as_deref().unwrap_or("network");
335    if side_effect_rank(current) > side_effect_rank(ceiling) {
336        policy.side_effect_level = Some(ceiling.to_string());
337    }
338}
339
340fn side_effect_rank(value: &str) -> usize {
341    match value {
342        "none" => 0,
343        "read_only" => 1,
344        "workspace_write" => 2,
345        "process_exec" => 3,
346        "network" => 4,
347        _ => 5,
348    }
349}
350
351fn apply_limit(records: &mut Vec<CorrectionRecord>, limit: Option<usize>) {
352    let Some(limit) = limit else {
353        return;
354    };
355    if records.len() <= limit {
356        return;
357    }
358    let keep_from = records.len() - limit;
359    records.drain(0..keep_from);
360}
361
362fn required_json(
363    object: &serde_json::Map<String, serde_json::Value>,
364    field: &str,
365) -> Result<serde_json::Value, String> {
366    object
367        .get(field)
368        .cloned()
369        .ok_or_else(|| format!("correction record missing field `{field}`"))
370}
371
372fn required_string(
373    object: &serde_json::Map<String, serde_json::Value>,
374    field: &str,
375) -> Result<String, String> {
376    optional_string(object, field)
377        .ok_or_else(|| format!("correction record missing string field `{field}`"))
378}
379
380fn optional_string(
381    object: &serde_json::Map<String, serde_json::Value>,
382    field: &str,
383) -> Option<String> {
384    object
385        .get(field)
386        .and_then(serde_json::Value::as_str)
387        .map(str::to_string)
388}
389
390fn optional_limit(
391    object: &serde_json::Map<String, serde_json::Value>,
392    field: &str,
393) -> Result<Option<usize>, String> {
394    let Some(value) = object.get(field) else {
395        return Ok(None);
396    };
397    if let Some(limit) = value.as_u64() {
398        return usize::try_from(limit)
399            .map(Some)
400            .map_err(|_| format!("correction query field `{field}` is too large"));
401    }
402    if value.as_i64().is_some() {
403        return Err(format!(
404            "correction query field `{field}` must be non-negative"
405        ));
406    }
407    Err(format!(
408        "correction query field `{field}` must be an integer"
409    ))
410}
411
412fn extract_string_anywhere(
413    first: &serde_json::Value,
414    second: &serde_json::Value,
415    keys: &[&str],
416) -> Option<String> {
417    keys.iter()
418        .find_map(|key| extract_string(first, key).or_else(|| extract_string(second, key)))
419}
420
421fn extract_string(value: &serde_json::Value, key: &str) -> Option<String> {
422    value
423        .as_object()
424        .and_then(|object| object.get(key))
425        .and_then(serde_json::Value::as_str)
426        .map(str::to_string)
427        .or_else(|| {
428            value
429                .get("metadata")
430                .and_then(|metadata| metadata.get(key))
431                .and_then(serde_json::Value::as_str)
432                .map(str::to_string)
433        })
434}
435
436fn parse_rfc3339(value: &str) -> Result<OffsetDateTime, String> {
437    OffsetDateTime::parse(value, &time::format_description::well_known::Rfc3339)
438        .map_err(|error| format!("invalid RFC3339 timestamp '{value}': {error}"))
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444    use crate::event_log::MemoryEventLog;
445
446    #[tokio::test]
447    async fn append_query_and_policy_round_trip() {
448        let log: Arc<AnyEventLog> = Arc::new(AnyEventLog::Memory(MemoryEventLog::new(16)));
449        let mut record = CorrectionRecord::new(
450            serde_json::json!({
451                "actor_id": "bot",
452                "action": "github.issue.opened",
453                "outcome": "success"
454            }),
455            serde_json::json!({"outcome": "denied"}),
456            "operator corrected unsafe issue triage",
457            "alice",
458            CorrectionScope::ThisPersona,
459        );
460        record.trace_id = Some("trace-correction".to_string());
461        append_correction_record(&log, &record).await.unwrap();
462
463        let records = query_correction_records(
464            &log,
465            &CorrectionQueryFilters {
466                actor_id: Some("bot".to_string()),
467                action: Some("github.issue.opened".to_string()),
468                ..CorrectionQueryFilters::default()
469            },
470        )
471        .await
472        .unwrap();
473
474        assert_eq!(records.len(), 1);
475        assert_eq!(records[0].applied_by, "alice");
476        assert_eq!(records[0].scope, CorrectionScope::ThisPersona);
477
478        let policy = CapabilityPolicy {
479            side_effect_level: Some("network".to_string()),
480            ..CapabilityPolicy::default()
481        };
482        let adapted = apply_corrections_to_policy(&log, "bot", policy)
483            .await
484            .unwrap();
485        assert_eq!(adapted.side_effect_level.as_deref(), Some("read_only"));
486    }
487
488    #[test]
489    fn this_run_scope_does_not_change_persona_policy() {
490        let record = CorrectionRecord::new(
491            serde_json::json!({"actor_id": "bot", "action": "deploy"}),
492            serde_json::json!({"outcome": "denied"}),
493            "single-run correction",
494            "alice",
495            CorrectionScope::ThisRun,
496        );
497        let policy = CapabilityPolicy {
498            side_effect_level: Some("network".to_string()),
499            ..CapabilityPolicy::default()
500        };
501
502        let adapted = policy_with_corrections("bot", policy, &[record]);
503
504        assert_eq!(adapted.side_effect_level.as_deref(), Some("network"));
505    }
506
507    #[test]
508    fn correction_parsers_reject_malformed_optional_fields() {
509        let record_error = correction_record_from_json(serde_json::json!({
510            "from_decision": {"actor_id": "bot"},
511            "to_decision": {"outcome": "denied"},
512            "reason": "operator corrected routing",
513            "applied_by": "alice",
514            "evidence_refs": "not-a-list"
515        }))
516        .expect_err("invalid evidence_refs");
517        assert!(record_error.contains("evidence_refs"));
518
519        let limit_error = correction_query_filters_from_json(serde_json::json!({
520            "limit": -1
521        }))
522        .expect_err("invalid limit");
523        assert!(limit_error.contains("non-negative"));
524    }
525}