1use std::collections::{BTreeMap, HashSet};
2use std::sync::Arc;
3
4use serde::{Deserialize, Serialize};
5use time::OffsetDateTime;
6use uuid::Uuid;
7
8use crate::event_log::{AnyEventLog, EventLog, LogError, LogEvent, Topic};
9use crate::orchestration::CapabilityPolicy;
10
11pub const CORRECTION_SCHEMA_V0: &str = "harn-correction/v0";
12pub const CORRECTIONS_TOPIC: &str = "corrections.records";
13pub const CORRECTION_EVENT_KIND: &str = "correction_recorded";
14
15#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
16#[serde(rename_all = "snake_case")]
17pub enum CorrectionScope {
18 #[default]
19 ThisRun,
20 ThisPersona,
21 All,
22}
23
24impl CorrectionScope {
25 pub fn as_str(self) -> &'static str {
26 match self {
27 Self::ThisRun => "this_run",
28 Self::ThisPersona => "this_persona",
29 Self::All => "all",
30 }
31 }
32
33 pub fn parse(value: &str) -> Result<Self, String> {
34 match value {
35 "this_run" => Ok(Self::ThisRun),
36 "this_persona" => Ok(Self::ThisPersona),
37 "all" => Ok(Self::All),
38 other => Err(format!(
39 "unsupported correction scope '{other}', expected this_run|this_persona|all"
40 )),
41 }
42 }
43}
44
45#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
46pub struct CorrectionRecord {
47 pub schema: String,
48 pub correction_id: String,
49 pub from_decision: serde_json::Value,
50 pub to_decision: serde_json::Value,
51 pub reason: String,
52 pub applied_by: String,
53 pub scope: CorrectionScope,
54 #[serde(with = "time::serde::rfc3339")]
55 pub timestamp: OffsetDateTime,
56 #[serde(default, skip_serializing_if = "Option::is_none")]
57 pub actor_id: Option<String>,
58 #[serde(default, skip_serializing_if = "Option::is_none")]
59 pub action: Option<String>,
60 #[serde(default, skip_serializing_if = "Option::is_none")]
61 pub trace_id: Option<String>,
62 #[serde(default, skip_serializing_if = "Option::is_none")]
63 pub step: Option<String>,
64 #[serde(default, skip_serializing_if = "Vec::is_empty")]
65 pub evidence_refs: Vec<serde_json::Value>,
66 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
67 pub metadata: BTreeMap<String, serde_json::Value>,
68}
69
70impl CorrectionRecord {
71 pub fn new(
72 from_decision: serde_json::Value,
73 to_decision: serde_json::Value,
74 reason: impl Into<String>,
75 applied_by: impl Into<String>,
76 scope: CorrectionScope,
77 ) -> Self {
78 let actor_id = extract_string_anywhere(
79 &from_decision,
80 &to_decision,
81 &["actor_id", "actor", "agent", "trigger_id", "binding_id"],
82 );
83 let action =
84 extract_string_anywhere(&from_decision, &to_decision, &["action", "event_kind"]);
85 let trace_id = extract_string_anywhere(&from_decision, &to_decision, &["trace_id"]);
86 Self {
87 schema: CORRECTION_SCHEMA_V0.to_string(),
88 correction_id: Uuid::now_v7().to_string(),
89 from_decision,
90 to_decision,
91 reason: reason.into(),
92 applied_by: applied_by.into(),
93 scope,
94 timestamp: OffsetDateTime::now_utc(),
95 actor_id,
96 action,
97 trace_id,
98 step: None,
99 evidence_refs: Vec::new(),
100 metadata: BTreeMap::new(),
101 }
102 }
103}
104
105#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
106#[serde(default)]
107pub struct CorrectionQueryFilters {
108 pub actor_id: Option<String>,
109 pub action: Option<String>,
110 pub scope: Option<CorrectionScope>,
111 #[serde(with = "time::serde::rfc3339::option")]
112 pub since: Option<OffsetDateTime>,
113 #[serde(with = "time::serde::rfc3339::option")]
114 pub until: Option<OffsetDateTime>,
115 pub limit: Option<usize>,
116}
117
118pub fn corrections_topic() -> Result<Topic, LogError> {
119 Topic::new(CORRECTIONS_TOPIC)
120}
121
122pub async fn append_correction_record(
123 log: &Arc<AnyEventLog>,
124 record: &CorrectionRecord,
125) -> Result<CorrectionRecord, LogError> {
126 let payload = serde_json::to_value(record)
127 .map_err(|error| LogError::Serde(format!("correction record encode error: {error}")))?;
128 let mut headers = BTreeMap::new();
129 headers.insert("correction_id".to_string(), record.correction_id.clone());
130 headers.insert("scope".to_string(), record.scope.as_str().to_string());
131 if let Some(actor_id) = record.actor_id.as_ref() {
132 headers.insert("actor_id".to_string(), actor_id.clone());
133 }
134 if let Some(action) = record.action.as_ref() {
135 headers.insert("action".to_string(), action.clone());
136 }
137 if let Some(trace_id) = record.trace_id.as_ref() {
138 headers.insert("trace_id".to_string(), trace_id.clone());
139 }
140 log.append(
141 &corrections_topic()?,
142 LogEvent::new(CORRECTION_EVENT_KIND, payload).with_headers(headers),
143 )
144 .await?;
145 Ok(record.clone())
146}
147
148pub async fn query_correction_records(
149 log: &Arc<AnyEventLog>,
150 filters: &CorrectionQueryFilters,
151) -> Result<Vec<CorrectionRecord>, LogError> {
152 let mut records = Vec::new();
153 let mut seen = HashSet::new();
154 for (_, event) in log
155 .read_range(&corrections_topic()?, None, usize::MAX)
156 .await?
157 {
158 if event.kind != CORRECTION_EVENT_KIND {
159 continue;
160 }
161 let Ok(record) = serde_json::from_value::<CorrectionRecord>(event.payload) else {
162 continue;
163 };
164 if !matches_filters(&record, filters) {
165 continue;
166 }
167 if seen.insert(record.correction_id.clone()) {
168 records.push(record);
169 }
170 }
171 records.sort_by(|left, right| {
172 left.timestamp
173 .cmp(&right.timestamp)
174 .then(left.correction_id.cmp(&right.correction_id))
175 });
176 apply_limit(&mut records, filters.limit);
177 Ok(records)
178}
179
180pub async fn apply_corrections_to_policy(
181 log: &Arc<AnyEventLog>,
182 actor_id: &str,
183 policy: CapabilityPolicy,
184) -> Result<CapabilityPolicy, LogError> {
185 let corrections = query_correction_records(
186 log,
187 &CorrectionQueryFilters {
188 actor_id: Some(actor_id.to_string()),
189 ..CorrectionQueryFilters::default()
190 },
191 )
192 .await?;
193 Ok(policy_with_corrections(actor_id, policy, &corrections))
194}
195
196pub fn policy_with_corrections(
197 actor_id: &str,
198 mut policy: CapabilityPolicy,
199 corrections: &[CorrectionRecord],
200) -> CapabilityPolicy {
201 let should_tighten = corrections.iter().any(|record| {
202 correction_applies_to_actor(record, actor_id)
203 && matches!(
204 record.scope,
205 CorrectionScope::ThisPersona | CorrectionScope::All
206 )
207 });
208 if should_tighten {
209 cap_side_effect_level(&mut policy, "read_only");
210 }
211 policy
212}
213
214pub fn correction_record_from_json(value: serde_json::Value) -> Result<CorrectionRecord, String> {
215 let object = value
216 .as_object()
217 .ok_or_else(|| "correction record must be a dict".to_string())?;
218 let from_decision = required_json(object, "from_decision")?;
219 let to_decision = required_json(object, "to_decision")?;
220 let reason = required_string(object, "reason")?;
221 let applied_by = required_string(object, "applied_by")?;
222 let scope = object
223 .get("scope")
224 .and_then(serde_json::Value::as_str)
225 .map(CorrectionScope::parse)
226 .transpose()?
227 .unwrap_or_default();
228
229 let mut record = CorrectionRecord::new(from_decision, to_decision, reason, applied_by, scope);
230 if let Some(actor_id) = optional_string(object, "actor_id")
231 .or_else(|| optional_string(object, "actor"))
232 .or_else(|| optional_string(object, "agent"))
233 {
234 record.actor_id = Some(actor_id);
235 }
236 if let Some(action) = optional_string(object, "action") {
237 record.action = Some(action);
238 }
239 if let Some(trace_id) = optional_string(object, "trace_id") {
240 record.trace_id = Some(trace_id);
241 }
242 record.step = optional_string(object, "step");
243 record.evidence_refs = match object.get("evidence_refs") {
244 Some(value) => value
245 .as_array()
246 .cloned()
247 .ok_or_else(|| "correction record field `evidence_refs` must be a list".to_string())?,
248 None => Vec::new(),
249 };
250 if let Some(metadata) = object.get("metadata") {
251 let Some(metadata_object) = metadata.as_object() else {
252 return Err("correction record field `metadata` must be a dict".to_string());
253 };
254 record.metadata.extend(
255 metadata_object
256 .iter()
257 .map(|(key, value)| (key.clone(), value.clone())),
258 );
259 }
260 Ok(record)
261}
262
263pub fn correction_query_filters_from_json(
264 value: serde_json::Value,
265) -> Result<CorrectionQueryFilters, String> {
266 let object = value
267 .as_object()
268 .ok_or_else(|| "correction query filters must be a dict".to_string())?;
269 let scope = object
270 .get("scope")
271 .and_then(serde_json::Value::as_str)
272 .map(CorrectionScope::parse)
273 .transpose()?;
274 Ok(CorrectionQueryFilters {
275 actor_id: optional_string(object, "actor_id")
276 .or_else(|| optional_string(object, "actor"))
277 .or_else(|| optional_string(object, "agent")),
278 action: optional_string(object, "action"),
279 scope,
280 since: object
281 .get("since")
282 .and_then(serde_json::Value::as_str)
283 .map(parse_rfc3339)
284 .transpose()?,
285 until: object
286 .get("until")
287 .and_then(serde_json::Value::as_str)
288 .map(parse_rfc3339)
289 .transpose()?,
290 limit: optional_limit(object, "limit")?,
291 })
292}
293
294fn matches_filters(record: &CorrectionRecord, filters: &CorrectionQueryFilters) -> bool {
295 if let Some(actor_id) = filters.actor_id.as_deref() {
296 if !correction_applies_to_actor(record, actor_id) {
297 return false;
298 }
299 }
300 if let Some(action) = filters.action.as_deref() {
301 if record.action.as_deref() != Some(action) {
302 return false;
303 }
304 }
305 if let Some(scope) = filters.scope {
306 if record.scope != scope {
307 return false;
308 }
309 }
310 if let Some(since) = filters.since {
311 if record.timestamp < since {
312 return false;
313 }
314 }
315 if let Some(until) = filters.until {
316 if record.timestamp > until {
317 return false;
318 }
319 }
320 true
321}
322
323fn correction_applies_to_actor(record: &CorrectionRecord, actor_id: &str) -> bool {
324 record.scope == CorrectionScope::All || record.actor_id.as_deref() == Some(actor_id)
325}
326
327fn cap_side_effect_level(policy: &mut CapabilityPolicy, ceiling: &str) {
328 let current = policy.side_effect_level.as_deref().unwrap_or("network");
329 if side_effect_rank(current) > side_effect_rank(ceiling) {
330 policy.side_effect_level = Some(ceiling.to_string());
331 }
332}
333
334fn side_effect_rank(value: &str) -> usize {
335 match value {
336 "none" => 0,
337 "read_only" => 1,
338 "workspace_write" => 2,
339 "process_exec" => 3,
340 "network" => 4,
341 _ => 5,
342 }
343}
344
345fn apply_limit(records: &mut Vec<CorrectionRecord>, limit: Option<usize>) {
346 let Some(limit) = limit else {
347 return;
348 };
349 if records.len() <= limit {
350 return;
351 }
352 let keep_from = records.len() - limit;
353 records.drain(0..keep_from);
354}
355
356fn required_json(
357 object: &serde_json::Map<String, serde_json::Value>,
358 field: &str,
359) -> Result<serde_json::Value, String> {
360 object
361 .get(field)
362 .cloned()
363 .ok_or_else(|| format!("correction record missing field `{field}`"))
364}
365
366fn required_string(
367 object: &serde_json::Map<String, serde_json::Value>,
368 field: &str,
369) -> Result<String, String> {
370 optional_string(object, field)
371 .ok_or_else(|| format!("correction record missing string field `{field}`"))
372}
373
374fn optional_string(
375 object: &serde_json::Map<String, serde_json::Value>,
376 field: &str,
377) -> Option<String> {
378 object
379 .get(field)
380 .and_then(serde_json::Value::as_str)
381 .map(str::to_string)
382}
383
384fn optional_limit(
385 object: &serde_json::Map<String, serde_json::Value>,
386 field: &str,
387) -> Result<Option<usize>, String> {
388 let Some(value) = object.get(field) else {
389 return Ok(None);
390 };
391 if let Some(limit) = value.as_u64() {
392 return usize::try_from(limit)
393 .map(Some)
394 .map_err(|_| format!("correction query field `{field}` is too large"));
395 }
396 if value.as_i64().is_some() {
397 return Err(format!(
398 "correction query field `{field}` must be non-negative"
399 ));
400 }
401 Err(format!(
402 "correction query field `{field}` must be an integer"
403 ))
404}
405
406fn extract_string_anywhere(
407 first: &serde_json::Value,
408 second: &serde_json::Value,
409 keys: &[&str],
410) -> Option<String> {
411 keys.iter()
412 .find_map(|key| extract_string(first, key).or_else(|| extract_string(second, key)))
413}
414
415fn extract_string(value: &serde_json::Value, key: &str) -> Option<String> {
416 value
417 .as_object()
418 .and_then(|object| object.get(key))
419 .and_then(serde_json::Value::as_str)
420 .map(str::to_string)
421 .or_else(|| {
422 value
423 .get("metadata")
424 .and_then(|metadata| metadata.get(key))
425 .and_then(serde_json::Value::as_str)
426 .map(str::to_string)
427 })
428}
429
430fn parse_rfc3339(value: &str) -> Result<OffsetDateTime, String> {
431 OffsetDateTime::parse(value, &time::format_description::well_known::Rfc3339)
432 .map_err(|error| format!("invalid RFC3339 timestamp '{value}': {error}"))
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use crate::event_log::MemoryEventLog;
439
440 #[tokio::test]
441 async fn append_query_and_policy_round_trip() {
442 let log: Arc<AnyEventLog> = Arc::new(AnyEventLog::Memory(MemoryEventLog::new(16)));
443 let mut record = CorrectionRecord::new(
444 serde_json::json!({
445 "actor_id": "bot",
446 "action": "github.issue.opened",
447 "outcome": "success"
448 }),
449 serde_json::json!({"outcome": "denied"}),
450 "operator corrected unsafe issue triage",
451 "alice",
452 CorrectionScope::ThisPersona,
453 );
454 record.trace_id = Some("trace-correction".to_string());
455 append_correction_record(&log, &record).await.unwrap();
456
457 let records = query_correction_records(
458 &log,
459 &CorrectionQueryFilters {
460 actor_id: Some("bot".to_string()),
461 action: Some("github.issue.opened".to_string()),
462 ..CorrectionQueryFilters::default()
463 },
464 )
465 .await
466 .unwrap();
467
468 assert_eq!(records.len(), 1);
469 assert_eq!(records[0].applied_by, "alice");
470 assert_eq!(records[0].scope, CorrectionScope::ThisPersona);
471
472 let policy = CapabilityPolicy {
473 side_effect_level: Some("network".to_string()),
474 ..CapabilityPolicy::default()
475 };
476 let adapted = apply_corrections_to_policy(&log, "bot", policy)
477 .await
478 .unwrap();
479 assert_eq!(adapted.side_effect_level.as_deref(), Some("read_only"));
480 }
481
482 #[test]
483 fn this_run_scope_does_not_change_persona_policy() {
484 let record = CorrectionRecord::new(
485 serde_json::json!({"actor_id": "bot", "action": "deploy"}),
486 serde_json::json!({"outcome": "denied"}),
487 "single-run correction",
488 "alice",
489 CorrectionScope::ThisRun,
490 );
491 let policy = CapabilityPolicy {
492 side_effect_level: Some("network".to_string()),
493 ..CapabilityPolicy::default()
494 };
495
496 let adapted = policy_with_corrections("bot", policy, &[record]);
497
498 assert_eq!(adapted.side_effect_level.as_deref(), Some("network"));
499 }
500
501 #[test]
502 fn correction_parsers_reject_malformed_optional_fields() {
503 let record_error = correction_record_from_json(serde_json::json!({
504 "from_decision": {"actor_id": "bot"},
505 "to_decision": {"outcome": "denied"},
506 "reason": "operator corrected routing",
507 "applied_by": "alice",
508 "evidence_refs": "not-a-list"
509 }))
510 .expect_err("invalid evidence_refs");
511 assert!(record_error.contains("evidence_refs"));
512
513 let limit_error = correction_query_filters_from_json(serde_json::json!({
514 "limit": -1
515 }))
516 .expect_err("invalid limit");
517 assert!(limit_error.contains("non-negative"));
518 }
519}