Skip to main content

alkahest_cas/kernel/
pool_persist.rs

1//! V1-14 — Persistent / incremental `ExprPool`.
2//!
3//! Opt-in serialization of the intern table to disk so long-running notebooks
4//! and repeated simplifications don't rebuild the pool from scratch on every
5//! process start.
6//!
7//! # Status
8//!
9//! This is the v1.0 scope: a **versioned binary file** (not a true mmap-backed
10//! arena).  `checkpoint()` writes the full node vector atomically (temp file +
11//! `rename`); `open_persistent(path)` reads it back if it exists.  Structural
12//! hashes line up by construction — the re-interned `ExprData` values hash
13//! identically, so a subsequent `pool.add([x, y])` lookup hits the rebuilt
14//! index.
15//!
16//! A true mmap/CapnProto arena with `ExprData` stored inline is tracked as a
17//! v2.0 follow-up; it requires a ground-up redesign of `ExprData` to avoid
18//! heap allocations for `Vec<ExprId>` children.
19//!
20//! # File format (v1)
21//!
22//! ```text
23//!   Magic     = "ALKP"             (4 bytes)
24//!   Version   = u32 (**4** = symbol `commutative` flag on `(tag 0)`; **3** = BigO tag 12; **2** = quantifiers 10–11; **1** = original 0–9)
25//!   Flags     = u32                 (reserved; always 0 in v1)
26//!   NodeCount = u64
27//!   Nodes     = NodeCount × TaggedNode
28//! ```
29//!
30//! Each `TaggedNode`:
31//! ```text
32//!   tag : u8
33//!     0 Symbol     -> domain:u8, [commutative:u8 if format≥4], len:u32, name
34//!     1 Integer    -> len:u32, base-10 digits (ASCII, optionally '-' prefix)
35//!     2 Rational   -> numer_len:u32, numer, denom_len:u32, denom
36//!     3 Float      -> prec:u32, len:u32, base-16 mantissa (rug to_string_radix)
37//!     4 Add        -> arity:u32, ExprId.0 (u32) × arity
38//!     5 Mul        -> arity:u32, ExprId.0 × arity
39//!     6 Pow        -> base:u32, exp:u32
40//!     7 Func       -> len:u32, name, arity:u32, ExprId.0 × arity
41//!     8 Piecewise  -> n_branches:u32, (cond:u32, val:u32) × n, default:u32
42//!     9 Predicate  -> kind:u8, arity:u32, ExprId.0 × arity
43//!     10 Forall   -> var:u32, body:u32
44//!     11 Exists   -> var:u32, body:u32
45//!     12 BigO    -> inner:u32
46//! ```
47//!
48//! File version (`Version` u32 field): **1** is the original v1.0 layout (tags 0–9 only).
49//! **2** adds tags 10–11 for quantifiers. **3** adds tag 12 for `BigO`. **4** adds
50//! `commutative: u8` after `domain` on symbol nodes (V3-2).
51//! Current writers emit version **4**; readers accept **1** … **4**.
52//!
53//! All integers are little-endian.
54
55use crate::kernel::domain::Domain;
56use crate::kernel::expr::{BigFloat, BigInt, BigRat, ExprData, ExprId, PredicateKind};
57use crate::kernel::pool::ExprPool;
58use std::fs::{self, File};
59use std::io::{self, BufReader, BufWriter, Read, Write};
60use std::path::{Path, PathBuf};
61
62const MAGIC: &[u8; 4] = b"ALKP";
63/// Oldest readable format (predicate / piecewise only).
64const POOL_FORMAT_V1: u32 = 1;
65/// Adds `Forall` / `Exists` node tags 10–11.
66const POOL_FORMAT_V2: u32 = 2;
67/// Adds `BigO` tag 12 (V2-15 series API).
68const POOL_FORMAT_V3: u32 = 3;
69/// Symbol nodes carry `commutative: u8` after `domain` (V3-2).
70const POOL_FORMAT_V4: u32 = 4;
71/// Adds `RootSum` tag 13 (algebraic-residue logarithmic part).
72const POOL_FORMAT_V5: u32 = 5;
73const POOL_FORMAT_WRITE: u32 = POOL_FORMAT_V5;
74
75// ---------------------------------------------------------------------------
76// Error
77// ---------------------------------------------------------------------------
78
79/// I/O errors from checkpoint and restore operations on `ExprPool`.
80///
81/// Codes: `E-IO-001` … `E-IO-009`.
82#[derive(Debug)]
83pub enum IoError {
84    Io(io::Error),
85    BadMagic,
86    UnsupportedVersion(u32),
87    Truncated,
88    BadUtf8,
89    BadDomain(u8),
90    BadTag(u8),
91    BadPredicateKind(u8),
92    BadNumeric(String),
93}
94
95impl std::fmt::Display for IoError {
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        match self {
98            IoError::Io(e) => write!(f, "io error: {e}"),
99            IoError::BadMagic => write!(f, "not an alkahest pool file (bad magic)"),
100            IoError::UnsupportedVersion(v) => {
101                write!(
102                    f,
103                    "unsupported pool file version {v}; run `alkahest migrate-pool`"
104                )
105            }
106            IoError::Truncated => write!(f, "pool file truncated or incomplete"),
107            IoError::BadUtf8 => write!(f, "pool file contains invalid UTF-8"),
108            IoError::BadDomain(b) => write!(f, "pool file has unknown domain tag {b}"),
109            IoError::BadTag(b) => write!(f, "pool file has unknown node tag {b}"),
110            IoError::BadPredicateKind(b) => {
111                write!(f, "pool file has unknown predicate kind {b}")
112            }
113            IoError::BadNumeric(s) => write!(f, "pool file has invalid numeric: {s}"),
114        }
115    }
116}
117
118impl std::error::Error for IoError {}
119
120impl From<io::Error> for IoError {
121    fn from(e: io::Error) -> Self {
122        IoError::Io(e)
123    }
124}
125
126impl crate::errors::AlkahestError for IoError {
127    fn code(&self) -> &'static str {
128        match self {
129            IoError::Io(_) => "E-IO-001",
130            IoError::BadMagic => "E-IO-002",
131            IoError::UnsupportedVersion(_) => "E-IO-003",
132            IoError::Truncated => "E-IO-004",
133            IoError::BadUtf8 => "E-IO-005",
134            IoError::BadDomain(_) => "E-IO-006",
135            IoError::BadTag(_) => "E-IO-007",
136            IoError::BadPredicateKind(_) => "E-IO-008",
137            IoError::BadNumeric(_) => "E-IO-009",
138        }
139    }
140
141    fn remediation(&self) -> Option<&'static str> {
142        match self {
143            IoError::BadMagic => Some(
144                "file is not an alkahest pool; check the path or regenerate with ExprPool::checkpoint()",
145            ),
146            IoError::UnsupportedVersion(_) => Some(
147                "run the `alkahest migrate-pool` CLI to upgrade the file, or regenerate from source",
148            ),
149            IoError::Truncated => Some(
150                "file was truncated (likely a crash during checkpoint); rerun from source and checkpoint again",
151            ),
152            _ => None,
153        }
154    }
155}
156
157/// Deprecated alias — use [`IoError`] instead.
158#[deprecated(since = "2.0.0", note = "renamed to IoError with E-IO-* codes")]
159pub type PoolPersistError = IoError;
160
161// ---------------------------------------------------------------------------
162// Low-level binary helpers
163// ---------------------------------------------------------------------------
164
165fn write_u8(w: &mut impl Write, v: u8) -> io::Result<()> {
166    w.write_all(&[v])
167}
168fn write_u32(w: &mut impl Write, v: u32) -> io::Result<()> {
169    w.write_all(&v.to_le_bytes())
170}
171fn write_u64(w: &mut impl Write, v: u64) -> io::Result<()> {
172    w.write_all(&v.to_le_bytes())
173}
174
175fn write_str(w: &mut impl Write, s: &str) -> io::Result<()> {
176    let bytes = s.as_bytes();
177    write_u32(w, bytes.len() as u32)?;
178    w.write_all(bytes)
179}
180
181fn write_ids(w: &mut impl Write, ids: &[ExprId]) -> io::Result<()> {
182    write_u32(w, ids.len() as u32)?;
183    for id in ids {
184        write_u32(w, id.0)?;
185    }
186    Ok(())
187}
188
189fn read_u8(r: &mut impl Read) -> Result<u8, IoError> {
190    let mut b = [0u8; 1];
191    r.read_exact(&mut b).map_err(|_| IoError::Truncated)?;
192    Ok(b[0])
193}
194
195fn read_u32(r: &mut impl Read) -> Result<u32, IoError> {
196    let mut b = [0u8; 4];
197    r.read_exact(&mut b).map_err(|_| IoError::Truncated)?;
198    Ok(u32::from_le_bytes(b))
199}
200
201fn read_u64(r: &mut impl Read) -> Result<u64, IoError> {
202    let mut b = [0u8; 8];
203    r.read_exact(&mut b).map_err(|_| IoError::Truncated)?;
204    Ok(u64::from_le_bytes(b))
205}
206
207fn read_str(r: &mut impl Read) -> Result<String, IoError> {
208    let len = read_u32(r)? as usize;
209    let mut buf = vec![0u8; len];
210    r.read_exact(&mut buf).map_err(|_| IoError::Truncated)?;
211    String::from_utf8(buf).map_err(|_| IoError::BadUtf8)
212}
213
214fn read_ids(r: &mut impl Read) -> Result<Vec<ExprId>, IoError> {
215    let arity = read_u32(r)? as usize;
216    let mut out = Vec::with_capacity(arity);
217    for _ in 0..arity {
218        out.push(ExprId(read_u32(r)?));
219    }
220    Ok(out)
221}
222
223// ---------------------------------------------------------------------------
224// Domain <-> u8
225// ---------------------------------------------------------------------------
226
227fn domain_to_u8(d: &Domain) -> u8 {
228    match d {
229        Domain::Real => 0,
230        Domain::Complex => 1,
231        Domain::Integer => 2,
232        Domain::Positive => 3,
233        Domain::NonNegative => 4,
234        Domain::NonZero => 5,
235    }
236}
237
238fn u8_to_domain(b: u8) -> Result<Domain, IoError> {
239    match b {
240        0 => Ok(Domain::Real),
241        1 => Ok(Domain::Complex),
242        2 => Ok(Domain::Integer),
243        3 => Ok(Domain::Positive),
244        4 => Ok(Domain::NonNegative),
245        5 => Ok(Domain::NonZero),
246        b => Err(IoError::BadDomain(b)),
247    }
248}
249
250fn pred_to_u8(k: &PredicateKind) -> u8 {
251    // Enumerate all variants in a stable order.
252    match k {
253        PredicateKind::Eq => 0,
254        PredicateKind::Ne => 1,
255        PredicateKind::Lt => 2,
256        PredicateKind::Le => 3,
257        PredicateKind::Gt => 4,
258        PredicateKind::Ge => 5,
259        PredicateKind::And => 6,
260        PredicateKind::Or => 7,
261        PredicateKind::Not => 8,
262        PredicateKind::True => 9,
263        PredicateKind::False => 10,
264    }
265}
266
267fn u8_to_pred(b: u8) -> Result<PredicateKind, IoError> {
268    match b {
269        0 => Ok(PredicateKind::Eq),
270        1 => Ok(PredicateKind::Ne),
271        2 => Ok(PredicateKind::Lt),
272        3 => Ok(PredicateKind::Le),
273        4 => Ok(PredicateKind::Gt),
274        5 => Ok(PredicateKind::Ge),
275        6 => Ok(PredicateKind::And),
276        7 => Ok(PredicateKind::Or),
277        8 => Ok(PredicateKind::Not),
278        9 => Ok(PredicateKind::True),
279        10 => Ok(PredicateKind::False),
280        b => Err(IoError::BadPredicateKind(b)),
281    }
282}
283
284// ---------------------------------------------------------------------------
285// Node ↔ bytes
286// ---------------------------------------------------------------------------
287
288fn write_node(w: &mut impl Write, node: &ExprData) -> io::Result<()> {
289    match node {
290        ExprData::Symbol {
291            name,
292            domain,
293            commutative,
294        } => {
295            write_u8(w, 0)?;
296            write_u8(w, domain_to_u8(domain))?;
297            write_u8(w, u8::from(*commutative))?;
298            write_str(w, name)
299        }
300        ExprData::Integer(BigInt(n)) => {
301            write_u8(w, 1)?;
302            write_str(w, &n.to_string())
303        }
304        ExprData::Rational(BigRat(r)) => {
305            write_u8(w, 2)?;
306            write_str(w, &r.numer().to_string())?;
307            write_str(w, &r.denom().to_string())
308        }
309        ExprData::Float(BigFloat { inner, prec }) => {
310            write_u8(w, 3)?;
311            write_u32(w, *prec)?;
312            // rug::Float::to_string_radix(16, None) round-trips exactly.
313            write_str(w, &inner.to_string_radix(16, None))
314        }
315        ExprData::Add(children) => {
316            write_u8(w, 4)?;
317            write_ids(w, children)
318        }
319        ExprData::Mul(children) => {
320            write_u8(w, 5)?;
321            write_ids(w, children)
322        }
323        ExprData::Pow { base, exp } => {
324            write_u8(w, 6)?;
325            write_u32(w, base.0)?;
326            write_u32(w, exp.0)
327        }
328        ExprData::Func { name, args } => {
329            write_u8(w, 7)?;
330            write_str(w, name)?;
331            write_ids(w, args)
332        }
333        ExprData::Piecewise { branches, default } => {
334            write_u8(w, 8)?;
335            write_u32(w, branches.len() as u32)?;
336            for (c, v) in branches {
337                write_u32(w, c.0)?;
338                write_u32(w, v.0)?;
339            }
340            write_u32(w, default.0)
341        }
342        ExprData::Predicate { kind, args } => {
343            write_u8(w, 9)?;
344            write_u8(w, pred_to_u8(kind))?;
345            write_ids(w, args)
346        }
347        ExprData::Forall { var, body } => {
348            write_u8(w, 10)?;
349            write_u32(w, var.0)?;
350            write_u32(w, body.0)
351        }
352        ExprData::Exists { var, body } => {
353            write_u8(w, 11)?;
354            write_u32(w, var.0)?;
355            write_u32(w, body.0)
356        }
357        ExprData::BigO(inner) => {
358            write_u8(w, 12)?;
359            write_u32(w, inner.0)
360        }
361        ExprData::RootSum { poly, var, body } => {
362            write_u8(w, 13)?;
363            write_u32(w, poly.0)?;
364            write_u32(w, var.0)?;
365            write_u32(w, body.0)
366        }
367    }
368}
369
370fn read_node(r: &mut impl Read, format_version: u32) -> Result<ExprData, IoError> {
371    let tag = read_u8(r)?;
372    match tag {
373        0 => {
374            let domain = u8_to_domain(read_u8(r)?)?;
375            let commutative = if format_version >= POOL_FORMAT_V4 {
376                read_u8(r)? != 0
377            } else {
378                true
379            };
380            let name = read_str(r)?;
381            Ok(ExprData::Symbol {
382                name,
383                domain,
384                commutative,
385            })
386        }
387        1 => {
388            let s = read_str(r)?;
389            let n: rug::Integer = s
390                .parse()
391                .map_err(|_| IoError::BadNumeric(format!("integer: {s}")))?;
392            Ok(ExprData::Integer(BigInt(n)))
393        }
394        2 => {
395            let nstr = read_str(r)?;
396            let dstr = read_str(r)?;
397            let n: rug::Integer = nstr
398                .parse()
399                .map_err(|_| IoError::BadNumeric(format!("numer: {nstr}")))?;
400            let d: rug::Integer = dstr
401                .parse()
402                .map_err(|_| IoError::BadNumeric(format!("denom: {dstr}")))?;
403            Ok(ExprData::Rational(BigRat(rug::Rational::from((n, d)))))
404        }
405        3 => {
406            let prec = read_u32(r)?;
407            let s = read_str(r)?;
408            let f = rug::Float::parse_radix(&s, 16)
409                .map_err(|_| IoError::BadNumeric(format!("float: {s}")))?;
410            let inner = rug::Float::with_val(prec, f);
411            Ok(ExprData::Float(BigFloat { inner, prec }))
412        }
413        4 => Ok(ExprData::Add(read_ids(r)?)),
414        5 => Ok(ExprData::Mul(read_ids(r)?)),
415        6 => {
416            let base = ExprId(read_u32(r)?);
417            let exp = ExprId(read_u32(r)?);
418            Ok(ExprData::Pow { base, exp })
419        }
420        7 => {
421            let name = read_str(r)?;
422            let args = read_ids(r)?;
423            Ok(ExprData::Func { name, args })
424        }
425        8 => {
426            let n = read_u32(r)? as usize;
427            let mut branches = Vec::with_capacity(n);
428            for _ in 0..n {
429                let c = ExprId(read_u32(r)?);
430                let v = ExprId(read_u32(r)?);
431                branches.push((c, v));
432            }
433            let default = ExprId(read_u32(r)?);
434            Ok(ExprData::Piecewise { branches, default })
435        }
436        9 => {
437            let kind = u8_to_pred(read_u8(r)?)?;
438            let args = read_ids(r)?;
439            Ok(ExprData::Predicate { kind, args })
440        }
441        10 => {
442            if format_version < POOL_FORMAT_V2 {
443                return Err(IoError::BadTag(10));
444            }
445            let var = ExprId(read_u32(r)?);
446            let body = ExprId(read_u32(r)?);
447            Ok(ExprData::Forall { var, body })
448        }
449        11 => {
450            if format_version < POOL_FORMAT_V2 {
451                return Err(IoError::BadTag(11));
452            }
453            let var = ExprId(read_u32(r)?);
454            let body = ExprId(read_u32(r)?);
455            Ok(ExprData::Exists { var, body })
456        }
457        12 => {
458            if format_version < POOL_FORMAT_V3 {
459                return Err(IoError::BadTag(12));
460            }
461            let inner = ExprId(read_u32(r)?);
462            Ok(ExprData::BigO(inner))
463        }
464        13 => {
465            if format_version < POOL_FORMAT_V5 {
466                return Err(IoError::BadTag(13));
467            }
468            let poly = ExprId(read_u32(r)?);
469            let var = ExprId(read_u32(r)?);
470            let body = ExprId(read_u32(r)?);
471            Ok(ExprData::RootSum { poly, var, body })
472        }
473        b => Err(IoError::BadTag(b)),
474    }
475}
476
477// ---------------------------------------------------------------------------
478// Public API
479// ---------------------------------------------------------------------------
480
481/// Write the pool's full node table to `path` atomically (temp + rename).
482pub fn save_to(pool: &ExprPool, path: impl AsRef<Path>) -> Result<(), IoError> {
483    let path = path.as_ref();
484    let tmp: PathBuf = {
485        let mut p = path.to_path_buf();
486        let mut name = p
487            .file_name()
488            .map(|s| s.to_os_string())
489            .unwrap_or_else(|| std::ffi::OsString::from("pool"));
490        name.push(".tmp");
491        p.set_file_name(name);
492        p
493    };
494
495    {
496        let f = File::create(&tmp)?;
497        let mut w = BufWriter::new(f);
498
499        w.write_all(MAGIC)?;
500        write_u32(&mut w, POOL_FORMAT_WRITE)?;
501        write_u32(&mut w, 0u32)?; // flags
502
503        let count = pool.len();
504        write_u64(&mut w, count as u64)?;
505        for i in 0..count {
506            let data = pool.get(ExprId(i as u32));
507            write_node(&mut w, &data)?;
508        }
509
510        w.flush()?;
511        w.get_ref().sync_all()?;
512    }
513
514    fs::rename(&tmp, path)?;
515    Ok(())
516}
517
518/// Load a pool from `path`.  Returns `Ok(None)` if the file does not exist,
519/// so callers can use `load_or_new` semantics.
520pub fn load_from(path: impl AsRef<Path>) -> Result<Option<ExprPool>, IoError> {
521    let path = path.as_ref();
522    if !path.exists() {
523        return Ok(None);
524    }
525
526    let f = File::open(path)?;
527    let mut r = BufReader::new(f);
528
529    let mut magic = [0u8; 4];
530    r.read_exact(&mut magic).map_err(|_| IoError::Truncated)?;
531    if &magic != MAGIC {
532        return Err(IoError::BadMagic);
533    }
534
535    let version = read_u32(&mut r)?;
536    if version != POOL_FORMAT_V1
537        && version != POOL_FORMAT_V2
538        && version != POOL_FORMAT_V3
539        && version != POOL_FORMAT_V4
540        && version != POOL_FORMAT_V5
541    {
542        return Err(IoError::UnsupportedVersion(version));
543    }
544    let _flags = read_u32(&mut r)?;
545
546    let pool = ExprPool::new();
547    let count = read_u64(&mut r)? as usize;
548    for expected in 0..count {
549        let data = read_node(&mut r, version)?;
550        let got = pool.intern(data);
551        debug_assert_eq!(got.0 as usize, expected, "pool id drift during load");
552    }
553
554    Ok(Some(pool))
555}
556
557/// Load if `path` exists, else return a fresh pool.
558pub fn open_persistent(path: impl AsRef<Path>) -> Result<ExprPool, IoError> {
559    match load_from(path)? {
560        Some(p) => Ok(p),
561        None => Ok(ExprPool::new()),
562    }
563}
564
565// ---------------------------------------------------------------------------
566// ExprPool convenience methods
567// ---------------------------------------------------------------------------
568
569impl ExprPool {
570    /// V1-14 — write the current pool to `path` atomically.  Equivalent to
571    /// [`save_to`].
572    pub fn checkpoint(&self, path: impl AsRef<Path>) -> Result<(), IoError> {
573        save_to(self, path)
574    }
575
576    /// V1-14 — load a persisted pool, or return a fresh one if the file does
577    /// not exist.  Equivalent to [`open_persistent`].
578    pub fn open_persistent(path: impl AsRef<Path>) -> Result<Self, IoError> {
579        open_persistent(path)
580    }
581}
582
583// ---------------------------------------------------------------------------
584// Tests
585// ---------------------------------------------------------------------------
586
587#[cfg(test)]
588mod tests {
589    use super::*;
590    use crate::kernel::{Domain, ExprData};
591
592    fn tempfile() -> PathBuf {
593        let mut p = std::env::temp_dir();
594        p.push(format!(
595            "alkahest_pool_{}_{}.akp",
596            std::process::id(),
597            std::time::SystemTime::now()
598                .duration_since(std::time::UNIX_EPOCH)
599                .unwrap()
600                .as_nanos()
601        ));
602        p
603    }
604
605    #[test]
606    fn round_trip_small_pool() {
607        let p = ExprPool::new();
608        let x = p.symbol("x", Domain::Real);
609        let y = p.symbol("y", Domain::Positive);
610        let two = p.integer(2_i32);
611        let three_halves = p.rational(3, 2);
612        let f = p.float(1.5_f64, 53);
613        let xp = p.pow(x, two);
614        let fn_node = p.func("sin", vec![xp]);
615        let _sum = p.add(vec![fn_node, y, three_halves, f]);
616
617        let path = tempfile();
618        p.checkpoint(&path).unwrap();
619
620        let q = ExprPool::open_persistent(&path).unwrap();
621        assert_eq!(q.len(), p.len(), "node count must match");
622        for i in 0..p.len() {
623            let id = ExprId(i as u32);
624            assert_eq!(p.get(id), q.get(id), "node {i} mismatch after round-trip");
625        }
626
627        // Re-interning the same structures under q must collide with the
628        // restored IDs — this is the hash-cons stability guarantee.
629        let q_x = q.symbol("x", Domain::Real);
630        assert_eq!(q_x, x, "symbol id drifted across checkpoint");
631        let q_two = q.integer(2_i32);
632        assert_eq!(q_two, two);
633
634        let _ = fs::remove_file(&path);
635    }
636
637    #[test]
638    fn round_trip_root_sum() {
639        // RootSum nodes (format V5) must round-trip through checkpoint/restore.
640        let p = ExprPool::new();
641        let x = p.symbol("x", Domain::Real);
642        let t = p.symbol("t", Domain::Complex);
643        let poly = p.add(vec![p.pow(t, p.integer(2_i32)), p.integer(1_i32)]);
644        let body = p.mul(vec![t, p.func("log", vec![p.add(vec![x, t])])]);
645        let rs = p.root_sum(poly, t, body);
646        assert!(matches!(p.get(rs), ExprData::RootSum { .. }));
647
648        let path = tempfile();
649        p.checkpoint(&path).unwrap();
650        let q = ExprPool::open_persistent(&path).unwrap();
651        assert_eq!(q.len(), p.len());
652        for i in 0..p.len() {
653            let id = ExprId(i as u32);
654            assert_eq!(p.get(id), q.get(id), "node {i} mismatch after round-trip");
655        }
656        let _ = fs::remove_file(&path);
657    }
658
659    #[test]
660    fn bad_magic_rejected() {
661        let path = tempfile();
662        std::fs::write(&path, b"nope1234").unwrap();
663        match load_from(&path) {
664            Err(IoError::BadMagic) => {}
665            other => panic!("expected BadMagic, got {:?}", other.err()),
666        }
667        let _ = fs::remove_file(&path);
668    }
669
670    #[test]
671    fn missing_file_returns_fresh() {
672        let path = tempfile();
673        assert!(!path.exists());
674        let p = ExprPool::open_persistent(&path).unwrap();
675        assert_eq!(p.len(), 0);
676    }
677
678    #[test]
679    fn predicate_and_piecewise_round_trip() {
680        let p = ExprPool::new();
681        let x = p.symbol("x", Domain::Real);
682        let zero = p.integer(0_i32);
683        let one = p.integer(1_i32);
684        let neg_one = p.integer(-1_i32);
685        let cond = p.intern(ExprData::Predicate {
686            kind: PredicateKind::Gt,
687            args: vec![x, zero],
688        });
689        let pc = p.intern(ExprData::Piecewise {
690            branches: vec![(cond, one)],
691            default: neg_one,
692        });
693
694        let path = tempfile();
695        p.checkpoint(&path).unwrap();
696        let q = ExprPool::open_persistent(&path).unwrap();
697        assert_eq!(p.get(pc), q.get(pc));
698        let _ = fs::remove_file(&path);
699    }
700
701    #[test]
702    fn big_o_round_trip() {
703        let p = ExprPool::new();
704        let x = p.symbol("x", Domain::Real);
705        let o = p.big_o(p.pow(x, p.integer(6)));
706        let path = tempfile();
707        p.checkpoint(&path).unwrap();
708        let q = ExprPool::open_persistent(&path).unwrap();
709        assert_eq!(q.get(o), p.get(o));
710        let _ = fs::remove_file(&path);
711    }
712}