1use std::fs::{self, OpenOptions};
25use std::io::{ErrorKind, Write};
26use std::path::{Path, PathBuf};
27use std::sync::atomic::{AtomicU64, Ordering};
28use std::sync::{Arc, Mutex};
29
30use ergo_adapter::capture::ExternalEventRecord;
31use ergo_adapter::{ExternalEvent, GraphId, RuntimeInvoker};
32
33use crate::{
34 CaptureBundle, Constraints, DecisionLog, DecisionLogEntry, EpisodeInvocationRecord, Supervisor,
35};
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum CaptureJsonStyle {
39 Compact,
40 Pretty,
41}
42
43#[derive(Debug)]
44#[non_exhaustive]
45pub enum CaptureWriteError {
46 CreateOutputDirectory {
47 path: PathBuf,
48 source: std::io::Error,
49 },
50 Serialize {
51 path: PathBuf,
52 style: CaptureJsonStyle,
53 source: serde_json::Error,
54 },
55 InvalidDestination {
56 path: PathBuf,
57 },
58 CreateTempFile {
59 destination: PathBuf,
60 temp_path: PathBuf,
61 source: std::io::Error,
62 },
63 ExhaustedTempFileAttempts {
64 destination: PathBuf,
65 },
66 WriteTempFile {
67 destination: PathBuf,
68 temp_path: PathBuf,
69 source: std::io::Error,
70 },
71 SyncTempFile {
72 destination: PathBuf,
73 temp_path: PathBuf,
74 source: std::io::Error,
75 },
76 RenameTempFile {
77 destination: PathBuf,
78 temp_path: PathBuf,
79 source: std::io::Error,
80 },
81}
82
83impl std::fmt::Display for CaptureWriteError {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 match self {
86 Self::CreateOutputDirectory { source, .. } => {
87 write!(f, "create capture output directory: {source}")
88 }
89 Self::Serialize {
90 path,
91 style,
92 source,
93 } => write!(
94 f,
95 "serialize capture bundle '{}' ({}): {source}",
96 path.display(),
97 match style {
98 CaptureJsonStyle::Compact => "compact",
99 CaptureJsonStyle::Pretty => "pretty",
100 }
101 ),
102 Self::InvalidDestination { path } => write!(
103 f,
104 "write capture bundle '{}': destination must include a file name",
105 path.display()
106 ),
107 Self::CreateTempFile {
108 destination,
109 temp_path,
110 source,
111 } => write!(
112 f,
113 "write capture bundle '{}': create temp file '{}': {source}",
114 destination.display(),
115 temp_path.display()
116 ),
117 Self::ExhaustedTempFileAttempts { destination } => write!(
118 f,
119 "write capture bundle '{}': exhausted temp file creation attempts",
120 destination.display()
121 ),
122 Self::WriteTempFile {
123 destination,
124 temp_path,
125 source,
126 } => write!(
127 f,
128 "write capture bundle '{}': write temp file '{}': {source}",
129 destination.display(),
130 temp_path.display()
131 ),
132 Self::SyncTempFile {
133 destination,
134 temp_path,
135 source,
136 } => write!(
137 f,
138 "write capture bundle '{}': sync temp file '{}': {source}",
139 destination.display(),
140 temp_path.display()
141 ),
142 Self::RenameTempFile {
143 destination,
144 temp_path,
145 source,
146 } => write!(
147 f,
148 "write capture bundle '{}': rename temp file '{}': {source}",
149 destination.display(),
150 temp_path.display()
151 ),
152 }
153 }
154}
155
156impl std::error::Error for CaptureWriteError {
157 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
158 match self {
159 Self::CreateOutputDirectory { source, .. } => Some(source),
160 Self::Serialize { source, .. } => Some(source),
161 Self::CreateTempFile { source, .. } => Some(source),
162 Self::WriteTempFile { source, .. } => Some(source),
163 Self::SyncTempFile { source, .. } => Some(source),
164 Self::RenameTempFile { source, .. } => Some(source),
165 Self::InvalidDestination { .. } | Self::ExhaustedTempFileAttempts { .. } => None,
166 }
167 }
168}
169
170static TEMP_FILE_COUNTER: AtomicU64 = AtomicU64::new(0);
171const MAX_TEMP_FILE_ATTEMPTS: u32 = 64;
172#[cfg(windows)]
173const MAX_REPLACE_RETRY_ATTEMPTS: u32 = 64;
174#[cfg(windows)]
175const REPLACE_RETRY_DELAY_MS: u64 = 10;
176
177pub struct CapturingDecisionLog<L: DecisionLog> {
178 inner: L,
179 bundle: Arc<Mutex<CaptureBundle>>,
180}
181
182impl<L: DecisionLog> CapturingDecisionLog<L> {
183 pub fn new(inner: L, bundle: Arc<Mutex<CaptureBundle>>) -> Self {
184 Self { inner, bundle }
185 }
186}
187
188impl<L: DecisionLog> DecisionLog for CapturingDecisionLog<L> {
189 fn log(&self, entry: DecisionLogEntry) {
190 self.inner.log(entry.clone());
191
192 let record = EpisodeInvocationRecord::from(&entry);
193
194 let mut guard = self.bundle.lock().expect("capture bundle poisoned");
195 guard.decisions.push(record);
196 }
197}
198
199pub struct CapturingSession<L: DecisionLog, R: RuntimeInvoker> {
200 supervisor: Supervisor<CapturingDecisionLog<L>, R>,
201 bundle: Arc<Mutex<CaptureBundle>>,
202}
203
204impl<L: DecisionLog, R: RuntimeInvoker> CapturingSession<L, R> {
205 pub fn new(
206 graph_id: GraphId,
207 constraints: Constraints,
208 inner_log: L,
209 runtime: R,
210 runtime_provenance: String,
211 ) -> Self {
212 Self::new_with_provenance(
213 graph_id,
214 constraints,
215 inner_log,
216 runtime,
217 crate::NO_ADAPTER_PROVENANCE.to_string(),
218 runtime_provenance,
219 )
220 }
221
222 pub fn new_with_provenance(
223 graph_id: GraphId,
224 constraints: Constraints,
225 inner_log: L,
226 runtime: R,
227 adapter_provenance: String,
228 runtime_provenance: String,
229 ) -> Self {
230 let bundle = Arc::new(Mutex::new(CaptureBundle {
231 capture_version: crate::CAPTURE_FORMAT_VERSION.to_string(),
232 graph_id: graph_id.clone(),
233 config: constraints.clone(),
234 events: Vec::new(),
235 decisions: Vec::new(),
236 adapter_provenance,
237 runtime_provenance,
238 egress_provenance: None,
239 }));
240
241 let capturing_log = CapturingDecisionLog::new(inner_log, Arc::clone(&bundle));
242 let supervisor = Supervisor::with_runtime(graph_id, constraints, capturing_log, runtime);
243
244 Self { supervisor, bundle }
245 }
246
247 pub fn on_event(&mut self, event: ExternalEvent) {
248 {
249 let mut guard = self.bundle.lock().expect("capture bundle poisoned");
250 guard.events.push(ExternalEventRecord::from_event(&event));
251 }
252
253 self.supervisor.on_event(event);
254 }
255
256 pub fn into_bundle(self) -> CaptureBundle {
257 let CapturingSession { supervisor, bundle } = self;
258 drop(supervisor);
259
260 match Arc::try_unwrap(bundle) {
261 Ok(mutex) => mutex.into_inner().expect("capture bundle poisoned"),
262 Err(shared) => shared.lock().expect("capture bundle poisoned").clone(),
263 }
264 }
265}
266
267pub fn write_capture_bundle(
268 path: &Path,
269 bundle: &CaptureBundle,
270 style: CaptureJsonStyle,
271) -> Result<(), CaptureWriteError> {
272 if let Some(parent) = path.parent() {
273 if !parent.as_os_str().is_empty() {
274 fs::create_dir_all(parent).map_err(|source| {
275 CaptureWriteError::CreateOutputDirectory {
276 path: parent.to_path_buf(),
277 source,
278 }
279 })?;
280 }
281 }
282
283 let mut bytes = match style {
284 CaptureJsonStyle::Compact => {
285 serde_json::to_vec(bundle).map_err(|source| CaptureWriteError::Serialize {
286 path: path.to_path_buf(),
287 style,
288 source,
289 })?
290 }
291 CaptureJsonStyle::Pretty => {
292 serde_json::to_vec_pretty(bundle).map_err(|source| CaptureWriteError::Serialize {
293 path: path.to_path_buf(),
294 style,
295 source,
296 })?
297 }
298 };
299 bytes.push(b'\n');
300
301 write_bytes_atomic(path, &bytes)
302}
303
304fn write_bytes_atomic(path: &Path, bytes: &[u8]) -> Result<(), CaptureWriteError> {
305 let parent = path
306 .parent()
307 .filter(|p| !p.as_os_str().is_empty())
308 .unwrap_or_else(|| Path::new("."));
309 let file_name = path
310 .file_name()
311 .ok_or_else(|| CaptureWriteError::InvalidDestination {
312 path: path.to_path_buf(),
313 })?;
314 let (temp_path, mut file) = create_temp_file(path, parent, file_name)?;
315
316 if let Err(source) = file.write_all(bytes) {
317 let _ = fs::remove_file(&temp_path);
318 return Err(CaptureWriteError::WriteTempFile {
319 destination: path.to_path_buf(),
320 temp_path,
321 source,
322 });
323 }
324
325 if let Err(source) = file.sync_all() {
326 let _ = fs::remove_file(&temp_path);
327 return Err(CaptureWriteError::SyncTempFile {
328 destination: path.to_path_buf(),
329 temp_path,
330 source,
331 });
332 }
333
334 drop(file);
335 if let Err(source) = replace_destination_with_retry(&temp_path, path) {
336 let _ = fs::remove_file(&temp_path);
337 return Err(CaptureWriteError::RenameTempFile {
338 destination: path.to_path_buf(),
339 temp_path,
340 source,
341 });
342 }
343
344 Ok(())
345}
346
347fn create_temp_file(
348 destination: &Path,
349 parent: &Path,
350 file_name: &std::ffi::OsStr,
351) -> Result<(std::path::PathBuf, std::fs::File), CaptureWriteError> {
352 for _ in 0..MAX_TEMP_FILE_ATTEMPTS {
353 let suffix = TEMP_FILE_COUNTER.fetch_add(1, Ordering::Relaxed);
354 let temp_name = format!(
355 "{}.{}.{}.tmp",
356 file_name.to_string_lossy(),
357 std::process::id(),
358 suffix
359 );
360 let temp_path = parent.join(temp_name);
361 match OpenOptions::new()
362 .create_new(true)
363 .write(true)
364 .open(&temp_path)
365 {
366 Ok(file) => return Ok((temp_path, file)),
367 Err(err) if err.kind() == ErrorKind::AlreadyExists => continue,
368 Err(source) => {
369 return Err(CaptureWriteError::CreateTempFile {
370 destination: destination.to_path_buf(),
371 temp_path,
372 source,
373 });
374 }
375 }
376 }
377
378 Err(CaptureWriteError::ExhaustedTempFileAttempts {
379 destination: destination.to_path_buf(),
380 })
381}
382
383#[cfg(not(windows))]
384fn replace_destination_with_retry(temp_path: &Path, destination: &Path) -> std::io::Result<()> {
385 fs::rename(temp_path, destination)
386}
387
388#[cfg(windows)]
389fn replace_destination_with_retry(temp_path: &Path, destination: &Path) -> std::io::Result<()> {
390 use std::time::Duration;
391
392 let mut last_permission_error = None;
393 for attempt in 0..MAX_REPLACE_RETRY_ATTEMPTS {
394 match replace_destination_once(temp_path, destination) {
395 Ok(()) => return Ok(()),
396 Err(err)
397 if err.kind() == ErrorKind::PermissionDenied
398 && attempt + 1 < MAX_REPLACE_RETRY_ATTEMPTS =>
399 {
400 last_permission_error = Some(err);
401 std::thread::sleep(Duration::from_millis(REPLACE_RETRY_DELAY_MS));
403 }
404 Err(err) => return Err(err),
405 }
406 }
407
408 Err(last_permission_error.unwrap_or_else(|| {
409 std::io::Error::new(
410 ErrorKind::PermissionDenied,
411 "atomic replace failed after retry attempts",
412 )
413 }))
414}
415
416#[cfg(windows)]
417fn replace_destination_once(temp_path: &Path, destination: &Path) -> std::io::Result<()> {
418 use std::iter;
419 use std::os::windows::ffi::OsStrExt;
420 type Dword = u32;
421 type WinBool = i32;
422
423 const MOVEFILE_REPLACE_EXISTING: Dword = 0x0000_0001;
424 const MOVEFILE_WRITE_THROUGH: Dword = 0x0000_0008;
425
426 #[link(name = "Kernel32")]
427 extern "system" {
428 fn MoveFileExW(
429 existing_file_name: *const u16,
430 new_file_name: *const u16,
431 flags: Dword,
432 ) -> WinBool;
433 }
434
435 let temp_wide: Vec<u16> = temp_path
436 .as_os_str()
437 .encode_wide()
438 .chain(iter::once(0))
439 .collect();
440 let destination_wide: Vec<u16> = destination
441 .as_os_str()
442 .encode_wide()
443 .chain(iter::once(0))
444 .collect();
445
446 let ok = unsafe {
448 MoveFileExW(
449 temp_wide.as_ptr(),
450 destination_wide.as_ptr(),
451 MOVEFILE_REPLACE_EXISTING | MOVEFILE_WRITE_THROUGH,
452 )
453 };
454 if ok == 0 {
455 Err(std::io::Error::last_os_error())
456 } else {
457 Ok(())
458 }
459}
460
461#[cfg(test)]
462mod tests;