Skip to main content

kapsl_engine_api/
lib.rs

1use async_trait::async_trait;
2use base64::Engine as _;
3use futures::stream::Stream;
4use serde::{Deserialize, Serialize};
5use std::borrow::Cow;
6use std::fmt;
7use std::pin::Pin;
8use std::str::FromStr;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::Arc;
11use std::time::{SystemTime, UNIX_EPOCH};
12use thiserror::Error;
13
14#[derive(Error, Debug)]
15pub enum EngineError {
16    #[error("Backend error: {message}")]
17    Backend {
18        message: String,
19        #[source]
20        source: Option<Box<dyn std::error::Error + Send + Sync>>,
21    },
22    #[error("Invalid input: {message}")]
23    InvalidInput {
24        message: String,
25        #[source]
26        source: Option<Box<dyn std::error::Error + Send + Sync>>,
27    },
28    #[error("Model not loaded")]
29    ModelNotLoaded,
30    #[error("System overloaded: {message}")]
31    Overloaded {
32        message: String,
33        #[source]
34        source: Option<Box<dyn std::error::Error + Send + Sync>>,
35    },
36    #[error("Model load error for {path}: {source}")]
37    ModelLoadError {
38        path: String,
39        source: Box<dyn std::error::Error + Send + Sync>,
40    },
41    #[error("Inference error: {reason}")]
42    InferenceError {
43        reason: String,
44        #[source]
45        source: Option<Box<dyn std::error::Error + Send + Sync>>,
46    },
47    #[error("Timeout: {message}")]
48    TimeoutError {
49        message: String,
50        #[source]
51        source: Option<Box<dyn std::error::Error + Send + Sync>>,
52    },
53    #[error("Resource exhausted: {message}")]
54    ResourceExhausted {
55        message: String,
56        #[source]
57        source: Option<Box<dyn std::error::Error + Send + Sync>>,
58    },
59    #[error("Cancelled: {message}")]
60    Cancelled { message: String },
61}
62
63impl EngineError {
64    pub fn backend(message: impl Into<String>) -> Self {
65        EngineError::Backend {
66            message: message.into(),
67            source: None,
68        }
69    }
70
71    pub fn backend_with_source(
72        message: impl Into<String>,
73        source: impl std::error::Error + Send + Sync + 'static,
74    ) -> Self {
75        EngineError::Backend {
76            message: message.into(),
77            source: Some(Box::new(source)),
78        }
79    }
80
81    pub fn invalid_input(message: impl Into<String>) -> Self {
82        EngineError::InvalidInput {
83            message: message.into(),
84            source: None,
85        }
86    }
87
88    pub fn invalid_input_with_source(
89        message: impl Into<String>,
90        source: impl std::error::Error + Send + Sync + 'static,
91    ) -> Self {
92        EngineError::InvalidInput {
93            message: message.into(),
94            source: Some(Box::new(source)),
95        }
96    }
97
98    pub fn overloaded(message: impl Into<String>) -> Self {
99        EngineError::Overloaded {
100            message: message.into(),
101            source: None,
102        }
103    }
104
105    pub fn is_overloaded(&self) -> bool {
106        matches!(self, EngineError::Overloaded { .. })
107    }
108
109    pub fn timeout(message: impl Into<String>) -> Self {
110        EngineError::TimeoutError {
111            message: message.into(),
112            source: None,
113        }
114    }
115
116    pub fn resource_exhausted(message: impl Into<String>) -> Self {
117        EngineError::ResourceExhausted {
118            message: message.into(),
119            source: None,
120        }
121    }
122
123    pub fn cancelled(message: impl Into<String>) -> Self {
124        EngineError::Cancelled {
125            message: message.into(),
126        }
127    }
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct EngineMetrics {
132    pub inference_time: f64,
133    pub memory_usage: usize,
134    pub gpu_utilization: f64,
135    pub throughput: f64,
136    pub batch_size: usize,
137    pub queue_depth: usize,
138    pub error_rate: f64,
139    pub collected_at_ms: u64,
140    pub kv_cache_bytes_used: usize,
141    pub kv_cache_bytes_capacity: usize,
142    pub kv_cache_blocks_total: usize,
143    pub kv_cache_blocks_free: usize,
144    pub kv_cache_sequences: usize,
145    pub kv_cache_evicted_blocks: u64,
146    pub kv_cache_evicted_sequences: u64,
147    pub kv_cache_packed_layers: usize,
148}
149
150impl EngineMetrics {
151    pub fn new() -> Self {
152        Self {
153            inference_time: 0.0,
154            memory_usage: 0,
155            gpu_utilization: 0.0,
156            throughput: 0.0,
157            batch_size: 0,
158            queue_depth: 0,
159            error_rate: 0.0,
160            collected_at_ms: Self::now_ms(),
161            kv_cache_bytes_used: 0,
162            kv_cache_bytes_capacity: 0,
163            kv_cache_blocks_total: 0,
164            kv_cache_blocks_free: 0,
165            kv_cache_sequences: 0,
166            kv_cache_evicted_blocks: 0,
167            kv_cache_evicted_sequences: 0,
168            kv_cache_packed_layers: 0,
169        }
170    }
171
172    pub fn refresh_timestamp(&mut self) {
173        self.collected_at_ms = Self::now_ms();
174    }
175
176    fn now_ms() -> u64 {
177        SystemTime::now()
178            .duration_since(UNIX_EPOCH)
179            .unwrap_or_default()
180            .as_millis() as u64
181    }
182}
183
184impl Default for EngineMetrics {
185    fn default() -> Self {
186        Self::new()
187    }
188}
189
190#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
191pub enum TensorDtype {
192    Float32,
193    Float64,
194    Float16,
195    Int32,
196    Int64,
197    Uint8,
198    Utf8,
199}
200
201impl TensorDtype {
202    pub fn as_str(&self) -> &'static str {
203        match self {
204            TensorDtype::Float32 => "float32",
205            TensorDtype::Float64 => "float64",
206            TensorDtype::Float16 => "float16",
207            TensorDtype::Int32 => "int32",
208            TensorDtype::Int64 => "int64",
209            TensorDtype::Uint8 => "uint8",
210            TensorDtype::Utf8 => "string",
211        }
212    }
213
214    pub fn size_bytes(&self) -> usize {
215        match self {
216            TensorDtype::Float32 => 4,
217            TensorDtype::Float64 => 8,
218            TensorDtype::Float16 => 2,
219            TensorDtype::Int32 => 4,
220            TensorDtype::Int64 => 8,
221            TensorDtype::Uint8 => 1,
222            TensorDtype::Utf8 => 1,
223        }
224    }
225}
226
227impl fmt::Display for TensorDtype {
228    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229        write!(f, "{}", self.as_str())
230    }
231}
232
233impl FromStr for TensorDtype {
234    type Err = EngineError;
235
236    fn from_str(value: &str) -> Result<Self, Self::Err> {
237        match value.to_lowercase().as_str() {
238            "float32" | "fp32" => Ok(TensorDtype::Float32),
239            "float64" | "fp64" => Ok(TensorDtype::Float64),
240            "float16" | "fp16" => Ok(TensorDtype::Float16),
241            "int32" | "i32" => Ok(TensorDtype::Int32),
242            "int64" | "i64" => Ok(TensorDtype::Int64),
243            "uint8" | "u8" => Ok(TensorDtype::Uint8),
244            "string" | "utf8" => Ok(TensorDtype::Utf8),
245            other => Err(EngineError::InvalidInput {
246                message: format!("Unsupported dtype: {}", other),
247                source: None,
248            }),
249        }
250    }
251}
252
253impl Serialize for TensorDtype {
254    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
255    where
256        S: serde::Serializer,
257    {
258        serializer.serialize_str(self.as_str())
259    }
260}
261
262impl<'de> Deserialize<'de> for TensorDtype {
263    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
264    where
265        D: serde::Deserializer<'de>,
266    {
267        let value = String::deserialize(deserializer)?;
268        TensorDtype::from_str(&value).map_err(serde::de::Error::custom)
269    }
270}
271
272#[derive(Debug, Clone)]
273pub struct BinaryTensorPacket {
274    pub shape: Vec<i64>,
275    pub dtype: TensorDtype,
276    pub data: Vec<u8>,
277}
278
279impl serde::Serialize for BinaryTensorPacket {
280    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
281        use serde::ser::SerializeStruct;
282        // Must match the 4-field layout of BinaryTensorPacketPayload in the Deserialize impl
283        // so that bincode round-trips correctly (derived Serialize only emits 3 fields,
284        // causing a field-count mismatch on the bincode decode side).
285        let mut state = serializer.serialize_struct("BinaryTensorPacket", 4)?;
286        state.serialize_field("shape", &self.shape)?;
287        state.serialize_field("dtype", &self.dtype)?;
288        state.serialize_field("data", &Some(&self.data))?;
289        state.serialize_field("data_base64", &None::<&str>)?;
290        state.end()
291    }
292}
293
294impl<'de> Deserialize<'de> for BinaryTensorPacket {
295    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
296    where
297        D: serde::Deserializer<'de>,
298    {
299        // Borrow `data_base64` as `&'de str` rather than allocating a `String`:
300        // since we deserialize from a buffered byte slice (from_slice), serde can hand
301        // us a reference directly into the JSON input, saving one large memcpy before decode.
302        // Use Cow<str> for data_base64: borrows directly from the JSON buffer when
303        // deserializing via from_slice (zero-copy), allocates when deserializing
304        // from an owned Value (test/fallback path).
305        #[derive(Deserialize)]
306        struct BinaryTensorPacketPayload<'src> {
307            shape: Vec<i64>,
308            dtype: TensorDtype,
309            #[serde(default)]
310            data: Option<Vec<u8>>,
311            #[serde(default, alias = "base64", borrow)]
312            data_base64: Option<Cow<'src, str>>,
313        }
314
315        let payload = BinaryTensorPacketPayload::deserialize(deserializer)?;
316        let data = match (payload.data, payload.data_base64) {
317            (Some(data), None) => data,
318            (None, Some(encoded)) => base64::engine::general_purpose::STANDARD
319                .decode(encoded.as_bytes())
320                .map_err(serde::de::Error::custom)?,
321            (Some(_), Some(_)) => {
322                return Err(serde::de::Error::custom(
323                    "binary tensor payload must include only one of `data` or `data_base64`",
324                ))
325            }
326            (None, None) => {
327                return Err(serde::de::Error::custom(
328                    "binary tensor payload must include `data` or `data_base64`",
329                ))
330            }
331        };
332
333        Ok(Self {
334            shape: payload.shape,
335            dtype: payload.dtype,
336            data,
337        })
338    }
339}
340
341#[derive(Debug, Clone, Serialize, Deserialize)]
342pub struct BinaryTensorPacketRef<'a> {
343    pub shape: Vec<i64>,
344    pub dtype: TensorDtype,
345    #[serde(borrow)]
346    pub data: Cow<'a, [u8]>,
347}
348
349#[derive(Debug, Clone, Copy)]
350pub struct TensorView<'a> {
351    pub shape: &'a [i64],
352    pub dtype: TensorDtype,
353    pub data: &'a [u8],
354}
355
356impl BinaryTensorPacket {
357    pub fn new(shape: Vec<i64>, dtype: TensorDtype, data: Vec<u8>) -> Result<Self, EngineError> {
358        let packet = Self { shape, dtype, data };
359        packet.validate()?;
360        Ok(packet)
361    }
362
363    pub fn size_bytes(&self) -> usize {
364        self.data.len()
365    }
366
367    pub fn tensor_elements(&self) -> Result<usize, EngineError> {
368        shape_elements(&self.shape)
369    }
370
371    pub fn tensor_elements_cached(&self, cache: &mut Option<usize>) -> Result<usize, EngineError> {
372        if let Some(value) = *cache {
373            return Ok(value);
374        }
375        let value = self.tensor_elements()?;
376        *cache = Some(value);
377        Ok(value)
378    }
379
380    pub fn validate(&self) -> Result<(), EngineError> {
381        let elements = self.tensor_elements()?;
382        let expected = elements
383            .checked_mul(self.dtype.size_bytes())
384            .ok_or_else(|| EngineError::InvalidInput {
385                message: "Data size overflow".to_string(),
386                source: None,
387            })?;
388
389        if self.data.len() != expected {
390            return Err(EngineError::InvalidInput {
391                message: format!(
392                    "Data length mismatch: expected {} bytes ({} {} values) but got {} bytes",
393                    expected,
394                    elements,
395                    self.dtype,
396                    self.data.len()
397                ),
398                source: None,
399            });
400        }
401
402        Ok(())
403    }
404
405    pub fn view(&self) -> TensorView<'_> {
406        TensorView {
407            shape: &self.shape,
408            dtype: self.dtype,
409            data: &self.data,
410        }
411    }
412
413    pub fn as_borrowed(&self) -> BinaryTensorPacketRef<'_> {
414        BinaryTensorPacketRef::from(self)
415    }
416}
417
418impl<'a> BinaryTensorPacketRef<'a> {
419    pub fn to_owned(self) -> BinaryTensorPacket {
420        BinaryTensorPacket {
421            shape: self.shape,
422            dtype: self.dtype,
423            data: self.data.into_owned(),
424        }
425    }
426}
427
428impl<'a> From<&'a BinaryTensorPacket> for BinaryTensorPacketRef<'a> {
429    fn from(packet: &'a BinaryTensorPacket) -> Self {
430        Self {
431            shape: packet.shape.clone(),
432            dtype: packet.dtype,
433            data: Cow::Borrowed(&packet.data),
434        }
435    }
436}
437
438fn shape_elements(shape: &[i64]) -> Result<usize, EngineError> {
439    if shape.is_empty() {
440        return Ok(1);
441    }
442
443    let mut prod: usize = 1;
444    for &dim in shape {
445        if dim <= 0 {
446            return Err(EngineError::InvalidInput {
447                message: format!("Invalid shape dimension: {}", dim),
448                source: None,
449            });
450        }
451        prod = prod
452            .checked_mul(dim as usize)
453            .ok_or_else(|| EngineError::InvalidInput {
454                message: "Shape multiplication overflow".to_string(),
455                source: None,
456            })?;
457    }
458
459    Ok(prod)
460}
461
462#[derive(Debug, Clone, Default, Serialize, Deserialize)]
463pub struct RequestMetadata {
464    #[serde(default)]
465    pub request_id: Option<String>,
466    #[serde(default)]
467    pub timeout_ms: Option<u64>,
468    #[serde(default)]
469    pub priority: Option<u8>,
470    #[serde(default)]
471    pub force_cpu: Option<bool>,
472    #[serde(default)]
473    pub model_version: Option<String>,
474    #[serde(default, skip_serializing_if = "Option::is_none")]
475    pub auth_token: Option<String>,
476
477    // === Optional LLM overrides ===
478    #[serde(default, alias = "max_tokens")]
479    pub max_new_tokens: Option<u32>,
480    #[serde(default)]
481    pub temperature: Option<f32>,
482    #[serde(default)]
483    pub top_p: Option<f32>,
484    #[serde(default)]
485    pub top_k: Option<u32>,
486    #[serde(default)]
487    pub repetition_penalty: Option<f32>,
488    #[serde(default)]
489    pub seed: Option<u64>,
490    #[serde(default, alias = "stop_ids")]
491    pub stop_token_ids: Option<Vec<u32>>,
492}
493
494#[derive(Debug, Clone, Serialize, Deserialize)]
495pub struct NamedTensor {
496    pub name: String,
497    pub tensor: BinaryTensorPacket,
498}
499
500#[derive(Debug, Clone, Default)]
501pub struct CancellationToken {
502    cancelled: Arc<AtomicBool>,
503}
504
505impl CancellationToken {
506    pub fn new() -> Self {
507        Self {
508            cancelled: Arc::new(AtomicBool::new(false)),
509        }
510    }
511
512    pub fn cancel(&self) {
513        self.cancelled.store(true, Ordering::SeqCst);
514    }
515
516    pub fn is_cancelled(&self) -> bool {
517        self.cancelled.load(Ordering::SeqCst)
518    }
519}
520
521#[derive(Debug, Clone, Serialize, Deserialize)]
522pub struct InferenceRequest {
523    pub input: BinaryTensorPacket,
524    #[serde(default)]
525    pub additional_inputs: Vec<NamedTensor>,
526    #[serde(default)]
527    pub session_id: Option<String>,
528    #[serde(default)]
529    pub metadata: Option<RequestMetadata>,
530    #[serde(skip, default)]
531    pub cancellation: Option<CancellationToken>,
532}
533
534impl InferenceRequest {
535    pub fn new(input: BinaryTensorPacket) -> Self {
536        Self {
537            input,
538            additional_inputs: Vec::new(),
539            session_id: None,
540            metadata: None,
541            cancellation: None,
542        }
543    }
544
545    pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
546        self.session_id = Some(session_id.into());
547        self
548    }
549
550    pub fn with_metadata(mut self, metadata: RequestMetadata) -> Self {
551        self.metadata = Some(metadata);
552        self
553    }
554
555    pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
556        let metadata = self.metadata.get_or_insert_with(RequestMetadata::default);
557        metadata.request_id = Some(request_id.into());
558        self
559    }
560
561    pub fn add_input(&mut self, name: impl Into<String>, tensor: BinaryTensorPacket) {
562        self.additional_inputs.push(NamedTensor {
563            name: name.into(),
564            tensor,
565        });
566    }
567}
568
569#[derive(Debug, Clone, Serialize, Deserialize)]
570pub struct EngineModelInfo {
571    pub input_names: Vec<String>,
572    pub output_names: Vec<String>,
573    pub input_shapes: Vec<Vec<i64>>,
574    pub output_shapes: Vec<Vec<i64>>,
575    #[serde(default, skip_serializing_if = "Vec::is_empty")]
576    pub input_dtypes: Vec<String>,
577    #[serde(default, skip_serializing_if = "Vec::is_empty")]
578    pub output_dtypes: Vec<String>,
579    #[serde(default, skip_serializing_if = "Option::is_none")]
580    pub framework: Option<String>,
581    #[serde(default, skip_serializing_if = "Option::is_none")]
582    pub model_version: Option<String>,
583    #[serde(default, skip_serializing_if = "Option::is_none")]
584    pub peak_concurrency: Option<u32>,
585}
586
587pub type EngineStream = Pin<Box<dyn Stream<Item = Result<BinaryTensorPacket, EngineError>> + Send>>;
588
589#[async_trait]
590pub trait Engine: Send + Sync {
591    /// Load model weights and prepare runtime state.
592    async fn load(&mut self, model_path: &std::path::Path) -> Result<(), EngineError>;
593
594    /// Run a single inference request and return the output tensor.
595    fn infer(&self, request: &InferenceRequest) -> Result<BinaryTensorPacket, EngineError>;
596
597    /// Run a single inference request asynchronously.
598    async fn infer_async(
599        &self,
600        request: &InferenceRequest,
601    ) -> Result<BinaryTensorPacket, EngineError> {
602        self.infer(request)
603    }
604
605    /// Run a batch of inference requests.
606    fn infer_batch(
607        &self,
608        requests: &[InferenceRequest],
609    ) -> Result<Vec<BinaryTensorPacket>, EngineError> {
610        requests.iter().map(|req| self.infer(req)).collect()
611    }
612
613    /// Run a batch of inference requests asynchronously.
614    async fn infer_batch_async(
615        &self,
616        requests: &[InferenceRequest],
617    ) -> Result<Vec<BinaryTensorPacket>, EngineError> {
618        self.infer_batch(requests)
619    }
620
621    /// Run a streaming inference request.
622    fn infer_stream(&self, request: &InferenceRequest) -> EngineStream;
623
624    /// Run inference with cancellation support.
625    fn infer_with_cancellation(
626        &self,
627        request: &InferenceRequest,
628        cancellation: &CancellationToken,
629    ) -> Result<BinaryTensorPacket, EngineError> {
630        if cancellation.is_cancelled() {
631            return Err(EngineError::Cancelled {
632                message: "Request cancelled".to_string(),
633            });
634        }
635        let result = self.infer(request);
636        if cancellation.is_cancelled() {
637            return Err(EngineError::Cancelled {
638                message: "Request cancelled".to_string(),
639            });
640        }
641        result
642    }
643
644    /// Warm up the model runtime before serving requests.
645    async fn warmup(&self) -> Result<(), EngineError> {
646        Ok(())
647    }
648
649    /// Release any held resources.
650    fn unload(&mut self);
651
652    /// Report the latest metrics snapshot.
653    fn metrics(&self) -> EngineMetrics;
654
655    /// Report model metadata when available.
656    fn model_info(&self) -> Option<EngineModelInfo> {
657        None
658    }
659
660    /// Check if the model is healthy.
661    fn health_check(&self) -> Result<(), EngineError>;
662}
663
664#[async_trait]
665impl Engine for Box<dyn Engine> {
666    async fn load(&mut self, model_path: &std::path::Path) -> Result<(), EngineError> {
667        (**self).load(model_path).await
668    }
669
670    fn infer(&self, request: &InferenceRequest) -> Result<BinaryTensorPacket, EngineError> {
671        (**self).infer(request)
672    }
673
674    async fn infer_async(
675        &self,
676        request: &InferenceRequest,
677    ) -> Result<BinaryTensorPacket, EngineError> {
678        (**self).infer_async(request).await
679    }
680
681    fn infer_batch(
682        &self,
683        requests: &[InferenceRequest],
684    ) -> Result<Vec<BinaryTensorPacket>, EngineError> {
685        (**self).infer_batch(requests)
686    }
687
688    async fn infer_batch_async(
689        &self,
690        requests: &[InferenceRequest],
691    ) -> Result<Vec<BinaryTensorPacket>, EngineError> {
692        (**self).infer_batch_async(requests).await
693    }
694
695    fn infer_stream(&self, request: &InferenceRequest) -> EngineStream {
696        (**self).infer_stream(request)
697    }
698
699    fn infer_with_cancellation(
700        &self,
701        request: &InferenceRequest,
702        cancellation: &CancellationToken,
703    ) -> Result<BinaryTensorPacket, EngineError> {
704        (**self).infer_with_cancellation(request, cancellation)
705    }
706
707    async fn warmup(&self) -> Result<(), EngineError> {
708        (**self).warmup().await
709    }
710
711    fn unload(&mut self) {
712        (**self).unload()
713    }
714
715    fn metrics(&self) -> EngineMetrics {
716        (**self).metrics()
717    }
718
719    fn model_info(&self) -> Option<EngineModelInfo> {
720        (**self).model_info()
721    }
722
723    fn health_check(&self) -> Result<(), EngineError> {
724        (**self).health_check()
725    }
726}
727
728pub type EngineHandle = Arc<dyn Engine>;
729
730#[cfg(test)]
731mod tests {
732    use super::*;
733
734    #[test]
735    fn binary_tensor_packet_deserializes_from_data_array() {
736        let payload = serde_json::json!({
737            "shape": [1, 2],
738            "dtype": "uint8",
739            "data": [1, 2]
740        });
741        let packet: BinaryTensorPacket =
742            serde_json::from_value(payload).expect("packet should deserialize");
743        assert_eq!(packet.shape, vec![1, 2]);
744        assert_eq!(packet.dtype, TensorDtype::Uint8);
745        assert_eq!(packet.data, vec![1, 2]);
746    }
747
748    #[test]
749    fn binary_tensor_packet_deserializes_from_data_base64() {
750        let payload = serde_json::json!({
751            "shape": [1, 4],
752            "dtype": "uint8",
753            "data_base64": "AQIDBA=="
754        });
755        let packet: BinaryTensorPacket =
756            serde_json::from_value(payload).expect("packet should deserialize");
757        assert_eq!(packet.shape, vec![1, 4]);
758        assert_eq!(packet.dtype, TensorDtype::Uint8);
759        assert_eq!(packet.data, vec![1, 2, 3, 4]);
760    }
761
762    #[test]
763    fn binary_tensor_packet_deserializes_from_base64_alias() {
764        let payload = serde_json::json!({
765            "shape": [1, 3],
766            "dtype": "uint8",
767            "base64": "AQID"
768        });
769        let packet: BinaryTensorPacket =
770            serde_json::from_value(payload).expect("packet should deserialize");
771        assert_eq!(packet.shape, vec![1, 3]);
772        assert_eq!(packet.dtype, TensorDtype::Uint8);
773        assert_eq!(packet.data, vec![1, 2, 3]);
774    }
775
776    #[test]
777    fn binary_tensor_packet_rejects_both_data_and_data_base64() {
778        let payload = serde_json::json!({
779            "shape": [1],
780            "dtype": "uint8",
781            "data": [1],
782            "data_base64": "AQ=="
783        });
784        let err = serde_json::from_value::<BinaryTensorPacket>(payload)
785            .expect_err("packet should fail deserialization");
786        assert!(err
787            .to_string()
788            .contains("only one of `data` or `data_base64`"));
789    }
790}