dynamo_kv_hashing/request.rs
1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Universal [`Request`] type used to derive block hashes.
5
6use derive_builder::Builder;
7use dynamo_tokens::{Token, TokenBlockMmInfo, validate_and_sort_mm_info};
8use serde::{Deserialize, Serialize};
9
10use crate::error::KvHashingError;
11
12/// Multimodal placeholder run as carried on a [`Request`].
13///
14/// Mirrors [`dynamo_tokens::TokenBlockMmInfo`]; kept distinct so the public Request shape
15/// is owned by the kv-hashing crate. `From` conversions are provided in both directions.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
17pub struct RequestMmObjectInfo {
18 /// Hash identifying the multimodal object.
19 pub mm_hash: u64,
20 /// Start position of the placeholder run in the full token sequence (zero-based).
21 pub offset: usize,
22 /// Number of placeholder slots in the run.
23 pub length: usize,
24}
25
26impl From<RequestMmObjectInfo> for TokenBlockMmInfo {
27 fn from(v: RequestMmObjectInfo) -> Self {
28 Self {
29 mm_hash: v.mm_hash,
30 offset: v.offset,
31 length: v.length,
32 }
33 }
34}
35
36impl From<TokenBlockMmInfo> for RequestMmObjectInfo {
37 fn from(v: TokenBlockMmInfo) -> Self {
38 Self {
39 mm_hash: v.mm_hash,
40 offset: v.offset,
41 length: v.length,
42 }
43 }
44}
45
46/// Canonical Request used to derive a deterministic sequence of block hashes.
47///
48/// Construction validates `mm_info` (no overlap, no out-of-bounds, no zero-length) and
49/// sorts it by `offset`. The validated/sorted state is the only way to construct a
50/// `Request`, so all downstream block-formation code can trust the invariant.
51///
52/// Built via the owned [`RequestBuilder`] (no clones on the build path):
53///
54/// ```ignore
55/// let request = Request::builder()
56/// .tokens(tokens)
57/// .lora_name(Some("lora-x".into()))
58/// .salt(Some("model-tag".into()))
59/// .mm_info(vec![/* RequestMmObjectInfo ... */])
60/// .build()?;
61/// ```
62#[derive(Debug, Clone, Builder)]
63#[builder(
64 pattern = "owned",
65 build_fn(private, name = "build_internal", error = "KvHashingError"),
66 derive(Debug)
67)]
68pub struct Request {
69 /// Token IDs of the request.
70 #[builder(setter(into))]
71 pub(crate) tokens: Vec<Token>,
72 /// Optional LoRA adapter name.
73 #[builder(default, setter(into))]
74 pub(crate) lora_name: Option<String>,
75 /// Optional free-form caller salt mixed into the per-request `SaltHash`.
76 #[builder(default, setter(into))]
77 pub(crate) salt: Option<String>,
78 /// Multimodal placeholder runs. Validated and sorted by `build()`.
79 #[builder(default)]
80 pub(crate) mm_info: Vec<RequestMmObjectInfo>,
81}
82
83impl Request {
84 /// Returns a fresh owned [`RequestBuilder`].
85 pub fn builder() -> RequestBuilder {
86 RequestBuilder::default()
87 }
88
89 /// Returns the request tokens.
90 pub fn tokens(&self) -> &[Token] {
91 &self.tokens
92 }
93
94 /// Returns the LoRA adapter name, if any.
95 pub fn lora_name(&self) -> Option<&str> {
96 self.lora_name.as_deref()
97 }
98
99 /// Returns the free-form caller salt, if any.
100 pub fn salt(&self) -> Option<&str> {
101 self.salt.as_deref()
102 }
103
104 /// Returns the validated, sorted multimodal runs.
105 pub fn mm_info(&self) -> &[RequestMmObjectInfo] {
106 &self.mm_info
107 }
108
109 /// Returns `mm_info` projected to the dynamo-tokens type, ready for
110 /// [`dynamo_tokens::TokenBlockSequence::new_with_mm`] (already sorted/validated).
111 pub(crate) fn token_mm_info(&self) -> Vec<TokenBlockMmInfo> {
112 self.mm_info.iter().copied().map(Into::into).collect()
113 }
114}
115
116impl RequestBuilder {
117 /// Builds the [`Request`], validating and sorting `mm_info` against the token length.
118 pub fn build(self) -> Result<Request, KvHashingError> {
119 let mut request = self.build_internal()?;
120 // Validate against the actual token length, sort by offset, and write back.
121 // No clones: we move out of `request.mm_info`, transform via Into, and replace.
122 let token_mm: Vec<TokenBlockMmInfo> = std::mem::take(&mut request.mm_info)
123 .into_iter()
124 .map(Into::into)
125 .collect();
126 let validated = validate_and_sort_mm_info(&token_mm, request.tokens.len())?;
127 request.mm_info = validated.into_iter().map(Into::into).collect();
128 Ok(request)
129 }
130}
131
132impl From<derive_builder::UninitializedFieldError> for KvHashingError {
133 fn from(e: derive_builder::UninitializedFieldError) -> Self {
134 KvHashingError::MissingField(e.field_name())
135 }
136}