1use crate::kernel::domain::Domain;
56use crate::kernel::expr::{BigFloat, BigInt, BigRat, ExprData, ExprId, PredicateKind};
57use crate::kernel::pool::ExprPool;
58use std::fs::{self, File};
59use std::io::{self, BufReader, BufWriter, Read, Write};
60use std::path::{Path, PathBuf};
61
62const MAGIC: &[u8; 4] = b"ALKP";
63const POOL_FORMAT_V1: u32 = 1;
65const POOL_FORMAT_V2: u32 = 2;
67const POOL_FORMAT_V3: u32 = 3;
69const POOL_FORMAT_V4: u32 = 4;
71const POOL_FORMAT_WRITE: u32 = POOL_FORMAT_V4;
72
73#[derive(Debug)]
81pub enum IoError {
82 Io(io::Error),
83 BadMagic,
84 UnsupportedVersion(u32),
85 Truncated,
86 BadUtf8,
87 BadDomain(u8),
88 BadTag(u8),
89 BadPredicateKind(u8),
90 BadNumeric(String),
91}
92
93impl std::fmt::Display for IoError {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 match self {
96 IoError::Io(e) => write!(f, "io error: {e}"),
97 IoError::BadMagic => write!(f, "not an alkahest pool file (bad magic)"),
98 IoError::UnsupportedVersion(v) => {
99 write!(
100 f,
101 "unsupported pool file version {v}; run `alkahest migrate-pool`"
102 )
103 }
104 IoError::Truncated => write!(f, "pool file truncated or incomplete"),
105 IoError::BadUtf8 => write!(f, "pool file contains invalid UTF-8"),
106 IoError::BadDomain(b) => write!(f, "pool file has unknown domain tag {b}"),
107 IoError::BadTag(b) => write!(f, "pool file has unknown node tag {b}"),
108 IoError::BadPredicateKind(b) => {
109 write!(f, "pool file has unknown predicate kind {b}")
110 }
111 IoError::BadNumeric(s) => write!(f, "pool file has invalid numeric: {s}"),
112 }
113 }
114}
115
116impl std::error::Error for IoError {}
117
118impl From<io::Error> for IoError {
119 fn from(e: io::Error) -> Self {
120 IoError::Io(e)
121 }
122}
123
124impl crate::errors::AlkahestError for IoError {
125 fn code(&self) -> &'static str {
126 match self {
127 IoError::Io(_) => "E-IO-001",
128 IoError::BadMagic => "E-IO-002",
129 IoError::UnsupportedVersion(_) => "E-IO-003",
130 IoError::Truncated => "E-IO-004",
131 IoError::BadUtf8 => "E-IO-005",
132 IoError::BadDomain(_) => "E-IO-006",
133 IoError::BadTag(_) => "E-IO-007",
134 IoError::BadPredicateKind(_) => "E-IO-008",
135 IoError::BadNumeric(_) => "E-IO-009",
136 }
137 }
138
139 fn remediation(&self) -> Option<&'static str> {
140 match self {
141 IoError::BadMagic => Some(
142 "file is not an alkahest pool; check the path or regenerate with ExprPool::checkpoint()",
143 ),
144 IoError::UnsupportedVersion(_) => Some(
145 "run the `alkahest migrate-pool` CLI to upgrade the file, or regenerate from source",
146 ),
147 IoError::Truncated => Some(
148 "file was truncated (likely a crash during checkpoint); rerun from source and checkpoint again",
149 ),
150 _ => None,
151 }
152 }
153}
154
155#[deprecated(since = "2.0.0", note = "renamed to IoError with E-IO-* codes")]
157pub type PoolPersistError = IoError;
158
159fn write_u8(w: &mut impl Write, v: u8) -> io::Result<()> {
164 w.write_all(&[v])
165}
166fn write_u32(w: &mut impl Write, v: u32) -> io::Result<()> {
167 w.write_all(&v.to_le_bytes())
168}
169fn write_u64(w: &mut impl Write, v: u64) -> io::Result<()> {
170 w.write_all(&v.to_le_bytes())
171}
172
173fn write_str(w: &mut impl Write, s: &str) -> io::Result<()> {
174 let bytes = s.as_bytes();
175 write_u32(w, bytes.len() as u32)?;
176 w.write_all(bytes)
177}
178
179fn write_ids(w: &mut impl Write, ids: &[ExprId]) -> io::Result<()> {
180 write_u32(w, ids.len() as u32)?;
181 for id in ids {
182 write_u32(w, id.0)?;
183 }
184 Ok(())
185}
186
187fn read_u8(r: &mut impl Read) -> Result<u8, IoError> {
188 let mut b = [0u8; 1];
189 r.read_exact(&mut b).map_err(|_| IoError::Truncated)?;
190 Ok(b[0])
191}
192
193fn read_u32(r: &mut impl Read) -> Result<u32, IoError> {
194 let mut b = [0u8; 4];
195 r.read_exact(&mut b).map_err(|_| IoError::Truncated)?;
196 Ok(u32::from_le_bytes(b))
197}
198
199fn read_u64(r: &mut impl Read) -> Result<u64, IoError> {
200 let mut b = [0u8; 8];
201 r.read_exact(&mut b).map_err(|_| IoError::Truncated)?;
202 Ok(u64::from_le_bytes(b))
203}
204
205fn read_str(r: &mut impl Read) -> Result<String, IoError> {
206 let len = read_u32(r)? as usize;
207 let mut buf = vec![0u8; len];
208 r.read_exact(&mut buf).map_err(|_| IoError::Truncated)?;
209 String::from_utf8(buf).map_err(|_| IoError::BadUtf8)
210}
211
212fn read_ids(r: &mut impl Read) -> Result<Vec<ExprId>, IoError> {
213 let arity = read_u32(r)? as usize;
214 let mut out = Vec::with_capacity(arity);
215 for _ in 0..arity {
216 out.push(ExprId(read_u32(r)?));
217 }
218 Ok(out)
219}
220
221fn domain_to_u8(d: &Domain) -> u8 {
226 match d {
227 Domain::Real => 0,
228 Domain::Complex => 1,
229 Domain::Integer => 2,
230 Domain::Positive => 3,
231 Domain::NonNegative => 4,
232 Domain::NonZero => 5,
233 }
234}
235
236fn u8_to_domain(b: u8) -> Result<Domain, IoError> {
237 match b {
238 0 => Ok(Domain::Real),
239 1 => Ok(Domain::Complex),
240 2 => Ok(Domain::Integer),
241 3 => Ok(Domain::Positive),
242 4 => Ok(Domain::NonNegative),
243 5 => Ok(Domain::NonZero),
244 b => Err(IoError::BadDomain(b)),
245 }
246}
247
248fn pred_to_u8(k: &PredicateKind) -> u8 {
249 match k {
251 PredicateKind::Eq => 0,
252 PredicateKind::Ne => 1,
253 PredicateKind::Lt => 2,
254 PredicateKind::Le => 3,
255 PredicateKind::Gt => 4,
256 PredicateKind::Ge => 5,
257 PredicateKind::And => 6,
258 PredicateKind::Or => 7,
259 PredicateKind::Not => 8,
260 PredicateKind::True => 9,
261 PredicateKind::False => 10,
262 }
263}
264
265fn u8_to_pred(b: u8) -> Result<PredicateKind, IoError> {
266 match b {
267 0 => Ok(PredicateKind::Eq),
268 1 => Ok(PredicateKind::Ne),
269 2 => Ok(PredicateKind::Lt),
270 3 => Ok(PredicateKind::Le),
271 4 => Ok(PredicateKind::Gt),
272 5 => Ok(PredicateKind::Ge),
273 6 => Ok(PredicateKind::And),
274 7 => Ok(PredicateKind::Or),
275 8 => Ok(PredicateKind::Not),
276 9 => Ok(PredicateKind::True),
277 10 => Ok(PredicateKind::False),
278 b => Err(IoError::BadPredicateKind(b)),
279 }
280}
281
282fn write_node(w: &mut impl Write, node: &ExprData) -> io::Result<()> {
287 match node {
288 ExprData::Symbol {
289 name,
290 domain,
291 commutative,
292 } => {
293 write_u8(w, 0)?;
294 write_u8(w, domain_to_u8(domain))?;
295 write_u8(w, u8::from(*commutative))?;
296 write_str(w, name)
297 }
298 ExprData::Integer(BigInt(n)) => {
299 write_u8(w, 1)?;
300 write_str(w, &n.to_string())
301 }
302 ExprData::Rational(BigRat(r)) => {
303 write_u8(w, 2)?;
304 write_str(w, &r.numer().to_string())?;
305 write_str(w, &r.denom().to_string())
306 }
307 ExprData::Float(BigFloat { inner, prec }) => {
308 write_u8(w, 3)?;
309 write_u32(w, *prec)?;
310 write_str(w, &inner.to_string_radix(16, None))
312 }
313 ExprData::Add(children) => {
314 write_u8(w, 4)?;
315 write_ids(w, children)
316 }
317 ExprData::Mul(children) => {
318 write_u8(w, 5)?;
319 write_ids(w, children)
320 }
321 ExprData::Pow { base, exp } => {
322 write_u8(w, 6)?;
323 write_u32(w, base.0)?;
324 write_u32(w, exp.0)
325 }
326 ExprData::Func { name, args } => {
327 write_u8(w, 7)?;
328 write_str(w, name)?;
329 write_ids(w, args)
330 }
331 ExprData::Piecewise { branches, default } => {
332 write_u8(w, 8)?;
333 write_u32(w, branches.len() as u32)?;
334 for (c, v) in branches {
335 write_u32(w, c.0)?;
336 write_u32(w, v.0)?;
337 }
338 write_u32(w, default.0)
339 }
340 ExprData::Predicate { kind, args } => {
341 write_u8(w, 9)?;
342 write_u8(w, pred_to_u8(kind))?;
343 write_ids(w, args)
344 }
345 ExprData::Forall { var, body } => {
346 write_u8(w, 10)?;
347 write_u32(w, var.0)?;
348 write_u32(w, body.0)
349 }
350 ExprData::Exists { var, body } => {
351 write_u8(w, 11)?;
352 write_u32(w, var.0)?;
353 write_u32(w, body.0)
354 }
355 ExprData::BigO(inner) => {
356 write_u8(w, 12)?;
357 write_u32(w, inner.0)
358 }
359 }
360}
361
362fn read_node(r: &mut impl Read, format_version: u32) -> Result<ExprData, IoError> {
363 let tag = read_u8(r)?;
364 match tag {
365 0 => {
366 let domain = u8_to_domain(read_u8(r)?)?;
367 let commutative = if format_version >= POOL_FORMAT_V4 {
368 read_u8(r)? != 0
369 } else {
370 true
371 };
372 let name = read_str(r)?;
373 Ok(ExprData::Symbol {
374 name,
375 domain,
376 commutative,
377 })
378 }
379 1 => {
380 let s = read_str(r)?;
381 let n: rug::Integer = s
382 .parse()
383 .map_err(|_| IoError::BadNumeric(format!("integer: {s}")))?;
384 Ok(ExprData::Integer(BigInt(n)))
385 }
386 2 => {
387 let nstr = read_str(r)?;
388 let dstr = read_str(r)?;
389 let n: rug::Integer = nstr
390 .parse()
391 .map_err(|_| IoError::BadNumeric(format!("numer: {nstr}")))?;
392 let d: rug::Integer = dstr
393 .parse()
394 .map_err(|_| IoError::BadNumeric(format!("denom: {dstr}")))?;
395 Ok(ExprData::Rational(BigRat(rug::Rational::from((n, d)))))
396 }
397 3 => {
398 let prec = read_u32(r)?;
399 let s = read_str(r)?;
400 let f = rug::Float::parse_radix(&s, 16)
401 .map_err(|_| IoError::BadNumeric(format!("float: {s}")))?;
402 let inner = rug::Float::with_val(prec, f);
403 Ok(ExprData::Float(BigFloat { inner, prec }))
404 }
405 4 => Ok(ExprData::Add(read_ids(r)?)),
406 5 => Ok(ExprData::Mul(read_ids(r)?)),
407 6 => {
408 let base = ExprId(read_u32(r)?);
409 let exp = ExprId(read_u32(r)?);
410 Ok(ExprData::Pow { base, exp })
411 }
412 7 => {
413 let name = read_str(r)?;
414 let args = read_ids(r)?;
415 Ok(ExprData::Func { name, args })
416 }
417 8 => {
418 let n = read_u32(r)? as usize;
419 let mut branches = Vec::with_capacity(n);
420 for _ in 0..n {
421 let c = ExprId(read_u32(r)?);
422 let v = ExprId(read_u32(r)?);
423 branches.push((c, v));
424 }
425 let default = ExprId(read_u32(r)?);
426 Ok(ExprData::Piecewise { branches, default })
427 }
428 9 => {
429 let kind = u8_to_pred(read_u8(r)?)?;
430 let args = read_ids(r)?;
431 Ok(ExprData::Predicate { kind, args })
432 }
433 10 => {
434 if format_version < POOL_FORMAT_V2 {
435 return Err(IoError::BadTag(10));
436 }
437 let var = ExprId(read_u32(r)?);
438 let body = ExprId(read_u32(r)?);
439 Ok(ExprData::Forall { var, body })
440 }
441 11 => {
442 if format_version < POOL_FORMAT_V2 {
443 return Err(IoError::BadTag(11));
444 }
445 let var = ExprId(read_u32(r)?);
446 let body = ExprId(read_u32(r)?);
447 Ok(ExprData::Exists { var, body })
448 }
449 12 => {
450 if format_version < POOL_FORMAT_V3 {
451 return Err(IoError::BadTag(12));
452 }
453 let inner = ExprId(read_u32(r)?);
454 Ok(ExprData::BigO(inner))
455 }
456 b => Err(IoError::BadTag(b)),
457 }
458}
459
460pub fn save_to(pool: &ExprPool, path: impl AsRef<Path>) -> Result<(), IoError> {
466 let path = path.as_ref();
467 let tmp: PathBuf = {
468 let mut p = path.to_path_buf();
469 let mut name = p
470 .file_name()
471 .map(|s| s.to_os_string())
472 .unwrap_or_else(|| std::ffi::OsString::from("pool"));
473 name.push(".tmp");
474 p.set_file_name(name);
475 p
476 };
477
478 {
479 let f = File::create(&tmp)?;
480 let mut w = BufWriter::new(f);
481
482 w.write_all(MAGIC)?;
483 write_u32(&mut w, POOL_FORMAT_WRITE)?;
484 write_u32(&mut w, 0u32)?; let count = pool.len();
487 write_u64(&mut w, count as u64)?;
488 for i in 0..count {
489 let data = pool.get(ExprId(i as u32));
490 write_node(&mut w, &data)?;
491 }
492
493 w.flush()?;
494 w.get_ref().sync_all()?;
495 }
496
497 fs::rename(&tmp, path)?;
498 Ok(())
499}
500
501pub fn load_from(path: impl AsRef<Path>) -> Result<Option<ExprPool>, IoError> {
504 let path = path.as_ref();
505 if !path.exists() {
506 return Ok(None);
507 }
508
509 let f = File::open(path)?;
510 let mut r = BufReader::new(f);
511
512 let mut magic = [0u8; 4];
513 r.read_exact(&mut magic).map_err(|_| IoError::Truncated)?;
514 if &magic != MAGIC {
515 return Err(IoError::BadMagic);
516 }
517
518 let version = read_u32(&mut r)?;
519 if version != POOL_FORMAT_V1
520 && version != POOL_FORMAT_V2
521 && version != POOL_FORMAT_V3
522 && version != POOL_FORMAT_V4
523 {
524 return Err(IoError::UnsupportedVersion(version));
525 }
526 let _flags = read_u32(&mut r)?;
527
528 let pool = ExprPool::new();
529 let count = read_u64(&mut r)? as usize;
530 for expected in 0..count {
531 let data = read_node(&mut r, version)?;
532 let got = pool.intern(data);
533 debug_assert_eq!(got.0 as usize, expected, "pool id drift during load");
534 }
535
536 Ok(Some(pool))
537}
538
539pub fn open_persistent(path: impl AsRef<Path>) -> Result<ExprPool, IoError> {
541 match load_from(path)? {
542 Some(p) => Ok(p),
543 None => Ok(ExprPool::new()),
544 }
545}
546
547impl ExprPool {
552 pub fn checkpoint(&self, path: impl AsRef<Path>) -> Result<(), IoError> {
555 save_to(self, path)
556 }
557
558 pub fn open_persistent(path: impl AsRef<Path>) -> Result<Self, IoError> {
561 open_persistent(path)
562 }
563}
564
565#[cfg(test)]
570mod tests {
571 use super::*;
572 use crate::kernel::{Domain, ExprData};
573
574 fn tempfile() -> PathBuf {
575 let mut p = std::env::temp_dir();
576 p.push(format!(
577 "alkahest_pool_{}_{}.akp",
578 std::process::id(),
579 std::time::SystemTime::now()
580 .duration_since(std::time::UNIX_EPOCH)
581 .unwrap()
582 .as_nanos()
583 ));
584 p
585 }
586
587 #[test]
588 fn round_trip_small_pool() {
589 let p = ExprPool::new();
590 let x = p.symbol("x", Domain::Real);
591 let y = p.symbol("y", Domain::Positive);
592 let two = p.integer(2_i32);
593 let three_halves = p.rational(3, 2);
594 let f = p.float(1.5_f64, 53);
595 let xp = p.pow(x, two);
596 let fn_node = p.func("sin", vec![xp]);
597 let _sum = p.add(vec![fn_node, y, three_halves, f]);
598
599 let path = tempfile();
600 p.checkpoint(&path).unwrap();
601
602 let q = ExprPool::open_persistent(&path).unwrap();
603 assert_eq!(q.len(), p.len(), "node count must match");
604 for i in 0..p.len() {
605 let id = ExprId(i as u32);
606 assert_eq!(p.get(id), q.get(id), "node {i} mismatch after round-trip");
607 }
608
609 let q_x = q.symbol("x", Domain::Real);
612 assert_eq!(q_x, x, "symbol id drifted across checkpoint");
613 let q_two = q.integer(2_i32);
614 assert_eq!(q_two, two);
615
616 let _ = fs::remove_file(&path);
617 }
618
619 #[test]
620 fn bad_magic_rejected() {
621 let path = tempfile();
622 std::fs::write(&path, b"nope1234").unwrap();
623 match load_from(&path) {
624 Err(IoError::BadMagic) => {}
625 other => panic!("expected BadMagic, got {:?}", other.err()),
626 }
627 let _ = fs::remove_file(&path);
628 }
629
630 #[test]
631 fn missing_file_returns_fresh() {
632 let path = tempfile();
633 assert!(!path.exists());
634 let p = ExprPool::open_persistent(&path).unwrap();
635 assert_eq!(p.len(), 0);
636 }
637
638 #[test]
639 fn predicate_and_piecewise_round_trip() {
640 let p = ExprPool::new();
641 let x = p.symbol("x", Domain::Real);
642 let zero = p.integer(0_i32);
643 let one = p.integer(1_i32);
644 let neg_one = p.integer(-1_i32);
645 let cond = p.intern(ExprData::Predicate {
646 kind: PredicateKind::Gt,
647 args: vec![x, zero],
648 });
649 let pc = p.intern(ExprData::Piecewise {
650 branches: vec![(cond, one)],
651 default: neg_one,
652 });
653
654 let path = tempfile();
655 p.checkpoint(&path).unwrap();
656 let q = ExprPool::open_persistent(&path).unwrap();
657 assert_eq!(p.get(pc), q.get(pc));
658 let _ = fs::remove_file(&path);
659 }
660
661 #[test]
662 fn big_o_round_trip() {
663 let p = ExprPool::new();
664 let x = p.symbol("x", Domain::Real);
665 let o = p.big_o(p.pow(x, p.integer(6)));
666 let path = tempfile();
667 p.checkpoint(&path).unwrap();
668 let q = ExprPool::open_persistent(&path).unwrap();
669 assert_eq!(q.get(o), p.get(o));
670 let _ = fs::remove_file(&path);
671 }
672}