1use std::collections::{BTreeMap, HashSet};
2use std::sync::Arc;
3
4use serde::{Deserialize, Serialize};
5use time::OffsetDateTime;
6use uuid::Uuid;
7
8use crate::event_log::{AnyEventLog, EventLog, LogError, LogEvent, Topic};
9use crate::orchestration::CapabilityPolicy;
10
11pub const CORRECTION_SCHEMA_V0: &str = "harn-correction/v0";
12pub const CORRECTIONS_TOPIC: &str = "corrections.records";
13pub const CORRECTION_EVENT_KIND: &str = "correction_recorded";
14
15#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
16#[serde(rename_all = "snake_case")]
17pub enum CorrectionScope {
18 #[default]
19 ThisRun,
20 ThisPersona,
21 All,
22}
23
24impl CorrectionScope {
25 pub fn as_str(self) -> &'static str {
26 match self {
27 Self::ThisRun => "this_run",
28 Self::ThisPersona => "this_persona",
29 Self::All => "all",
30 }
31 }
32
33 pub fn parse(value: &str) -> Result<Self, String> {
34 match value {
35 "this_run" => Ok(Self::ThisRun),
36 "this_persona" => Ok(Self::ThisPersona),
37 "all" => Ok(Self::All),
38 other => Err(format!(
39 "unsupported correction scope '{other}', expected this_run|this_persona|all"
40 )),
41 }
42 }
43}
44
45#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
46pub struct CorrectionRecord {
47 pub schema: String,
48 pub correction_id: String,
49 pub from_decision: serde_json::Value,
50 pub to_decision: serde_json::Value,
51 pub reason: String,
52 pub applied_by: String,
53 pub scope: CorrectionScope,
54 #[serde(with = "time::serde::rfc3339")]
55 pub timestamp: OffsetDateTime,
56 #[serde(default, skip_serializing_if = "Option::is_none")]
57 pub actor_id: Option<String>,
58 #[serde(default, skip_serializing_if = "Option::is_none")]
59 pub action: Option<String>,
60 #[serde(default, skip_serializing_if = "Option::is_none")]
61 pub trace_id: Option<String>,
62 #[serde(default, skip_serializing_if = "Option::is_none")]
63 pub step: Option<String>,
64 #[serde(default, skip_serializing_if = "Vec::is_empty")]
65 pub evidence_refs: Vec<serde_json::Value>,
66 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
67 pub metadata: BTreeMap<String, serde_json::Value>,
68}
69
70impl CorrectionRecord {
71 pub fn new(
72 from_decision: serde_json::Value,
73 to_decision: serde_json::Value,
74 reason: impl Into<String>,
75 applied_by: impl Into<String>,
76 scope: CorrectionScope,
77 ) -> Self {
78 let actor_id = extract_string_anywhere(
79 &from_decision,
80 &to_decision,
81 &["actor_id", "actor", "agent", "trigger_id", "binding_id"],
82 );
83 let action =
84 extract_string_anywhere(&from_decision, &to_decision, &["action", "event_kind"]);
85 let trace_id = extract_string_anywhere(&from_decision, &to_decision, &["trace_id"]);
86 Self {
87 schema: CORRECTION_SCHEMA_V0.to_string(),
88 correction_id: Uuid::now_v7().to_string(),
89 from_decision,
90 to_decision,
91 reason: reason.into(),
92 applied_by: applied_by.into(),
93 scope,
94 timestamp: OffsetDateTime::now_utc(),
95 actor_id,
96 action,
97 trace_id,
98 step: None,
99 evidence_refs: Vec::new(),
100 metadata: BTreeMap::new(),
101 }
102 }
103}
104
105#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
106#[serde(default)]
107pub struct CorrectionQueryFilters {
108 pub actor_id: Option<String>,
109 pub action: Option<String>,
110 pub scope: Option<CorrectionScope>,
111 #[serde(with = "time::serde::rfc3339::option")]
112 pub since: Option<OffsetDateTime>,
113 #[serde(with = "time::serde::rfc3339::option")]
114 pub until: Option<OffsetDateTime>,
115 pub limit: Option<usize>,
116}
117
118pub fn corrections_topic() -> Result<Topic, LogError> {
119 Topic::new(CORRECTIONS_TOPIC)
120}
121
122pub async fn append_correction_record(
123 log: &Arc<AnyEventLog>,
124 record: &CorrectionRecord,
125) -> Result<CorrectionRecord, LogError> {
126 let payload = serde_json::to_value(record)
127 .map_err(|error| LogError::Serde(format!("correction record encode error: {error}")))?;
128 let mut headers = BTreeMap::new();
129 headers.insert("correction_id".to_string(), record.correction_id.clone());
130 headers.insert("scope".to_string(), record.scope.as_str().to_string());
131 if let Some(actor_id) = record.actor_id.as_ref() {
132 headers.insert("actor_id".to_string(), actor_id.clone());
133 }
134 if let Some(action) = record.action.as_ref() {
135 headers.insert("action".to_string(), action.clone());
136 }
137 if let Some(trace_id) = record.trace_id.as_ref() {
138 headers.insert("trace_id".to_string(), trace_id.clone());
139 }
140 log.append(
141 &corrections_topic()?,
142 LogEvent::new(CORRECTION_EVENT_KIND, payload).with_headers(headers),
143 )
144 .await?;
145 Ok(record.clone())
146}
147
148pub async fn query_correction_records(
149 log: &Arc<AnyEventLog>,
150 filters: &CorrectionQueryFilters,
151) -> Result<Vec<CorrectionRecord>, LogError> {
152 let mut records = Vec::new();
153 let mut seen = HashSet::new();
154 for (_, event) in log
155 .read_range(&corrections_topic()?, None, usize::MAX)
156 .await?
157 {
158 if event.kind != CORRECTION_EVENT_KIND {
159 continue;
160 }
161 let Ok(record) = serde_json::from_value::<CorrectionRecord>(event.payload) else {
162 continue;
163 };
164 if !matches_filters(&record, filters) {
165 continue;
166 }
167 if seen.insert(record.correction_id.clone()) {
168 records.push(record);
169 }
170 }
171 records.sort_by(|left, right| {
172 left.timestamp
173 .cmp(&right.timestamp)
174 .then(left.correction_id.cmp(&right.correction_id))
175 });
176 apply_limit(&mut records, filters.limit);
177 Ok(records)
178}
179
180pub async fn apply_corrections_to_policy(
181 log: &Arc<AnyEventLog>,
182 actor_id: &str,
183 policy: CapabilityPolicy,
184) -> Result<CapabilityPolicy, LogError> {
185 let corrections = query_correction_records(
186 log,
187 &CorrectionQueryFilters {
188 actor_id: Some(actor_id.to_string()),
189 ..CorrectionQueryFilters::default()
190 },
191 )
192 .await?;
193 Ok(policy_with_corrections(actor_id, policy, &corrections))
194}
195
196pub fn policy_with_corrections(
197 actor_id: &str,
198 mut policy: CapabilityPolicy,
199 corrections: &[CorrectionRecord],
200) -> CapabilityPolicy {
201 let should_tighten = corrections.iter().any(|record| {
202 correction_applies_to_actor(record, actor_id)
203 && matches!(
204 record.scope,
205 CorrectionScope::ThisPersona | CorrectionScope::All
206 )
207 });
208 if should_tighten {
209 cap_side_effect_level(&mut policy, "read_only");
210 }
211 policy
212}
213
214pub fn correction_record_from_json(value: serde_json::Value) -> Result<CorrectionRecord, String> {
215 let object = value
216 .as_object()
217 .ok_or_else(|| "correction record must be a dict".to_string())?;
218 let from_decision = required_json(object, "from_decision")?;
219 let to_decision = required_json(object, "to_decision")?;
220 let reason = required_string(object, "reason")?;
221 let applied_by = required_string(object, "applied_by")?;
222 let scope = object
223 .get("scope")
224 .and_then(serde_json::Value::as_str)
225 .map(CorrectionScope::parse)
226 .transpose()?
227 .unwrap_or_default();
228
229 let mut record = CorrectionRecord::new(
230 from_decision.clone(),
231 to_decision.clone(),
232 reason,
233 applied_by,
234 scope,
235 );
236 if let Some(actor_id) = optional_string(object, "actor_id")
237 .or_else(|| optional_string(object, "actor"))
238 .or_else(|| optional_string(object, "agent"))
239 {
240 record.actor_id = Some(actor_id);
241 }
242 if let Some(action) = optional_string(object, "action") {
243 record.action = Some(action);
244 }
245 if let Some(trace_id) = optional_string(object, "trace_id") {
246 record.trace_id = Some(trace_id);
247 }
248 record.step = optional_string(object, "step");
249 record.evidence_refs = match object.get("evidence_refs") {
250 Some(value) => value
251 .as_array()
252 .cloned()
253 .ok_or_else(|| "correction record field `evidence_refs` must be a list".to_string())?,
254 None => Vec::new(),
255 };
256 if let Some(metadata) = object.get("metadata") {
257 let Some(metadata_object) = metadata.as_object() else {
258 return Err("correction record field `metadata` must be a dict".to_string());
259 };
260 record.metadata.extend(
261 metadata_object
262 .iter()
263 .map(|(key, value)| (key.clone(), value.clone())),
264 );
265 }
266 Ok(record)
267}
268
269pub fn correction_query_filters_from_json(
270 value: serde_json::Value,
271) -> Result<CorrectionQueryFilters, String> {
272 let object = value
273 .as_object()
274 .ok_or_else(|| "correction query filters must be a dict".to_string())?;
275 let scope = object
276 .get("scope")
277 .and_then(serde_json::Value::as_str)
278 .map(CorrectionScope::parse)
279 .transpose()?;
280 Ok(CorrectionQueryFilters {
281 actor_id: optional_string(object, "actor_id")
282 .or_else(|| optional_string(object, "actor"))
283 .or_else(|| optional_string(object, "agent")),
284 action: optional_string(object, "action"),
285 scope,
286 since: object
287 .get("since")
288 .and_then(serde_json::Value::as_str)
289 .map(parse_rfc3339)
290 .transpose()?,
291 until: object
292 .get("until")
293 .and_then(serde_json::Value::as_str)
294 .map(parse_rfc3339)
295 .transpose()?,
296 limit: optional_limit(object, "limit")?,
297 })
298}
299
300fn matches_filters(record: &CorrectionRecord, filters: &CorrectionQueryFilters) -> bool {
301 if let Some(actor_id) = filters.actor_id.as_deref() {
302 if !correction_applies_to_actor(record, actor_id) {
303 return false;
304 }
305 }
306 if let Some(action) = filters.action.as_deref() {
307 if record.action.as_deref() != Some(action) {
308 return false;
309 }
310 }
311 if let Some(scope) = filters.scope {
312 if record.scope != scope {
313 return false;
314 }
315 }
316 if let Some(since) = filters.since {
317 if record.timestamp < since {
318 return false;
319 }
320 }
321 if let Some(until) = filters.until {
322 if record.timestamp > until {
323 return false;
324 }
325 }
326 true
327}
328
329fn correction_applies_to_actor(record: &CorrectionRecord, actor_id: &str) -> bool {
330 record.scope == CorrectionScope::All || record.actor_id.as_deref() == Some(actor_id)
331}
332
333fn cap_side_effect_level(policy: &mut CapabilityPolicy, ceiling: &str) {
334 let current = policy.side_effect_level.as_deref().unwrap_or("network");
335 if side_effect_rank(current) > side_effect_rank(ceiling) {
336 policy.side_effect_level = Some(ceiling.to_string());
337 }
338}
339
340fn side_effect_rank(value: &str) -> usize {
341 match value {
342 "none" => 0,
343 "read_only" => 1,
344 "workspace_write" => 2,
345 "process_exec" => 3,
346 "network" => 4,
347 _ => 5,
348 }
349}
350
351fn apply_limit(records: &mut Vec<CorrectionRecord>, limit: Option<usize>) {
352 let Some(limit) = limit else {
353 return;
354 };
355 if records.len() <= limit {
356 return;
357 }
358 let keep_from = records.len() - limit;
359 records.drain(0..keep_from);
360}
361
362fn required_json(
363 object: &serde_json::Map<String, serde_json::Value>,
364 field: &str,
365) -> Result<serde_json::Value, String> {
366 object
367 .get(field)
368 .cloned()
369 .ok_or_else(|| format!("correction record missing field `{field}`"))
370}
371
372fn required_string(
373 object: &serde_json::Map<String, serde_json::Value>,
374 field: &str,
375) -> Result<String, String> {
376 optional_string(object, field)
377 .ok_or_else(|| format!("correction record missing string field `{field}`"))
378}
379
380fn optional_string(
381 object: &serde_json::Map<String, serde_json::Value>,
382 field: &str,
383) -> Option<String> {
384 object
385 .get(field)
386 .and_then(serde_json::Value::as_str)
387 .map(str::to_string)
388}
389
390fn optional_limit(
391 object: &serde_json::Map<String, serde_json::Value>,
392 field: &str,
393) -> Result<Option<usize>, String> {
394 let Some(value) = object.get(field) else {
395 return Ok(None);
396 };
397 if let Some(limit) = value.as_u64() {
398 return usize::try_from(limit)
399 .map(Some)
400 .map_err(|_| format!("correction query field `{field}` is too large"));
401 }
402 if value.as_i64().is_some() {
403 return Err(format!(
404 "correction query field `{field}` must be non-negative"
405 ));
406 }
407 Err(format!(
408 "correction query field `{field}` must be an integer"
409 ))
410}
411
412fn extract_string_anywhere(
413 first: &serde_json::Value,
414 second: &serde_json::Value,
415 keys: &[&str],
416) -> Option<String> {
417 keys.iter()
418 .find_map(|key| extract_string(first, key).or_else(|| extract_string(second, key)))
419}
420
421fn extract_string(value: &serde_json::Value, key: &str) -> Option<String> {
422 value
423 .as_object()
424 .and_then(|object| object.get(key))
425 .and_then(serde_json::Value::as_str)
426 .map(str::to_string)
427 .or_else(|| {
428 value
429 .get("metadata")
430 .and_then(|metadata| metadata.get(key))
431 .and_then(serde_json::Value::as_str)
432 .map(str::to_string)
433 })
434}
435
436fn parse_rfc3339(value: &str) -> Result<OffsetDateTime, String> {
437 OffsetDateTime::parse(value, &time::format_description::well_known::Rfc3339)
438 .map_err(|error| format!("invalid RFC3339 timestamp '{value}': {error}"))
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444 use crate::event_log::MemoryEventLog;
445
446 #[tokio::test]
447 async fn append_query_and_policy_round_trip() {
448 let log: Arc<AnyEventLog> = Arc::new(AnyEventLog::Memory(MemoryEventLog::new(16)));
449 let mut record = CorrectionRecord::new(
450 serde_json::json!({
451 "actor_id": "bot",
452 "action": "github.issue.opened",
453 "outcome": "success"
454 }),
455 serde_json::json!({"outcome": "denied"}),
456 "operator corrected unsafe issue triage",
457 "alice",
458 CorrectionScope::ThisPersona,
459 );
460 record.trace_id = Some("trace-correction".to_string());
461 append_correction_record(&log, &record).await.unwrap();
462
463 let records = query_correction_records(
464 &log,
465 &CorrectionQueryFilters {
466 actor_id: Some("bot".to_string()),
467 action: Some("github.issue.opened".to_string()),
468 ..CorrectionQueryFilters::default()
469 },
470 )
471 .await
472 .unwrap();
473
474 assert_eq!(records.len(), 1);
475 assert_eq!(records[0].applied_by, "alice");
476 assert_eq!(records[0].scope, CorrectionScope::ThisPersona);
477
478 let policy = CapabilityPolicy {
479 side_effect_level: Some("network".to_string()),
480 ..CapabilityPolicy::default()
481 };
482 let adapted = apply_corrections_to_policy(&log, "bot", policy)
483 .await
484 .unwrap();
485 assert_eq!(adapted.side_effect_level.as_deref(), Some("read_only"));
486 }
487
488 #[test]
489 fn this_run_scope_does_not_change_persona_policy() {
490 let record = CorrectionRecord::new(
491 serde_json::json!({"actor_id": "bot", "action": "deploy"}),
492 serde_json::json!({"outcome": "denied"}),
493 "single-run correction",
494 "alice",
495 CorrectionScope::ThisRun,
496 );
497 let policy = CapabilityPolicy {
498 side_effect_level: Some("network".to_string()),
499 ..CapabilityPolicy::default()
500 };
501
502 let adapted = policy_with_corrections("bot", policy, &[record]);
503
504 assert_eq!(adapted.side_effect_level.as_deref(), Some("network"));
505 }
506
507 #[test]
508 fn correction_parsers_reject_malformed_optional_fields() {
509 let record_error = correction_record_from_json(serde_json::json!({
510 "from_decision": {"actor_id": "bot"},
511 "to_decision": {"outcome": "denied"},
512 "reason": "operator corrected routing",
513 "applied_by": "alice",
514 "evidence_refs": "not-a-list"
515 }))
516 .expect_err("invalid evidence_refs");
517 assert!(record_error.contains("evidence_refs"));
518
519 let limit_error = correction_query_filters_from_json(serde_json::json!({
520 "limit": -1
521 }))
522 .expect_err("invalid limit");
523 assert!(limit_error.contains("non-negative"));
524 }
525}