1use anyhow::{anyhow, bail, Result};
10use reflow_media_types::{PacketMetadata, TensorDType, TensorPacket, TensorShape};
11use serde::{Deserialize, Serialize};
12use serde_json::{json, Value};
13use std::collections::HashMap;
14use std::sync::Arc;
15
16#[cfg(feature = "external-litert")]
17pub use external_litert::LiteRtBackend;
18
19#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(rename_all = "camelCase")]
21pub struct TensorSpec {
22 pub name: String,
23 pub dtype: TensorDType,
24 pub shape: TensorShape,
25}
26
27impl TensorSpec {
28 pub fn byte_len(&self) -> usize {
29 self.shape.element_count() * self.dtype.bytes_per_element()
30 }
31}
32
33#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
34#[serde(rename_all = "camelCase")]
35pub struct ModelInfo {
36 pub id: String,
37 pub backend: String,
38 pub task: String,
39 #[serde(default)]
40 pub inputs: Vec<TensorSpec>,
41 #[serde(default)]
42 pub outputs: Vec<TensorSpec>,
43 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
44 pub metadata: HashMap<String, Value>,
45}
46
47impl ModelInfo {
48 pub fn mock(id: impl Into<String>, task: impl Into<String>, outputs: Vec<TensorSpec>) -> Self {
49 Self {
50 id: id.into(),
51 backend: "mock".to_string(),
52 task: task.into(),
53 inputs: Vec::new(),
54 outputs,
55 metadata: HashMap::new(),
56 }
57 }
58}
59
60#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
61#[serde(rename_all = "camelCase")]
62pub struct InferenceInput {
63 pub name: String,
64 pub tensor: TensorPacket,
65}
66
67#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
68#[serde(rename_all = "camelCase")]
69pub struct InferenceOutput {
70 #[serde(default)]
71 pub tensors: Vec<TensorPacket>,
72 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
73 pub metadata: HashMap<String, Value>,
74}
75
76pub trait InferenceBackend: Send + Sync {
77 fn name(&self) -> &str;
78
79 fn load_model(
80 &self,
81 model: ModelInfo,
82 model_data: Option<Arc<Vec<u8>>>,
83 ) -> Result<Box<dyn InferenceSession>>;
84}
85
86pub trait InferenceSession: Send + Sync {
87 fn model_info(&self) -> &ModelInfo;
88
89 fn run(&self, inputs: &[InferenceInput]) -> Result<InferenceOutput>;
90}
91
92#[derive(Debug, Clone, Default)]
93pub struct MockBackend;
94
95impl MockBackend {
96 pub fn new() -> Self {
97 Self
98 }
99}
100
101impl InferenceBackend for MockBackend {
102 fn name(&self) -> &str {
103 "mock"
104 }
105
106 fn load_model(
107 &self,
108 model: ModelInfo,
109 model_data: Option<Arc<Vec<u8>>>,
110 ) -> Result<Box<dyn InferenceSession>> {
111 if model.backend != "mock" && model.backend != "litert" {
112 bail!("mock backend cannot load backend '{}'", model.backend);
113 }
114 Ok(Box::new(MockSession {
115 model,
116 model_data_len: model_data.as_ref().map(|bytes| bytes.len()).unwrap_or(0),
117 }))
118 }
119}
120
121#[derive(Debug, Clone)]
122struct MockSession {
123 model: ModelInfo,
124 model_data_len: usize,
125}
126
127impl InferenceSession for MockSession {
128 fn model_info(&self) -> &ModelInfo {
129 &self.model
130 }
131
132 fn run(&self, inputs: &[InferenceInput]) -> Result<InferenceOutput> {
133 let outputs = if self.model.outputs.is_empty() {
134 infer_default_outputs(inputs)?
135 } else {
136 self.model
137 .outputs
138 .iter()
139 .map(|spec| deterministic_tensor(spec, &self.model, inputs))
140 .collect()
141 };
142
143 Ok(InferenceOutput {
144 tensors: outputs,
145 metadata: HashMap::from([
146 ("backend".to_string(), json!("mock")),
147 ("modelId".to_string(), json!(self.model.id)),
148 ("modelBytes".to_string(), json!(self.model_data_len)),
149 ("inputCount".to_string(), json!(inputs.len())),
150 ]),
151 })
152 }
153}
154
155#[cfg(feature = "external-litert")]
156mod external_litert {
157 use super::*;
158 use litert::{
159 Accelerators, CompilationOptions, CompiledModel, ElementType, Environment, Model,
160 TensorBuffer,
161 };
162 use std::sync::Mutex;
163
164 #[derive(Debug, Clone)]
165 pub struct LiteRtBackend {
166 accelerators: Accelerators,
167 }
168
169 impl LiteRtBackend {
170 pub fn new() -> Self {
171 Self {
172 accelerators: Accelerators::CPU,
173 }
174 }
175
176 pub fn with_accelerators(accelerators: Accelerators) -> Self {
177 Self { accelerators }
178 }
179 }
180
181 impl Default for LiteRtBackend {
182 fn default() -> Self {
183 Self::new()
184 }
185 }
186
187 impl InferenceBackend for LiteRtBackend {
188 fn name(&self) -> &str {
189 "litert"
190 }
191
192 fn load_model(
193 &self,
194 model: ModelInfo,
195 model_data: Option<Arc<Vec<u8>>>,
196 ) -> Result<Box<dyn InferenceSession>> {
197 if model.backend != "litert" {
198 bail!("LiteRT backend cannot load backend '{}'", model.backend);
199 }
200 let model_data =
201 model_data.ok_or_else(|| anyhow!("LiteRT backend requires model_data bytes"))?;
202 let env = Environment::new().map_err(|err| anyhow!("LiteRT environment: {err}"))?;
203 let litert_model = Model::from_bytes(model_data.as_ref().clone().into_boxed_slice())
204 .map_err(|err| anyhow!("LiteRT model load: {err}"))?;
205 let signature = litert_model
206 .signature(0)
207 .map_err(|err| anyhow!("LiteRT signature 0: {err}"))?;
208 let input_shapes = (0..signature.input_count()?)
209 .map(|index| signature.input_shape(index))
210 .collect::<litert::Result<Vec<_>>>()
211 .map_err(|err| anyhow!("LiteRT input shape introspection: {err}"))?;
212 let output_shapes = (0..signature.output_count()?)
213 .map(|index| signature.output_shape(index))
214 .collect::<litert::Result<Vec<_>>>()
215 .map_err(|err| anyhow!("LiteRT output shape introspection: {err}"))?;
216 validate_specs("input", &model.inputs, &input_shapes)?;
217 validate_specs("output", &model.outputs, &output_shapes)?;
218 let accelerators = accelerators_from_model(&model, self.accelerators)?;
219 let options = CompilationOptions::new()
220 .and_then(|options| options.with_accelerators(accelerators))
221 .map_err(|err| anyhow!("LiteRT compilation options: {err}"))?;
222 let mut input_buffers = input_shapes
223 .iter()
224 .map(|shape| TensorBuffer::managed_host(&env, shape))
225 .collect::<litert::Result<Vec<_>>>()
226 .map_err(|err| anyhow!("LiteRT input buffer allocation: {err}"))?;
227 let mut output_buffers = output_shapes
228 .iter()
229 .map(|shape| TensorBuffer::managed_host(&env, shape))
230 .collect::<litert::Result<Vec<_>>>()
231 .map_err(|err| anyhow!("LiteRT output buffer allocation: {err}"))?;
232 let compiled = CompiledModel::new(env, litert_model, &options)
233 .map_err(|err| anyhow!("LiteRT compile: {err}"))?;
234 let fully_accelerated = compiled
235 .is_fully_accelerated()
236 .map_err(|err| anyhow!("LiteRT acceleration query: {err}"))?;
237
238 input_buffers.shrink_to_fit();
240 output_buffers.shrink_to_fit();
241
242 Ok(Box::new(LiteRtSession {
243 model,
244 state: Mutex::new(LiteRtSessionState {
245 compiled,
246 input_buffers,
247 output_buffers,
248 input_shapes,
249 output_shapes,
250 accelerators,
251 fully_accelerated,
252 }),
253 }))
254 }
255 }
256
257 struct LiteRtSession {
258 model: ModelInfo,
259 state: Mutex<LiteRtSessionState>,
260 }
261
262 struct LiteRtSessionState {
263 compiled: CompiledModel,
264 input_buffers: Vec<TensorBuffer>,
265 output_buffers: Vec<TensorBuffer>,
266 input_shapes: Vec<litert::TensorShape>,
267 output_shapes: Vec<litert::TensorShape>,
268 accelerators: Accelerators,
269 fully_accelerated: bool,
270 }
271
272 impl InferenceSession for LiteRtSession {
273 fn model_info(&self) -> &ModelInfo {
274 &self.model
275 }
276
277 fn run(&self, inputs: &[InferenceInput]) -> Result<InferenceOutput> {
278 let mut state = self
279 .state
280 .lock()
281 .map_err(|_| anyhow!("LiteRT session state was poisoned"))?;
282 if inputs.len() != state.input_shapes.len() {
283 bail!(
284 "LiteRT input count mismatch: model expects {}, graph provided {}",
285 state.input_shapes.len(),
286 inputs.len()
287 );
288 }
289
290 for index in 0..state.input_buffers.len() {
291 let input = input_for_index(&self.model, inputs, index)?;
292 let shape = state.input_shapes[index].clone();
293 let buffer = &mut state.input_buffers[index];
294 write_tensor_to_buffer(&input.tensor, buffer, &shape)?;
295 }
296
297 {
298 let LiteRtSessionState {
299 compiled,
300 input_buffers,
301 output_buffers,
302 ..
303 } = &mut *state;
304 compiled
305 .run(input_buffers, output_buffers)
306 .map_err(|err| anyhow!("LiteRT inference: {err}"))?;
307 }
308
309 let mut tensors = Vec::with_capacity(state.output_buffers.len());
310 for index in 0..state.output_buffers.len() {
311 let buffer = &state.output_buffers[index];
312 let shape = state.output_shapes[index].clone();
313 let name = self
314 .model
315 .outputs
316 .get(index)
317 .map(|spec| spec.name.clone())
318 .unwrap_or_else(|| format!("output_{index}"));
319 tensors.push(read_tensor_from_buffer(name, buffer, &shape)?);
320 }
321 let fully_accelerated = state.fully_accelerated;
322 let accelerators = state.accelerators;
323
324 Ok(InferenceOutput {
325 tensors,
326 metadata: HashMap::from([
327 ("backend".to_string(), json!("litert")),
328 ("modelId".to_string(), json!(self.model.id)),
329 ("inputCount".to_string(), json!(inputs.len())),
330 (
331 "accelerators".to_string(),
332 json!(accelerator_names(accelerators)),
333 ),
334 ("fullyAccelerated".to_string(), json!(fully_accelerated)),
335 ]),
336 })
337 }
338 }
339
340 fn input_for_index<'a>(
341 model: &ModelInfo,
342 inputs: &'a [InferenceInput],
343 index: usize,
344 ) -> Result<&'a InferenceInput> {
345 if let Some(spec) = model.inputs.get(index) {
346 if let Some(input) = inputs.iter().find(|input| {
347 input.name == spec.name || input.tensor.name.as_deref() == Some(spec.name.as_str())
348 }) {
349 return Ok(input);
350 }
351 }
352 inputs
353 .get(index)
354 .ok_or_else(|| anyhow!("missing LiteRT input tensor at index {index}"))
355 }
356
357 fn write_tensor_to_buffer(
358 tensor: &TensorPacket,
359 buffer: &mut TensorBuffer,
360 shape: &litert::TensorShape,
361 ) -> Result<()> {
362 let expected_dtype = dtype_from_element_type(shape.element_type)?;
363 if tensor.dtype != expected_dtype {
364 bail!(
365 "LiteRT tensor dtype mismatch for {:?}: expected {:?}, got {:?}",
366 tensor.name,
367 expected_dtype,
368 tensor.dtype
369 );
370 }
371 let expected_dims = dims_from_litert(shape)?;
372 if tensor.shape.dims != expected_dims {
373 bail!(
374 "LiteRT tensor shape mismatch for {:?}: expected {:?}, got {:?}",
375 tensor.name,
376 expected_dims,
377 tensor.shape.dims
378 );
379 }
380 validate_tensor_byte_len(tensor)?;
381
382 match shape.element_type {
383 ElementType::Float32 => copy_to_buffer::<f32>(buffer, &read_f32_values(tensor)?)?,
384 ElementType::UInt8 => copy_to_buffer::<u8>(buffer, &tensor.data)?,
385 ElementType::Int8 => copy_to_buffer::<i8>(buffer, &read_i8_values(tensor)?)?,
386 ElementType::Int32 => copy_to_buffer::<i32>(buffer, &read_i32_values(tensor)?)?,
387 ElementType::Int64 => copy_to_buffer::<i64>(buffer, &read_i64_values(tensor)?)?,
388 ElementType::Bool => copy_to_buffer::<bool>(buffer, &read_bool_values(tensor)?)?,
389 other => bail!("LiteRT input element type {:?} is not supported yet", other),
390 }
391 Ok(())
392 }
393
394 fn read_tensor_from_buffer(
395 name: String,
396 buffer: &TensorBuffer,
397 shape: &litert::TensorShape,
398 ) -> Result<TensorPacket> {
399 let dtype = dtype_from_element_type(shape.element_type)?;
400 let dims = TensorShape::new(dims_from_litert(shape)?);
401 let data = match shape.element_type {
402 ElementType::Float32 => {
403 let values = buffer.lock_for_read::<f32>()?;
404 f32_values_to_bytes(&values)
405 }
406 ElementType::UInt8 => buffer.lock_for_read::<u8>()?.to_vec(),
407 ElementType::Int8 => buffer
408 .lock_for_read::<i8>()?
409 .iter()
410 .map(|value| *value as u8)
411 .collect(),
412 ElementType::Int32 => buffer
413 .lock_for_read::<i32>()?
414 .iter()
415 .flat_map(|value| value.to_le_bytes())
416 .collect(),
417 ElementType::Int64 => buffer
418 .lock_for_read::<i64>()?
419 .iter()
420 .flat_map(|value| value.to_le_bytes())
421 .collect(),
422 ElementType::Bool => buffer
423 .lock_for_read::<bool>()?
424 .iter()
425 .map(|value| u8::from(*value))
426 .collect(),
427 other => bail!(
428 "LiteRT output element type {:?} is not supported yet",
429 other
430 ),
431 };
432
433 Ok(TensorPacket::new(Some(name), dtype, dims, data))
434 }
435
436 fn copy_to_buffer<T: litert::TensorElement>(
437 buffer: &mut TensorBuffer,
438 values: &[T],
439 ) -> Result<()> {
440 let mut guard = buffer.lock_for_write::<T>()?;
441 if guard.len() != values.len() {
442 bail!(
443 "LiteRT buffer element count mismatch: expected {}, got {}",
444 guard.len(),
445 values.len()
446 );
447 }
448 guard.copy_from_slice(values);
449 Ok(())
450 }
451
452 fn validate_tensor_byte_len(tensor: &TensorPacket) -> Result<()> {
453 let expected = tensor.expected_byte_len();
454 if tensor.data.len() != expected {
455 bail!(
456 "tensor {:?} byte length mismatch: expected {}, got {}",
457 tensor.name,
458 expected,
459 tensor.data.len()
460 );
461 }
462 Ok(())
463 }
464
465 fn validate_specs(
466 label: &str,
467 specs: &[TensorSpec],
468 shapes: &[litert::TensorShape],
469 ) -> Result<()> {
470 if specs.is_empty() {
471 return Ok(());
472 }
473 if specs.len() != shapes.len() {
474 bail!(
475 "LiteRT {label} spec count mismatch: manifest declares {}, model exposes {}",
476 specs.len(),
477 shapes.len()
478 );
479 }
480 for (index, (spec, shape)) in specs.iter().zip(shapes.iter()).enumerate() {
481 let dtype = dtype_from_element_type(shape.element_type)?;
482 let dims = dims_from_litert(shape)?;
483 if spec.dtype != dtype || spec.shape.dims != dims {
484 bail!(
485 "LiteRT {label} spec mismatch at index {index} ({:?}): manifest {:?} {:?}, model {:?} {:?}",
486 spec.name,
487 spec.dtype,
488 spec.shape.dims,
489 dtype,
490 dims
491 );
492 }
493 }
494 Ok(())
495 }
496
497 fn dims_from_litert(shape: &litert::TensorShape) -> Result<Vec<usize>> {
498 shape
499 .dims
500 .iter()
501 .map(|dim| {
502 usize::try_from(*dim)
503 .map_err(|_| anyhow!("LiteRT tensor has negative dimension {dim}"))
504 })
505 .collect()
506 }
507
508 fn dtype_from_element_type(element_type: ElementType) -> Result<TensorDType> {
509 Ok(match element_type {
510 ElementType::Float32 => TensorDType::F32,
511 ElementType::Int32 => TensorDType::I32,
512 ElementType::Int64 => TensorDType::I64,
513 ElementType::UInt8 => TensorDType::U8,
514 ElementType::Int8 => TensorDType::I8,
515 ElementType::Bool => TensorDType::Bool,
516 other => bail!("LiteRT element type {:?} is not supported yet", other),
517 })
518 }
519
520 fn accelerators_from_model(model: &ModelInfo, default: Accelerators) -> Result<Accelerators> {
521 let Some(value) = model
522 .metadata
523 .get("accelerators")
524 .or_else(|| model.metadata.get("accelerator"))
525 else {
526 return Ok(default);
527 };
528
529 match value {
530 Value::String(text) => parse_accelerator_list(text),
531 Value::Array(values) => {
532 let mut accelerators = Accelerators::NONE;
533 for value in values {
534 let Some(name) = value.as_str() else {
535 bail!("accelerators metadata array must contain strings");
536 };
537 accelerators = accelerators | parse_accelerator(name)?;
538 }
539 Ok(accelerators)
540 }
541 _ => bail!("accelerators metadata must be a string or array of strings"),
542 }
543 }
544
545 fn parse_accelerator_list(text: &str) -> Result<Accelerators> {
546 let mut accelerators = Accelerators::NONE;
547 for part in text.split([',', '|', '+']) {
548 let part = part.trim();
549 if part.is_empty() {
550 continue;
551 }
552 accelerators = accelerators | parse_accelerator(part)?;
553 }
554 Ok(accelerators)
555 }
556
557 fn parse_accelerator(name: &str) -> Result<Accelerators> {
558 match name.trim().to_ascii_lowercase().as_str() {
559 "none" => Ok(Accelerators::NONE),
560 "cpu" => Ok(Accelerators::CPU),
561 "gpu" | "metal" => Ok(Accelerators::GPU),
562 "npu" => Ok(Accelerators::NPU),
563 other => bail!("unsupported LiteRT accelerator '{other}'"),
564 }
565 }
566
567 fn accelerator_names(accelerators: Accelerators) -> Vec<&'static str> {
568 if accelerators == Accelerators::NONE {
569 return vec!["none"];
570 }
571 let mut names = Vec::new();
572 if accelerators.contains(Accelerators::CPU) {
573 names.push("cpu");
574 }
575 if accelerators.contains(Accelerators::GPU) {
576 names.push("gpu");
577 }
578 if accelerators.contains(Accelerators::NPU) {
579 names.push("npu");
580 }
581 names
582 }
583
584 fn read_f32_values(tensor: &TensorPacket) -> Result<Vec<f32>> {
585 tensor
586 .as_f32_vec()
587 .ok_or_else(|| anyhow!("expected f32 tensor bytes"))
588 }
589
590 fn read_i8_values(tensor: &TensorPacket) -> Result<Vec<i8>> {
591 Ok(tensor.data.iter().map(|value| *value as i8).collect())
592 }
593
594 fn read_i32_values(tensor: &TensorPacket) -> Result<Vec<i32>> {
595 read_chunks::<4, i32>(&tensor.data, i32::from_le_bytes)
596 }
597
598 fn read_i64_values(tensor: &TensorPacket) -> Result<Vec<i64>> {
599 read_chunks::<8, i64>(&tensor.data, i64::from_le_bytes)
600 }
601
602 fn read_bool_values(tensor: &TensorPacket) -> Result<Vec<bool>> {
603 Ok(tensor.data.iter().map(|value| *value != 0).collect())
604 }
605
606 fn read_chunks<const N: usize, T>(
607 data: &[u8],
608 decode: impl Fn([u8; N]) -> T,
609 ) -> Result<Vec<T>> {
610 if data.len() % N != 0 {
611 bail!("tensor byte length {} is not divisible by {N}", data.len());
612 }
613 Ok(data
614 .chunks_exact(N)
615 .map(|chunk| {
616 let mut bytes = [0u8; N];
617 bytes.copy_from_slice(chunk);
618 decode(bytes)
619 })
620 .collect())
621 }
622
623 fn f32_values_to_bytes(values: &[f32]) -> Vec<u8> {
624 values
625 .iter()
626 .flat_map(|value| value.to_le_bytes())
627 .collect()
628 }
629}
630
631fn infer_default_outputs(inputs: &[InferenceInput]) -> Result<Vec<TensorPacket>> {
632 let first = inputs
633 .first()
634 .ok_or_else(|| anyhow!("mock inference requires at least one input tensor"))?;
635 let spec = TensorSpec {
636 name: "output".to_string(),
637 dtype: TensorDType::F32,
638 shape: TensorShape::new([1, first.tensor.shape.element_count().clamp(1, 16)]),
639 };
640 Ok(vec![deterministic_tensor(
641 &spec,
642 &ModelInfo::mock("mock", "generic", vec![spec.clone()]),
643 inputs,
644 )])
645}
646
647fn deterministic_tensor(
648 spec: &TensorSpec,
649 model: &ModelInfo,
650 inputs: &[InferenceInput],
651) -> TensorPacket {
652 let count = spec.shape.element_count();
653 let seed = stable_seed(model, inputs, &spec.name);
654 let mut metadata = PacketMetadata::default();
655 if let Some(first) = inputs.first() {
656 metadata.merge_missing_from(&first.tensor.metadata);
657 }
658 metadata.fields.insert("mockSeed".to_string(), json!(seed));
659
660 match spec.dtype {
661 TensorDType::F32 => {
662 let mut values = Vec::with_capacity(count);
663 for i in 0..count {
664 let raw = seed.wrapping_add((i as u64).wrapping_mul(1_103_515_245));
665 values.push(((raw % 10_000) as f32 / 10_000.0).clamp(0.0, 1.0));
666 }
667 let mut tensor =
668 TensorPacket::from_f32(Some(spec.name.clone()), spec.shape.clone(), &values);
669 tensor.metadata = metadata;
670 tensor
671 }
672 TensorDType::U8 => {
673 let data = (0..count)
674 .map(|i| seed.wrapping_add(i as u64) as u8)
675 .collect::<Vec<_>>();
676 let mut tensor = TensorPacket::new(
677 Some(spec.name.clone()),
678 TensorDType::U8,
679 spec.shape.clone(),
680 data,
681 );
682 tensor.metadata = metadata;
683 tensor
684 }
685 _ => {
686 let bytes = vec![0u8; spec.byte_len()];
687 let mut tensor = TensorPacket::new(
688 Some(spec.name.clone()),
689 spec.dtype,
690 spec.shape.clone(),
691 bytes,
692 );
693 tensor.metadata = metadata;
694 tensor
695 }
696 }
697}
698
699fn stable_seed(model: &ModelInfo, inputs: &[InferenceInput], output_name: &str) -> u64 {
700 let mut hash = 14_695_981_039_346_656_037u64;
701 for byte in model.id.bytes().chain(output_name.bytes()) {
702 hash ^= byte as u64;
703 hash = hash.wrapping_mul(1_099_511_628_211);
704 }
705 for input in inputs {
706 for byte in input.name.bytes() {
707 hash ^= byte as u64;
708 hash = hash.wrapping_mul(1_099_511_628_211);
709 }
710 hash ^= input.tensor.data.len() as u64;
711 hash = hash.wrapping_mul(1_099_511_628_211);
712 }
713 hash
714}
715
716#[cfg(test)]
717mod tests {
718 use super::*;
719
720 #[test]
721 fn mock_backend_is_deterministic() {
722 let model = ModelInfo::mock(
723 "hand-landmark",
724 "landmark",
725 vec![TensorSpec {
726 name: "landmarks".to_string(),
727 dtype: TensorDType::F32,
728 shape: TensorShape::new([1, 6]),
729 }],
730 );
731 let input = InferenceInput {
732 name: "image".to_string(),
733 tensor: TensorPacket::from_f32(
734 Some("image".to_string()),
735 TensorShape::new([1, 2]),
736 &[0.0, 1.0],
737 ),
738 };
739 let backend = MockBackend::new();
740 let session = backend.load_model(model, None).unwrap();
741
742 let a = session.run(std::slice::from_ref(&input)).unwrap();
743 let b = session.run(&[input]).unwrap();
744
745 assert_eq!(a, b);
746 assert_eq!(a.tensors[0].shape.dims, vec![1, 6]);
747 }
748
749 #[cfg(feature = "external-litert")]
750 #[test]
751 fn external_litert_backend_runs_bundled_add_fixture() -> Result<()> {
752 let _ = litert::set_global_log_severity(litert::LogSeverity::Error);
753 let model_data = std::fs::read(
754 std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
755 .join("tests/data/add_10x10.tflite"),
756 )?;
757 let shape = TensorShape::new([10, 10]);
758 let model = ModelInfo {
759 id: "add_10x10".to_string(),
760 backend: "litert".to_string(),
761 task: "elementwise_add".to_string(),
762 inputs: vec![
763 TensorSpec {
764 name: "lhs".to_string(),
765 dtype: TensorDType::F32,
766 shape: shape.clone(),
767 },
768 TensorSpec {
769 name: "rhs".to_string(),
770 dtype: TensorDType::F32,
771 shape: shape.clone(),
772 },
773 ],
774 outputs: vec![TensorSpec {
775 name: "sum".to_string(),
776 dtype: TensorDType::F32,
777 shape: shape.clone(),
778 }],
779 metadata: HashMap::new(),
780 };
781 let lhs = (0..100).map(|index| index as f32).collect::<Vec<_>>();
782 let rhs = (0..100)
783 .map(|index| 100.0 + index as f32)
784 .collect::<Vec<_>>();
785 let backend = LiteRtBackend::new();
786 let session = backend.load_model(model, Some(Arc::new(model_data)))?;
787 let output = session.run(&[
788 InferenceInput {
789 name: "lhs".to_string(),
790 tensor: TensorPacket::from_f32(Some("lhs".to_string()), shape.clone(), &lhs),
791 },
792 InferenceInput {
793 name: "rhs".to_string(),
794 tensor: TensorPacket::from_f32(Some("rhs".to_string()), shape, &rhs),
795 },
796 ])?;
797
798 let values = output.tensors[0].as_f32_vec().unwrap();
799 assert_eq!(values.len(), 100);
800 for (index, value) in values.iter().enumerate() {
801 assert!(
802 (*value - (100.0 + 2.0 * index as f32)).abs() < 1e-6,
803 "element {index}: got {value}"
804 );
805 }
806 assert_eq!(output.metadata.get("backend"), Some(&json!("litert")));
807 Ok(())
808 }
809}