Skip to main content

nnrp_core/
cache.rs

1use crate::{
2    NnrpError, CACHE_ERROR_DEPENDENCY_INVALID, CACHE_ERROR_LEASE_EXPIRED, CACHE_ERROR_MISS,
3    CACHE_ERROR_SCHEMA_MISMATCH, CACHE_ERROR_VERSION_MISMATCH,
4};
5
6pub const CACHE_PUT_METADATA_LEN: usize = 32;
7pub const CACHE_ACK_METADATA_LEN: usize = 28;
8pub const CACHE_INVALIDATE_METADATA_LEN: usize = 20;
9pub const CACHE_PUT_FLAGS_KNOWN_MASK: u32 = 0x0000_0003;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
12#[repr(u32)]
13pub enum CacheObjectKind {
14    CameraBlock = 0x0001,
15    TileIndexBlock = 0x0002,
16    TensorSectionTable = 0x0003,
17    CodecTable = 0x0004,
18    ReusableResultObject = 0x0005,
19    PayloadLayoutTemplate = 0x0006,
20    PromptSegment = 0x0007,
21    ToolSchema = 0x0008,
22    StructuredEventSchema = 0x0009,
23}
24
25impl CacheObjectKind {
26    pub fn try_from_u32(value: u32) -> Result<Self, NnrpError> {
27        match value {
28            0x0001 => Ok(Self::CameraBlock),
29            0x0002 => Ok(Self::TileIndexBlock),
30            0x0003 => Ok(Self::TensorSectionTable),
31            0x0004 => Ok(Self::CodecTable),
32            0x0005 => Ok(Self::ReusableResultObject),
33            0x0006 => Ok(Self::PayloadLayoutTemplate),
34            0x0007 => Ok(Self::PromptSegment),
35            0x0008 => Ok(Self::ToolSchema),
36            0x0009 => Ok(Self::StructuredEventSchema),
37            _ => Err(NnrpError::UnknownEnumValue {
38                enum_name: "cache_object_kind",
39                value: value as u64,
40            }),
41        }
42    }
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46#[repr(u32)]
47pub enum CacheAckStatus {
48    Accepted = 0,
49    Rejected = 1,
50    Replaced = 2,
51}
52
53impl CacheAckStatus {
54    pub fn try_from_u32(value: u32) -> Result<Self, NnrpError> {
55        match value {
56            0 => Ok(Self::Accepted),
57            1 => Ok(Self::Rejected),
58            2 => Ok(Self::Replaced),
59            _ => Err(NnrpError::UnknownEnumValue {
60                enum_name: "cache_ack_status",
61                value: value as u64,
62            }),
63        }
64    }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68#[repr(u32)]
69pub enum CacheInvalidateScope {
70    WholeSession = 0,
71    Namespace = 1,
72    ObjectKind = 2,
73    ObjectKey = 3,
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
77pub struct CacheObjectId {
78    pub cache_namespace: u32,
79    pub cache_key_hi: u32,
80    pub cache_key_lo: u32,
81    pub object_kind: CacheObjectKind,
82}
83
84impl CacheObjectId {
85    pub fn from_put(metadata: &CachePutMetadata) -> Self {
86        Self {
87            cache_namespace: metadata.cache_namespace,
88            cache_key_hi: metadata.cache_key_hi,
89            cache_key_lo: metadata.cache_key_lo,
90            object_kind: metadata.object_kind,
91        }
92    }
93
94    pub fn matches_invalidate(&self, metadata: &CacheInvalidateMetadata) -> bool {
95        match metadata.invalidate_scope {
96            CacheInvalidateScope::WholeSession => true,
97            CacheInvalidateScope::Namespace => self.cache_namespace == metadata.cache_namespace,
98            CacheInvalidateScope::ObjectKind => {
99                self.cache_namespace == metadata.cache_namespace
100                    && self.object_kind as u32 == metadata.cache_key_hi
101            }
102            CacheInvalidateScope::ObjectKey => {
103                self.cache_namespace == metadata.cache_namespace
104                    && self.cache_key_hi == metadata.cache_key_hi
105                    && self.cache_key_lo == metadata.cache_key_lo
106            }
107        }
108    }
109}
110
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
112#[repr(u8)]
113pub enum CacheLeaseOwnerScope {
114    Connection = 0,
115    Session = 1,
116    Operation = 2,
117}
118
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub struct CacheLease {
121    pub object_id: CacheObjectId,
122    pub object_version: u64,
123    pub lease_id: u64,
124    pub owner_scope: CacheLeaseOwnerScope,
125    pub owner_id: u64,
126    pub granted_at_ms: u64,
127    pub ttl_ms: u32,
128}
129
130impl CacheLease {
131    pub fn expires_at_ms(&self) -> u64 {
132        self.granted_at_ms.saturating_add(self.ttl_ms as u64)
133    }
134
135    pub fn is_expired_at(&self, now_ms: u64) -> bool {
136        now_ms >= self.expires_at_ms()
137    }
138
139    pub fn validate_live_at(&self, now_ms: u64) -> Result<(), CacheValidationFailure> {
140        if self.is_expired_at(now_ms) {
141            return Err(CacheValidationFailure::LeaseExpired);
142        }
143
144        Ok(())
145    }
146
147    pub fn validate_version(&self, expected_version: u64) -> Result<(), CacheValidationFailure> {
148        if self.object_version != expected_version {
149            return Err(CacheValidationFailure::VersionMismatch);
150        }
151
152        Ok(())
153    }
154}
155
156#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157pub struct CacheDependency {
158    pub object_id: CacheObjectId,
159    pub required_version: u64,
160}
161
162#[derive(Debug, Clone, Copy, PartialEq, Eq)]
163pub struct CacheDependencyState {
164    pub object_id: CacheObjectId,
165    pub current_version: u64,
166    pub invalidated: bool,
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq)]
170pub enum CacheValidationFailure {
171    Miss,
172    LeaseExpired,
173    VersionMismatch,
174    DependencyInvalid,
175    SchemaMismatch,
176}
177
178impl CacheValidationFailure {
179    pub fn error_code(self) -> u32 {
180        match self {
181            Self::Miss => CACHE_ERROR_MISS,
182            Self::LeaseExpired => CACHE_ERROR_LEASE_EXPIRED,
183            Self::VersionMismatch => CACHE_ERROR_VERSION_MISMATCH,
184            Self::DependencyInvalid => CACHE_ERROR_DEPENDENCY_INVALID,
185            Self::SchemaMismatch => CACHE_ERROR_SCHEMA_MISMATCH,
186        }
187    }
188}
189
190pub fn validate_cache_dependencies(
191    dependencies: &[CacheDependency],
192    states: &[CacheDependencyState],
193) -> Result<(), CacheValidationFailure> {
194    for dependency in dependencies {
195        let state = states
196            .iter()
197            .find(|state| state.object_id == dependency.object_id)
198            .ok_or(CacheValidationFailure::DependencyInvalid)?;
199
200        if state.invalidated || state.current_version != dependency.required_version {
201            return Err(CacheValidationFailure::DependencyInvalid);
202        }
203    }
204
205    Ok(())
206}
207
208impl CacheInvalidateScope {
209    pub fn try_from_u32(value: u32) -> Result<Self, NnrpError> {
210        match value {
211            0 => Ok(Self::WholeSession),
212            1 => Ok(Self::Namespace),
213            2 => Ok(Self::ObjectKind),
214            3 => Ok(Self::ObjectKey),
215            _ => Err(NnrpError::UnknownEnumValue {
216                enum_name: "cache_invalidate_scope",
217                value: value as u64,
218            }),
219        }
220    }
221}
222
223#[derive(Debug, Clone, Copy, PartialEq, Eq)]
224pub struct CachePutMetadata {
225    pub cache_namespace: u32,
226    pub cache_key_hi: u32,
227    pub cache_key_lo: u32,
228    pub object_kind: CacheObjectKind,
229    pub ttl_ms: u32,
230    pub object_bytes: u32,
231    pub codec_bitmap: u32,
232    pub flags: u32,
233}
234
235impl CachePutMetadata {
236    pub fn parse(source: &[u8]) -> Result<Self, NnrpError> {
237        require_len(source, CACHE_PUT_METADATA_LEN)?;
238        let flags = read_u32(source, 28);
239        validate_mask_u32(flags, CACHE_PUT_FLAGS_KNOWN_MASK)?;
240
241        Ok(Self {
242            cache_namespace: read_u32(source, 0),
243            cache_key_hi: read_u32(source, 4),
244            cache_key_lo: read_u32(source, 8),
245            object_kind: CacheObjectKind::try_from_u32(read_u32(source, 12))?,
246            ttl_ms: read_u32(source, 16),
247            object_bytes: read_u32(source, 20),
248            codec_bitmap: read_u32(source, 24),
249            flags,
250        })
251    }
252
253    pub fn write(&self, destination: &mut [u8]) -> Result<(), NnrpError> {
254        require_destination_len(destination, CACHE_PUT_METADATA_LEN)?;
255        validate_mask_u32(self.flags, CACHE_PUT_FLAGS_KNOWN_MASK)?;
256
257        write_u32(destination, 0, self.cache_namespace);
258        write_u32(destination, 4, self.cache_key_hi);
259        write_u32(destination, 8, self.cache_key_lo);
260        write_u32(destination, 12, self.object_kind as u32);
261        write_u32(destination, 16, self.ttl_ms);
262        write_u32(destination, 20, self.object_bytes);
263        write_u32(destination, 24, self.codec_bitmap);
264        write_u32(destination, 28, self.flags);
265        Ok(())
266    }
267
268    pub fn to_bytes(&self) -> Result<[u8; CACHE_PUT_METADATA_LEN], NnrpError> {
269        let mut bytes = [0u8; CACHE_PUT_METADATA_LEN];
270        self.write(&mut bytes)?;
271        Ok(bytes)
272    }
273}
274
275#[derive(Debug, Clone, Copy, PartialEq, Eq)]
276pub struct CacheAckMetadata {
277    pub cache_namespace: u32,
278    pub cache_key_hi: u32,
279    pub cache_key_lo: u32,
280    pub status: CacheAckStatus,
281    pub accepted_ttl_ms: u32,
282    pub max_object_bytes: u32,
283    pub detail_code: u32,
284}
285
286impl CacheAckMetadata {
287    pub fn parse(source: &[u8]) -> Result<Self, NnrpError> {
288        require_len(source, CACHE_ACK_METADATA_LEN)?;
289        Ok(Self {
290            cache_namespace: read_u32(source, 0),
291            cache_key_hi: read_u32(source, 4),
292            cache_key_lo: read_u32(source, 8),
293            status: CacheAckStatus::try_from_u32(read_u32(source, 12))?,
294            accepted_ttl_ms: read_u32(source, 16),
295            max_object_bytes: read_u32(source, 20),
296            detail_code: read_u32(source, 24),
297        })
298    }
299
300    pub fn write(&self, destination: &mut [u8]) -> Result<(), NnrpError> {
301        require_destination_len(destination, CACHE_ACK_METADATA_LEN)?;
302        write_u32(destination, 0, self.cache_namespace);
303        write_u32(destination, 4, self.cache_key_hi);
304        write_u32(destination, 8, self.cache_key_lo);
305        write_u32(destination, 12, self.status as u32);
306        write_u32(destination, 16, self.accepted_ttl_ms);
307        write_u32(destination, 20, self.max_object_bytes);
308        write_u32(destination, 24, self.detail_code);
309        Ok(())
310    }
311
312    pub fn to_bytes(&self) -> Result<[u8; CACHE_ACK_METADATA_LEN], NnrpError> {
313        let mut bytes = [0u8; CACHE_ACK_METADATA_LEN];
314        self.write(&mut bytes)?;
315        Ok(bytes)
316    }
317}
318
319#[derive(Debug, Clone, Copy, PartialEq, Eq)]
320pub struct CacheInvalidateMetadata {
321    pub invalidate_scope: CacheInvalidateScope,
322    pub cache_namespace: u32,
323    pub cache_key_hi: u32,
324    pub cache_key_lo: u32,
325    pub reason_code: u32,
326}
327
328impl CacheInvalidateMetadata {
329    pub fn parse(source: &[u8]) -> Result<Self, NnrpError> {
330        require_len(source, CACHE_INVALIDATE_METADATA_LEN)?;
331        Ok(Self {
332            invalidate_scope: CacheInvalidateScope::try_from_u32(read_u32(source, 0))?,
333            cache_namespace: read_u32(source, 4),
334            cache_key_hi: read_u32(source, 8),
335            cache_key_lo: read_u32(source, 12),
336            reason_code: read_u32(source, 16),
337        })
338    }
339
340    pub fn write(&self, destination: &mut [u8]) -> Result<(), NnrpError> {
341        require_destination_len(destination, CACHE_INVALIDATE_METADATA_LEN)?;
342        write_u32(destination, 0, self.invalidate_scope as u32);
343        write_u32(destination, 4, self.cache_namespace);
344        write_u32(destination, 8, self.cache_key_hi);
345        write_u32(destination, 12, self.cache_key_lo);
346        write_u32(destination, 16, self.reason_code);
347        Ok(())
348    }
349
350    pub fn to_bytes(&self) -> Result<[u8; CACHE_INVALIDATE_METADATA_LEN], NnrpError> {
351        let mut bytes = [0u8; CACHE_INVALIDATE_METADATA_LEN];
352        self.write(&mut bytes)?;
353        Ok(bytes)
354    }
355}
356
357fn require_len(source: &[u8], expected: usize) -> Result<(), NnrpError> {
358    if source.len() < expected {
359        return Err(NnrpError::SourceTooShort {
360            expected,
361            actual: source.len(),
362        });
363    }
364    Ok(())
365}
366
367fn require_destination_len(destination: &[u8], expected: usize) -> Result<(), NnrpError> {
368    if destination.len() < expected {
369        return Err(NnrpError::DestinationTooShort {
370            expected,
371            actual: destination.len(),
372        });
373    }
374    Ok(())
375}
376
377fn validate_mask_u32(value: u32, allowed: u32) -> Result<(), NnrpError> {
378    if value & !allowed != 0 {
379        return Err(NnrpError::ReservedBitsSet {
380            value: value as u64,
381            allowed: allowed as u64,
382        });
383    }
384    Ok(())
385}
386
387fn read_u32(source: &[u8], offset: usize) -> u32 {
388    u32::from_le_bytes(source[offset..offset + 4].try_into().expect("slice length"))
389}
390
391fn write_u32(destination: &mut [u8], offset: usize, value: u32) {
392    destination[offset..offset + 4].copy_from_slice(&value.to_le_bytes());
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398
399    #[test]
400    fn cache_metadata_round_trips_python_golden_vectors() {
401        let put_bytes =
402            hex_to_bytes("01000000040302010807060501000000983a0000000800000300000003000000");
403        let put = CachePutMetadata::parse(&put_bytes).unwrap();
404        assert_eq!(put.cache_namespace, 1);
405        assert_eq!(put.cache_key_hi, 0x0102_0304);
406        assert_eq!(put.cache_key_lo, 0x0506_0708);
407        assert_eq!(put.object_kind, CacheObjectKind::CameraBlock);
408        assert_eq!(put.ttl_ms, 15_000);
409        assert_eq!(put.object_bytes, 2048);
410        assert_eq!(put.flags, 3);
411        assert_eq!(put.to_bytes().unwrap().as_slice(), put_bytes.as_slice());
412
413        let ack_bytes = hex_to_bytes("01000000040302010807060500000000983a00000020000000000000");
414        let ack = CacheAckMetadata::parse(&ack_bytes).unwrap();
415        assert_eq!(ack.status, CacheAckStatus::Accepted);
416        assert_eq!(ack.max_object_bytes, 8192);
417        assert_eq!(ack.to_bytes().unwrap().as_slice(), ack_bytes.as_slice());
418
419        let invalidate_bytes = hex_to_bytes("0000000001000000040302010807060502000000");
420        let invalidate = CacheInvalidateMetadata::parse(&invalidate_bytes).unwrap();
421        assert_eq!(
422            invalidate.invalidate_scope,
423            CacheInvalidateScope::WholeSession
424        );
425        assert_eq!(invalidate.cache_namespace, 1);
426        assert_eq!(invalidate.cache_key_lo, 0x0506_0708);
427        assert_eq!(
428            invalidate.to_bytes().unwrap().as_slice(),
429            invalidate_bytes.as_slice()
430        );
431    }
432
433    #[test]
434    fn cache_metadata_rejects_unknown_assignments_and_flags() {
435        for value in 1..=9 {
436            assert!(CacheObjectKind::try_from_u32(value).is_ok());
437        }
438        for value in 0..=2 {
439            assert!(CacheAckStatus::try_from_u32(value).is_ok());
440        }
441        for value in 0..=3 {
442            assert!(CacheInvalidateScope::try_from_u32(value).is_ok());
443        }
444
445        assert_eq!(
446            CacheObjectKind::try_from_u32(0xffff),
447            Err(NnrpError::UnknownEnumValue {
448                enum_name: "cache_object_kind",
449                value: 0xffff
450            })
451        );
452        assert_eq!(
453            CacheInvalidateScope::try_from_u32(0xff),
454            Err(NnrpError::UnknownEnumValue {
455                enum_name: "cache_invalidate_scope",
456                value: 0xff
457            })
458        );
459
460        let mut put_bytes = [0u8; CACHE_PUT_METADATA_LEN];
461        write_u32(&mut put_bytes, 12, CacheObjectKind::CameraBlock as u32);
462        write_u32(&mut put_bytes, 28, 0x4);
463        assert_eq!(
464            CachePutMetadata::parse(&put_bytes),
465            Err(NnrpError::ReservedBitsSet {
466                value: 0x4,
467                allowed: CACHE_PUT_FLAGS_KNOWN_MASK as u64
468            })
469        );
470
471        assert_eq!(
472            CacheAckStatus::try_from_u32(99),
473            Err(NnrpError::UnknownEnumValue {
474                enum_name: "cache_ack_status",
475                value: 99
476            })
477        );
478        assert_eq!(
479            CachePutMetadata::parse(&[0u8; CACHE_PUT_METADATA_LEN - 1]),
480            Err(NnrpError::SourceTooShort {
481                expected: CACHE_PUT_METADATA_LEN,
482                actual: CACHE_PUT_METADATA_LEN - 1
483            })
484        );
485        let put = CachePutMetadata {
486            cache_namespace: 1,
487            cache_key_hi: 2,
488            cache_key_lo: 3,
489            object_kind: CacheObjectKind::CameraBlock,
490            ttl_ms: 4,
491            object_bytes: 5,
492            codec_bitmap: 6,
493            flags: 0,
494        };
495        assert_eq!(
496            put.write(&mut [0u8; CACHE_PUT_METADATA_LEN - 1]),
497            Err(NnrpError::DestinationTooShort {
498                expected: CACHE_PUT_METADATA_LEN,
499                actual: CACHE_PUT_METADATA_LEN - 1
500            })
501        );
502    }
503
504    #[test]
505    fn cache_lease_exports_stable_validation_failures() {
506        let object_id = CacheObjectId {
507            cache_namespace: 1,
508            cache_key_hi: 2,
509            cache_key_lo: 3,
510            object_kind: CacheObjectKind::PromptSegment,
511        };
512        let lease = CacheLease {
513            object_id,
514            object_version: 7,
515            lease_id: 99,
516            owner_scope: CacheLeaseOwnerScope::Session,
517            owner_id: 42,
518            granted_at_ms: 1_000,
519            ttl_ms: 500,
520        };
521
522        assert_eq!(lease.expires_at_ms(), 1_500);
523        assert_eq!(lease.validate_live_at(1_499), Ok(()));
524        assert_eq!(
525            lease.validate_live_at(1_500),
526            Err(CacheValidationFailure::LeaseExpired)
527        );
528        assert_eq!(lease.validate_version(7), Ok(()));
529        assert_eq!(
530            lease.validate_version(8),
531            Err(CacheValidationFailure::VersionMismatch)
532        );
533        assert_eq!(
534            CacheValidationFailure::LeaseExpired.error_code(),
535            CACHE_ERROR_LEASE_EXPIRED
536        );
537        assert_eq!(
538            CacheValidationFailure::SchemaMismatch.error_code(),
539            CACHE_ERROR_SCHEMA_MISMATCH
540        );
541    }
542
543    #[test]
544    fn cache_dependencies_validate_versions_and_invalidations() {
545        let object_id = CacheObjectId {
546            cache_namespace: 1,
547            cache_key_hi: 2,
548            cache_key_lo: 3,
549            object_kind: CacheObjectKind::PromptSegment,
550        };
551        let dependencies = [CacheDependency {
552            object_id,
553            required_version: 7,
554        }];
555        let states = [CacheDependencyState {
556            object_id,
557            current_version: 7,
558            invalidated: false,
559        }];
560
561        assert_eq!(validate_cache_dependencies(&dependencies, &states), Ok(()));
562
563        let wrong_version = [CacheDependencyState {
564            current_version: 8,
565            ..states[0]
566        }];
567        assert_eq!(
568            validate_cache_dependencies(&dependencies, &wrong_version),
569            Err(CacheValidationFailure::DependencyInvalid)
570        );
571
572        let invalidated = [CacheDependencyState {
573            invalidated: true,
574            ..states[0]
575        }];
576        assert_eq!(
577            validate_cache_dependencies(&dependencies, &invalidated),
578            Err(CacheValidationFailure::DependencyInvalid)
579        );
580        assert_eq!(
581            validate_cache_dependencies(&dependencies, &[]),
582            Err(CacheValidationFailure::DependencyInvalid)
583        );
584    }
585
586    #[test]
587    fn cache_object_id_consumes_invalidate_scopes() {
588        let put = CachePutMetadata {
589            cache_namespace: 7,
590            cache_key_hi: 8,
591            cache_key_lo: 9,
592            object_kind: CacheObjectKind::ToolSchema,
593            ttl_ms: 100,
594            object_bytes: 64,
595            codec_bitmap: 0,
596            flags: 0,
597        };
598        let object_id = CacheObjectId::from_put(&put);
599
600        assert!(object_id.matches_invalidate(&CacheInvalidateMetadata {
601            invalidate_scope: CacheInvalidateScope::WholeSession,
602            cache_namespace: 0,
603            cache_key_hi: 0,
604            cache_key_lo: 0,
605            reason_code: 0,
606        }));
607        assert!(object_id.matches_invalidate(&CacheInvalidateMetadata {
608            invalidate_scope: CacheInvalidateScope::Namespace,
609            cache_namespace: 7,
610            cache_key_hi: 0,
611            cache_key_lo: 0,
612            reason_code: 0,
613        }));
614        assert!(object_id.matches_invalidate(&CacheInvalidateMetadata {
615            invalidate_scope: CacheInvalidateScope::ObjectKind,
616            cache_namespace: 7,
617            cache_key_hi: CacheObjectKind::ToolSchema as u32,
618            cache_key_lo: 0,
619            reason_code: 0,
620        }));
621        assert!(object_id.matches_invalidate(&CacheInvalidateMetadata {
622            invalidate_scope: CacheInvalidateScope::ObjectKey,
623            cache_namespace: 7,
624            cache_key_hi: 8,
625            cache_key_lo: 9,
626            reason_code: 0,
627        }));
628        assert!(!object_id.matches_invalidate(&CacheInvalidateMetadata {
629            invalidate_scope: CacheInvalidateScope::ObjectKey,
630            cache_namespace: 7,
631            cache_key_hi: 8,
632            cache_key_lo: 10,
633            reason_code: 0,
634        }));
635    }
636
637    fn hex_to_bytes(hex: &str) -> Vec<u8> {
638        assert_eq!(hex.len() % 2, 0);
639        (0..hex.len())
640            .step_by(2)
641            .map(|index| u8::from_str_radix(&hex[index..index + 2], 16).unwrap())
642            .collect()
643    }
644}