Skip to main content

llama_cpp_bindings/context/
llama_state_seq_flags.rs

1//! Flags for extended state sequence operations on hybrid/recurrent models.
2
3/// Flags controlling which parts of state to save/restore for sequence operations.
4///
5/// Used with the `state_seq_*_ext` methods on [`super::LlamaContext`] to enable
6/// partial state operations (e.g., saving only recurrent/SSM state for hybrid models).
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub struct LlamaStateSeqFlags {
9    flags: u32,
10}
11
12impl LlamaStateSeqFlags {
13    /// Save/restore only partial (recurrent/SSM) state, skipping attention KV cache.
14    pub const PARTIAL_ONLY: Self = Self { flags: 1 };
15
16    /// No flags set.
17    #[must_use]
18    pub const fn empty() -> Self {
19        Self { flags: 0 }
20    }
21
22    /// Returns the raw bit representation.
23    #[must_use]
24    pub const fn bits(&self) -> u32 {
25        self.flags
26    }
27
28    /// Returns true if `self` contains all bits in `other`.
29    #[must_use]
30    pub const fn contains(&self, other: Self) -> bool {
31        (self.flags & other.flags) == other.flags
32    }
33}
34
35impl std::ops::BitOr for LlamaStateSeqFlags {
36    type Output = Self;
37
38    fn bitor(self, rhs: Self) -> Self {
39        Self {
40            flags: self.flags | rhs.flags,
41        }
42    }
43}
44
45#[cfg(test)]
46mod tests {
47    use super::LlamaStateSeqFlags;
48
49    #[test]
50    fn empty_has_no_bits_set() {
51        assert_eq!(LlamaStateSeqFlags::empty().bits(), 0);
52    }
53
54    #[test]
55    fn partial_only_has_bit_one() {
56        assert_eq!(LlamaStateSeqFlags::PARTIAL_ONLY.bits(), 1);
57    }
58
59    #[test]
60    fn bitor_combines_flags() {
61        let combined = LlamaStateSeqFlags::empty() | LlamaStateSeqFlags::PARTIAL_ONLY;
62
63        assert_eq!(combined.bits(), 1);
64    }
65
66    #[test]
67    fn contains_detects_set_flag() {
68        let flags = LlamaStateSeqFlags::PARTIAL_ONLY;
69
70        assert!(flags.contains(LlamaStateSeqFlags::PARTIAL_ONLY));
71    }
72
73    #[test]
74    fn empty_does_not_contain_partial_only() {
75        let flags = LlamaStateSeqFlags::empty();
76
77        assert!(!flags.contains(LlamaStateSeqFlags::PARTIAL_ONLY));
78    }
79}