Skip to main content

harn_vm/corrections/
mod.rs

1use std::collections::{BTreeMap, HashSet};
2use std::sync::Arc;
3
4use serde::{Deserialize, Serialize};
5use time::OffsetDateTime;
6use uuid::Uuid;
7
8use crate::event_log::{AnyEventLog, EventLog, LogError, LogEvent, Topic};
9use crate::orchestration::CapabilityPolicy;
10
11pub const CORRECTION_SCHEMA_V0: &str = "harn-correction/v0";
12pub const CORRECTIONS_TOPIC: &str = "corrections.records";
13pub const CORRECTION_EVENT_KIND: &str = "correction_recorded";
14
15#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
16#[serde(rename_all = "snake_case")]
17pub enum CorrectionScope {
18    #[default]
19    ThisRun,
20    ThisPersona,
21    All,
22}
23
24impl CorrectionScope {
25    pub fn as_str(self) -> &'static str {
26        match self {
27            Self::ThisRun => "this_run",
28            Self::ThisPersona => "this_persona",
29            Self::All => "all",
30        }
31    }
32
33    pub fn parse(value: &str) -> Result<Self, String> {
34        match value {
35            "this_run" => Ok(Self::ThisRun),
36            "this_persona" => Ok(Self::ThisPersona),
37            "all" => Ok(Self::All),
38            other => Err(format!(
39                "unsupported correction scope '{other}', expected this_run|this_persona|all"
40            )),
41        }
42    }
43}
44
45#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
46pub struct CorrectionRecord {
47    pub schema: String,
48    pub correction_id: String,
49    pub from_decision: serde_json::Value,
50    pub to_decision: serde_json::Value,
51    pub reason: String,
52    pub applied_by: String,
53    pub scope: CorrectionScope,
54    #[serde(with = "time::serde::rfc3339")]
55    pub timestamp: OffsetDateTime,
56    #[serde(default, skip_serializing_if = "Option::is_none")]
57    pub actor_id: Option<String>,
58    #[serde(default, skip_serializing_if = "Option::is_none")]
59    pub action: Option<String>,
60    #[serde(default, skip_serializing_if = "Option::is_none")]
61    pub trace_id: Option<String>,
62    #[serde(default, skip_serializing_if = "Option::is_none")]
63    pub step: Option<String>,
64    #[serde(default, skip_serializing_if = "Vec::is_empty")]
65    pub evidence_refs: Vec<serde_json::Value>,
66    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
67    pub metadata: BTreeMap<String, serde_json::Value>,
68}
69
70impl CorrectionRecord {
71    pub fn new(
72        from_decision: serde_json::Value,
73        to_decision: serde_json::Value,
74        reason: impl Into<String>,
75        applied_by: impl Into<String>,
76        scope: CorrectionScope,
77    ) -> Self {
78        let actor_id = extract_string_anywhere(
79            &from_decision,
80            &to_decision,
81            &["actor_id", "actor", "agent", "trigger_id", "binding_id"],
82        );
83        let action =
84            extract_string_anywhere(&from_decision, &to_decision, &["action", "event_kind"]);
85        let trace_id = extract_string_anywhere(&from_decision, &to_decision, &["trace_id"]);
86        Self {
87            schema: CORRECTION_SCHEMA_V0.to_string(),
88            correction_id: Uuid::now_v7().to_string(),
89            from_decision,
90            to_decision,
91            reason: reason.into(),
92            applied_by: applied_by.into(),
93            scope,
94            timestamp: OffsetDateTime::now_utc(),
95            actor_id,
96            action,
97            trace_id,
98            step: None,
99            evidence_refs: Vec::new(),
100            metadata: BTreeMap::new(),
101        }
102    }
103}
104
105#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
106#[serde(default)]
107pub struct CorrectionQueryFilters {
108    pub actor_id: Option<String>,
109    pub action: Option<String>,
110    pub scope: Option<CorrectionScope>,
111    #[serde(with = "time::serde::rfc3339::option")]
112    pub since: Option<OffsetDateTime>,
113    #[serde(with = "time::serde::rfc3339::option")]
114    pub until: Option<OffsetDateTime>,
115    pub limit: Option<usize>,
116}
117
118pub fn corrections_topic() -> Result<Topic, LogError> {
119    Topic::new(CORRECTIONS_TOPIC)
120}
121
122pub async fn append_correction_record(
123    log: &Arc<AnyEventLog>,
124    record: &CorrectionRecord,
125) -> Result<CorrectionRecord, LogError> {
126    let payload = serde_json::to_value(record)
127        .map_err(|error| LogError::Serde(format!("correction record encode error: {error}")))?;
128    let mut headers = BTreeMap::new();
129    headers.insert("correction_id".to_string(), record.correction_id.clone());
130    headers.insert("scope".to_string(), record.scope.as_str().to_string());
131    if let Some(actor_id) = record.actor_id.as_ref() {
132        headers.insert("actor_id".to_string(), actor_id.clone());
133    }
134    if let Some(action) = record.action.as_ref() {
135        headers.insert("action".to_string(), action.clone());
136    }
137    if let Some(trace_id) = record.trace_id.as_ref() {
138        headers.insert("trace_id".to_string(), trace_id.clone());
139    }
140    log.append(
141        &corrections_topic()?,
142        LogEvent::new(CORRECTION_EVENT_KIND, payload).with_headers(headers),
143    )
144    .await?;
145    Ok(record.clone())
146}
147
148pub async fn query_correction_records(
149    log: &Arc<AnyEventLog>,
150    filters: &CorrectionQueryFilters,
151) -> Result<Vec<CorrectionRecord>, LogError> {
152    let mut records = Vec::new();
153    let mut seen = HashSet::new();
154    for (_, event) in log
155        .read_range(&corrections_topic()?, None, usize::MAX)
156        .await?
157    {
158        if event.kind != CORRECTION_EVENT_KIND {
159            continue;
160        }
161        let Ok(record) = serde_json::from_value::<CorrectionRecord>(event.payload) else {
162            continue;
163        };
164        if !matches_filters(&record, filters) {
165            continue;
166        }
167        if seen.insert(record.correction_id.clone()) {
168            records.push(record);
169        }
170    }
171    records.sort_by(|left, right| {
172        left.timestamp
173            .cmp(&right.timestamp)
174            .then(left.correction_id.cmp(&right.correction_id))
175    });
176    apply_limit(&mut records, filters.limit);
177    Ok(records)
178}
179
180pub async fn apply_corrections_to_policy(
181    log: &Arc<AnyEventLog>,
182    actor_id: &str,
183    policy: CapabilityPolicy,
184) -> Result<CapabilityPolicy, LogError> {
185    let corrections = query_correction_records(
186        log,
187        &CorrectionQueryFilters {
188            actor_id: Some(actor_id.to_string()),
189            ..CorrectionQueryFilters::default()
190        },
191    )
192    .await?;
193    Ok(policy_with_corrections(actor_id, policy, &corrections))
194}
195
196pub fn policy_with_corrections(
197    actor_id: &str,
198    mut policy: CapabilityPolicy,
199    corrections: &[CorrectionRecord],
200) -> CapabilityPolicy {
201    let should_tighten = corrections.iter().any(|record| {
202        correction_applies_to_actor(record, actor_id)
203            && matches!(
204                record.scope,
205                CorrectionScope::ThisPersona | CorrectionScope::All
206            )
207    });
208    if should_tighten {
209        cap_side_effect_level(&mut policy, "read_only");
210    }
211    policy
212}
213
214pub fn correction_record_from_json(value: serde_json::Value) -> Result<CorrectionRecord, String> {
215    let object = value
216        .as_object()
217        .ok_or_else(|| "correction record must be a dict".to_string())?;
218    let from_decision = required_json(object, "from_decision")?;
219    let to_decision = required_json(object, "to_decision")?;
220    let reason = required_string(object, "reason")?;
221    let applied_by = required_string(object, "applied_by")?;
222    let scope = object
223        .get("scope")
224        .and_then(serde_json::Value::as_str)
225        .map(CorrectionScope::parse)
226        .transpose()?
227        .unwrap_or_default();
228
229    let mut record = CorrectionRecord::new(from_decision, to_decision, reason, applied_by, scope);
230    if let Some(actor_id) = optional_string(object, "actor_id")
231        .or_else(|| optional_string(object, "actor"))
232        .or_else(|| optional_string(object, "agent"))
233    {
234        record.actor_id = Some(actor_id);
235    }
236    if let Some(action) = optional_string(object, "action") {
237        record.action = Some(action);
238    }
239    if let Some(trace_id) = optional_string(object, "trace_id") {
240        record.trace_id = Some(trace_id);
241    }
242    record.step = optional_string(object, "step");
243    record.evidence_refs = match object.get("evidence_refs") {
244        Some(value) => value
245            .as_array()
246            .cloned()
247            .ok_or_else(|| "correction record field `evidence_refs` must be a list".to_string())?,
248        None => Vec::new(),
249    };
250    if let Some(metadata) = object.get("metadata") {
251        let Some(metadata_object) = metadata.as_object() else {
252            return Err("correction record field `metadata` must be a dict".to_string());
253        };
254        record.metadata.extend(
255            metadata_object
256                .iter()
257                .map(|(key, value)| (key.clone(), value.clone())),
258        );
259    }
260    Ok(record)
261}
262
263pub fn correction_query_filters_from_json(
264    value: serde_json::Value,
265) -> Result<CorrectionQueryFilters, String> {
266    let object = value
267        .as_object()
268        .ok_or_else(|| "correction query filters must be a dict".to_string())?;
269    let scope = object
270        .get("scope")
271        .and_then(serde_json::Value::as_str)
272        .map(CorrectionScope::parse)
273        .transpose()?;
274    Ok(CorrectionQueryFilters {
275        actor_id: optional_string(object, "actor_id")
276            .or_else(|| optional_string(object, "actor"))
277            .or_else(|| optional_string(object, "agent")),
278        action: optional_string(object, "action"),
279        scope,
280        since: object
281            .get("since")
282            .and_then(serde_json::Value::as_str)
283            .map(parse_rfc3339)
284            .transpose()?,
285        until: object
286            .get("until")
287            .and_then(serde_json::Value::as_str)
288            .map(parse_rfc3339)
289            .transpose()?,
290        limit: optional_limit(object, "limit")?,
291    })
292}
293
294fn matches_filters(record: &CorrectionRecord, filters: &CorrectionQueryFilters) -> bool {
295    if let Some(actor_id) = filters.actor_id.as_deref() {
296        if !correction_applies_to_actor(record, actor_id) {
297            return false;
298        }
299    }
300    if let Some(action) = filters.action.as_deref() {
301        if record.action.as_deref() != Some(action) {
302            return false;
303        }
304    }
305    if let Some(scope) = filters.scope {
306        if record.scope != scope {
307            return false;
308        }
309    }
310    if let Some(since) = filters.since {
311        if record.timestamp < since {
312            return false;
313        }
314    }
315    if let Some(until) = filters.until {
316        if record.timestamp > until {
317            return false;
318        }
319    }
320    true
321}
322
323fn correction_applies_to_actor(record: &CorrectionRecord, actor_id: &str) -> bool {
324    record.scope == CorrectionScope::All || record.actor_id.as_deref() == Some(actor_id)
325}
326
327fn cap_side_effect_level(policy: &mut CapabilityPolicy, ceiling: &str) {
328    use crate::tool_annotations::SideEffectLevel;
329    let current = policy.side_effect_level.as_deref().unwrap_or("network");
330    if SideEffectLevel::rank_str(current) > SideEffectLevel::rank_str(ceiling) {
331        policy.side_effect_level = Some(ceiling.to_string());
332    }
333}
334
335fn apply_limit(records: &mut Vec<CorrectionRecord>, limit: Option<usize>) {
336    let Some(limit) = limit else {
337        return;
338    };
339    if records.len() <= limit {
340        return;
341    }
342    let keep_from = records.len() - limit;
343    records.drain(0..keep_from);
344}
345
346fn required_json(
347    object: &serde_json::Map<String, serde_json::Value>,
348    field: &str,
349) -> Result<serde_json::Value, String> {
350    object
351        .get(field)
352        .cloned()
353        .ok_or_else(|| format!("correction record missing field `{field}`"))
354}
355
356fn required_string(
357    object: &serde_json::Map<String, serde_json::Value>,
358    field: &str,
359) -> Result<String, String> {
360    optional_string(object, field)
361        .ok_or_else(|| format!("correction record missing string field `{field}`"))
362}
363
364fn optional_string(
365    object: &serde_json::Map<String, serde_json::Value>,
366    field: &str,
367) -> Option<String> {
368    object
369        .get(field)
370        .and_then(serde_json::Value::as_str)
371        .map(str::to_string)
372}
373
374fn optional_limit(
375    object: &serde_json::Map<String, serde_json::Value>,
376    field: &str,
377) -> Result<Option<usize>, String> {
378    let Some(value) = object.get(field) else {
379        return Ok(None);
380    };
381    if let Some(limit) = value.as_u64() {
382        return usize::try_from(limit)
383            .map(Some)
384            .map_err(|_| format!("correction query field `{field}` is too large"));
385    }
386    if value.as_i64().is_some() {
387        return Err(format!(
388            "correction query field `{field}` must be non-negative"
389        ));
390    }
391    Err(format!(
392        "correction query field `{field}` must be an integer"
393    ))
394}
395
396fn extract_string_anywhere(
397    first: &serde_json::Value,
398    second: &serde_json::Value,
399    keys: &[&str],
400) -> Option<String> {
401    keys.iter()
402        .find_map(|key| extract_string(first, key).or_else(|| extract_string(second, key)))
403}
404
405fn extract_string(value: &serde_json::Value, key: &str) -> Option<String> {
406    value
407        .as_object()
408        .and_then(|object| object.get(key))
409        .and_then(serde_json::Value::as_str)
410        .map(str::to_string)
411        .or_else(|| {
412            value
413                .get("metadata")
414                .and_then(|metadata| metadata.get(key))
415                .and_then(serde_json::Value::as_str)
416                .map(str::to_string)
417        })
418}
419
420fn parse_rfc3339(value: &str) -> Result<OffsetDateTime, String> {
421    OffsetDateTime::parse(value, &time::format_description::well_known::Rfc3339)
422        .map_err(|error| format!("invalid RFC3339 timestamp '{value}': {error}"))
423}
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428    use crate::event_log::MemoryEventLog;
429
430    #[tokio::test]
431    async fn append_query_and_policy_round_trip() {
432        let log: Arc<AnyEventLog> = Arc::new(AnyEventLog::Memory(MemoryEventLog::new(16)));
433        let mut record = CorrectionRecord::new(
434            serde_json::json!({
435                "actor_id": "bot",
436                "action": "github.issue.opened",
437                "outcome": "success"
438            }),
439            serde_json::json!({"outcome": "denied"}),
440            "operator corrected unsafe issue triage",
441            "alice",
442            CorrectionScope::ThisPersona,
443        );
444        record.trace_id = Some("trace-correction".to_string());
445        append_correction_record(&log, &record).await.unwrap();
446
447        let records = query_correction_records(
448            &log,
449            &CorrectionQueryFilters {
450                actor_id: Some("bot".to_string()),
451                action: Some("github.issue.opened".to_string()),
452                ..CorrectionQueryFilters::default()
453            },
454        )
455        .await
456        .unwrap();
457
458        assert_eq!(records.len(), 1);
459        assert_eq!(records[0].applied_by, "alice");
460        assert_eq!(records[0].scope, CorrectionScope::ThisPersona);
461
462        let policy = CapabilityPolicy {
463            side_effect_level: Some("network".to_string()),
464            ..CapabilityPolicy::default()
465        };
466        let adapted = apply_corrections_to_policy(&log, "bot", policy)
467            .await
468            .unwrap();
469        assert_eq!(adapted.side_effect_level.as_deref(), Some("read_only"));
470    }
471
472    #[test]
473    fn this_run_scope_does_not_change_persona_policy() {
474        let record = CorrectionRecord::new(
475            serde_json::json!({"actor_id": "bot", "action": "deploy"}),
476            serde_json::json!({"outcome": "denied"}),
477            "single-run correction",
478            "alice",
479            CorrectionScope::ThisRun,
480        );
481        let policy = CapabilityPolicy {
482            side_effect_level: Some("network".to_string()),
483            ..CapabilityPolicy::default()
484        };
485
486        let adapted = policy_with_corrections("bot", policy, &[record]);
487
488        assert_eq!(adapted.side_effect_level.as_deref(), Some("network"));
489    }
490
491    #[test]
492    fn correction_parsers_reject_malformed_optional_fields() {
493        let record_error = correction_record_from_json(serde_json::json!({
494            "from_decision": {"actor_id": "bot"},
495            "to_decision": {"outcome": "denied"},
496            "reason": "operator corrected routing",
497            "applied_by": "alice",
498            "evidence_refs": "not-a-list"
499        }))
500        .expect_err("invalid evidence_refs");
501        assert!(record_error.contains("evidence_refs"));
502
503        let limit_error = correction_query_filters_from_json(serde_json::json!({
504            "limit": -1
505        }))
506        .expect_err("invalid limit");
507        assert!(limit_error.contains("non-negative"));
508    }
509}