1use async_trait::async_trait;
2use base64::Engine as _;
3use futures::stream::Stream;
4use serde::{Deserialize, Serialize};
5use std::borrow::Cow;
6use std::fmt;
7use std::pin::Pin;
8use std::str::FromStr;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::Arc;
11use std::time::{SystemTime, UNIX_EPOCH};
12use thiserror::Error;
13
14#[derive(Error, Debug)]
15pub enum EngineError {
16 #[error("Backend error: {message}")]
17 Backend {
18 message: String,
19 #[source]
20 source: Option<Box<dyn std::error::Error + Send + Sync>>,
21 },
22 #[error("Invalid input: {message}")]
23 InvalidInput {
24 message: String,
25 #[source]
26 source: Option<Box<dyn std::error::Error + Send + Sync>>,
27 },
28 #[error("Model not loaded")]
29 ModelNotLoaded,
30 #[error("System overloaded: {message}")]
31 Overloaded {
32 message: String,
33 #[source]
34 source: Option<Box<dyn std::error::Error + Send + Sync>>,
35 },
36 #[error("Model load error for {path}: {source}")]
37 ModelLoadError {
38 path: String,
39 source: Box<dyn std::error::Error + Send + Sync>,
40 },
41 #[error("Inference error: {reason}")]
42 InferenceError {
43 reason: String,
44 #[source]
45 source: Option<Box<dyn std::error::Error + Send + Sync>>,
46 },
47 #[error("Timeout: {message}")]
48 TimeoutError {
49 message: String,
50 #[source]
51 source: Option<Box<dyn std::error::Error + Send + Sync>>,
52 },
53 #[error("Resource exhausted: {message}")]
54 ResourceExhausted {
55 message: String,
56 #[source]
57 source: Option<Box<dyn std::error::Error + Send + Sync>>,
58 },
59 #[error("Cancelled: {message}")]
60 Cancelled { message: String },
61}
62
63impl EngineError {
64 pub fn backend(message: impl Into<String>) -> Self {
65 EngineError::Backend {
66 message: message.into(),
67 source: None,
68 }
69 }
70
71 pub fn backend_with_source(
72 message: impl Into<String>,
73 source: impl std::error::Error + Send + Sync + 'static,
74 ) -> Self {
75 EngineError::Backend {
76 message: message.into(),
77 source: Some(Box::new(source)),
78 }
79 }
80
81 pub fn invalid_input(message: impl Into<String>) -> Self {
82 EngineError::InvalidInput {
83 message: message.into(),
84 source: None,
85 }
86 }
87
88 pub fn invalid_input_with_source(
89 message: impl Into<String>,
90 source: impl std::error::Error + Send + Sync + 'static,
91 ) -> Self {
92 EngineError::InvalidInput {
93 message: message.into(),
94 source: Some(Box::new(source)),
95 }
96 }
97
98 pub fn overloaded(message: impl Into<String>) -> Self {
99 EngineError::Overloaded {
100 message: message.into(),
101 source: None,
102 }
103 }
104
105 pub fn is_overloaded(&self) -> bool {
106 matches!(self, EngineError::Overloaded { .. })
107 }
108
109 pub fn timeout(message: impl Into<String>) -> Self {
110 EngineError::TimeoutError {
111 message: message.into(),
112 source: None,
113 }
114 }
115
116 pub fn resource_exhausted(message: impl Into<String>) -> Self {
117 EngineError::ResourceExhausted {
118 message: message.into(),
119 source: None,
120 }
121 }
122
123 pub fn cancelled(message: impl Into<String>) -> Self {
124 EngineError::Cancelled {
125 message: message.into(),
126 }
127 }
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct EngineMetrics {
132 pub inference_time: f64,
133 pub memory_usage: usize,
134 pub gpu_utilization: f64,
135 pub throughput: f64,
136 pub batch_size: usize,
137 pub queue_depth: usize,
138 pub error_rate: f64,
139 pub collected_at_ms: u64,
140 pub kv_cache_bytes_used: usize,
141 pub kv_cache_bytes_capacity: usize,
142 pub kv_cache_blocks_total: usize,
143 pub kv_cache_blocks_free: usize,
144 pub kv_cache_sequences: usize,
145 pub kv_cache_evicted_blocks: u64,
146 pub kv_cache_evicted_sequences: u64,
147 pub kv_cache_packed_layers: usize,
148}
149
150impl EngineMetrics {
151 pub fn new() -> Self {
152 Self {
153 inference_time: 0.0,
154 memory_usage: 0,
155 gpu_utilization: 0.0,
156 throughput: 0.0,
157 batch_size: 0,
158 queue_depth: 0,
159 error_rate: 0.0,
160 collected_at_ms: Self::now_ms(),
161 kv_cache_bytes_used: 0,
162 kv_cache_bytes_capacity: 0,
163 kv_cache_blocks_total: 0,
164 kv_cache_blocks_free: 0,
165 kv_cache_sequences: 0,
166 kv_cache_evicted_blocks: 0,
167 kv_cache_evicted_sequences: 0,
168 kv_cache_packed_layers: 0,
169 }
170 }
171
172 pub fn refresh_timestamp(&mut self) {
173 self.collected_at_ms = Self::now_ms();
174 }
175
176 fn now_ms() -> u64 {
177 SystemTime::now()
178 .duration_since(UNIX_EPOCH)
179 .unwrap_or_default()
180 .as_millis() as u64
181 }
182}
183
184impl Default for EngineMetrics {
185 fn default() -> Self {
186 Self::new()
187 }
188}
189
190#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
191pub enum TensorDtype {
192 Float32,
193 Float64,
194 Float16,
195 Int32,
196 Int64,
197 Uint8,
198 Utf8,
199}
200
201impl TensorDtype {
202 pub fn as_str(&self) -> &'static str {
203 match self {
204 TensorDtype::Float32 => "float32",
205 TensorDtype::Float64 => "float64",
206 TensorDtype::Float16 => "float16",
207 TensorDtype::Int32 => "int32",
208 TensorDtype::Int64 => "int64",
209 TensorDtype::Uint8 => "uint8",
210 TensorDtype::Utf8 => "string",
211 }
212 }
213
214 pub fn size_bytes(&self) -> usize {
215 match self {
216 TensorDtype::Float32 => 4,
217 TensorDtype::Float64 => 8,
218 TensorDtype::Float16 => 2,
219 TensorDtype::Int32 => 4,
220 TensorDtype::Int64 => 8,
221 TensorDtype::Uint8 => 1,
222 TensorDtype::Utf8 => 1,
223 }
224 }
225}
226
227impl fmt::Display for TensorDtype {
228 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229 write!(f, "{}", self.as_str())
230 }
231}
232
233impl FromStr for TensorDtype {
234 type Err = EngineError;
235
236 fn from_str(value: &str) -> Result<Self, Self::Err> {
237 match value.to_lowercase().as_str() {
238 "float32" | "fp32" => Ok(TensorDtype::Float32),
239 "float64" | "fp64" => Ok(TensorDtype::Float64),
240 "float16" | "fp16" => Ok(TensorDtype::Float16),
241 "int32" | "i32" => Ok(TensorDtype::Int32),
242 "int64" | "i64" => Ok(TensorDtype::Int64),
243 "uint8" | "u8" => Ok(TensorDtype::Uint8),
244 "string" | "utf8" => Ok(TensorDtype::Utf8),
245 other => Err(EngineError::InvalidInput {
246 message: format!("Unsupported dtype: {}", other),
247 source: None,
248 }),
249 }
250 }
251}
252
253impl Serialize for TensorDtype {
254 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
255 where
256 S: serde::Serializer,
257 {
258 serializer.serialize_str(self.as_str())
259 }
260}
261
262impl<'de> Deserialize<'de> for TensorDtype {
263 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
264 where
265 D: serde::Deserializer<'de>,
266 {
267 let value = String::deserialize(deserializer)?;
268 TensorDtype::from_str(&value).map_err(serde::de::Error::custom)
269 }
270}
271
272#[derive(Debug, Clone)]
273pub struct BinaryTensorPacket {
274 pub shape: Vec<i64>,
275 pub dtype: TensorDtype,
276 pub data: Vec<u8>,
277}
278
279impl serde::Serialize for BinaryTensorPacket {
280 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
281 use serde::ser::SerializeStruct;
282 let mut state = serializer.serialize_struct("BinaryTensorPacket", 4)?;
286 state.serialize_field("shape", &self.shape)?;
287 state.serialize_field("dtype", &self.dtype)?;
288 state.serialize_field("data", &Some(&self.data))?;
289 state.serialize_field("data_base64", &None::<&str>)?;
290 state.end()
291 }
292}
293
294impl<'de> Deserialize<'de> for BinaryTensorPacket {
295 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
296 where
297 D: serde::Deserializer<'de>,
298 {
299 #[derive(Deserialize)]
306 struct BinaryTensorPacketPayload<'src> {
307 shape: Vec<i64>,
308 dtype: TensorDtype,
309 #[serde(default)]
310 data: Option<Vec<u8>>,
311 #[serde(default, alias = "base64", borrow)]
312 data_base64: Option<Cow<'src, str>>,
313 }
314
315 let payload = BinaryTensorPacketPayload::deserialize(deserializer)?;
316 let data = match (payload.data, payload.data_base64) {
317 (Some(data), None) => data,
318 (None, Some(encoded)) => base64::engine::general_purpose::STANDARD
319 .decode(encoded.as_bytes())
320 .map_err(serde::de::Error::custom)?,
321 (Some(_), Some(_)) => {
322 return Err(serde::de::Error::custom(
323 "binary tensor payload must include only one of `data` or `data_base64`",
324 ))
325 }
326 (None, None) => {
327 return Err(serde::de::Error::custom(
328 "binary tensor payload must include `data` or `data_base64`",
329 ))
330 }
331 };
332
333 Ok(Self {
334 shape: payload.shape,
335 dtype: payload.dtype,
336 data,
337 })
338 }
339}
340
341#[derive(Debug, Clone, Serialize, Deserialize)]
342pub struct BinaryTensorPacketRef<'a> {
343 pub shape: Vec<i64>,
344 pub dtype: TensorDtype,
345 #[serde(borrow)]
346 pub data: Cow<'a, [u8]>,
347}
348
349#[derive(Debug, Clone, Copy)]
350pub struct TensorView<'a> {
351 pub shape: &'a [i64],
352 pub dtype: TensorDtype,
353 pub data: &'a [u8],
354}
355
356impl BinaryTensorPacket {
357 pub fn new(shape: Vec<i64>, dtype: TensorDtype, data: Vec<u8>) -> Result<Self, EngineError> {
358 let packet = Self { shape, dtype, data };
359 packet.validate()?;
360 Ok(packet)
361 }
362
363 pub fn size_bytes(&self) -> usize {
364 self.data.len()
365 }
366
367 pub fn tensor_elements(&self) -> Result<usize, EngineError> {
368 shape_elements(&self.shape)
369 }
370
371 pub fn tensor_elements_cached(&self, cache: &mut Option<usize>) -> Result<usize, EngineError> {
372 if let Some(value) = *cache {
373 return Ok(value);
374 }
375 let value = self.tensor_elements()?;
376 *cache = Some(value);
377 Ok(value)
378 }
379
380 pub fn validate(&self) -> Result<(), EngineError> {
381 let elements = self.tensor_elements()?;
382 let expected = elements
383 .checked_mul(self.dtype.size_bytes())
384 .ok_or_else(|| EngineError::InvalidInput {
385 message: "Data size overflow".to_string(),
386 source: None,
387 })?;
388
389 if self.data.len() != expected {
390 return Err(EngineError::InvalidInput {
391 message: format!(
392 "Data length mismatch: expected {} bytes ({} {} values) but got {} bytes",
393 expected,
394 elements,
395 self.dtype,
396 self.data.len()
397 ),
398 source: None,
399 });
400 }
401
402 Ok(())
403 }
404
405 pub fn view(&self) -> TensorView<'_> {
406 TensorView {
407 shape: &self.shape,
408 dtype: self.dtype,
409 data: &self.data,
410 }
411 }
412
413 pub fn as_borrowed(&self) -> BinaryTensorPacketRef<'_> {
414 BinaryTensorPacketRef::from(self)
415 }
416}
417
418impl<'a> BinaryTensorPacketRef<'a> {
419 pub fn to_owned(self) -> BinaryTensorPacket {
420 BinaryTensorPacket {
421 shape: self.shape,
422 dtype: self.dtype,
423 data: self.data.into_owned(),
424 }
425 }
426}
427
428impl<'a> From<&'a BinaryTensorPacket> for BinaryTensorPacketRef<'a> {
429 fn from(packet: &'a BinaryTensorPacket) -> Self {
430 Self {
431 shape: packet.shape.clone(),
432 dtype: packet.dtype,
433 data: Cow::Borrowed(&packet.data),
434 }
435 }
436}
437
438fn shape_elements(shape: &[i64]) -> Result<usize, EngineError> {
439 if shape.is_empty() {
440 return Ok(1);
441 }
442
443 let mut prod: usize = 1;
444 for &dim in shape {
445 if dim <= 0 {
446 return Err(EngineError::InvalidInput {
447 message: format!("Invalid shape dimension: {}", dim),
448 source: None,
449 });
450 }
451 prod = prod
452 .checked_mul(dim as usize)
453 .ok_or_else(|| EngineError::InvalidInput {
454 message: "Shape multiplication overflow".to_string(),
455 source: None,
456 })?;
457 }
458
459 Ok(prod)
460}
461
462#[derive(Debug, Clone, Default, Serialize, Deserialize)]
463pub struct RequestMetadata {
464 #[serde(default)]
465 pub request_id: Option<String>,
466 #[serde(default)]
467 pub timeout_ms: Option<u64>,
468 #[serde(default)]
469 pub priority: Option<u8>,
470 #[serde(default)]
471 pub force_cpu: Option<bool>,
472 #[serde(default)]
473 pub model_version: Option<String>,
474 #[serde(default, skip_serializing_if = "Option::is_none")]
475 pub auth_token: Option<String>,
476
477 #[serde(default, alias = "max_tokens")]
479 pub max_new_tokens: Option<u32>,
480 #[serde(default)]
481 pub temperature: Option<f32>,
482 #[serde(default)]
483 pub top_p: Option<f32>,
484 #[serde(default)]
485 pub top_k: Option<u32>,
486 #[serde(default)]
487 pub repetition_penalty: Option<f32>,
488 #[serde(default)]
489 pub seed: Option<u64>,
490 #[serde(default, alias = "stop_ids")]
491 pub stop_token_ids: Option<Vec<u32>>,
492}
493
494#[derive(Debug, Clone, Serialize, Deserialize)]
495pub struct NamedTensor {
496 pub name: String,
497 pub tensor: BinaryTensorPacket,
498}
499
500#[derive(Debug, Clone, Default)]
501pub struct CancellationToken {
502 cancelled: Arc<AtomicBool>,
503}
504
505impl CancellationToken {
506 pub fn new() -> Self {
507 Self {
508 cancelled: Arc::new(AtomicBool::new(false)),
509 }
510 }
511
512 pub fn cancel(&self) {
513 self.cancelled.store(true, Ordering::SeqCst);
514 }
515
516 pub fn is_cancelled(&self) -> bool {
517 self.cancelled.load(Ordering::SeqCst)
518 }
519}
520
521#[derive(Debug, Clone, Serialize, Deserialize)]
522pub struct InferenceRequest {
523 pub input: BinaryTensorPacket,
524 #[serde(default)]
525 pub additional_inputs: Vec<NamedTensor>,
526 #[serde(default)]
527 pub session_id: Option<String>,
528 #[serde(default)]
529 pub metadata: Option<RequestMetadata>,
530 #[serde(skip, default)]
531 pub cancellation: Option<CancellationToken>,
532}
533
534impl InferenceRequest {
535 pub fn new(input: BinaryTensorPacket) -> Self {
536 Self {
537 input,
538 additional_inputs: Vec::new(),
539 session_id: None,
540 metadata: None,
541 cancellation: None,
542 }
543 }
544
545 pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
546 self.session_id = Some(session_id.into());
547 self
548 }
549
550 pub fn with_metadata(mut self, metadata: RequestMetadata) -> Self {
551 self.metadata = Some(metadata);
552 self
553 }
554
555 pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
556 let metadata = self.metadata.get_or_insert_with(RequestMetadata::default);
557 metadata.request_id = Some(request_id.into());
558 self
559 }
560
561 pub fn add_input(&mut self, name: impl Into<String>, tensor: BinaryTensorPacket) {
562 self.additional_inputs.push(NamedTensor {
563 name: name.into(),
564 tensor,
565 });
566 }
567}
568
569#[derive(Debug, Clone, Serialize, Deserialize)]
570pub struct EngineModelInfo {
571 pub input_names: Vec<String>,
572 pub output_names: Vec<String>,
573 pub input_shapes: Vec<Vec<i64>>,
574 pub output_shapes: Vec<Vec<i64>>,
575 #[serde(default, skip_serializing_if = "Vec::is_empty")]
576 pub input_dtypes: Vec<String>,
577 #[serde(default, skip_serializing_if = "Vec::is_empty")]
578 pub output_dtypes: Vec<String>,
579 #[serde(default, skip_serializing_if = "Option::is_none")]
580 pub framework: Option<String>,
581 #[serde(default, skip_serializing_if = "Option::is_none")]
582 pub model_version: Option<String>,
583 #[serde(default, skip_serializing_if = "Option::is_none")]
584 pub peak_concurrency: Option<u32>,
585}
586
587pub type EngineStream = Pin<Box<dyn Stream<Item = Result<BinaryTensorPacket, EngineError>> + Send>>;
588
589#[async_trait]
590pub trait Engine: Send + Sync {
591 async fn load(&mut self, model_path: &std::path::Path) -> Result<(), EngineError>;
593
594 fn infer(&self, request: &InferenceRequest) -> Result<BinaryTensorPacket, EngineError>;
596
597 async fn infer_async(
599 &self,
600 request: &InferenceRequest,
601 ) -> Result<BinaryTensorPacket, EngineError> {
602 self.infer(request)
603 }
604
605 fn infer_batch(
607 &self,
608 requests: &[InferenceRequest],
609 ) -> Result<Vec<BinaryTensorPacket>, EngineError> {
610 requests.iter().map(|req| self.infer(req)).collect()
611 }
612
613 async fn infer_batch_async(
615 &self,
616 requests: &[InferenceRequest],
617 ) -> Result<Vec<BinaryTensorPacket>, EngineError> {
618 self.infer_batch(requests)
619 }
620
621 fn infer_stream(&self, request: &InferenceRequest) -> EngineStream;
623
624 fn infer_with_cancellation(
626 &self,
627 request: &InferenceRequest,
628 cancellation: &CancellationToken,
629 ) -> Result<BinaryTensorPacket, EngineError> {
630 if cancellation.is_cancelled() {
631 return Err(EngineError::Cancelled {
632 message: "Request cancelled".to_string(),
633 });
634 }
635 let result = self.infer(request);
636 if cancellation.is_cancelled() {
637 return Err(EngineError::Cancelled {
638 message: "Request cancelled".to_string(),
639 });
640 }
641 result
642 }
643
644 async fn warmup(&self) -> Result<(), EngineError> {
646 Ok(())
647 }
648
649 fn unload(&mut self);
651
652 fn metrics(&self) -> EngineMetrics;
654
655 fn model_info(&self) -> Option<EngineModelInfo> {
657 None
658 }
659
660 fn health_check(&self) -> Result<(), EngineError>;
662}
663
664#[async_trait]
665impl Engine for Box<dyn Engine> {
666 async fn load(&mut self, model_path: &std::path::Path) -> Result<(), EngineError> {
667 (**self).load(model_path).await
668 }
669
670 fn infer(&self, request: &InferenceRequest) -> Result<BinaryTensorPacket, EngineError> {
671 (**self).infer(request)
672 }
673
674 async fn infer_async(
675 &self,
676 request: &InferenceRequest,
677 ) -> Result<BinaryTensorPacket, EngineError> {
678 (**self).infer_async(request).await
679 }
680
681 fn infer_batch(
682 &self,
683 requests: &[InferenceRequest],
684 ) -> Result<Vec<BinaryTensorPacket>, EngineError> {
685 (**self).infer_batch(requests)
686 }
687
688 async fn infer_batch_async(
689 &self,
690 requests: &[InferenceRequest],
691 ) -> Result<Vec<BinaryTensorPacket>, EngineError> {
692 (**self).infer_batch_async(requests).await
693 }
694
695 fn infer_stream(&self, request: &InferenceRequest) -> EngineStream {
696 (**self).infer_stream(request)
697 }
698
699 fn infer_with_cancellation(
700 &self,
701 request: &InferenceRequest,
702 cancellation: &CancellationToken,
703 ) -> Result<BinaryTensorPacket, EngineError> {
704 (**self).infer_with_cancellation(request, cancellation)
705 }
706
707 async fn warmup(&self) -> Result<(), EngineError> {
708 (**self).warmup().await
709 }
710
711 fn unload(&mut self) {
712 (**self).unload()
713 }
714
715 fn metrics(&self) -> EngineMetrics {
716 (**self).metrics()
717 }
718
719 fn model_info(&self) -> Option<EngineModelInfo> {
720 (**self).model_info()
721 }
722
723 fn health_check(&self) -> Result<(), EngineError> {
724 (**self).health_check()
725 }
726}
727
728pub type EngineHandle = Arc<dyn Engine>;
729
730#[cfg(test)]
731mod tests {
732 use super::*;
733
734 #[test]
735 fn binary_tensor_packet_deserializes_from_data_array() {
736 let payload = serde_json::json!({
737 "shape": [1, 2],
738 "dtype": "uint8",
739 "data": [1, 2]
740 });
741 let packet: BinaryTensorPacket =
742 serde_json::from_value(payload).expect("packet should deserialize");
743 assert_eq!(packet.shape, vec![1, 2]);
744 assert_eq!(packet.dtype, TensorDtype::Uint8);
745 assert_eq!(packet.data, vec![1, 2]);
746 }
747
748 #[test]
749 fn binary_tensor_packet_deserializes_from_data_base64() {
750 let payload = serde_json::json!({
751 "shape": [1, 4],
752 "dtype": "uint8",
753 "data_base64": "AQIDBA=="
754 });
755 let packet: BinaryTensorPacket =
756 serde_json::from_value(payload).expect("packet should deserialize");
757 assert_eq!(packet.shape, vec![1, 4]);
758 assert_eq!(packet.dtype, TensorDtype::Uint8);
759 assert_eq!(packet.data, vec![1, 2, 3, 4]);
760 }
761
762 #[test]
763 fn binary_tensor_packet_deserializes_from_base64_alias() {
764 let payload = serde_json::json!({
765 "shape": [1, 3],
766 "dtype": "uint8",
767 "base64": "AQID"
768 });
769 let packet: BinaryTensorPacket =
770 serde_json::from_value(payload).expect("packet should deserialize");
771 assert_eq!(packet.shape, vec![1, 3]);
772 assert_eq!(packet.dtype, TensorDtype::Uint8);
773 assert_eq!(packet.data, vec![1, 2, 3]);
774 }
775
776 #[test]
777 fn binary_tensor_packet_rejects_both_data_and_data_base64() {
778 let payload = serde_json::json!({
779 "shape": [1],
780 "dtype": "uint8",
781 "data": [1],
782 "data_base64": "AQ=="
783 });
784 let err = serde_json::from_value::<BinaryTensorPacket>(payload)
785 .expect_err("packet should fail deserialization");
786 assert!(err
787 .to_string()
788 .contains("only one of `data` or `data_base64`"));
789 }
790}