Skip to main content

harn_vm/corrections/
mod.rs

1use std::collections::{BTreeMap, HashSet};
2use std::sync::Arc;
3
4use serde::{Deserialize, Serialize};
5use time::OffsetDateTime;
6use uuid::Uuid;
7
8use crate::event_log::{AnyEventLog, EventLog, LogError, LogEvent, Topic};
9use crate::orchestration::CapabilityPolicy;
10
11pub const CORRECTION_SCHEMA_V0: &str = "harn-correction/v0";
12pub const CORRECTIONS_TOPIC: &str = "corrections.records";
13pub const CORRECTION_EVENT_KIND: &str = "correction_recorded";
14
15#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
16#[serde(rename_all = "snake_case")]
17pub enum CorrectionScope {
18    #[default]
19    ThisRun,
20    ThisPersona,
21    All,
22}
23
24impl CorrectionScope {
25    pub fn as_str(self) -> &'static str {
26        match self {
27            Self::ThisRun => "this_run",
28            Self::ThisPersona => "this_persona",
29            Self::All => "all",
30        }
31    }
32
33    pub fn parse(value: &str) -> Result<Self, String> {
34        match value {
35            "this_run" => Ok(Self::ThisRun),
36            "this_persona" => Ok(Self::ThisPersona),
37            "all" => Ok(Self::All),
38            other => Err(format!(
39                "unsupported correction scope '{other}', expected this_run|this_persona|all"
40            )),
41        }
42    }
43}
44
45#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
46pub struct CorrectionRecord {
47    pub schema: String,
48    pub correction_id: String,
49    pub from_decision: serde_json::Value,
50    pub to_decision: serde_json::Value,
51    pub reason: String,
52    pub applied_by: String,
53    pub scope: CorrectionScope,
54    #[serde(with = "time::serde::rfc3339")]
55    pub timestamp: OffsetDateTime,
56    #[serde(default, skip_serializing_if = "Option::is_none")]
57    pub actor_id: Option<String>,
58    #[serde(default, skip_serializing_if = "Option::is_none")]
59    pub action: Option<String>,
60    #[serde(default, skip_serializing_if = "Option::is_none")]
61    pub trace_id: Option<String>,
62    #[serde(default, skip_serializing_if = "Option::is_none")]
63    pub step: Option<String>,
64    #[serde(default, skip_serializing_if = "Vec::is_empty")]
65    pub evidence_refs: Vec<serde_json::Value>,
66    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
67    pub metadata: BTreeMap<String, serde_json::Value>,
68}
69
70impl CorrectionRecord {
71    pub fn new(
72        from_decision: serde_json::Value,
73        to_decision: serde_json::Value,
74        reason: impl Into<String>,
75        applied_by: impl Into<String>,
76        scope: CorrectionScope,
77    ) -> Self {
78        let actor_id = extract_string_anywhere(
79            &from_decision,
80            &to_decision,
81            &["actor_id", "actor", "agent", "trigger_id", "binding_id"],
82        );
83        let action =
84            extract_string_anywhere(&from_decision, &to_decision, &["action", "event_kind"]);
85        let trace_id = extract_string_anywhere(&from_decision, &to_decision, &["trace_id"]);
86        Self {
87            schema: CORRECTION_SCHEMA_V0.to_string(),
88            correction_id: Uuid::now_v7().to_string(),
89            from_decision,
90            to_decision,
91            reason: reason.into(),
92            applied_by: applied_by.into(),
93            scope,
94            timestamp: OffsetDateTime::now_utc(),
95            actor_id,
96            action,
97            trace_id,
98            step: None,
99            evidence_refs: Vec::new(),
100            metadata: BTreeMap::new(),
101        }
102    }
103}
104
105#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
106#[serde(default)]
107pub struct CorrectionQueryFilters {
108    pub actor_id: Option<String>,
109    pub action: Option<String>,
110    pub scope: Option<CorrectionScope>,
111    #[serde(with = "time::serde::rfc3339::option")]
112    pub since: Option<OffsetDateTime>,
113    #[serde(with = "time::serde::rfc3339::option")]
114    pub until: Option<OffsetDateTime>,
115    pub limit: Option<usize>,
116}
117
118pub fn corrections_topic() -> Result<Topic, LogError> {
119    Topic::new(CORRECTIONS_TOPIC)
120}
121
122pub async fn append_correction_record(
123    log: &Arc<AnyEventLog>,
124    record: &CorrectionRecord,
125) -> Result<CorrectionRecord, LogError> {
126    let payload = serde_json::to_value(record)
127        .map_err(|error| LogError::Serde(format!("correction record encode error: {error}")))?;
128    let mut headers = BTreeMap::new();
129    headers.insert("correction_id".to_string(), record.correction_id.clone());
130    headers.insert("scope".to_string(), record.scope.as_str().to_string());
131    if let Some(actor_id) = record.actor_id.as_ref() {
132        headers.insert("actor_id".to_string(), actor_id.clone());
133    }
134    if let Some(action) = record.action.as_ref() {
135        headers.insert("action".to_string(), action.clone());
136    }
137    if let Some(trace_id) = record.trace_id.as_ref() {
138        headers.insert("trace_id".to_string(), trace_id.clone());
139    }
140    log.append(
141        &corrections_topic()?,
142        LogEvent::new(CORRECTION_EVENT_KIND, payload).with_headers(headers),
143    )
144    .await?;
145    Ok(record.clone())
146}
147
148pub async fn query_correction_records(
149    log: &Arc<AnyEventLog>,
150    filters: &CorrectionQueryFilters,
151) -> Result<Vec<CorrectionRecord>, LogError> {
152    let mut records = Vec::new();
153    let mut seen = HashSet::new();
154    for (_, event) in log
155        .read_range(&corrections_topic()?, None, usize::MAX)
156        .await?
157    {
158        if event.kind != CORRECTION_EVENT_KIND {
159            continue;
160        }
161        let Ok(record) = serde_json::from_value::<CorrectionRecord>(event.payload) else {
162            continue;
163        };
164        if !matches_filters(&record, filters) {
165            continue;
166        }
167        if seen.insert(record.correction_id.clone()) {
168            records.push(record);
169        }
170    }
171    records.sort_by(|left, right| {
172        left.timestamp
173            .cmp(&right.timestamp)
174            .then(left.correction_id.cmp(&right.correction_id))
175    });
176    apply_limit(&mut records, filters.limit);
177    Ok(records)
178}
179
180pub async fn apply_corrections_to_policy(
181    log: &Arc<AnyEventLog>,
182    actor_id: &str,
183    policy: CapabilityPolicy,
184) -> Result<CapabilityPolicy, LogError> {
185    let corrections = query_correction_records(
186        log,
187        &CorrectionQueryFilters {
188            actor_id: Some(actor_id.to_string()),
189            ..CorrectionQueryFilters::default()
190        },
191    )
192    .await?;
193    Ok(policy_with_corrections(actor_id, policy, &corrections))
194}
195
196pub fn policy_with_corrections(
197    actor_id: &str,
198    mut policy: CapabilityPolicy,
199    corrections: &[CorrectionRecord],
200) -> CapabilityPolicy {
201    let should_tighten = corrections.iter().any(|record| {
202        correction_applies_to_actor(record, actor_id)
203            && matches!(
204                record.scope,
205                CorrectionScope::ThisPersona | CorrectionScope::All
206            )
207    });
208    if should_tighten {
209        cap_side_effect_level(&mut policy, "read_only");
210    }
211    policy
212}
213
214pub fn correction_record_from_json(value: serde_json::Value) -> Result<CorrectionRecord, String> {
215    let object = value
216        .as_object()
217        .ok_or_else(|| "correction record must be a dict".to_string())?;
218    let from_decision = required_json(object, "from_decision")?;
219    let to_decision = required_json(object, "to_decision")?;
220    let reason = required_string(object, "reason")?;
221    let applied_by = required_string(object, "applied_by")?;
222    let scope = object
223        .get("scope")
224        .and_then(serde_json::Value::as_str)
225        .map(CorrectionScope::parse)
226        .transpose()?
227        .unwrap_or_default();
228
229    let mut record = CorrectionRecord::new(from_decision, to_decision, reason, applied_by, scope);
230    if let Some(actor_id) = optional_string(object, "actor_id")
231        .or_else(|| optional_string(object, "actor"))
232        .or_else(|| optional_string(object, "agent"))
233    {
234        record.actor_id = Some(actor_id);
235    }
236    if let Some(action) = optional_string(object, "action") {
237        record.action = Some(action);
238    }
239    if let Some(trace_id) = optional_string(object, "trace_id") {
240        record.trace_id = Some(trace_id);
241    }
242    record.step = optional_string(object, "step");
243    record.evidence_refs = match object.get("evidence_refs") {
244        Some(value) => value
245            .as_array()
246            .cloned()
247            .ok_or_else(|| "correction record field `evidence_refs` must be a list".to_string())?,
248        None => Vec::new(),
249    };
250    if let Some(metadata) = object.get("metadata") {
251        let Some(metadata_object) = metadata.as_object() else {
252            return Err("correction record field `metadata` must be a dict".to_string());
253        };
254        record.metadata.extend(
255            metadata_object
256                .iter()
257                .map(|(key, value)| (key.clone(), value.clone())),
258        );
259    }
260    Ok(record)
261}
262
263pub fn correction_query_filters_from_json(
264    value: serde_json::Value,
265) -> Result<CorrectionQueryFilters, String> {
266    let object = value
267        .as_object()
268        .ok_or_else(|| "correction query filters must be a dict".to_string())?;
269    let scope = object
270        .get("scope")
271        .and_then(serde_json::Value::as_str)
272        .map(CorrectionScope::parse)
273        .transpose()?;
274    Ok(CorrectionQueryFilters {
275        actor_id: optional_string(object, "actor_id")
276            .or_else(|| optional_string(object, "actor"))
277            .or_else(|| optional_string(object, "agent")),
278        action: optional_string(object, "action"),
279        scope,
280        since: object
281            .get("since")
282            .and_then(serde_json::Value::as_str)
283            .map(parse_rfc3339)
284            .transpose()?,
285        until: object
286            .get("until")
287            .and_then(serde_json::Value::as_str)
288            .map(parse_rfc3339)
289            .transpose()?,
290        limit: optional_limit(object, "limit")?,
291    })
292}
293
294fn matches_filters(record: &CorrectionRecord, filters: &CorrectionQueryFilters) -> bool {
295    if let Some(actor_id) = filters.actor_id.as_deref() {
296        if !correction_applies_to_actor(record, actor_id) {
297            return false;
298        }
299    }
300    if let Some(action) = filters.action.as_deref() {
301        if record.action.as_deref() != Some(action) {
302            return false;
303        }
304    }
305    if let Some(scope) = filters.scope {
306        if record.scope != scope {
307            return false;
308        }
309    }
310    if let Some(since) = filters.since {
311        if record.timestamp < since {
312            return false;
313        }
314    }
315    if let Some(until) = filters.until {
316        if record.timestamp > until {
317            return false;
318        }
319    }
320    true
321}
322
323fn correction_applies_to_actor(record: &CorrectionRecord, actor_id: &str) -> bool {
324    record.scope == CorrectionScope::All || record.actor_id.as_deref() == Some(actor_id)
325}
326
327fn cap_side_effect_level(policy: &mut CapabilityPolicy, ceiling: &str) {
328    let current = policy.side_effect_level.as_deref().unwrap_or("network");
329    if side_effect_rank(current) > side_effect_rank(ceiling) {
330        policy.side_effect_level = Some(ceiling.to_string());
331    }
332}
333
334fn side_effect_rank(value: &str) -> usize {
335    match value {
336        "none" => 0,
337        "read_only" => 1,
338        "workspace_write" => 2,
339        "process_exec" => 3,
340        "network" => 4,
341        _ => 5,
342    }
343}
344
345fn apply_limit(records: &mut Vec<CorrectionRecord>, limit: Option<usize>) {
346    let Some(limit) = limit else {
347        return;
348    };
349    if records.len() <= limit {
350        return;
351    }
352    let keep_from = records.len() - limit;
353    records.drain(0..keep_from);
354}
355
356fn required_json(
357    object: &serde_json::Map<String, serde_json::Value>,
358    field: &str,
359) -> Result<serde_json::Value, String> {
360    object
361        .get(field)
362        .cloned()
363        .ok_or_else(|| format!("correction record missing field `{field}`"))
364}
365
366fn required_string(
367    object: &serde_json::Map<String, serde_json::Value>,
368    field: &str,
369) -> Result<String, String> {
370    optional_string(object, field)
371        .ok_or_else(|| format!("correction record missing string field `{field}`"))
372}
373
374fn optional_string(
375    object: &serde_json::Map<String, serde_json::Value>,
376    field: &str,
377) -> Option<String> {
378    object
379        .get(field)
380        .and_then(serde_json::Value::as_str)
381        .map(str::to_string)
382}
383
384fn optional_limit(
385    object: &serde_json::Map<String, serde_json::Value>,
386    field: &str,
387) -> Result<Option<usize>, String> {
388    let Some(value) = object.get(field) else {
389        return Ok(None);
390    };
391    if let Some(limit) = value.as_u64() {
392        return usize::try_from(limit)
393            .map(Some)
394            .map_err(|_| format!("correction query field `{field}` is too large"));
395    }
396    if value.as_i64().is_some() {
397        return Err(format!(
398            "correction query field `{field}` must be non-negative"
399        ));
400    }
401    Err(format!(
402        "correction query field `{field}` must be an integer"
403    ))
404}
405
406fn extract_string_anywhere(
407    first: &serde_json::Value,
408    second: &serde_json::Value,
409    keys: &[&str],
410) -> Option<String> {
411    keys.iter()
412        .find_map(|key| extract_string(first, key).or_else(|| extract_string(second, key)))
413}
414
415fn extract_string(value: &serde_json::Value, key: &str) -> Option<String> {
416    value
417        .as_object()
418        .and_then(|object| object.get(key))
419        .and_then(serde_json::Value::as_str)
420        .map(str::to_string)
421        .or_else(|| {
422            value
423                .get("metadata")
424                .and_then(|metadata| metadata.get(key))
425                .and_then(serde_json::Value::as_str)
426                .map(str::to_string)
427        })
428}
429
430fn parse_rfc3339(value: &str) -> Result<OffsetDateTime, String> {
431    OffsetDateTime::parse(value, &time::format_description::well_known::Rfc3339)
432        .map_err(|error| format!("invalid RFC3339 timestamp '{value}': {error}"))
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use crate::event_log::MemoryEventLog;
439
440    #[tokio::test]
441    async fn append_query_and_policy_round_trip() {
442        let log: Arc<AnyEventLog> = Arc::new(AnyEventLog::Memory(MemoryEventLog::new(16)));
443        let mut record = CorrectionRecord::new(
444            serde_json::json!({
445                "actor_id": "bot",
446                "action": "github.issue.opened",
447                "outcome": "success"
448            }),
449            serde_json::json!({"outcome": "denied"}),
450            "operator corrected unsafe issue triage",
451            "alice",
452            CorrectionScope::ThisPersona,
453        );
454        record.trace_id = Some("trace-correction".to_string());
455        append_correction_record(&log, &record).await.unwrap();
456
457        let records = query_correction_records(
458            &log,
459            &CorrectionQueryFilters {
460                actor_id: Some("bot".to_string()),
461                action: Some("github.issue.opened".to_string()),
462                ..CorrectionQueryFilters::default()
463            },
464        )
465        .await
466        .unwrap();
467
468        assert_eq!(records.len(), 1);
469        assert_eq!(records[0].applied_by, "alice");
470        assert_eq!(records[0].scope, CorrectionScope::ThisPersona);
471
472        let policy = CapabilityPolicy {
473            side_effect_level: Some("network".to_string()),
474            ..CapabilityPolicy::default()
475        };
476        let adapted = apply_corrections_to_policy(&log, "bot", policy)
477            .await
478            .unwrap();
479        assert_eq!(adapted.side_effect_level.as_deref(), Some("read_only"));
480    }
481
482    #[test]
483    fn this_run_scope_does_not_change_persona_policy() {
484        let record = CorrectionRecord::new(
485            serde_json::json!({"actor_id": "bot", "action": "deploy"}),
486            serde_json::json!({"outcome": "denied"}),
487            "single-run correction",
488            "alice",
489            CorrectionScope::ThisRun,
490        );
491        let policy = CapabilityPolicy {
492            side_effect_level: Some("network".to_string()),
493            ..CapabilityPolicy::default()
494        };
495
496        let adapted = policy_with_corrections("bot", policy, &[record]);
497
498        assert_eq!(adapted.side_effect_level.as_deref(), Some("network"));
499    }
500
501    #[test]
502    fn correction_parsers_reject_malformed_optional_fields() {
503        let record_error = correction_record_from_json(serde_json::json!({
504            "from_decision": {"actor_id": "bot"},
505            "to_decision": {"outcome": "denied"},
506            "reason": "operator corrected routing",
507            "applied_by": "alice",
508            "evidence_refs": "not-a-list"
509        }))
510        .expect_err("invalid evidence_refs");
511        assert!(record_error.contains("evidence_refs"));
512
513        let limit_error = correction_query_filters_from_json(serde_json::json!({
514            "limit": -1
515        }))
516        .expect_err("invalid limit");
517        assert!(limit_error.contains("non-negative"));
518    }
519}