Skip to main content

reflow_litert/
lib.rs

1//! Backend boundary for LiteRT-class inference in Reflow.
2//!
3//! The default build remains deterministic and mock-first so CI and graph
4//! authoring do not require native ML libraries. The real LiteRT adapter is
5//! optional and targets the published Offbit `litert` crates, wiring them in
6//! behind the same backend traits without changing graph-facing actors or
7//! taskpacks.
8
9use anyhow::{anyhow, bail, Result};
10use reflow_media_types::{PacketMetadata, TensorDType, TensorPacket, TensorShape};
11use serde::{Deserialize, Serialize};
12use serde_json::{json, Value};
13use std::collections::HashMap;
14use std::sync::Arc;
15
16#[cfg(feature = "external-litert")]
17pub use external_litert::LiteRtBackend;
18
19#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(rename_all = "camelCase")]
21pub struct TensorSpec {
22    pub name: String,
23    pub dtype: TensorDType,
24    pub shape: TensorShape,
25}
26
27impl TensorSpec {
28    pub fn byte_len(&self) -> usize {
29        self.shape.element_count() * self.dtype.bytes_per_element()
30    }
31}
32
33#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
34#[serde(rename_all = "camelCase")]
35pub struct ModelInfo {
36    pub id: String,
37    pub backend: String,
38    pub task: String,
39    #[serde(default)]
40    pub inputs: Vec<TensorSpec>,
41    #[serde(default)]
42    pub outputs: Vec<TensorSpec>,
43    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
44    pub metadata: HashMap<String, Value>,
45}
46
47impl ModelInfo {
48    pub fn mock(id: impl Into<String>, task: impl Into<String>, outputs: Vec<TensorSpec>) -> Self {
49        Self {
50            id: id.into(),
51            backend: "mock".to_string(),
52            task: task.into(),
53            inputs: Vec::new(),
54            outputs,
55            metadata: HashMap::new(),
56        }
57    }
58}
59
60#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
61#[serde(rename_all = "camelCase")]
62pub struct InferenceInput {
63    pub name: String,
64    pub tensor: TensorPacket,
65}
66
67#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
68#[serde(rename_all = "camelCase")]
69pub struct InferenceOutput {
70    #[serde(default)]
71    pub tensors: Vec<TensorPacket>,
72    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
73    pub metadata: HashMap<String, Value>,
74}
75
76pub trait InferenceBackend: Send + Sync {
77    fn name(&self) -> &str;
78
79    fn load_model(
80        &self,
81        model: ModelInfo,
82        model_data: Option<Arc<Vec<u8>>>,
83    ) -> Result<Box<dyn InferenceSession>>;
84}
85
86pub trait InferenceSession: Send + Sync {
87    fn model_info(&self) -> &ModelInfo;
88
89    fn run(&self, inputs: &[InferenceInput]) -> Result<InferenceOutput>;
90}
91
92#[derive(Debug, Clone, Default)]
93pub struct MockBackend;
94
95impl MockBackend {
96    pub fn new() -> Self {
97        Self
98    }
99}
100
101impl InferenceBackend for MockBackend {
102    fn name(&self) -> &str {
103        "mock"
104    }
105
106    fn load_model(
107        &self,
108        model: ModelInfo,
109        model_data: Option<Arc<Vec<u8>>>,
110    ) -> Result<Box<dyn InferenceSession>> {
111        if model.backend != "mock" && model.backend != "litert" {
112            bail!("mock backend cannot load backend '{}'", model.backend);
113        }
114        Ok(Box::new(MockSession {
115            model,
116            model_data_len: model_data.as_ref().map(|bytes| bytes.len()).unwrap_or(0),
117        }))
118    }
119}
120
121#[derive(Debug, Clone)]
122struct MockSession {
123    model: ModelInfo,
124    model_data_len: usize,
125}
126
127impl InferenceSession for MockSession {
128    fn model_info(&self) -> &ModelInfo {
129        &self.model
130    }
131
132    fn run(&self, inputs: &[InferenceInput]) -> Result<InferenceOutput> {
133        let outputs = if self.model.outputs.is_empty() {
134            infer_default_outputs(inputs)?
135        } else {
136            self.model
137                .outputs
138                .iter()
139                .map(|spec| deterministic_tensor(spec, &self.model, inputs))
140                .collect()
141        };
142
143        Ok(InferenceOutput {
144            tensors: outputs,
145            metadata: HashMap::from([
146                ("backend".to_string(), json!("mock")),
147                ("modelId".to_string(), json!(self.model.id)),
148                ("modelBytes".to_string(), json!(self.model_data_len)),
149                ("inputCount".to_string(), json!(inputs.len())),
150            ]),
151        })
152    }
153}
154
155#[cfg(feature = "external-litert")]
156mod external_litert {
157    use super::*;
158    use litert::{
159        Accelerators, CompilationOptions, CompiledModel, ElementType, Environment, Model,
160        TensorBuffer,
161    };
162    use std::sync::Mutex;
163
164    #[derive(Debug, Clone)]
165    pub struct LiteRtBackend {
166        accelerators: Accelerators,
167    }
168
169    impl LiteRtBackend {
170        pub fn new() -> Self {
171            Self {
172                accelerators: Accelerators::CPU,
173            }
174        }
175
176        pub fn with_accelerators(accelerators: Accelerators) -> Self {
177            Self { accelerators }
178        }
179    }
180
181    impl Default for LiteRtBackend {
182        fn default() -> Self {
183            Self::new()
184        }
185    }
186
187    impl InferenceBackend for LiteRtBackend {
188        fn name(&self) -> &str {
189            "litert"
190        }
191
192        fn load_model(
193            &self,
194            model: ModelInfo,
195            model_data: Option<Arc<Vec<u8>>>,
196        ) -> Result<Box<dyn InferenceSession>> {
197            if model.backend != "litert" {
198                bail!("LiteRT backend cannot load backend '{}'", model.backend);
199            }
200            let model_data =
201                model_data.ok_or_else(|| anyhow!("LiteRT backend requires model_data bytes"))?;
202            let env = Environment::new().map_err(|err| anyhow!("LiteRT environment: {err}"))?;
203            let litert_model = Model::from_bytes(model_data.as_ref().clone().into_boxed_slice())
204                .map_err(|err| anyhow!("LiteRT model load: {err}"))?;
205            let signature = litert_model
206                .signature(0)
207                .map_err(|err| anyhow!("LiteRT signature 0: {err}"))?;
208            let input_shapes = (0..signature.input_count()?)
209                .map(|index| signature.input_shape(index))
210                .collect::<litert::Result<Vec<_>>>()
211                .map_err(|err| anyhow!("LiteRT input shape introspection: {err}"))?;
212            let output_shapes = (0..signature.output_count()?)
213                .map(|index| signature.output_shape(index))
214                .collect::<litert::Result<Vec<_>>>()
215                .map_err(|err| anyhow!("LiteRT output shape introspection: {err}"))?;
216            validate_specs("input", &model.inputs, &input_shapes)?;
217            validate_specs("output", &model.outputs, &output_shapes)?;
218            let accelerators = accelerators_from_model(&model, self.accelerators)?;
219            let options = CompilationOptions::new()
220                .and_then(|options| options.with_accelerators(accelerators))
221                .map_err(|err| anyhow!("LiteRT compilation options: {err}"))?;
222            let mut input_buffers = input_shapes
223                .iter()
224                .map(|shape| TensorBuffer::managed_host(&env, shape))
225                .collect::<litert::Result<Vec<_>>>()
226                .map_err(|err| anyhow!("LiteRT input buffer allocation: {err}"))?;
227            let mut output_buffers = output_shapes
228                .iter()
229                .map(|shape| TensorBuffer::managed_host(&env, shape))
230                .collect::<litert::Result<Vec<_>>>()
231                .map_err(|err| anyhow!("LiteRT output buffer allocation: {err}"))?;
232            let compiled = CompiledModel::new(env, litert_model, &options)
233                .map_err(|err| anyhow!("LiteRT compile: {err}"))?;
234            let fully_accelerated = compiled
235                .is_fully_accelerated()
236                .map_err(|err| anyhow!("LiteRT acceleration query: {err}"))?;
237
238            // Keep buffers allocated against the same environment that owns the compiled model.
239            input_buffers.shrink_to_fit();
240            output_buffers.shrink_to_fit();
241
242            Ok(Box::new(LiteRtSession {
243                model,
244                state: Mutex::new(LiteRtSessionState {
245                    compiled,
246                    input_buffers,
247                    output_buffers,
248                    input_shapes,
249                    output_shapes,
250                    accelerators,
251                    fully_accelerated,
252                }),
253            }))
254        }
255    }
256
257    struct LiteRtSession {
258        model: ModelInfo,
259        state: Mutex<LiteRtSessionState>,
260    }
261
262    struct LiteRtSessionState {
263        compiled: CompiledModel,
264        input_buffers: Vec<TensorBuffer>,
265        output_buffers: Vec<TensorBuffer>,
266        input_shapes: Vec<litert::TensorShape>,
267        output_shapes: Vec<litert::TensorShape>,
268        accelerators: Accelerators,
269        fully_accelerated: bool,
270    }
271
272    impl InferenceSession for LiteRtSession {
273        fn model_info(&self) -> &ModelInfo {
274            &self.model
275        }
276
277        fn run(&self, inputs: &[InferenceInput]) -> Result<InferenceOutput> {
278            let mut state = self
279                .state
280                .lock()
281                .map_err(|_| anyhow!("LiteRT session state was poisoned"))?;
282            if inputs.len() != state.input_shapes.len() {
283                bail!(
284                    "LiteRT input count mismatch: model expects {}, graph provided {}",
285                    state.input_shapes.len(),
286                    inputs.len()
287                );
288            }
289
290            for index in 0..state.input_buffers.len() {
291                let input = input_for_index(&self.model, inputs, index)?;
292                let shape = state.input_shapes[index].clone();
293                let buffer = &mut state.input_buffers[index];
294                write_tensor_to_buffer(&input.tensor, buffer, &shape)?;
295            }
296
297            {
298                let LiteRtSessionState {
299                    compiled,
300                    input_buffers,
301                    output_buffers,
302                    ..
303                } = &mut *state;
304                compiled
305                    .run(input_buffers, output_buffers)
306                    .map_err(|err| anyhow!("LiteRT inference: {err}"))?;
307            }
308
309            let mut tensors = Vec::with_capacity(state.output_buffers.len());
310            for index in 0..state.output_buffers.len() {
311                let buffer = &state.output_buffers[index];
312                let shape = state.output_shapes[index].clone();
313                let name = self
314                    .model
315                    .outputs
316                    .get(index)
317                    .map(|spec| spec.name.clone())
318                    .unwrap_or_else(|| format!("output_{index}"));
319                tensors.push(read_tensor_from_buffer(name, buffer, &shape)?);
320            }
321            let fully_accelerated = state.fully_accelerated;
322            let accelerators = state.accelerators;
323
324            Ok(InferenceOutput {
325                tensors,
326                metadata: HashMap::from([
327                    ("backend".to_string(), json!("litert")),
328                    ("modelId".to_string(), json!(self.model.id)),
329                    ("inputCount".to_string(), json!(inputs.len())),
330                    (
331                        "accelerators".to_string(),
332                        json!(accelerator_names(accelerators)),
333                    ),
334                    ("fullyAccelerated".to_string(), json!(fully_accelerated)),
335                ]),
336            })
337        }
338    }
339
340    fn input_for_index<'a>(
341        model: &ModelInfo,
342        inputs: &'a [InferenceInput],
343        index: usize,
344    ) -> Result<&'a InferenceInput> {
345        if let Some(spec) = model.inputs.get(index) {
346            if let Some(input) = inputs.iter().find(|input| {
347                input.name == spec.name || input.tensor.name.as_deref() == Some(spec.name.as_str())
348            }) {
349                return Ok(input);
350            }
351        }
352        inputs
353            .get(index)
354            .ok_or_else(|| anyhow!("missing LiteRT input tensor at index {index}"))
355    }
356
357    fn write_tensor_to_buffer(
358        tensor: &TensorPacket,
359        buffer: &mut TensorBuffer,
360        shape: &litert::TensorShape,
361    ) -> Result<()> {
362        let expected_dtype = dtype_from_element_type(shape.element_type)?;
363        if tensor.dtype != expected_dtype {
364            bail!(
365                "LiteRT tensor dtype mismatch for {:?}: expected {:?}, got {:?}",
366                tensor.name,
367                expected_dtype,
368                tensor.dtype
369            );
370        }
371        let expected_dims = dims_from_litert(shape)?;
372        if tensor.shape.dims != expected_dims {
373            bail!(
374                "LiteRT tensor shape mismatch for {:?}: expected {:?}, got {:?}",
375                tensor.name,
376                expected_dims,
377                tensor.shape.dims
378            );
379        }
380        validate_tensor_byte_len(tensor)?;
381
382        match shape.element_type {
383            ElementType::Float32 => copy_to_buffer::<f32>(buffer, &read_f32_values(tensor)?)?,
384            ElementType::UInt8 => copy_to_buffer::<u8>(buffer, &tensor.data)?,
385            ElementType::Int8 => copy_to_buffer::<i8>(buffer, &read_i8_values(tensor)?)?,
386            ElementType::Int32 => copy_to_buffer::<i32>(buffer, &read_i32_values(tensor)?)?,
387            ElementType::Int64 => copy_to_buffer::<i64>(buffer, &read_i64_values(tensor)?)?,
388            ElementType::Bool => copy_to_buffer::<bool>(buffer, &read_bool_values(tensor)?)?,
389            other => bail!("LiteRT input element type {:?} is not supported yet", other),
390        }
391        Ok(())
392    }
393
394    fn read_tensor_from_buffer(
395        name: String,
396        buffer: &TensorBuffer,
397        shape: &litert::TensorShape,
398    ) -> Result<TensorPacket> {
399        let dtype = dtype_from_element_type(shape.element_type)?;
400        let dims = TensorShape::new(dims_from_litert(shape)?);
401        let data = match shape.element_type {
402            ElementType::Float32 => {
403                let values = buffer.lock_for_read::<f32>()?;
404                f32_values_to_bytes(&values)
405            }
406            ElementType::UInt8 => buffer.lock_for_read::<u8>()?.to_vec(),
407            ElementType::Int8 => buffer
408                .lock_for_read::<i8>()?
409                .iter()
410                .map(|value| *value as u8)
411                .collect(),
412            ElementType::Int32 => buffer
413                .lock_for_read::<i32>()?
414                .iter()
415                .flat_map(|value| value.to_le_bytes())
416                .collect(),
417            ElementType::Int64 => buffer
418                .lock_for_read::<i64>()?
419                .iter()
420                .flat_map(|value| value.to_le_bytes())
421                .collect(),
422            ElementType::Bool => buffer
423                .lock_for_read::<bool>()?
424                .iter()
425                .map(|value| u8::from(*value))
426                .collect(),
427            other => bail!(
428                "LiteRT output element type {:?} is not supported yet",
429                other
430            ),
431        };
432
433        Ok(TensorPacket::new(Some(name), dtype, dims, data))
434    }
435
436    fn copy_to_buffer<T: litert::TensorElement>(
437        buffer: &mut TensorBuffer,
438        values: &[T],
439    ) -> Result<()> {
440        let mut guard = buffer.lock_for_write::<T>()?;
441        if guard.len() != values.len() {
442            bail!(
443                "LiteRT buffer element count mismatch: expected {}, got {}",
444                guard.len(),
445                values.len()
446            );
447        }
448        guard.copy_from_slice(values);
449        Ok(())
450    }
451
452    fn validate_tensor_byte_len(tensor: &TensorPacket) -> Result<()> {
453        let expected = tensor.expected_byte_len();
454        if tensor.data.len() != expected {
455            bail!(
456                "tensor {:?} byte length mismatch: expected {}, got {}",
457                tensor.name,
458                expected,
459                tensor.data.len()
460            );
461        }
462        Ok(())
463    }
464
465    fn validate_specs(
466        label: &str,
467        specs: &[TensorSpec],
468        shapes: &[litert::TensorShape],
469    ) -> Result<()> {
470        if specs.is_empty() {
471            return Ok(());
472        }
473        if specs.len() != shapes.len() {
474            bail!(
475                "LiteRT {label} spec count mismatch: manifest declares {}, model exposes {}",
476                specs.len(),
477                shapes.len()
478            );
479        }
480        for (index, (spec, shape)) in specs.iter().zip(shapes.iter()).enumerate() {
481            let dtype = dtype_from_element_type(shape.element_type)?;
482            let dims = dims_from_litert(shape)?;
483            if spec.dtype != dtype || spec.shape.dims != dims {
484                bail!(
485                    "LiteRT {label} spec mismatch at index {index} ({:?}): manifest {:?} {:?}, model {:?} {:?}",
486                    spec.name,
487                    spec.dtype,
488                    spec.shape.dims,
489                    dtype,
490                    dims
491                );
492            }
493        }
494        Ok(())
495    }
496
497    fn dims_from_litert(shape: &litert::TensorShape) -> Result<Vec<usize>> {
498        shape
499            .dims
500            .iter()
501            .map(|dim| {
502                usize::try_from(*dim)
503                    .map_err(|_| anyhow!("LiteRT tensor has negative dimension {dim}"))
504            })
505            .collect()
506    }
507
508    fn dtype_from_element_type(element_type: ElementType) -> Result<TensorDType> {
509        Ok(match element_type {
510            ElementType::Float32 => TensorDType::F32,
511            ElementType::Int32 => TensorDType::I32,
512            ElementType::Int64 => TensorDType::I64,
513            ElementType::UInt8 => TensorDType::U8,
514            ElementType::Int8 => TensorDType::I8,
515            ElementType::Bool => TensorDType::Bool,
516            other => bail!("LiteRT element type {:?} is not supported yet", other),
517        })
518    }
519
520    fn accelerators_from_model(model: &ModelInfo, default: Accelerators) -> Result<Accelerators> {
521        let Some(value) = model
522            .metadata
523            .get("accelerators")
524            .or_else(|| model.metadata.get("accelerator"))
525        else {
526            return Ok(default);
527        };
528
529        match value {
530            Value::String(text) => parse_accelerator_list(text),
531            Value::Array(values) => {
532                let mut accelerators = Accelerators::NONE;
533                for value in values {
534                    let Some(name) = value.as_str() else {
535                        bail!("accelerators metadata array must contain strings");
536                    };
537                    accelerators = accelerators | parse_accelerator(name)?;
538                }
539                Ok(accelerators)
540            }
541            _ => bail!("accelerators metadata must be a string or array of strings"),
542        }
543    }
544
545    fn parse_accelerator_list(text: &str) -> Result<Accelerators> {
546        let mut accelerators = Accelerators::NONE;
547        for part in text.split([',', '|', '+']) {
548            let part = part.trim();
549            if part.is_empty() {
550                continue;
551            }
552            accelerators = accelerators | parse_accelerator(part)?;
553        }
554        Ok(accelerators)
555    }
556
557    fn parse_accelerator(name: &str) -> Result<Accelerators> {
558        match name.trim().to_ascii_lowercase().as_str() {
559            "none" => Ok(Accelerators::NONE),
560            "cpu" => Ok(Accelerators::CPU),
561            "gpu" | "metal" => Ok(Accelerators::GPU),
562            "npu" => Ok(Accelerators::NPU),
563            other => bail!("unsupported LiteRT accelerator '{other}'"),
564        }
565    }
566
567    fn accelerator_names(accelerators: Accelerators) -> Vec<&'static str> {
568        if accelerators == Accelerators::NONE {
569            return vec!["none"];
570        }
571        let mut names = Vec::new();
572        if accelerators.contains(Accelerators::CPU) {
573            names.push("cpu");
574        }
575        if accelerators.contains(Accelerators::GPU) {
576            names.push("gpu");
577        }
578        if accelerators.contains(Accelerators::NPU) {
579            names.push("npu");
580        }
581        names
582    }
583
584    fn read_f32_values(tensor: &TensorPacket) -> Result<Vec<f32>> {
585        tensor
586            .as_f32_vec()
587            .ok_or_else(|| anyhow!("expected f32 tensor bytes"))
588    }
589
590    fn read_i8_values(tensor: &TensorPacket) -> Result<Vec<i8>> {
591        Ok(tensor.data.iter().map(|value| *value as i8).collect())
592    }
593
594    fn read_i32_values(tensor: &TensorPacket) -> Result<Vec<i32>> {
595        read_chunks::<4, i32>(&tensor.data, i32::from_le_bytes)
596    }
597
598    fn read_i64_values(tensor: &TensorPacket) -> Result<Vec<i64>> {
599        read_chunks::<8, i64>(&tensor.data, i64::from_le_bytes)
600    }
601
602    fn read_bool_values(tensor: &TensorPacket) -> Result<Vec<bool>> {
603        Ok(tensor.data.iter().map(|value| *value != 0).collect())
604    }
605
606    fn read_chunks<const N: usize, T>(
607        data: &[u8],
608        decode: impl Fn([u8; N]) -> T,
609    ) -> Result<Vec<T>> {
610        if data.len() % N != 0 {
611            bail!("tensor byte length {} is not divisible by {N}", data.len());
612        }
613        Ok(data
614            .chunks_exact(N)
615            .map(|chunk| {
616                let mut bytes = [0u8; N];
617                bytes.copy_from_slice(chunk);
618                decode(bytes)
619            })
620            .collect())
621    }
622
623    fn f32_values_to_bytes(values: &[f32]) -> Vec<u8> {
624        values
625            .iter()
626            .flat_map(|value| value.to_le_bytes())
627            .collect()
628    }
629}
630
631fn infer_default_outputs(inputs: &[InferenceInput]) -> Result<Vec<TensorPacket>> {
632    let first = inputs
633        .first()
634        .ok_or_else(|| anyhow!("mock inference requires at least one input tensor"))?;
635    let spec = TensorSpec {
636        name: "output".to_string(),
637        dtype: TensorDType::F32,
638        shape: TensorShape::new([1, first.tensor.shape.element_count().clamp(1, 16)]),
639    };
640    Ok(vec![deterministic_tensor(
641        &spec,
642        &ModelInfo::mock("mock", "generic", vec![spec.clone()]),
643        inputs,
644    )])
645}
646
647fn deterministic_tensor(
648    spec: &TensorSpec,
649    model: &ModelInfo,
650    inputs: &[InferenceInput],
651) -> TensorPacket {
652    let count = spec.shape.element_count();
653    let seed = stable_seed(model, inputs, &spec.name);
654    let mut metadata = PacketMetadata::default();
655    if let Some(first) = inputs.first() {
656        metadata.merge_missing_from(&first.tensor.metadata);
657    }
658    metadata.fields.insert("mockSeed".to_string(), json!(seed));
659
660    match spec.dtype {
661        TensorDType::F32 => {
662            let mut values = Vec::with_capacity(count);
663            for i in 0..count {
664                let raw = seed.wrapping_add((i as u64).wrapping_mul(1_103_515_245));
665                values.push(((raw % 10_000) as f32 / 10_000.0).clamp(0.0, 1.0));
666            }
667            let mut tensor =
668                TensorPacket::from_f32(Some(spec.name.clone()), spec.shape.clone(), &values);
669            tensor.metadata = metadata;
670            tensor
671        }
672        TensorDType::U8 => {
673            let data = (0..count)
674                .map(|i| seed.wrapping_add(i as u64) as u8)
675                .collect::<Vec<_>>();
676            let mut tensor = TensorPacket::new(
677                Some(spec.name.clone()),
678                TensorDType::U8,
679                spec.shape.clone(),
680                data,
681            );
682            tensor.metadata = metadata;
683            tensor
684        }
685        _ => {
686            let bytes = vec![0u8; spec.byte_len()];
687            let mut tensor = TensorPacket::new(
688                Some(spec.name.clone()),
689                spec.dtype,
690                spec.shape.clone(),
691                bytes,
692            );
693            tensor.metadata = metadata;
694            tensor
695        }
696    }
697}
698
699fn stable_seed(model: &ModelInfo, inputs: &[InferenceInput], output_name: &str) -> u64 {
700    let mut hash = 14_695_981_039_346_656_037u64;
701    for byte in model.id.bytes().chain(output_name.bytes()) {
702        hash ^= byte as u64;
703        hash = hash.wrapping_mul(1_099_511_628_211);
704    }
705    for input in inputs {
706        for byte in input.name.bytes() {
707            hash ^= byte as u64;
708            hash = hash.wrapping_mul(1_099_511_628_211);
709        }
710        hash ^= input.tensor.data.len() as u64;
711        hash = hash.wrapping_mul(1_099_511_628_211);
712    }
713    hash
714}
715
716#[cfg(test)]
717mod tests {
718    use super::*;
719
720    #[test]
721    fn mock_backend_is_deterministic() {
722        let model = ModelInfo::mock(
723            "hand-landmark",
724            "landmark",
725            vec![TensorSpec {
726                name: "landmarks".to_string(),
727                dtype: TensorDType::F32,
728                shape: TensorShape::new([1, 6]),
729            }],
730        );
731        let input = InferenceInput {
732            name: "image".to_string(),
733            tensor: TensorPacket::from_f32(
734                Some("image".to_string()),
735                TensorShape::new([1, 2]),
736                &[0.0, 1.0],
737            ),
738        };
739        let backend = MockBackend::new();
740        let session = backend.load_model(model, None).unwrap();
741
742        let a = session.run(std::slice::from_ref(&input)).unwrap();
743        let b = session.run(&[input]).unwrap();
744
745        assert_eq!(a, b);
746        assert_eq!(a.tensors[0].shape.dims, vec![1, 6]);
747    }
748
749    #[cfg(feature = "external-litert")]
750    #[test]
751    fn external_litert_backend_runs_bundled_add_fixture() -> Result<()> {
752        let _ = litert::set_global_log_severity(litert::LogSeverity::Error);
753        let model_data = std::fs::read(
754            std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
755                .join("tests/data/add_10x10.tflite"),
756        )?;
757        let shape = TensorShape::new([10, 10]);
758        let model = ModelInfo {
759            id: "add_10x10".to_string(),
760            backend: "litert".to_string(),
761            task: "elementwise_add".to_string(),
762            inputs: vec![
763                TensorSpec {
764                    name: "lhs".to_string(),
765                    dtype: TensorDType::F32,
766                    shape: shape.clone(),
767                },
768                TensorSpec {
769                    name: "rhs".to_string(),
770                    dtype: TensorDType::F32,
771                    shape: shape.clone(),
772                },
773            ],
774            outputs: vec![TensorSpec {
775                name: "sum".to_string(),
776                dtype: TensorDType::F32,
777                shape: shape.clone(),
778            }],
779            metadata: HashMap::new(),
780        };
781        let lhs = (0..100).map(|index| index as f32).collect::<Vec<_>>();
782        let rhs = (0..100)
783            .map(|index| 100.0 + index as f32)
784            .collect::<Vec<_>>();
785        let backend = LiteRtBackend::new();
786        let session = backend.load_model(model, Some(Arc::new(model_data)))?;
787        let output = session.run(&[
788            InferenceInput {
789                name: "lhs".to_string(),
790                tensor: TensorPacket::from_f32(Some("lhs".to_string()), shape.clone(), &lhs),
791            },
792            InferenceInput {
793                name: "rhs".to_string(),
794                tensor: TensorPacket::from_f32(Some("rhs".to_string()), shape, &rhs),
795            },
796        ])?;
797
798        let values = output.tensors[0].as_f32_vec().unwrap();
799        assert_eq!(values.len(), 100);
800        for (index, value) in values.iter().enumerate() {
801            assert!(
802                (*value - (100.0 + 2.0 * index as f32)).abs() < 1e-6,
803                "element {index}: got {value}"
804            );
805        }
806        assert_eq!(output.metadata.get("backend"), Some(&json!("litert")));
807        Ok(())
808    }
809}