Skip to main content

dynamo_kv_hashing/
request.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Universal [`Request`] type used to derive block hashes.
5
6use derive_builder::Builder;
7use dynamo_tokens::{Token, TokenBlockMmInfo, validate_and_sort_mm_info};
8use serde::{Deserialize, Serialize};
9
10use crate::error::KvHashingError;
11
12/// Multimodal placeholder run as carried on a [`Request`].
13///
14/// Mirrors [`dynamo_tokens::TokenBlockMmInfo`]; kept distinct so the public Request shape
15/// is owned by the kv-hashing crate. `From` conversions are provided in both directions.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
17pub struct RequestMmObjectInfo {
18    /// Hash identifying the multimodal object.
19    pub mm_hash: u64,
20    /// Start position of the placeholder run in the full token sequence (zero-based).
21    pub offset: usize,
22    /// Number of placeholder slots in the run.
23    pub length: usize,
24}
25
26impl From<RequestMmObjectInfo> for TokenBlockMmInfo {
27    fn from(v: RequestMmObjectInfo) -> Self {
28        Self {
29            mm_hash: v.mm_hash,
30            offset: v.offset,
31            length: v.length,
32        }
33    }
34}
35
36impl From<TokenBlockMmInfo> for RequestMmObjectInfo {
37    fn from(v: TokenBlockMmInfo) -> Self {
38        Self {
39            mm_hash: v.mm_hash,
40            offset: v.offset,
41            length: v.length,
42        }
43    }
44}
45
46/// Canonical Request used to derive a deterministic sequence of block hashes.
47///
48/// Construction validates `mm_info` (no overlap, no out-of-bounds, no zero-length) and
49/// sorts it by `offset`. The validated/sorted state is the only way to construct a
50/// `Request`, so all downstream block-formation code can trust the invariant.
51///
52/// Built via the owned [`RequestBuilder`] (no clones on the build path):
53///
54/// ```ignore
55/// let request = Request::builder()
56///     .tokens(tokens)
57///     .lora_name(Some("lora-x".into()))
58///     .salt(Some("model-tag".into()))
59///     .mm_info(vec![/* RequestMmObjectInfo ... */])
60///     .build()?;
61/// ```
62#[derive(Debug, Clone, Builder)]
63#[builder(
64    pattern = "owned",
65    build_fn(private, name = "build_internal", error = "KvHashingError"),
66    derive(Debug)
67)]
68pub struct Request {
69    /// Token IDs of the request.
70    #[builder(setter(into))]
71    pub(crate) tokens: Vec<Token>,
72    /// Optional LoRA adapter name.
73    #[builder(default, setter(into))]
74    pub(crate) lora_name: Option<String>,
75    /// Optional free-form caller salt mixed into the per-request `SaltHash`.
76    #[builder(default, setter(into))]
77    pub(crate) salt: Option<String>,
78    /// Multimodal placeholder runs. Validated and sorted by `build()`.
79    #[builder(default)]
80    pub(crate) mm_info: Vec<RequestMmObjectInfo>,
81}
82
83impl Request {
84    /// Returns a fresh owned [`RequestBuilder`].
85    pub fn builder() -> RequestBuilder {
86        RequestBuilder::default()
87    }
88
89    /// Returns the request tokens.
90    pub fn tokens(&self) -> &[Token] {
91        &self.tokens
92    }
93
94    /// Returns the LoRA adapter name, if any.
95    pub fn lora_name(&self) -> Option<&str> {
96        self.lora_name.as_deref()
97    }
98
99    /// Returns the free-form caller salt, if any.
100    pub fn salt(&self) -> Option<&str> {
101        self.salt.as_deref()
102    }
103
104    /// Returns the validated, sorted multimodal runs.
105    pub fn mm_info(&self) -> &[RequestMmObjectInfo] {
106        &self.mm_info
107    }
108
109    /// Returns `mm_info` projected to the dynamo-tokens type, ready for
110    /// [`dynamo_tokens::TokenBlockSequence::new_with_mm`] (already sorted/validated).
111    pub(crate) fn token_mm_info(&self) -> Vec<TokenBlockMmInfo> {
112        self.mm_info.iter().copied().map(Into::into).collect()
113    }
114}
115
116impl RequestBuilder {
117    /// Builds the [`Request`], validating and sorting `mm_info` against the token length.
118    pub fn build(self) -> Result<Request, KvHashingError> {
119        let mut request = self.build_internal()?;
120        // Validate against the actual token length, sort by offset, and write back.
121        // No clones: we move out of `request.mm_info`, transform via Into, and replace.
122        let token_mm: Vec<TokenBlockMmInfo> = std::mem::take(&mut request.mm_info)
123            .into_iter()
124            .map(Into::into)
125            .collect();
126        let validated = validate_and_sort_mm_info(&token_mm, request.tokens.len())?;
127        request.mm_info = validated.into_iter().map(Into::into).collect();
128        Ok(request)
129    }
130}
131
132impl From<derive_builder::UninitializedFieldError> for KvHashingError {
133    fn from(e: derive_builder::UninitializedFieldError) -> Self {
134        KvHashingError::MissingField(e.field_name())
135    }
136}