Skip to main content

tl_ai/
predict.rs

1// ThinkingLanguage — Prediction (ONNX Runtime + linfa)
2
3use std::path::Path;
4
5use crate::model::TlModel;
6use crate::tensor::TlTensor;
7use crate::train::predict_linfa;
8
9/// Run prediction on a model.
10pub fn predict(model: &TlModel, input: &TlTensor) -> Result<TlTensor, String> {
11    match model {
12        TlModel::Onnx { path, .. } => predict_onnx(path, input),
13        TlModel::Linfa { .. } => predict_linfa(model, input),
14        TlModel::LlmEndpoint { .. } => {
15            Err("Cannot use predict() on an LLM endpoint. Use ai_complete() instead.".to_string())
16        }
17    }
18}
19
20/// Run prediction using ONNX Runtime.
21pub fn predict_onnx(model_path: &Path, input: &TlTensor) -> Result<TlTensor, String> {
22    use ort::session::Session;
23
24    let mut session = Session::builder()
25        .and_then(|mut b| b.commit_from_file(model_path))
26        .map_err(|e| format!("Failed to load ONNX model: {e}"))?;
27
28    let shape = input.shape();
29    let flat_data: Vec<f32> = input.to_vec().iter().map(|&x| x as f32).collect();
30    let shape_i64: Vec<i64> = shape.iter().map(|&s| s as i64).collect();
31
32    // Create ORT tensor value: needs (shape, Vec<T>)
33    let input_value = ort::value::Tensor::from_array((shape_i64, flat_data))
34        .map_err(|e| format!("Failed to create ORT tensor: {e}"))?;
35
36    let outputs = session
37        .run(ort::inputs![input_value])
38        .map_err(|e| format!("ONNX inference failed: {e}"))?;
39
40    // Extract first output
41    let output = outputs.values().next().ok_or("No output from ONNX model")?;
42
43    let (out_shape_ref, out_flat) = output
44        .try_extract_tensor::<f32>()
45        .map_err(|e| format!("Failed to extract output: {e}"))?;
46
47    let out_shape: Vec<usize> = out_shape_ref.iter().map(|&d| d as usize).collect();
48    let out_data: Vec<f64> = out_flat.iter().map(|&x| x as f64).collect();
49
50    TlTensor::from_vec(out_data, &out_shape)
51}
52
53/// Batch prediction: split input into batches, predict, reassemble.
54pub fn predict_batch(
55    model: &TlModel,
56    input: &TlTensor,
57    batch_size: usize,
58) -> Result<TlTensor, String> {
59    let shape = input.shape();
60    if shape.len() < 2 {
61        return predict(model, input);
62    }
63
64    let n_samples = shape[0];
65    if n_samples <= batch_size {
66        return predict(model, input);
67    }
68
69    let mut all_preds = Vec::new();
70
71    for start in (0..n_samples).step_by(batch_size) {
72        let end = (start + batch_size).min(n_samples);
73        let batch = input.slice(start, end)?;
74        let preds = predict(model, &batch)?;
75        all_preds.extend(preds.to_vec());
76    }
77
78    Ok(TlTensor::from_list(all_preds))
79}