lfm 0.1.0

Rust ONNX inference for LiquidAI LFM2.5-VL (vision-language) models — implements the engine-agnostic llmtask::Task contract via llguidance for schema-constrained sampling
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
//! ORT session building + strict input/output validation.
//!
//! Per spec §8.5: ONNX I/O contract (tests/fixtures/onnx_io_contract.json).

use std::path::Path;

use ort::{
  session::Session,
  value::{Outlet, TensorElementType, ValueType},
};

use crate::{
  error::{Error, Result},
  options::Options,
};

/// Build an ORT session from a path with the given options.
///
/// Wires `optimization_level` and thread counts from `Options`. EP
/// registration (cuda/tensorrt/etc.) is feature-gated below per
/// spec §5.3 EP-feature pattern.
#[allow(dead_code)]
pub(crate) fn build_session(graph: &Path, opts: &Options) -> Result<Session> {
  if !graph.exists() {
    return Err(Error::NotFound(graph.to_path_buf()));
  }
  let level = opts.optimization_level();
  // Session::builder() returns ort::Result<SessionBuilder>.
  // with_* methods return BuilderResult = Result<SessionBuilder, Error<SessionBuilder>>.
  // Error<SessionBuilder> converts to ort::Error (Error<()>) via From.
  let mut builder = Session::builder()
    .map_err(Error::Ort)?
    .with_optimization_level(level)
    .map_err(|e| Error::Ort(ort::Error::from(e)))?;

  if let Some(t) = opts.thread().intra_threads() {
    builder = builder
      .with_intra_threads(t)
      .map_err(|e| Error::Ort(ort::Error::from(e)))?;
  }
  if let Some(t) = opts.thread().inter_threads() {
    builder = builder
      .with_inter_threads(t)
      .map_err(|e| Error::Ort(ort::Error::from(e)))?;
  }

  // register the requested execution provider
  // per Cargo feature. Without this, enabling `--features cuda` (etc.)
  // turned on the underlying ort EP support but never told the session
  // to use it, so workloads silently ran on CPU.
  //
  // ort 2.0's `with_execution_providers` takes an iterable of
  // `ExecutionProviderDispatch`; we register at most one per feature.
  // Multiple GPU features compiled in together stack in declaration
  // order — the first whose runtime is available wins.
  #[allow(unused_mut)]
  let mut eps: Vec<ort::execution_providers::ExecutionProviderDispatch> = Vec::new();
  #[cfg(feature = "cuda")]
  {
    eps.push(ort::execution_providers::CUDAExecutionProvider::default().build());
  }
  #[cfg(feature = "tensorrt")]
  {
    eps.push(ort::execution_providers::TensorRTExecutionProvider::default().build());
  }
  #[cfg(feature = "directml")]
  {
    eps.push(ort::execution_providers::DirectMLExecutionProvider::default().build());
  }
  #[cfg(feature = "rocm")]
  {
    eps.push(ort::execution_providers::ROCmExecutionProvider::default().build());
  }
  #[cfg(feature = "coreml")]
  {
    eps.push(ort::execution_providers::CoreMLExecutionProvider::default().build());
  }
  if !eps.is_empty() {
    builder = builder
      .with_execution_providers(eps)
      .map_err(|e| Error::Ort(ort::Error::from(e)))?;
  }

  let session = builder.commit_from_file(graph).map_err(Error::Ort)?;
  Ok(session)
}

/// Verify an outlet matches the expected dtype + shape.
///
/// `expected_shape` semantics:
/// - `-1` means "this axis MUST be dynamic in the graph". A static
///   dim there is rejected.
/// - any other value means "exact match (or `-1` ok)". The graph may
///   bake a concrete dim or declare it dynamic; both work at runtime.
///
/// Mirrors siglip2's `check_outlet` exactly.
#[allow(dead_code)]
pub(crate) fn check_outlet(
  outlets: &[Outlet],
  name: &'static str,
  expected_dtype: TensorElementType,
  expected_shape: &[i64],
) -> Result<()> {
  let outlet = outlets
    .iter()
    .find(|o| o.name() == name)
    .ok_or(Error::SessionShapeMismatch {
      input: name,
      expected: "outlet present in session",
      got: vec![],
    })?;

  match outlet.dtype() {
    ValueType::Tensor { ty, shape, .. } => {
      if *ty != expected_dtype {
        return Err(Error::SessionContractMismatch {
          input: name,
          expected: "matching tensor dtype",
          got: *ty,
        });
      }
      let actual: &[i64] = shape;
      if actual.len() != expected_shape.len() {
        return Err(Error::SessionShapeMismatch {
          input: name,
          expected: "matching tensor rank",
          got: actual.to_vec(),
        });
      }
      for (i, &want) in expected_shape.iter().enumerate() {
        let got = actual[i];
        if want == -1 {
          // Expected dynamic axis. The graph MUST declare it dynamic —
          // a static dim here would fail at runtime with variable batch sizes.
          if got != -1 {
            return Err(Error::SessionShapeMismatch {
              input: name,
              expected: "dynamic axis required",
              got: actual.to_vec(),
            });
          }
        } else {
          // Expected concrete dim. Graph may match exactly or declare
          // the axis dynamic (-1) — both work at runtime.
          if got != -1 && got != want {
            return Err(Error::SessionShapeMismatch {
              input: name,
              expected: "matching static dim",
              got: actual.to_vec(),
            });
          }
        }
      }
      Ok(())
    }
    _ => Err(Error::SessionShapeMismatch {
      input: name,
      expected: "tensor",
      got: vec![],
    }),
  }
}

