1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{SystemTime, UNIX_EPOCH};
9
10use async_trait::async_trait;
11use rand::TryRngCore;
12use tokio::sync::Mutex;
13
14use crate::engine::Stores;
15use crate::errors::{ErrorCategory, ErrorInfo, IoError};
16use crate::events::{Event, EventEnvelope, FactRecorded, DOMAIN_EVENT_FACT_RECORDED};
17use crate::hashing::{canonical_json_bytes, CanonicalJsonError};
18use crate::ids::{ArtifactId, ErrorCode, FactKey, RunId, StateId};
19use crate::io::{IoCall, IoProvider, IoResult};
20use crate::stores::{ArtifactKind, ArtifactStore};
21
22fn info(code: &'static str, category: ErrorCategory, message: &'static str) -> ErrorInfo {
23 ErrorInfo {
24 code: ErrorCode(code.to_string()),
25 category,
26 retryable: false,
27 message: message.to_string(),
28 details: None,
29 }
30}
31
32fn io_other(code: &'static str, category: ErrorCategory, message: &'static str) -> IoError {
33 IoError::Other(info(code, category, message))
34}
35
36#[derive(Clone, Default)]
40pub struct FactIndex {
41 inner: Arc<Mutex<HashMap<FactKey, ArtifactId>>>,
42}
43
44impl FactIndex {
45 pub fn from_event_stream(stream: &[EventEnvelope]) -> Self {
50 let mut m = HashMap::new();
51 for e in stream {
52 let Event::Domain(de) = &e.event else {
53 continue;
54 };
55 if de.name != DOMAIN_EVENT_FACT_RECORDED {
56 continue;
57 }
58
59 let Ok(fr) = serde_json::from_value::<FactRecorded>(de.payload.clone()) else {
60 continue;
61 };
62
63 m.entry(fr.key).or_insert(fr.payload_id);
65 }
66
67 Self {
68 inner: Arc::new(Mutex::new(m)),
69 }
70 }
71
72 pub async fn get(&self, key: &FactKey) -> Option<ArtifactId> {
74 self.inner.lock().await.get(key).cloned()
75 }
76
77 pub async fn bind_if_unset(&self, key: FactKey, payload_id: ArtifactId) -> (ArtifactId, bool) {
82 let mut inner = self.inner.lock().await;
83 match inner.get(&key) {
84 Some(existing) => (existing.clone(), false),
85 None => {
86 inner.insert(key, payload_id.clone());
87 (payload_id, true)
88 }
89 }
90 }
91
92 pub async fn unbind_if_matches(&self, key: &FactKey, payload_id: &ArtifactId) -> bool {
97 let mut inner = self.inner.lock().await;
98 match inner.get(key) {
99 Some(existing) if existing == payload_id => {
100 inner.remove(key);
101 true
102 }
103 _ => false,
104 }
105 }
106}
107
108#[async_trait]
110pub trait LiveIoTransport: Send {
111 async fn call(&mut self, call: IoCall) -> Result<serde_json::Value, IoError>;
113}
114
115#[derive(Clone)]
117pub struct LiveIoEnv {
118 pub stores: Stores,
120 pub run_id: RunId,
122 pub state_id: StateId,
124 pub attempt: u32,
126}
127
128pub trait LiveIoTransportFactory: Send + Sync {
130 fn namespace_group(&self) -> &str;
132
133 fn make(&self, env: LiveIoEnv) -> Box<dyn LiveIoTransport>;
135}
136
137struct UnimplementedLiveIoTransport;
138
139#[async_trait]
140impl LiveIoTransport for UnimplementedLiveIoTransport {
141 async fn call(&mut self, _call: IoCall) -> Result<serde_json::Value, IoError> {
142 Err(io_other(
143 "io_unimplemented",
144 ErrorCategory::Unknown,
145 "live io transport is not configured",
146 ))
147 }
148}
149
150#[derive(Clone, Default)]
152pub struct UnimplementedLiveIoTransportFactory;
153
154impl LiveIoTransportFactory for UnimplementedLiveIoTransportFactory {
155 fn namespace_group(&self) -> &str {
156 "unimplemented"
157 }
158
159 fn make(&self, _env: LiveIoEnv) -> Box<dyn LiveIoTransport> {
160 Box::new(UnimplementedLiveIoTransport)
161 }
162}
163
164pub struct LiveIo {
166 run_id: RunId,
167 state_id: StateId,
168 attempt: u32,
169 call_ordinal: u64,
170 artifacts: Arc<dyn ArtifactStore>,
171 facts: FactIndex,
172 fact_recorder: Arc<dyn FactRecorder>,
173 transport: Box<dyn LiveIoTransport>,
174}
175
176impl LiveIo {
177 pub fn new(
179 run_id: RunId,
180 state_id: StateId,
181 attempt: u32,
182 artifacts: Arc<dyn ArtifactStore>,
183 facts: FactIndex,
184 fact_recorder: Arc<dyn FactRecorder>,
185 transport: Box<dyn LiveIoTransport>,
186 ) -> Self {
187 Self {
188 run_id,
189 state_id,
190 attempt,
191 call_ordinal: 0,
192 artifacts,
193 facts,
194 fact_recorder,
195 transport,
196 }
197 }
198
199 fn derived_fact_key(&mut self, kind: &str) -> FactKey {
200 let ord = self.call_ordinal;
201 self.call_ordinal += 1;
202 FactKey(format!(
203 "mfm:{kind}|run:{}|state:{}|attempt:{}|ord:{ord}",
204 self.run_id.0,
205 self.state_id.as_str(),
206 self.attempt
207 ))
208 }
209
210 async fn record_fact_json(
211 &mut self,
212 key: FactKey,
213 value: serde_json::Value,
214 ) -> Result<(serde_json::Value, ArtifactId), IoError> {
215 if let Some(payload_id) = self.facts.get(&key).await {
216 let bytes = self.artifacts.get(&payload_id).await.map_err(|_| {
217 io_other(
218 "fact_payload_get_failed",
219 ErrorCategory::Storage,
220 "failed to read fact payload",
221 )
222 })?;
223 let v = serde_json::from_slice::<serde_json::Value>(&bytes).map_err(|_| {
224 io_other(
225 "fact_payload_decode_failed",
226 ErrorCategory::ParsingInput,
227 "failed to decode fact payload",
228 )
229 })?;
230 return Ok((v, payload_id));
231 }
232
233 let bytes = canonical_json_bytes(&value).map_err(|e| match e {
234 CanonicalJsonError::FloatNotAllowed => io_other(
235 "fact_payload_not_canonical",
236 ErrorCategory::ParsingInput,
237 "fact payload is not canonical-json-hashable (floats are forbidden)",
238 ),
239 CanonicalJsonError::SecretsNotAllowed => io_other(
240 "secrets_detected",
241 ErrorCategory::Unknown,
242 "fact payload contained secrets (policy forbids persisting secrets)",
243 ),
244 })?;
245
246 let payload_id = self
247 .artifacts
248 .put(ArtifactKind::FactPayload, bytes)
249 .await
250 .map_err(|_| {
251 io_other(
252 "fact_payload_put_failed",
253 ErrorCategory::Storage,
254 "failed to store fact payload",
255 )
256 })?;
257
258 let (bound_id, inserted) = self.facts.bind_if_unset(key.clone(), payload_id).await;
259 if inserted {
260 if let Err(e) = self
261 .fact_recorder
262 .record_fact_binding(key.clone(), bound_id.clone())
263 .await
264 {
265 let _ = self.facts.unbind_if_matches(&key, &bound_id).await;
267 return Err(e);
268 }
269 Ok((value, bound_id))
270 } else {
271 let bytes = self.artifacts.get(&bound_id).await.map_err(|_| {
273 io_other(
274 "fact_payload_get_failed",
275 ErrorCategory::Storage,
276 "failed to read fact payload",
277 )
278 })?;
279 let v = serde_json::from_slice::<serde_json::Value>(&bytes).map_err(|_| {
280 io_other(
281 "fact_payload_decode_failed",
282 ErrorCategory::ParsingInput,
283 "failed to decode fact payload",
284 )
285 })?;
286 Ok((v, bound_id))
287 }
288 }
289
290 async fn record_fact_bytes(
291 &mut self,
292 key: FactKey,
293 bytes: Vec<u8>,
294 ) -> Result<(Vec<u8>, ArtifactId), IoError> {
295 if let Some(payload_id) = self.facts.get(&key).await {
296 let got = self.artifacts.get(&payload_id).await.map_err(|_| {
297 io_other(
298 "fact_payload_get_failed",
299 ErrorCategory::Storage,
300 "failed to read fact payload",
301 )
302 })?;
303 return Ok((got, payload_id));
304 }
305
306 let payload_id = self
307 .artifacts
308 .put(ArtifactKind::FactPayload, bytes.clone())
309 .await
310 .map_err(|_| {
311 io_other(
312 "fact_payload_put_failed",
313 ErrorCategory::Storage,
314 "failed to store fact payload",
315 )
316 })?;
317
318 let (bound_id, inserted) = self.facts.bind_if_unset(key.clone(), payload_id).await;
319 if inserted {
320 if let Err(e) = self
321 .fact_recorder
322 .record_fact_binding(key.clone(), bound_id.clone())
323 .await
324 {
325 let _ = self.facts.unbind_if_matches(&key, &bound_id).await;
327 return Err(e);
328 }
329 }
330
331 Ok((bytes, bound_id))
332 }
333}
334
335#[async_trait]
336impl IoProvider for LiveIo {
337 async fn call(&mut self, call: IoCall) -> Result<IoResult, IoError> {
338 let Some(key) = call.fact_key.clone() else {
339 let response = self.transport.call(call).await?;
340 return Ok(IoResult {
341 response,
342 recorded_payload_id: None,
343 });
344 };
345
346 if let Some(payload_id) = self.facts.get(&key).await {
347 let bytes = self.artifacts.get(&payload_id).await.map_err(|_| {
348 io_other(
349 "fact_payload_get_failed",
350 ErrorCategory::Storage,
351 "failed to read fact payload",
352 )
353 })?;
354 let response = serde_json::from_slice::<serde_json::Value>(&bytes).map_err(|_| {
355 io_other(
356 "fact_payload_decode_failed",
357 ErrorCategory::ParsingInput,
358 "failed to decode fact payload",
359 )
360 })?;
361 return Ok(IoResult {
362 response,
363 recorded_payload_id: Some(payload_id),
364 });
365 }
366
367 let response = self.transport.call(call).await?;
368 let (response, payload_id) = self.record_fact_json(key, response).await?;
369 Ok(IoResult {
370 response,
371 recorded_payload_id: Some(payload_id),
372 })
373 }
374
375 async fn get_recorded_fact(&mut self, key: &FactKey) -> Result<Option<ArtifactId>, IoError> {
376 Ok(self.facts.get(key).await)
377 }
378
379 async fn record_value(
380 &mut self,
381 key: FactKey,
382 value: serde_json::Value,
383 ) -> Result<ArtifactId, IoError> {
384 let (_, payload_id) = self.record_fact_json(key, value).await?;
385 Ok(payload_id)
386 }
387
388 async fn now_millis(&mut self) -> Result<u64, IoError> {
389 let ms = SystemTime::now()
390 .duration_since(UNIX_EPOCH)
391 .map_err(|_| {
392 io_other(
393 "time_unavailable",
394 ErrorCategory::Unknown,
395 "system time not available",
396 )
397 })?
398 .as_millis() as u64;
399
400 let key = self.derived_fact_key("now_millis");
401 let (v, _payload_id) = self
402 .record_fact_json(key, serde_json::Value::Number(ms.into()))
403 .await?;
404
405 let n = v.as_u64().ok_or_else(|| {
406 io_other(
407 "fact_payload_invalid",
408 ErrorCategory::ParsingInput,
409 "recorded time fact payload was not a u64",
410 )
411 })?;
412 Ok(n)
413 }
414
415 async fn random_bytes(&mut self, n: usize) -> Result<Vec<u8>, IoError> {
416 let mut bytes = vec![0u8; n];
417 let mut rng = rand::rngs::OsRng;
418 rng.try_fill_bytes(&mut bytes).map_err(|_| {
419 io_other(
420 "random_unavailable",
421 ErrorCategory::Unknown,
422 "os randomness not available",
423 )
424 })?;
425
426 let key = self.derived_fact_key("random_bytes");
427 let (got, _payload_id) = self.record_fact_bytes(key, bytes).await?;
428 Ok(got)
429 }
430}
431
432#[async_trait]
436pub trait FactRecorder: Send + Sync {
437 async fn record_fact_binding(
439 &self,
440 key: FactKey,
441 payload_id: ArtifactId,
442 ) -> Result<(), IoError>;
443}
444
445#[derive(Clone, Default)]
447pub struct NoopFactRecorder;
448
449#[async_trait]
450impl FactRecorder for NoopFactRecorder {
451 async fn record_fact_binding(
452 &self,
453 _key: FactKey,
454 _payload_id: ArtifactId,
455 ) -> Result<(), IoError> {
456 Ok(())
457 }
458}