1use crate::kernel::domain::Domain;
56use crate::kernel::expr::{BigFloat, BigInt, BigRat, ExprData, ExprId, PredicateKind};
57use crate::kernel::pool::ExprPool;
58use std::fs::{self, File};
59use std::io::{self, BufReader, BufWriter, Read, Write};
60use std::path::{Path, PathBuf};
61
62const MAGIC: &[u8; 4] = b"ALKP";
63const POOL_FORMAT_V1: u32 = 1;
65const POOL_FORMAT_V2: u32 = 2;
67const POOL_FORMAT_V3: u32 = 3;
69const POOL_FORMAT_V4: u32 = 4;
71const POOL_FORMAT_V5: u32 = 5;
73const POOL_FORMAT_WRITE: u32 = POOL_FORMAT_V5;
74
75#[derive(Debug)]
83pub enum IoError {
84 Io(io::Error),
85 BadMagic,
86 UnsupportedVersion(u32),
87 Truncated,
88 BadUtf8,
89 BadDomain(u8),
90 BadTag(u8),
91 BadPredicateKind(u8),
92 BadNumeric(String),
93}
94
95impl std::fmt::Display for IoError {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 match self {
98 IoError::Io(e) => write!(f, "io error: {e}"),
99 IoError::BadMagic => write!(f, "not an alkahest pool file (bad magic)"),
100 IoError::UnsupportedVersion(v) => {
101 write!(
102 f,
103 "unsupported pool file version {v}; run `alkahest migrate-pool`"
104 )
105 }
106 IoError::Truncated => write!(f, "pool file truncated or incomplete"),
107 IoError::BadUtf8 => write!(f, "pool file contains invalid UTF-8"),
108 IoError::BadDomain(b) => write!(f, "pool file has unknown domain tag {b}"),
109 IoError::BadTag(b) => write!(f, "pool file has unknown node tag {b}"),
110 IoError::BadPredicateKind(b) => {
111 write!(f, "pool file has unknown predicate kind {b}")
112 }
113 IoError::BadNumeric(s) => write!(f, "pool file has invalid numeric: {s}"),
114 }
115 }
116}
117
118impl std::error::Error for IoError {}
119
120impl From<io::Error> for IoError {
121 fn from(e: io::Error) -> Self {
122 IoError::Io(e)
123 }
124}
125
126impl crate::errors::AlkahestError for IoError {
127 fn code(&self) -> &'static str {
128 match self {
129 IoError::Io(_) => "E-IO-001",
130 IoError::BadMagic => "E-IO-002",
131 IoError::UnsupportedVersion(_) => "E-IO-003",
132 IoError::Truncated => "E-IO-004",
133 IoError::BadUtf8 => "E-IO-005",
134 IoError::BadDomain(_) => "E-IO-006",
135 IoError::BadTag(_) => "E-IO-007",
136 IoError::BadPredicateKind(_) => "E-IO-008",
137 IoError::BadNumeric(_) => "E-IO-009",
138 }
139 }
140
141 fn remediation(&self) -> Option<&'static str> {
142 match self {
143 IoError::BadMagic => Some(
144 "file is not an alkahest pool; check the path or regenerate with ExprPool::checkpoint()",
145 ),
146 IoError::UnsupportedVersion(_) => Some(
147 "run the `alkahest migrate-pool` CLI to upgrade the file, or regenerate from source",
148 ),
149 IoError::Truncated => Some(
150 "file was truncated (likely a crash during checkpoint); rerun from source and checkpoint again",
151 ),
152 _ => None,
153 }
154 }
155}
156
157#[deprecated(since = "2.0.0", note = "renamed to IoError with E-IO-* codes")]
159pub type PoolPersistError = IoError;
160
161fn write_u8(w: &mut impl Write, v: u8) -> io::Result<()> {
166 w.write_all(&[v])
167}
168fn write_u32(w: &mut impl Write, v: u32) -> io::Result<()> {
169 w.write_all(&v.to_le_bytes())
170}
171fn write_u64(w: &mut impl Write, v: u64) -> io::Result<()> {
172 w.write_all(&v.to_le_bytes())
173}
174
175fn write_str(w: &mut impl Write, s: &str) -> io::Result<()> {
176 let bytes = s.as_bytes();
177 write_u32(w, bytes.len() as u32)?;
178 w.write_all(bytes)
179}
180
181fn write_ids(w: &mut impl Write, ids: &[ExprId]) -> io::Result<()> {
182 write_u32(w, ids.len() as u32)?;
183 for id in ids {
184 write_u32(w, id.0)?;
185 }
186 Ok(())
187}
188
189fn read_u8(r: &mut impl Read) -> Result<u8, IoError> {
190 let mut b = [0u8; 1];
191 r.read_exact(&mut b).map_err(|_| IoError::Truncated)?;
192 Ok(b[0])
193}
194
195fn read_u32(r: &mut impl Read) -> Result<u32, IoError> {
196 let mut b = [0u8; 4];
197 r.read_exact(&mut b).map_err(|_| IoError::Truncated)?;
198 Ok(u32::from_le_bytes(b))
199}
200
201fn read_u64(r: &mut impl Read) -> Result<u64, IoError> {
202 let mut b = [0u8; 8];
203 r.read_exact(&mut b).map_err(|_| IoError::Truncated)?;
204 Ok(u64::from_le_bytes(b))
205}
206
207fn read_str(r: &mut impl Read) -> Result<String, IoError> {
208 let len = read_u32(r)? as usize;
209 let mut buf = vec![0u8; len];
210 r.read_exact(&mut buf).map_err(|_| IoError::Truncated)?;
211 String::from_utf8(buf).map_err(|_| IoError::BadUtf8)
212}
213
214fn read_ids(r: &mut impl Read) -> Result<Vec<ExprId>, IoError> {
215 let arity = read_u32(r)? as usize;
216 let mut out = Vec::with_capacity(arity);
217 for _ in 0..arity {
218 out.push(ExprId(read_u32(r)?));
219 }
220 Ok(out)
221}
222
223fn domain_to_u8(d: &Domain) -> u8 {
228 match d {
229 Domain::Real => 0,
230 Domain::Complex => 1,
231 Domain::Integer => 2,
232 Domain::Positive => 3,
233 Domain::NonNegative => 4,
234 Domain::NonZero => 5,
235 }
236}
237
238fn u8_to_domain(b: u8) -> Result<Domain, IoError> {
239 match b {
240 0 => Ok(Domain::Real),
241 1 => Ok(Domain::Complex),
242 2 => Ok(Domain::Integer),
243 3 => Ok(Domain::Positive),
244 4 => Ok(Domain::NonNegative),
245 5 => Ok(Domain::NonZero),
246 b => Err(IoError::BadDomain(b)),
247 }
248}
249
250fn pred_to_u8(k: &PredicateKind) -> u8 {
251 match k {
253 PredicateKind::Eq => 0,
254 PredicateKind::Ne => 1,
255 PredicateKind::Lt => 2,
256 PredicateKind::Le => 3,
257 PredicateKind::Gt => 4,
258 PredicateKind::Ge => 5,
259 PredicateKind::And => 6,
260 PredicateKind::Or => 7,
261 PredicateKind::Not => 8,
262 PredicateKind::True => 9,
263 PredicateKind::False => 10,
264 }
265}
266
267fn u8_to_pred(b: u8) -> Result<PredicateKind, IoError> {
268 match b {
269 0 => Ok(PredicateKind::Eq),
270 1 => Ok(PredicateKind::Ne),
271 2 => Ok(PredicateKind::Lt),
272 3 => Ok(PredicateKind::Le),
273 4 => Ok(PredicateKind::Gt),
274 5 => Ok(PredicateKind::Ge),
275 6 => Ok(PredicateKind::And),
276 7 => Ok(PredicateKind::Or),
277 8 => Ok(PredicateKind::Not),
278 9 => Ok(PredicateKind::True),
279 10 => Ok(PredicateKind::False),
280 b => Err(IoError::BadPredicateKind(b)),
281 }
282}
283
284fn write_node(w: &mut impl Write, node: &ExprData) -> io::Result<()> {
289 match node {
290 ExprData::Symbol {
291 name,
292 domain,
293 commutative,
294 } => {
295 write_u8(w, 0)?;
296 write_u8(w, domain_to_u8(domain))?;
297 write_u8(w, u8::from(*commutative))?;
298 write_str(w, name)
299 }
300 ExprData::Integer(BigInt(n)) => {
301 write_u8(w, 1)?;
302 write_str(w, &n.to_string())
303 }
304 ExprData::Rational(BigRat(r)) => {
305 write_u8(w, 2)?;
306 write_str(w, &r.numer().to_string())?;
307 write_str(w, &r.denom().to_string())
308 }
309 ExprData::Float(BigFloat { inner, prec }) => {
310 write_u8(w, 3)?;
311 write_u32(w, *prec)?;
312 write_str(w, &inner.to_string_radix(16, None))
314 }
315 ExprData::Add(children) => {
316 write_u8(w, 4)?;
317 write_ids(w, children)
318 }
319 ExprData::Mul(children) => {
320 write_u8(w, 5)?;
321 write_ids(w, children)
322 }
323 ExprData::Pow { base, exp } => {
324 write_u8(w, 6)?;
325 write_u32(w, base.0)?;
326 write_u32(w, exp.0)
327 }
328 ExprData::Func { name, args } => {
329 write_u8(w, 7)?;
330 write_str(w, name)?;
331 write_ids(w, args)
332 }
333 ExprData::Piecewise { branches, default } => {
334 write_u8(w, 8)?;
335 write_u32(w, branches.len() as u32)?;
336 for (c, v) in branches {
337 write_u32(w, c.0)?;
338 write_u32(w, v.0)?;
339 }
340 write_u32(w, default.0)
341 }
342 ExprData::Predicate { kind, args } => {
343 write_u8(w, 9)?;
344 write_u8(w, pred_to_u8(kind))?;
345 write_ids(w, args)
346 }
347 ExprData::Forall { var, body } => {
348 write_u8(w, 10)?;
349 write_u32(w, var.0)?;
350 write_u32(w, body.0)
351 }
352 ExprData::Exists { var, body } => {
353 write_u8(w, 11)?;
354 write_u32(w, var.0)?;
355 write_u32(w, body.0)
356 }
357 ExprData::BigO(inner) => {
358 write_u8(w, 12)?;
359 write_u32(w, inner.0)
360 }
361 ExprData::RootSum { poly, var, body } => {
362 write_u8(w, 13)?;
363 write_u32(w, poly.0)?;
364 write_u32(w, var.0)?;
365 write_u32(w, body.0)
366 }
367 }
368}
369
370fn read_node(r: &mut impl Read, format_version: u32) -> Result<ExprData, IoError> {
371 let tag = read_u8(r)?;
372 match tag {
373 0 => {
374 let domain = u8_to_domain(read_u8(r)?)?;
375 let commutative = if format_version >= POOL_FORMAT_V4 {
376 read_u8(r)? != 0
377 } else {
378 true
379 };
380 let name = read_str(r)?;
381 Ok(ExprData::Symbol {
382 name,
383 domain,
384 commutative,
385 })
386 }
387 1 => {
388 let s = read_str(r)?;
389 let n: rug::Integer = s
390 .parse()
391 .map_err(|_| IoError::BadNumeric(format!("integer: {s}")))?;
392 Ok(ExprData::Integer(BigInt(n)))
393 }
394 2 => {
395 let nstr = read_str(r)?;
396 let dstr = read_str(r)?;
397 let n: rug::Integer = nstr
398 .parse()
399 .map_err(|_| IoError::BadNumeric(format!("numer: {nstr}")))?;
400 let d: rug::Integer = dstr
401 .parse()
402 .map_err(|_| IoError::BadNumeric(format!("denom: {dstr}")))?;
403 Ok(ExprData::Rational(BigRat(rug::Rational::from((n, d)))))
404 }
405 3 => {
406 let prec = read_u32(r)?;
407 let s = read_str(r)?;
408 let f = rug::Float::parse_radix(&s, 16)
409 .map_err(|_| IoError::BadNumeric(format!("float: {s}")))?;
410 let inner = rug::Float::with_val(prec, f);
411 Ok(ExprData::Float(BigFloat { inner, prec }))
412 }
413 4 => Ok(ExprData::Add(read_ids(r)?)),
414 5 => Ok(ExprData::Mul(read_ids(r)?)),
415 6 => {
416 let base = ExprId(read_u32(r)?);
417 let exp = ExprId(read_u32(r)?);
418 Ok(ExprData::Pow { base, exp })
419 }
420 7 => {
421 let name = read_str(r)?;
422 let args = read_ids(r)?;
423 Ok(ExprData::Func { name, args })
424 }
425 8 => {
426 let n = read_u32(r)? as usize;
427 let mut branches = Vec::with_capacity(n);
428 for _ in 0..n {
429 let c = ExprId(read_u32(r)?);
430 let v = ExprId(read_u32(r)?);
431 branches.push((c, v));
432 }
433 let default = ExprId(read_u32(r)?);
434 Ok(ExprData::Piecewise { branches, default })
435 }
436 9 => {
437 let kind = u8_to_pred(read_u8(r)?)?;
438 let args = read_ids(r)?;
439 Ok(ExprData::Predicate { kind, args })
440 }
441 10 => {
442 if format_version < POOL_FORMAT_V2 {
443 return Err(IoError::BadTag(10));
444 }
445 let var = ExprId(read_u32(r)?);
446 let body = ExprId(read_u32(r)?);
447 Ok(ExprData::Forall { var, body })
448 }
449 11 => {
450 if format_version < POOL_FORMAT_V2 {
451 return Err(IoError::BadTag(11));
452 }
453 let var = ExprId(read_u32(r)?);
454 let body = ExprId(read_u32(r)?);
455 Ok(ExprData::Exists { var, body })
456 }
457 12 => {
458 if format_version < POOL_FORMAT_V3 {
459 return Err(IoError::BadTag(12));
460 }
461 let inner = ExprId(read_u32(r)?);
462 Ok(ExprData::BigO(inner))
463 }
464 13 => {
465 if format_version < POOL_FORMAT_V5 {
466 return Err(IoError::BadTag(13));
467 }
468 let poly = ExprId(read_u32(r)?);
469 let var = ExprId(read_u32(r)?);
470 let body = ExprId(read_u32(r)?);
471 Ok(ExprData::RootSum { poly, var, body })
472 }
473 b => Err(IoError::BadTag(b)),
474 }
475}
476
477pub fn save_to(pool: &ExprPool, path: impl AsRef<Path>) -> Result<(), IoError> {
483 let path = path.as_ref();
484 let tmp: PathBuf = {
485 let mut p = path.to_path_buf();
486 let mut name = p
487 .file_name()
488 .map(|s| s.to_os_string())
489 .unwrap_or_else(|| std::ffi::OsString::from("pool"));
490 name.push(".tmp");
491 p.set_file_name(name);
492 p
493 };
494
495 {
496 let f = File::create(&tmp)?;
497 let mut w = BufWriter::new(f);
498
499 w.write_all(MAGIC)?;
500 write_u32(&mut w, POOL_FORMAT_WRITE)?;
501 write_u32(&mut w, 0u32)?; let count = pool.len();
504 write_u64(&mut w, count as u64)?;
505 for i in 0..count {
506 let data = pool.get(ExprId(i as u32));
507 write_node(&mut w, &data)?;
508 }
509
510 w.flush()?;
511 w.get_ref().sync_all()?;
512 }
513
514 fs::rename(&tmp, path)?;
515 Ok(())
516}
517
518pub fn load_from(path: impl AsRef<Path>) -> Result<Option<ExprPool>, IoError> {
521 let path = path.as_ref();
522 if !path.exists() {
523 return Ok(None);
524 }
525
526 let f = File::open(path)?;
527 let mut r = BufReader::new(f);
528
529 let mut magic = [0u8; 4];
530 r.read_exact(&mut magic).map_err(|_| IoError::Truncated)?;
531 if &magic != MAGIC {
532 return Err(IoError::BadMagic);
533 }
534
535 let version = read_u32(&mut r)?;
536 if version != POOL_FORMAT_V1
537 && version != POOL_FORMAT_V2
538 && version != POOL_FORMAT_V3
539 && version != POOL_FORMAT_V4
540 && version != POOL_FORMAT_V5
541 {
542 return Err(IoError::UnsupportedVersion(version));
543 }
544 let _flags = read_u32(&mut r)?;
545
546 let pool = ExprPool::new();
547 let count = read_u64(&mut r)? as usize;
548 for expected in 0..count {
549 let data = read_node(&mut r, version)?;
550 let got = pool.intern(data);
551 debug_assert_eq!(got.0 as usize, expected, "pool id drift during load");
552 }
553
554 Ok(Some(pool))
555}
556
557pub fn open_persistent(path: impl AsRef<Path>) -> Result<ExprPool, IoError> {
559 match load_from(path)? {
560 Some(p) => Ok(p),
561 None => Ok(ExprPool::new()),
562 }
563}
564
565impl ExprPool {
570 pub fn checkpoint(&self, path: impl AsRef<Path>) -> Result<(), IoError> {
573 save_to(self, path)
574 }
575
576 pub fn open_persistent(path: impl AsRef<Path>) -> Result<Self, IoError> {
579 open_persistent(path)
580 }
581}
582
583#[cfg(test)]
588mod tests {
589 use super::*;
590 use crate::kernel::{Domain, ExprData};
591
592 fn tempfile() -> PathBuf {
593 let mut p = std::env::temp_dir();
594 p.push(format!(
595 "alkahest_pool_{}_{}.akp",
596 std::process::id(),
597 std::time::SystemTime::now()
598 .duration_since(std::time::UNIX_EPOCH)
599 .unwrap()
600 .as_nanos()
601 ));
602 p
603 }
604
605 #[test]
606 fn round_trip_small_pool() {
607 let p = ExprPool::new();
608 let x = p.symbol("x", Domain::Real);
609 let y = p.symbol("y", Domain::Positive);
610 let two = p.integer(2_i32);
611 let three_halves = p.rational(3, 2);
612 let f = p.float(1.5_f64, 53);
613 let xp = p.pow(x, two);
614 let fn_node = p.func("sin", vec![xp]);
615 let _sum = p.add(vec![fn_node, y, three_halves, f]);
616
617 let path = tempfile();
618 p.checkpoint(&path).unwrap();
619
620 let q = ExprPool::open_persistent(&path).unwrap();
621 assert_eq!(q.len(), p.len(), "node count must match");
622 for i in 0..p.len() {
623 let id = ExprId(i as u32);
624 assert_eq!(p.get(id), q.get(id), "node {i} mismatch after round-trip");
625 }
626
627 let q_x = q.symbol("x", Domain::Real);
630 assert_eq!(q_x, x, "symbol id drifted across checkpoint");
631 let q_two = q.integer(2_i32);
632 assert_eq!(q_two, two);
633
634 let _ = fs::remove_file(&path);
635 }
636
637 #[test]
638 fn round_trip_root_sum() {
639 let p = ExprPool::new();
641 let x = p.symbol("x", Domain::Real);
642 let t = p.symbol("t", Domain::Complex);
643 let poly = p.add(vec![p.pow(t, p.integer(2_i32)), p.integer(1_i32)]);
644 let body = p.mul(vec![t, p.func("log", vec![p.add(vec![x, t])])]);
645 let rs = p.root_sum(poly, t, body);
646 assert!(matches!(p.get(rs), ExprData::RootSum { .. }));
647
648 let path = tempfile();
649 p.checkpoint(&path).unwrap();
650 let q = ExprPool::open_persistent(&path).unwrap();
651 assert_eq!(q.len(), p.len());
652 for i in 0..p.len() {
653 let id = ExprId(i as u32);
654 assert_eq!(p.get(id), q.get(id), "node {i} mismatch after round-trip");
655 }
656 let _ = fs::remove_file(&path);
657 }
658
659 #[test]
660 fn bad_magic_rejected() {
661 let path = tempfile();
662 std::fs::write(&path, b"nope1234").unwrap();
663 match load_from(&path) {
664 Err(IoError::BadMagic) => {}
665 other => panic!("expected BadMagic, got {:?}", other.err()),
666 }
667 let _ = fs::remove_file(&path);
668 }
669
670 #[test]
671 fn missing_file_returns_fresh() {
672 let path = tempfile();
673 assert!(!path.exists());
674 let p = ExprPool::open_persistent(&path).unwrap();
675 assert_eq!(p.len(), 0);
676 }
677
678 #[test]
679 fn predicate_and_piecewise_round_trip() {
680 let p = ExprPool::new();
681 let x = p.symbol("x", Domain::Real);
682 let zero = p.integer(0_i32);
683 let one = p.integer(1_i32);
684 let neg_one = p.integer(-1_i32);
685 let cond = p.intern(ExprData::Predicate {
686 kind: PredicateKind::Gt,
687 args: vec![x, zero],
688 });
689 let pc = p.intern(ExprData::Piecewise {
690 branches: vec![(cond, one)],
691 default: neg_one,
692 });
693
694 let path = tempfile();
695 p.checkpoint(&path).unwrap();
696 let q = ExprPool::open_persistent(&path).unwrap();
697 assert_eq!(p.get(pc), q.get(pc));
698 let _ = fs::remove_file(&path);
699 }
700
701 #[test]
702 fn big_o_round_trip() {
703 let p = ExprPool::new();
704 let x = p.symbol("x", Domain::Real);
705 let o = p.big_o(p.pow(x, p.integer(6)));
706 let path = tempfile();
707 p.checkpoint(&path).unwrap();
708 let q = ExprPool::open_persistent(&path).unwrap();
709 assert_eq!(q.get(o), p.get(o));
710 let _ = fs::remove_file(&path);
711 }
712}