/// Validate the vision encoder session against the contract.
/// pixel_values is PRE-PATCHIFIED [batch, num_patches, 768] (not image-shaped).
#[allow(dead_code)]
pub(crate) fn validate_vision_session(s: &Session) -> Result<()> {
  check_outlet(
    s.inputs(),
    "pixel_values",
    TensorElementType::Float32,
    &[-1, -1, 768],
  )?;
  check_outlet(
    s.inputs(),
    "pixel_attention_mask",
    TensorElementType::Int64,
    &[-1, -1],
  )?;
  check_outlet(
    s.inputs(),
    "spatial_shapes",
    TensorElementType::Int64,
    &[-1, 2],
  )?;
  // Output: rank 2 [num_image_tokens, 1024]. NOT rank 3.
  check_outlet(
    s.outputs(),
    "image_features",
    TensorElementType::Float32,
    &[-1, 1024],
  )?;
  Ok(())
}

/// Validate the embed_tokens session.
#[allow(dead_code)]
pub(crate) fn validate_embed_session(s: &Session) -> Result<()> {
  check_outlet(s.inputs(), "input_ids", TensorElementType::Int64, &[-1, -1])?;
  check_outlet(
    s.outputs(),
    "inputs_embeds",
    TensorElementType::Float32,
    &[-1, -1, 1024],
  )?;
  Ok(())
}

/// Validate the decoder session.
/// decoder has NO `position_ids` input.
/// cache uses sparse layer indices
/// (conv at [0,1,3,4,6,7,9,11,13,15], attn at [2,5,8,10,12,14] × {key,value}).
#[allow(dead_code)]
pub(crate) fn validate_decoder_session(s: &Session) -> Result<()> {
  check_outlet(
    s.inputs(),
    "inputs_embeds",
    TensorElementType::Float32,
    &[-1, -1, 1024],
  )?;
  check_outlet(
    s.inputs(),
    "attention_mask",
    TensorElementType::Int64,
    &[-1, -1],
  )?;

  // actively REJECT position_ids if
  // present. Decoder::step does not pass position_ids; an ONNX export
  // that requires it would silently fail at first session.run with an
  // opaque ORT error. Catch it at construction.
  if s.inputs().iter().any(|o| o.name() == "position_ids") {
    return Err(Error::SessionShapeMismatch {
      input: "position_ids",
      expected: "must NOT be a required input (Decoder::step doesn't pass it)",
      got: vec![],
    });
  }

  let cache = collect_cache_inputs(s.inputs())?;
  if cache.conv.len() != 10 || cache.attn.len() != 12 {
    return Err(Error::DecoderCacheMismatch {
      expected_conv: 10,
      expected_attn: 12,
      got_conv: cache.conv.len(),
      got_attn: cache.attn.len(),
    });
  }
  // Sparse-index check: collect indices from discovered names, verify
  // they exactly match the expected sets.
  const EXPECTED_CONV: &[u32] = &[0, 1, 3, 4, 6, 7, 9, 11, 13, 15];
  const EXPECTED_ATTN: &[u32] = &[2, 5, 8, 10, 12, 14];
  let mut conv_indices: Vec<u32> = cache
    .conv
    .iter()
    .filter_map(|n| parse_conv_index(n))
    .collect();
  conv_indices.sort_unstable();
  if conv_indices != EXPECTED_CONV {
    return Err(Error::SessionShapeMismatch {
      input: "past_conv.*",
      expected: "sparse indices [0,1,3,4,6,7,9,11,13,15]",
      got: conv_indices.into_iter().map(i64::from).collect(),
    });
  }
  let mut attn_indices: Vec<u32> = cache
    .attn
    .iter()
    .filter_map(|n| parse_attn_index(n))
    .collect();
  attn_indices.sort_unstable();
  attn_indices.dedup();
  if attn_indices != EXPECTED_ATTN {
    return Err(Error::SessionShapeMismatch {
      input: "past_key_values.*.{key,value}",
      expected: "sparse indices [2,5,8,10,12,14]",
      got: attn_indices.into_iter().map(i64::from).collect(),
    });
  }

  // validate dtype + shape for EACH
  // past_* cache input AND its corresponding present_* output. The
  // fix already required present_* to exist for every past_*;
  // this adds the dtype/shape contract so an ONNX export with same
  // names but changed dimensions (e.g., a different head dim, or
  // float16 instead of float32) fails at construction instead of at
  // first decode-step with an opaque ORT shape error.
  //
  // Conv cache: shape [1, 1024, 3], dtype f32 ().
  for name in &cache.conv {
    let owned: &'static str = leak_static(name);
    check_outlet(s.inputs(), owned, TensorElementType::Float32, &[1, 1024, 3])?;
    let present = format!(
      "present_conv.{}",
      parse_conv_index(name).unwrap_or(u32::MAX)
    );
    let present_owned: &'static str = leak_static(&present);
    check_outlet(
      s.outputs(),
      present_owned,
      TensorElementType::Float32,
      &[1, 1024, 3],
    )?;
  }
  // Attn cache: shape [1, 8, past_len, 64], dtype f32 ().
  // past_len is dynamic (-1) on inputs; present is also dynamic since
  // it's past_len + seq.
  for name in &cache.attn {
    let owned: &'static str = leak_static(name);
    check_outlet(
      s.inputs(),
      owned,
      TensorElementType::Float32,
      &[1, 8, -1, 64],
    )?;
    // present_X.key / present_X.value derived from past_key_values.X.{key,value}
    if let Some(rest) = name.strip_prefix("past_key_values.") {
      let present = format!("present.{rest}");
      let present_owned: &'static str = leak_static(&present);
      check_outlet(
        s.outputs(),
        present_owned,
        TensorElementType::Float32,
        &[1, 8, -1, 64],
      )?;
    }
  }

  check_outlet(
    s.outputs(),
    "logits",
    TensorElementType::Float32,
    &[-1, -1, 65536],
  )?;
  Ok(())
}

/// Leak a `String` to obtain a `&'static str` for `check_outlet`'s
/// `name: &'static str` parameter. Called O(layer count) times per
/// session construction (≤22 outlets × 2 = 44 leaks); the leaked
/// memory persists for the process lifetime and is bounded.
fn leak_static(s: &str) -> &'static str {
  Box::leak(s.to_string().into_boxed_str())
}

/// Cache input names grouped by kind, discovered at session-build time.
#[allow(dead_code)]
pub(crate) struct CacheInputs {
  pub(crate) conv: Vec<String>,
  pub(crate) attn: Vec<String>,
}

#[allow(dead_code)]
pub(crate) fn collect_cache_inputs(outlets: &[Outlet]) -> Result<CacheInputs> {
  let mut conv = Vec::new();
  let mut attn = Vec::new();
  for o in outlets {
    let n = o.name();
    if n.starts_with("past_conv.") {
      conv.push(n.to_string());
    } else if n.starts_with("past_key_values.") {
      attn.push(n.to_string());
    }
  }
  Ok(CacheInputs { conv, attn })
}

fn parse_conv_index(name: &str) -> Option<u32> {
  name.strip_prefix("past_conv.")?.parse().ok()
}

#[allow(dead_code)]
fn parse_attn_index(name: &str) -> Option<u32> {
  let rest = name.strip_prefix("past_key_values.")?;
  let dot = rest.find('.')?;
  rest[..dot].parse().ok()
}

#[cfg(test)]
mod tests {
  use super::*;

  #[test]
  fn parse_conv_index_works() {
    assert_eq!(parse_conv_index("past_conv.0"), Some(0));
    assert_eq!(parse_conv_index("past_conv.15"), Some(15));
    assert_eq!(parse_conv_index("past_kv.0"), None);
    assert_eq!(parse_conv_index("past_conv."), None); // empty index
    assert_eq!(parse_conv_index("past_conv.foo"), None); // non-numeric
  }

  #[test]
  fn parse_attn_index_works() {
    assert_eq!(parse_attn_index("past_key_values.2.key"), Some(2));
    assert_eq!(parse_attn_index("past_key_values.14.value"), Some(14));
    assert_eq!(parse_attn_index("past_conv.0"), None);
    assert_eq!(parse_attn_index("past_key_values.2"), None); // no .key/.value suffix
  }

  // Note: validators that require a real ort::Session are tested at the
  // integration level (Task 15) — they need actual ONNX files. The
  // shape-discovery + sparse-index sorting logic here is testable via
  // string parsing, which we cover above.
}